refactor: Update dependencies and refactor multireader and workers

This commit is contained in:
divyam234 2024-08-10 00:07:56 +05:30
parent 1532bc0418
commit c7149701f1
12 changed files with 257 additions and 239 deletions

View file

@ -99,8 +99,8 @@ func NewRun() *cobra.Command {
runCmd.Flags().StringVar(&config.TG.SystemLangCode, "tg-system-lang-code", "en-US", "System language code")
runCmd.Flags().StringVar(&config.TG.LangPack, "tg-lang-pack", "webk", "Language pack")
runCmd.Flags().StringVar(&config.TG.Proxy, "tg-proxy", "", "HTTP OR SOCKS5 proxy URL")
runCmd.Flags().IntVar(&config.TG.BgBotsLimit, "tg-bg-bots-limit", 5, "Background bots limit")
runCmd.Flags().BoolVar(&config.TG.DisableStreamBots, "tg-disable-stream-bots", false, "Disable stream bots")
runCmd.Flags().BoolVar(&config.TG.DisableBgBots, "tg-disable-bg-bots", false, "Disable Background bots")
runCmd.Flags().BoolVar(&config.TG.DisableStreamBots, "tg-disable-stream-bots", false, "Disable Stream bots")
runCmd.Flags().BoolVar(&config.TG.EnableLogging, "tg-enable-logging", false, "Enable telegram client logging")
runCmd.Flags().StringVar(&config.TG.Uploads.EncryptionKey, "tg-uploads-encryption-key", "", "Uploads encryption key")
runCmd.Flags().IntVar(&config.TG.Uploads.Threads, "tg-uploads-threads", 8, "Uploads threads")
@ -108,12 +108,11 @@ func NewRun() *cobra.Command {
runCmd.Flags().Int64Var(&config.TG.PoolSize, "tg-pool-size", 8, "Telegram Session pool size")
duration.DurationVar(runCmd.Flags(), &config.TG.ReconnectTimeout, "tg-reconnect-timeout", 5*time.Minute, "Reconnect Timeout")
duration.DurationVar(runCmd.Flags(), &config.TG.Uploads.Retention, "tg-uploads-retention", (24*7)*time.Hour, "Uploads retention duration")
duration.DurationVar(runCmd.Flags(), &config.TG.BgBotsCheckInterval, "tg-bg-bots-check-interval", 4*time.Hour, "Interval for checking Idle background bots")
runCmd.Flags().IntVar(&config.TG.Stream.MultiThreads, "tg-stream-multi-threads", 0, "Stream multi-threads")
runCmd.Flags().IntVar(&config.TG.Stream.Buffers, "tg-stream-buffers", 8, "No of Stream buffers")
runCmd.Flags().IntVar(&config.TG.Stream.BotsOffset, "tg-stream-bots-offset", 1, "Stream Bots Offset")
duration.DurationVar(runCmd.Flags(), &config.TG.Stream.ChunkTimeout, "tg-stream-chunk-timeout", 30*time.Second, "Chunk Fetch Timeout")
duration.DurationVar(runCmd.Flags(), &config.TG.BgBotsTimeout, "tg-bg-bots-timeout", 30*time.Minute, "Stop Timeout for Idle background bots")
duration.DurationVar(runCmd.Flags(), &config.TG.BgBotsCheckInterval, "tg-bg-bots-check-interval", 5*time.Minute, "Interval for checking Idle background bots")
runCmd.Flags().IntVar(&config.TG.Stream.BotsLimit, "tg-stream-bots-limit", 5, "Stream bots limit")
duration.DurationVar(runCmd.Flags(), &config.TG.Stream.ChunkTimeout, "tg-stream-chunk-timeout", 20*time.Second, "Chunk Fetch Timeout")
runCmd.MarkFlagRequired("tg-app-id")
runCmd.MarkFlagRequired("tg-app-hash")
runCmd.MarkFlagRequired("db-data-source")
@ -147,13 +146,14 @@ func runApplication(conf *config.Config) {
return cacher
}),
fx.Supply(logging.DefaultLogger().Desugar()),
fx.Supply(logging.DefaultLogger()),
fx.NopLogger,
fx.StopTimeout(conf.Server.GracefulShutdown+time.Second),
fx.Provide(
database.NewDatabase,
kv.NewBoltKV,
tgc.NewStreamWorker(ctx),
tgc.NewUploadWorker,
tgc.NewBotWorker,
tgc.NewStreamWorker,
services.NewAuthService,
services.NewFileService,
services.NewUploadService,

View file

@ -47,9 +47,8 @@ type TGConfig struct {
SystemLangCode string
LangPack string
SessionFile string
BgBotsLimit int
DisableBgBots bool
DisableStreamBots bool
BgBotsTimeout time.Duration
BgBotsCheckInterval time.Duration
Proxy string
ReconnectTimeout time.Duration
@ -62,7 +61,7 @@ type TGConfig struct {
Retention time.Duration
}
Stream struct {
BotsOffset int
BotsLimit int
MultiThreads int
Buffers int
ChunkTimeout time.Duration

View file

@ -2,53 +2,55 @@ package reader
import (
"context"
"fmt"
"io"
"github.com/divyam234/teldrive/internal/cache"
"github.com/divyam234/teldrive/internal/config"
"github.com/divyam234/teldrive/internal/crypt"
"github.com/divyam234/teldrive/internal/tgc"
"github.com/divyam234/teldrive/pkg/schemas"
"github.com/divyam234/teldrive/pkg/types"
"github.com/gotd/td/tg"
)
type decrpytedReader struct {
ctx context.Context
file *schemas.FileOutFull
parts []types.Part
ranges []types.Range
pos int
reader io.ReadCloser
limit int64
config *config.TGConfig
channelId int64
worker *tgc.StreamWorker
client *tgc.Client
fileId string
client *tg.Client
concurrency int
cache cache.Cacher
}
func NewDecryptedReader(
ctx context.Context,
fileId string,
client *tg.Client,
worker *tgc.StreamWorker,
cache cache.Cacher,
file *schemas.FileOutFull,
parts []types.Part,
start, end int64,
channelId int64,
start,
end int64,
config *config.TGConfig,
concurrency int,
client *tgc.Client,
worker *tgc.StreamWorker,
cache cache.Cacher) (*decrpytedReader, error) {
) (*decrpytedReader, error) {
r := &decrpytedReader{
ctx: ctx,
parts: parts,
file: file,
limit: end - start + 1,
ranges: calculatePartByteRanges(start, end, parts[0].DecryptedSize),
config: config,
client: client,
worker: worker,
channelId: channelId,
fileId: fileId,
concurrency: concurrency,
cache: cache,
}
@ -113,14 +115,16 @@ func (r *decrpytedReader) nextPart() (io.ReadCloser, error) {
if underlyingLimit >= 0 {
end = min(r.parts[r.ranges[r.pos].PartNo].Size-1, underlyingOffset+underlyingLimit-1)
}
partID := r.parts[r.ranges[r.pos].PartNo].ID
chunkSrc := &chunkSource{
channelID: r.channelId,
worker: r.worker,
fileID: r.fileId,
partID: r.parts[r.ranges[r.pos].PartNo].ID,
channelID: r.file.ChannelID,
partID: partID,
client: r.client,
concurrency: r.concurrency,
cache: r.cache,
key: fmt.Sprintf("files:location:%s:%d", r.file.Id, partID),
worker: r.worker,
}
if r.concurrency < 2 {
return newTGReader(r.ctx, underlyingOffset, end, chunkSrc)

View file

@ -2,12 +2,15 @@ package reader
import (
"context"
"fmt"
"io"
"github.com/divyam234/teldrive/internal/cache"
"github.com/divyam234/teldrive/internal/config"
"github.com/divyam234/teldrive/internal/tgc"
"github.com/divyam234/teldrive/pkg/schemas"
"github.com/divyam234/teldrive/pkg/types"
"github.com/gotd/td/tg"
)
func calculatePartByteRanges(startByte, endByte, partSize int64) []types.Range {
@ -41,34 +44,40 @@ func calculatePartByteRanges(startByte, endByte, partSize int64) []types.Range {
type LinearReader struct {
ctx context.Context
file *schemas.FileOutFull
parts []types.Part
ranges []types.Range
pos int
reader io.ReadCloser
limit int64
config *config.TGConfig
channelID int64
worker *tgc.StreamWorker
client *tgc.Client
fileID string
client *tg.Client
concurrency int
cache cache.Cacher
}
func NewLinearReader(ctx context.Context, fileID string, parts []types.Part, start, end int64,
channelID int64, config *config.TGConfig, concurrency int, client *tgc.Client,
worker *tgc.StreamWorker, cache cache.Cacher) (io.ReadCloser, error) {
func NewLinearReader(ctx context.Context,
client *tg.Client,
worker *tgc.StreamWorker,
cache cache.Cacher,
file *schemas.FileOutFull,
parts []types.Part,
start,
end int64,
config *config.TGConfig,
concurrency int,
) (io.ReadCloser, error) {
r := &LinearReader{
ctx: ctx,
parts: parts,
file: file,
limit: end - start + 1,
ranges: calculatePartByteRanges(start, end, parts[0].Size),
config: config,
client: client,
worker: worker,
channelID: channelID,
fileID: fileID,
concurrency: concurrency,
cache: cache,
}
@ -108,14 +117,16 @@ func (r *LinearReader) nextPart() (io.ReadCloser, error) {
start := r.ranges[r.pos].Start
end := r.ranges[r.pos].End
partID := r.parts[r.ranges[r.pos].PartNo].ID
chunkSrc := &chunkSource{
channelID: r.channelID,
worker: r.worker,
fileID: r.fileID,
partID: r.parts[r.ranges[r.pos].PartNo].ID,
channelID: r.file.ChannelID,
partID: partID,
client: r.client,
concurrency: r.concurrency,
cache: r.cache,
key: fmt.Sprintf("files:location:%s:%d", r.file.Id, partID),
worker: r.worker,
}
if r.concurrency < 2 {

View file

@ -26,12 +26,12 @@ type ChunkSource interface {
type chunkSource struct {
channelID int64
worker *tgc.StreamWorker
fileID string
partID int64
concurrency int
client *tgc.Client
client *tg.Client
key string
cache cache.Cacher
worker *tgc.StreamWorker
}
func (c *chunkSource) ChunkSize(start, end int64) int64 {
@ -42,27 +42,31 @@ func (c *chunkSource) Chunk(ctx context.Context, offset int64, limit int64) ([]b
var (
location *tg.InputDocumentFileLocation
err error
client *tgc.Client
client *tg.Client
)
err = c.cache.Get(c.key, location)
client = c.client
defer func() {
if c.concurrency > 0 && client != nil {
defer c.worker.Release(client)
}
}()
if c.concurrency > 0 {
client, _, _ = c.worker.Next(c.channelID)
tc, err := c.worker.Next(c.channelID)
if err != nil {
return nil, err
}
client = tc.Tg.API()
}
location, err = tgc.GetLocation(ctx, client, c.cache, c.fileID, c.channelID, c.partID)
if err != nil {
return nil, err
location, err = tgc.GetLocation(ctx, client, c.channelID, c.partID)
if err != nil {
return nil, err
}
c.cache.Set(c.key, location, 30*time.Minute)
}
return tgc.GetChunk(ctx, client.Tg.API(), location, offset, limit)
return tgc.GetChunk(ctx, client, location, offset, limit)
}

View file

@ -46,11 +46,11 @@ type TestSuite struct {
func (suite *TestSuite) SetupTest() {
suite.config = &config.TGConfig{Stream: struct {
BotsOffset int
BotsLimit int
MultiThreads int
Buffers int
ChunkTimeout time.Duration
}{BotsOffset: 1, MultiThreads: 8, Buffers: 10, ChunkTimeout: 1 * time.Second}}
}{MultiThreads: 8, Buffers: 10, ChunkTimeout: 1 * time.Second}}
}
func (suite *TestSuite) TestFullRead() {

View file

@ -8,9 +8,7 @@ import (
"math"
"runtime"
"sync"
"time"
"github.com/divyam234/teldrive/internal/cache"
"github.com/divyam234/teldrive/internal/config"
"github.com/divyam234/teldrive/internal/kv"
"github.com/divyam234/teldrive/pkg/types"
@ -199,44 +197,39 @@ func GetBotInfo(ctx context.Context, KV kv.KV, config *config.TGConfig, token st
return &types.BotInfo{Id: user.ID, UserName: user.Username, Token: token}, nil
}
func GetLocation(ctx context.Context, client *Client, cache cache.Cacher, fileId string, channelId int64, partId int64) (location *tg.InputDocumentFileLocation, err error) {
func GetLocation(ctx context.Context, client *tg.Client, channelId int64, partId int64) (location *tg.InputDocumentFileLocation, err error) {
key := fmt.Sprintf("files:location:%s:%s:%d", client.UserID, fileId, partId)
err = cache.Get(key, location)
channel, err := GetChannelById(ctx, client, channelId)
if err != nil {
channel, err := GetChannelById(ctx, client.Tg.API(), channelId)
if err != nil {
return nil, err
}
messageRequest := tg.ChannelsGetMessagesRequest{
Channel: channel,
ID: []tg.InputMessageClass{&tg.InputMessageID{ID: int(partId)}},
}
res, err := client.Tg.API().ChannelsGetMessages(ctx, &messageRequest)
if err != nil {
return nil, err
}
messages, _ := res.(*tg.MessagesChannelMessages)
if len(messages.Messages) == 0 {
return nil, errors.New("no messages found")
}
switch item := messages.Messages[0].(type) {
case *tg.MessageEmpty:
return nil, errors.New("no messages found")
case *tg.Message:
media := item.Media.(*tg.MessageMediaDocument)
document := media.Document.(*tg.Document)
location = document.AsInputDocumentFileLocation()
cache.Set(key, location, 30*time.Minute)
}
return nil, err
}
messageRequest := tg.ChannelsGetMessagesRequest{
Channel: channel,
ID: []tg.InputMessageClass{&tg.InputMessageID{ID: int(partId)}},
}
res, err := client.ChannelsGetMessages(ctx, &messageRequest)
if err != nil {
return nil, err
}
messages, _ := res.(*tg.MessagesChannelMessages)
if len(messages.Messages) == 0 {
return nil, errors.New("no messages found")
}
switch item := messages.Messages[0].(type) {
case *tg.MessageEmpty:
return nil, errors.New("no messages found")
case *tg.Message:
media := item.Media.(*tg.MessageMediaDocument)
document := media.Document.(*tg.Document)
location = document.AsInputDocumentFileLocation()
}
return location, nil
}

View file

@ -2,39 +2,37 @@ package tgc
import (
"context"
"strconv"
"strings"
"sync"
"time"
"github.com/divyam234/teldrive/internal/config"
"github.com/divyam234/teldrive/internal/kv"
"github.com/divyam234/teldrive/internal/logging"
"github.com/gotd/td/telegram"
"go.uber.org/zap"
)
type UploadWorker struct {
type BotWorker struct {
mu sync.RWMutex
bots map[int64][]string
currIdx map[int64]int
}
func NewUploadWorker() *UploadWorker {
return &UploadWorker{
func NewBotWorker() *BotWorker {
return &BotWorker{
bots: make(map[int64][]string),
currIdx: make(map[int64]int),
}
}
func (w *UploadWorker) Set(bots []string, channelID int64) {
func (w *BotWorker) Set(bots []string, channelID int64) {
w.mu.Lock()
defer w.mu.Unlock()
w.bots[channelID] = bots
w.currIdx[channelID] = 0
}
func (w *UploadWorker) Next(channelID int64) (string, int) {
func (w *BotWorker) Next(channelID int64) (string, int) {
w.mu.RLock()
defer w.mu.RUnlock()
bots := w.bots[channelID]
@ -43,40 +41,48 @@ func (w *UploadWorker) Next(channelID int64) (string, int) {
return bots[index], index
}
type ClientStatus int
const (
StatusIdle ClientStatus = iota
StatusBusy
)
type Client struct {
Tg *telegram.Client
Stop StopFunc
Status string
UserID string
LastUsed time.Time
Connections int
Tg *telegram.Client
Stop StopFunc
Status ClientStatus
UserID string
}
type StreamWorker struct {
mu sync.RWMutex
clients map[string]*Client
currIdx map[int64]int
channelBots map[int64][]string
cnf *config.TGConfig
kv kv.KV
ctx context.Context
logger *zap.SugaredLogger
mu sync.RWMutex
clients map[string]*Client
currIdx map[int64]int
channelBots map[int64][]string
cnf *config.TGConfig
kv kv.KV
ctx context.Context
logger *zap.SugaredLogger
activeStreams int
cancel context.CancelFunc
}
func NewStreamWorker(ctx context.Context) func(cnf *config.Config, kv kv.KV) *StreamWorker {
return func(cnf *config.Config, kv kv.KV) *StreamWorker {
worker := &StreamWorker{
cnf: &cnf.TG,
kv: kv,
ctx: ctx,
clients: make(map[string]*Client),
currIdx: make(map[int64]int),
channelBots: make(map[int64][]string),
logger: logging.FromContext(ctx),
}
go worker.startIdleClientMonitor()
return worker
func NewStreamWorker(cnf *config.Config, kv kv.KV, logger *zap.SugaredLogger) *StreamWorker {
ctx, cancel := context.WithCancel(context.Background())
worker := &StreamWorker{
cnf: &cnf.TG,
kv: kv,
ctx: ctx,
clients: make(map[string]*Client),
currIdx: make(map[int64]int),
channelBots: make(map[int64][]string),
logger: logger,
cancel: cancel,
}
go worker.startIdleClientMonitor()
return worker
}
func (w *StreamWorker) Set(bots []string, channelID int64) {
@ -86,7 +92,7 @@ func (w *StreamWorker) Set(bots []string, channelID int64) {
w.currIdx[channelID] = 0
}
func (w *StreamWorker) Next(channelID int64) (*Client, int, error) {
func (w *StreamWorker) Next(channelID int64) (*Client, error) {
w.mu.Lock()
defer w.mu.Unlock()
@ -97,75 +103,51 @@ func (w *StreamWorker) Next(channelID int64) (*Client, int, error) {
client, err := w.getOrCreateClient(userID, token)
if err != nil {
return nil, 0, err
return nil, err
}
w.currIdx[channelID] = (index + 1) % len(bots)
client.LastUsed = time.Now()
client.Connections++
if client.Connections == 1 {
client.Status = "serving"
}
return client, index, nil
return client, nil
}
func (w *StreamWorker) IncActiveStream() error {
w.mu.Lock()
defer w.mu.Unlock()
w.activeStreams++
return nil
}
func (w *StreamWorker) DecActiveStreams() error {
w.mu.Lock()
defer w.mu.Unlock()
if w.activeStreams == 0 {
return nil
}
w.activeStreams--
return nil
}
func (w *StreamWorker) getOrCreateClient(userID, token string) (*Client, error) {
client, ok := w.clients[userID]
if !ok || (client.Status == "idle" && client.Stop == nil) {
if !ok || (client.Status == StatusIdle && client.Stop == nil) {
middlewares := Middlewares(w.cnf, 5)
tgClient, _ := BotClient(w.ctx, w.kv, w.cnf, token, middlewares...)
client = &Client{Tg: tgClient, Status: "idle", UserID: userID}
client = &Client{Tg: tgClient, Status: StatusIdle, UserID: userID}
w.clients[userID] = client
stop, err := Connect(client.Tg, WithBotToken(token))
if err != nil {
return nil, err
}
client.Stop = stop
w.logger.Debug("started bg client: ", client.UserID)
client.Status = StatusBusy
w.logger.Debug("started bg client: ", userID)
}
return client, nil
}
func (w *StreamWorker) Release(client *Client) {
w.mu.Lock()
defer w.mu.Unlock()
client.Connections--
if client.Connections == 0 {
client.Status = "running"
}
}
func (w *StreamWorker) UserWorker(session string, userID int64) (*Client, error) {
w.mu.Lock()
defer w.mu.Unlock()
id := strconv.FormatInt(userID, 10)
client, ok := w.clients[id]
if !ok || (client.Status == "idle" && client.Stop == nil) {
middlewares := Middlewares(w.cnf, 5)
tgClient, _ := AuthClient(w.ctx, w.cnf, session, middlewares...)
client = &Client{Tg: tgClient, Status: "idle", UserID: id}
w.clients[id] = client
stop, err := Connect(client.Tg, WithContext(w.ctx))
if err != nil {
return nil, err
}
client.Stop = stop
w.logger.Debug("started bg client: ", client.UserID)
}
client.LastUsed = time.Now()
client.Connections++
if client.Connections == 1 {
client.Status = "serving"
}
return client, nil
}
func (w *StreamWorker) startIdleClientMonitor() {
ticker := time.NewTicker(w.cnf.BgBotsCheckInterval)
defer ticker.Stop()
@ -183,15 +165,16 @@ func (w *StreamWorker) startIdleClientMonitor() {
func (w *StreamWorker) checkIdleClients() {
w.mu.Lock()
defer w.mu.Unlock()
for _, client := range w.clients {
if client.Status == "running" && time.Since(client.LastUsed) > w.cnf.BgBotsTimeout {
if client.Stop != nil {
if w.activeStreams == 0 {
for _, client := range w.clients {
if client.Status == StatusBusy && client.Stop != nil {
client.Stop()
client.Stop = nil
client.Status = "idle"
client.Tg = nil
client.Status = StatusIdle
w.logger.Debug("stopped bg client: ", client.UserID)
}
}
}
}

View file

@ -11,16 +11,17 @@ import (
"github.com/divyam234/teldrive/pkg/models"
"github.com/divyam234/teldrive/pkg/schemas"
"github.com/divyam234/teldrive/pkg/types"
"github.com/gotd/td/telegram"
"github.com/gotd/td/tg"
"github.com/pkg/errors"
"gorm.io/gorm"
)
func getParts(ctx context.Context, client *tg.Client, cache cache.Cacher, file *schemas.FileOutFull, userID string) ([]types.Part, error) {
func getParts(ctx context.Context, client *telegram.Client, cache cache.Cacher, file *schemas.FileOutFull) ([]types.Part, error) {
parts := []types.Part{}
key := fmt.Sprintf("files:messages:%s:%s", file.Id, userID)
key := fmt.Sprintf("files:messages:%s", file.Id)
err := cache.Get(key, &parts)
@ -32,7 +33,7 @@ func getParts(ctx context.Context, client *tg.Client, cache cache.Cacher, file *
for _, part := range file.Parts {
ids = append(ids, int(part.ID))
}
messages, err := tgc.GetMessages(ctx, client, ids, file.ChannelID)
messages, err := tgc.GetMessages(ctx, client.API(), ids, file.ChannelID)
if err != nil {
return nil, err

View file

@ -22,7 +22,7 @@ import (
"github.com/divyam234/teldrive/internal/config"
"github.com/divyam234/teldrive/internal/database"
"github.com/divyam234/teldrive/internal/http_range"
"github.com/divyam234/teldrive/internal/logging"
"github.com/divyam234/teldrive/internal/kv"
"github.com/divyam234/teldrive/internal/md5"
"github.com/divyam234/teldrive/internal/reader"
"github.com/divyam234/teldrive/internal/tgc"
@ -32,6 +32,7 @@ import (
"github.com/divyam234/teldrive/pkg/schemas"
"github.com/divyam234/teldrive/pkg/types"
"github.com/gin-gonic/gin"
"github.com/gotd/td/telegram"
"github.com/gotd/td/tg"
"go.uber.org/zap"
@ -76,14 +77,24 @@ func randInt64() (int64, error) {
}
type FileService struct {
db *gorm.DB
cnf *config.Config
worker *tgc.StreamWorker
cache cache.Cacher
db *gorm.DB
cnf *config.Config
worker *tgc.StreamWorker
botWorker *tgc.BotWorker
cache cache.Cacher
kv kv.KV
logger *zap.SugaredLogger
}
func NewFileService(db *gorm.DB, cnf *config.Config, worker *tgc.StreamWorker, cache cache.Cacher) *FileService {
return &FileService{db: db, cnf: cnf, worker: worker, cache: cache}
func NewFileService(
db *gorm.DB,
cnf *config.Config,
worker *tgc.StreamWorker,
botWorker *tgc.BotWorker,
kv kv.KV,
cache cache.Cacher,
logger *zap.SugaredLogger) *FileService {
return &FileService{db: db, cnf: cnf, worker: worker, botWorker: botWorker, cache: cache, kv: kv, logger: logger}
}
func (fs *FileService) CreateFile(c *gin.Context, userId int64, fileIn *schemas.FileIn) (*schemas.FileOut, *types.AppError) {
@ -716,84 +727,96 @@ func (fs *FileService) GetFileStream(c *gin.Context, download bool) {
tokens, err := getBotsToken(fs.db, fs.cache, session.UserId, file.ChannelID)
logger := logging.FromContext(c)
if err != nil {
logger.Error("failed to get bots", zap.Error(err))
http.Error(w, err.Error(), http.StatusInternalServerError)
fs.handleError(fmt.Errorf("failed to get bots: %w", err), w)
return
}
var (
channelUser string
lr io.ReadCloser
client *tgc.Client
client *telegram.Client
multiThreads int
token string
)
multiThreads = fs.cnf.TG.Stream.MultiThreads
defer func() {
if client != nil {
fs.worker.Release(client)
}
}()
if fs.cnf.TG.DisableStreamBots || len(tokens) == 0 {
client, err = fs.worker.UserWorker(session.Session, session.UserId)
client, err = tgc.AuthClient(c, &fs.cnf.TG, session.Session)
if err != nil {
logger.Error(ErrorStreamAbandoned, err)
http.Error(w, err.Error(), http.StatusInternalServerError)
fs.handleError(err, w)
return
}
channelUser = strconv.FormatInt(session.UserId, 10)
multiThreads = 0
} else {
offset := fs.cnf.TG.Stream.BotsOffset - 1
limit := min(len(tokens), fs.cnf.TG.BgBotsLimit+offset)
fs.worker.Set(tokens[offset:limit], file.ChannelID)
client, _, err = fs.worker.Next(file.ChannelID)
} else if fs.cnf.TG.DisableBgBots && len(tokens) > 0 {
fs.botWorker.Set(tokens, file.ChannelID)
token, _ = fs.botWorker.Next(file.ChannelID)
client, err = tgc.BotClient(c, fs.kv, &fs.cnf.TG, token)
if err != nil {
logger.Error(ErrorStreamAbandoned, err)
http.Error(w, err.Error(), http.StatusInternalServerError)
fs.handleError(err, w)
}
multiThreads = 0
} else {
fs.worker.Set(tokens[0:fs.cnf.TG.Stream.BotsLimit], file.ChannelID)
c, err := fs.worker.Next(file.ChannelID)
if err != nil {
fs.handleError(err, w)
return
}
client = c.Tg
}
if download {
multiThreads = 0
}
if r.Method != "HEAD" {
parts, err := getParts(c, client.Tg.API(), fs.cache, file, channelUser)
if err != nil {
logger.Error(ErrorStreamAbandoned, err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
handleStream := func() error {
parts, err := getParts(c, client, fs.cache, file)
if err != nil {
fs.handleError(err, w)
return nil
}
if file.Encrypted {
lr, err = reader.NewDecryptedReader(c, client.API(), fs.worker, fs.cache, file, parts, start, end, &fs.cnf.TG, multiThreads)
} else {
lr, err = reader.NewLinearReader(c, client.API(), fs.worker, fs.cache, file, parts, start, end, &fs.cnf.TG, multiThreads)
}
if download {
multiThreads = 0
if err != nil {
fs.handleError(err, w)
return nil
}
if lr == nil {
fs.handleError(fmt.Errorf("failed to initialise reader"), w)
return nil
}
_, err = io.CopyN(w, lr, contentLength)
if err != nil {
lr.Close()
}
return nil
}
if file.Encrypted {
lr, err = reader.NewDecryptedReader(c, file.Id, parts, start, end, file.ChannelID, &fs.cnf.TG, multiThreads, client, fs.worker, fs.cache)
if fs.cnf.TG.DisableBgBots {
tgc.RunWithAuth(c, client, token, func(ctx context.Context) error {
return handleStream()
})
} else {
lr, err = reader.NewLinearReader(c, file.Id, parts, start, end, file.ChannelID, &fs.cnf.TG, multiThreads, client, fs.worker, fs.cache)
fs.worker.IncActiveStream()
defer fs.worker.DecActiveStreams()
handleStream()
}
if err != nil {
logger.Error(ErrorStreamAbandoned, err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if lr == nil {
http.Error(w, "failed to initialise reader", http.StatusInternalServerError)
return
}
_, err = io.CopyN(w, lr, contentLength)
if err != nil {
lr.Close()
}
}
}
func (fs *FileService) handleError(err error, w http.ResponseWriter) {
fs.logger.Error(err)
http.Error(w, err.Error(), http.StatusInternalServerError)
}
func getOrder(fquery *schemas.FileQuery) clause.OrderByColumn {
sortColumn := utils.CamelToSnake(fquery.Sort)

View file

@ -20,7 +20,7 @@ type FileServiceSuite struct {
func (s *FileServiceSuite) SetupSuite() {
s.db = database.NewTestDatabase(s.T(), false)
s.srv = NewFileService(s.db, nil, nil, nil)
s.srv = NewFileService(s.db, nil, nil, nil, nil, nil, nil)
}
func (s *FileServiceSuite) SetupTest() {

View file

@ -39,13 +39,13 @@ const saltLength = 32
type UploadService struct {
db *gorm.DB
worker *tgc.UploadWorker
worker *tgc.BotWorker
cnf *config.TGConfig
kv kv.KV
cache cache.Cacher
}
func NewUploadService(db *gorm.DB, cnf *config.Config, worker *tgc.UploadWorker, kv kv.KV, cache cache.Cacher) *UploadService {
func NewUploadService(db *gorm.DB, cnf *config.Config, worker *tgc.BotWorker, kv kv.KV, cache cache.Cacher) *UploadService {
return &UploadService{db: db, worker: worker, cnf: &cnf.TG, kv: kv, cache: cache}
}