mirror of
https://github.com/tgdrive/teldrive.git
synced 2024-09-20 08:15:55 +08:00
refactor: Update dependencies and refactor multireader and workers
This commit is contained in:
parent
1532bc0418
commit
c7149701f1
16
cmd/run.go
16
cmd/run.go
|
@ -99,8 +99,8 @@ func NewRun() *cobra.Command {
|
|||
runCmd.Flags().StringVar(&config.TG.SystemLangCode, "tg-system-lang-code", "en-US", "System language code")
|
||||
runCmd.Flags().StringVar(&config.TG.LangPack, "tg-lang-pack", "webk", "Language pack")
|
||||
runCmd.Flags().StringVar(&config.TG.Proxy, "tg-proxy", "", "HTTP OR SOCKS5 proxy URL")
|
||||
runCmd.Flags().IntVar(&config.TG.BgBotsLimit, "tg-bg-bots-limit", 5, "Background bots limit")
|
||||
runCmd.Flags().BoolVar(&config.TG.DisableStreamBots, "tg-disable-stream-bots", false, "Disable stream bots")
|
||||
runCmd.Flags().BoolVar(&config.TG.DisableBgBots, "tg-disable-bg-bots", false, "Disable Background bots")
|
||||
runCmd.Flags().BoolVar(&config.TG.DisableStreamBots, "tg-disable-stream-bots", false, "Disable Stream bots")
|
||||
runCmd.Flags().BoolVar(&config.TG.EnableLogging, "tg-enable-logging", false, "Enable telegram client logging")
|
||||
runCmd.Flags().StringVar(&config.TG.Uploads.EncryptionKey, "tg-uploads-encryption-key", "", "Uploads encryption key")
|
||||
runCmd.Flags().IntVar(&config.TG.Uploads.Threads, "tg-uploads-threads", 8, "Uploads threads")
|
||||
|
@ -108,12 +108,11 @@ func NewRun() *cobra.Command {
|
|||
runCmd.Flags().Int64Var(&config.TG.PoolSize, "tg-pool-size", 8, "Telegram Session pool size")
|
||||
duration.DurationVar(runCmd.Flags(), &config.TG.ReconnectTimeout, "tg-reconnect-timeout", 5*time.Minute, "Reconnect Timeout")
|
||||
duration.DurationVar(runCmd.Flags(), &config.TG.Uploads.Retention, "tg-uploads-retention", (24*7)*time.Hour, "Uploads retention duration")
|
||||
duration.DurationVar(runCmd.Flags(), &config.TG.BgBotsCheckInterval, "tg-bg-bots-check-interval", 4*time.Hour, "Interval for checking Idle background bots")
|
||||
runCmd.Flags().IntVar(&config.TG.Stream.MultiThreads, "tg-stream-multi-threads", 0, "Stream multi-threads")
|
||||
runCmd.Flags().IntVar(&config.TG.Stream.Buffers, "tg-stream-buffers", 8, "No of Stream buffers")
|
||||
runCmd.Flags().IntVar(&config.TG.Stream.BotsOffset, "tg-stream-bots-offset", 1, "Stream Bots Offset")
|
||||
duration.DurationVar(runCmd.Flags(), &config.TG.Stream.ChunkTimeout, "tg-stream-chunk-timeout", 30*time.Second, "Chunk Fetch Timeout")
|
||||
duration.DurationVar(runCmd.Flags(), &config.TG.BgBotsTimeout, "tg-bg-bots-timeout", 30*time.Minute, "Stop Timeout for Idle background bots")
|
||||
duration.DurationVar(runCmd.Flags(), &config.TG.BgBotsCheckInterval, "tg-bg-bots-check-interval", 5*time.Minute, "Interval for checking Idle background bots")
|
||||
runCmd.Flags().IntVar(&config.TG.Stream.BotsLimit, "tg-stream-bots-limit", 5, "Stream bots limit")
|
||||
duration.DurationVar(runCmd.Flags(), &config.TG.Stream.ChunkTimeout, "tg-stream-chunk-timeout", 20*time.Second, "Chunk Fetch Timeout")
|
||||
runCmd.MarkFlagRequired("tg-app-id")
|
||||
runCmd.MarkFlagRequired("tg-app-hash")
|
||||
runCmd.MarkFlagRequired("db-data-source")
|
||||
|
@ -147,13 +146,14 @@ func runApplication(conf *config.Config) {
|
|||
return cacher
|
||||
}),
|
||||
fx.Supply(logging.DefaultLogger().Desugar()),
|
||||
fx.Supply(logging.DefaultLogger()),
|
||||
fx.NopLogger,
|
||||
fx.StopTimeout(conf.Server.GracefulShutdown+time.Second),
|
||||
fx.Provide(
|
||||
database.NewDatabase,
|
||||
kv.NewBoltKV,
|
||||
tgc.NewStreamWorker(ctx),
|
||||
tgc.NewUploadWorker,
|
||||
tgc.NewBotWorker,
|
||||
tgc.NewStreamWorker,
|
||||
services.NewAuthService,
|
||||
services.NewFileService,
|
||||
services.NewUploadService,
|
||||
|
|
|
@ -47,9 +47,8 @@ type TGConfig struct {
|
|||
SystemLangCode string
|
||||
LangPack string
|
||||
SessionFile string
|
||||
BgBotsLimit int
|
||||
DisableBgBots bool
|
||||
DisableStreamBots bool
|
||||
BgBotsTimeout time.Duration
|
||||
BgBotsCheckInterval time.Duration
|
||||
Proxy string
|
||||
ReconnectTimeout time.Duration
|
||||
|
@ -62,7 +61,7 @@ type TGConfig struct {
|
|||
Retention time.Duration
|
||||
}
|
||||
Stream struct {
|
||||
BotsOffset int
|
||||
BotsLimit int
|
||||
MultiThreads int
|
||||
Buffers int
|
||||
ChunkTimeout time.Duration
|
||||
|
|
|
@ -2,53 +2,55 @@ package reader
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/divyam234/teldrive/internal/cache"
|
||||
"github.com/divyam234/teldrive/internal/config"
|
||||
"github.com/divyam234/teldrive/internal/crypt"
|
||||
"github.com/divyam234/teldrive/internal/tgc"
|
||||
"github.com/divyam234/teldrive/pkg/schemas"
|
||||
"github.com/divyam234/teldrive/pkg/types"
|
||||
"github.com/gotd/td/tg"
|
||||
)
|
||||
|
||||
type decrpytedReader struct {
|
||||
ctx context.Context
|
||||
file *schemas.FileOutFull
|
||||
parts []types.Part
|
||||
ranges []types.Range
|
||||
pos int
|
||||
reader io.ReadCloser
|
||||
limit int64
|
||||
config *config.TGConfig
|
||||
channelId int64
|
||||
worker *tgc.StreamWorker
|
||||
client *tgc.Client
|
||||
fileId string
|
||||
client *tg.Client
|
||||
concurrency int
|
||||
cache cache.Cacher
|
||||
}
|
||||
|
||||
func NewDecryptedReader(
|
||||
ctx context.Context,
|
||||
fileId string,
|
||||
client *tg.Client,
|
||||
worker *tgc.StreamWorker,
|
||||
cache cache.Cacher,
|
||||
file *schemas.FileOutFull,
|
||||
parts []types.Part,
|
||||
start, end int64,
|
||||
channelId int64,
|
||||
start,
|
||||
end int64,
|
||||
config *config.TGConfig,
|
||||
concurrency int,
|
||||
client *tgc.Client,
|
||||
worker *tgc.StreamWorker,
|
||||
cache cache.Cacher) (*decrpytedReader, error) {
|
||||
) (*decrpytedReader, error) {
|
||||
|
||||
r := &decrpytedReader{
|
||||
ctx: ctx,
|
||||
parts: parts,
|
||||
file: file,
|
||||
limit: end - start + 1,
|
||||
ranges: calculatePartByteRanges(start, end, parts[0].DecryptedSize),
|
||||
config: config,
|
||||
client: client,
|
||||
worker: worker,
|
||||
channelId: channelId,
|
||||
fileId: fileId,
|
||||
concurrency: concurrency,
|
||||
cache: cache,
|
||||
}
|
||||
|
@ -113,14 +115,16 @@ func (r *decrpytedReader) nextPart() (io.ReadCloser, error) {
|
|||
if underlyingLimit >= 0 {
|
||||
end = min(r.parts[r.ranges[r.pos].PartNo].Size-1, underlyingOffset+underlyingLimit-1)
|
||||
}
|
||||
partID := r.parts[r.ranges[r.pos].PartNo].ID
|
||||
|
||||
chunkSrc := &chunkSource{
|
||||
channelID: r.channelId,
|
||||
worker: r.worker,
|
||||
fileID: r.fileId,
|
||||
partID: r.parts[r.ranges[r.pos].PartNo].ID,
|
||||
channelID: r.file.ChannelID,
|
||||
partID: partID,
|
||||
client: r.client,
|
||||
concurrency: r.concurrency,
|
||||
cache: r.cache,
|
||||
key: fmt.Sprintf("files:location:%s:%d", r.file.Id, partID),
|
||||
worker: r.worker,
|
||||
}
|
||||
if r.concurrency < 2 {
|
||||
return newTGReader(r.ctx, underlyingOffset, end, chunkSrc)
|
||||
|
|
|
@ -2,12 +2,15 @@ package reader
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/divyam234/teldrive/internal/cache"
|
||||
"github.com/divyam234/teldrive/internal/config"
|
||||
"github.com/divyam234/teldrive/internal/tgc"
|
||||
"github.com/divyam234/teldrive/pkg/schemas"
|
||||
"github.com/divyam234/teldrive/pkg/types"
|
||||
"github.com/gotd/td/tg"
|
||||
)
|
||||
|
||||
func calculatePartByteRanges(startByte, endByte, partSize int64) []types.Range {
|
||||
|
@ -41,34 +44,40 @@ func calculatePartByteRanges(startByte, endByte, partSize int64) []types.Range {
|
|||
|
||||
type LinearReader struct {
|
||||
ctx context.Context
|
||||
file *schemas.FileOutFull
|
||||
parts []types.Part
|
||||
ranges []types.Range
|
||||
pos int
|
||||
reader io.ReadCloser
|
||||
limit int64
|
||||
config *config.TGConfig
|
||||
channelID int64
|
||||
worker *tgc.StreamWorker
|
||||
client *tgc.Client
|
||||
fileID string
|
||||
client *tg.Client
|
||||
concurrency int
|
||||
cache cache.Cacher
|
||||
}
|
||||
|
||||
func NewLinearReader(ctx context.Context, fileID string, parts []types.Part, start, end int64,
|
||||
channelID int64, config *config.TGConfig, concurrency int, client *tgc.Client,
|
||||
worker *tgc.StreamWorker, cache cache.Cacher) (io.ReadCloser, error) {
|
||||
func NewLinearReader(ctx context.Context,
|
||||
client *tg.Client,
|
||||
worker *tgc.StreamWorker,
|
||||
cache cache.Cacher,
|
||||
file *schemas.FileOutFull,
|
||||
parts []types.Part,
|
||||
start,
|
||||
end int64,
|
||||
config *config.TGConfig,
|
||||
concurrency int,
|
||||
) (io.ReadCloser, error) {
|
||||
|
||||
r := &LinearReader{
|
||||
ctx: ctx,
|
||||
parts: parts,
|
||||
file: file,
|
||||
limit: end - start + 1,
|
||||
ranges: calculatePartByteRanges(start, end, parts[0].Size),
|
||||
config: config,
|
||||
client: client,
|
||||
worker: worker,
|
||||
channelID: channelID,
|
||||
fileID: fileID,
|
||||
concurrency: concurrency,
|
||||
cache: cache,
|
||||
}
|
||||
|
@ -108,14 +117,16 @@ func (r *LinearReader) nextPart() (io.ReadCloser, error) {
|
|||
start := r.ranges[r.pos].Start
|
||||
end := r.ranges[r.pos].End
|
||||
|
||||
partID := r.parts[r.ranges[r.pos].PartNo].ID
|
||||
|
||||
chunkSrc := &chunkSource{
|
||||
channelID: r.channelID,
|
||||
worker: r.worker,
|
||||
fileID: r.fileID,
|
||||
partID: r.parts[r.ranges[r.pos].PartNo].ID,
|
||||
channelID: r.file.ChannelID,
|
||||
partID: partID,
|
||||
client: r.client,
|
||||
concurrency: r.concurrency,
|
||||
cache: r.cache,
|
||||
key: fmt.Sprintf("files:location:%s:%d", r.file.Id, partID),
|
||||
worker: r.worker,
|
||||
}
|
||||
|
||||
if r.concurrency < 2 {
|
||||
|
|
|
@ -26,12 +26,12 @@ type ChunkSource interface {
|
|||
|
||||
type chunkSource struct {
|
||||
channelID int64
|
||||
worker *tgc.StreamWorker
|
||||
fileID string
|
||||
partID int64
|
||||
concurrency int
|
||||
client *tgc.Client
|
||||
client *tg.Client
|
||||
key string
|
||||
cache cache.Cacher
|
||||
worker *tgc.StreamWorker
|
||||
}
|
||||
|
||||
func (c *chunkSource) ChunkSize(start, end int64) int64 {
|
||||
|
@ -42,27 +42,31 @@ func (c *chunkSource) Chunk(ctx context.Context, offset int64, limit int64) ([]b
|
|||
var (
|
||||
location *tg.InputDocumentFileLocation
|
||||
err error
|
||||
client *tgc.Client
|
||||
client *tg.Client
|
||||
)
|
||||
|
||||
err = c.cache.Get(c.key, location)
|
||||
|
||||
client = c.client
|
||||
|
||||
defer func() {
|
||||
if c.concurrency > 0 && client != nil {
|
||||
defer c.worker.Release(client)
|
||||
}
|
||||
}()
|
||||
|
||||
if c.concurrency > 0 {
|
||||
client, _, _ = c.worker.Next(c.channelID)
|
||||
tc, err := c.worker.Next(c.channelID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
client = tc.Tg.API()
|
||||
|
||||
}
|
||||
location, err = tgc.GetLocation(ctx, client, c.cache, c.fileID, c.channelID, c.partID)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
location, err = tgc.GetLocation(ctx, client, c.channelID, c.partID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.cache.Set(c.key, location, 30*time.Minute)
|
||||
}
|
||||
|
||||
return tgc.GetChunk(ctx, client.Tg.API(), location, offset, limit)
|
||||
return tgc.GetChunk(ctx, client, location, offset, limit)
|
||||
|
||||
}
|
||||
|
||||
|
|
|
@ -46,11 +46,11 @@ type TestSuite struct {
|
|||
|
||||
func (suite *TestSuite) SetupTest() {
|
||||
suite.config = &config.TGConfig{Stream: struct {
|
||||
BotsOffset int
|
||||
BotsLimit int
|
||||
MultiThreads int
|
||||
Buffers int
|
||||
ChunkTimeout time.Duration
|
||||
}{BotsOffset: 1, MultiThreads: 8, Buffers: 10, ChunkTimeout: 1 * time.Second}}
|
||||
}{MultiThreads: 8, Buffers: 10, ChunkTimeout: 1 * time.Second}}
|
||||
}
|
||||
|
||||
func (suite *TestSuite) TestFullRead() {
|
||||
|
|
|
@ -8,9 +8,7 @@ import (
|
|||
"math"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/divyam234/teldrive/internal/cache"
|
||||
"github.com/divyam234/teldrive/internal/config"
|
||||
"github.com/divyam234/teldrive/internal/kv"
|
||||
"github.com/divyam234/teldrive/pkg/types"
|
||||
|
@ -199,44 +197,39 @@ func GetBotInfo(ctx context.Context, KV kv.KV, config *config.TGConfig, token st
|
|||
return &types.BotInfo{Id: user.ID, UserName: user.Username, Token: token}, nil
|
||||
}
|
||||
|
||||
func GetLocation(ctx context.Context, client *Client, cache cache.Cacher, fileId string, channelId int64, partId int64) (location *tg.InputDocumentFileLocation, err error) {
|
||||
func GetLocation(ctx context.Context, client *tg.Client, channelId int64, partId int64) (location *tg.InputDocumentFileLocation, err error) {
|
||||
|
||||
key := fmt.Sprintf("files:location:%s:%s:%d", client.UserID, fileId, partId)
|
||||
|
||||
err = cache.Get(key, location)
|
||||
channel, err := GetChannelById(ctx, client, channelId)
|
||||
|
||||
if err != nil {
|
||||
channel, err := GetChannelById(ctx, client.Tg.API(), channelId)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
messageRequest := tg.ChannelsGetMessagesRequest{
|
||||
Channel: channel,
|
||||
ID: []tg.InputMessageClass{&tg.InputMessageID{ID: int(partId)}},
|
||||
}
|
||||
|
||||
res, err := client.Tg.API().ChannelsGetMessages(ctx, &messageRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
messages, _ := res.(*tg.MessagesChannelMessages)
|
||||
|
||||
if len(messages.Messages) == 0 {
|
||||
return nil, errors.New("no messages found")
|
||||
}
|
||||
|
||||
switch item := messages.Messages[0].(type) {
|
||||
case *tg.MessageEmpty:
|
||||
return nil, errors.New("no messages found")
|
||||
case *tg.Message:
|
||||
media := item.Media.(*tg.MessageMediaDocument)
|
||||
document := media.Document.(*tg.Document)
|
||||
location = document.AsInputDocumentFileLocation()
|
||||
cache.Set(key, location, 30*time.Minute)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
messageRequest := tg.ChannelsGetMessagesRequest{
|
||||
Channel: channel,
|
||||
ID: []tg.InputMessageClass{&tg.InputMessageID{ID: int(partId)}},
|
||||
}
|
||||
|
||||
res, err := client.ChannelsGetMessages(ctx, &messageRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
messages, _ := res.(*tg.MessagesChannelMessages)
|
||||
|
||||
if len(messages.Messages) == 0 {
|
||||
return nil, errors.New("no messages found")
|
||||
}
|
||||
|
||||
switch item := messages.Messages[0].(type) {
|
||||
case *tg.MessageEmpty:
|
||||
return nil, errors.New("no messages found")
|
||||
case *tg.Message:
|
||||
media := item.Media.(*tg.MessageMediaDocument)
|
||||
document := media.Document.(*tg.Document)
|
||||
location = document.AsInputDocumentFileLocation()
|
||||
|
||||
}
|
||||
|
||||
return location, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -2,39 +2,37 @@ package tgc
|
|||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/divyam234/teldrive/internal/config"
|
||||
"github.com/divyam234/teldrive/internal/kv"
|
||||
"github.com/divyam234/teldrive/internal/logging"
|
||||
"github.com/gotd/td/telegram"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type UploadWorker struct {
|
||||
type BotWorker struct {
|
||||
mu sync.RWMutex
|
||||
bots map[int64][]string
|
||||
currIdx map[int64]int
|
||||
}
|
||||
|
||||
func NewUploadWorker() *UploadWorker {
|
||||
return &UploadWorker{
|
||||
func NewBotWorker() *BotWorker {
|
||||
return &BotWorker{
|
||||
bots: make(map[int64][]string),
|
||||
currIdx: make(map[int64]int),
|
||||
}
|
||||
}
|
||||
|
||||
func (w *UploadWorker) Set(bots []string, channelID int64) {
|
||||
func (w *BotWorker) Set(bots []string, channelID int64) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
w.bots[channelID] = bots
|
||||
w.currIdx[channelID] = 0
|
||||
}
|
||||
|
||||
func (w *UploadWorker) Next(channelID int64) (string, int) {
|
||||
func (w *BotWorker) Next(channelID int64) (string, int) {
|
||||
w.mu.RLock()
|
||||
defer w.mu.RUnlock()
|
||||
bots := w.bots[channelID]
|
||||
|
@ -43,40 +41,48 @@ func (w *UploadWorker) Next(channelID int64) (string, int) {
|
|||
return bots[index], index
|
||||
}
|
||||
|
||||
type ClientStatus int
|
||||
|
||||
const (
|
||||
StatusIdle ClientStatus = iota
|
||||
StatusBusy
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
Tg *telegram.Client
|
||||
Stop StopFunc
|
||||
Status string
|
||||
UserID string
|
||||
LastUsed time.Time
|
||||
Connections int
|
||||
Tg *telegram.Client
|
||||
Stop StopFunc
|
||||
Status ClientStatus
|
||||
UserID string
|
||||
}
|
||||
|
||||
type StreamWorker struct {
|
||||
mu sync.RWMutex
|
||||
clients map[string]*Client
|
||||
currIdx map[int64]int
|
||||
channelBots map[int64][]string
|
||||
cnf *config.TGConfig
|
||||
kv kv.KV
|
||||
ctx context.Context
|
||||
logger *zap.SugaredLogger
|
||||
mu sync.RWMutex
|
||||
clients map[string]*Client
|
||||
currIdx map[int64]int
|
||||
channelBots map[int64][]string
|
||||
cnf *config.TGConfig
|
||||
kv kv.KV
|
||||
ctx context.Context
|
||||
logger *zap.SugaredLogger
|
||||
activeStreams int
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func NewStreamWorker(ctx context.Context) func(cnf *config.Config, kv kv.KV) *StreamWorker {
|
||||
return func(cnf *config.Config, kv kv.KV) *StreamWorker {
|
||||
worker := &StreamWorker{
|
||||
cnf: &cnf.TG,
|
||||
kv: kv,
|
||||
ctx: ctx,
|
||||
clients: make(map[string]*Client),
|
||||
currIdx: make(map[int64]int),
|
||||
channelBots: make(map[int64][]string),
|
||||
logger: logging.FromContext(ctx),
|
||||
}
|
||||
go worker.startIdleClientMonitor()
|
||||
return worker
|
||||
func NewStreamWorker(cnf *config.Config, kv kv.KV, logger *zap.SugaredLogger) *StreamWorker {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
worker := &StreamWorker{
|
||||
cnf: &cnf.TG,
|
||||
kv: kv,
|
||||
ctx: ctx,
|
||||
clients: make(map[string]*Client),
|
||||
currIdx: make(map[int64]int),
|
||||
channelBots: make(map[int64][]string),
|
||||
logger: logger,
|
||||
cancel: cancel,
|
||||
}
|
||||
go worker.startIdleClientMonitor()
|
||||
return worker
|
||||
|
||||
}
|
||||
|
||||
func (w *StreamWorker) Set(bots []string, channelID int64) {
|
||||
|
@ -86,7 +92,7 @@ func (w *StreamWorker) Set(bots []string, channelID int64) {
|
|||
w.currIdx[channelID] = 0
|
||||
}
|
||||
|
||||
func (w *StreamWorker) Next(channelID int64) (*Client, int, error) {
|
||||
func (w *StreamWorker) Next(channelID int64) (*Client, error) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
|
@ -97,75 +103,51 @@ func (w *StreamWorker) Next(channelID int64) (*Client, int, error) {
|
|||
|
||||
client, err := w.getOrCreateClient(userID, token)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
w.currIdx[channelID] = (index + 1) % len(bots)
|
||||
client.LastUsed = time.Now()
|
||||
client.Connections++
|
||||
if client.Connections == 1 {
|
||||
client.Status = "serving"
|
||||
}
|
||||
|
||||
return client, index, nil
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (w *StreamWorker) IncActiveStream() error {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
w.activeStreams++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *StreamWorker) DecActiveStreams() error {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
if w.activeStreams == 0 {
|
||||
return nil
|
||||
}
|
||||
w.activeStreams--
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *StreamWorker) getOrCreateClient(userID, token string) (*Client, error) {
|
||||
client, ok := w.clients[userID]
|
||||
if !ok || (client.Status == "idle" && client.Stop == nil) {
|
||||
if !ok || (client.Status == StatusIdle && client.Stop == nil) {
|
||||
middlewares := Middlewares(w.cnf, 5)
|
||||
tgClient, _ := BotClient(w.ctx, w.kv, w.cnf, token, middlewares...)
|
||||
client = &Client{Tg: tgClient, Status: "idle", UserID: userID}
|
||||
client = &Client{Tg: tgClient, Status: StatusIdle, UserID: userID}
|
||||
w.clients[userID] = client
|
||||
|
||||
stop, err := Connect(client.Tg, WithBotToken(token))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
client.Stop = stop
|
||||
w.logger.Debug("started bg client: ", client.UserID)
|
||||
client.Status = StatusBusy
|
||||
w.logger.Debug("started bg client: ", userID)
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (w *StreamWorker) Release(client *Client) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
client.Connections--
|
||||
if client.Connections == 0 {
|
||||
client.Status = "running"
|
||||
}
|
||||
}
|
||||
|
||||
func (w *StreamWorker) UserWorker(session string, userID int64) (*Client, error) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
id := strconv.FormatInt(userID, 10)
|
||||
client, ok := w.clients[id]
|
||||
if !ok || (client.Status == "idle" && client.Stop == nil) {
|
||||
middlewares := Middlewares(w.cnf, 5)
|
||||
tgClient, _ := AuthClient(w.ctx, w.cnf, session, middlewares...)
|
||||
client = &Client{Tg: tgClient, Status: "idle", UserID: id}
|
||||
w.clients[id] = client
|
||||
|
||||
stop, err := Connect(client.Tg, WithContext(w.ctx))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
client.Stop = stop
|
||||
w.logger.Debug("started bg client: ", client.UserID)
|
||||
}
|
||||
|
||||
client.LastUsed = time.Now()
|
||||
client.Connections++
|
||||
if client.Connections == 1 {
|
||||
client.Status = "serving"
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (w *StreamWorker) startIdleClientMonitor() {
|
||||
ticker := time.NewTicker(w.cnf.BgBotsCheckInterval)
|
||||
defer ticker.Stop()
|
||||
|
@ -183,15 +165,16 @@ func (w *StreamWorker) startIdleClientMonitor() {
|
|||
func (w *StreamWorker) checkIdleClients() {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
for _, client := range w.clients {
|
||||
if client.Status == "running" && time.Since(client.LastUsed) > w.cnf.BgBotsTimeout {
|
||||
if client.Stop != nil {
|
||||
if w.activeStreams == 0 {
|
||||
for _, client := range w.clients {
|
||||
if client.Status == StatusBusy && client.Stop != nil {
|
||||
client.Stop()
|
||||
client.Stop = nil
|
||||
client.Status = "idle"
|
||||
client.Tg = nil
|
||||
client.Status = StatusIdle
|
||||
w.logger.Debug("stopped bg client: ", client.UserID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -11,16 +11,17 @@ import (
|
|||
"github.com/divyam234/teldrive/pkg/models"
|
||||
"github.com/divyam234/teldrive/pkg/schemas"
|
||||
"github.com/divyam234/teldrive/pkg/types"
|
||||
"github.com/gotd/td/telegram"
|
||||
"github.com/gotd/td/tg"
|
||||
"github.com/pkg/errors"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func getParts(ctx context.Context, client *tg.Client, cache cache.Cacher, file *schemas.FileOutFull, userID string) ([]types.Part, error) {
|
||||
func getParts(ctx context.Context, client *telegram.Client, cache cache.Cacher, file *schemas.FileOutFull) ([]types.Part, error) {
|
||||
|
||||
parts := []types.Part{}
|
||||
|
||||
key := fmt.Sprintf("files:messages:%s:%s", file.Id, userID)
|
||||
key := fmt.Sprintf("files:messages:%s", file.Id)
|
||||
|
||||
err := cache.Get(key, &parts)
|
||||
|
||||
|
@ -32,7 +33,7 @@ func getParts(ctx context.Context, client *tg.Client, cache cache.Cacher, file *
|
|||
for _, part := range file.Parts {
|
||||
ids = append(ids, int(part.ID))
|
||||
}
|
||||
messages, err := tgc.GetMessages(ctx, client, ids, file.ChannelID)
|
||||
messages, err := tgc.GetMessages(ctx, client.API(), ids, file.ChannelID)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -22,7 +22,7 @@ import (
|
|||
"github.com/divyam234/teldrive/internal/config"
|
||||
"github.com/divyam234/teldrive/internal/database"
|
||||
"github.com/divyam234/teldrive/internal/http_range"
|
||||
"github.com/divyam234/teldrive/internal/logging"
|
||||
"github.com/divyam234/teldrive/internal/kv"
|
||||
"github.com/divyam234/teldrive/internal/md5"
|
||||
"github.com/divyam234/teldrive/internal/reader"
|
||||
"github.com/divyam234/teldrive/internal/tgc"
|
||||
|
@ -32,6 +32,7 @@ import (
|
|||
"github.com/divyam234/teldrive/pkg/schemas"
|
||||
"github.com/divyam234/teldrive/pkg/types"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gotd/td/telegram"
|
||||
"github.com/gotd/td/tg"
|
||||
"go.uber.org/zap"
|
||||
|
||||
|
@ -76,14 +77,24 @@ func randInt64() (int64, error) {
|
|||
}
|
||||
|
||||
type FileService struct {
|
||||
db *gorm.DB
|
||||
cnf *config.Config
|
||||
worker *tgc.StreamWorker
|
||||
cache cache.Cacher
|
||||
db *gorm.DB
|
||||
cnf *config.Config
|
||||
worker *tgc.StreamWorker
|
||||
botWorker *tgc.BotWorker
|
||||
cache cache.Cacher
|
||||
kv kv.KV
|
||||
logger *zap.SugaredLogger
|
||||
}
|
||||
|
||||
func NewFileService(db *gorm.DB, cnf *config.Config, worker *tgc.StreamWorker, cache cache.Cacher) *FileService {
|
||||
return &FileService{db: db, cnf: cnf, worker: worker, cache: cache}
|
||||
func NewFileService(
|
||||
db *gorm.DB,
|
||||
cnf *config.Config,
|
||||
worker *tgc.StreamWorker,
|
||||
botWorker *tgc.BotWorker,
|
||||
kv kv.KV,
|
||||
cache cache.Cacher,
|
||||
logger *zap.SugaredLogger) *FileService {
|
||||
return &FileService{db: db, cnf: cnf, worker: worker, botWorker: botWorker, cache: cache, kv: kv, logger: logger}
|
||||
}
|
||||
|
||||
func (fs *FileService) CreateFile(c *gin.Context, userId int64, fileIn *schemas.FileIn) (*schemas.FileOut, *types.AppError) {
|
||||
|
@ -716,84 +727,96 @@ func (fs *FileService) GetFileStream(c *gin.Context, download bool) {
|
|||
|
||||
tokens, err := getBotsToken(fs.db, fs.cache, session.UserId, file.ChannelID)
|
||||
|
||||
logger := logging.FromContext(c)
|
||||
if err != nil {
|
||||
logger.Error("failed to get bots", zap.Error(err))
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
fs.handleError(fmt.Errorf("failed to get bots: %w", err), w)
|
||||
return
|
||||
}
|
||||
|
||||
var (
|
||||
channelUser string
|
||||
lr io.ReadCloser
|
||||
client *tgc.Client
|
||||
client *telegram.Client
|
||||
multiThreads int
|
||||
token string
|
||||
)
|
||||
|
||||
multiThreads = fs.cnf.TG.Stream.MultiThreads
|
||||
|
||||
defer func() {
|
||||
if client != nil {
|
||||
fs.worker.Release(client)
|
||||
}
|
||||
}()
|
||||
|
||||
if fs.cnf.TG.DisableStreamBots || len(tokens) == 0 {
|
||||
client, err = fs.worker.UserWorker(session.Session, session.UserId)
|
||||
client, err = tgc.AuthClient(c, &fs.cnf.TG, session.Session)
|
||||
if err != nil {
|
||||
logger.Error(ErrorStreamAbandoned, err)
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
fs.handleError(err, w)
|
||||
return
|
||||
}
|
||||
channelUser = strconv.FormatInt(session.UserId, 10)
|
||||
multiThreads = 0
|
||||
|
||||
} else {
|
||||
offset := fs.cnf.TG.Stream.BotsOffset - 1
|
||||
limit := min(len(tokens), fs.cnf.TG.BgBotsLimit+offset)
|
||||
fs.worker.Set(tokens[offset:limit], file.ChannelID)
|
||||
client, _, err = fs.worker.Next(file.ChannelID)
|
||||
} else if fs.cnf.TG.DisableBgBots && len(tokens) > 0 {
|
||||
fs.botWorker.Set(tokens, file.ChannelID)
|
||||
token, _ = fs.botWorker.Next(file.ChannelID)
|
||||
client, err = tgc.BotClient(c, fs.kv, &fs.cnf.TG, token)
|
||||
if err != nil {
|
||||
logger.Error(ErrorStreamAbandoned, err)
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
fs.handleError(err, w)
|
||||
}
|
||||
multiThreads = 0
|
||||
} else {
|
||||
fs.worker.Set(tokens[0:fs.cnf.TG.Stream.BotsLimit], file.ChannelID)
|
||||
c, err := fs.worker.Next(file.ChannelID)
|
||||
if err != nil {
|
||||
fs.handleError(err, w)
|
||||
return
|
||||
}
|
||||
client = c.Tg
|
||||
}
|
||||
|
||||
if download {
|
||||
multiThreads = 0
|
||||
}
|
||||
|
||||
if r.Method != "HEAD" {
|
||||
parts, err := getParts(c, client.Tg.API(), fs.cache, file, channelUser)
|
||||
if err != nil {
|
||||
logger.Error(ErrorStreamAbandoned, err)
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
handleStream := func() error {
|
||||
parts, err := getParts(c, client, fs.cache, file)
|
||||
if err != nil {
|
||||
fs.handleError(err, w)
|
||||
return nil
|
||||
}
|
||||
if file.Encrypted {
|
||||
lr, err = reader.NewDecryptedReader(c, client.API(), fs.worker, fs.cache, file, parts, start, end, &fs.cnf.TG, multiThreads)
|
||||
} else {
|
||||
lr, err = reader.NewLinearReader(c, client.API(), fs.worker, fs.cache, file, parts, start, end, &fs.cnf.TG, multiThreads)
|
||||
}
|
||||
|
||||
if download {
|
||||
multiThreads = 0
|
||||
if err != nil {
|
||||
fs.handleError(err, w)
|
||||
return nil
|
||||
}
|
||||
if lr == nil {
|
||||
fs.handleError(fmt.Errorf("failed to initialise reader"), w)
|
||||
return nil
|
||||
}
|
||||
_, err = io.CopyN(w, lr, contentLength)
|
||||
if err != nil {
|
||||
lr.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if file.Encrypted {
|
||||
lr, err = reader.NewDecryptedReader(c, file.Id, parts, start, end, file.ChannelID, &fs.cnf.TG, multiThreads, client, fs.worker, fs.cache)
|
||||
if fs.cnf.TG.DisableBgBots {
|
||||
tgc.RunWithAuth(c, client, token, func(ctx context.Context) error {
|
||||
return handleStream()
|
||||
})
|
||||
} else {
|
||||
lr, err = reader.NewLinearReader(c, file.Id, parts, start, end, file.ChannelID, &fs.cnf.TG, multiThreads, client, fs.worker, fs.cache)
|
||||
fs.worker.IncActiveStream()
|
||||
defer fs.worker.DecActiveStreams()
|
||||
handleStream()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
logger.Error(ErrorStreamAbandoned, err)
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if lr == nil {
|
||||
http.Error(w, "failed to initialise reader", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
_, err = io.CopyN(w, lr, contentLength)
|
||||
if err != nil {
|
||||
lr.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (fs *FileService) handleError(err error, w http.ResponseWriter) {
|
||||
fs.logger.Error(err)
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
|
||||
}
|
||||
|
||||
func getOrder(fquery *schemas.FileQuery) clause.OrderByColumn {
|
||||
sortColumn := utils.CamelToSnake(fquery.Sort)
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ type FileServiceSuite struct {
|
|||
|
||||
func (s *FileServiceSuite) SetupSuite() {
|
||||
s.db = database.NewTestDatabase(s.T(), false)
|
||||
s.srv = NewFileService(s.db, nil, nil, nil)
|
||||
s.srv = NewFileService(s.db, nil, nil, nil, nil, nil, nil)
|
||||
}
|
||||
|
||||
func (s *FileServiceSuite) SetupTest() {
|
||||
|
|
|
@ -39,13 +39,13 @@ const saltLength = 32
|
|||
|
||||
type UploadService struct {
|
||||
db *gorm.DB
|
||||
worker *tgc.UploadWorker
|
||||
worker *tgc.BotWorker
|
||||
cnf *config.TGConfig
|
||||
kv kv.KV
|
||||
cache cache.Cacher
|
||||
}
|
||||
|
||||
func NewUploadService(db *gorm.DB, cnf *config.Config, worker *tgc.UploadWorker, kv kv.KV, cache cache.Cacher) *UploadService {
|
||||
func NewUploadService(db *gorm.DB, cnf *config.Config, worker *tgc.BotWorker, kv kv.KV, cache cache.Cacher) *UploadService {
|
||||
return &UploadService{db: db, worker: worker, cnf: &cnf.TG, kv: kv, cache: cache}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue