tdl/pkg/downloader/downloader.go

123 lines
2.6 KiB
Go

package downloader
import (
"context"
"errors"
"fmt"
"github.com/fatih/color"
"github.com/gabriel-vasile/mimetype"
"github.com/gotd/td/telegram/downloader"
"github.com/gotd/td/tg"
"github.com/iyear/tdl/pkg/consts"
"github.com/iyear/tdl/pkg/prog"
"github.com/iyear/tdl/pkg/utils"
"github.com/jedib0t/go-pretty/v6/progress"
"golang.org/x/sync/errgroup"
"os"
"path/filepath"
"strings"
"time"
)
const TempExt = ".tmp"
type Downloader struct {
client *tg.Client
pw progress.Writer
partSize int
threads int
iter Iter
}
func New(client *tg.Client, partSize int, threads int, iter Iter) *Downloader {
return &Downloader{
client: client,
pw: prog.New(),
partSize: partSize,
threads: threads,
iter: iter,
}
}
func (d *Downloader) Download(ctx context.Context, limit int) error {
d.pw.Log(color.GreenString("All files will be downloaded to '%s' dir", consts.DownloadPath))
d.pw.SetNumTrackersExpected(d.iter.Total(ctx))
go d.pw.Render()
wg, errctx := errgroup.WithContext(ctx)
wg.SetLimit(limit)
for d.iter.Next(ctx) {
item, err := d.iter.Value(ctx)
if err != nil {
d.pw.Log(color.RedString("Get item failed: %v, skip...", err))
continue
}
wg.Go(func() error {
// d.pw.Log(color.MagentaString("name: %s,size: %s", item.Name, utils.Byte.FormatBinaryBytes(item.Size)))
return d.download(errctx, item)
})
}
err := wg.Wait()
if err != nil {
if errors.Is(err, context.Canceled) {
d.pw.Log(color.RedString("Download aborted."))
}
d.pw.Stop()
for d.pw.IsRenderInProgress() {
time.Sleep(time.Millisecond * 10)
}
return err
}
for d.pw.IsRenderInProgress() {
if d.pw.LengthActive() == 0 {
d.pw.Stop()
}
time.Sleep(10 * time.Millisecond)
}
return nil
}
func (d *Downloader) download(ctx context.Context, item *Item) error {
tracker := prog.AppendTracker(d.pw, item.Name, item.Size)
filename := fmt.Sprintf("%s%s", utils.FS.GetNameWithoutExt(item.Name), TempExt)
path := filepath.Join(consts.DownloadPath, filename)
f, err := os.Create(path)
if err != nil {
return err
}
_, err = downloader.NewDownloader().WithPartSize(d.partSize).
Download(d.client, item.InputFileLoc).WithThreads(d.threads).
Parallel(ctx, &writeAt{
f: f,
tracker: tracker,
})
if err := f.Close(); err != nil {
return err
}
if err != nil {
return err
}
mime, err := mimetype.DetectFile(path)
if err != nil {
return err
}
newfile := fmt.Sprintf("%s%s", strings.TrimSuffix(filename, TempExt), mime.Extension())
if err = os.Rename(path, filepath.Join(consts.DownloadPath, newfile)); err != nil {
return err
}
return nil
}