diff --git a/pkg/downloader/downloader.go b/pkg/downloader/downloader.go new file mode 100644 index 0000000..43e3cb7 --- /dev/null +++ b/pkg/downloader/downloader.go @@ -0,0 +1,112 @@ +package downloader + +import ( + "context" + "fmt" + "github.com/fatih/color" + "github.com/gabriel-vasile/mimetype" + "github.com/gotd/td/telegram/downloader" + "github.com/gotd/td/tg" + "github.com/iyear/tdl/pkg/consts" + "github.com/iyear/tdl/pkg/prog" + "github.com/iyear/tdl/pkg/utils" + "github.com/jedib0t/go-pretty/v6/progress" + "golang.org/x/sync/errgroup" + "os" + "path/filepath" + "strings" + "time" +) + +const TempExt = ".tmp" + +type Downloader struct { + client *tg.Client + pw progress.Writer + partSize int + threads int + iter Iter +} + +func New(client *tg.Client, partSize int, threads int, iter Iter) *Downloader { + return &Downloader{ + client: client, + pw: prog.New(), + partSize: partSize, + threads: threads, + iter: iter, + } +} + +func (d *Downloader) Download(ctx context.Context, limit int) error { + d.pw.SetNumTrackersExpected(d.iter.Total(ctx)) + + go d.pw.Render() + + wg, errctx := errgroup.WithContext(ctx) + wg.SetLimit(limit) + + for d.iter.Next(ctx) { + item, err := d.iter.Value(ctx) + if err != nil { + d.pw.Log(color.RedString("[ERROR] Get item failed: %v", err)) + continue + } + + wg.Go(func() error { + // d.pw.Log(color.MagentaString("name: %s,size: %s", item.Name, utils.Byte.FormatBytes(item.Size))) + return d.download(errctx, item) + }) + } + + err := wg.Wait() + if err != nil { + + return err + } + + for d.pw.IsRenderInProgress() { + if d.pw.LengthActive() == 0 { + d.pw.Stop() + } + time.Sleep(100 * time.Millisecond) + } + + return nil +} + +func (d *Downloader) download(ctx context.Context, item *Item) error { + tracker := prog.AppendTracker(d.pw, item.Name, item.Size) + filename := fmt.Sprintf("%s%s", utils.FS.GetNameWithoutExt(item.Name), TempExt) + path := filepath.Join(consts.DownloadPath, filename) + + f, err := os.Create(path) + if err != nil { + return err + } + + _, err = downloader.NewDownloader().WithPartSize(d.partSize). + Download(d.client, item.InputFileLoc).WithThreads(d.threads). + Parallel(ctx, &writeAt{ + f: f, + tracker: tracker, + }) + if err = f.Close(); err != nil { + return err + } + if err != nil { + return err + } + + mime, err := mimetype.DetectFile(path) + if err != nil { + return err + } + + newfile := fmt.Sprintf("%s%s", strings.TrimSuffix(filename, TempExt), mime.Extension()) + if err = os.Rename(path, filepath.Join(consts.DownloadPath, newfile)); err != nil { + return err + } + + return nil +} diff --git a/pkg/downloader/iter.go b/pkg/downloader/iter.go new file mode 100644 index 0000000..4b78d51 --- /dev/null +++ b/pkg/downloader/iter.go @@ -0,0 +1,18 @@ +package downloader + +import ( + "context" + "github.com/gotd/td/tg" +) + +type Iter interface { + Next(ctx context.Context) bool + Value(ctx context.Context) (*Item, error) + Total(ctx context.Context) int +} + +type Item struct { + InputFileLoc tg.InputFileLocationClass + Name string + Size int64 +} diff --git a/pkg/downloader/write_at.go b/pkg/downloader/write_at.go new file mode 100644 index 0000000..d64b66a --- /dev/null +++ b/pkg/downloader/write_at.go @@ -0,0 +1,22 @@ +package downloader + +import ( + "github.com/jedib0t/go-pretty/v6/progress" + "os" +) + +// writeAt wrapper for file to use progress bar +type writeAt struct { + f *os.File + tracker *progress.Tracker +} + +func (w *writeAt) WriteAt(p []byte, off int64) (int, error) { + at, err := w.f.WriteAt(p, off) + if err != nil { + w.tracker.MarkAsErrored() + return 0, err + } + w.tracker.Increment(int64(at)) + return at, nil +}