package cmd import ( "context" "encoding/json" "fmt" "os" "strings" "sync" "time" "github.com/fatih/color" "github.com/gotd/td/telegram" "github.com/gotd/td/telegram/query" "github.com/gotd/td/telegram/query/messages" "github.com/gotd/td/tg" "github.com/jedib0t/go-pretty/v6/progress" "github.com/jedib0t/go-pretty/v6/text" "github.com/manifoldco/promptui" "github.com/spf13/cobra" "github.com/tgdrive/teldrive/internal/api" "github.com/tgdrive/teldrive/internal/config" "github.com/tgdrive/teldrive/internal/crypt" "github.com/tgdrive/teldrive/internal/database" "github.com/tgdrive/teldrive/internal/logging" "github.com/tgdrive/teldrive/internal/tgc" "github.com/tgdrive/teldrive/internal/utils" "github.com/tgdrive/teldrive/pkg/models" "golang.org/x/sync/errgroup" "golang.org/x/term" "gorm.io/datatypes" "gorm.io/gorm" ) var termWidth = func() (width int, err error) { width, _, err = term.GetSize(int(os.Stdout.Fd())) if err == nil { return width, nil } return 0, err } type file struct { ID string Name string Size int64 Encrypted bool Parts datatypes.JSONSlice[api.Part] } type exportFile struct { ID string `json:"id"` Name string `json:"name"` } type channelExport struct { ChannelID int64 `json:"channel_id"` Timestamp string `json:"timestamp"` FileCount int `json:"file_count"` Files []exportFile `json:"files"` } type channelProcessor struct { id int64 files []file missingFiles []file orphanMessages []int totalCount int64 pw progress.Writer tracker *progress.Tracker channelExport *channelExport client *telegram.Client ctx context.Context db *gorm.DB userId int64 clean bool } func NewCheckCmd() *cobra.Command { var cfg config.ServerCmdConfig loader := config.NewConfigLoader() cmd := &cobra.Command{ Use: "check", Short: "Check and purge incomplete files", Run: func(cmd *cobra.Command, args []string) { runCheckCmd(cmd, &cfg) }, PersistentPreRunE: func(cmd *cobra.Command, args []string) error { if err := loader.Load(cmd, &cfg); err != nil { return err } if err := checkRequiredCheckFlags(&cfg); err != nil { return err } return nil }, } loader.RegisterPlags(cmd.Flags(), "", cfg, true) cmd.Flags().Bool("export", true, "Export incomplete files to json file") cmd.Flags().Bool("clean", false, "Clean missing and orphan file parts") cmd.Flags().String("user", "", "Telegram User Name") cmd.Flags().Int("concurrent", 4, "Number of concurrent channel processing") return cmd } func checkRequiredCheckFlags(cfg *config.ServerCmdConfig) error { var missingFields []string if cfg.DB.DataSource == "" { missingFields = append(missingFields, "db-data-source") } if len(missingFields) > 0 { return fmt.Errorf("required configuration values not set: %s", strings.Join(missingFields, ", ")) } return nil } func selectUser(user string, users []models.User) (*models.User, error) { if user != "" { res := utils.Filter(users, func(u models.User) bool { return u.UserName == user }) if len(res) == 0 { return nil, fmt.Errorf("invalid user name: %s", user) } return &res[0], nil } templates := &promptui.SelectTemplates{ Label: "{{ . }}", Active: "{{ .UserName | cyan }}", Inactive: "{{ .UserName | white }}", Selected: "{{ .UserName | red | cyan }}", } prompt := promptui.Select{ Label: "Select User", Items: users, Templates: templates, Size: 50, } index, _, err := prompt.Run() if err != nil { return nil, err } return &users[index], nil } func (cp *channelProcessor) updateStatus(status string, value int64) { cp.tracker.SetValue(value) cp.tracker.UpdateMessage(fmt.Sprintf("Channel %d: %s", cp.id, status)) } func (cp *channelProcessor) process() error { cp.updateStatus("Loading files", 0) files, err := cp.loadFiles() if err != nil { return fmt.Errorf("failed to load files: %w", err) } cp.files = files if len(cp.files) == 0 { cp.updateStatus("No files found", 100) return nil } cp.updateStatus("Loading messages from Telegram", 0) msgs, total, err := cp.loadChannelMessages() if err != nil { return fmt.Errorf("failed to load messages: %w", err) } if total == 0 && len(msgs) == 0 { cp.updateStatus("No messages found", 100) return nil } if len(msgs) < total { return fmt.Errorf("found %d messages out of %d", len(msgs), total) } cp.updateStatus("Processing messages and parts", 0) uploadPartIds := []int{} if err := cp.db.Model(&models.Upload{}). Where("user_id = ?", cp.userId). Where("channel_id = ?", cp.id). Pluck("part_id", &uploadPartIds).Error; err != nil { return err } uploadPartMap := make(map[int]bool) for _, partID := range uploadPartIds { uploadPartMap[partID] = true } msgMap := make(map[int]int64) for _, m := range msgs { id := m.Msg.GetID() _, ok := uploadPartMap[id] if id > 0 && !ok { doc, ok := m.Document() if !ok { msgMap[id] = 0 } else { msgMap[id] = doc.GetSize() } } } cp.updateStatus("Checking file integrity", 0) allPartIDs := make(map[int]bool) for _, f := range cp.files { size := int64(0) for _, p := range f.Parts { if p.ID != 0 { allPartIDs[p.ID] = true } _, ok := msgMap[p.ID] if !ok { cp.missingFiles = append(cp.missingFiles, f) break } if f.Encrypted { d, _ := crypt.DecryptedSize(msgMap[p.ID]) size += d } else { size += msgMap[p.ID] } } if size != f.Size { cp.missingFiles = append(cp.missingFiles, f) } } if len(allPartIDs) == 0 { cp.updateStatus("No parts found", 100) return nil } for msgID := range msgMap { _, ok := allPartIDs[msgID] if !ok { cp.orphanMessages = append(cp.orphanMessages, msgID) } } if len(cp.missingFiles) > 0 { cp.channelExport = &channelExport{ ChannelID: cp.id, Timestamp: time.Now().Format(time.RFC3339), FileCount: len(cp.missingFiles), Files: make([]exportFile, 0, len(cp.missingFiles)), } for _, f := range cp.missingFiles { cp.channelExport.Files = append(cp.channelExport.Files, exportFile{ ID: f.ID, Name: f.Name, }) } if cp.clean { cp.updateStatus("Cleaning files", 0) err = cp.db.Exec("call teldrive.delete_files_bulk($1 , $2)", utils.Map(cp.missingFiles, func(f file) string { return f.ID }), cp.userId).Error if err != nil { return err } } } if cp.clean && len(cp.orphanMessages) > 0 { cp.updateStatus("Cleaning orphan messages", 0) tgc.DeleteMessages(cp.ctx, cp.client, cp.id, cp.orphanMessages) } cp.updateStatus("Complete", 100) return nil } func (cp *channelProcessor) loadFiles() ([]file, error) { var files []file const batchSize = 1000 var totalFiles int64 var lastID string if err := cp.db.Model(&models.File{}). Where("user_id = ?", cp.userId). Where("channel_id = ?", cp.id). Where("type = ?", "file"). Count(&totalFiles).Error; err != nil { return nil, err } if totalFiles == 0 { return nil, nil } processed := int64(0) for { var batch []file query := cp.db.WithContext(cp.ctx).Model(&models.File{}). Where("user_id = ?", cp.userId). Where("channel_id = ?", cp.id). Where("type = ?", "file"). Order("id"). Limit(batchSize) if lastID != "" { query = query.Where("id > ?", lastID) } result := query.Scan(&batch) if result.Error != nil { return nil, result.Error } if len(batch) == 0 { break } files = append(files, batch...) processed += int64(len(batch)) lastID = batch[len(batch)-1].ID progress := (float64(processed) / float64(totalFiles)) * 100 cp.updateStatus(fmt.Sprintf("Loading files: %d/%d", processed, totalFiles), int64(progress)) if len(batch) < batchSize { break } } return files, nil } func (cp *channelProcessor) loadChannelMessages() (msgs []messages.Elem, total int, err error) { err = tgc.RunWithAuth(cp.ctx, cp.client, "", func(ctx context.Context) error { var channel *tg.InputChannel channel, err = tgc.GetChannelById(ctx, cp.client.API(), cp.id) if err != nil { return err } q := query.NewQuery(cp.client.API()).Messages().GetHistory(&tg.InputPeerChannel{ ChannelID: cp.id, AccessHash: channel.AccessHash, }) msgiter := messages.NewIterator(q, 100) total, err = msgiter.Total(ctx) if err != nil { return fmt.Errorf("failed to get total messages: %w", err) } processed := 0 for msgiter.Next(ctx) { msg := msgiter.Value() msgs = append(msgs, msg) processed++ if processed%100 == 0 { progress := (float64(processed) / float64(total)) * 100 cp.updateStatus(fmt.Sprintf("Loading messages: %d/%d", processed, total), int64(progress)) } } return nil }) return } func runCheckCmd(cmd *cobra.Command, cfg *config.ServerCmdConfig) { ctx := cmd.Context() lg := logging.DefaultLogger().Sugar() defer logging.DefaultLogger().Sync() cfg.DB.LogLevel = "fatal" db, err := database.NewDatabase(&cfg.DB, lg) if err != nil { lg.Fatalw("failed to create database", "err", err) } users := []models.User{} if err := db.Model(&models.User{}).Find(&users).Error; err != nil { lg.Fatalw("failed to get users", "err", err) } userName, _ := cmd.Flags().GetString("user") user, err := selectUser(userName, users) if err != nil { lg.Fatalw("failed to select user", "err", err) } session := models.Session{} if err := db.Model(&models.Session{}). Where("user_id = ?", user.UserId). Order("created_at desc"). First(&session).Error; err != nil { lg.Fatalw("failed to get session", "err", err) } channelIds := []int64{} if err := db.Model(&models.Channel{}). Where("user_id = ?", user.UserId). Pluck("channel_id", &channelIds).Error; err != nil { lg.Fatalw("failed to get channels", "err", err) } if len(channelIds) == 0 { lg.Fatalw("no channels found") } middlewares := tgc.NewMiddleware(&cfg.TG, tgc.WithFloodWait(), tgc.WithRateLimit()) export, _ := cmd.Flags().GetBool("export") clean, _ := cmd.Flags().GetBool("clean") concurrent, _ := cmd.Flags().GetInt("concurrent") pw := progress.NewWriter() pw.SetAutoStop(false) width := 75 if size, err := termWidth(); err == nil { width = int((float32(3) / float32(4)) * float32(size)) } pw.SetTrackerLength(width / 5) pw.SetMessageLength(width * 3 / 5) pw.SetStyle(progress.StyleDefault) pw.SetTrackerPosition(progress.PositionRight) pw.SetUpdateFrequency(time.Millisecond * 100) pw.Style().Colors = progress.StyleColorsExample pw.Style().Colors.Message = text.Colors{text.FgBlue} pw.Style().Options.PercentFormat = "%4.1f%%" pw.Style().Visibility.Value = false pw.Style().Options.TimeInProgressPrecision = time.Millisecond pw.Style().Options.ErrorString = color.RedString("failed!") pw.Style().Options.DoneString = color.GreenString("done!") var channelExports []channelExport var mutex sync.Mutex g, ctx := errgroup.WithContext(ctx) g.SetLimit(concurrent) go pw.Render() for _, id := range channelIds { g.Go(func() error { client, err := tgc.AuthClient(ctx, &cfg.TG, session.Session, middlewares...) if err != nil { lg.Errorw("failed to create client", "err", err, "channel", id) return fmt.Errorf("failed to create client for channel %d: %w", id, err) } tracker := &progress.Tracker{ Message: fmt.Sprintf("Channel %d: Initializing", id), Total: 100, Units: progress.UnitsDefault, } pw.AppendTracker(tracker) processor := &channelProcessor{ id: id, client: client, ctx: ctx, db: db, userId: user.UserId, clean: clean, pw: pw, tracker: tracker, totalCount: 100, } if err := processor.process(); err != nil { tracker.MarkAsErrored() return err } if processor.channelExport != nil { mutex.Lock() channelExports = append(channelExports, *processor.channelExport) mutex.Unlock() } return nil }) } if err := g.Wait(); err != nil { lg.Fatal(fmt.Errorf("one or more channels failed to process")) } pw.Stop() if export && len(channelExports) > 0 { jsonData, err := json.MarshalIndent(channelExports, "", " ") if err != nil { lg.Errorw("failed to marshal JSON", "err", err) return } err = os.WriteFile("results.json", jsonData, 0644) if err != nil { lg.Errorw("failed to write JSON file", "err", err) return } lg.Infof("Exported data to results.json") } }