mirror of
https://github.com/tgdrive/teldrive.git
synced 2024-09-20 08:15:55 +08:00
feat: MultiThreaded stream support
This commit is contained in:
parent
8c9ca99a93
commit
825dc11fe1
36
.github/workflows/build-dev.yml
vendored
Normal file
36
.github/workflows/build-dev.yml
vendored
Normal file
|
@ -0,0 +1,36 @@
|
|||
name: Build Dev
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- dev
|
||||
|
||||
permissions: write-all
|
||||
|
||||
jobs:
|
||||
goreleaser:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to GitHub Container Registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Set Vars
|
||||
run: |
|
||||
echo "IMAGE=ghcr.io/${GITHUB_REPOSITORY,,}" >> $GITHUB_ENV
|
||||
|
||||
- name: Build Image
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
push: true
|
||||
tags: ${{ env.IMAGE }}:dev
|
24
Dockerfile
Normal file
24
Dockerfile
Normal file
|
@ -0,0 +1,24 @@
|
|||
|
||||
FROM golang:alpine AS builder
|
||||
|
||||
RUN apk add --no-cache git unzip curl make bash
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY go.mod go.sum ./
|
||||
|
||||
RUN go mod download
|
||||
|
||||
COPY . .
|
||||
|
||||
RUN make build
|
||||
|
||||
FROM scratch
|
||||
|
||||
WORKDIR /
|
||||
|
||||
COPY --from=builder /app/bin/teldrive /teldrive
|
||||
|
||||
EXPOSE 8080
|
||||
|
||||
ENTRYPOINT ["/teldrive","run","--tg-session-file","/session.db"]
|
12
Makefile
12
Makefile
|
@ -11,19 +11,16 @@ FRONTEND_ASSET := https://github.com/divyam234/teldrive-ui/releases/download/v1/
|
|||
GIT_TAG := $(shell git describe --tags --abbrev=0)
|
||||
GIT_COMMIT := $(shell git rev-parse --short HEAD)
|
||||
GIT_LINK := $(shell git remote get-url origin)
|
||||
ENV_FILE := $(FRONTEND_DIR)/.env
|
||||
MODULE_PATH := $(shell go list -m)
|
||||
BUILD_DATE := $(shell $(BUILD_DATE))
|
||||
|
||||
GOOS ?= $(shell go env GOOS)
|
||||
GOARCH ?= $(shell go env GOARCH)
|
||||
VERSION:= $(GIT_TAG)
|
||||
BINARY_EXTENSION :=
|
||||
|
||||
.PHONY: all build run clean frontend backend run sync-ui retag patch-version minor-version
|
||||
|
||||
all: build
|
||||
|
||||
|
||||
|
||||
frontend:
|
||||
@echo "Extract UI"
|
||||
ifeq ($(OS),Windows_NT)
|
||||
|
@ -40,10 +37,13 @@ else
|
|||
rm -rf teldrive-ui.zip
|
||||
endif
|
||||
|
||||
ifeq ($(OS),Windows_NT)
|
||||
BINARY_EXTENSION := .exe
|
||||
endif
|
||||
|
||||
backend:
|
||||
@echo "Building backend for $(GOOS)/$(GOARCH)..."
|
||||
go build -trimpath -ldflags "-s -w -X $(MODULE_PATH)/internal/config.Version=$(GIT_TAG) -extldflags=-static" -o $(BUILD_DIR)/$(APP_NAME)$(BINARY_EXTENSION)
|
||||
go build -trimpath -ldflags "-s -w -X $(MODULE_PATH)/internal/config.Version=$(VERSION) -extldflags=-static" -o $(BUILD_DIR)/$(APP_NAME)$(BINARY_EXTENSION)
|
||||
|
||||
build: frontend backend
|
||||
@echo "Building complete."
|
||||
|
|
|
@ -28,6 +28,8 @@ func InitRouter(r *gin.Engine, c *controller.Controller, cnf *config.Config) *gi
|
|||
files.PATCH(":fileID", authmiddleware, c.UpdateFile)
|
||||
files.HEAD(":fileID/stream/:fileName", c.GetFileStream)
|
||||
files.GET(":fileID/stream/:fileName", c.GetFileStream)
|
||||
files.HEAD(":fileID/download/:fileName", c.GetFileDownload)
|
||||
files.GET(":fileID/download/:fileName", c.GetFileDownload)
|
||||
files.DELETE(":fileID/parts", authmiddleware, c.DeleteFileParts)
|
||||
files.GET("/category/stats", authmiddleware, c.GetCategoryStats)
|
||||
files.POST("/move", authmiddleware, c.MoveFiles)
|
||||
|
|
13
cmd/run.go
13
cmd/run.go
|
@ -89,12 +89,11 @@ func NewRun() *cobra.Command {
|
|||
runCmd.Flags().IntVar(&config.TG.Uploads.Threads, "tg-uploads-threads", 8, "Uploads threads")
|
||||
runCmd.Flags().IntVar(&config.TG.Uploads.MaxRetries, "tg-uploads-max-retries", 10, "Uploads Retries")
|
||||
runCmd.Flags().Int64Var(&config.TG.PoolSize, "tg-pool-size", 8, "Telegram Session pool size")
|
||||
runCmd.Flags().BoolVar(&config.TG.Stream.BufferReader, "tg-stream-buffer-reader", false, "Async Buffered reader for fast streaming")
|
||||
runCmd.Flags().IntVar(&config.TG.Stream.Buffers, "tg-stream-buffers", 16, "No of Stream buffers")
|
||||
runCmd.Flags().BoolVar(&config.TG.Stream.UseMmap, "tg-stream-use-mmap", false, "Use mmap for stream buffers")
|
||||
runCmd.Flags().BoolVar(&config.TG.Stream.UsePooling, "tg-stream-use-pooling", false, "Use session pooling for stream workers")
|
||||
duration.DurationVar(runCmd.Flags(), &config.TG.ReconnectTimeout, "tg-reconnect-timeout", 5*time.Minute, "Reconnect Timeout")
|
||||
duration.DurationVar(runCmd.Flags(), &config.TG.Uploads.Retention, "tg-uploads-retention", (24*7)*time.Hour, "Uploads retention duration")
|
||||
runCmd.Flags().IntVar(&config.TG.Stream.MultiThreads, "tg-stream-multi-threads", 0, "Stream multi-threads")
|
||||
runCmd.Flags().IntVar(&config.TG.Stream.Buffers, "tg-stream-buffers", 8, "No of Stream buffers")
|
||||
duration.DurationVar(runCmd.Flags(), &config.TG.Stream.ChunkTimeout, "tg-stream-chunk-timeout", 30*time.Second, "Chunk Fetch Timeout")
|
||||
runCmd.MarkFlagRequired("tg-app-id")
|
||||
runCmd.MarkFlagRequired("tg-app-hash")
|
||||
runCmd.MarkFlagRequired("db-data-source")
|
||||
|
@ -162,11 +161,11 @@ func initViperConfig(cmd *cobra.Command) error {
|
|||
viper.SetEnvKeyReplacer(strings.NewReplacer("-", "_"))
|
||||
viper.AutomaticEnv()
|
||||
viper.ReadInConfig()
|
||||
bindFlagsRecursive(cmd.Flags(), "", reflect.ValueOf(config.Config{}))
|
||||
bindFlags(cmd.Flags(), "", reflect.ValueOf(config.Config{}))
|
||||
return nil
|
||||
|
||||
}
|
||||
func bindFlagsRecursive(flags *pflag.FlagSet, prefix string, v reflect.Value) {
|
||||
func bindFlags(flags *pflag.FlagSet, prefix string, v reflect.Value) {
|
||||
t := v.Type()
|
||||
if t.Kind() == reflect.Ptr {
|
||||
t = t.Elem()
|
||||
|
@ -175,7 +174,7 @@ func bindFlagsRecursive(flags *pflag.FlagSet, prefix string, v reflect.Value) {
|
|||
field := t.Field(i)
|
||||
switch field.Type.Kind() {
|
||||
case reflect.Struct:
|
||||
bindFlagsRecursive(flags, fmt.Sprintf("%s.%s", prefix, strings.ToLower(field.Name)), v.Field(i))
|
||||
bindFlags(flags, fmt.Sprintf("%s.%s", prefix, strings.ToLower(field.Name)), v.Field(i))
|
||||
default:
|
||||
newPrefix := prefix[1:]
|
||||
newName := modifyFlag(field.Name)
|
||||
|
|
|
@ -43,8 +43,6 @@
|
|||
retention = "7d"
|
||||
threads = 8
|
||||
[tg.stream]
|
||||
buffer-reader = false
|
||||
buffers = 6
|
||||
use-mmap = false
|
||||
use-pooling= false
|
||||
multi-threads = 0
|
||||
buffers = 16
|
||||
|
||||
|
|
3
go.mod
3
go.mod
|
@ -16,7 +16,6 @@ require (
|
|||
github.com/magiconair/properties v1.8.7
|
||||
github.com/mitchellh/go-homedir v1.1.0
|
||||
github.com/pkg/errors v0.9.1
|
||||
github.com/rclone/rclone v1.67.0
|
||||
github.com/spf13/cobra v1.8.1
|
||||
github.com/spf13/pflag v1.0.5
|
||||
github.com/spf13/viper v1.19.0
|
||||
|
@ -97,7 +96,7 @@ require (
|
|||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
||||
github.com/pressly/goose/v3 v3.20.0
|
||||
github.com/pressly/goose/v3 v3.21.1
|
||||
github.com/segmentio/asm v1.2.0 // indirect
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
|
|
10
go.sum
10
go.sum
|
@ -145,8 +145,8 @@ github.com/mattn/go-sqlite3 v1.14.19 h1:fhGleo2h1p8tVChob4I9HpmVFIAkKGpiukdrgQbW
|
|||
github.com/mattn/go-sqlite3 v1.14.19/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
|
||||
github.com/mfridman/interpolate v0.0.2 h1:pnuTK7MQIxxFz1Gr+rjSIx9u7qVjf5VOoM/u6BbAxPY=
|
||||
github.com/mfridman/interpolate v0.0.2/go.mod h1:p+7uk6oE07mpE/Ik1b8EckO0O4ZXiGAfshKBWLUM9Xg=
|
||||
github.com/microsoft/go-mssqldb v1.7.0 h1:sgMPW0HA6Ihd37Yx0MzHyKD726C2kY/8KJsQtXHNaAs=
|
||||
github.com/microsoft/go-mssqldb v1.7.0/go.mod h1:kOvZKUdrhhFQmxLZqbwUV0rHkNkZpthMITIb2Ko1IoA=
|
||||
github.com/microsoft/go-mssqldb v1.7.1 h1:KU/g8aWeM3Hx7IMOFpiwYiUkU+9zeISb4+tx3ScVfsM=
|
||||
github.com/microsoft/go-mssqldb v1.7.1/go.mod h1:kOvZKUdrhhFQmxLZqbwUV0rHkNkZpthMITIb2Ko1IoA=
|
||||
github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y=
|
||||
github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
|
||||
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
|
||||
|
@ -166,10 +166,8 @@ github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE
|
|||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/pressly/goose/v3 v3.20.0 h1:uPJdOxF/Ipj7ABVNOAMJXSxwFXZGwMGHNqjC8e61VA0=
|
||||
github.com/pressly/goose/v3 v3.20.0/go.mod h1:BRfF2GcG4FTG12QfdBVy3q1yveaf4ckL9vWwEcIO3lA=
|
||||
github.com/rclone/rclone v1.67.0 h1:yLRNgHEG2vQ60HCuzFqd0hYwKCRuWuvPUhvhMJ2jI5E=
|
||||
github.com/rclone/rclone v1.67.0/go.mod h1:Cb3Ar47M/SvwfhAjZTbVXdtrP/JLtPFCq2tkdtBVC6w=
|
||||
github.com/pressly/goose/v3 v3.21.1 h1:5SSAKKWej8LVVzNLuT6KIvP1eFDuPvxa+B6H0w78buQ=
|
||||
github.com/pressly/goose/v3 v3.21.1/go.mod h1:sqthmzV8PitchEkjecFJII//l43dLOCzfWh8pHEe+vE=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
|
||||
|
|
|
@ -2,8 +2,10 @@ package auth
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strconv"
|
||||
|
||||
"github.com/divyam234/teldrive/pkg/types"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-jose/go-jose/v3"
|
||||
)
|
||||
|
||||
|
@ -59,3 +61,10 @@ func Decode(secret string, token string) (*types.JWTClaims, error) {
|
|||
return jwtToken, nil
|
||||
|
||||
}
|
||||
|
||||
func GetUser(c *gin.Context) (int64, string) {
|
||||
val, _ := c.Get("jwtUser")
|
||||
jwtUser := val.(*types.JWTClaims)
|
||||
userId, _ := strconv.ParseInt(jwtUser.Subject, 10, 64)
|
||||
return userId, jwtUser.TgSession
|
||||
}
|
||||
|
|
|
@ -43,10 +43,9 @@ type TGConfig struct {
|
|||
Retention time.Duration
|
||||
}
|
||||
Stream struct {
|
||||
BufferReader bool
|
||||
MultiThreads int
|
||||
Buffers int
|
||||
UseMmap bool
|
||||
UsePooling bool
|
||||
ChunkTimeout time.Duration
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,206 +0,0 @@
|
|||
// taken from rclone async reader implmentation
|
||||
package reader
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/rclone/rclone/lib/pool"
|
||||
)
|
||||
|
||||
const (
|
||||
BufferSize = 1024 * 1024
|
||||
softStartInitial = 4 * 1024
|
||||
bufferCacheSize = 64
|
||||
bufferCacheFlushTime = 5 * time.Second
|
||||
)
|
||||
|
||||
var ErrorStreamAbandoned = errors.New("stream abandoned")
|
||||
|
||||
type AsyncReader struct {
|
||||
in io.ReadCloser
|
||||
ready chan *buffer
|
||||
token chan struct{}
|
||||
exit chan struct{}
|
||||
buffers int
|
||||
err error
|
||||
cur *buffer
|
||||
exited chan struct{}
|
||||
closed bool
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewAsyncReader(ctx context.Context, rd io.ReadCloser, buffers int) (*AsyncReader, error) {
|
||||
if buffers <= 0 {
|
||||
return nil, errors.New("number of buffers too small")
|
||||
}
|
||||
if rd == nil {
|
||||
return nil, errors.New("nil reader supplied")
|
||||
}
|
||||
a := &AsyncReader{}
|
||||
a.init(rd, buffers)
|
||||
return a, nil
|
||||
}
|
||||
|
||||
func (a *AsyncReader) init(rd io.ReadCloser, buffers int) {
|
||||
a.in = rd
|
||||
a.ready = make(chan *buffer, buffers)
|
||||
a.token = make(chan struct{}, buffers)
|
||||
a.exit = make(chan struct{})
|
||||
a.exited = make(chan struct{})
|
||||
a.buffers = buffers
|
||||
a.cur = nil
|
||||
|
||||
for i := 0; i < buffers; i++ {
|
||||
a.token <- struct{}{}
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer close(a.exited)
|
||||
defer close(a.ready)
|
||||
for {
|
||||
select {
|
||||
case <-a.token:
|
||||
b := a.getBuffer()
|
||||
err := b.read(a.in)
|
||||
a.ready <- b
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
case <-a.exit:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
var bufferPool *pool.Pool
|
||||
var bufferPoolOnce sync.Once
|
||||
|
||||
func (a *AsyncReader) putBuffer(b *buffer) {
|
||||
bufferPool.Put(b.buf)
|
||||
b.buf = nil
|
||||
}
|
||||
|
||||
func (a *AsyncReader) getBuffer() *buffer {
|
||||
bufferPoolOnce.Do(func() {
|
||||
bufferPool = pool.New(bufferCacheFlushTime, BufferSize, bufferCacheSize, false)
|
||||
})
|
||||
return &buffer{
|
||||
buf: bufferPool.Get(),
|
||||
}
|
||||
}
|
||||
|
||||
func (a *AsyncReader) fill() (err error) {
|
||||
if a.cur.isEmpty() {
|
||||
if a.cur != nil {
|
||||
a.putBuffer(a.cur)
|
||||
a.token <- struct{}{}
|
||||
a.cur = nil
|
||||
}
|
||||
b, ok := <-a.ready
|
||||
if !ok {
|
||||
if a.err == nil {
|
||||
return ErrorStreamAbandoned
|
||||
}
|
||||
return a.err
|
||||
}
|
||||
a.cur = b
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *AsyncReader) Read(p []byte) (n int, err error) {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
|
||||
err = a.fill()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
n = copy(p, a.cur.buffer())
|
||||
a.cur.increment(n)
|
||||
|
||||
if a.cur.isEmpty() {
|
||||
a.err = a.cur.err
|
||||
return n, a.err
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (a *AsyncReader) StopBuffering() {
|
||||
select {
|
||||
case <-a.exit:
|
||||
return
|
||||
default:
|
||||
}
|
||||
close(a.exit)
|
||||
<-a.exited
|
||||
}
|
||||
|
||||
func (a *AsyncReader) Abandon() {
|
||||
a.StopBuffering()
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
if a.cur != nil {
|
||||
a.putBuffer(a.cur)
|
||||
a.cur = nil
|
||||
}
|
||||
for b := range a.ready {
|
||||
a.putBuffer(b)
|
||||
}
|
||||
}
|
||||
|
||||
func (a *AsyncReader) Close() (err error) {
|
||||
a.Abandon()
|
||||
if a.closed {
|
||||
return nil
|
||||
}
|
||||
a.closed = true
|
||||
return a.in.Close()
|
||||
}
|
||||
|
||||
type buffer struct {
|
||||
buf []byte
|
||||
err error
|
||||
offset int
|
||||
}
|
||||
|
||||
func (b *buffer) isEmpty() bool {
|
||||
if b == nil {
|
||||
return true
|
||||
}
|
||||
if len(b.buf)-b.offset <= 0 {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (b *buffer) readFill(r io.Reader, buf []byte) (n int, err error) {
|
||||
var nn int
|
||||
for n < len(buf) && err == nil {
|
||||
nn, err = r.Read(buf[n:])
|
||||
n += nn
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (b *buffer) read(rd io.Reader) error {
|
||||
var n int
|
||||
n, b.err = b.readFill(rd, b.buf)
|
||||
b.buf = b.buf[0:n]
|
||||
b.offset = 0
|
||||
return b.err
|
||||
}
|
||||
|
||||
func (b *buffer) buffer() []byte {
|
||||
return b.buf[b.offset:]
|
||||
}
|
||||
|
||||
func (b *buffer) increment(n int) {
|
||||
b.offset += n
|
||||
}
|
|
@ -6,8 +6,8 @@ import (
|
|||
|
||||
"github.com/divyam234/teldrive/internal/config"
|
||||
"github.com/divyam234/teldrive/internal/crypt"
|
||||
"github.com/divyam234/teldrive/internal/tgc"
|
||||
"github.com/divyam234/teldrive/pkg/types"
|
||||
"github.com/gotd/td/tg"
|
||||
)
|
||||
|
||||
type decrpytedReader struct {
|
||||
|
@ -15,27 +15,38 @@ type decrpytedReader struct {
|
|||
parts []types.Part
|
||||
ranges []types.Range
|
||||
pos int
|
||||
client *tg.Client
|
||||
reader io.ReadCloser
|
||||
limit int64
|
||||
err error
|
||||
config *config.TGConfig
|
||||
channelId int64
|
||||
worker *tgc.StreamWorker
|
||||
client *tgc.Client
|
||||
fileId string
|
||||
concurrency int
|
||||
}
|
||||
|
||||
func NewDecryptedReader(
|
||||
ctx context.Context,
|
||||
client *tg.Client,
|
||||
fileId string,
|
||||
parts []types.Part,
|
||||
start, end int64,
|
||||
config *config.TGConfig) (io.ReadCloser, error) {
|
||||
channelId int64,
|
||||
config *config.TGConfig,
|
||||
concurrency int,
|
||||
client *tgc.Client,
|
||||
worker *tgc.StreamWorker) (io.ReadCloser, error) {
|
||||
|
||||
r := &decrpytedReader{
|
||||
ctx: ctx,
|
||||
parts: parts,
|
||||
client: client,
|
||||
limit: end - start + 1,
|
||||
ranges: calculatePartByteRanges(start, end, parts[0].DecryptedSize),
|
||||
config: config,
|
||||
client: client,
|
||||
worker: worker,
|
||||
channelId: channelId,
|
||||
fileId: fileId,
|
||||
concurrency: concurrency,
|
||||
}
|
||||
res, err := r.nextPart()
|
||||
|
||||
|
@ -51,30 +62,24 @@ func NewDecryptedReader(
|
|||
|
||||
func (r *decrpytedReader) Read(p []byte) (n int, err error) {
|
||||
|
||||
if r.err != nil {
|
||||
return 0, r.err
|
||||
}
|
||||
|
||||
if r.limit <= 0 {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
n, err = r.reader.Read(p)
|
||||
|
||||
if err == nil {
|
||||
r.limit -= int64(n)
|
||||
}
|
||||
|
||||
if err == io.EOF {
|
||||
if r.limit > 0 {
|
||||
err = nil
|
||||
if r.reader != nil {
|
||||
r.reader.Close()
|
||||
}
|
||||
}
|
||||
r.pos++
|
||||
if r.pos < len(r.ranges) {
|
||||
r.reader, err = r.nextPart()
|
||||
}
|
||||
}
|
||||
r.err = err
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -89,7 +94,6 @@ func (r *decrpytedReader) Close() (err error) {
|
|||
|
||||
func (r *decrpytedReader) nextPart() (io.ReadCloser, error) {
|
||||
|
||||
location := r.parts[r.ranges[r.pos].PartNo].Location
|
||||
start := r.ranges[r.pos].Start
|
||||
end := r.ranges[r.pos].End
|
||||
salt := r.parts[r.ranges[r.pos].PartNo].Salt
|
||||
|
@ -99,21 +103,15 @@ func (r *decrpytedReader) nextPart() (io.ReadCloser, error) {
|
|||
func(ctx context.Context,
|
||||
underlyingOffset,
|
||||
underlyingLimit int64) (io.ReadCloser, error) {
|
||||
|
||||
var end int64
|
||||
|
||||
if underlyingLimit >= 0 {
|
||||
end = min(r.parts[r.ranges[r.pos].PartNo].Size-1, underlyingOffset+underlyingLimit-1)
|
||||
}
|
||||
rd, err := newTGReader(r.ctx, r.client, location, underlyingOffset, end)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if r.config.Stream.BufferReader {
|
||||
return NewAsyncReader(r.ctx, rd, r.config.Stream.Buffers)
|
||||
|
||||
}
|
||||
return rd, nil
|
||||
chunkSrc := &chunkSource{channelId: r.channelId, worker: r.worker,
|
||||
fileId: r.fileId, partId: r.parts[r.ranges[r.pos].PartNo].ID,
|
||||
client: r.client, concurrency: r.concurrency}
|
||||
return newTGReader(r.ctx, start, end, r.config, chunkSrc)
|
||||
|
||||
}, start, end-start+1)
|
||||
|
|
@ -5,8 +5,8 @@ import (
|
|||
"io"
|
||||
|
||||
"github.com/divyam234/teldrive/internal/config"
|
||||
"github.com/divyam234/teldrive/internal/tgc"
|
||||
"github.com/divyam234/teldrive/pkg/types"
|
||||
"github.com/gotd/td/tg"
|
||||
)
|
||||
|
||||
func calculatePartByteRanges(startByte, endByte, partSize int64) []types.Range {
|
||||
|
@ -41,27 +41,38 @@ type linearReader struct {
|
|||
parts []types.Part
|
||||
ranges []types.Range
|
||||
pos int
|
||||
client *tg.Client
|
||||
reader io.ReadCloser
|
||||
limit int64
|
||||
err error
|
||||
config *config.TGConfig
|
||||
channelId int64
|
||||
worker *tgc.StreamWorker
|
||||
client *tgc.Client
|
||||
fileId string
|
||||
concurrency int
|
||||
}
|
||||
|
||||
func NewLinearReader(ctx context.Context,
|
||||
client *tg.Client,
|
||||
fileId string,
|
||||
parts []types.Part,
|
||||
start, end int64,
|
||||
channelId int64,
|
||||
config *config.TGConfig,
|
||||
concurrency int,
|
||||
client *tgc.Client,
|
||||
worker *tgc.StreamWorker,
|
||||
) (reader io.ReadCloser, err error) {
|
||||
|
||||
r := &linearReader{
|
||||
ctx: ctx,
|
||||
parts: parts,
|
||||
client: client,
|
||||
limit: end - start + 1,
|
||||
ranges: calculatePartByteRanges(start, end, parts[0].Size),
|
||||
config: config,
|
||||
client: client,
|
||||
worker: worker,
|
||||
channelId: channelId,
|
||||
fileId: fileId,
|
||||
concurrency: concurrency,
|
||||
}
|
||||
|
||||
r.reader, err = r.nextPart()
|
||||
|
@ -73,25 +84,22 @@ func NewLinearReader(ctx context.Context,
|
|||
return r, nil
|
||||
}
|
||||
|
||||
func (r *linearReader) Read(p []byte) (n int, err error) {
|
||||
|
||||
if r.err != nil {
|
||||
return 0, r.err
|
||||
}
|
||||
func (r *linearReader) Read(p []byte) (int, error) {
|
||||
|
||||
if r.limit <= 0 {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
n, err = r.reader.Read(p)
|
||||
n, err := r.reader.Read(p)
|
||||
|
||||
if err == nil {
|
||||
r.limit -= int64(n)
|
||||
}
|
||||
|
||||
if err == io.EOF {
|
||||
if r.limit > 0 {
|
||||
err = nil
|
||||
if r.reader != nil {
|
||||
r.reader.Close()
|
||||
}
|
||||
}
|
||||
r.pos++
|
||||
if r.pos < len(r.ranges) {
|
||||
|
@ -99,24 +107,18 @@ func (r *linearReader) Read(p []byte) (n int, err error) {
|
|||
|
||||
}
|
||||
}
|
||||
r.err = err
|
||||
return
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (r *linearReader) nextPart() (io.ReadCloser, error) {
|
||||
|
||||
location := r.parts[r.ranges[r.pos].PartNo].Location
|
||||
startByte := r.ranges[r.pos].Start
|
||||
endByte := r.ranges[r.pos].End
|
||||
rd, err := newTGReader(r.ctx, r.client, location, startByte, endByte)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if r.config.Stream.BufferReader {
|
||||
return NewAsyncReader(r.ctx, rd, r.config.Stream.Buffers)
|
||||
start := r.ranges[r.pos].Start
|
||||
end := r.ranges[r.pos].End
|
||||
|
||||
}
|
||||
return rd, nil
|
||||
chunkSrc := &chunkSource{channelId: r.channelId, worker: r.worker,
|
||||
fileId: r.fileId, partId: r.parts[r.ranges[r.pos].PartNo].ID,
|
||||
client: r.client, concurrency: r.concurrency}
|
||||
return newTGReader(r.ctx, start, end, r.config, chunkSrc)
|
||||
|
||||
}
|
||||
|
||||
|
|
311
internal/reader/tg_reader.go
Normal file
311
internal/reader/tg_reader.go
Normal file
|
@ -0,0 +1,311 @@
|
|||
package reader
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/divyam234/teldrive/internal/config"
|
||||
"github.com/divyam234/teldrive/internal/tgc"
|
||||
"github.com/gotd/td/tg"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
var ErrorStreamAbandoned = errors.New("stream abandoned")
|
||||
|
||||
type ChunkSource interface {
|
||||
Chunk(ctx context.Context, offset int64, limit int64) ([]byte, error)
|
||||
ChunkSize(start, end int64) int64
|
||||
}
|
||||
|
||||
type chunkSource struct {
|
||||
channelId int64
|
||||
worker *tgc.StreamWorker
|
||||
fileId string
|
||||
partId int64
|
||||
concurrency int
|
||||
client *tgc.Client
|
||||
}
|
||||
|
||||
func (c *chunkSource) ChunkSize(start, end int64) int64 {
|
||||
return tgc.CalculateChunkSize(start, end)
|
||||
}
|
||||
|
||||
func (c *chunkSource) Chunk(ctx context.Context, offset int64, limit int64) ([]byte, error) {
|
||||
var (
|
||||
location *tg.InputDocumentFileLocation
|
||||
err error
|
||||
client *tgc.Client
|
||||
)
|
||||
|
||||
client = c.client
|
||||
|
||||
if c.concurrency > 0 {
|
||||
client, _, _ = c.worker.Next(c.channelId)
|
||||
}
|
||||
location, err = tgc.GetLocation(ctx, client, c.fileId, c.channelId, c.partId)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tgc.GetChunk(ctx, client.Tg.API(), location, offset, limit)
|
||||
|
||||
}
|
||||
|
||||
type tgReader struct {
|
||||
ctx context.Context
|
||||
offset int64
|
||||
limit int64
|
||||
chunkSize int64
|
||||
bufferChan chan *buffer
|
||||
done chan struct{}
|
||||
cur *buffer
|
||||
err chan error
|
||||
mu sync.Mutex
|
||||
concurrency int
|
||||
leftCut int64
|
||||
rightCut int64
|
||||
totalParts int
|
||||
currentPart int
|
||||
closed bool
|
||||
timeout time.Duration
|
||||
chunkSrc ChunkSource
|
||||
}
|
||||
|
||||
func newTGReader(
|
||||
ctx context.Context,
|
||||
start int64,
|
||||
end int64,
|
||||
config *config.TGConfig,
|
||||
chunkSrc ChunkSource,
|
||||
|
||||
) (*tgReader, error) {
|
||||
|
||||
chunkSize := chunkSrc.ChunkSize(start, end)
|
||||
|
||||
offset := start - (start % chunkSize)
|
||||
|
||||
r := &tgReader{
|
||||
ctx: ctx,
|
||||
limit: end - start + 1,
|
||||
bufferChan: make(chan *buffer, config.Stream.Buffers),
|
||||
concurrency: config.Stream.MultiThreads,
|
||||
leftCut: start - offset,
|
||||
rightCut: (end % chunkSize) + 1,
|
||||
totalParts: int((end - offset + chunkSize) / chunkSize),
|
||||
offset: offset,
|
||||
chunkSize: chunkSize,
|
||||
chunkSrc: chunkSrc,
|
||||
timeout: config.Stream.ChunkTimeout,
|
||||
done: make(chan struct{}, 1),
|
||||
err: make(chan error, 1),
|
||||
}
|
||||
|
||||
if r.concurrency == 0 {
|
||||
r.currentPart = 1
|
||||
go r.fillBufferSequentially()
|
||||
} else {
|
||||
go r.fillBufferConcurrently()
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (r *tgReader) Close() error {
|
||||
close(r.done)
|
||||
close(r.err)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *tgReader) Read(p []byte) (int, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if r.limit <= 0 {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
if r.cur.isEmpty() {
|
||||
if r.cur != nil {
|
||||
r.cur = nil
|
||||
}
|
||||
select {
|
||||
case cur, ok := <-r.bufferChan:
|
||||
if !ok && r.limit > 0 {
|
||||
return 0, ErrorStreamAbandoned
|
||||
}
|
||||
r.cur = cur
|
||||
|
||||
case err := <-r.err:
|
||||
return 0, fmt.Errorf("error reading chunk: %w", err)
|
||||
case <-r.ctx.Done():
|
||||
return 0, r.ctx.Err()
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
n := copy(p, r.cur.buffer())
|
||||
r.cur.increment(n)
|
||||
r.limit -= int64(n)
|
||||
|
||||
if r.limit <= 0 {
|
||||
return n, io.EOF
|
||||
}
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (r *tgReader) fillBufferConcurrently() error {
|
||||
|
||||
var mapMu sync.Mutex
|
||||
|
||||
bufferMap := make(map[int]*buffer)
|
||||
|
||||
defer func() {
|
||||
close(r.bufferChan)
|
||||
r.closed = true
|
||||
for i := range bufferMap {
|
||||
delete(bufferMap, i)
|
||||
}
|
||||
}()
|
||||
|
||||
cb := func(ctx context.Context, i int) func() error {
|
||||
return func() error {
|
||||
|
||||
chunk, err := r.chunkSrc.Chunk(ctx, r.offset+(int64(i)*r.chunkSize), r.chunkSize)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if r.totalParts == 1 {
|
||||
chunk = chunk[r.leftCut:r.rightCut]
|
||||
} else if r.currentPart+i+1 == 1 {
|
||||
chunk = chunk[r.leftCut:]
|
||||
} else if r.currentPart+i+1 == r.totalParts {
|
||||
chunk = chunk[:r.rightCut]
|
||||
}
|
||||
buf := &buffer{buf: chunk}
|
||||
mapMu.Lock()
|
||||
bufferMap[i] = buf
|
||||
mapMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
for {
|
||||
|
||||
g := errgroup.Group{}
|
||||
|
||||
g.SetLimit(r.concurrency)
|
||||
|
||||
for i := range r.concurrency {
|
||||
if r.currentPart+i+1 <= r.totalParts {
|
||||
g.Go(cb(r.ctx, i))
|
||||
}
|
||||
}
|
||||
|
||||
done := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
done <- g.Wait()
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
r.err <- err
|
||||
return nil
|
||||
} else {
|
||||
for i := range r.concurrency {
|
||||
if r.currentPart+i+1 <= r.totalParts {
|
||||
r.bufferChan <- bufferMap[i]
|
||||
}
|
||||
}
|
||||
r.currentPart += r.concurrency
|
||||
r.offset += r.chunkSize * int64(r.concurrency)
|
||||
for i := range bufferMap {
|
||||
delete(bufferMap, i)
|
||||
}
|
||||
if r.currentPart >= r.totalParts {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
case <-time.After(r.timeout):
|
||||
return nil
|
||||
case <-r.done:
|
||||
return nil
|
||||
case <-r.ctx.Done():
|
||||
return r.ctx.Err()
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func (r *tgReader) fillBufferSequentially() error {
|
||||
|
||||
defer close(r.bufferChan)
|
||||
|
||||
fetchChunk := func(ctx context.Context) (*buffer, error) {
|
||||
chunk, err := r.chunkSrc.Chunk(ctx, r.offset, r.chunkSize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if r.totalParts == 1 {
|
||||
chunk = chunk[r.leftCut:r.rightCut]
|
||||
} else if r.currentPart == 1 {
|
||||
chunk = chunk[r.leftCut:]
|
||||
} else if r.currentPart == r.totalParts {
|
||||
chunk = chunk[:r.rightCut]
|
||||
}
|
||||
return &buffer{buf: chunk}, nil
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-r.done:
|
||||
return nil
|
||||
case <-r.ctx.Done():
|
||||
return r.ctx.Err()
|
||||
case <-time.After(r.timeout):
|
||||
return nil
|
||||
default:
|
||||
buf, err := fetchChunk(r.ctx)
|
||||
if err != nil {
|
||||
r.err <- err
|
||||
return nil
|
||||
}
|
||||
r.bufferChan <- buf
|
||||
r.currentPart++
|
||||
r.offset += r.chunkSize
|
||||
if r.currentPart > r.totalParts {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type buffer struct {
|
||||
buf []byte
|
||||
offset int
|
||||
}
|
||||
|
||||
func (b *buffer) isEmpty() bool {
|
||||
if b == nil {
|
||||
return true
|
||||
}
|
||||
if len(b.buf)-b.offset <= 0 {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (b *buffer) buffer() []byte {
|
||||
return b.buf[b.offset:]
|
||||
}
|
||||
|
||||
func (b *buffer) increment(n int) {
|
||||
b.offset += n
|
||||
}
|
142
internal/reader/tg_reader_test.go
Normal file
142
internal/reader/tg_reader_test.go
Normal file
|
@ -0,0 +1,142 @@
|
|||
package reader
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"io"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/divyam234/teldrive/internal/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type testChunkSource struct {
|
||||
buffer []byte
|
||||
}
|
||||
|
||||
func (m *testChunkSource) Chunk(ctx context.Context, offset int64, limit int64) ([]byte, error) {
|
||||
return m.buffer[offset : offset+limit], nil
|
||||
}
|
||||
|
||||
func (m *testChunkSource) ChunkSize(start, end int64) int64 {
|
||||
return 1
|
||||
}
|
||||
|
||||
type testChunkSourceTimeout struct {
|
||||
buffer []byte
|
||||
}
|
||||
|
||||
func (m *testChunkSourceTimeout) Chunk(ctx context.Context, offset int64, limit int64) ([]byte, error) {
|
||||
if offset == 8 {
|
||||
time.Sleep(2 * time.Second)
|
||||
}
|
||||
return m.buffer[offset : offset+limit], nil
|
||||
}
|
||||
|
||||
func (m *testChunkSourceTimeout) ChunkSize(start, end int64) int64 {
|
||||
return 1
|
||||
}
|
||||
|
||||
type TestSuite struct {
|
||||
suite.Suite
|
||||
config *config.TGConfig
|
||||
}
|
||||
|
||||
func (suite *TestSuite) SetupTest() {
|
||||
suite.config = &config.TGConfig{Stream: struct {
|
||||
MultiThreads int
|
||||
Buffers int
|
||||
ChunkTimeout time.Duration
|
||||
}{MultiThreads: 8, Buffers: 10, ChunkTimeout: 1 * time.Second}}
|
||||
}
|
||||
|
||||
func (suite *TestSuite) TestFullRead() {
|
||||
ctx := context.Background()
|
||||
start := int64(0)
|
||||
end := int64(99)
|
||||
data := make([]byte, 100)
|
||||
rand.Read(data)
|
||||
chunkSrc := &testChunkSource{buffer: data}
|
||||
reader, err := newTGReader(ctx, start, end, suite.config, chunkSrc)
|
||||
assert.NoError(suite.T(), err)
|
||||
test_data, err := io.ReadAll(reader)
|
||||
assert.Equal(suite.T(), nil, err)
|
||||
assert.Equal(suite.T(), data[start:end+1], test_data)
|
||||
}
|
||||
|
||||
func (suite *TestSuite) TestPartialRead() {
|
||||
ctx := context.Background()
|
||||
start := int64(0)
|
||||
end := int64(65)
|
||||
data := make([]byte, 100)
|
||||
rand.Read(data)
|
||||
chunkSrc := &testChunkSource{buffer: data}
|
||||
reader, err := newTGReader(ctx, start, end, suite.config, chunkSrc)
|
||||
assert.NoError(suite.T(), err)
|
||||
test_data, err := io.ReadAll(reader)
|
||||
assert.NoError(suite.T(), err)
|
||||
assert.Equal(suite.T(), data[start:end+1], test_data)
|
||||
}
|
||||
|
||||
func (suite *TestSuite) TestTimeout() {
|
||||
ctx := context.Background()
|
||||
start := int64(0)
|
||||
end := int64(65)
|
||||
data := make([]byte, 100)
|
||||
rand.Read(data)
|
||||
chunkSrc := &testChunkSourceTimeout{buffer: data}
|
||||
reader, err := newTGReader(ctx, start, end, suite.config, chunkSrc)
|
||||
assert.NoError(suite.T(), err)
|
||||
test_data, err := io.ReadAll(reader)
|
||||
assert.Greater(suite.T(), len(test_data), 0)
|
||||
assert.Equal(suite.T(), err, ErrorStreamAbandoned)
|
||||
}
|
||||
|
||||
func (suite *TestSuite) TestClose() {
|
||||
ctx := context.Background()
|
||||
start := int64(0)
|
||||
end := int64(65)
|
||||
data := make([]byte, 100)
|
||||
rand.Read(data)
|
||||
chunkSrc := &testChunkSource{buffer: data}
|
||||
reader, err := newTGReader(ctx, start, end, suite.config, chunkSrc)
|
||||
assert.NoError(suite.T(), err)
|
||||
_, err = io.ReadAll(reader)
|
||||
assert.NoError(suite.T(), err)
|
||||
assert.NoError(suite.T(), reader.Close())
|
||||
}
|
||||
|
||||
func (suite *TestSuite) TestCancellation() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
start := int64(0)
|
||||
end := int64(65)
|
||||
data := make([]byte, 100)
|
||||
rand.Read(data)
|
||||
chunkSrc := &testChunkSource{buffer: data}
|
||||
reader, err := newTGReader(ctx, start, end, suite.config, chunkSrc)
|
||||
assert.NoError(suite.T(), err)
|
||||
cancel()
|
||||
_, err = io.ReadAll(reader)
|
||||
assert.Equal(suite.T(), err, context.Canceled)
|
||||
assert.Equal(suite.T(), len(reader.bufferChan), 0)
|
||||
}
|
||||
|
||||
func (suite *TestSuite) TestCancellationWithTimeout() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||
_ = cancel
|
||||
start := int64(0)
|
||||
end := int64(65)
|
||||
data := make([]byte, 100)
|
||||
rand.Read(data)
|
||||
chunkSrc := &testChunkSourceTimeout{buffer: data}
|
||||
reader, err := newTGReader(ctx, start, end, suite.config, chunkSrc)
|
||||
assert.NoError(suite.T(), err)
|
||||
_, err = io.ReadAll(reader)
|
||||
assert.Equal(suite.T(), err, context.DeadlineExceeded)
|
||||
assert.Equal(suite.T(), len(reader.bufferChan), 0)
|
||||
}
|
||||
func Test(t *testing.T) {
|
||||
suite.Run(t, new(TestSuite))
|
||||
}
|
|
@ -1,144 +0,0 @@
|
|||
package reader
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/gotd/td/tg"
|
||||
)
|
||||
|
||||
type tgReader struct {
|
||||
ctx context.Context
|
||||
client *tg.Client
|
||||
location *tg.InputDocumentFileLocation
|
||||
start int64
|
||||
end int64
|
||||
next func() ([]byte, error)
|
||||
buffer []byte
|
||||
limit int64
|
||||
chunkSize int64
|
||||
i int64
|
||||
}
|
||||
|
||||
func calculateChunkSize(start, end int64) int64 {
|
||||
chunkSize := int64(1024 * 1024)
|
||||
|
||||
for chunkSize > 1024 && chunkSize > (end-start) {
|
||||
chunkSize /= 2
|
||||
}
|
||||
|
||||
return chunkSize
|
||||
}
|
||||
|
||||
func newTGReader(
|
||||
ctx context.Context,
|
||||
client *tg.Client,
|
||||
location *tg.InputDocumentFileLocation,
|
||||
start int64,
|
||||
end int64,
|
||||
|
||||
) (io.ReadCloser, error) {
|
||||
|
||||
r := &tgReader{
|
||||
ctx: ctx,
|
||||
location: location,
|
||||
client: client,
|
||||
start: start,
|
||||
end: end,
|
||||
chunkSize: calculateChunkSize(start, end),
|
||||
limit: end - start + 1,
|
||||
}
|
||||
r.next = r.partStream()
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (r *tgReader) Read(p []byte) (n int, err error) {
|
||||
|
||||
if r.limit <= 0 {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
if r.i >= int64(len(r.buffer)) {
|
||||
r.buffer, err = r.next()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if len(r.buffer) == 0 {
|
||||
r.next = r.partStream()
|
||||
r.buffer, err = r.next()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
}
|
||||
r.i = 0
|
||||
}
|
||||
n = copy(p, r.buffer[r.i:])
|
||||
r.i += int64(n)
|
||||
r.limit -= int64(n)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (*tgReader) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *tgReader) chunk(offset int64, limit int64) ([]byte, error) {
|
||||
|
||||
req := &tg.UploadGetFileRequest{
|
||||
Offset: offset,
|
||||
Limit: int(limit),
|
||||
Location: r.location,
|
||||
Precise: true,
|
||||
}
|
||||
|
||||
res, err := r.client.UploadGetFile(r.ctx, req)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch result := res.(type) {
|
||||
case *tg.UploadFile:
|
||||
return result.Bytes, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected type %T", r)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *tgReader) partStream() func() ([]byte, error) {
|
||||
|
||||
start := r.start
|
||||
end := r.end
|
||||
offset := start - (start % r.chunkSize)
|
||||
|
||||
leftCut := start - offset
|
||||
rightCut := (end % r.chunkSize) + 1
|
||||
totalParts := int((end - offset + r.chunkSize) / r.chunkSize)
|
||||
currentPart := 1
|
||||
|
||||
return func() ([]byte, error) {
|
||||
if currentPart > totalParts {
|
||||
return make([]byte, 0), nil
|
||||
}
|
||||
res, err := r.chunk(offset, r.chunkSize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(res) == 0 {
|
||||
return res, nil
|
||||
} else if totalParts == 1 {
|
||||
res = res[leftCut:rightCut]
|
||||
} else if currentPart == 1 {
|
||||
res = res[leftCut:]
|
||||
} else if currentPart == totalParts {
|
||||
res = res[:rightCut]
|
||||
}
|
||||
|
||||
currentPart++
|
||||
offset += r.chunkSize
|
||||
return res, nil
|
||||
}
|
||||
}
|
239
internal/tgc/helpers.go
Normal file
239
internal/tgc/helpers.go
Normal file
|
@ -0,0 +1,239 @@
|
|||
package tgc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"runtime"
|
||||
"sync"
|
||||
|
||||
"github.com/divyam234/teldrive/internal/cache"
|
||||
"github.com/divyam234/teldrive/pkg/types"
|
||||
"github.com/gotd/td/telegram"
|
||||
"github.com/gotd/td/tg"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInValidChannelID = errors.New("invalid channel id")
|
||||
ErrInvalidChannelMessages = errors.New("invalid channel messages")
|
||||
)
|
||||
|
||||
func GetChannelById(ctx context.Context, client *tg.Client, channelId int64) (*tg.InputChannel, error) {
|
||||
inputChannel := &tg.InputChannel{
|
||||
ChannelID: channelId,
|
||||
}
|
||||
channels, err := client.ChannelsGetChannels(ctx, []tg.InputChannelClass{inputChannel})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(channels.GetChats()) == 0 {
|
||||
return nil, ErrInValidChannelID
|
||||
}
|
||||
return channels.GetChats()[0].(*tg.Channel).AsInput(), nil
|
||||
}
|
||||
|
||||
func DeleteMessages(ctx context.Context, client *telegram.Client, channelId int64, ids []int) error {
|
||||
|
||||
return RunWithAuth(ctx, client, "", func(ctx context.Context) error {
|
||||
channel, err := GetChannelById(ctx, client.API(), channelId)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
batchSize := 100
|
||||
|
||||
batchCount := int(math.Ceil(float64(len(ids)) / float64(batchSize)))
|
||||
|
||||
g, _ := errgroup.WithContext(ctx)
|
||||
|
||||
g.SetLimit(runtime.NumCPU())
|
||||
|
||||
for i := 0; i < batchCount; i++ {
|
||||
start := i * batchSize
|
||||
end := min((i+1)*batchSize, len(ids))
|
||||
batchIds := ids[start:end]
|
||||
g.Go(func() error {
|
||||
messageDeleteRequest := tg.ChannelsDeleteMessagesRequest{Channel: channel, ID: batchIds}
|
||||
_, err = client.API().ChannelsDeleteMessages(ctx, &messageDeleteRequest)
|
||||
return err
|
||||
})
|
||||
}
|
||||
return g.Wait()
|
||||
})
|
||||
}
|
||||
|
||||
func getTGMessagesBatch(ctx context.Context, client *tg.Client, channel *tg.InputChannel, ids []int) (tg.MessagesMessagesClass, error) {
|
||||
|
||||
msgIds := []tg.InputMessageClass{}
|
||||
|
||||
for _, id := range ids {
|
||||
msgIds = append(msgIds, &tg.InputMessageID{ID: id})
|
||||
}
|
||||
|
||||
messageRequest := tg.ChannelsGetMessagesRequest{
|
||||
Channel: channel,
|
||||
ID: msgIds,
|
||||
}
|
||||
|
||||
res, err := client.ChannelsGetMessages(ctx, &messageRequest)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return res, nil
|
||||
|
||||
}
|
||||
|
||||
func GetMessages(ctx context.Context, client *tg.Client, ids []int, channelId int64) ([]tg.MessageClass, error) {
|
||||
|
||||
channel, err := GetChannelById(ctx, client, channelId)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
batchSize := 200
|
||||
|
||||
batchCount := int(math.Ceil(float64(len(ids)) / float64(batchSize)))
|
||||
|
||||
g, _ := errgroup.WithContext(ctx)
|
||||
|
||||
g.SetLimit(runtime.NumCPU())
|
||||
|
||||
messageMap := make(map[int]*tg.MessagesChannelMessages)
|
||||
|
||||
var mapMu sync.Mutex
|
||||
|
||||
for i := range batchCount {
|
||||
g.Go(func() error {
|
||||
splitIds := ids[i*batchSize : min((i+1)*batchSize, len(ids))]
|
||||
res, err := getTGMessagesBatch(ctx, client, channel, splitIds)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
messages, ok := res.(*tg.MessagesChannelMessages)
|
||||
if !ok {
|
||||
return ErrInvalidChannelMessages
|
||||
}
|
||||
mapMu.Lock()
|
||||
messageMap[i] = messages
|
||||
mapMu.Unlock()
|
||||
return nil
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
if err = g.Wait(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
allMessages := []tg.MessageClass{}
|
||||
|
||||
for i := range batchCount {
|
||||
allMessages = append(allMessages, messageMap[i].Messages...)
|
||||
}
|
||||
|
||||
return allMessages, nil
|
||||
}
|
||||
|
||||
func GetChunk(ctx context.Context, client *tg.Client, location tg.InputFileLocationClass, offset int64, limit int64) ([]byte, error) {
|
||||
req := &tg.UploadGetFileRequest{
|
||||
Offset: offset,
|
||||
Limit: int(limit),
|
||||
Location: location,
|
||||
Precise: true,
|
||||
}
|
||||
|
||||
r, err := client.UploadGetFile(ctx, req)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch result := r.(type) {
|
||||
case *tg.UploadFile:
|
||||
return result.Bytes, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected type %T", r)
|
||||
}
|
||||
}
|
||||
|
||||
func GetMediaContent(ctx context.Context, client *tg.Client, location tg.InputFileLocationClass) (*bytes.Buffer, error) {
|
||||
offset := int64(0)
|
||||
limit := int64(1024 * 1024)
|
||||
buff := &bytes.Buffer{}
|
||||
for {
|
||||
r, err := GetChunk(ctx, client, location, offset, limit)
|
||||
if err != nil {
|
||||
return buff, err
|
||||
}
|
||||
if len(r) == 0 {
|
||||
break
|
||||
}
|
||||
buff.Write(r)
|
||||
offset += int64(limit)
|
||||
}
|
||||
return buff, nil
|
||||
}
|
||||
|
||||
func GetBotInfo(ctx context.Context, client *telegram.Client, token string) (*types.BotInfo, error) {
|
||||
var user *tg.User
|
||||
err := RunWithAuth(ctx, client, token, func(ctx context.Context) error {
|
||||
user, _ = client.Self(ctx)
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &types.BotInfo{Id: user.ID, UserName: user.Username, Token: token}, nil
|
||||
}
|
||||
|
||||
func GetLocation(ctx context.Context, client *Client, fileId string, channelId int64, partId int64) (location *tg.InputDocumentFileLocation, err error) {
|
||||
|
||||
cache := cache.FromContext(ctx)
|
||||
|
||||
key := fmt.Sprintf("location:%s:%s:%d", client.UserId, fileId, partId)
|
||||
|
||||
err = cache.Get(key, location)
|
||||
|
||||
if err != nil {
|
||||
channel, err := GetChannelById(ctx, client.Tg.API(), channelId)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
messageRequest := tg.ChannelsGetMessagesRequest{
|
||||
Channel: channel,
|
||||
ID: []tg.InputMessageClass{&tg.InputMessageID{ID: int(partId)}},
|
||||
}
|
||||
|
||||
res, err := client.Tg.API().ChannelsGetMessages(ctx, &messageRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
messages, _ := res.(*tg.MessagesChannelMessages)
|
||||
item := messages.Messages[0].(*tg.Message)
|
||||
media := item.Media.(*tg.MessageMediaDocument)
|
||||
document := media.Document.(*tg.Document)
|
||||
location = document.AsInputDocumentFileLocation()
|
||||
cache.Set(key, location, 3600)
|
||||
}
|
||||
return location, nil
|
||||
}
|
||||
|
||||
func CalculateChunkSize(start, end int64) int64 {
|
||||
chunkSize := int64(1024 * 1024)
|
||||
|
||||
for chunkSize > 1024 && chunkSize > (end-start) {
|
||||
chunkSize /= 2
|
||||
}
|
||||
return chunkSize
|
||||
}
|
|
@ -55,8 +55,8 @@ func New(ctx context.Context, config *config.TGConfig, handler telegram.UpdateHa
|
|||
LangCode: config.LangCode,
|
||||
},
|
||||
SessionStorage: storage,
|
||||
RetryInterval: 5 * time.Second,
|
||||
MaxRetries: 5,
|
||||
RetryInterval: 2 * time.Second,
|
||||
MaxRetries: 10,
|
||||
DialTimeout: 10 * time.Second,
|
||||
Middlewares: middlewares,
|
||||
UpdateHandler: handler,
|
||||
|
|
|
@ -2,11 +2,11 @@ package tgc
|
|||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/divyam234/teldrive/internal/config"
|
||||
"github.com/divyam234/teldrive/internal/kv"
|
||||
"github.com/divyam234/teldrive/internal/pool"
|
||||
"github.com/gotd/td/telegram"
|
||||
)
|
||||
|
||||
|
@ -42,9 +42,9 @@ func NewUploadWorker() *UploadWorker {
|
|||
|
||||
type Client struct {
|
||||
Tg *telegram.Client
|
||||
Pool pool.Pool
|
||||
Stop StopFunc
|
||||
Status string
|
||||
UserId string
|
||||
}
|
||||
|
||||
type StreamWorker struct {
|
||||
|
@ -69,10 +69,7 @@ func (w *StreamWorker) Set(bots []string, channelId int64) {
|
|||
for _, token := range bots {
|
||||
middlewares := Middlewares(w.cnf, 5)
|
||||
client, _ := BotClient(w.ctx, w.kv, w.cnf, token, middlewares...)
|
||||
c := &Client{Tg: client, Status: "idle"}
|
||||
if w.cnf.Stream.UsePooling {
|
||||
c.Pool = pool.NewPool(client, int64(w.cnf.PoolSize), middlewares...)
|
||||
}
|
||||
c := &Client{Tg: client, Status: "idle", UserId: strings.Split(token, ":")[0]}
|
||||
w.clients[channelId] = append(w.clients[channelId], c)
|
||||
}
|
||||
w.currIdx[channelId] = 0
|
||||
|
@ -108,9 +105,6 @@ func (w *StreamWorker) UserWorker(session string, userId int64) (*Client, error)
|
|||
middlewares := Middlewares(w.cnf, 5)
|
||||
client, _ := AuthClient(w.ctx, w.cnf, session, middlewares...)
|
||||
c := &Client{Tg: client, Status: "idle"}
|
||||
if w.cnf.Stream.UsePooling {
|
||||
c.Pool = pool.NewPool(client, int64(w.cnf.PoolSize), middlewares...)
|
||||
}
|
||||
w.clients[userId] = append(w.clients[userId], c)
|
||||
}
|
||||
nextClient := w.clients[userId][0]
|
||||
|
|
|
@ -3,11 +3,11 @@ package controller
|
|||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/divyam234/teldrive/internal/auth"
|
||||
"github.com/divyam234/teldrive/internal/cache"
|
||||
"github.com/divyam234/teldrive/internal/logging"
|
||||
"github.com/divyam234/teldrive/pkg/httputil"
|
||||
"github.com/divyam234/teldrive/pkg/schemas"
|
||||
"github.com/divyam234/teldrive/pkg/services"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
|
@ -23,7 +23,7 @@ func (fc *Controller) CreateFile(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
userId, _ := services.GetUserAuth(c)
|
||||
userId, _ := auth.GetUser(c)
|
||||
|
||||
res, err := fc.FileService.CreateFile(c, userId, &fileIn)
|
||||
if err != nil {
|
||||
|
@ -36,7 +36,7 @@ func (fc *Controller) CreateFile(c *gin.Context) {
|
|||
|
||||
func (fc *Controller) UpdateFile(c *gin.Context) {
|
||||
|
||||
userId, _ := services.GetUserAuth(c)
|
||||
userId, _ := auth.GetUser(c)
|
||||
|
||||
var fileUpdate schemas.FileUpdate
|
||||
|
||||
|
@ -65,7 +65,7 @@ func (fc *Controller) GetFileByID(c *gin.Context) {
|
|||
|
||||
func (fc *Controller) ListFiles(c *gin.Context) {
|
||||
|
||||
userId, _ := services.GetUserAuth(c)
|
||||
userId, _ := auth.GetUser(c)
|
||||
|
||||
fquery := schemas.FileQuery{
|
||||
PerPage: 500,
|
||||
|
@ -90,7 +90,7 @@ func (fc *Controller) ListFiles(c *gin.Context) {
|
|||
|
||||
func (fc *Controller) MakeDirectory(c *gin.Context) {
|
||||
|
||||
userId, _ := services.GetUserAuth(c)
|
||||
userId, _ := auth.GetUser(c)
|
||||
|
||||
var payload schemas.MkDir
|
||||
if err := c.ShouldBindJSON(&payload); err != nil {
|
||||
|
@ -118,7 +118,7 @@ func (fc *Controller) CopyFile(c *gin.Context) {
|
|||
|
||||
func (fc *Controller) MoveFiles(c *gin.Context) {
|
||||
|
||||
userId, _ := services.GetUserAuth(c)
|
||||
userId, _ := auth.GetUser(c)
|
||||
|
||||
var payload schemas.FileOperation
|
||||
if err := c.ShouldBindJSON(&payload); err != nil {
|
||||
|
@ -136,7 +136,7 @@ func (fc *Controller) MoveFiles(c *gin.Context) {
|
|||
|
||||
func (fc *Controller) DeleteFiles(c *gin.Context) {
|
||||
|
||||
userId, _ := services.GetUserAuth(c)
|
||||
userId, _ := auth.GetUser(c)
|
||||
|
||||
var payload schemas.DeleteOperation
|
||||
if err := c.ShouldBindJSON(&payload); err != nil {
|
||||
|
@ -164,7 +164,7 @@ func (fc *Controller) DeleteFileParts(c *gin.Context) {
|
|||
}
|
||||
|
||||
func (fc *Controller) MoveDirectory(c *gin.Context) {
|
||||
userId, _ := services.GetUserAuth(c)
|
||||
userId, _ := auth.GetUser(c)
|
||||
|
||||
var payload schemas.DirMove
|
||||
if err := c.ShouldBindJSON(&payload); err != nil {
|
||||
|
@ -181,7 +181,7 @@ func (fc *Controller) MoveDirectory(c *gin.Context) {
|
|||
}
|
||||
|
||||
func (fc *Controller) GetCategoryStats(c *gin.Context) {
|
||||
userId, _ := services.GetUserAuth(c)
|
||||
userId, _ := auth.GetUser(c)
|
||||
|
||||
res, err := fc.FileService.GetCategoryStats(userId)
|
||||
if err != nil {
|
||||
|
@ -193,5 +193,9 @@ func (fc *Controller) GetCategoryStats(c *gin.Context) {
|
|||
}
|
||||
|
||||
func (fc *Controller) GetFileStream(c *gin.Context) {
|
||||
fc.FileService.GetFileStream(c)
|
||||
fc.FileService.GetFileStream(c, false)
|
||||
}
|
||||
|
||||
func (fc *Controller) GetFileDownload(c *gin.Context) {
|
||||
fc.FileService.GetFileStream(c, true)
|
||||
}
|
||||
|
|
|
@ -4,8 +4,8 @@ import (
|
|||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/divyam234/teldrive/internal/auth"
|
||||
"github.com/divyam234/teldrive/pkg/httputil"
|
||||
"github.com/divyam234/teldrive/pkg/services"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
|
@ -40,7 +40,7 @@ func (uc *Controller) UploadFile(c *gin.Context) {
|
|||
}
|
||||
|
||||
func (uc *Controller) UploadStats(c *gin.Context) {
|
||||
userId, _ := services.GetUserAuth(c)
|
||||
userId, _ := auth.GetUser(c)
|
||||
|
||||
days := 7
|
||||
|
||||
|
|
|
@ -6,9 +6,9 @@ import (
|
|||
|
||||
"github.com/divyam234/teldrive/internal/config"
|
||||
"github.com/divyam234/teldrive/internal/logging"
|
||||
"github.com/divyam234/teldrive/internal/tgc"
|
||||
"github.com/divyam234/teldrive/pkg/models"
|
||||
"github.com/divyam234/teldrive/pkg/schemas"
|
||||
"github.com/divyam234/teldrive/pkg/services"
|
||||
"github.com/go-co-op/gocron"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"go.uber.org/zap"
|
||||
|
@ -87,7 +87,8 @@ func (c *CronService) CleanFiles(ctx context.Context) {
|
|||
}
|
||||
|
||||
}
|
||||
err := services.DeleteTGMessages(ctx, &c.cnf.TG, row.Session, row.ChannelId, row.UserId, ids)
|
||||
client, _ := tgc.AuthClient(ctx, &c.cnf.TG, row.Session)
|
||||
err := tgc.DeleteMessages(ctx, client, row.ChannelId, ids)
|
||||
|
||||
if err != nil {
|
||||
c.logger.Errorw("failed to delete messages", err)
|
||||
|
@ -122,7 +123,9 @@ func (c *CronService) CleanUploads(ctx context.Context) {
|
|||
for _, result := range upResults {
|
||||
|
||||
if result.Session != "" && len(result.Parts) > 0 {
|
||||
err := services.DeleteTGMessages(ctx, &c.cnf.TG, result.Session, result.ChannelId, result.UserId, result.Parts)
|
||||
client, _ := tgc.AuthClient(ctx, &c.cnf.TG, result.Session)
|
||||
|
||||
err := tgc.DeleteMessages(ctx, client, result.ChannelId, result.Parts)
|
||||
if err != nil {
|
||||
c.logger.Errorw("failed to delete messages", err)
|
||||
return
|
||||
|
|
|
@ -1,214 +1,21 @@
|
|||
package services
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"sort"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
"github.com/divyam234/teldrive/internal/cache"
|
||||
"github.com/divyam234/teldrive/internal/config"
|
||||
"github.com/divyam234/teldrive/internal/crypt"
|
||||
"github.com/divyam234/teldrive/internal/kv"
|
||||
"github.com/divyam234/teldrive/internal/tgc"
|
||||
"github.com/divyam234/teldrive/pkg/models"
|
||||
"github.com/divyam234/teldrive/pkg/schemas"
|
||||
"github.com/divyam234/teldrive/pkg/types"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gotd/td/telegram"
|
||||
"github.com/gotd/td/tg"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/thoas/go-funk"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type buffer struct {
|
||||
Buf []byte
|
||||
}
|
||||
|
||||
func (b *buffer) long() (int64, error) {
|
||||
v, err := b.uint64()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int64(v), nil
|
||||
}
|
||||
func (b *buffer) uint64() (uint64, error) {
|
||||
const size = 8
|
||||
if len(b.Buf) < size {
|
||||
return 0, io.ErrUnexpectedEOF
|
||||
}
|
||||
v := binary.LittleEndian.Uint64(b.Buf)
|
||||
b.Buf = b.Buf[size:]
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func randInt64() (int64, error) {
|
||||
var buf [8]byte
|
||||
if _, err := io.ReadFull(rand.Reader, buf[:]); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
b := &buffer{Buf: buf[:]}
|
||||
return b.long()
|
||||
}
|
||||
|
||||
type batchResult struct {
|
||||
Index int
|
||||
Messages *tg.MessagesChannelMessages
|
||||
}
|
||||
|
||||
func getChunk(ctx context.Context, tgClient *telegram.Client, location tg.InputFileLocationClass, offset int64, limit int64) ([]byte, error) {
|
||||
|
||||
req := &tg.UploadGetFileRequest{
|
||||
Offset: offset,
|
||||
Limit: int(limit),
|
||||
Location: location,
|
||||
}
|
||||
|
||||
r, err := tgClient.API().UploadGetFile(ctx, req)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch result := r.(type) {
|
||||
case *tg.UploadFile:
|
||||
return result.Bytes, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected type %T", r)
|
||||
}
|
||||
}
|
||||
|
||||
func iterContent(ctx context.Context, tgClient *telegram.Client, location tg.InputFileLocationClass) (*bytes.Buffer, error) {
|
||||
offset := int64(0)
|
||||
limit := int64(1024 * 1024)
|
||||
buff := &bytes.Buffer{}
|
||||
for {
|
||||
r, err := getChunk(ctx, tgClient, location, offset, limit)
|
||||
if err != nil {
|
||||
return buff, err
|
||||
}
|
||||
if len(r) == 0 {
|
||||
break
|
||||
}
|
||||
buff.Write(r)
|
||||
offset += int64(limit)
|
||||
}
|
||||
return buff, nil
|
||||
}
|
||||
|
||||
func GetUserAuth(c *gin.Context) (int64, string) {
|
||||
val, _ := c.Get("jwtUser")
|
||||
jwtUser := val.(*types.JWTClaims)
|
||||
userId, _ := strconv.ParseInt(jwtUser.Subject, 10, 64)
|
||||
return userId, jwtUser.TgSession
|
||||
}
|
||||
|
||||
func getBotInfo(ctx context.Context, KV kv.KV, config *config.TGConfig, token string) (*types.BotInfo, error) {
|
||||
client, _ := tgc.BotClient(ctx, KV, config, token, tgc.Middlewares(config, 5)...)
|
||||
var user *tg.User
|
||||
err := tgc.RunWithAuth(ctx, client, token, func(ctx context.Context) error {
|
||||
user, _ = client.Self(ctx)
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &types.BotInfo{Id: user.ID, UserName: user.Username, Token: token}, nil
|
||||
}
|
||||
|
||||
func getTGMessagesBatch(ctx context.Context, client *telegram.Client, channel *tg.InputChannel, parts []schemas.Part, index int,
|
||||
results chan<- batchResult, errors chan<- error, wg *sync.WaitGroup) {
|
||||
|
||||
defer wg.Done()
|
||||
|
||||
ids := funk.Map(parts, func(part schemas.Part) tg.InputMessageClass {
|
||||
return &tg.InputMessageID{ID: int(part.ID)}
|
||||
}).([]tg.InputMessageClass)
|
||||
|
||||
messageRequest := tg.ChannelsGetMessagesRequest{
|
||||
Channel: channel,
|
||||
ID: ids,
|
||||
}
|
||||
|
||||
res, err := client.API().ChannelsGetMessages(ctx, &messageRequest)
|
||||
if err != nil {
|
||||
errors <- err
|
||||
return
|
||||
}
|
||||
|
||||
messages, ok := res.(*tg.MessagesChannelMessages)
|
||||
|
||||
if !ok {
|
||||
errors <- fmt.Errorf("unexpected response type: %T", res)
|
||||
return
|
||||
}
|
||||
|
||||
results <- batchResult{Index: index, Messages: messages}
|
||||
}
|
||||
|
||||
func getTGMessages(ctx context.Context, client *telegram.Client, parts []schemas.Part, channelId int64, userID string) ([]tg.MessageClass, error) {
|
||||
|
||||
channel, err := GetChannelById(ctx, client, channelId, userID)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
batchSize := 200
|
||||
|
||||
batchCount := int(math.Ceil(float64(len(parts)) / float64(batchSize)))
|
||||
|
||||
results := make(chan batchResult, batchCount)
|
||||
|
||||
errors := make(chan error, batchCount)
|
||||
|
||||
for i := range batchCount {
|
||||
wg.Add(1)
|
||||
splitParts := parts[i*batchSize : min((i+1)*batchSize, len(parts))]
|
||||
go getTGMessagesBatch(ctx, client, channel, splitParts, i, results, errors, &wg)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(results)
|
||||
close(errors)
|
||||
|
||||
for err := range errors {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
batchResults := []batchResult{}
|
||||
|
||||
for result := range results {
|
||||
batchResults = append(batchResults, result)
|
||||
}
|
||||
|
||||
sort.Slice(batchResults, func(i, j int) bool {
|
||||
return batchResults[i].Index < batchResults[j].Index
|
||||
})
|
||||
|
||||
allMessages := []tg.MessageClass{}
|
||||
|
||||
for _, result := range batchResults {
|
||||
allMessages = append(allMessages, result.Messages.GetMessages()...)
|
||||
}
|
||||
|
||||
return allMessages, nil
|
||||
}
|
||||
|
||||
func getParts(ctx context.Context, client *telegram.Client, file *schemas.FileOutFull, userID string) ([]types.Part, error) {
|
||||
func getParts(ctx context.Context, client *tg.Client, file *schemas.FileOutFull, userID string) ([]types.Part, error) {
|
||||
cache := cache.FromContext(ctx)
|
||||
parts := []types.Part{}
|
||||
|
||||
|
@ -220,7 +27,11 @@ func getParts(ctx context.Context, client *telegram.Client, file *schemas.FileOu
|
|||
return parts, nil
|
||||
}
|
||||
|
||||
messages, err := getTGMessages(ctx, client, file.Parts, file.ChannelID, userID)
|
||||
ids := []int{}
|
||||
for _, part := range file.Parts {
|
||||
ids = append(ids, int(part.ID))
|
||||
}
|
||||
messages, err := tgc.GetMessages(ctx, client, ids, file.ChannelID)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -230,10 +41,9 @@ func getParts(ctx context.Context, client *telegram.Client, file *schemas.FileOu
|
|||
item := message.(*tg.Message)
|
||||
media := item.Media.(*tg.MessageMediaDocument)
|
||||
document := media.Document.(*tg.Document)
|
||||
location := document.AsInputDocumentFileLocation()
|
||||
|
||||
part := types.Part{
|
||||
Location: location,
|
||||
ID: file.Parts[i].ID,
|
||||
Size: document.Size,
|
||||
Salt: file.Parts[i].Salt,
|
||||
}
|
||||
|
@ -246,27 +56,7 @@ func getParts(ctx context.Context, client *telegram.Client, file *schemas.FileOu
|
|||
return parts, nil
|
||||
}
|
||||
|
||||
func GetChannelById(ctx context.Context, client *telegram.Client, channelId int64, userID string) (*tg.InputChannel, error) {
|
||||
|
||||
channel := &tg.InputChannel{}
|
||||
inputChannel := &tg.InputChannel{
|
||||
ChannelID: channelId,
|
||||
}
|
||||
channels, err := client.API().ChannelsGetChannels(ctx, []tg.InputChannelClass{inputChannel})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(channels.GetChats()) == 0 {
|
||||
return nil, errors.New("no channels found")
|
||||
}
|
||||
|
||||
channel = channels.GetChats()[0].(*tg.Channel).AsInput()
|
||||
return channel, nil
|
||||
}
|
||||
|
||||
func GetDefaultChannel(ctx context.Context, db *gorm.DB, userID int64) (int64, error) {
|
||||
func getDefaultChannel(ctx context.Context, db *gorm.DB, userID int64) (int64, error) {
|
||||
cache := cache.FromContext(ctx)
|
||||
var channelId int64
|
||||
key := fmt.Sprintf("users:channel:%d", userID)
|
||||
|
@ -332,37 +122,3 @@ func getSessionByHash(db *gorm.DB, cache *cache.Cache, hash string) (*models.Ses
|
|||
return &session, nil
|
||||
|
||||
}
|
||||
|
||||
func DeleteTGMessages(ctx context.Context, cnf *config.TGConfig, session string, channelId, userId int64, ids []int) error {
|
||||
|
||||
client, _ := tgc.AuthClient(ctx, cnf, session)
|
||||
|
||||
err := tgc.RunWithAuth(ctx, client, "", func(ctx context.Context) error {
|
||||
channel, err := GetChannelById(ctx, client, channelId, strconv.FormatInt(userId, 10))
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
batchSize := 100
|
||||
|
||||
batchCount := int(math.Ceil(float64(len(ids)) / float64(batchSize)))
|
||||
|
||||
g, _ := errgroup.WithContext(ctx)
|
||||
|
||||
g.SetLimit(8)
|
||||
|
||||
for i := 0; i < batchCount; i++ {
|
||||
start := i * batchSize
|
||||
end := min((i+1)*batchSize, len(ids))
|
||||
batchIds := ids[start:end]
|
||||
g.Go(func() error {
|
||||
messageDeleteRequest := tg.ChannelsDeleteMessagesRequest{Channel: channel, ID: batchIds}
|
||||
_, err = client.API().ChannelsDeleteMessages(ctx, &messageDeleteRequest)
|
||||
return err
|
||||
})
|
||||
}
|
||||
return g.Wait()
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -2,7 +2,9 @@ package services
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
|
@ -12,6 +14,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/WinterYukky/gorm-extra-clause-plugin/exclause"
|
||||
"github.com/divyam234/teldrive/internal/auth"
|
||||
"github.com/divyam234/teldrive/internal/cache"
|
||||
"github.com/divyam234/teldrive/internal/category"
|
||||
"github.com/divyam234/teldrive/internal/config"
|
||||
|
@ -35,6 +38,36 @@ import (
|
|||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
type buffer struct {
|
||||
Buf []byte
|
||||
}
|
||||
|
||||
func (b *buffer) long() (int64, error) {
|
||||
v, err := b.uint64()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int64(v), nil
|
||||
}
|
||||
func (b *buffer) uint64() (uint64, error) {
|
||||
const size = 8
|
||||
if len(b.Buf) < size {
|
||||
return 0, io.ErrUnexpectedEOF
|
||||
}
|
||||
v := binary.LittleEndian.Uint64(b.Buf)
|
||||
b.Buf = b.Buf[size:]
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func randInt64() (int64, error) {
|
||||
var buf [8]byte
|
||||
if _, err := io.ReadFull(rand.Reader, buf[:]); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
b := &buffer{Buf: buf[:]}
|
||||
return b.long()
|
||||
}
|
||||
|
||||
type FileService struct {
|
||||
db *gorm.DB
|
||||
cnf *config.TGConfig
|
||||
|
@ -75,7 +108,7 @@ func (fs *FileService) CreateFile(c *gin.Context, userId int64, fileIn *schemas.
|
|||
channelId := fileIn.ChannelID
|
||||
if fileIn.ChannelID == 0 {
|
||||
var err error
|
||||
channelId, err = GetDefaultChannel(c, fs.db, userId)
|
||||
channelId, err = getDefaultChannel(c, fs.db, userId)
|
||||
if err != nil {
|
||||
return nil, &types.AppError{Error: err, Code: http.StatusNotFound}
|
||||
}
|
||||
|
@ -332,7 +365,9 @@ func (fs *FileService) DeleteFileParts(c *gin.Context, id string) (*schemas.Mess
|
|||
return nil, &types.AppError{Error: err}
|
||||
}
|
||||
|
||||
userId, session := GetUserAuth(c)
|
||||
_, session := auth.GetUser(c)
|
||||
|
||||
client, _ := tgc.AuthClient(c, fs.cnf, session)
|
||||
|
||||
ids := []int{}
|
||||
|
||||
|
@ -340,7 +375,7 @@ func (fs *FileService) DeleteFileParts(c *gin.Context, id string) (*schemas.Mess
|
|||
ids = append(ids, int(part.ID))
|
||||
}
|
||||
|
||||
err := DeleteTGMessages(c, fs.cnf, session, *file.ChannelID, userId, ids)
|
||||
err := tgc.DeleteMessages(c, client, *file.ChannelID, ids)
|
||||
|
||||
if err != nil {
|
||||
return nil, &types.AppError{Error: err}
|
||||
|
@ -380,7 +415,7 @@ func (fs *FileService) CopyFile(c *gin.Context) (*schemas.FileOut, *types.AppErr
|
|||
return nil, &types.AppError{Error: err, Code: http.StatusBadRequest}
|
||||
}
|
||||
|
||||
userId, session := GetUserAuth(c)
|
||||
userId, session := auth.GetUser(c)
|
||||
|
||||
client, _ := tgc.AuthClient(c, fs.cnf, session)
|
||||
|
||||
|
@ -394,20 +429,24 @@ func (fs *FileService) CopyFile(c *gin.Context) (*schemas.FileOut, *types.AppErr
|
|||
|
||||
newIds := []schemas.Part{}
|
||||
|
||||
channelId, err := GetDefaultChannel(c, fs.db, userId)
|
||||
channelId, err := getDefaultChannel(c, fs.db, userId)
|
||||
if err != nil {
|
||||
return nil, &types.AppError{Error: err}
|
||||
}
|
||||
|
||||
err = tgc.RunWithAuth(c, client, "", func(ctx context.Context) error {
|
||||
user := strconv.FormatInt(userId, 10)
|
||||
messages, err := getTGMessages(c, client, file.Parts, file.ChannelID, user)
|
||||
ids := []int{}
|
||||
|
||||
for _, part := range file.Parts {
|
||||
ids = append(ids, int(part.ID))
|
||||
}
|
||||
messages, err := tgc.GetMessages(c, client.API(), ids, file.ChannelID)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
channel, err := GetChannelById(ctx, client, channelId, user)
|
||||
channel, err := tgc.GetChannelById(ctx, client.API(), channelId)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -482,7 +521,7 @@ func (fs *FileService) CopyFile(c *gin.Context) (*schemas.FileOut, *types.AppErr
|
|||
return mapper.ToFileOut(dbFile), nil
|
||||
}
|
||||
|
||||
func (fs *FileService) GetFileStream(c *gin.Context) {
|
||||
func (fs *FileService) GetFileStream(c *gin.Context, download bool) {
|
||||
|
||||
w := c.Writer
|
||||
|
||||
|
@ -585,7 +624,7 @@ func (fs *FileService) GetFileStream(c *gin.Context) {
|
|||
|
||||
disposition := "inline"
|
||||
|
||||
if c.Query("d") == "1" {
|
||||
if download {
|
||||
disposition = "attachment"
|
||||
}
|
||||
|
||||
|
@ -603,12 +642,13 @@ func (fs *FileService) GetFileStream(c *gin.Context) {
|
|||
var (
|
||||
channelUser string
|
||||
lr io.ReadCloser
|
||||
client *tgc.Client
|
||||
multiThreads int
|
||||
)
|
||||
|
||||
var client *tgc.Client
|
||||
multiThreads = fs.cnf.Stream.MultiThreads
|
||||
|
||||
if fs.cnf.DisableStreamBots || len(tokens) == 0 {
|
||||
|
||||
client, err = fs.worker.UserWorker(session.Session, session.UserId)
|
||||
if err != nil {
|
||||
logger.Error("file stream", zap.Error(err))
|
||||
|
@ -616,47 +656,36 @@ func (fs *FileService) GetFileStream(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
channelUser = strconv.FormatInt(session.UserId, 10)
|
||||
|
||||
logger.Debugw("requesting file", "name", file.Name, "bot", channelUser, "user", channelUser, "start", start,
|
||||
"end", end, "fileSize", file.Size)
|
||||
multiThreads = 0
|
||||
|
||||
} else {
|
||||
var index int
|
||||
|
||||
limit := min(len(tokens), fs.cnf.BgBotsLimit)
|
||||
|
||||
fs.worker.Set(tokens[:limit], file.ChannelID)
|
||||
|
||||
client, index, err = fs.worker.Next(file.ChannelID)
|
||||
|
||||
client, _, err = fs.worker.Next(file.ChannelID)
|
||||
if err != nil {
|
||||
logger.Error("file stream", zap.Error(err))
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
channelUser = strings.Split(tokens[index], ":")[0]
|
||||
logger.Debugw("requesting file", "name", file.Name, "bot", channelUser, "botNo", index, "start", start,
|
||||
"end", end, "fileSize", file.Size)
|
||||
}
|
||||
|
||||
if r.Method != "HEAD" {
|
||||
parts, err := getParts(c, client.Tg, file, channelUser)
|
||||
parts, err := getParts(c, client.Tg.API(), file, channelUser)
|
||||
if err != nil {
|
||||
logger.Error("file stream", err)
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
tgClient := client.Tg.API()
|
||||
|
||||
if fs.cnf.Stream.UsePooling {
|
||||
tgClient = client.Pool.Default(c)
|
||||
if download {
|
||||
multiThreads = 0
|
||||
}
|
||||
|
||||
if file.Encrypted {
|
||||
lr, err = reader.NewDecryptedReader(c, tgClient, parts, start, end, fs.cnf)
|
||||
lr, err = reader.NewDecryptedReader(c, file.Id, parts, start, end, file.ChannelID, fs.cnf, multiThreads, client, fs.worker)
|
||||
} else {
|
||||
lr, err = reader.NewLinearReader(c, tgClient, parts, start, end, fs.cnf)
|
||||
lr, err = reader.NewLinearReader(c, file.Id, parts, start, end, file.ChannelID, fs.cnf, multiThreads, client, fs.worker)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
|
@ -669,7 +698,10 @@ func (fs *FileService) GetFileStream(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
io.CopyN(w, lr, contentLength)
|
||||
_, err = io.CopyN(w, lr, contentLength)
|
||||
if err != nil {
|
||||
lr.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
func setOrderFilter(query *gorm.DB, fquery *schemas.FileQuery) *gorm.DB {
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/divyam234/teldrive/internal/auth"
|
||||
"github.com/divyam234/teldrive/internal/crypt"
|
||||
"github.com/divyam234/teldrive/internal/kv"
|
||||
"github.com/divyam234/teldrive/internal/logging"
|
||||
|
@ -116,7 +117,7 @@ func (us *UploadService) UploadFile(c *gin.Context) (*schemas.UploadPartOut, *ty
|
|||
Code: http.StatusBadRequest}
|
||||
}
|
||||
|
||||
userId, session := GetUserAuth(c)
|
||||
userId, session := auth.GetUser(c)
|
||||
|
||||
uploadId := c.Param("id")
|
||||
|
||||
|
@ -127,7 +128,7 @@ func (us *UploadService) UploadFile(c *gin.Context) (*schemas.UploadPartOut, *ty
|
|||
defer fileStream.Close()
|
||||
|
||||
if uploadQuery.ChannelID == 0 {
|
||||
channelId, err = GetDefaultChannel(c, us.db, userId)
|
||||
channelId, err = getDefaultChannel(c, us.db, userId)
|
||||
if err != nil {
|
||||
return nil, &types.AppError{Error: err}
|
||||
}
|
||||
|
@ -174,7 +175,7 @@ func (us *UploadService) UploadFile(c *gin.Context) (*schemas.UploadPartOut, *ty
|
|||
|
||||
err = tgc.RunWithAuth(c, client, token, func(ctx context.Context) error {
|
||||
|
||||
channel, err := GetChannelById(ctx, client, channelId, channelUser)
|
||||
channel, err := tgc.GetChannelById(ctx, client.API(), channelId)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/divyam234/teldrive/internal/auth"
|
||||
"github.com/divyam234/teldrive/internal/cache"
|
||||
"github.com/divyam234/teldrive/internal/config"
|
||||
"github.com/divyam234/teldrive/internal/kv"
|
||||
|
@ -41,7 +42,7 @@ func NewUserService(db *gorm.DB, cnf *config.Config, kv kv.KV) *UserService {
|
|||
return &UserService{db: db, cnf: cnf, kv: kv}
|
||||
}
|
||||
func (us *UserService) GetProfilePhoto(c *gin.Context) {
|
||||
_, session := GetUserAuth(c)
|
||||
_, session := auth.GetUser(c)
|
||||
|
||||
client, err := tgc.AuthClient(c, &us.cnf.TG, session)
|
||||
|
||||
|
@ -64,7 +65,7 @@ func (us *UserService) GetProfilePhoto(c *gin.Context) {
|
|||
return errors.New("profile not found")
|
||||
}
|
||||
location := &tg.InputPeerPhotoFileLocation{Big: false, Peer: peer, PhotoID: photo.PhotoID}
|
||||
buff, err := iterContent(c, client, location)
|
||||
buff, err := tgc.GetMediaContent(c, client.API(), location)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -83,13 +84,13 @@ func (us *UserService) GetProfilePhoto(c *gin.Context) {
|
|||
}
|
||||
|
||||
func (us *UserService) GetStats(c *gin.Context) (*schemas.AccountStats, *types.AppError) {
|
||||
userID, _ := GetUserAuth(c)
|
||||
userID, _ := auth.GetUser(c)
|
||||
var (
|
||||
channelId int64
|
||||
err error
|
||||
)
|
||||
|
||||
channelId, _ = GetDefaultChannel(c, us.db, userID)
|
||||
channelId, _ = getDefaultChannel(c, us.db, userID)
|
||||
|
||||
tokens, err := getBotsToken(c, us.db, userID, channelId)
|
||||
|
||||
|
@ -103,7 +104,7 @@ func (us *UserService) UpdateChannel(c *gin.Context) (*schemas.Message, *types.A
|
|||
|
||||
cache := cache.FromContext(c)
|
||||
|
||||
userId, _ := GetUserAuth(c)
|
||||
userId, _ := auth.GetUser(c)
|
||||
|
||||
var payload schemas.Channel
|
||||
|
||||
|
@ -130,7 +131,7 @@ func (us *UserService) UpdateChannel(c *gin.Context) (*schemas.Message, *types.A
|
|||
}
|
||||
|
||||
func (us *UserService) ListSessions(c *gin.Context) ([]schemas.SessionOut, *types.AppError) {
|
||||
userId, userSession := GetUserAuth(c)
|
||||
userId, userSession := auth.GetUser(c)
|
||||
|
||||
client, _ := tgc.AuthClient(c, &us.cnf.TG, userSession)
|
||||
|
||||
|
@ -185,7 +186,7 @@ func (us *UserService) ListSessions(c *gin.Context) ([]schemas.SessionOut, *type
|
|||
|
||||
func (us *UserService) RemoveSession(c *gin.Context) (*schemas.Message, *types.AppError) {
|
||||
|
||||
userId, _ := GetUserAuth(c)
|
||||
userId, _ := auth.GetUser(c)
|
||||
|
||||
session := &models.Session{}
|
||||
|
||||
|
@ -209,7 +210,7 @@ func (us *UserService) RemoveSession(c *gin.Context) (*schemas.Message, *types.A
|
|||
}
|
||||
|
||||
func (us *UserService) ListChannels(c *gin.Context) (interface{}, *types.AppError) {
|
||||
_, session := GetUserAuth(c)
|
||||
_, session := auth.GetUser(c)
|
||||
client, _ := tgc.AuthClient(c, &us.cnf.TG, session)
|
||||
|
||||
channels := make(map[int64]*schemas.Channel)
|
||||
|
@ -236,7 +237,7 @@ func (us *UserService) ListChannels(c *gin.Context) (interface{}, *types.AppErro
|
|||
}
|
||||
|
||||
func (us *UserService) AddBots(c *gin.Context) (*schemas.Message, *types.AppError) {
|
||||
userId, session := GetUserAuth(c)
|
||||
userId, session := auth.GetUser(c)
|
||||
client, _ := tgc.AuthClient(c, &us.cnf.TG, session)
|
||||
|
||||
var botsTokens []string
|
||||
|
@ -249,7 +250,7 @@ func (us *UserService) AddBots(c *gin.Context) (*schemas.Message, *types.AppErro
|
|||
return &schemas.Message{Message: "no bots to add"}, nil
|
||||
}
|
||||
|
||||
channelId, err := GetDefaultChannel(c, us.db, userId)
|
||||
channelId, err := getDefaultChannel(c, us.db, userId)
|
||||
|
||||
if err != nil {
|
||||
return nil, &types.AppError{Error: err, Code: http.StatusInternalServerError}
|
||||
|
@ -263,9 +264,9 @@ func (us *UserService) RemoveBots(c *gin.Context) (*schemas.Message, *types.AppE
|
|||
|
||||
cache := cache.FromContext(c)
|
||||
|
||||
userID, _ := GetUserAuth(c)
|
||||
userID, _ := auth.GetUser(c)
|
||||
|
||||
channelId, err := GetDefaultChannel(c, us.db, userID)
|
||||
channelId, err := getDefaultChannel(c, us.db, userID)
|
||||
|
||||
if err != nil {
|
||||
return nil, &types.AppError{Error: err, Code: http.StatusInternalServerError}
|
||||
|
@ -294,7 +295,7 @@ func (us *UserService) addBots(c context.Context, client *telegram.Client, userI
|
|||
|
||||
err := tgc.RunWithAuth(c, client, "", func(ctx context.Context) error {
|
||||
|
||||
channel, err := GetChannelById(ctx, client, channelId, strconv.FormatInt(userId, 10))
|
||||
channel, err := tgc.GetChannelById(ctx, client.API(), channelId)
|
||||
|
||||
if err != nil {
|
||||
logger.Error("error", zap.Error(err))
|
||||
|
@ -309,7 +310,7 @@ func (us *UserService) addBots(c context.Context, client *telegram.Client, userI
|
|||
waitChan <- struct{}{}
|
||||
wg.Add(1)
|
||||
go func(t string) {
|
||||
info, err := getBotInfo(c, us.kv, &us.cnf.TG, t)
|
||||
info, err := tgc.GetBotInfo(c, client, t)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
|
|
@ -3,7 +3,6 @@ package types
|
|||
import (
|
||||
"github.com/go-jose/go-jose/v3/jwt"
|
||||
"github.com/gotd/td/session"
|
||||
"github.com/gotd/td/tg"
|
||||
)
|
||||
|
||||
type AppError struct {
|
||||
|
@ -12,10 +11,10 @@ type AppError struct {
|
|||
}
|
||||
|
||||
type Part struct {
|
||||
Location *tg.InputDocumentFileLocation
|
||||
DecryptedSize int64
|
||||
Size int64
|
||||
Salt string
|
||||
ID int64
|
||||
}
|
||||
|
||||
type JWTClaims struct {
|
||||
|
|
Loading…
Reference in a new issue