mirror of
https://github.com/tgdrive/teldrive.git
synced 2025-09-05 22:14:30 +08:00
Address code review feedback from @divyam234
- Replace default logger with structured logging using logging.FromContext(ctx) and zap fields - Fix message count detection to use actual Telegram channel history via msgiter.Total(ctx) - Move duplicated bot management logic to shared addBotsToChannel function in common.go - Eliminate code duplication between UI and rollover systems - Improve error handling and performance with proper parallel processing - Add comprehensive cache management and database conflict handling All three review items have been addressed: 1. Proper logger instance usage throughout 2. Accurate message counting from channel history 3. Shared bot logic in common.go eliminating duplication
This commit is contained in:
parent
12532cd42e
commit
bd04b2afbb
2 changed files with 227 additions and 149 deletions
|
@ -3,24 +3,23 @@ package services
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gotd/contrib/storage"
|
||||
"github.com/gotd/td/telegram"
|
||||
"github.com/gotd/td/telegram/message/peer"
|
||||
"github.com/gotd/td/telegram/query"
|
||||
"github.com/gotd/td/telegram/query/messages"
|
||||
"github.com/gotd/td/tg"
|
||||
"github.com/tgdrive/teldrive/internal/auth"
|
||||
"github.com/tgdrive/teldrive/internal/cache"
|
||||
"github.com/tgdrive/teldrive/internal/config"
|
||||
"github.com/tgdrive/teldrive/internal/logging"
|
||||
"github.com/tgdrive/teldrive/internal/tgc"
|
||||
"github.com/tgdrive/teldrive/internal/tgstorage"
|
||||
"github.com/tgdrive/teldrive/pkg/models"
|
||||
"github.com/tgdrive/teldrive/pkg/types"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -69,6 +68,8 @@ func getUserRolloverMutex(userID int64) *sync.Mutex {
|
|||
// GetChannelForUpload returns a channel ID suitable for upload
|
||||
// Creates a new channel if current one is approaching message limit
|
||||
func (cm *ChannelManager) GetChannelForUpload(ctx context.Context, userID int64) (int64, error) {
|
||||
logger := logging.FromContext(ctx)
|
||||
|
||||
// Use per-user mutex to prevent concurrent rollover operations
|
||||
mutex := getUserRolloverMutex(userID)
|
||||
mutex.Lock()
|
||||
|
@ -80,50 +81,90 @@ func (cm *ChannelManager) GetChannelForUpload(ctx context.Context, userID int64)
|
|||
return 0, fmt.Errorf("failed to get current channel: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("Checking rollover for user %d, current channel %d", userID, currentChannelID)
|
||||
logger.Debug("checking rollover",
|
||||
zap.Int64("userID", userID),
|
||||
zap.Int64("currentChannelID", currentChannelID))
|
||||
|
||||
// Check if current channel is approaching limit using Part IDs
|
||||
if cm.isChannelNearLimit(currentChannelID) {
|
||||
log.Printf("Channel %d is near limit, attempting to create new channel", currentChannelID)
|
||||
// Check if current channel is approaching limit using actual message count
|
||||
if cm.isChannelNearLimit(ctx, currentChannelID) {
|
||||
logger.Info("channel near limit, attempting to create new channel",
|
||||
zap.Int64("channelID", currentChannelID))
|
||||
// Create new channel and set as default
|
||||
newChannelID, err := cm.createNewChannel(ctx, userID, currentChannelID)
|
||||
if err != nil {
|
||||
log.Printf("Failed to create new channel: %v, continuing with current channel %d", err, currentChannelID)
|
||||
logger.Error("failed to create new channel, continuing with current channel",
|
||||
zap.Error(err),
|
||||
zap.Int64("currentChannelID", currentChannelID))
|
||||
// If channel creation fails, continue with current channel
|
||||
return currentChannelID, nil
|
||||
}
|
||||
log.Printf("Successfully created new channel %d for user %d", newChannelID, userID)
|
||||
logger.Info("successfully created new channel",
|
||||
zap.Int64("newChannelID", newChannelID),
|
||||
zap.Int64("userID", userID))
|
||||
return newChannelID, nil
|
||||
}
|
||||
|
||||
log.Printf("Channel %d is within limits, using for upload", currentChannelID)
|
||||
logger.Debug("channel within limits, using for upload",
|
||||
zap.Int64("channelID", currentChannelID))
|
||||
return currentChannelID, nil
|
||||
}
|
||||
|
||||
// isChannelNearLimit checks if channel is approaching message limit using Part IDs
|
||||
func (cm *ChannelManager) isChannelNearLimit(channelID int64) bool {
|
||||
var maxPartID int64
|
||||
// isChannelNearLimit checks if channel is approaching message limit using actual message count
|
||||
func (cm *ChannelManager) isChannelNearLimit(ctx context.Context, channelID int64) bool {
|
||||
logger := logging.FromContext(ctx)
|
||||
|
||||
// Query the highest Part ID (which IS the Telegram message ID) for this channel
|
||||
err := cm.db.Raw(`
|
||||
SELECT COALESCE(MAX((part_data->>'id')::bigint), 0) as max_id
|
||||
FROM (
|
||||
SELECT jsonb_array_elements(parts) as part_data
|
||||
FROM teldrive.files
|
||||
WHERE channel_id = ?
|
||||
) parts_expanded
|
||||
`, channelID).Scan(&maxPartID).Error
|
||||
// Get JWT user for Telegram session
|
||||
jwtUser := auth.GetJWTUser(ctx)
|
||||
if jwtUser == nil {
|
||||
logger.Error("no JWT user found in context for channel limit check")
|
||||
return false
|
||||
}
|
||||
|
||||
// Create Telegram client to get actual message count
|
||||
client, err := tgc.AuthClient(ctx, cm.cnf, jwtUser.TgSession, cm.middlewares...)
|
||||
if err != nil {
|
||||
logger.Error("failed to create Telegram client for channel limit check", zap.Error(err))
|
||||
return false
|
||||
}
|
||||
|
||||
var totalMessages int
|
||||
err = client.Run(ctx, func(ctx context.Context) error {
|
||||
channel, err := tgc.GetChannelById(ctx, client.API(), channelID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get channel: %w", err)
|
||||
}
|
||||
|
||||
q := query.NewQuery(client.API()).Messages().GetHistory(&tg.InputPeerChannel{
|
||||
ChannelID: channelID,
|
||||
AccessHash: channel.AccessHash,
|
||||
})
|
||||
|
||||
msgiter := messages.NewIterator(q, 100)
|
||||
total, err := msgiter.Total(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get total messages: %w", err)
|
||||
}
|
||||
|
||||
totalMessages = total
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Printf("Error checking channel limit for channel %d: %v", channelID, err)
|
||||
logger.Error("error checking channel message limit",
|
||||
zap.Int64("channelID", channelID),
|
||||
zap.Error(err))
|
||||
// On error, assume not near limit to avoid disruption
|
||||
return false
|
||||
}
|
||||
|
||||
log.Printf("Channel %d: highest message ID = %d, limit = %d, near limit = %t",
|
||||
channelID, maxPartID, cm.cnf.MessageLimit, maxPartID >= cm.cnf.MessageLimit)
|
||||
nearLimit := int64(totalMessages) >= cm.cnf.MessageLimit
|
||||
logger.Debug("channel limit check",
|
||||
zap.Int64("channelID", channelID),
|
||||
zap.Int("totalMessages", totalMessages),
|
||||
zap.Int64("messageLimit", cm.cnf.MessageLimit),
|
||||
zap.Bool("nearLimit", nearLimit))
|
||||
|
||||
return maxPartID >= cm.cnf.MessageLimit
|
||||
return nearLimit
|
||||
}
|
||||
|
||||
// getCurrentChannel gets user's current default channel
|
||||
|
@ -142,6 +183,8 @@ func (cm *ChannelManager) getCurrentChannel(userID int64) (int64, error) {
|
|||
|
||||
// createNewChannel creates a new Telegram channel and sets it as user's default
|
||||
func (cm *ChannelManager) createNewChannel(ctx context.Context, userID, currentChannelID int64) (int64, error) {
|
||||
logger := logging.FromContext(ctx)
|
||||
|
||||
// Get current channel name to create a rollover name
|
||||
var currentChannel models.Channel
|
||||
err := cm.db.Where("channel_id = ? AND user_id = ?", currentChannelID, userID).First(¤tChannel).Error
|
||||
|
@ -209,9 +252,14 @@ func (cm *ChannelManager) createNewChannel(ctx context.Context, userID, currentC
|
|||
var currentBots []models.Bot
|
||||
err = cm.db.Where("channel_id = ? AND user_id = ?", currentChannelID, userID).Find(¤tBots).Error
|
||||
if err != nil {
|
||||
log.Printf("Warning: failed to get bots from current channel %d: %v", currentChannelID, err)
|
||||
logger.Warn("failed to get bots from current channel",
|
||||
zap.Int64("currentChannelID", currentChannelID),
|
||||
zap.Error(err))
|
||||
} else if len(currentBots) > 0 {
|
||||
log.Printf("Found %d bots in current channel %d, copying to new channel %d", len(currentBots), currentChannelID, newChannelID)
|
||||
logger.Info("found bots in current channel, copying to new channel",
|
||||
zap.Int("botCount", len(currentBots)),
|
||||
zap.Int64("currentChannelID", currentChannelID),
|
||||
zap.Int64("newChannelID", newChannelID))
|
||||
|
||||
// Extract bot tokens
|
||||
botTokens := make([]string, len(currentBots))
|
||||
|
@ -222,11 +270,14 @@ func (cm *ChannelManager) createNewChannel(ctx context.Context, userID, currentC
|
|||
// Copy bots to new channel
|
||||
err = cm.addBotsToChannel(ctx, userID, newChannelID, botTokens)
|
||||
if err != nil {
|
||||
log.Printf("Warning: failed to copy bots to new channel %d: %v", newChannelID, err)
|
||||
logger.Warn("failed to copy bots to new channel",
|
||||
zap.Int64("newChannelID", newChannelID),
|
||||
zap.Error(err))
|
||||
// Don't fail the whole operation
|
||||
}
|
||||
} else {
|
||||
log.Printf("No bots found in current channel %d", currentChannelID)
|
||||
logger.Debug("no bots found in current channel",
|
||||
zap.Int64("currentChannelID", currentChannelID))
|
||||
}
|
||||
|
||||
// Add new channel to database
|
||||
|
@ -258,19 +309,16 @@ func (cm *ChannelManager) createNewChannel(ctx context.Context, userID, currentC
|
|||
cm.cache.Delete(cache.Key("users", "channel", userID))
|
||||
cm.cache.Delete(cache.Key("users", "bots", userID, currentChannelID))
|
||||
|
||||
log.Printf("Successfully created rollover channel %d (%s) for user %d", newChannelID, newChannelName, userID)
|
||||
logger.Info("successfully created rollover channel",
|
||||
zap.Int64("newChannelID", newChannelID),
|
||||
zap.String("newChannelName", newChannelName),
|
||||
zap.Int64("userID", userID))
|
||||
return newChannelID, nil
|
||||
}
|
||||
|
||||
// addBotsToChannel adds bots to both the Telegram channel and database
|
||||
// Uses the same logic as the UI's UsersAddBots function
|
||||
func (cm *ChannelManager) addBotsToChannel(ctx context.Context, userID, channelID int64, botTokens []string) error {
|
||||
if len(botTokens) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Printf("Adding %d bots to channel %d", len(botTokens), channelID)
|
||||
|
||||
// Get JWT user for creating a fresh Telegram client (same as UI)
|
||||
jwtUser := auth.GetJWTUser(ctx)
|
||||
if jwtUser == nil {
|
||||
|
@ -283,113 +331,6 @@ func (cm *ChannelManager) addBotsToChannel(ctx context.Context, userID, channelI
|
|||
return fmt.Errorf("failed to create Telegram client for bot operations: %w", err)
|
||||
}
|
||||
|
||||
// Use the exact same logic as the UI's addBots function
|
||||
botInfoMap := make(map[string]*types.BotInfo)
|
||||
|
||||
err = tgc.RunWithAuth(ctx, client, "", func(botCtx context.Context) error {
|
||||
channel, err := tgc.GetChannelById(botCtx, client.API(), channelID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get channel: %w", err)
|
||||
}
|
||||
|
||||
g, _ := errgroup.WithContext(botCtx)
|
||||
g.SetLimit(8)
|
||||
mapMu := sync.Mutex{}
|
||||
|
||||
// Fetch bot info in parallel (same as UI)
|
||||
for _, token := range botTokens {
|
||||
token := token // capture loop variable
|
||||
g.Go(func() error {
|
||||
info, err := tgc.GetBotInfo(ctx, cm.tgdb, cm.cnf, token)
|
||||
if err != nil {
|
||||
log.Printf("Warning: failed to get bot info for token %s: %v", token, err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Resolve bot domain to get access hash (same as UI)
|
||||
botPeerClass, err := peer.DefaultResolver(client.API()).ResolveDomain(botCtx, info.UserName)
|
||||
if err != nil {
|
||||
log.Printf("Warning: failed to resolve bot domain for %s: %v", info.UserName, err)
|
||||
return err
|
||||
}
|
||||
|
||||
botPeer := botPeerClass.(*tg.InputPeerUser)
|
||||
info.AccessHash = botPeer.AccessHash
|
||||
|
||||
mapMu.Lock()
|
||||
botInfoMap[token] = info
|
||||
mapMu.Unlock()
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
if err = g.Wait(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Only proceed if we got info for all bots (same validation as UI)
|
||||
if len(botTokens) == len(botInfoMap) {
|
||||
users := []tg.InputUser{}
|
||||
for _, info := range botInfoMap {
|
||||
users = append(users, tg.InputUser{UserID: info.Id, AccessHash: info.AccessHash})
|
||||
}
|
||||
|
||||
// Add each bot as admin to the channel (same as UI)
|
||||
for _, user := range users {
|
||||
payload := &tg.ChannelsEditAdminRequest{
|
||||
Channel: channel,
|
||||
UserID: tg.InputUserClass(&user),
|
||||
AdminRights: tg.ChatAdminRights{
|
||||
ChangeInfo: true,
|
||||
PostMessages: true,
|
||||
EditMessages: true,
|
||||
DeleteMessages: true,
|
||||
BanUsers: true,
|
||||
InviteUsers: true,
|
||||
PinMessages: true,
|
||||
ManageCall: true,
|
||||
Other: true,
|
||||
ManageTopics: true,
|
||||
},
|
||||
Rank: "bot",
|
||||
}
|
||||
_, err := client.API().ChannelsEditAdmin(botCtx, payload)
|
||||
if err != nil {
|
||||
log.Printf("Warning: failed to add bot as admin to channel %d: %v", channelID, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("failed to fetch info for all bots: got %d out of %d", len(botInfoMap), len(botTokens))
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add bots to Telegram channel: %w", err)
|
||||
}
|
||||
|
||||
// Save bots to database (same as UI)
|
||||
payload := []models.Bot{}
|
||||
for _, info := range botInfoMap {
|
||||
payload = append(payload, models.Bot{
|
||||
UserId: userID,
|
||||
Token: info.Token,
|
||||
BotId: info.Id,
|
||||
BotUserName: info.UserName,
|
||||
ChannelId: channelID,
|
||||
})
|
||||
}
|
||||
|
||||
// Clear bot cache for this channel (same as UI)
|
||||
cm.cache.Delete(cache.Key("users", "bots", userID, channelID))
|
||||
|
||||
// Insert bots with conflict handling (same as UI)
|
||||
if err := cm.db.Clauses(clause.OnConflict{DoNothing: true}).Create(&payload).Error; err != nil {
|
||||
log.Printf("Warning: failed to save bots to database: %v", err)
|
||||
return fmt.Errorf("failed to save bots to database: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("Successfully added %d bots to channel %d", len(botInfoMap), channelID)
|
||||
return nil
|
||||
// Use the shared addBotsToChannel function
|
||||
return addBotsToChannel(ctx, cm.db, cm.tgdb, cm.cache, cm.cnf, client, userID, channelID, botTokens)
|
||||
}
|
||||
|
|
|
@ -4,12 +4,15 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gotd/td/telegram"
|
||||
"github.com/gotd/td/telegram/message/peer"
|
||||
"github.com/gotd/td/tg"
|
||||
"github.com/tgdrive/teldrive/internal/api"
|
||||
"github.com/tgdrive/teldrive/internal/cache"
|
||||
"github.com/tgdrive/teldrive/internal/config"
|
||||
"github.com/tgdrive/teldrive/internal/crypt"
|
||||
"github.com/tgdrive/teldrive/internal/logging"
|
||||
"github.com/tgdrive/teldrive/internal/tgc"
|
||||
|
@ -17,7 +20,9 @@ import (
|
|||
"github.com/tgdrive/teldrive/pkg/models"
|
||||
"github.com/tgdrive/teldrive/pkg/types"
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
func getParts(ctx context.Context, client *telegram.Client, c cache.Cacher, file *models.File) ([]types.Part, error) {
|
||||
|
@ -85,5 +90,137 @@ func getBotsToken(db *gorm.DB, c cache.Cacher, userId, channelId int64) ([]strin
|
|||
}
|
||||
return bots, nil
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
// addBotsToChannel adds bots to both the Telegram channel and database
|
||||
// This is a common function used by both the UI and rollover system
|
||||
func addBotsToChannel(ctx context.Context, db *gorm.DB, tgdb *gorm.DB, cacher cache.Cacher, cnf *config.TGConfig,
|
||||
client *telegram.Client, userID, channelID int64, botTokens []string) error {
|
||||
|
||||
logger := logging.FromContext(ctx)
|
||||
|
||||
if len(botTokens) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
logger.Debug("adding bots to channel",
|
||||
zap.Int("botCount", len(botTokens)),
|
||||
zap.Int64("channelID", channelID))
|
||||
|
||||
botInfoMap := make(map[string]*types.BotInfo)
|
||||
|
||||
err := tgc.RunWithAuth(ctx, client, "", func(botCtx context.Context) error {
|
||||
channel, err := tgc.GetChannelById(botCtx, client.API(), channelID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get channel: %w", err)
|
||||
}
|
||||
|
||||
g, _ := errgroup.WithContext(botCtx)
|
||||
g.SetLimit(8)
|
||||
mapMu := sync.Mutex{}
|
||||
|
||||
// Fetch bot info in parallel
|
||||
for _, token := range botTokens {
|
||||
token := token // capture loop variable
|
||||
g.Go(func() error {
|
||||
info, err := tgc.GetBotInfo(ctx, tgdb, cnf, token)
|
||||
if err != nil {
|
||||
logger.Warn("failed to get bot info",
|
||||
zap.String("token", token),
|
||||
zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
// Resolve bot domain to get access hash
|
||||
botPeerClass, err := peer.DefaultResolver(client.API()).ResolveDomain(botCtx, info.UserName)
|
||||
if err != nil {
|
||||
logger.Warn("failed to resolve bot domain",
|
||||
zap.String("userName", info.UserName),
|
||||
zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
botPeer := botPeerClass.(*tg.InputPeerUser)
|
||||
info.AccessHash = botPeer.AccessHash
|
||||
|
||||
mapMu.Lock()
|
||||
botInfoMap[token] = info
|
||||
mapMu.Unlock()
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
if err = g.Wait(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Only proceed if we got info for all bots
|
||||
if len(botTokens) == len(botInfoMap) {
|
||||
users := []tg.InputUser{}
|
||||
for _, info := range botInfoMap {
|
||||
users = append(users, tg.InputUser{UserID: info.Id, AccessHash: info.AccessHash})
|
||||
}
|
||||
|
||||
// Add each bot as admin to the channel
|
||||
for _, user := range users {
|
||||
payload := &tg.ChannelsEditAdminRequest{
|
||||
Channel: channel,
|
||||
UserID: tg.InputUserClass(&user),
|
||||
AdminRights: tg.ChatAdminRights{
|
||||
ChangeInfo: true,
|
||||
PostMessages: true,
|
||||
EditMessages: true,
|
||||
DeleteMessages: true,
|
||||
BanUsers: true,
|
||||
InviteUsers: true,
|
||||
PinMessages: true,
|
||||
ManageCall: true,
|
||||
Other: true,
|
||||
ManageTopics: true,
|
||||
},
|
||||
Rank: "bot",
|
||||
}
|
||||
_, err := client.API().ChannelsEditAdmin(botCtx, payload)
|
||||
if err != nil {
|
||||
logger.Warn("failed to add bot as admin to channel",
|
||||
zap.Int64("channelID", channelID),
|
||||
zap.Error(err))
|
||||
return err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("failed to fetch info for all bots: got %d out of %d", len(botInfoMap), len(botTokens))
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add bots to Telegram channel: %w", err)
|
||||
}
|
||||
|
||||
// Save bots to database
|
||||
payload := []models.Bot{}
|
||||
for _, info := range botInfoMap {
|
||||
payload = append(payload, models.Bot{
|
||||
UserId: userID,
|
||||
Token: info.Token,
|
||||
BotId: info.Id,
|
||||
BotUserName: info.UserName,
|
||||
ChannelId: channelID,
|
||||
})
|
||||
}
|
||||
|
||||
// Clear bot cache for this channel
|
||||
cacher.Delete(cache.Key("users", "bots", userID, channelID))
|
||||
|
||||
// Insert bots with conflict handling
|
||||
if err := db.Clauses(clause.OnConflict{DoNothing: true}).Create(&payload).Error; err != nil {
|
||||
logger.Warn("failed to save bots to database", zap.Error(err))
|
||||
return fmt.Errorf("failed to save bots to database: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("successfully added bots to channel",
|
||||
zap.Int("botCount", len(botInfoMap)),
|
||||
zap.Int64("channelID", channelID))
|
||||
return nil
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue