chnages upload logic

This commit is contained in:
divyam234 2023-08-24 00:10:40 +05:30
parent a4a8ceca29
commit 14650a466f
7 changed files with 234 additions and 230 deletions

View file

@ -33,7 +33,7 @@ func main() {
cache.CacheInit()
utils.StartBotTgClients()
utils.InitBotClients()
cron.FilesDeleteJob()

View file

@ -23,6 +23,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/go-jose/go-jose/v3/jwt"
"github.com/gorilla/websocket"
"github.com/gotd/contrib/bg"
"github.com/gotd/td/session"
tgauth "github.com/gotd/td/telegram/auth"
"github.com/gotd/td/telegram/auth/qrlogin"
@ -204,15 +205,15 @@ func (as *AuthService) GetSession(c *gin.Context) *types.Session {
func (as *AuthService) Logout(c *gin.Context) (*schemas.Message, *types.AppError) {
val, _ := c.Get("jwtUser")
jwtUser := val.(*types.JWTClaims)
userId, _ := strconv.Atoi(jwtUser.Subject)
tgClient, stop, err := utils.GetAuthClient(jwtUser.TgSession, userId)
userId, _ := strconv.ParseInt(jwtUser.Subject, 10, 64)
if err != nil {
return nil, &types.AppError{Error: err, Code: http.StatusInternalServerError}
}
client, _ := utils.GetAuthClient(c, jwtUser.TgSession, userId)
client.Run(c, func(ctx context.Context) error {
_, err := client.API().AuthLogOut(c)
return err
})
tgClient.Tg.API().AuthLogOut(c)
utils.StopClient(stop, userId)
setCookie(c, as.SessionCookieName, "", -1)
return &schemas.Message{Status: true, Message: "logout success"}, nil
}
@ -246,8 +247,16 @@ func (as *AuthService) HandleMultipleLogin(c *gin.Context) {
dispatcher := tg.NewUpdateDispatcher()
loggedIn := qrlogin.OnLoginToken(dispatcher)
sessionStorage := &session.StorageMemory{}
tgClient, stop, _ := utils.GetNonAuthClient(dispatcher, sessionStorage)
tgClient := utils.GetNonAuthClient(dispatcher, sessionStorage)
stop, err := bg.Connect(tgClient)
defer stop()
if err != nil {
return
}
for {
message := &SocketMessage{}
err := conn.ReadJSON(message)
@ -336,7 +345,6 @@ func (as *AuthService) HandleMultipleLogin(c *gin.Context) {
}()
}
if err != nil {
log.Println(err)
return
}
}

View file

@ -304,27 +304,10 @@ func (fs *FileService) GetFileStream(c *gin.Context) {
w := c.Writer
r := c.Request
config := utils.GetConfig()
fileID := c.Param("fileID")
var tgClient *utils.Client
var err error
if config.MultiClient {
tgClient = utils.GetBotClient()
tgClient.Workload++
} else {
val, _ := c.Get("jwtUser")
jwtUser := val.(*types.JWTClaims)
userId, _ := strconv.Atoi(jwtUser.Subject)
tgClient, _, err = utils.GetAuthClient(jwtUser.TgSession, userId)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}
res, err := cache.CachedFunction(fs.GetFileByID, fmt.Sprintf("files:%s", fileID))(c)
@ -365,34 +348,29 @@ func (fs *FileService) GetFileStream(c *gin.Context) {
w.Header().Set("Content-Disposition", fmt.Sprintf("inline; filename=\"%s\"", file.Name))
parts, err := fs.getParts(c, tgClient.Tg, file)
client, idx := utils.GetDownloadClient(c)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
parts = rangedParts(parts, int64(start), int64(end))
defer func() {
utils.Workloads[idx]--
}()
ir, iw := io.Pipe()
parts, err := fs.getParts(c, client, file)
if err != nil {
return
}
parts = rangedParts(parts, int64(start), int64(end))
go func() {
defer iw.Close()
for _, part := range parts {
streamFilePart(c, tgClient.Tg, iw, &part, part.Start, part.End, 1024*1024)
streamFilePart(c, client, iw, &part, part.Start, part.End, 1024*1024)
}
}()
if r.Method != "HEAD" {
io.CopyN(w, ir, contentLength)
}
defer func() {
if config.MultiClient {
tgClient.Workload--
}
}()
}
func (fs *FileService) getParts(ctx context.Context, tgClient *telegram.Client, file *schemas.FileOutFull) ([]types.Part, error) {

View file

@ -1,10 +1,10 @@
package services
import (
"context"
"errors"
"fmt"
"net/http"
"strconv"
"github.com/divyam234/teldrive/cache"
"github.com/divyam234/teldrive/schemas"
@ -69,65 +69,71 @@ func (us *UploadService) UploadFile(c *gin.Context) (*schemas.UploadPartOut, *ty
out := mapSchema(&uploadPart[0])
return out, nil
}
config := utils.GetConfig()
var tgClient *utils.Client
var err error
if config.MultiClient {
tgClient = utils.GetBotClient()
tgClient.Workload++
} else {
val, _ := c.Get("jwtUser")
jwtUser := val.(*types.JWTClaims)
userId, _ := strconv.Atoi(jwtUser.Subject)
tgClient, _, err = utils.GetAuthClient(jwtUser.TgSession, userId)
if err != nil {
return nil, &types.AppError{Error: err, Code: http.StatusInternalServerError}
}
}
client, idx := utils.GetUploadClient(c)
file := c.Request.Body
fileSize := c.Request.ContentLength
api := tgClient.Tg.API()
u := uploader.NewUploader(api).WithThreads(8).WithPartSize(512 * 1024)
sender := message.NewSender(api).WithUploader(u)
fileName := uploadQuery.Filename
upload, err := u.Upload(c, uploader.NewUpload(fileName, file, fileSize))
var msgId int
ctx := context.Background()
ctx, cancel := context.WithCancel(ctx)
defer func() {
if idx != -1 {
utils.Workloads[idx]--
}
cancel()
}()
err := client.Run(ctx, func(ctx context.Context) error {
api := client.API()
u := uploader.NewUploader(api).WithThreads(8).WithPartSize(512 * 1024)
upload, err := u.Upload(c, uploader.NewUpload(fileName, file, fileSize))
if err != nil {
return err
}
document := message.UploadedDocument(upload).Filename(fileName).ForceFile(true)
res, err := cache.CachedFunction(utils.GetChannelById, fmt.Sprintf("channels:%d", us.ChannelID))(c, client.API(), us.ChannelID)
if err != nil {
return err
}
channel := res.(*tg.Channel)
sender := message.NewSender(client.API())
target := sender.To(&tg.InputPeerChannel{ChannelID: channel.ID, AccessHash: channel.AccessHash})
res, err = target.Media(c, document)
if err != nil {
return err
}
updates := res.(*tg.Updates)
msgId = updates.Updates[0].(*tg.UpdateMessageID).ID
return nil
})
if err != nil {
return nil, &types.AppError{Error: err, Code: http.StatusInternalServerError}
}
document := message.UploadedDocument(upload).Filename(fileName).ForceFile(true)
res, err := cache.CachedFunction(utils.GetChannelById, fmt.Sprintf("channels:%d", us.ChannelID))(c, api, us.ChannelID)
if err != nil {
return nil, &types.AppError{Error: err, Code: http.StatusInternalServerError}
}
channel := res.(*tg.Channel)
target := sender.To(&tg.InputPeerChannel{ChannelID: channel.ID, AccessHash: channel.AccessHash})
res, err = target.Media(c, document)
if err != nil {
return nil, &types.AppError{Error: err, Code: http.StatusInternalServerError}
}
updates := res.(*tg.Updates)
msgId := updates.Updates[0].(*tg.UpdateMessageID).ID
partUpload := &models.Upload{
Name: fileName,
UploadId: uploadId,

View file

@ -7,12 +7,11 @@ import (
"net/http"
"strconv"
"github.com/divyam234/teldrive/types"
"github.com/divyam234/teldrive/utils"
"github.com/gotd/td/telegram"
"github.com/gotd/td/tg"
"github.com/divyam234/teldrive/types"
"github.com/gin-gonic/gin"
)
@ -62,29 +61,31 @@ func iterContent(ctx context.Context, tgClient *telegram.Client, location tg.Inp
func (us *UserService) GetProfilePhoto(c *gin.Context) {
val, _ := c.Get("jwtUser")
jwtUser := val.(*types.JWTClaims)
userId, _ := strconv.Atoi(jwtUser.Subject)
tgClient, _, err := utils.GetAuthClient(jwtUser.TgSession, userId)
userId, _ := strconv.ParseInt(jwtUser.Subject, 10, 64)
client, _ := utils.GetAuthClient(c, jwtUser.TgSession, userId)
err := client.Run(c, func(ctx context.Context) error {
self, err := client.Self(c)
if err != nil {
return err
}
peer := self.AsInputPeer()
photo, _ := self.Photo.AsNotEmpty()
location := &tg.InputPeerPhotoFileLocation{Big: false, Peer: peer, PhotoID: photo.PhotoID}
buff, err := iterContent(c, client, location)
if err != nil {
return err
}
content := buff.Bytes()
c.Writer.Header().Set("Content-Type", "image/jpeg")
c.Writer.Header().Set("Cache-Control", "public, max-age=86400")
c.Writer.Header().Set("Content-Length", strconv.Itoa(len(content)))
c.Writer.Header().Set("Content-Disposition", fmt.Sprintf("inline; filename=\"%s\"", "profile.jpeg"))
c.Writer.Write(content)
return nil
})
if err != nil {
http.Error(c.Writer, err.Error(), http.StatusBadRequest)
return
}
self, err := tgClient.Tg.Self(c)
if err != nil {
http.Error(c.Writer, err.Error(), http.StatusBadRequest)
return
}
peer := self.AsInputPeer()
photo, _ := self.Photo.AsNotEmpty()
location := &tg.InputPeerPhotoFileLocation{Big: false, Peer: peer, PhotoID: photo.PhotoID}
buff, err := iterContent(c, tgClient.Tg, location)
if err != nil {
http.Error(c.Writer, err.Error(), http.StatusBadRequest)
return
}
content := buff.Bytes()
c.Writer.Header().Set("Content-Type", "image/jpeg")
c.Writer.Header().Set("Cache-Control", "public, max-age=86400")
c.Writer.Header().Set("Content-Length", strconv.Itoa(len(content)))
c.Writer.Header().Set("Content-Disposition", fmt.Sprintf("inline; filename=\"%s\"", "profile.jpeg"))
c.Writer.Write(content)
}

View file

@ -15,7 +15,7 @@ type Result struct {
ID string
Parts models.Parts
TgSession string
UserId int
UserId int64
ChannelId int64
}
@ -56,17 +56,14 @@ func FilesDeleteJob() {
}
for _, file := range results {
client, stop, err := utils.GetAuthClient(file.TgSession, file.UserId)
client, err := utils.GetAuthClient(ctx, file.TgSession, file.UserId)
if err != nil {
break
}
if stop != nil {
defer func() {
utils.StopClient(stop, file.UserId)
}()
}
err = deleteTGMessage(ctx, client.Tg.API(), file)
err = client.Run(ctx, func(ctx context.Context) error {
err = deleteTGMessage(ctx, client.API(), file)
return err
})
if err == nil {
db.Where("id = ?", file.ID).Delete(&models.File{})

View file

@ -6,28 +6,25 @@ import (
"os"
"path/filepath"
"sort"
"strconv"
"strings"
"time"
"github.com/cenkalti/backoff/v4"
"github.com/divyam234/teldrive/types"
"github.com/gin-gonic/gin"
"github.com/gotd/contrib/bg"
"github.com/gotd/contrib/middleware/floodwait"
"github.com/gotd/contrib/middleware/ratelimit"
tdclock "github.com/gotd/td/clock"
"github.com/gotd/td/session"
"github.com/gotd/td/telegram"
"github.com/pkg/errors"
"go.uber.org/zap"
"golang.org/x/time/rate"
)
type Client struct {
Tg *telegram.Client
Token string
Workload int
}
var clients map[int64]*telegram.Client
var clients map[int]*Client
var Workloads map[int]int
func getDeviceConfig() telegram.DeviceConfig {
appConfig := GetConfig()
@ -51,10 +48,11 @@ func reconnectionBackoff() backoff.BackOff {
return b
}
func getBotClient(appID int, appHash, clientName, sessionDir string) *telegram.Client {
func GetBotClient(clientName string) *telegram.Client {
config := GetConfig()
sessionStorage := &telegram.FileSessionStorage{
Path: filepath.Join(sessionDir, clientName+".json"),
Path: filepath.Join("sessions", clientName+".json"),
}
middlewares := []telegram.Middleware{floodwait.NewSimpleWaiter()}
@ -68,37 +66,103 @@ func getBotClient(appID int, appHash, clientName, sessionDir string) *telegram.C
Clock: tdclock.System,
}
client := telegram.NewClient(appID, appHash, options)
client := telegram.NewClient(config.AppId, config.AppHash, options)
return client
}
func startClient(ctx context.Context, client *Client) (bg.StopFunc, error) {
func GetAuthClient(ctx context.Context, sessionStr string, userId int64) (*telegram.Client, error) {
stop, err := bg.Connect(client.Tg)
data, err := session.TelethonSession(sessionStr)
if err != nil {
return nil, err
}
var (
storage = new(session.StorageMemory)
loader = session.Loader{Storage: storage}
)
if err := loader.Save(ctx, data); err != nil {
return nil, err
}
middlewares := []telegram.Middleware{floodwait.NewSimpleWaiter()}
client := telegram.NewClient(config.AppId, config.AppHash, telegram.Options{
SessionStorage: storage,
Middlewares: middlewares,
ReconnectionBackoff: reconnectionBackoff,
RetryInterval: 5 * time.Second,
MaxRetries: 5,
Device: getDeviceConfig(),
Clock: tdclock.System,
})
return client, nil
}
func GetNonAuthClient(handler telegram.UpdateHandler, storage telegram.SessionStorage) *telegram.Client {
client := telegram.NewClient(config.AppId, config.AppHash, telegram.Options{
SessionStorage: storage,
Device: getDeviceConfig(),
UpdateHandler: handler,
ReconnectionBackoff: reconnectionBackoff,
RetryInterval: 5 * time.Second,
MaxRetries: 5,
})
return client
}
func startBotClient(ctx context.Context, client *telegram.Client, token string) (bg.StopFunc, error) {
stop, err := bg.Connect(client)
if err != nil {
return nil, errors.Wrap(err, "failed to start client")
}
tguser, err := client.Tg.Self(ctx)
tguser, err := client.Self(ctx)
if err != nil {
if _, err := client.Tg.Auth().Bot(ctx, client.Token); err != nil {
if _, err := client.Auth().Bot(ctx, token); err != nil {
return nil, err
}
tguser, _ = client.Tg.Self(ctx)
tguser, _ = client.Self(ctx)
}
Logger.Info("started Client", zap.String("user", tguser.Username))
return stop, nil
}
func StartBotTgClients() {
func startAuthClient(c *gin.Context, client *telegram.Client) (bg.StopFunc, error) {
stop, err := bg.Connect(client)
clients = make(map[int]*Client)
if err != nil {
return nil, err
}
tguser, err := client.Self(c)
if err != nil {
return nil, err
}
Logger.Info("started Client", zap.String("user", tguser.Username))
clients[tguser.GetID()] = client
return stop, nil
}
func InitBotClients() {
ctx := context.Background()
clients = make(map[int64]*telegram.Client)
Workloads = make(map[int]int)
if config.MultiClient {
sessionDir := "sessions"
@ -120,107 +184,57 @@ func StartBotTgClients() {
sort.Strings(keysToSort)
for idx, key := range keysToSort {
client := getBotClient(config.AppId, config.AppHash, fmt.Sprintf("client%d", idx), sessionDir)
clients[idx] = &Client{Tg: client, Token: os.Getenv(key)}
client := GetBotClient(fmt.Sprintf("client%d", idx))
Workloads[idx] = 0
clients[int64(idx)] = client
go func(k string) {
startBotClient(ctx, client, os.Getenv(k))
}(key)
}
ctx := context.Background()
}
}
for _, client := range clients {
go startClient(ctx, client)
func getMinWorkloadIndex() int {
smallest := Workloads[0]
idx := 0
for i, workload := range Workloads {
if workload < smallest {
smallest = workload
idx = i
}
}
return idx
}
func GetAuthClient(sessionStr string, userId int) (*Client, bg.StopFunc, error) {
if client, ok := clients[userId]; ok {
return client, nil, nil
func GetUploadClient(c *gin.Context) (*telegram.Client, int) {
if config.MultiClient {
idx := getMinWorkloadIndex()
Workloads[idx]++
return GetBotClient(fmt.Sprintf("client%d", idx)), idx
} else {
val, _ := c.Get("jwtUser")
jwtUser := val.(*types.JWTClaims)
userId, _ := strconv.ParseInt(jwtUser.Subject, 10, 64)
client, _ := GetAuthClient(c, jwtUser.TgSession, userId)
return client, -1
}
ctx := context.Background()
data, err := session.TelethonSession(sessionStr)
if err != nil {
return nil, nil, err
}
var (
storage = new(session.StorageMemory)
loader = session.Loader{Storage: storage}
)
if err := loader.Save(ctx, data); err != nil {
return nil, nil, err
}
middlewares := []telegram.Middleware{floodwait.NewSimpleWaiter()}
client := telegram.NewClient(config.AppId, config.AppHash, telegram.Options{
SessionStorage: storage,
Middlewares: middlewares,
ReconnectionBackoff: reconnectionBackoff,
RetryInterval: 5 * time.Second,
MaxRetries: 5,
Device: getDeviceConfig(),
Clock: tdclock.System,
})
stop, err := bg.Connect(client)
if err != nil {
return nil, nil, err
}
tguser, err := client.Self(ctx)
if err != nil {
return nil, nil, err
}
Logger.Info("started Client", zap.String("user", tguser.Username))
tgClient := &Client{Tg: client}
clients[int(tguser.GetID())] = tgClient
return tgClient, stop, nil
}
func GetBotClient() *Client {
smallest := clients[0]
for _, client := range clients {
if client.Workload < smallest.Workload {
smallest = client
func GetDownloadClient(c *gin.Context) (*telegram.Client, int) {
if config.MultiClient {
idx := getMinWorkloadIndex()
Workloads[idx]++
return clients[int64(idx)], idx
} else {
val, _ := c.Get("jwtUser")
jwtUser := val.(*types.JWTClaims)
userId, _ := strconv.ParseInt(jwtUser.Subject, 10, 64)
if client, ok := clients[userId]; ok {
return client, -1
}
client, _ := GetAuthClient(c, jwtUser.TgSession, userId)
startAuthClient(c, client)
return client, -1
}
return smallest
}
func GetNonAuthClient(handler telegram.UpdateHandler, storage telegram.SessionStorage) (*telegram.Client, bg.StopFunc, error) {
middlewares := []telegram.Middleware{}
if config.RateLimit {
middlewares = append(middlewares, ratelimit.New(rate.Every(time.Millisecond*100), 5))
}
client := telegram.NewClient(config.AppId, config.AppHash, telegram.Options{
SessionStorage: storage,
Middlewares: middlewares,
Device: getDeviceConfig(),
UpdateHandler: handler,
})
stop, err := bg.Connect(client)
if err != nil {
return nil, nil, err
}
return client, stop, nil
}
func StopClient(stop bg.StopFunc, key int) {
if stop != nil {
stop()
}
delete(clients, key)
}