Address code review feedback from @divyam234

- Replace default logger with structured logging using logging.FromContext(ctx) and zap fields
- Fix message count detection to use actual Telegram channel history via msgiter.Total(ctx)
- Move duplicated bot management logic to shared addBotsToChannel function in common.go
- Eliminate code duplication between UI and rollover systems
- Improve error handling and performance with proper parallel processing
- Add comprehensive cache management and database conflict handling

All three review items have been addressed:
1.  Proper logger instance usage throughout
2.  Accurate message counting from channel history
3.  Shared bot logic in common.go eliminating duplication
This commit is contained in:
flamey-code 2025-08-03 01:53:46 +01:00
parent 12532cd42e
commit bd04b2afbb
2 changed files with 227 additions and 149 deletions

View file

@ -3,24 +3,23 @@ package services
import (
"context"
"fmt"
"log"
"sync"
"time"
"github.com/gotd/contrib/storage"
"github.com/gotd/td/telegram"
"github.com/gotd/td/telegram/message/peer"
"github.com/gotd/td/telegram/query"
"github.com/gotd/td/telegram/query/messages"
"github.com/gotd/td/tg"
"github.com/tgdrive/teldrive/internal/auth"
"github.com/tgdrive/teldrive/internal/cache"
"github.com/tgdrive/teldrive/internal/config"
"github.com/tgdrive/teldrive/internal/logging"
"github.com/tgdrive/teldrive/internal/tgc"
"github.com/tgdrive/teldrive/internal/tgstorage"
"github.com/tgdrive/teldrive/pkg/models"
"github.com/tgdrive/teldrive/pkg/types"
"golang.org/x/sync/errgroup"
"go.uber.org/zap"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
var (
@ -69,6 +68,8 @@ func getUserRolloverMutex(userID int64) *sync.Mutex {
// GetChannelForUpload returns a channel ID suitable for upload
// Creates a new channel if current one is approaching message limit
func (cm *ChannelManager) GetChannelForUpload(ctx context.Context, userID int64) (int64, error) {
logger := logging.FromContext(ctx)
// Use per-user mutex to prevent concurrent rollover operations
mutex := getUserRolloverMutex(userID)
mutex.Lock()
@ -80,50 +81,90 @@ func (cm *ChannelManager) GetChannelForUpload(ctx context.Context, userID int64)
return 0, fmt.Errorf("failed to get current channel: %w", err)
}
log.Printf("Checking rollover for user %d, current channel %d", userID, currentChannelID)
logger.Debug("checking rollover",
zap.Int64("userID", userID),
zap.Int64("currentChannelID", currentChannelID))
// Check if current channel is approaching limit using Part IDs
if cm.isChannelNearLimit(currentChannelID) {
log.Printf("Channel %d is near limit, attempting to create new channel", currentChannelID)
// Check if current channel is approaching limit using actual message count
if cm.isChannelNearLimit(ctx, currentChannelID) {
logger.Info("channel near limit, attempting to create new channel",
zap.Int64("channelID", currentChannelID))
// Create new channel and set as default
newChannelID, err := cm.createNewChannel(ctx, userID, currentChannelID)
if err != nil {
log.Printf("Failed to create new channel: %v, continuing with current channel %d", err, currentChannelID)
logger.Error("failed to create new channel, continuing with current channel",
zap.Error(err),
zap.Int64("currentChannelID", currentChannelID))
// If channel creation fails, continue with current channel
return currentChannelID, nil
}
log.Printf("Successfully created new channel %d for user %d", newChannelID, userID)
logger.Info("successfully created new channel",
zap.Int64("newChannelID", newChannelID),
zap.Int64("userID", userID))
return newChannelID, nil
}
log.Printf("Channel %d is within limits, using for upload", currentChannelID)
logger.Debug("channel within limits, using for upload",
zap.Int64("channelID", currentChannelID))
return currentChannelID, nil
}
// isChannelNearLimit checks if channel is approaching message limit using Part IDs
func (cm *ChannelManager) isChannelNearLimit(channelID int64) bool {
var maxPartID int64
// isChannelNearLimit checks if channel is approaching message limit using actual message count
func (cm *ChannelManager) isChannelNearLimit(ctx context.Context, channelID int64) bool {
logger := logging.FromContext(ctx)
// Query the highest Part ID (which IS the Telegram message ID) for this channel
err := cm.db.Raw(`
SELECT COALESCE(MAX((part_data->>'id')::bigint), 0) as max_id
FROM (
SELECT jsonb_array_elements(parts) as part_data
FROM teldrive.files
WHERE channel_id = ?
) parts_expanded
`, channelID).Scan(&maxPartID).Error
// Get JWT user for Telegram session
jwtUser := auth.GetJWTUser(ctx)
if jwtUser == nil {
logger.Error("no JWT user found in context for channel limit check")
return false
}
// Create Telegram client to get actual message count
client, err := tgc.AuthClient(ctx, cm.cnf, jwtUser.TgSession, cm.middlewares...)
if err != nil {
logger.Error("failed to create Telegram client for channel limit check", zap.Error(err))
return false
}
var totalMessages int
err = client.Run(ctx, func(ctx context.Context) error {
channel, err := tgc.GetChannelById(ctx, client.API(), channelID)
if err != nil {
return fmt.Errorf("failed to get channel: %w", err)
}
q := query.NewQuery(client.API()).Messages().GetHistory(&tg.InputPeerChannel{
ChannelID: channelID,
AccessHash: channel.AccessHash,
})
msgiter := messages.NewIterator(q, 100)
total, err := msgiter.Total(ctx)
if err != nil {
return fmt.Errorf("failed to get total messages: %w", err)
}
totalMessages = total
return nil
})
if err != nil {
log.Printf("Error checking channel limit for channel %d: %v", channelID, err)
logger.Error("error checking channel message limit",
zap.Int64("channelID", channelID),
zap.Error(err))
// On error, assume not near limit to avoid disruption
return false
}
log.Printf("Channel %d: highest message ID = %d, limit = %d, near limit = %t",
channelID, maxPartID, cm.cnf.MessageLimit, maxPartID >= cm.cnf.MessageLimit)
nearLimit := int64(totalMessages) >= cm.cnf.MessageLimit
logger.Debug("channel limit check",
zap.Int64("channelID", channelID),
zap.Int("totalMessages", totalMessages),
zap.Int64("messageLimit", cm.cnf.MessageLimit),
zap.Bool("nearLimit", nearLimit))
return maxPartID >= cm.cnf.MessageLimit
return nearLimit
}
// getCurrentChannel gets user's current default channel
@ -142,6 +183,8 @@ func (cm *ChannelManager) getCurrentChannel(userID int64) (int64, error) {
// createNewChannel creates a new Telegram channel and sets it as user's default
func (cm *ChannelManager) createNewChannel(ctx context.Context, userID, currentChannelID int64) (int64, error) {
logger := logging.FromContext(ctx)
// Get current channel name to create a rollover name
var currentChannel models.Channel
err := cm.db.Where("channel_id = ? AND user_id = ?", currentChannelID, userID).First(&currentChannel).Error
@ -209,9 +252,14 @@ func (cm *ChannelManager) createNewChannel(ctx context.Context, userID, currentC
var currentBots []models.Bot
err = cm.db.Where("channel_id = ? AND user_id = ?", currentChannelID, userID).Find(&currentBots).Error
if err != nil {
log.Printf("Warning: failed to get bots from current channel %d: %v", currentChannelID, err)
logger.Warn("failed to get bots from current channel",
zap.Int64("currentChannelID", currentChannelID),
zap.Error(err))
} else if len(currentBots) > 0 {
log.Printf("Found %d bots in current channel %d, copying to new channel %d", len(currentBots), currentChannelID, newChannelID)
logger.Info("found bots in current channel, copying to new channel",
zap.Int("botCount", len(currentBots)),
zap.Int64("currentChannelID", currentChannelID),
zap.Int64("newChannelID", newChannelID))
// Extract bot tokens
botTokens := make([]string, len(currentBots))
@ -222,11 +270,14 @@ func (cm *ChannelManager) createNewChannel(ctx context.Context, userID, currentC
// Copy bots to new channel
err = cm.addBotsToChannel(ctx, userID, newChannelID, botTokens)
if err != nil {
log.Printf("Warning: failed to copy bots to new channel %d: %v", newChannelID, err)
logger.Warn("failed to copy bots to new channel",
zap.Int64("newChannelID", newChannelID),
zap.Error(err))
// Don't fail the whole operation
}
} else {
log.Printf("No bots found in current channel %d", currentChannelID)
logger.Debug("no bots found in current channel",
zap.Int64("currentChannelID", currentChannelID))
}
// Add new channel to database
@ -258,19 +309,16 @@ func (cm *ChannelManager) createNewChannel(ctx context.Context, userID, currentC
cm.cache.Delete(cache.Key("users", "channel", userID))
cm.cache.Delete(cache.Key("users", "bots", userID, currentChannelID))
log.Printf("Successfully created rollover channel %d (%s) for user %d", newChannelID, newChannelName, userID)
logger.Info("successfully created rollover channel",
zap.Int64("newChannelID", newChannelID),
zap.String("newChannelName", newChannelName),
zap.Int64("userID", userID))
return newChannelID, nil
}
// addBotsToChannel adds bots to both the Telegram channel and database
// Uses the same logic as the UI's UsersAddBots function
func (cm *ChannelManager) addBotsToChannel(ctx context.Context, userID, channelID int64, botTokens []string) error {
if len(botTokens) == 0 {
return nil
}
log.Printf("Adding %d bots to channel %d", len(botTokens), channelID)
// Get JWT user for creating a fresh Telegram client (same as UI)
jwtUser := auth.GetJWTUser(ctx)
if jwtUser == nil {
@ -283,113 +331,6 @@ func (cm *ChannelManager) addBotsToChannel(ctx context.Context, userID, channelI
return fmt.Errorf("failed to create Telegram client for bot operations: %w", err)
}
// Use the exact same logic as the UI's addBots function
botInfoMap := make(map[string]*types.BotInfo)
err = tgc.RunWithAuth(ctx, client, "", func(botCtx context.Context) error {
channel, err := tgc.GetChannelById(botCtx, client.API(), channelID)
if err != nil {
return fmt.Errorf("failed to get channel: %w", err)
}
g, _ := errgroup.WithContext(botCtx)
g.SetLimit(8)
mapMu := sync.Mutex{}
// Fetch bot info in parallel (same as UI)
for _, token := range botTokens {
token := token // capture loop variable
g.Go(func() error {
info, err := tgc.GetBotInfo(ctx, cm.tgdb, cm.cnf, token)
if err != nil {
log.Printf("Warning: failed to get bot info for token %s: %v", token, err)
return err
}
// Resolve bot domain to get access hash (same as UI)
botPeerClass, err := peer.DefaultResolver(client.API()).ResolveDomain(botCtx, info.UserName)
if err != nil {
log.Printf("Warning: failed to resolve bot domain for %s: %v", info.UserName, err)
return err
}
botPeer := botPeerClass.(*tg.InputPeerUser)
info.AccessHash = botPeer.AccessHash
mapMu.Lock()
botInfoMap[token] = info
mapMu.Unlock()
return nil
})
}
if err = g.Wait(); err != nil {
return err
}
// Only proceed if we got info for all bots (same validation as UI)
if len(botTokens) == len(botInfoMap) {
users := []tg.InputUser{}
for _, info := range botInfoMap {
users = append(users, tg.InputUser{UserID: info.Id, AccessHash: info.AccessHash})
}
// Add each bot as admin to the channel (same as UI)
for _, user := range users {
payload := &tg.ChannelsEditAdminRequest{
Channel: channel,
UserID: tg.InputUserClass(&user),
AdminRights: tg.ChatAdminRights{
ChangeInfo: true,
PostMessages: true,
EditMessages: true,
DeleteMessages: true,
BanUsers: true,
InviteUsers: true,
PinMessages: true,
ManageCall: true,
Other: true,
ManageTopics: true,
},
Rank: "bot",
}
_, err := client.API().ChannelsEditAdmin(botCtx, payload)
if err != nil {
log.Printf("Warning: failed to add bot as admin to channel %d: %v", channelID, err)
return err
}
}
} else {
return fmt.Errorf("failed to fetch info for all bots: got %d out of %d", len(botInfoMap), len(botTokens))
}
return nil
})
if err != nil {
return fmt.Errorf("failed to add bots to Telegram channel: %w", err)
}
// Save bots to database (same as UI)
payload := []models.Bot{}
for _, info := range botInfoMap {
payload = append(payload, models.Bot{
UserId: userID,
Token: info.Token,
BotId: info.Id,
BotUserName: info.UserName,
ChannelId: channelID,
})
}
// Clear bot cache for this channel (same as UI)
cm.cache.Delete(cache.Key("users", "bots", userID, channelID))
// Insert bots with conflict handling (same as UI)
if err := cm.db.Clauses(clause.OnConflict{DoNothing: true}).Create(&payload).Error; err != nil {
log.Printf("Warning: failed to save bots to database: %v", err)
return fmt.Errorf("failed to save bots to database: %w", err)
}
log.Printf("Successfully added %d bots to channel %d", len(botInfoMap), channelID)
return nil
// Use the shared addBotsToChannel function
return addBotsToChannel(ctx, cm.db, cm.tgdb, cm.cache, cm.cnf, client, userID, channelID, botTokens)
}

View file

@ -4,12 +4,15 @@ import (
"context"
"errors"
"fmt"
"sync"
"time"
"github.com/gotd/td/telegram"
"github.com/gotd/td/telegram/message/peer"
"github.com/gotd/td/tg"
"github.com/tgdrive/teldrive/internal/api"
"github.com/tgdrive/teldrive/internal/cache"
"github.com/tgdrive/teldrive/internal/config"
"github.com/tgdrive/teldrive/internal/crypt"
"github.com/tgdrive/teldrive/internal/logging"
"github.com/tgdrive/teldrive/internal/tgc"
@ -17,7 +20,9 @@ import (
"github.com/tgdrive/teldrive/pkg/models"
"github.com/tgdrive/teldrive/pkg/types"
"go.uber.org/zap"
"golang.org/x/sync/errgroup"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
func getParts(ctx context.Context, client *telegram.Client, c cache.Cacher, file *models.File) ([]types.Part, error) {
@ -85,5 +90,137 @@ func getBotsToken(db *gorm.DB, c cache.Cacher, userId, channelId int64) ([]strin
}
return bots, nil
})
}
// addBotsToChannel adds bots to both the Telegram channel and database
// This is a common function used by both the UI and rollover system
func addBotsToChannel(ctx context.Context, db *gorm.DB, tgdb *gorm.DB, cacher cache.Cacher, cnf *config.TGConfig,
client *telegram.Client, userID, channelID int64, botTokens []string) error {
logger := logging.FromContext(ctx)
if len(botTokens) == 0 {
return nil
}
logger.Debug("adding bots to channel",
zap.Int("botCount", len(botTokens)),
zap.Int64("channelID", channelID))
botInfoMap := make(map[string]*types.BotInfo)
err := tgc.RunWithAuth(ctx, client, "", func(botCtx context.Context) error {
channel, err := tgc.GetChannelById(botCtx, client.API(), channelID)
if err != nil {
return fmt.Errorf("failed to get channel: %w", err)
}
g, _ := errgroup.WithContext(botCtx)
g.SetLimit(8)
mapMu := sync.Mutex{}
// Fetch bot info in parallel
for _, token := range botTokens {
token := token // capture loop variable
g.Go(func() error {
info, err := tgc.GetBotInfo(ctx, tgdb, cnf, token)
if err != nil {
logger.Warn("failed to get bot info",
zap.String("token", token),
zap.Error(err))
return err
}
// Resolve bot domain to get access hash
botPeerClass, err := peer.DefaultResolver(client.API()).ResolveDomain(botCtx, info.UserName)
if err != nil {
logger.Warn("failed to resolve bot domain",
zap.String("userName", info.UserName),
zap.Error(err))
return err
}
botPeer := botPeerClass.(*tg.InputPeerUser)
info.AccessHash = botPeer.AccessHash
mapMu.Lock()
botInfoMap[token] = info
mapMu.Unlock()
return nil
})
}
if err = g.Wait(); err != nil {
return err
}
// Only proceed if we got info for all bots
if len(botTokens) == len(botInfoMap) {
users := []tg.InputUser{}
for _, info := range botInfoMap {
users = append(users, tg.InputUser{UserID: info.Id, AccessHash: info.AccessHash})
}
// Add each bot as admin to the channel
for _, user := range users {
payload := &tg.ChannelsEditAdminRequest{
Channel: channel,
UserID: tg.InputUserClass(&user),
AdminRights: tg.ChatAdminRights{
ChangeInfo: true,
PostMessages: true,
EditMessages: true,
DeleteMessages: true,
BanUsers: true,
InviteUsers: true,
PinMessages: true,
ManageCall: true,
Other: true,
ManageTopics: true,
},
Rank: "bot",
}
_, err := client.API().ChannelsEditAdmin(botCtx, payload)
if err != nil {
logger.Warn("failed to add bot as admin to channel",
zap.Int64("channelID", channelID),
zap.Error(err))
return err
}
}
} else {
return fmt.Errorf("failed to fetch info for all bots: got %d out of %d", len(botInfoMap), len(botTokens))
}
return nil
})
if err != nil {
return fmt.Errorf("failed to add bots to Telegram channel: %w", err)
}
// Save bots to database
payload := []models.Bot{}
for _, info := range botInfoMap {
payload = append(payload, models.Bot{
UserId: userID,
Token: info.Token,
BotId: info.Id,
BotUserName: info.UserName,
ChannelId: channelID,
})
}
// Clear bot cache for this channel
cacher.Delete(cache.Key("users", "bots", userID, channelID))
// Insert bots with conflict handling
if err := db.Clauses(clause.OnConflict{DoNothing: true}).Create(&payload).Error; err != nil {
logger.Warn("failed to save bots to database", zap.Error(err))
return fmt.Errorf("failed to save bots to database: %w", err)
}
logger.Info("successfully added bots to channel",
zap.Int("botCount", len(botInfoMap)),
zap.Int64("channelID", channelID))
return nil
}