mirror of
https://github.com/tgdrive/teldrive.git
synced 2024-09-20 08:15:55 +08:00
feat: MultiThreaded stream support
This commit is contained in:
parent
8c9ca99a93
commit
825dc11fe1
36
.github/workflows/build-dev.yml
vendored
Normal file
36
.github/workflows/build-dev.yml
vendored
Normal file
|
@ -0,0 +1,36 @@
|
||||||
|
name: Build Dev
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- dev
|
||||||
|
|
||||||
|
permissions: write-all
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
goreleaser:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v3
|
||||||
|
|
||||||
|
- name: Set up Docker Buildx
|
||||||
|
uses: docker/setup-buildx-action@v3
|
||||||
|
|
||||||
|
- name: Login to GitHub Container Registry
|
||||||
|
uses: docker/login-action@v3
|
||||||
|
with:
|
||||||
|
registry: ghcr.io
|
||||||
|
username: ${{ github.actor }}
|
||||||
|
password: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|
||||||
|
- name: Set Vars
|
||||||
|
run: |
|
||||||
|
echo "IMAGE=ghcr.io/${GITHUB_REPOSITORY,,}" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Build Image
|
||||||
|
uses: docker/build-push-action@v6
|
||||||
|
with:
|
||||||
|
context: .
|
||||||
|
push: true
|
||||||
|
tags: ${{ env.IMAGE }}:dev
|
24
Dockerfile
Normal file
24
Dockerfile
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
|
||||||
|
FROM golang:alpine AS builder
|
||||||
|
|
||||||
|
RUN apk add --no-cache git unzip curl make bash
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
COPY go.mod go.sum ./
|
||||||
|
|
||||||
|
RUN go mod download
|
||||||
|
|
||||||
|
COPY . .
|
||||||
|
|
||||||
|
RUN make build
|
||||||
|
|
||||||
|
FROM scratch
|
||||||
|
|
||||||
|
WORKDIR /
|
||||||
|
|
||||||
|
COPY --from=builder /app/bin/teldrive /teldrive
|
||||||
|
|
||||||
|
EXPOSE 8080
|
||||||
|
|
||||||
|
ENTRYPOINT ["/teldrive","run","--tg-session-file","/session.db"]
|
12
Makefile
12
Makefile
|
@ -11,19 +11,16 @@ FRONTEND_ASSET := https://github.com/divyam234/teldrive-ui/releases/download/v1/
|
||||||
GIT_TAG := $(shell git describe --tags --abbrev=0)
|
GIT_TAG := $(shell git describe --tags --abbrev=0)
|
||||||
GIT_COMMIT := $(shell git rev-parse --short HEAD)
|
GIT_COMMIT := $(shell git rev-parse --short HEAD)
|
||||||
GIT_LINK := $(shell git remote get-url origin)
|
GIT_LINK := $(shell git remote get-url origin)
|
||||||
ENV_FILE := $(FRONTEND_DIR)/.env
|
|
||||||
MODULE_PATH := $(shell go list -m)
|
MODULE_PATH := $(shell go list -m)
|
||||||
BUILD_DATE := $(shell $(BUILD_DATE))
|
|
||||||
|
|
||||||
GOOS ?= $(shell go env GOOS)
|
GOOS ?= $(shell go env GOOS)
|
||||||
GOARCH ?= $(shell go env GOARCH)
|
GOARCH ?= $(shell go env GOARCH)
|
||||||
|
VERSION:= $(GIT_TAG)
|
||||||
|
BINARY_EXTENSION :=
|
||||||
|
|
||||||
.PHONY: all build run clean frontend backend run sync-ui retag patch-version minor-version
|
.PHONY: all build run clean frontend backend run sync-ui retag patch-version minor-version
|
||||||
|
|
||||||
all: build
|
all: build
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
frontend:
|
frontend:
|
||||||
@echo "Extract UI"
|
@echo "Extract UI"
|
||||||
ifeq ($(OS),Windows_NT)
|
ifeq ($(OS),Windows_NT)
|
||||||
|
@ -40,10 +37,13 @@ else
|
||||||
rm -rf teldrive-ui.zip
|
rm -rf teldrive-ui.zip
|
||||||
endif
|
endif
|
||||||
|
|
||||||
|
ifeq ($(OS),Windows_NT)
|
||||||
|
BINARY_EXTENSION := .exe
|
||||||
|
endif
|
||||||
|
|
||||||
backend:
|
backend:
|
||||||
@echo "Building backend for $(GOOS)/$(GOARCH)..."
|
@echo "Building backend for $(GOOS)/$(GOARCH)..."
|
||||||
go build -trimpath -ldflags "-s -w -X $(MODULE_PATH)/internal/config.Version=$(GIT_TAG) -extldflags=-static" -o $(BUILD_DIR)/$(APP_NAME)$(BINARY_EXTENSION)
|
go build -trimpath -ldflags "-s -w -X $(MODULE_PATH)/internal/config.Version=$(VERSION) -extldflags=-static" -o $(BUILD_DIR)/$(APP_NAME)$(BINARY_EXTENSION)
|
||||||
|
|
||||||
build: frontend backend
|
build: frontend backend
|
||||||
@echo "Building complete."
|
@echo "Building complete."
|
||||||
|
|
|
@ -28,6 +28,8 @@ func InitRouter(r *gin.Engine, c *controller.Controller, cnf *config.Config) *gi
|
||||||
files.PATCH(":fileID", authmiddleware, c.UpdateFile)
|
files.PATCH(":fileID", authmiddleware, c.UpdateFile)
|
||||||
files.HEAD(":fileID/stream/:fileName", c.GetFileStream)
|
files.HEAD(":fileID/stream/:fileName", c.GetFileStream)
|
||||||
files.GET(":fileID/stream/:fileName", c.GetFileStream)
|
files.GET(":fileID/stream/:fileName", c.GetFileStream)
|
||||||
|
files.HEAD(":fileID/download/:fileName", c.GetFileDownload)
|
||||||
|
files.GET(":fileID/download/:fileName", c.GetFileDownload)
|
||||||
files.DELETE(":fileID/parts", authmiddleware, c.DeleteFileParts)
|
files.DELETE(":fileID/parts", authmiddleware, c.DeleteFileParts)
|
||||||
files.GET("/category/stats", authmiddleware, c.GetCategoryStats)
|
files.GET("/category/stats", authmiddleware, c.GetCategoryStats)
|
||||||
files.POST("/move", authmiddleware, c.MoveFiles)
|
files.POST("/move", authmiddleware, c.MoveFiles)
|
||||||
|
|
13
cmd/run.go
13
cmd/run.go
|
@ -89,12 +89,11 @@ func NewRun() *cobra.Command {
|
||||||
runCmd.Flags().IntVar(&config.TG.Uploads.Threads, "tg-uploads-threads", 8, "Uploads threads")
|
runCmd.Flags().IntVar(&config.TG.Uploads.Threads, "tg-uploads-threads", 8, "Uploads threads")
|
||||||
runCmd.Flags().IntVar(&config.TG.Uploads.MaxRetries, "tg-uploads-max-retries", 10, "Uploads Retries")
|
runCmd.Flags().IntVar(&config.TG.Uploads.MaxRetries, "tg-uploads-max-retries", 10, "Uploads Retries")
|
||||||
runCmd.Flags().Int64Var(&config.TG.PoolSize, "tg-pool-size", 8, "Telegram Session pool size")
|
runCmd.Flags().Int64Var(&config.TG.PoolSize, "tg-pool-size", 8, "Telegram Session pool size")
|
||||||
runCmd.Flags().BoolVar(&config.TG.Stream.BufferReader, "tg-stream-buffer-reader", false, "Async Buffered reader for fast streaming")
|
|
||||||
runCmd.Flags().IntVar(&config.TG.Stream.Buffers, "tg-stream-buffers", 16, "No of Stream buffers")
|
|
||||||
runCmd.Flags().BoolVar(&config.TG.Stream.UseMmap, "tg-stream-use-mmap", false, "Use mmap for stream buffers")
|
|
||||||
runCmd.Flags().BoolVar(&config.TG.Stream.UsePooling, "tg-stream-use-pooling", false, "Use session pooling for stream workers")
|
|
||||||
duration.DurationVar(runCmd.Flags(), &config.TG.ReconnectTimeout, "tg-reconnect-timeout", 5*time.Minute, "Reconnect Timeout")
|
duration.DurationVar(runCmd.Flags(), &config.TG.ReconnectTimeout, "tg-reconnect-timeout", 5*time.Minute, "Reconnect Timeout")
|
||||||
duration.DurationVar(runCmd.Flags(), &config.TG.Uploads.Retention, "tg-uploads-retention", (24*7)*time.Hour, "Uploads retention duration")
|
duration.DurationVar(runCmd.Flags(), &config.TG.Uploads.Retention, "tg-uploads-retention", (24*7)*time.Hour, "Uploads retention duration")
|
||||||
|
runCmd.Flags().IntVar(&config.TG.Stream.MultiThreads, "tg-stream-multi-threads", 0, "Stream multi-threads")
|
||||||
|
runCmd.Flags().IntVar(&config.TG.Stream.Buffers, "tg-stream-buffers", 8, "No of Stream buffers")
|
||||||
|
duration.DurationVar(runCmd.Flags(), &config.TG.Stream.ChunkTimeout, "tg-stream-chunk-timeout", 30*time.Second, "Chunk Fetch Timeout")
|
||||||
runCmd.MarkFlagRequired("tg-app-id")
|
runCmd.MarkFlagRequired("tg-app-id")
|
||||||
runCmd.MarkFlagRequired("tg-app-hash")
|
runCmd.MarkFlagRequired("tg-app-hash")
|
||||||
runCmd.MarkFlagRequired("db-data-source")
|
runCmd.MarkFlagRequired("db-data-source")
|
||||||
|
@ -162,11 +161,11 @@ func initViperConfig(cmd *cobra.Command) error {
|
||||||
viper.SetEnvKeyReplacer(strings.NewReplacer("-", "_"))
|
viper.SetEnvKeyReplacer(strings.NewReplacer("-", "_"))
|
||||||
viper.AutomaticEnv()
|
viper.AutomaticEnv()
|
||||||
viper.ReadInConfig()
|
viper.ReadInConfig()
|
||||||
bindFlagsRecursive(cmd.Flags(), "", reflect.ValueOf(config.Config{}))
|
bindFlags(cmd.Flags(), "", reflect.ValueOf(config.Config{}))
|
||||||
return nil
|
return nil
|
||||||
|
|
||||||
}
|
}
|
||||||
func bindFlagsRecursive(flags *pflag.FlagSet, prefix string, v reflect.Value) {
|
func bindFlags(flags *pflag.FlagSet, prefix string, v reflect.Value) {
|
||||||
t := v.Type()
|
t := v.Type()
|
||||||
if t.Kind() == reflect.Ptr {
|
if t.Kind() == reflect.Ptr {
|
||||||
t = t.Elem()
|
t = t.Elem()
|
||||||
|
@ -175,7 +174,7 @@ func bindFlagsRecursive(flags *pflag.FlagSet, prefix string, v reflect.Value) {
|
||||||
field := t.Field(i)
|
field := t.Field(i)
|
||||||
switch field.Type.Kind() {
|
switch field.Type.Kind() {
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
bindFlagsRecursive(flags, fmt.Sprintf("%s.%s", prefix, strings.ToLower(field.Name)), v.Field(i))
|
bindFlags(flags, fmt.Sprintf("%s.%s", prefix, strings.ToLower(field.Name)), v.Field(i))
|
||||||
default:
|
default:
|
||||||
newPrefix := prefix[1:]
|
newPrefix := prefix[1:]
|
||||||
newName := modifyFlag(field.Name)
|
newName := modifyFlag(field.Name)
|
||||||
|
|
|
@ -43,8 +43,6 @@
|
||||||
retention = "7d"
|
retention = "7d"
|
||||||
threads = 8
|
threads = 8
|
||||||
[tg.stream]
|
[tg.stream]
|
||||||
buffer-reader = false
|
multi-threads = 0
|
||||||
buffers = 6
|
buffers = 16
|
||||||
use-mmap = false
|
|
||||||
use-pooling= false
|
|
||||||
|
|
||||||
|
|
3
go.mod
3
go.mod
|
@ -16,7 +16,6 @@ require (
|
||||||
github.com/magiconair/properties v1.8.7
|
github.com/magiconair/properties v1.8.7
|
||||||
github.com/mitchellh/go-homedir v1.1.0
|
github.com/mitchellh/go-homedir v1.1.0
|
||||||
github.com/pkg/errors v0.9.1
|
github.com/pkg/errors v0.9.1
|
||||||
github.com/rclone/rclone v1.67.0
|
|
||||||
github.com/spf13/cobra v1.8.1
|
github.com/spf13/cobra v1.8.1
|
||||||
github.com/spf13/pflag v1.0.5
|
github.com/spf13/pflag v1.0.5
|
||||||
github.com/spf13/viper v1.19.0
|
github.com/spf13/viper v1.19.0
|
||||||
|
@ -97,7 +96,7 @@ require (
|
||||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||||
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
||||||
github.com/pressly/goose/v3 v3.20.0
|
github.com/pressly/goose/v3 v3.21.1
|
||||||
github.com/segmentio/asm v1.2.0 // indirect
|
github.com/segmentio/asm v1.2.0 // indirect
|
||||||
github.com/stretchr/testify v1.9.0
|
github.com/stretchr/testify v1.9.0
|
||||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||||
|
|
10
go.sum
10
go.sum
|
@ -145,8 +145,8 @@ github.com/mattn/go-sqlite3 v1.14.19 h1:fhGleo2h1p8tVChob4I9HpmVFIAkKGpiukdrgQbW
|
||||||
github.com/mattn/go-sqlite3 v1.14.19/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
|
github.com/mattn/go-sqlite3 v1.14.19/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
|
||||||
github.com/mfridman/interpolate v0.0.2 h1:pnuTK7MQIxxFz1Gr+rjSIx9u7qVjf5VOoM/u6BbAxPY=
|
github.com/mfridman/interpolate v0.0.2 h1:pnuTK7MQIxxFz1Gr+rjSIx9u7qVjf5VOoM/u6BbAxPY=
|
||||||
github.com/mfridman/interpolate v0.0.2/go.mod h1:p+7uk6oE07mpE/Ik1b8EckO0O4ZXiGAfshKBWLUM9Xg=
|
github.com/mfridman/interpolate v0.0.2/go.mod h1:p+7uk6oE07mpE/Ik1b8EckO0O4ZXiGAfshKBWLUM9Xg=
|
||||||
github.com/microsoft/go-mssqldb v1.7.0 h1:sgMPW0HA6Ihd37Yx0MzHyKD726C2kY/8KJsQtXHNaAs=
|
github.com/microsoft/go-mssqldb v1.7.1 h1:KU/g8aWeM3Hx7IMOFpiwYiUkU+9zeISb4+tx3ScVfsM=
|
||||||
github.com/microsoft/go-mssqldb v1.7.0/go.mod h1:kOvZKUdrhhFQmxLZqbwUV0rHkNkZpthMITIb2Ko1IoA=
|
github.com/microsoft/go-mssqldb v1.7.1/go.mod h1:kOvZKUdrhhFQmxLZqbwUV0rHkNkZpthMITIb2Ko1IoA=
|
||||||
github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y=
|
github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y=
|
||||||
github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
|
github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
|
||||||
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
|
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
|
||||||
|
@ -166,10 +166,8 @@ github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE
|
||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
|
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
|
||||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
github.com/pressly/goose/v3 v3.20.0 h1:uPJdOxF/Ipj7ABVNOAMJXSxwFXZGwMGHNqjC8e61VA0=
|
github.com/pressly/goose/v3 v3.21.1 h1:5SSAKKWej8LVVzNLuT6KIvP1eFDuPvxa+B6H0w78buQ=
|
||||||
github.com/pressly/goose/v3 v3.20.0/go.mod h1:BRfF2GcG4FTG12QfdBVy3q1yveaf4ckL9vWwEcIO3lA=
|
github.com/pressly/goose/v3 v3.21.1/go.mod h1:sqthmzV8PitchEkjecFJII//l43dLOCzfWh8pHEe+vE=
|
||||||
github.com/rclone/rclone v1.67.0 h1:yLRNgHEG2vQ60HCuzFqd0hYwKCRuWuvPUhvhMJ2jI5E=
|
|
||||||
github.com/rclone/rclone v1.67.0/go.mod h1:Cb3Ar47M/SvwfhAjZTbVXdtrP/JLtPFCq2tkdtBVC6w=
|
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||||
github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
|
github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
|
||||||
|
|
|
@ -2,8 +2,10 @@ package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
"github.com/divyam234/teldrive/pkg/types"
|
"github.com/divyam234/teldrive/pkg/types"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/go-jose/go-jose/v3"
|
"github.com/go-jose/go-jose/v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -59,3 +61,10 @@ func Decode(secret string, token string) (*types.JWTClaims, error) {
|
||||||
return jwtToken, nil
|
return jwtToken, nil
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetUser(c *gin.Context) (int64, string) {
|
||||||
|
val, _ := c.Get("jwtUser")
|
||||||
|
jwtUser := val.(*types.JWTClaims)
|
||||||
|
userId, _ := strconv.ParseInt(jwtUser.Subject, 10, 64)
|
||||||
|
return userId, jwtUser.TgSession
|
||||||
|
}
|
||||||
|
|
|
@ -43,10 +43,9 @@ type TGConfig struct {
|
||||||
Retention time.Duration
|
Retention time.Duration
|
||||||
}
|
}
|
||||||
Stream struct {
|
Stream struct {
|
||||||
BufferReader bool
|
MultiThreads int
|
||||||
Buffers int
|
Buffers int
|
||||||
UseMmap bool
|
ChunkTimeout time.Duration
|
||||||
UsePooling bool
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,206 +0,0 @@
|
||||||
// taken from rclone async reader implmentation
|
|
||||||
package reader
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"io"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/rclone/rclone/lib/pool"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
BufferSize = 1024 * 1024
|
|
||||||
softStartInitial = 4 * 1024
|
|
||||||
bufferCacheSize = 64
|
|
||||||
bufferCacheFlushTime = 5 * time.Second
|
|
||||||
)
|
|
||||||
|
|
||||||
var ErrorStreamAbandoned = errors.New("stream abandoned")
|
|
||||||
|
|
||||||
type AsyncReader struct {
|
|
||||||
in io.ReadCloser
|
|
||||||
ready chan *buffer
|
|
||||||
token chan struct{}
|
|
||||||
exit chan struct{}
|
|
||||||
buffers int
|
|
||||||
err error
|
|
||||||
cur *buffer
|
|
||||||
exited chan struct{}
|
|
||||||
closed bool
|
|
||||||
mu sync.Mutex
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewAsyncReader(ctx context.Context, rd io.ReadCloser, buffers int) (*AsyncReader, error) {
|
|
||||||
if buffers <= 0 {
|
|
||||||
return nil, errors.New("number of buffers too small")
|
|
||||||
}
|
|
||||||
if rd == nil {
|
|
||||||
return nil, errors.New("nil reader supplied")
|
|
||||||
}
|
|
||||||
a := &AsyncReader{}
|
|
||||||
a.init(rd, buffers)
|
|
||||||
return a, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *AsyncReader) init(rd io.ReadCloser, buffers int) {
|
|
||||||
a.in = rd
|
|
||||||
a.ready = make(chan *buffer, buffers)
|
|
||||||
a.token = make(chan struct{}, buffers)
|
|
||||||
a.exit = make(chan struct{})
|
|
||||||
a.exited = make(chan struct{})
|
|
||||||
a.buffers = buffers
|
|
||||||
a.cur = nil
|
|
||||||
|
|
||||||
for i := 0; i < buffers; i++ {
|
|
||||||
a.token <- struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
defer close(a.exited)
|
|
||||||
defer close(a.ready)
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-a.token:
|
|
||||||
b := a.getBuffer()
|
|
||||||
err := b.read(a.in)
|
|
||||||
a.ready <- b
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
case <-a.exit:
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
var bufferPool *pool.Pool
|
|
||||||
var bufferPoolOnce sync.Once
|
|
||||||
|
|
||||||
func (a *AsyncReader) putBuffer(b *buffer) {
|
|
||||||
bufferPool.Put(b.buf)
|
|
||||||
b.buf = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *AsyncReader) getBuffer() *buffer {
|
|
||||||
bufferPoolOnce.Do(func() {
|
|
||||||
bufferPool = pool.New(bufferCacheFlushTime, BufferSize, bufferCacheSize, false)
|
|
||||||
})
|
|
||||||
return &buffer{
|
|
||||||
buf: bufferPool.Get(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *AsyncReader) fill() (err error) {
|
|
||||||
if a.cur.isEmpty() {
|
|
||||||
if a.cur != nil {
|
|
||||||
a.putBuffer(a.cur)
|
|
||||||
a.token <- struct{}{}
|
|
||||||
a.cur = nil
|
|
||||||
}
|
|
||||||
b, ok := <-a.ready
|
|
||||||
if !ok {
|
|
||||||
if a.err == nil {
|
|
||||||
return ErrorStreamAbandoned
|
|
||||||
}
|
|
||||||
return a.err
|
|
||||||
}
|
|
||||||
a.cur = b
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *AsyncReader) Read(p []byte) (n int, err error) {
|
|
||||||
a.mu.Lock()
|
|
||||||
defer a.mu.Unlock()
|
|
||||||
|
|
||||||
err = a.fill()
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
n = copy(p, a.cur.buffer())
|
|
||||||
a.cur.increment(n)
|
|
||||||
|
|
||||||
if a.cur.isEmpty() {
|
|
||||||
a.err = a.cur.err
|
|
||||||
return n, a.err
|
|
||||||
}
|
|
||||||
return n, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *AsyncReader) StopBuffering() {
|
|
||||||
select {
|
|
||||||
case <-a.exit:
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
close(a.exit)
|
|
||||||
<-a.exited
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *AsyncReader) Abandon() {
|
|
||||||
a.StopBuffering()
|
|
||||||
a.mu.Lock()
|
|
||||||
defer a.mu.Unlock()
|
|
||||||
if a.cur != nil {
|
|
||||||
a.putBuffer(a.cur)
|
|
||||||
a.cur = nil
|
|
||||||
}
|
|
||||||
for b := range a.ready {
|
|
||||||
a.putBuffer(b)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *AsyncReader) Close() (err error) {
|
|
||||||
a.Abandon()
|
|
||||||
if a.closed {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
a.closed = true
|
|
||||||
return a.in.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
type buffer struct {
|
|
||||||
buf []byte
|
|
||||||
err error
|
|
||||||
offset int
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *buffer) isEmpty() bool {
|
|
||||||
if b == nil {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
if len(b.buf)-b.offset <= 0 {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *buffer) readFill(r io.Reader, buf []byte) (n int, err error) {
|
|
||||||
var nn int
|
|
||||||
for n < len(buf) && err == nil {
|
|
||||||
nn, err = r.Read(buf[n:])
|
|
||||||
n += nn
|
|
||||||
}
|
|
||||||
return n, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *buffer) read(rd io.Reader) error {
|
|
||||||
var n int
|
|
||||||
n, b.err = b.readFill(rd, b.buf)
|
|
||||||
b.buf = b.buf[0:n]
|
|
||||||
b.offset = 0
|
|
||||||
return b.err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *buffer) buffer() []byte {
|
|
||||||
return b.buf[b.offset:]
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *buffer) increment(n int) {
|
|
||||||
b.offset += n
|
|
||||||
}
|
|
|
@ -6,36 +6,47 @@ import (
|
||||||
|
|
||||||
"github.com/divyam234/teldrive/internal/config"
|
"github.com/divyam234/teldrive/internal/config"
|
||||||
"github.com/divyam234/teldrive/internal/crypt"
|
"github.com/divyam234/teldrive/internal/crypt"
|
||||||
|
"github.com/divyam234/teldrive/internal/tgc"
|
||||||
"github.com/divyam234/teldrive/pkg/types"
|
"github.com/divyam234/teldrive/pkg/types"
|
||||||
"github.com/gotd/td/tg"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type decrpytedReader struct {
|
type decrpytedReader struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
parts []types.Part
|
parts []types.Part
|
||||||
ranges []types.Range
|
ranges []types.Range
|
||||||
pos int
|
pos int
|
||||||
client *tg.Client
|
reader io.ReadCloser
|
||||||
reader io.ReadCloser
|
limit int64
|
||||||
limit int64
|
config *config.TGConfig
|
||||||
err error
|
channelId int64
|
||||||
config *config.TGConfig
|
worker *tgc.StreamWorker
|
||||||
|
client *tgc.Client
|
||||||
|
fileId string
|
||||||
|
concurrency int
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDecryptedReader(
|
func NewDecryptedReader(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
client *tg.Client,
|
fileId string,
|
||||||
parts []types.Part,
|
parts []types.Part,
|
||||||
start, end int64,
|
start, end int64,
|
||||||
config *config.TGConfig) (io.ReadCloser, error) {
|
channelId int64,
|
||||||
|
config *config.TGConfig,
|
||||||
|
concurrency int,
|
||||||
|
client *tgc.Client,
|
||||||
|
worker *tgc.StreamWorker) (io.ReadCloser, error) {
|
||||||
|
|
||||||
r := &decrpytedReader{
|
r := &decrpytedReader{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
parts: parts,
|
parts: parts,
|
||||||
client: client,
|
limit: end - start + 1,
|
||||||
limit: end - start + 1,
|
ranges: calculatePartByteRanges(start, end, parts[0].DecryptedSize),
|
||||||
ranges: calculatePartByteRanges(start, end, parts[0].DecryptedSize),
|
config: config,
|
||||||
config: config,
|
client: client,
|
||||||
|
worker: worker,
|
||||||
|
channelId: channelId,
|
||||||
|
fileId: fileId,
|
||||||
|
concurrency: concurrency,
|
||||||
}
|
}
|
||||||
res, err := r.nextPart()
|
res, err := r.nextPart()
|
||||||
|
|
||||||
|
@ -51,30 +62,24 @@ func NewDecryptedReader(
|
||||||
|
|
||||||
func (r *decrpytedReader) Read(p []byte) (n int, err error) {
|
func (r *decrpytedReader) Read(p []byte) (n int, err error) {
|
||||||
|
|
||||||
if r.err != nil {
|
|
||||||
return 0, r.err
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.limit <= 0 {
|
if r.limit <= 0 {
|
||||||
return 0, io.EOF
|
return 0, io.EOF
|
||||||
}
|
}
|
||||||
|
|
||||||
n, err = r.reader.Read(p)
|
n, err = r.reader.Read(p)
|
||||||
|
r.limit -= int64(n)
|
||||||
if err == nil {
|
|
||||||
r.limit -= int64(n)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
if r.limit > 0 {
|
if r.limit > 0 {
|
||||||
err = nil
|
err = nil
|
||||||
|
if r.reader != nil {
|
||||||
|
r.reader.Close()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
r.pos++
|
r.pos++
|
||||||
if r.pos < len(r.ranges) {
|
if r.pos < len(r.ranges) {
|
||||||
r.reader, err = r.nextPart()
|
r.reader, err = r.nextPart()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
r.err = err
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -89,7 +94,6 @@ func (r *decrpytedReader) Close() (err error) {
|
||||||
|
|
||||||
func (r *decrpytedReader) nextPart() (io.ReadCloser, error) {
|
func (r *decrpytedReader) nextPart() (io.ReadCloser, error) {
|
||||||
|
|
||||||
location := r.parts[r.ranges[r.pos].PartNo].Location
|
|
||||||
start := r.ranges[r.pos].Start
|
start := r.ranges[r.pos].Start
|
||||||
end := r.ranges[r.pos].End
|
end := r.ranges[r.pos].End
|
||||||
salt := r.parts[r.ranges[r.pos].PartNo].Salt
|
salt := r.parts[r.ranges[r.pos].PartNo].Salt
|
||||||
|
@ -99,21 +103,15 @@ func (r *decrpytedReader) nextPart() (io.ReadCloser, error) {
|
||||||
func(ctx context.Context,
|
func(ctx context.Context,
|
||||||
underlyingOffset,
|
underlyingOffset,
|
||||||
underlyingLimit int64) (io.ReadCloser, error) {
|
underlyingLimit int64) (io.ReadCloser, error) {
|
||||||
|
|
||||||
var end int64
|
var end int64
|
||||||
|
|
||||||
if underlyingLimit >= 0 {
|
if underlyingLimit >= 0 {
|
||||||
end = min(r.parts[r.ranges[r.pos].PartNo].Size-1, underlyingOffset+underlyingLimit-1)
|
end = min(r.parts[r.ranges[r.pos].PartNo].Size-1, underlyingOffset+underlyingLimit-1)
|
||||||
}
|
}
|
||||||
rd, err := newTGReader(r.ctx, r.client, location, underlyingOffset, end)
|
chunkSrc := &chunkSource{channelId: r.channelId, worker: r.worker,
|
||||||
if err != nil {
|
fileId: r.fileId, partId: r.parts[r.ranges[r.pos].PartNo].ID,
|
||||||
return nil, err
|
client: r.client, concurrency: r.concurrency}
|
||||||
}
|
return newTGReader(r.ctx, start, end, r.config, chunkSrc)
|
||||||
if r.config.Stream.BufferReader {
|
|
||||||
return NewAsyncReader(r.ctx, rd, r.config.Stream.Buffers)
|
|
||||||
|
|
||||||
}
|
|
||||||
return rd, nil
|
|
||||||
|
|
||||||
}, start, end-start+1)
|
}, start, end-start+1)
|
||||||
|
|
|
@ -5,8 +5,8 @@ import (
|
||||||
"io"
|
"io"
|
||||||
|
|
||||||
"github.com/divyam234/teldrive/internal/config"
|
"github.com/divyam234/teldrive/internal/config"
|
||||||
|
"github.com/divyam234/teldrive/internal/tgc"
|
||||||
"github.com/divyam234/teldrive/pkg/types"
|
"github.com/divyam234/teldrive/pkg/types"
|
||||||
"github.com/gotd/td/tg"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func calculatePartByteRanges(startByte, endByte, partSize int64) []types.Range {
|
func calculatePartByteRanges(startByte, endByte, partSize int64) []types.Range {
|
||||||
|
@ -37,31 +37,42 @@ func calculatePartByteRanges(startByte, endByte, partSize int64) []types.Range {
|
||||||
}
|
}
|
||||||
|
|
||||||
type linearReader struct {
|
type linearReader struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
parts []types.Part
|
parts []types.Part
|
||||||
ranges []types.Range
|
ranges []types.Range
|
||||||
pos int
|
pos int
|
||||||
client *tg.Client
|
reader io.ReadCloser
|
||||||
reader io.ReadCloser
|
limit int64
|
||||||
limit int64
|
config *config.TGConfig
|
||||||
err error
|
channelId int64
|
||||||
config *config.TGConfig
|
worker *tgc.StreamWorker
|
||||||
|
client *tgc.Client
|
||||||
|
fileId string
|
||||||
|
concurrency int
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewLinearReader(ctx context.Context,
|
func NewLinearReader(ctx context.Context,
|
||||||
client *tg.Client,
|
fileId string,
|
||||||
parts []types.Part,
|
parts []types.Part,
|
||||||
start, end int64,
|
start, end int64,
|
||||||
|
channelId int64,
|
||||||
config *config.TGConfig,
|
config *config.TGConfig,
|
||||||
|
concurrency int,
|
||||||
|
client *tgc.Client,
|
||||||
|
worker *tgc.StreamWorker,
|
||||||
) (reader io.ReadCloser, err error) {
|
) (reader io.ReadCloser, err error) {
|
||||||
|
|
||||||
r := &linearReader{
|
r := &linearReader{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
parts: parts,
|
parts: parts,
|
||||||
client: client,
|
limit: end - start + 1,
|
||||||
limit: end - start + 1,
|
ranges: calculatePartByteRanges(start, end, parts[0].Size),
|
||||||
ranges: calculatePartByteRanges(start, end, parts[0].Size),
|
config: config,
|
||||||
config: config,
|
client: client,
|
||||||
|
worker: worker,
|
||||||
|
channelId: channelId,
|
||||||
|
fileId: fileId,
|
||||||
|
concurrency: concurrency,
|
||||||
}
|
}
|
||||||
|
|
||||||
r.reader, err = r.nextPart()
|
r.reader, err = r.nextPart()
|
||||||
|
@ -73,25 +84,22 @@ func NewLinearReader(ctx context.Context,
|
||||||
return r, nil
|
return r, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *linearReader) Read(p []byte) (n int, err error) {
|
func (r *linearReader) Read(p []byte) (int, error) {
|
||||||
|
|
||||||
if r.err != nil {
|
|
||||||
return 0, r.err
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.limit <= 0 {
|
if r.limit <= 0 {
|
||||||
return 0, io.EOF
|
return 0, io.EOF
|
||||||
}
|
}
|
||||||
|
|
||||||
n, err = r.reader.Read(p)
|
n, err := r.reader.Read(p)
|
||||||
|
|
||||||
if err == nil {
|
r.limit -= int64(n)
|
||||||
r.limit -= int64(n)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
if r.limit > 0 {
|
if r.limit > 0 {
|
||||||
err = nil
|
err = nil
|
||||||
|
if r.reader != nil {
|
||||||
|
r.reader.Close()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
r.pos++
|
r.pos++
|
||||||
if r.pos < len(r.ranges) {
|
if r.pos < len(r.ranges) {
|
||||||
|
@ -99,24 +107,18 @@ func (r *linearReader) Read(p []byte) (n int, err error) {
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
r.err = err
|
return n, err
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *linearReader) nextPart() (io.ReadCloser, error) {
|
func (r *linearReader) nextPart() (io.ReadCloser, error) {
|
||||||
|
|
||||||
location := r.parts[r.ranges[r.pos].PartNo].Location
|
start := r.ranges[r.pos].Start
|
||||||
startByte := r.ranges[r.pos].Start
|
end := r.ranges[r.pos].End
|
||||||
endByte := r.ranges[r.pos].End
|
|
||||||
rd, err := newTGReader(r.ctx, r.client, location, startByte, endByte)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if r.config.Stream.BufferReader {
|
|
||||||
return NewAsyncReader(r.ctx, rd, r.config.Stream.Buffers)
|
|
||||||
|
|
||||||
}
|
chunkSrc := &chunkSource{channelId: r.channelId, worker: r.worker,
|
||||||
return rd, nil
|
fileId: r.fileId, partId: r.parts[r.ranges[r.pos].PartNo].ID,
|
||||||
|
client: r.client, concurrency: r.concurrency}
|
||||||
|
return newTGReader(r.ctx, start, end, r.config, chunkSrc)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
311
internal/reader/tg_reader.go
Normal file
311
internal/reader/tg_reader.go
Normal file
|
@ -0,0 +1,311 @@
|
||||||
|
package reader
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/divyam234/teldrive/internal/config"
|
||||||
|
"github.com/divyam234/teldrive/internal/tgc"
|
||||||
|
"github.com/gotd/td/tg"
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrorStreamAbandoned = errors.New("stream abandoned")
|
||||||
|
|
||||||
|
type ChunkSource interface {
|
||||||
|
Chunk(ctx context.Context, offset int64, limit int64) ([]byte, error)
|
||||||
|
ChunkSize(start, end int64) int64
|
||||||
|
}
|
||||||
|
|
||||||
|
type chunkSource struct {
|
||||||
|
channelId int64
|
||||||
|
worker *tgc.StreamWorker
|
||||||
|
fileId string
|
||||||
|
partId int64
|
||||||
|
concurrency int
|
||||||
|
client *tgc.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *chunkSource) ChunkSize(start, end int64) int64 {
|
||||||
|
return tgc.CalculateChunkSize(start, end)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *chunkSource) Chunk(ctx context.Context, offset int64, limit int64) ([]byte, error) {
|
||||||
|
var (
|
||||||
|
location *tg.InputDocumentFileLocation
|
||||||
|
err error
|
||||||
|
client *tgc.Client
|
||||||
|
)
|
||||||
|
|
||||||
|
client = c.client
|
||||||
|
|
||||||
|
if c.concurrency > 0 {
|
||||||
|
client, _, _ = c.worker.Next(c.channelId)
|
||||||
|
}
|
||||||
|
location, err = tgc.GetLocation(ctx, client, c.fileId, c.channelId, c.partId)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return tgc.GetChunk(ctx, client.Tg.API(), location, offset, limit)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
type tgReader struct {
|
||||||
|
ctx context.Context
|
||||||
|
offset int64
|
||||||
|
limit int64
|
||||||
|
chunkSize int64
|
||||||
|
bufferChan chan *buffer
|
||||||
|
done chan struct{}
|
||||||
|
cur *buffer
|
||||||
|
err chan error
|
||||||
|
mu sync.Mutex
|
||||||
|
concurrency int
|
||||||
|
leftCut int64
|
||||||
|
rightCut int64
|
||||||
|
totalParts int
|
||||||
|
currentPart int
|
||||||
|
closed bool
|
||||||
|
timeout time.Duration
|
||||||
|
chunkSrc ChunkSource
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTGReader(
|
||||||
|
ctx context.Context,
|
||||||
|
start int64,
|
||||||
|
end int64,
|
||||||
|
config *config.TGConfig,
|
||||||
|
chunkSrc ChunkSource,
|
||||||
|
|
||||||
|
) (*tgReader, error) {
|
||||||
|
|
||||||
|
chunkSize := chunkSrc.ChunkSize(start, end)
|
||||||
|
|
||||||
|
offset := start - (start % chunkSize)
|
||||||
|
|
||||||
|
r := &tgReader{
|
||||||
|
ctx: ctx,
|
||||||
|
limit: end - start + 1,
|
||||||
|
bufferChan: make(chan *buffer, config.Stream.Buffers),
|
||||||
|
concurrency: config.Stream.MultiThreads,
|
||||||
|
leftCut: start - offset,
|
||||||
|
rightCut: (end % chunkSize) + 1,
|
||||||
|
totalParts: int((end - offset + chunkSize) / chunkSize),
|
||||||
|
offset: offset,
|
||||||
|
chunkSize: chunkSize,
|
||||||
|
chunkSrc: chunkSrc,
|
||||||
|
timeout: config.Stream.ChunkTimeout,
|
||||||
|
done: make(chan struct{}, 1),
|
||||||
|
err: make(chan error, 1),
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.concurrency == 0 {
|
||||||
|
r.currentPart = 1
|
||||||
|
go r.fillBufferSequentially()
|
||||||
|
} else {
|
||||||
|
go r.fillBufferConcurrently()
|
||||||
|
}
|
||||||
|
|
||||||
|
return r, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *tgReader) Close() error {
|
||||||
|
close(r.done)
|
||||||
|
close(r.err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *tgReader) Read(p []byte) (int, error) {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
|
if r.limit <= 0 {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.cur.isEmpty() {
|
||||||
|
if r.cur != nil {
|
||||||
|
r.cur = nil
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case cur, ok := <-r.bufferChan:
|
||||||
|
if !ok && r.limit > 0 {
|
||||||
|
return 0, ErrorStreamAbandoned
|
||||||
|
}
|
||||||
|
r.cur = cur
|
||||||
|
|
||||||
|
case err := <-r.err:
|
||||||
|
return 0, fmt.Errorf("error reading chunk: %w", err)
|
||||||
|
case <-r.ctx.Done():
|
||||||
|
return 0, r.ctx.Err()
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
n := copy(p, r.cur.buffer())
|
||||||
|
r.cur.increment(n)
|
||||||
|
r.limit -= int64(n)
|
||||||
|
|
||||||
|
if r.limit <= 0 {
|
||||||
|
return n, io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *tgReader) fillBufferConcurrently() error {
|
||||||
|
|
||||||
|
var mapMu sync.Mutex
|
||||||
|
|
||||||
|
bufferMap := make(map[int]*buffer)
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
close(r.bufferChan)
|
||||||
|
r.closed = true
|
||||||
|
for i := range bufferMap {
|
||||||
|
delete(bufferMap, i)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
cb := func(ctx context.Context, i int) func() error {
|
||||||
|
return func() error {
|
||||||
|
|
||||||
|
chunk, err := r.chunkSrc.Chunk(ctx, r.offset+(int64(i)*r.chunkSize), r.chunkSize)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if r.totalParts == 1 {
|
||||||
|
chunk = chunk[r.leftCut:r.rightCut]
|
||||||
|
} else if r.currentPart+i+1 == 1 {
|
||||||
|
chunk = chunk[r.leftCut:]
|
||||||
|
} else if r.currentPart+i+1 == r.totalParts {
|
||||||
|
chunk = chunk[:r.rightCut]
|
||||||
|
}
|
||||||
|
buf := &buffer{buf: chunk}
|
||||||
|
mapMu.Lock()
|
||||||
|
bufferMap[i] = buf
|
||||||
|
mapMu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
|
||||||
|
g := errgroup.Group{}
|
||||||
|
|
||||||
|
g.SetLimit(r.concurrency)
|
||||||
|
|
||||||
|
for i := range r.concurrency {
|
||||||
|
if r.currentPart+i+1 <= r.totalParts {
|
||||||
|
g.Go(cb(r.ctx, i))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
done := make(chan error, 1)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
done <- g.Wait()
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-done:
|
||||||
|
if err != nil {
|
||||||
|
r.err <- err
|
||||||
|
return nil
|
||||||
|
} else {
|
||||||
|
for i := range r.concurrency {
|
||||||
|
if r.currentPart+i+1 <= r.totalParts {
|
||||||
|
r.bufferChan <- bufferMap[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
r.currentPart += r.concurrency
|
||||||
|
r.offset += r.chunkSize * int64(r.concurrency)
|
||||||
|
for i := range bufferMap {
|
||||||
|
delete(bufferMap, i)
|
||||||
|
}
|
||||||
|
if r.currentPart >= r.totalParts {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case <-time.After(r.timeout):
|
||||||
|
return nil
|
||||||
|
case <-r.done:
|
||||||
|
return nil
|
||||||
|
case <-r.ctx.Done():
|
||||||
|
return r.ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *tgReader) fillBufferSequentially() error {
|
||||||
|
|
||||||
|
defer close(r.bufferChan)
|
||||||
|
|
||||||
|
fetchChunk := func(ctx context.Context) (*buffer, error) {
|
||||||
|
chunk, err := r.chunkSrc.Chunk(ctx, r.offset, r.chunkSize)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if r.totalParts == 1 {
|
||||||
|
chunk = chunk[r.leftCut:r.rightCut]
|
||||||
|
} else if r.currentPart == 1 {
|
||||||
|
chunk = chunk[r.leftCut:]
|
||||||
|
} else if r.currentPart == r.totalParts {
|
||||||
|
chunk = chunk[:r.rightCut]
|
||||||
|
}
|
||||||
|
return &buffer{buf: chunk}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-r.done:
|
||||||
|
return nil
|
||||||
|
case <-r.ctx.Done():
|
||||||
|
return r.ctx.Err()
|
||||||
|
case <-time.After(r.timeout):
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
buf, err := fetchChunk(r.ctx)
|
||||||
|
if err != nil {
|
||||||
|
r.err <- err
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
r.bufferChan <- buf
|
||||||
|
r.currentPart++
|
||||||
|
r.offset += r.chunkSize
|
||||||
|
if r.currentPart > r.totalParts {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type buffer struct {
|
||||||
|
buf []byte
|
||||||
|
offset int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *buffer) isEmpty() bool {
|
||||||
|
if b == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if len(b.buf)-b.offset <= 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *buffer) buffer() []byte {
|
||||||
|
return b.buf[b.offset:]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *buffer) increment(n int) {
|
||||||
|
b.offset += n
|
||||||
|
}
|
142
internal/reader/tg_reader_test.go
Normal file
142
internal/reader/tg_reader_test.go
Normal file
|
@ -0,0 +1,142 @@
|
||||||
|
package reader
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"io"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/divyam234/teldrive/internal/config"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/suite"
|
||||||
|
)
|
||||||
|
|
||||||
|
type testChunkSource struct {
|
||||||
|
buffer []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testChunkSource) Chunk(ctx context.Context, offset int64, limit int64) ([]byte, error) {
|
||||||
|
return m.buffer[offset : offset+limit], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testChunkSource) ChunkSize(start, end int64) int64 {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
type testChunkSourceTimeout struct {
|
||||||
|
buffer []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testChunkSourceTimeout) Chunk(ctx context.Context, offset int64, limit int64) ([]byte, error) {
|
||||||
|
if offset == 8 {
|
||||||
|
time.Sleep(2 * time.Second)
|
||||||
|
}
|
||||||
|
return m.buffer[offset : offset+limit], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testChunkSourceTimeout) ChunkSize(start, end int64) int64 {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
type TestSuite struct {
|
||||||
|
suite.Suite
|
||||||
|
config *config.TGConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *TestSuite) SetupTest() {
|
||||||
|
suite.config = &config.TGConfig{Stream: struct {
|
||||||
|
MultiThreads int
|
||||||
|
Buffers int
|
||||||
|
ChunkTimeout time.Duration
|
||||||
|
}{MultiThreads: 8, Buffers: 10, ChunkTimeout: 1 * time.Second}}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *TestSuite) TestFullRead() {
|
||||||
|
ctx := context.Background()
|
||||||
|
start := int64(0)
|
||||||
|
end := int64(99)
|
||||||
|
data := make([]byte, 100)
|
||||||
|
rand.Read(data)
|
||||||
|
chunkSrc := &testChunkSource{buffer: data}
|
||||||
|
reader, err := newTGReader(ctx, start, end, suite.config, chunkSrc)
|
||||||
|
assert.NoError(suite.T(), err)
|
||||||
|
test_data, err := io.ReadAll(reader)
|
||||||
|
assert.Equal(suite.T(), nil, err)
|
||||||
|
assert.Equal(suite.T(), data[start:end+1], test_data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *TestSuite) TestPartialRead() {
|
||||||
|
ctx := context.Background()
|
||||||
|
start := int64(0)
|
||||||
|
end := int64(65)
|
||||||
|
data := make([]byte, 100)
|
||||||
|
rand.Read(data)
|
||||||
|
chunkSrc := &testChunkSource{buffer: data}
|
||||||
|
reader, err := newTGReader(ctx, start, end, suite.config, chunkSrc)
|
||||||
|
assert.NoError(suite.T(), err)
|
||||||
|
test_data, err := io.ReadAll(reader)
|
||||||
|
assert.NoError(suite.T(), err)
|
||||||
|
assert.Equal(suite.T(), data[start:end+1], test_data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *TestSuite) TestTimeout() {
|
||||||
|
ctx := context.Background()
|
||||||
|
start := int64(0)
|
||||||
|
end := int64(65)
|
||||||
|
data := make([]byte, 100)
|
||||||
|
rand.Read(data)
|
||||||
|
chunkSrc := &testChunkSourceTimeout{buffer: data}
|
||||||
|
reader, err := newTGReader(ctx, start, end, suite.config, chunkSrc)
|
||||||
|
assert.NoError(suite.T(), err)
|
||||||
|
test_data, err := io.ReadAll(reader)
|
||||||
|
assert.Greater(suite.T(), len(test_data), 0)
|
||||||
|
assert.Equal(suite.T(), err, ErrorStreamAbandoned)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *TestSuite) TestClose() {
|
||||||
|
ctx := context.Background()
|
||||||
|
start := int64(0)
|
||||||
|
end := int64(65)
|
||||||
|
data := make([]byte, 100)
|
||||||
|
rand.Read(data)
|
||||||
|
chunkSrc := &testChunkSource{buffer: data}
|
||||||
|
reader, err := newTGReader(ctx, start, end, suite.config, chunkSrc)
|
||||||
|
assert.NoError(suite.T(), err)
|
||||||
|
_, err = io.ReadAll(reader)
|
||||||
|
assert.NoError(suite.T(), err)
|
||||||
|
assert.NoError(suite.T(), reader.Close())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *TestSuite) TestCancellation() {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
start := int64(0)
|
||||||
|
end := int64(65)
|
||||||
|
data := make([]byte, 100)
|
||||||
|
rand.Read(data)
|
||||||
|
chunkSrc := &testChunkSource{buffer: data}
|
||||||
|
reader, err := newTGReader(ctx, start, end, suite.config, chunkSrc)
|
||||||
|
assert.NoError(suite.T(), err)
|
||||||
|
cancel()
|
||||||
|
_, err = io.ReadAll(reader)
|
||||||
|
assert.Equal(suite.T(), err, context.Canceled)
|
||||||
|
assert.Equal(suite.T(), len(reader.bufferChan), 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *TestSuite) TestCancellationWithTimeout() {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||||
|
_ = cancel
|
||||||
|
start := int64(0)
|
||||||
|
end := int64(65)
|
||||||
|
data := make([]byte, 100)
|
||||||
|
rand.Read(data)
|
||||||
|
chunkSrc := &testChunkSourceTimeout{buffer: data}
|
||||||
|
reader, err := newTGReader(ctx, start, end, suite.config, chunkSrc)
|
||||||
|
assert.NoError(suite.T(), err)
|
||||||
|
_, err = io.ReadAll(reader)
|
||||||
|
assert.Equal(suite.T(), err, context.DeadlineExceeded)
|
||||||
|
assert.Equal(suite.T(), len(reader.bufferChan), 0)
|
||||||
|
}
|
||||||
|
func Test(t *testing.T) {
|
||||||
|
suite.Run(t, new(TestSuite))
|
||||||
|
}
|
|
@ -1,144 +0,0 @@
|
||||||
package reader
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
|
|
||||||
"github.com/gotd/td/tg"
|
|
||||||
)
|
|
||||||
|
|
||||||
type tgReader struct {
|
|
||||||
ctx context.Context
|
|
||||||
client *tg.Client
|
|
||||||
location *tg.InputDocumentFileLocation
|
|
||||||
start int64
|
|
||||||
end int64
|
|
||||||
next func() ([]byte, error)
|
|
||||||
buffer []byte
|
|
||||||
limit int64
|
|
||||||
chunkSize int64
|
|
||||||
i int64
|
|
||||||
}
|
|
||||||
|
|
||||||
func calculateChunkSize(start, end int64) int64 {
|
|
||||||
chunkSize := int64(1024 * 1024)
|
|
||||||
|
|
||||||
for chunkSize > 1024 && chunkSize > (end-start) {
|
|
||||||
chunkSize /= 2
|
|
||||||
}
|
|
||||||
|
|
||||||
return chunkSize
|
|
||||||
}
|
|
||||||
|
|
||||||
func newTGReader(
|
|
||||||
ctx context.Context,
|
|
||||||
client *tg.Client,
|
|
||||||
location *tg.InputDocumentFileLocation,
|
|
||||||
start int64,
|
|
||||||
end int64,
|
|
||||||
|
|
||||||
) (io.ReadCloser, error) {
|
|
||||||
|
|
||||||
r := &tgReader{
|
|
||||||
ctx: ctx,
|
|
||||||
location: location,
|
|
||||||
client: client,
|
|
||||||
start: start,
|
|
||||||
end: end,
|
|
||||||
chunkSize: calculateChunkSize(start, end),
|
|
||||||
limit: end - start + 1,
|
|
||||||
}
|
|
||||||
r.next = r.partStream()
|
|
||||||
return r, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *tgReader) Read(p []byte) (n int, err error) {
|
|
||||||
|
|
||||||
if r.limit <= 0 {
|
|
||||||
return 0, io.EOF
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.i >= int64(len(r.buffer)) {
|
|
||||||
r.buffer, err = r.next()
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
if len(r.buffer) == 0 {
|
|
||||||
r.next = r.partStream()
|
|
||||||
r.buffer, err = r.next()
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
r.i = 0
|
|
||||||
}
|
|
||||||
n = copy(p, r.buffer[r.i:])
|
|
||||||
r.i += int64(n)
|
|
||||||
r.limit -= int64(n)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (*tgReader) Close() error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *tgReader) chunk(offset int64, limit int64) ([]byte, error) {
|
|
||||||
|
|
||||||
req := &tg.UploadGetFileRequest{
|
|
||||||
Offset: offset,
|
|
||||||
Limit: int(limit),
|
|
||||||
Location: r.location,
|
|
||||||
Precise: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
res, err := r.client.UploadGetFile(r.ctx, req)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
switch result := res.(type) {
|
|
||||||
case *tg.UploadFile:
|
|
||||||
return result.Bytes, nil
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("unexpected type %T", r)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *tgReader) partStream() func() ([]byte, error) {
|
|
||||||
|
|
||||||
start := r.start
|
|
||||||
end := r.end
|
|
||||||
offset := start - (start % r.chunkSize)
|
|
||||||
|
|
||||||
leftCut := start - offset
|
|
||||||
rightCut := (end % r.chunkSize) + 1
|
|
||||||
totalParts := int((end - offset + r.chunkSize) / r.chunkSize)
|
|
||||||
currentPart := 1
|
|
||||||
|
|
||||||
return func() ([]byte, error) {
|
|
||||||
if currentPart > totalParts {
|
|
||||||
return make([]byte, 0), nil
|
|
||||||
}
|
|
||||||
res, err := r.chunk(offset, r.chunkSize)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if len(res) == 0 {
|
|
||||||
return res, nil
|
|
||||||
} else if totalParts == 1 {
|
|
||||||
res = res[leftCut:rightCut]
|
|
||||||
} else if currentPart == 1 {
|
|
||||||
res = res[leftCut:]
|
|
||||||
} else if currentPart == totalParts {
|
|
||||||
res = res[:rightCut]
|
|
||||||
}
|
|
||||||
|
|
||||||
currentPart++
|
|
||||||
offset += r.chunkSize
|
|
||||||
return res, nil
|
|
||||||
}
|
|
||||||
}
|
|
239
internal/tgc/helpers.go
Normal file
239
internal/tgc/helpers.go
Normal file
|
@ -0,0 +1,239 @@
|
||||||
|
package tgc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
"runtime"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/divyam234/teldrive/internal/cache"
|
||||||
|
"github.com/divyam234/teldrive/pkg/types"
|
||||||
|
"github.com/gotd/td/telegram"
|
||||||
|
"github.com/gotd/td/tg"
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrInValidChannelID = errors.New("invalid channel id")
|
||||||
|
ErrInvalidChannelMessages = errors.New("invalid channel messages")
|
||||||
|
)
|
||||||
|
|
||||||
|
func GetChannelById(ctx context.Context, client *tg.Client, channelId int64) (*tg.InputChannel, error) {
|
||||||
|
inputChannel := &tg.InputChannel{
|
||||||
|
ChannelID: channelId,
|
||||||
|
}
|
||||||
|
channels, err := client.ChannelsGetChannels(ctx, []tg.InputChannelClass{inputChannel})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(channels.GetChats()) == 0 {
|
||||||
|
return nil, ErrInValidChannelID
|
||||||
|
}
|
||||||
|
return channels.GetChats()[0].(*tg.Channel).AsInput(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func DeleteMessages(ctx context.Context, client *telegram.Client, channelId int64, ids []int) error {
|
||||||
|
|
||||||
|
return RunWithAuth(ctx, client, "", func(ctx context.Context) error {
|
||||||
|
channel, err := GetChannelById(ctx, client.API(), channelId)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
batchSize := 100
|
||||||
|
|
||||||
|
batchCount := int(math.Ceil(float64(len(ids)) / float64(batchSize)))
|
||||||
|
|
||||||
|
g, _ := errgroup.WithContext(ctx)
|
||||||
|
|
||||||
|
g.SetLimit(runtime.NumCPU())
|
||||||
|
|
||||||
|
for i := 0; i < batchCount; i++ {
|
||||||
|
start := i * batchSize
|
||||||
|
end := min((i+1)*batchSize, len(ids))
|
||||||
|
batchIds := ids[start:end]
|
||||||
|
g.Go(func() error {
|
||||||
|
messageDeleteRequest := tg.ChannelsDeleteMessagesRequest{Channel: channel, ID: batchIds}
|
||||||
|
_, err = client.API().ChannelsDeleteMessages(ctx, &messageDeleteRequest)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return g.Wait()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func getTGMessagesBatch(ctx context.Context, client *tg.Client, channel *tg.InputChannel, ids []int) (tg.MessagesMessagesClass, error) {
|
||||||
|
|
||||||
|
msgIds := []tg.InputMessageClass{}
|
||||||
|
|
||||||
|
for _, id := range ids {
|
||||||
|
msgIds = append(msgIds, &tg.InputMessageID{ID: id})
|
||||||
|
}
|
||||||
|
|
||||||
|
messageRequest := tg.ChannelsGetMessagesRequest{
|
||||||
|
Channel: channel,
|
||||||
|
ID: msgIds,
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := client.ChannelsGetMessages(ctx, &messageRequest)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return res, nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetMessages(ctx context.Context, client *tg.Client, ids []int, channelId int64) ([]tg.MessageClass, error) {
|
||||||
|
|
||||||
|
channel, err := GetChannelById(ctx, client, channelId)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
batchSize := 200
|
||||||
|
|
||||||
|
batchCount := int(math.Ceil(float64(len(ids)) / float64(batchSize)))
|
||||||
|
|
||||||
|
g, _ := errgroup.WithContext(ctx)
|
||||||
|
|
||||||
|
g.SetLimit(runtime.NumCPU())
|
||||||
|
|
||||||
|
messageMap := make(map[int]*tg.MessagesChannelMessages)
|
||||||
|
|
||||||
|
var mapMu sync.Mutex
|
||||||
|
|
||||||
|
for i := range batchCount {
|
||||||
|
g.Go(func() error {
|
||||||
|
splitIds := ids[i*batchSize : min((i+1)*batchSize, len(ids))]
|
||||||
|
res, err := getTGMessagesBatch(ctx, client, channel, splitIds)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
messages, ok := res.(*tg.MessagesChannelMessages)
|
||||||
|
if !ok {
|
||||||
|
return ErrInvalidChannelMessages
|
||||||
|
}
|
||||||
|
mapMu.Lock()
|
||||||
|
messageMap[i] = messages
|
||||||
|
mapMu.Unlock()
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = g.Wait(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
allMessages := []tg.MessageClass{}
|
||||||
|
|
||||||
|
for i := range batchCount {
|
||||||
|
allMessages = append(allMessages, messageMap[i].Messages...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return allMessages, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetChunk(ctx context.Context, client *tg.Client, location tg.InputFileLocationClass, offset int64, limit int64) ([]byte, error) {
|
||||||
|
req := &tg.UploadGetFileRequest{
|
||||||
|
Offset: offset,
|
||||||
|
Limit: int(limit),
|
||||||
|
Location: location,
|
||||||
|
Precise: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
r, err := client.UploadGetFile(ctx, req)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch result := r.(type) {
|
||||||
|
case *tg.UploadFile:
|
||||||
|
return result.Bytes, nil
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unexpected type %T", r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetMediaContent(ctx context.Context, client *tg.Client, location tg.InputFileLocationClass) (*bytes.Buffer, error) {
|
||||||
|
offset := int64(0)
|
||||||
|
limit := int64(1024 * 1024)
|
||||||
|
buff := &bytes.Buffer{}
|
||||||
|
for {
|
||||||
|
r, err := GetChunk(ctx, client, location, offset, limit)
|
||||||
|
if err != nil {
|
||||||
|
return buff, err
|
||||||
|
}
|
||||||
|
if len(r) == 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
buff.Write(r)
|
||||||
|
offset += int64(limit)
|
||||||
|
}
|
||||||
|
return buff, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetBotInfo(ctx context.Context, client *telegram.Client, token string) (*types.BotInfo, error) {
|
||||||
|
var user *tg.User
|
||||||
|
err := RunWithAuth(ctx, client, token, func(ctx context.Context) error {
|
||||||
|
user, _ = client.Self(ctx)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &types.BotInfo{Id: user.ID, UserName: user.Username, Token: token}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetLocation(ctx context.Context, client *Client, fileId string, channelId int64, partId int64) (location *tg.InputDocumentFileLocation, err error) {
|
||||||
|
|
||||||
|
cache := cache.FromContext(ctx)
|
||||||
|
|
||||||
|
key := fmt.Sprintf("location:%s:%s:%d", client.UserId, fileId, partId)
|
||||||
|
|
||||||
|
err = cache.Get(key, location)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
channel, err := GetChannelById(ctx, client.Tg.API(), channelId)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
messageRequest := tg.ChannelsGetMessagesRequest{
|
||||||
|
Channel: channel,
|
||||||
|
ID: []tg.InputMessageClass{&tg.InputMessageID{ID: int(partId)}},
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := client.Tg.API().ChannelsGetMessages(ctx, &messageRequest)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
messages, _ := res.(*tg.MessagesChannelMessages)
|
||||||
|
item := messages.Messages[0].(*tg.Message)
|
||||||
|
media := item.Media.(*tg.MessageMediaDocument)
|
||||||
|
document := media.Document.(*tg.Document)
|
||||||
|
location = document.AsInputDocumentFileLocation()
|
||||||
|
cache.Set(key, location, 3600)
|
||||||
|
}
|
||||||
|
return location, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func CalculateChunkSize(start, end int64) int64 {
|
||||||
|
chunkSize := int64(1024 * 1024)
|
||||||
|
|
||||||
|
for chunkSize > 1024 && chunkSize > (end-start) {
|
||||||
|
chunkSize /= 2
|
||||||
|
}
|
||||||
|
return chunkSize
|
||||||
|
}
|
|
@ -55,8 +55,8 @@ func New(ctx context.Context, config *config.TGConfig, handler telegram.UpdateHa
|
||||||
LangCode: config.LangCode,
|
LangCode: config.LangCode,
|
||||||
},
|
},
|
||||||
SessionStorage: storage,
|
SessionStorage: storage,
|
||||||
RetryInterval: 5 * time.Second,
|
RetryInterval: 2 * time.Second,
|
||||||
MaxRetries: 5,
|
MaxRetries: 10,
|
||||||
DialTimeout: 10 * time.Second,
|
DialTimeout: 10 * time.Second,
|
||||||
Middlewares: middlewares,
|
Middlewares: middlewares,
|
||||||
UpdateHandler: handler,
|
UpdateHandler: handler,
|
||||||
|
|
|
@ -2,11 +2,11 @@ package tgc
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/divyam234/teldrive/internal/config"
|
"github.com/divyam234/teldrive/internal/config"
|
||||||
"github.com/divyam234/teldrive/internal/kv"
|
"github.com/divyam234/teldrive/internal/kv"
|
||||||
"github.com/divyam234/teldrive/internal/pool"
|
|
||||||
"github.com/gotd/td/telegram"
|
"github.com/gotd/td/telegram"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -42,9 +42,9 @@ func NewUploadWorker() *UploadWorker {
|
||||||
|
|
||||||
type Client struct {
|
type Client struct {
|
||||||
Tg *telegram.Client
|
Tg *telegram.Client
|
||||||
Pool pool.Pool
|
|
||||||
Stop StopFunc
|
Stop StopFunc
|
||||||
Status string
|
Status string
|
||||||
|
UserId string
|
||||||
}
|
}
|
||||||
|
|
||||||
type StreamWorker struct {
|
type StreamWorker struct {
|
||||||
|
@ -69,10 +69,7 @@ func (w *StreamWorker) Set(bots []string, channelId int64) {
|
||||||
for _, token := range bots {
|
for _, token := range bots {
|
||||||
middlewares := Middlewares(w.cnf, 5)
|
middlewares := Middlewares(w.cnf, 5)
|
||||||
client, _ := BotClient(w.ctx, w.kv, w.cnf, token, middlewares...)
|
client, _ := BotClient(w.ctx, w.kv, w.cnf, token, middlewares...)
|
||||||
c := &Client{Tg: client, Status: "idle"}
|
c := &Client{Tg: client, Status: "idle", UserId: strings.Split(token, ":")[0]}
|
||||||
if w.cnf.Stream.UsePooling {
|
|
||||||
c.Pool = pool.NewPool(client, int64(w.cnf.PoolSize), middlewares...)
|
|
||||||
}
|
|
||||||
w.clients[channelId] = append(w.clients[channelId], c)
|
w.clients[channelId] = append(w.clients[channelId], c)
|
||||||
}
|
}
|
||||||
w.currIdx[channelId] = 0
|
w.currIdx[channelId] = 0
|
||||||
|
@ -108,9 +105,6 @@ func (w *StreamWorker) UserWorker(session string, userId int64) (*Client, error)
|
||||||
middlewares := Middlewares(w.cnf, 5)
|
middlewares := Middlewares(w.cnf, 5)
|
||||||
client, _ := AuthClient(w.ctx, w.cnf, session, middlewares...)
|
client, _ := AuthClient(w.ctx, w.cnf, session, middlewares...)
|
||||||
c := &Client{Tg: client, Status: "idle"}
|
c := &Client{Tg: client, Status: "idle"}
|
||||||
if w.cnf.Stream.UsePooling {
|
|
||||||
c.Pool = pool.NewPool(client, int64(w.cnf.PoolSize), middlewares...)
|
|
||||||
}
|
|
||||||
w.clients[userId] = append(w.clients[userId], c)
|
w.clients[userId] = append(w.clients[userId], c)
|
||||||
}
|
}
|
||||||
nextClient := w.clients[userId][0]
|
nextClient := w.clients[userId][0]
|
||||||
|
|
|
@ -3,11 +3,11 @@ package controller
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/divyam234/teldrive/internal/auth"
|
||||||
"github.com/divyam234/teldrive/internal/cache"
|
"github.com/divyam234/teldrive/internal/cache"
|
||||||
"github.com/divyam234/teldrive/internal/logging"
|
"github.com/divyam234/teldrive/internal/logging"
|
||||||
"github.com/divyam234/teldrive/pkg/httputil"
|
"github.com/divyam234/teldrive/pkg/httputil"
|
||||||
"github.com/divyam234/teldrive/pkg/schemas"
|
"github.com/divyam234/teldrive/pkg/schemas"
|
||||||
"github.com/divyam234/teldrive/pkg/services"
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -23,7 +23,7 @@ func (fc *Controller) CreateFile(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
userId, _ := services.GetUserAuth(c)
|
userId, _ := auth.GetUser(c)
|
||||||
|
|
||||||
res, err := fc.FileService.CreateFile(c, userId, &fileIn)
|
res, err := fc.FileService.CreateFile(c, userId, &fileIn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -36,7 +36,7 @@ func (fc *Controller) CreateFile(c *gin.Context) {
|
||||||
|
|
||||||
func (fc *Controller) UpdateFile(c *gin.Context) {
|
func (fc *Controller) UpdateFile(c *gin.Context) {
|
||||||
|
|
||||||
userId, _ := services.GetUserAuth(c)
|
userId, _ := auth.GetUser(c)
|
||||||
|
|
||||||
var fileUpdate schemas.FileUpdate
|
var fileUpdate schemas.FileUpdate
|
||||||
|
|
||||||
|
@ -65,7 +65,7 @@ func (fc *Controller) GetFileByID(c *gin.Context) {
|
||||||
|
|
||||||
func (fc *Controller) ListFiles(c *gin.Context) {
|
func (fc *Controller) ListFiles(c *gin.Context) {
|
||||||
|
|
||||||
userId, _ := services.GetUserAuth(c)
|
userId, _ := auth.GetUser(c)
|
||||||
|
|
||||||
fquery := schemas.FileQuery{
|
fquery := schemas.FileQuery{
|
||||||
PerPage: 500,
|
PerPage: 500,
|
||||||
|
@ -90,7 +90,7 @@ func (fc *Controller) ListFiles(c *gin.Context) {
|
||||||
|
|
||||||
func (fc *Controller) MakeDirectory(c *gin.Context) {
|
func (fc *Controller) MakeDirectory(c *gin.Context) {
|
||||||
|
|
||||||
userId, _ := services.GetUserAuth(c)
|
userId, _ := auth.GetUser(c)
|
||||||
|
|
||||||
var payload schemas.MkDir
|
var payload schemas.MkDir
|
||||||
if err := c.ShouldBindJSON(&payload); err != nil {
|
if err := c.ShouldBindJSON(&payload); err != nil {
|
||||||
|
@ -118,7 +118,7 @@ func (fc *Controller) CopyFile(c *gin.Context) {
|
||||||
|
|
||||||
func (fc *Controller) MoveFiles(c *gin.Context) {
|
func (fc *Controller) MoveFiles(c *gin.Context) {
|
||||||
|
|
||||||
userId, _ := services.GetUserAuth(c)
|
userId, _ := auth.GetUser(c)
|
||||||
|
|
||||||
var payload schemas.FileOperation
|
var payload schemas.FileOperation
|
||||||
if err := c.ShouldBindJSON(&payload); err != nil {
|
if err := c.ShouldBindJSON(&payload); err != nil {
|
||||||
|
@ -136,7 +136,7 @@ func (fc *Controller) MoveFiles(c *gin.Context) {
|
||||||
|
|
||||||
func (fc *Controller) DeleteFiles(c *gin.Context) {
|
func (fc *Controller) DeleteFiles(c *gin.Context) {
|
||||||
|
|
||||||
userId, _ := services.GetUserAuth(c)
|
userId, _ := auth.GetUser(c)
|
||||||
|
|
||||||
var payload schemas.DeleteOperation
|
var payload schemas.DeleteOperation
|
||||||
if err := c.ShouldBindJSON(&payload); err != nil {
|
if err := c.ShouldBindJSON(&payload); err != nil {
|
||||||
|
@ -164,7 +164,7 @@ func (fc *Controller) DeleteFileParts(c *gin.Context) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fc *Controller) MoveDirectory(c *gin.Context) {
|
func (fc *Controller) MoveDirectory(c *gin.Context) {
|
||||||
userId, _ := services.GetUserAuth(c)
|
userId, _ := auth.GetUser(c)
|
||||||
|
|
||||||
var payload schemas.DirMove
|
var payload schemas.DirMove
|
||||||
if err := c.ShouldBindJSON(&payload); err != nil {
|
if err := c.ShouldBindJSON(&payload); err != nil {
|
||||||
|
@ -181,7 +181,7 @@ func (fc *Controller) MoveDirectory(c *gin.Context) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fc *Controller) GetCategoryStats(c *gin.Context) {
|
func (fc *Controller) GetCategoryStats(c *gin.Context) {
|
||||||
userId, _ := services.GetUserAuth(c)
|
userId, _ := auth.GetUser(c)
|
||||||
|
|
||||||
res, err := fc.FileService.GetCategoryStats(userId)
|
res, err := fc.FileService.GetCategoryStats(userId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -193,5 +193,9 @@ func (fc *Controller) GetCategoryStats(c *gin.Context) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fc *Controller) GetFileStream(c *gin.Context) {
|
func (fc *Controller) GetFileStream(c *gin.Context) {
|
||||||
fc.FileService.GetFileStream(c)
|
fc.FileService.GetFileStream(c, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fc *Controller) GetFileDownload(c *gin.Context) {
|
||||||
|
fc.FileService.GetFileStream(c, true)
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,8 +4,8 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/divyam234/teldrive/internal/auth"
|
||||||
"github.com/divyam234/teldrive/pkg/httputil"
|
"github.com/divyam234/teldrive/pkg/httputil"
|
||||||
"github.com/divyam234/teldrive/pkg/services"
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -40,7 +40,7 @@ func (uc *Controller) UploadFile(c *gin.Context) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (uc *Controller) UploadStats(c *gin.Context) {
|
func (uc *Controller) UploadStats(c *gin.Context) {
|
||||||
userId, _ := services.GetUserAuth(c)
|
userId, _ := auth.GetUser(c)
|
||||||
|
|
||||||
days := 7
|
days := 7
|
||||||
|
|
||||||
|
|
|
@ -6,9 +6,9 @@ import (
|
||||||
|
|
||||||
"github.com/divyam234/teldrive/internal/config"
|
"github.com/divyam234/teldrive/internal/config"
|
||||||
"github.com/divyam234/teldrive/internal/logging"
|
"github.com/divyam234/teldrive/internal/logging"
|
||||||
|
"github.com/divyam234/teldrive/internal/tgc"
|
||||||
"github.com/divyam234/teldrive/pkg/models"
|
"github.com/divyam234/teldrive/pkg/models"
|
||||||
"github.com/divyam234/teldrive/pkg/schemas"
|
"github.com/divyam234/teldrive/pkg/schemas"
|
||||||
"github.com/divyam234/teldrive/pkg/services"
|
|
||||||
"github.com/go-co-op/gocron"
|
"github.com/go-co-op/gocron"
|
||||||
"github.com/jackc/pgx/v5/pgtype"
|
"github.com/jackc/pgx/v5/pgtype"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
|
@ -87,7 +87,8 @@ func (c *CronService) CleanFiles(ctx context.Context) {
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
err := services.DeleteTGMessages(ctx, &c.cnf.TG, row.Session, row.ChannelId, row.UserId, ids)
|
client, _ := tgc.AuthClient(ctx, &c.cnf.TG, row.Session)
|
||||||
|
err := tgc.DeleteMessages(ctx, client, row.ChannelId, ids)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.logger.Errorw("failed to delete messages", err)
|
c.logger.Errorw("failed to delete messages", err)
|
||||||
|
@ -122,7 +123,9 @@ func (c *CronService) CleanUploads(ctx context.Context) {
|
||||||
for _, result := range upResults {
|
for _, result := range upResults {
|
||||||
|
|
||||||
if result.Session != "" && len(result.Parts) > 0 {
|
if result.Session != "" && len(result.Parts) > 0 {
|
||||||
err := services.DeleteTGMessages(ctx, &c.cnf.TG, result.Session, result.ChannelId, result.UserId, result.Parts)
|
client, _ := tgc.AuthClient(ctx, &c.cnf.TG, result.Session)
|
||||||
|
|
||||||
|
err := tgc.DeleteMessages(ctx, client, result.ChannelId, result.Parts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.logger.Errorw("failed to delete messages", err)
|
c.logger.Errorw("failed to delete messages", err)
|
||||||
return
|
return
|
||||||
|
|
|
@ -1,214 +1,21 @@
|
||||||
package services
|
package services
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"context"
|
"context"
|
||||||
"crypto/rand"
|
|
||||||
"encoding/binary"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"math"
|
|
||||||
"sort"
|
|
||||||
"strconv"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/divyam234/teldrive/internal/cache"
|
"github.com/divyam234/teldrive/internal/cache"
|
||||||
"github.com/divyam234/teldrive/internal/config"
|
|
||||||
"github.com/divyam234/teldrive/internal/crypt"
|
"github.com/divyam234/teldrive/internal/crypt"
|
||||||
"github.com/divyam234/teldrive/internal/kv"
|
|
||||||
"github.com/divyam234/teldrive/internal/tgc"
|
"github.com/divyam234/teldrive/internal/tgc"
|
||||||
"github.com/divyam234/teldrive/pkg/models"
|
"github.com/divyam234/teldrive/pkg/models"
|
||||||
"github.com/divyam234/teldrive/pkg/schemas"
|
"github.com/divyam234/teldrive/pkg/schemas"
|
||||||
"github.com/divyam234/teldrive/pkg/types"
|
"github.com/divyam234/teldrive/pkg/types"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"github.com/gotd/td/telegram"
|
|
||||||
"github.com/gotd/td/tg"
|
"github.com/gotd/td/tg"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/thoas/go-funk"
|
|
||||||
"golang.org/x/sync/errgroup"
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
type buffer struct {
|
func getParts(ctx context.Context, client *tg.Client, file *schemas.FileOutFull, userID string) ([]types.Part, error) {
|
||||||
Buf []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *buffer) long() (int64, error) {
|
|
||||||
v, err := b.uint64()
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
return int64(v), nil
|
|
||||||
}
|
|
||||||
func (b *buffer) uint64() (uint64, error) {
|
|
||||||
const size = 8
|
|
||||||
if len(b.Buf) < size {
|
|
||||||
return 0, io.ErrUnexpectedEOF
|
|
||||||
}
|
|
||||||
v := binary.LittleEndian.Uint64(b.Buf)
|
|
||||||
b.Buf = b.Buf[size:]
|
|
||||||
return v, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func randInt64() (int64, error) {
|
|
||||||
var buf [8]byte
|
|
||||||
if _, err := io.ReadFull(rand.Reader, buf[:]); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
b := &buffer{Buf: buf[:]}
|
|
||||||
return b.long()
|
|
||||||
}
|
|
||||||
|
|
||||||
type batchResult struct {
|
|
||||||
Index int
|
|
||||||
Messages *tg.MessagesChannelMessages
|
|
||||||
}
|
|
||||||
|
|
||||||
func getChunk(ctx context.Context, tgClient *telegram.Client, location tg.InputFileLocationClass, offset int64, limit int64) ([]byte, error) {
|
|
||||||
|
|
||||||
req := &tg.UploadGetFileRequest{
|
|
||||||
Offset: offset,
|
|
||||||
Limit: int(limit),
|
|
||||||
Location: location,
|
|
||||||
}
|
|
||||||
|
|
||||||
r, err := tgClient.API().UploadGetFile(ctx, req)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
switch result := r.(type) {
|
|
||||||
case *tg.UploadFile:
|
|
||||||
return result.Bytes, nil
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("unexpected type %T", r)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func iterContent(ctx context.Context, tgClient *telegram.Client, location tg.InputFileLocationClass) (*bytes.Buffer, error) {
|
|
||||||
offset := int64(0)
|
|
||||||
limit := int64(1024 * 1024)
|
|
||||||
buff := &bytes.Buffer{}
|
|
||||||
for {
|
|
||||||
r, err := getChunk(ctx, tgClient, location, offset, limit)
|
|
||||||
if err != nil {
|
|
||||||
return buff, err
|
|
||||||
}
|
|
||||||
if len(r) == 0 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
buff.Write(r)
|
|
||||||
offset += int64(limit)
|
|
||||||
}
|
|
||||||
return buff, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetUserAuth(c *gin.Context) (int64, string) {
|
|
||||||
val, _ := c.Get("jwtUser")
|
|
||||||
jwtUser := val.(*types.JWTClaims)
|
|
||||||
userId, _ := strconv.ParseInt(jwtUser.Subject, 10, 64)
|
|
||||||
return userId, jwtUser.TgSession
|
|
||||||
}
|
|
||||||
|
|
||||||
func getBotInfo(ctx context.Context, KV kv.KV, config *config.TGConfig, token string) (*types.BotInfo, error) {
|
|
||||||
client, _ := tgc.BotClient(ctx, KV, config, token, tgc.Middlewares(config, 5)...)
|
|
||||||
var user *tg.User
|
|
||||||
err := tgc.RunWithAuth(ctx, client, token, func(ctx context.Context) error {
|
|
||||||
user, _ = client.Self(ctx)
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &types.BotInfo{Id: user.ID, UserName: user.Username, Token: token}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getTGMessagesBatch(ctx context.Context, client *telegram.Client, channel *tg.InputChannel, parts []schemas.Part, index int,
|
|
||||||
results chan<- batchResult, errors chan<- error, wg *sync.WaitGroup) {
|
|
||||||
|
|
||||||
defer wg.Done()
|
|
||||||
|
|
||||||
ids := funk.Map(parts, func(part schemas.Part) tg.InputMessageClass {
|
|
||||||
return &tg.InputMessageID{ID: int(part.ID)}
|
|
||||||
}).([]tg.InputMessageClass)
|
|
||||||
|
|
||||||
messageRequest := tg.ChannelsGetMessagesRequest{
|
|
||||||
Channel: channel,
|
|
||||||
ID: ids,
|
|
||||||
}
|
|
||||||
|
|
||||||
res, err := client.API().ChannelsGetMessages(ctx, &messageRequest)
|
|
||||||
if err != nil {
|
|
||||||
errors <- err
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
messages, ok := res.(*tg.MessagesChannelMessages)
|
|
||||||
|
|
||||||
if !ok {
|
|
||||||
errors <- fmt.Errorf("unexpected response type: %T", res)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
results <- batchResult{Index: index, Messages: messages}
|
|
||||||
}
|
|
||||||
|
|
||||||
func getTGMessages(ctx context.Context, client *telegram.Client, parts []schemas.Part, channelId int64, userID string) ([]tg.MessageClass, error) {
|
|
||||||
|
|
||||||
channel, err := GetChannelById(ctx, client, channelId, userID)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
|
|
||||||
batchSize := 200
|
|
||||||
|
|
||||||
batchCount := int(math.Ceil(float64(len(parts)) / float64(batchSize)))
|
|
||||||
|
|
||||||
results := make(chan batchResult, batchCount)
|
|
||||||
|
|
||||||
errors := make(chan error, batchCount)
|
|
||||||
|
|
||||||
for i := range batchCount {
|
|
||||||
wg.Add(1)
|
|
||||||
splitParts := parts[i*batchSize : min((i+1)*batchSize, len(parts))]
|
|
||||||
go getTGMessagesBatch(ctx, client, channel, splitParts, i, results, errors, &wg)
|
|
||||||
}
|
|
||||||
|
|
||||||
wg.Wait()
|
|
||||||
close(results)
|
|
||||||
close(errors)
|
|
||||||
|
|
||||||
for err := range errors {
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
batchResults := []batchResult{}
|
|
||||||
|
|
||||||
for result := range results {
|
|
||||||
batchResults = append(batchResults, result)
|
|
||||||
}
|
|
||||||
|
|
||||||
sort.Slice(batchResults, func(i, j int) bool {
|
|
||||||
return batchResults[i].Index < batchResults[j].Index
|
|
||||||
})
|
|
||||||
|
|
||||||
allMessages := []tg.MessageClass{}
|
|
||||||
|
|
||||||
for _, result := range batchResults {
|
|
||||||
allMessages = append(allMessages, result.Messages.GetMessages()...)
|
|
||||||
}
|
|
||||||
|
|
||||||
return allMessages, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getParts(ctx context.Context, client *telegram.Client, file *schemas.FileOutFull, userID string) ([]types.Part, error) {
|
|
||||||
cache := cache.FromContext(ctx)
|
cache := cache.FromContext(ctx)
|
||||||
parts := []types.Part{}
|
parts := []types.Part{}
|
||||||
|
|
||||||
|
@ -220,7 +27,11 @@ func getParts(ctx context.Context, client *telegram.Client, file *schemas.FileOu
|
||||||
return parts, nil
|
return parts, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
messages, err := getTGMessages(ctx, client, file.Parts, file.ChannelID, userID)
|
ids := []int{}
|
||||||
|
for _, part := range file.Parts {
|
||||||
|
ids = append(ids, int(part.ID))
|
||||||
|
}
|
||||||
|
messages, err := tgc.GetMessages(ctx, client, ids, file.ChannelID)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -230,12 +41,11 @@ func getParts(ctx context.Context, client *telegram.Client, file *schemas.FileOu
|
||||||
item := message.(*tg.Message)
|
item := message.(*tg.Message)
|
||||||
media := item.Media.(*tg.MessageMediaDocument)
|
media := item.Media.(*tg.MessageMediaDocument)
|
||||||
document := media.Document.(*tg.Document)
|
document := media.Document.(*tg.Document)
|
||||||
location := document.AsInputDocumentFileLocation()
|
|
||||||
|
|
||||||
part := types.Part{
|
part := types.Part{
|
||||||
Location: location,
|
ID: file.Parts[i].ID,
|
||||||
Size: document.Size,
|
Size: document.Size,
|
||||||
Salt: file.Parts[i].Salt,
|
Salt: file.Parts[i].Salt,
|
||||||
}
|
}
|
||||||
if file.Encrypted {
|
if file.Encrypted {
|
||||||
part.DecryptedSize, _ = crypt.DecryptedSize(document.Size)
|
part.DecryptedSize, _ = crypt.DecryptedSize(document.Size)
|
||||||
|
@ -246,27 +56,7 @@ func getParts(ctx context.Context, client *telegram.Client, file *schemas.FileOu
|
||||||
return parts, nil
|
return parts, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetChannelById(ctx context.Context, client *telegram.Client, channelId int64, userID string) (*tg.InputChannel, error) {
|
func getDefaultChannel(ctx context.Context, db *gorm.DB, userID int64) (int64, error) {
|
||||||
|
|
||||||
channel := &tg.InputChannel{}
|
|
||||||
inputChannel := &tg.InputChannel{
|
|
||||||
ChannelID: channelId,
|
|
||||||
}
|
|
||||||
channels, err := client.API().ChannelsGetChannels(ctx, []tg.InputChannelClass{inputChannel})
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(channels.GetChats()) == 0 {
|
|
||||||
return nil, errors.New("no channels found")
|
|
||||||
}
|
|
||||||
|
|
||||||
channel = channels.GetChats()[0].(*tg.Channel).AsInput()
|
|
||||||
return channel, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetDefaultChannel(ctx context.Context, db *gorm.DB, userID int64) (int64, error) {
|
|
||||||
cache := cache.FromContext(ctx)
|
cache := cache.FromContext(ctx)
|
||||||
var channelId int64
|
var channelId int64
|
||||||
key := fmt.Sprintf("users:channel:%d", userID)
|
key := fmt.Sprintf("users:channel:%d", userID)
|
||||||
|
@ -332,37 +122,3 @@ func getSessionByHash(db *gorm.DB, cache *cache.Cache, hash string) (*models.Ses
|
||||||
return &session, nil
|
return &session, nil
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func DeleteTGMessages(ctx context.Context, cnf *config.TGConfig, session string, channelId, userId int64, ids []int) error {
|
|
||||||
|
|
||||||
client, _ := tgc.AuthClient(ctx, cnf, session)
|
|
||||||
|
|
||||||
err := tgc.RunWithAuth(ctx, client, "", func(ctx context.Context) error {
|
|
||||||
channel, err := GetChannelById(ctx, client, channelId, strconv.FormatInt(userId, 10))
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
batchSize := 100
|
|
||||||
|
|
||||||
batchCount := int(math.Ceil(float64(len(ids)) / float64(batchSize)))
|
|
||||||
|
|
||||||
g, _ := errgroup.WithContext(ctx)
|
|
||||||
|
|
||||||
g.SetLimit(8)
|
|
||||||
|
|
||||||
for i := 0; i < batchCount; i++ {
|
|
||||||
start := i * batchSize
|
|
||||||
end := min((i+1)*batchSize, len(ids))
|
|
||||||
batchIds := ids[start:end]
|
|
||||||
g.Go(func() error {
|
|
||||||
messageDeleteRequest := tg.ChannelsDeleteMessagesRequest{Channel: channel, ID: batchIds}
|
|
||||||
_, err = client.API().ChannelsDeleteMessages(ctx, &messageDeleteRequest)
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return g.Wait()
|
|
||||||
})
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
|
@ -2,7 +2,9 @@ package services
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"mime"
|
"mime"
|
||||||
|
@ -12,6 +14,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/WinterYukky/gorm-extra-clause-plugin/exclause"
|
"github.com/WinterYukky/gorm-extra-clause-plugin/exclause"
|
||||||
|
"github.com/divyam234/teldrive/internal/auth"
|
||||||
"github.com/divyam234/teldrive/internal/cache"
|
"github.com/divyam234/teldrive/internal/cache"
|
||||||
"github.com/divyam234/teldrive/internal/category"
|
"github.com/divyam234/teldrive/internal/category"
|
||||||
"github.com/divyam234/teldrive/internal/config"
|
"github.com/divyam234/teldrive/internal/config"
|
||||||
|
@ -35,6 +38,36 @@ import (
|
||||||
"gorm.io/gorm/clause"
|
"gorm.io/gorm/clause"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type buffer struct {
|
||||||
|
Buf []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *buffer) long() (int64, error) {
|
||||||
|
v, err := b.uint64()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return int64(v), nil
|
||||||
|
}
|
||||||
|
func (b *buffer) uint64() (uint64, error) {
|
||||||
|
const size = 8
|
||||||
|
if len(b.Buf) < size {
|
||||||
|
return 0, io.ErrUnexpectedEOF
|
||||||
|
}
|
||||||
|
v := binary.LittleEndian.Uint64(b.Buf)
|
||||||
|
b.Buf = b.Buf[size:]
|
||||||
|
return v, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func randInt64() (int64, error) {
|
||||||
|
var buf [8]byte
|
||||||
|
if _, err := io.ReadFull(rand.Reader, buf[:]); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
b := &buffer{Buf: buf[:]}
|
||||||
|
return b.long()
|
||||||
|
}
|
||||||
|
|
||||||
type FileService struct {
|
type FileService struct {
|
||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
cnf *config.TGConfig
|
cnf *config.TGConfig
|
||||||
|
@ -75,7 +108,7 @@ func (fs *FileService) CreateFile(c *gin.Context, userId int64, fileIn *schemas.
|
||||||
channelId := fileIn.ChannelID
|
channelId := fileIn.ChannelID
|
||||||
if fileIn.ChannelID == 0 {
|
if fileIn.ChannelID == 0 {
|
||||||
var err error
|
var err error
|
||||||
channelId, err = GetDefaultChannel(c, fs.db, userId)
|
channelId, err = getDefaultChannel(c, fs.db, userId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, &types.AppError{Error: err, Code: http.StatusNotFound}
|
return nil, &types.AppError{Error: err, Code: http.StatusNotFound}
|
||||||
}
|
}
|
||||||
|
@ -332,7 +365,9 @@ func (fs *FileService) DeleteFileParts(c *gin.Context, id string) (*schemas.Mess
|
||||||
return nil, &types.AppError{Error: err}
|
return nil, &types.AppError{Error: err}
|
||||||
}
|
}
|
||||||
|
|
||||||
userId, session := GetUserAuth(c)
|
_, session := auth.GetUser(c)
|
||||||
|
|
||||||
|
client, _ := tgc.AuthClient(c, fs.cnf, session)
|
||||||
|
|
||||||
ids := []int{}
|
ids := []int{}
|
||||||
|
|
||||||
|
@ -340,7 +375,7 @@ func (fs *FileService) DeleteFileParts(c *gin.Context, id string) (*schemas.Mess
|
||||||
ids = append(ids, int(part.ID))
|
ids = append(ids, int(part.ID))
|
||||||
}
|
}
|
||||||
|
|
||||||
err := DeleteTGMessages(c, fs.cnf, session, *file.ChannelID, userId, ids)
|
err := tgc.DeleteMessages(c, client, *file.ChannelID, ids)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, &types.AppError{Error: err}
|
return nil, &types.AppError{Error: err}
|
||||||
|
@ -380,7 +415,7 @@ func (fs *FileService) CopyFile(c *gin.Context) (*schemas.FileOut, *types.AppErr
|
||||||
return nil, &types.AppError{Error: err, Code: http.StatusBadRequest}
|
return nil, &types.AppError{Error: err, Code: http.StatusBadRequest}
|
||||||
}
|
}
|
||||||
|
|
||||||
userId, session := GetUserAuth(c)
|
userId, session := auth.GetUser(c)
|
||||||
|
|
||||||
client, _ := tgc.AuthClient(c, fs.cnf, session)
|
client, _ := tgc.AuthClient(c, fs.cnf, session)
|
||||||
|
|
||||||
|
@ -394,20 +429,24 @@ func (fs *FileService) CopyFile(c *gin.Context) (*schemas.FileOut, *types.AppErr
|
||||||
|
|
||||||
newIds := []schemas.Part{}
|
newIds := []schemas.Part{}
|
||||||
|
|
||||||
channelId, err := GetDefaultChannel(c, fs.db, userId)
|
channelId, err := getDefaultChannel(c, fs.db, userId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, &types.AppError{Error: err}
|
return nil, &types.AppError{Error: err}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = tgc.RunWithAuth(c, client, "", func(ctx context.Context) error {
|
err = tgc.RunWithAuth(c, client, "", func(ctx context.Context) error {
|
||||||
user := strconv.FormatInt(userId, 10)
|
ids := []int{}
|
||||||
messages, err := getTGMessages(c, client, file.Parts, file.ChannelID, user)
|
|
||||||
|
for _, part := range file.Parts {
|
||||||
|
ids = append(ids, int(part.ID))
|
||||||
|
}
|
||||||
|
messages, err := tgc.GetMessages(c, client.API(), ids, file.ChannelID)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
channel, err := GetChannelById(ctx, client, channelId, user)
|
channel, err := tgc.GetChannelById(ctx, client.API(), channelId)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -482,7 +521,7 @@ func (fs *FileService) CopyFile(c *gin.Context) (*schemas.FileOut, *types.AppErr
|
||||||
return mapper.ToFileOut(dbFile), nil
|
return mapper.ToFileOut(dbFile), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fs *FileService) GetFileStream(c *gin.Context) {
|
func (fs *FileService) GetFileStream(c *gin.Context, download bool) {
|
||||||
|
|
||||||
w := c.Writer
|
w := c.Writer
|
||||||
|
|
||||||
|
@ -585,7 +624,7 @@ func (fs *FileService) GetFileStream(c *gin.Context) {
|
||||||
|
|
||||||
disposition := "inline"
|
disposition := "inline"
|
||||||
|
|
||||||
if c.Query("d") == "1" {
|
if download {
|
||||||
disposition = "attachment"
|
disposition = "attachment"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -601,14 +640,15 @@ func (fs *FileService) GetFileStream(c *gin.Context) {
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
channelUser string
|
channelUser string
|
||||||
lr io.ReadCloser
|
lr io.ReadCloser
|
||||||
|
client *tgc.Client
|
||||||
|
multiThreads int
|
||||||
)
|
)
|
||||||
|
|
||||||
var client *tgc.Client
|
multiThreads = fs.cnf.Stream.MultiThreads
|
||||||
|
|
||||||
if fs.cnf.DisableStreamBots || len(tokens) == 0 {
|
if fs.cnf.DisableStreamBots || len(tokens) == 0 {
|
||||||
|
|
||||||
client, err = fs.worker.UserWorker(session.Session, session.UserId)
|
client, err = fs.worker.UserWorker(session.Session, session.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("file stream", zap.Error(err))
|
logger.Error("file stream", zap.Error(err))
|
||||||
|
@ -616,47 +656,36 @@ func (fs *FileService) GetFileStream(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
channelUser = strconv.FormatInt(session.UserId, 10)
|
channelUser = strconv.FormatInt(session.UserId, 10)
|
||||||
|
multiThreads = 0
|
||||||
logger.Debugw("requesting file", "name", file.Name, "bot", channelUser, "user", channelUser, "start", start,
|
|
||||||
"end", end, "fileSize", file.Size)
|
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
var index int
|
|
||||||
|
|
||||||
limit := min(len(tokens), fs.cnf.BgBotsLimit)
|
limit := min(len(tokens), fs.cnf.BgBotsLimit)
|
||||||
|
|
||||||
fs.worker.Set(tokens[:limit], file.ChannelID)
|
fs.worker.Set(tokens[:limit], file.ChannelID)
|
||||||
|
client, _, err = fs.worker.Next(file.ChannelID)
|
||||||
client, index, err = fs.worker.Next(file.ChannelID)
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("file stream", zap.Error(err))
|
logger.Error("file stream", zap.Error(err))
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
channelUser = strings.Split(tokens[index], ":")[0]
|
|
||||||
logger.Debugw("requesting file", "name", file.Name, "bot", channelUser, "botNo", index, "start", start,
|
|
||||||
"end", end, "fileSize", file.Size)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.Method != "HEAD" {
|
if r.Method != "HEAD" {
|
||||||
parts, err := getParts(c, client.Tg, file, channelUser)
|
parts, err := getParts(c, client.Tg.API(), file, channelUser)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("file stream", err)
|
logger.Error("file stream", err)
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
tgClient := client.Tg.API()
|
if download {
|
||||||
|
multiThreads = 0
|
||||||
if fs.cnf.Stream.UsePooling {
|
|
||||||
tgClient = client.Pool.Default(c)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if file.Encrypted {
|
if file.Encrypted {
|
||||||
lr, err = reader.NewDecryptedReader(c, tgClient, parts, start, end, fs.cnf)
|
lr, err = reader.NewDecryptedReader(c, file.Id, parts, start, end, file.ChannelID, fs.cnf, multiThreads, client, fs.worker)
|
||||||
} else {
|
} else {
|
||||||
lr, err = reader.NewLinearReader(c, tgClient, parts, start, end, fs.cnf)
|
lr, err = reader.NewLinearReader(c, file.Id, parts, start, end, file.ChannelID, fs.cnf, multiThreads, client, fs.worker)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -669,7 +698,10 @@ func (fs *FileService) GetFileStream(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
io.CopyN(w, lr, contentLength)
|
_, err = io.CopyN(w, lr, contentLength)
|
||||||
|
if err != nil {
|
||||||
|
lr.Close()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
func setOrderFilter(query *gorm.DB, fquery *schemas.FileQuery) *gorm.DB {
|
func setOrderFilter(query *gorm.DB, fquery *schemas.FileQuery) *gorm.DB {
|
||||||
|
|
|
@ -13,6 +13,7 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/divyam234/teldrive/internal/auth"
|
||||||
"github.com/divyam234/teldrive/internal/crypt"
|
"github.com/divyam234/teldrive/internal/crypt"
|
||||||
"github.com/divyam234/teldrive/internal/kv"
|
"github.com/divyam234/teldrive/internal/kv"
|
||||||
"github.com/divyam234/teldrive/internal/logging"
|
"github.com/divyam234/teldrive/internal/logging"
|
||||||
|
@ -116,7 +117,7 @@ func (us *UploadService) UploadFile(c *gin.Context) (*schemas.UploadPartOut, *ty
|
||||||
Code: http.StatusBadRequest}
|
Code: http.StatusBadRequest}
|
||||||
}
|
}
|
||||||
|
|
||||||
userId, session := GetUserAuth(c)
|
userId, session := auth.GetUser(c)
|
||||||
|
|
||||||
uploadId := c.Param("id")
|
uploadId := c.Param("id")
|
||||||
|
|
||||||
|
@ -127,7 +128,7 @@ func (us *UploadService) UploadFile(c *gin.Context) (*schemas.UploadPartOut, *ty
|
||||||
defer fileStream.Close()
|
defer fileStream.Close()
|
||||||
|
|
||||||
if uploadQuery.ChannelID == 0 {
|
if uploadQuery.ChannelID == 0 {
|
||||||
channelId, err = GetDefaultChannel(c, us.db, userId)
|
channelId, err = getDefaultChannel(c, us.db, userId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, &types.AppError{Error: err}
|
return nil, &types.AppError{Error: err}
|
||||||
}
|
}
|
||||||
|
@ -174,7 +175,7 @@ func (us *UploadService) UploadFile(c *gin.Context) (*schemas.UploadPartOut, *ty
|
||||||
|
|
||||||
err = tgc.RunWithAuth(c, client, token, func(ctx context.Context) error {
|
err = tgc.RunWithAuth(c, client, token, func(ctx context.Context) error {
|
||||||
|
|
||||||
channel, err := GetChannelById(ctx, client, channelId, channelUser)
|
channel, err := tgc.GetChannelById(ctx, client.API(), channelId)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
|
@ -10,6 +10,7 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/divyam234/teldrive/internal/auth"
|
||||||
"github.com/divyam234/teldrive/internal/cache"
|
"github.com/divyam234/teldrive/internal/cache"
|
||||||
"github.com/divyam234/teldrive/internal/config"
|
"github.com/divyam234/teldrive/internal/config"
|
||||||
"github.com/divyam234/teldrive/internal/kv"
|
"github.com/divyam234/teldrive/internal/kv"
|
||||||
|
@ -41,7 +42,7 @@ func NewUserService(db *gorm.DB, cnf *config.Config, kv kv.KV) *UserService {
|
||||||
return &UserService{db: db, cnf: cnf, kv: kv}
|
return &UserService{db: db, cnf: cnf, kv: kv}
|
||||||
}
|
}
|
||||||
func (us *UserService) GetProfilePhoto(c *gin.Context) {
|
func (us *UserService) GetProfilePhoto(c *gin.Context) {
|
||||||
_, session := GetUserAuth(c)
|
_, session := auth.GetUser(c)
|
||||||
|
|
||||||
client, err := tgc.AuthClient(c, &us.cnf.TG, session)
|
client, err := tgc.AuthClient(c, &us.cnf.TG, session)
|
||||||
|
|
||||||
|
@ -64,7 +65,7 @@ func (us *UserService) GetProfilePhoto(c *gin.Context) {
|
||||||
return errors.New("profile not found")
|
return errors.New("profile not found")
|
||||||
}
|
}
|
||||||
location := &tg.InputPeerPhotoFileLocation{Big: false, Peer: peer, PhotoID: photo.PhotoID}
|
location := &tg.InputPeerPhotoFileLocation{Big: false, Peer: peer, PhotoID: photo.PhotoID}
|
||||||
buff, err := iterContent(c, client, location)
|
buff, err := tgc.GetMediaContent(c, client.API(), location)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -83,13 +84,13 @@ func (us *UserService) GetProfilePhoto(c *gin.Context) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (us *UserService) GetStats(c *gin.Context) (*schemas.AccountStats, *types.AppError) {
|
func (us *UserService) GetStats(c *gin.Context) (*schemas.AccountStats, *types.AppError) {
|
||||||
userID, _ := GetUserAuth(c)
|
userID, _ := auth.GetUser(c)
|
||||||
var (
|
var (
|
||||||
channelId int64
|
channelId int64
|
||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
|
|
||||||
channelId, _ = GetDefaultChannel(c, us.db, userID)
|
channelId, _ = getDefaultChannel(c, us.db, userID)
|
||||||
|
|
||||||
tokens, err := getBotsToken(c, us.db, userID, channelId)
|
tokens, err := getBotsToken(c, us.db, userID, channelId)
|
||||||
|
|
||||||
|
@ -103,7 +104,7 @@ func (us *UserService) UpdateChannel(c *gin.Context) (*schemas.Message, *types.A
|
||||||
|
|
||||||
cache := cache.FromContext(c)
|
cache := cache.FromContext(c)
|
||||||
|
|
||||||
userId, _ := GetUserAuth(c)
|
userId, _ := auth.GetUser(c)
|
||||||
|
|
||||||
var payload schemas.Channel
|
var payload schemas.Channel
|
||||||
|
|
||||||
|
@ -130,7 +131,7 @@ func (us *UserService) UpdateChannel(c *gin.Context) (*schemas.Message, *types.A
|
||||||
}
|
}
|
||||||
|
|
||||||
func (us *UserService) ListSessions(c *gin.Context) ([]schemas.SessionOut, *types.AppError) {
|
func (us *UserService) ListSessions(c *gin.Context) ([]schemas.SessionOut, *types.AppError) {
|
||||||
userId, userSession := GetUserAuth(c)
|
userId, userSession := auth.GetUser(c)
|
||||||
|
|
||||||
client, _ := tgc.AuthClient(c, &us.cnf.TG, userSession)
|
client, _ := tgc.AuthClient(c, &us.cnf.TG, userSession)
|
||||||
|
|
||||||
|
@ -185,7 +186,7 @@ func (us *UserService) ListSessions(c *gin.Context) ([]schemas.SessionOut, *type
|
||||||
|
|
||||||
func (us *UserService) RemoveSession(c *gin.Context) (*schemas.Message, *types.AppError) {
|
func (us *UserService) RemoveSession(c *gin.Context) (*schemas.Message, *types.AppError) {
|
||||||
|
|
||||||
userId, _ := GetUserAuth(c)
|
userId, _ := auth.GetUser(c)
|
||||||
|
|
||||||
session := &models.Session{}
|
session := &models.Session{}
|
||||||
|
|
||||||
|
@ -209,7 +210,7 @@ func (us *UserService) RemoveSession(c *gin.Context) (*schemas.Message, *types.A
|
||||||
}
|
}
|
||||||
|
|
||||||
func (us *UserService) ListChannels(c *gin.Context) (interface{}, *types.AppError) {
|
func (us *UserService) ListChannels(c *gin.Context) (interface{}, *types.AppError) {
|
||||||
_, session := GetUserAuth(c)
|
_, session := auth.GetUser(c)
|
||||||
client, _ := tgc.AuthClient(c, &us.cnf.TG, session)
|
client, _ := tgc.AuthClient(c, &us.cnf.TG, session)
|
||||||
|
|
||||||
channels := make(map[int64]*schemas.Channel)
|
channels := make(map[int64]*schemas.Channel)
|
||||||
|
@ -236,7 +237,7 @@ func (us *UserService) ListChannels(c *gin.Context) (interface{}, *types.AppErro
|
||||||
}
|
}
|
||||||
|
|
||||||
func (us *UserService) AddBots(c *gin.Context) (*schemas.Message, *types.AppError) {
|
func (us *UserService) AddBots(c *gin.Context) (*schemas.Message, *types.AppError) {
|
||||||
userId, session := GetUserAuth(c)
|
userId, session := auth.GetUser(c)
|
||||||
client, _ := tgc.AuthClient(c, &us.cnf.TG, session)
|
client, _ := tgc.AuthClient(c, &us.cnf.TG, session)
|
||||||
|
|
||||||
var botsTokens []string
|
var botsTokens []string
|
||||||
|
@ -249,7 +250,7 @@ func (us *UserService) AddBots(c *gin.Context) (*schemas.Message, *types.AppErro
|
||||||
return &schemas.Message{Message: "no bots to add"}, nil
|
return &schemas.Message{Message: "no bots to add"}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
channelId, err := GetDefaultChannel(c, us.db, userId)
|
channelId, err := getDefaultChannel(c, us.db, userId)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, &types.AppError{Error: err, Code: http.StatusInternalServerError}
|
return nil, &types.AppError{Error: err, Code: http.StatusInternalServerError}
|
||||||
|
@ -263,9 +264,9 @@ func (us *UserService) RemoveBots(c *gin.Context) (*schemas.Message, *types.AppE
|
||||||
|
|
||||||
cache := cache.FromContext(c)
|
cache := cache.FromContext(c)
|
||||||
|
|
||||||
userID, _ := GetUserAuth(c)
|
userID, _ := auth.GetUser(c)
|
||||||
|
|
||||||
channelId, err := GetDefaultChannel(c, us.db, userID)
|
channelId, err := getDefaultChannel(c, us.db, userID)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, &types.AppError{Error: err, Code: http.StatusInternalServerError}
|
return nil, &types.AppError{Error: err, Code: http.StatusInternalServerError}
|
||||||
|
@ -294,7 +295,7 @@ func (us *UserService) addBots(c context.Context, client *telegram.Client, userI
|
||||||
|
|
||||||
err := tgc.RunWithAuth(c, client, "", func(ctx context.Context) error {
|
err := tgc.RunWithAuth(c, client, "", func(ctx context.Context) error {
|
||||||
|
|
||||||
channel, err := GetChannelById(ctx, client, channelId, strconv.FormatInt(userId, 10))
|
channel, err := tgc.GetChannelById(ctx, client.API(), channelId)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("error", zap.Error(err))
|
logger.Error("error", zap.Error(err))
|
||||||
|
@ -309,7 +310,7 @@ func (us *UserService) addBots(c context.Context, client *telegram.Client, userI
|
||||||
waitChan <- struct{}{}
|
waitChan <- struct{}{}
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func(t string) {
|
go func(t string) {
|
||||||
info, err := getBotInfo(c, us.kv, &us.cnf.TG, t)
|
info, err := tgc.GetBotInfo(c, client, t)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,7 +3,6 @@ package types
|
||||||
import (
|
import (
|
||||||
"github.com/go-jose/go-jose/v3/jwt"
|
"github.com/go-jose/go-jose/v3/jwt"
|
||||||
"github.com/gotd/td/session"
|
"github.com/gotd/td/session"
|
||||||
"github.com/gotd/td/tg"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type AppError struct {
|
type AppError struct {
|
||||||
|
@ -12,10 +11,10 @@ type AppError struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type Part struct {
|
type Part struct {
|
||||||
Location *tg.InputDocumentFileLocation
|
|
||||||
DecryptedSize int64
|
DecryptedSize int64
|
||||||
Size int64
|
Size int64
|
||||||
Salt string
|
Salt string
|
||||||
|
ID int64
|
||||||
}
|
}
|
||||||
|
|
||||||
type JWTClaims struct {
|
type JWTClaims struct {
|
||||||
|
|
Loading…
Reference in a new issue