mirror of
https://github.com/tgdrive/teldrive.git
synced 2025-01-08 08:10:05 +08:00
336 lines
7.6 KiB
Go
336 lines
7.6 KiB
Go
package services
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/rand"
|
|
"encoding/binary"
|
|
"fmt"
|
|
"io"
|
|
"math"
|
|
"sort"
|
|
"strconv"
|
|
"sync"
|
|
|
|
"github.com/divyam234/teldrive/internal/cache"
|
|
"github.com/divyam234/teldrive/internal/config"
|
|
"github.com/divyam234/teldrive/internal/crypt"
|
|
"github.com/divyam234/teldrive/internal/kv"
|
|
"github.com/divyam234/teldrive/internal/tgc"
|
|
"github.com/divyam234/teldrive/pkg/models"
|
|
"github.com/divyam234/teldrive/pkg/schemas"
|
|
"github.com/divyam234/teldrive/pkg/types"
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/gotd/td/telegram"
|
|
"github.com/gotd/td/tg"
|
|
"github.com/pkg/errors"
|
|
"github.com/thoas/go-funk"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
type buffer struct {
|
|
Buf []byte
|
|
}
|
|
|
|
func (b *buffer) long() (int64, error) {
|
|
v, err := b.uint64()
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return int64(v), nil
|
|
}
|
|
func (b *buffer) uint64() (uint64, error) {
|
|
const size = 8
|
|
if len(b.Buf) < size {
|
|
return 0, io.ErrUnexpectedEOF
|
|
}
|
|
v := binary.LittleEndian.Uint64(b.Buf)
|
|
b.Buf = b.Buf[size:]
|
|
return v, nil
|
|
}
|
|
|
|
func randInt64() (int64, error) {
|
|
var buf [8]byte
|
|
if _, err := io.ReadFull(rand.Reader, buf[:]); err != nil {
|
|
return 0, err
|
|
}
|
|
b := &buffer{Buf: buf[:]}
|
|
return b.long()
|
|
}
|
|
|
|
type batchResult struct {
|
|
Index int
|
|
Messages *tg.MessagesChannelMessages
|
|
}
|
|
|
|
func getChunk(ctx context.Context, tgClient *telegram.Client, location tg.InputFileLocationClass, offset int64, limit int64) ([]byte, error) {
|
|
|
|
req := &tg.UploadGetFileRequest{
|
|
Offset: offset,
|
|
Limit: int(limit),
|
|
Location: location,
|
|
}
|
|
|
|
r, err := tgClient.API().UploadGetFile(ctx, req)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
switch result := r.(type) {
|
|
case *tg.UploadFile:
|
|
return result.Bytes, nil
|
|
default:
|
|
return nil, fmt.Errorf("unexpected type %T", r)
|
|
}
|
|
}
|
|
|
|
func iterContent(ctx context.Context, tgClient *telegram.Client, location tg.InputFileLocationClass) (*bytes.Buffer, error) {
|
|
offset := int64(0)
|
|
limit := int64(1024 * 1024)
|
|
buff := &bytes.Buffer{}
|
|
for {
|
|
r, err := getChunk(ctx, tgClient, location, offset, limit)
|
|
if err != nil {
|
|
return buff, err
|
|
}
|
|
if len(r) == 0 {
|
|
break
|
|
}
|
|
buff.Write(r)
|
|
offset += int64(limit)
|
|
}
|
|
return buff, nil
|
|
}
|
|
|
|
func GetUserAuth(c *gin.Context) (int64, string) {
|
|
val, _ := c.Get("jwtUser")
|
|
jwtUser := val.(*types.JWTClaims)
|
|
userId, _ := strconv.ParseInt(jwtUser.Subject, 10, 64)
|
|
return userId, jwtUser.TgSession
|
|
}
|
|
|
|
func getBotInfo(ctx context.Context, KV kv.KV, config *config.TGConfig, token string) (*types.BotInfo, error) {
|
|
client, _ := tgc.BotClient(ctx, KV, config, token)
|
|
var user *tg.User
|
|
err := tgc.RunWithAuth(ctx, client, token, func(ctx context.Context) error {
|
|
user, _ = client.Self(ctx)
|
|
return nil
|
|
})
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &types.BotInfo{Id: user.ID, UserName: user.Username, Token: token}, nil
|
|
}
|
|
|
|
func getTGMessagesBatch(ctx context.Context, client *telegram.Client, channel *tg.InputChannel, parts []schemas.Part, userID string, index int,
|
|
results chan<- batchResult, errors chan<- error, wg *sync.WaitGroup) {
|
|
|
|
defer wg.Done()
|
|
|
|
ids := funk.Map(parts, func(part schemas.Part) tg.InputMessageClass {
|
|
return &tg.InputMessageID{ID: int(part.ID)}
|
|
}).([]tg.InputMessageClass)
|
|
|
|
messageRequest := tg.ChannelsGetMessagesRequest{
|
|
Channel: channel,
|
|
ID: ids,
|
|
}
|
|
|
|
res, err := client.API().ChannelsGetMessages(ctx, &messageRequest)
|
|
if err != nil {
|
|
errors <- err
|
|
return
|
|
}
|
|
|
|
messages, ok := res.(*tg.MessagesChannelMessages)
|
|
|
|
if !ok {
|
|
errors <- fmt.Errorf("unexpected response type: %T", res)
|
|
return
|
|
}
|
|
|
|
results <- batchResult{Index: index, Messages: messages}
|
|
}
|
|
|
|
func getTGMessages(ctx context.Context, client *telegram.Client, parts []schemas.Part, channelId int64, userID string) ([]tg.MessageClass, error) {
|
|
|
|
channel, err := GetChannelById(ctx, client, channelId, userID)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var wg sync.WaitGroup
|
|
|
|
batchSize := 200
|
|
|
|
batchCount := int(math.Ceil(float64(len(parts)) / float64(batchSize)))
|
|
|
|
results := make(chan batchResult, batchCount)
|
|
|
|
errors := make(chan error, batchCount)
|
|
|
|
for i := 0; i < batchCount; i++ {
|
|
wg.Add(1)
|
|
splitParts := parts[i*batchSize : min((i+1)*batchSize, len(parts))]
|
|
go getTGMessagesBatch(ctx, client, channel, splitParts, userID, i, results, errors, &wg)
|
|
}
|
|
|
|
wg.Wait()
|
|
close(results)
|
|
close(errors)
|
|
|
|
for err := range errors {
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
channelResult := []batchResult{}
|
|
|
|
for result := range results {
|
|
channelResult = append(channelResult, result)
|
|
}
|
|
|
|
sort.Slice(channelResult, func(i, j int) bool {
|
|
return channelResult[i].Index < channelResult[j].Index
|
|
})
|
|
|
|
allMessages := []tg.MessageClass{}
|
|
|
|
for _, result := range channelResult {
|
|
allMessages = append(allMessages, result.Messages.GetMessages()...)
|
|
}
|
|
|
|
return allMessages, nil
|
|
}
|
|
|
|
func getParts(ctx context.Context, client *telegram.Client, file *schemas.FileOutFull, userID string) ([]types.Part, error) {
|
|
cache := cache.FromContext(ctx)
|
|
parts := []types.Part{}
|
|
|
|
key := fmt.Sprintf("messages:%s:%s", file.ID, userID)
|
|
|
|
err := cache.Get(key, &parts)
|
|
|
|
if err == nil {
|
|
return parts, nil
|
|
}
|
|
|
|
messages, err := getTGMessages(ctx, client, file.Parts, file.ChannelID, userID)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for i, message := range messages {
|
|
item := message.(*tg.Message)
|
|
media := item.Media.(*tg.MessageMediaDocument)
|
|
document := media.Document.(*tg.Document)
|
|
location := document.AsInputDocumentFileLocation()
|
|
|
|
part := types.Part{
|
|
Location: location,
|
|
Size: document.Size,
|
|
Salt: file.Parts[i].Salt,
|
|
}
|
|
if file.Encrypted {
|
|
part.DecryptedSize, _ = crypt.DecryptedSize(document.Size)
|
|
}
|
|
parts = append(parts, part)
|
|
}
|
|
cache.Set(key, &parts, 3600)
|
|
return parts, nil
|
|
}
|
|
|
|
func GetChannelById(ctx context.Context, client *telegram.Client, channelId int64, userID string) (*tg.InputChannel, error) {
|
|
|
|
channel := &tg.InputChannel{}
|
|
inputChannel := &tg.InputChannel{
|
|
ChannelID: channelId,
|
|
}
|
|
channels, err := client.API().ChannelsGetChannels(ctx, []tg.InputChannelClass{inputChannel})
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if len(channels.GetChats()) == 0 {
|
|
return nil, errors.New("no channels found")
|
|
}
|
|
|
|
channel = channels.GetChats()[0].(*tg.Channel).AsInput()
|
|
return channel, nil
|
|
}
|
|
|
|
func GetDefaultChannel(ctx context.Context, db *gorm.DB, userID int64) (int64, error) {
|
|
cache := cache.FromContext(ctx)
|
|
var channelId int64
|
|
key := fmt.Sprintf("users:channel:%d", userID)
|
|
|
|
err := cache.Get(key, &channelId)
|
|
|
|
if err == nil {
|
|
return channelId, nil
|
|
}
|
|
|
|
var channelIds []int64
|
|
db.Model(&models.Channel{}).Where("user_id = ?", userID).Where("selected = ?", true).
|
|
Pluck("channel_id", &channelIds)
|
|
|
|
if len(channelIds) == 1 {
|
|
channelId = channelIds[0]
|
|
cache.Set(key, channelId, 0)
|
|
}
|
|
|
|
if channelId == 0 {
|
|
return channelId, errors.New("default channel not set")
|
|
}
|
|
|
|
return channelId, nil
|
|
}
|
|
|
|
func getBotsToken(ctx context.Context, db *gorm.DB, userID, channelId int64) ([]string, error) {
|
|
cache := cache.FromContext(ctx)
|
|
var bots []string
|
|
|
|
key := fmt.Sprintf("users:bots:%d:%d", userID, channelId)
|
|
|
|
err := cache.Get(key, &bots)
|
|
|
|
if err == nil {
|
|
return bots, nil
|
|
}
|
|
|
|
if err := db.Model(&models.Bot{}).Where("user_id = ?", userID).
|
|
Where("channel_id = ?", channelId).Pluck("token", &bots).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
cache.Set(key, &bots, 0)
|
|
return bots, nil
|
|
|
|
}
|
|
|
|
func getSessionByHash(db *gorm.DB, cache *cache.Cache, hash string) (*models.Session, error) {
|
|
var session models.Session
|
|
|
|
key := fmt.Sprintf("sessions:%s", hash)
|
|
|
|
err := cache.Get(key, &session)
|
|
|
|
if err == nil {
|
|
return &session, nil
|
|
}
|
|
|
|
if err := db.Model(&models.Session{}).Where("hash = ?", hash).First(&session).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
cache.Set(key, &session, 0)
|
|
|
|
return &session, nil
|
|
|
|
}
|