teldrive/pkg/services/common.go

308 lines
7 KiB
Go
Raw Normal View History

2023-09-20 03:20:44 +08:00
package services
import (
"bytes"
"context"
2023-12-03 03:47:23 +08:00
"crypto/rand"
"encoding/binary"
2023-09-20 03:20:44 +08:00
"fmt"
2023-12-03 03:47:23 +08:00
"io"
2023-09-20 03:20:44 +08:00
"strconv"
2023-12-03 03:47:23 +08:00
"github.com/divyam234/teldrive/internal/cache"
"github.com/divyam234/teldrive/internal/tgc"
"github.com/divyam234/teldrive/pkg/database"
"github.com/divyam234/teldrive/pkg/models"
"github.com/divyam234/teldrive/pkg/schemas"
"github.com/divyam234/teldrive/pkg/types"
2023-09-20 03:20:44 +08:00
"github.com/gin-gonic/gin"
"github.com/gotd/td/telegram"
"github.com/gotd/td/tg"
"github.com/pkg/errors"
"github.com/thoas/go-funk"
)
2023-12-03 03:47:23 +08:00
type buffer struct {
Buf []byte
}
func (b *buffer) long() (int64, error) {
v, err := b.uint64()
if err != nil {
return 0, err
}
return int64(v), nil
}
func (b *buffer) uint64() (uint64, error) {
const size = 8
if len(b.Buf) < size {
return 0, io.ErrUnexpectedEOF
}
v := binary.LittleEndian.Uint64(b.Buf)
b.Buf = b.Buf[size:]
return v, nil
}
func randInt64() (int64, error) {
var buf [8]byte
if _, err := io.ReadFull(rand.Reader, buf[:]); err != nil {
return 0, err
}
b := &buffer{Buf: buf[:]}
return b.long()
}
2023-09-20 03:20:44 +08:00
func getChunk(ctx context.Context, tgClient *telegram.Client, location tg.InputFileLocationClass, offset int64, limit int64) ([]byte, error) {
req := &tg.UploadGetFileRequest{
Offset: offset,
Limit: int(limit),
Location: location,
}
r, err := tgClient.API().UploadGetFile(ctx, req)
if err != nil {
return nil, err
}
switch result := r.(type) {
case *tg.UploadFile:
return result.Bytes, nil
default:
return nil, fmt.Errorf("unexpected type %T", r)
}
}
func iterContent(ctx context.Context, tgClient *telegram.Client, location tg.InputFileLocationClass) (*bytes.Buffer, error) {
offset := int64(0)
limit := int64(1024 * 1024)
buff := &bytes.Buffer{}
for {
r, err := getChunk(ctx, tgClient, location, offset, limit)
if err != nil {
return buff, err
}
if len(r) == 0 {
break
}
buff.Write(r)
offset += int64(limit)
}
return buff, nil
}
func getUserAuth(c *gin.Context) (int64, string) {
val, _ := c.Get("jwtUser")
jwtUser := val.(*types.JWTClaims)
userId, _ := strconv.ParseInt(jwtUser.Subject, 10, 64)
return userId, jwtUser.TgSession
}
2023-12-03 05:46:53 +08:00
func getBotInfo(ctx context.Context, token string) (*types.BotInfo, error) {
2023-11-25 12:31:29 +08:00
client, _ := tgc.BotLogin(ctx, token)
2023-09-20 03:20:44 +08:00
var user *tg.User
err := tgc.RunWithAuth(ctx, client, token, func(ctx context.Context) error {
user, _ = client.Self(ctx)
return nil
})
if err != nil {
return nil, err
}
2023-12-03 05:46:53 +08:00
return &types.BotInfo{Id: user.ID, UserName: user.Username, Token: token}, nil
2023-09-20 03:20:44 +08:00
}
2023-12-03 03:47:23 +08:00
func getTGMessages(ctx context.Context, client *telegram.Client, parts []schemas.Part, channelId int64, userID string) (*tg.MessagesChannelMessages, error) {
2023-11-02 21:51:30 +08:00
2023-12-03 05:23:06 +08:00
ids := funk.Map(parts, func(part schemas.Part) tg.InputMessageClass {
2023-09-20 03:20:44 +08:00
return tg.InputMessageClass(&tg.InputMessageID{ID: int(part.ID)})
})
2023-11-16 23:21:35 +08:00
channel, err := GetChannelById(ctx, client, channelId, userID)
2023-09-20 03:20:44 +08:00
if err != nil {
return nil, err
}
messageRequest := tg.ChannelsGetMessagesRequest{Channel: channel, ID: ids.([]tg.InputMessageClass)}
res, err := client.API().ChannelsGetMessages(ctx, &messageRequest)
if err != nil {
return nil, err
}
messages := res.(*tg.MessagesChannelMessages)
2023-11-09 19:10:37 +08:00
return messages, nil
}
func getParts(ctx context.Context, client *telegram.Client, file *schemas.FileOutFull, userID string) ([]types.Part, error) {
parts := []types.Part{}
key := fmt.Sprintf("messages:%s:%s", file.ID, userID)
err := cache.GetCache().Get(key, &parts)
if err == nil {
return parts, nil
}
2023-12-03 03:47:23 +08:00
messages, err := getTGMessages(ctx, client, file.Parts, file.ChannelID, userID)
2023-11-09 19:10:37 +08:00
if err != nil {
return nil, err
}
2023-09-20 03:20:44 +08:00
for _, message := range messages.Messages {
item := message.(*tg.Message)
media := item.Media.(*tg.MessageMediaDocument)
document := media.Document.(*tg.Document)
location := document.AsInputDocumentFileLocation()
2023-11-06 02:08:18 +08:00
parts = append(parts, types.Part{Location: location, Start: 0, End: document.Size - 1})
2023-09-20 03:20:44 +08:00
}
2023-11-02 21:51:30 +08:00
cache.GetCache().Set(key, &parts, 3600)
2023-09-20 03:20:44 +08:00
return parts, nil
}
2023-11-06 02:08:18 +08:00
func rangedParts(parts []types.Part, startByte, endByte int64) []types.Part {
chunkSize := parts[0].End + 1
numParts := int64(len(parts))
2023-09-20 03:20:44 +08:00
2023-11-06 02:08:18 +08:00
validParts := []types.Part{}
2023-09-20 03:20:44 +08:00
2023-11-06 02:08:18 +08:00
firstChunk := max(startByte/chunkSize, 0)
2023-09-20 03:20:44 +08:00
2023-11-06 02:08:18 +08:00
lastChunk := min(endByte/chunkSize, numParts)
2023-09-20 03:20:44 +08:00
2023-11-06 02:08:18 +08:00
startInFirstChunk := startByte % chunkSize
endInLastChunk := endByte % chunkSize
if firstChunk == lastChunk {
validParts = append(validParts, types.Part{
Location: parts[firstChunk].Location,
Start: startInFirstChunk,
End: endInLastChunk,
})
} else {
validParts = append(validParts, types.Part{
Location: parts[firstChunk].Location,
Start: startInFirstChunk,
End: parts[firstChunk].End,
})
// Add valid parts from any chunks in between.
for i := firstChunk + 1; i < lastChunk; i++ {
validParts = append(validParts, types.Part{
2023-11-08 17:04:18 +08:00
Location: parts[i].Location,
2023-11-06 02:08:18 +08:00
Start: 0,
2023-11-08 17:04:18 +08:00
End: parts[i].End,
2023-11-06 02:08:18 +08:00
})
}
2023-09-20 03:20:44 +08:00
2023-11-06 02:08:18 +08:00
// Add valid parts from the last chunk.
validParts = append(validParts, types.Part{
2023-11-08 17:04:18 +08:00
Location: parts[lastChunk].Location,
2023-11-06 02:08:18 +08:00
Start: 0,
End: endInLastChunk,
})
2023-09-20 03:20:44 +08:00
}
2023-11-06 02:08:18 +08:00
return validParts
2023-09-20 03:20:44 +08:00
}
2023-11-16 23:21:35 +08:00
func GetChannelById(ctx context.Context, client *telegram.Client, channelId int64, userID string) (*tg.InputChannel, error) {
2023-09-20 03:20:44 +08:00
channel := &tg.InputChannel{}
2023-09-24 04:26:04 +08:00
inputChannel := &tg.InputChannel{
2023-11-16 23:21:35 +08:00
ChannelID: channelId,
2023-09-24 04:26:04 +08:00
}
channels, err := client.API().ChannelsGetChannels(ctx, []tg.InputChannelClass{inputChannel})
2023-09-20 03:20:44 +08:00
if err != nil {
2023-09-24 04:26:04 +08:00
return nil, err
}
2023-09-20 03:20:44 +08:00
2023-09-24 04:26:04 +08:00
if len(channels.GetChats()) == 0 {
return nil, errors.New("no channels found")
2023-09-20 03:20:44 +08:00
}
2023-09-24 04:26:04 +08:00
channel = channels.GetChats()[0].(*tg.Channel).AsInput()
2023-09-20 03:20:44 +08:00
return channel, nil
}
func GetDefaultChannel(ctx context.Context, userID int64) (int64, error) {
2023-11-16 23:21:35 +08:00
var channelId int64
2023-09-20 03:20:44 +08:00
2023-11-02 21:51:30 +08:00
key := fmt.Sprintf("users:channel:%d", userID)
2023-09-20 03:20:44 +08:00
2023-11-16 23:21:35 +08:00
err := cache.GetCache().Get(key, &channelId)
2023-09-20 03:20:44 +08:00
2023-11-02 21:51:30 +08:00
if err == nil {
2023-11-16 23:21:35 +08:00
return channelId, nil
2023-11-02 21:51:30 +08:00
}
2023-09-20 03:20:44 +08:00
2023-11-02 21:51:30 +08:00
var channelIds []int64
database.DB.Model(&models.Channel{}).Where("user_id = ?", userID).Where("selected = ?", true).
Pluck("channel_id", &channelIds)
if len(channelIds) == 1 {
2023-11-16 23:21:35 +08:00
channelId = channelIds[0]
cache.GetCache().Set(key, channelId, 0)
2023-09-20 03:20:44 +08:00
}
2023-11-16 23:21:35 +08:00
if channelId == 0 {
return channelId, errors.New("default channel not set")
2023-09-20 03:20:44 +08:00
}
2023-11-16 23:21:35 +08:00
return channelId, nil
2023-09-20 03:20:44 +08:00
}
2023-12-03 03:47:23 +08:00
func getBotsToken(ctx context.Context, userID, channelId int64) ([]string, error) {
2023-09-20 03:20:44 +08:00
var bots []string
2023-11-02 21:51:30 +08:00
key := fmt.Sprintf("users:bots:%d:%d", userID, channelId)
2023-11-16 23:21:35 +08:00
err := cache.GetCache().Get(key, &bots)
2023-09-20 03:20:44 +08:00
2023-11-02 21:51:30 +08:00
if err == nil {
return bots, nil
2023-09-20 03:20:44 +08:00
}
2023-11-02 21:51:30 +08:00
if err := database.DB.Model(&models.Bot{}).Where("user_id = ?", userID).
Where("channel_id = ?", channelId).Pluck("token", &bots).Error; err != nil {
return nil, err
}
cache.GetCache().Set(key, &bots, 0)
2023-09-20 03:20:44 +08:00
return bots, nil
}
2023-11-02 21:51:30 +08:00
2023-12-03 03:47:23 +08:00
func getSessionByHash(hash string) (*models.Session, error) {
2023-11-02 21:51:30 +08:00
var session models.Session
key := fmt.Sprintf("sessions:%s", hash)
err := cache.GetCache().Get(key, &session)
if err == nil {
return &session, nil
}
if err := database.DB.Model(&models.Session{}).Where("hash = ?", hash).First(&session).Error; err != nil {
return nil, err
}
cache.GetCache().Set(key, &session, 0)
return &session, nil
}