refactor: multireader and workers

This commit is contained in:
divyam234 2024-08-04 17:13:40 +05:30
parent 81c4bd775a
commit 3605ea4193
6 changed files with 225 additions and 233 deletions

View file

@ -113,10 +113,15 @@ func (r *decrpytedReader) nextPart() (io.ReadCloser, error) {
if underlyingLimit >= 0 {
end = min(r.parts[r.ranges[r.pos].PartNo].Size-1, underlyingOffset+underlyingLimit-1)
}
chunkSrc := &chunkSource{channelId: r.channelId, worker: r.worker,
fileId: r.fileId, partId: r.parts[r.ranges[r.pos].PartNo].ID,
client: r.client, concurrency: r.concurrency, cache: r.cache}
chunkSrc := &chunkSource{
channelID: r.channelId,
worker: r.worker,
fileID: r.fileId,
partID: r.parts[r.ranges[r.pos].PartNo].ID,
client: r.client,
concurrency: r.concurrency,
cache: r.cache,
}
if r.concurrency < 2 {
return newTGReader(r.ctx, underlyingOffset, end, chunkSrc)
}

View file

@ -11,25 +11,27 @@ import (
)
func calculatePartByteRanges(startByte, endByte, partSize int64) []types.Range {
partByteRanges := []types.Range{}
startPart := startByte / partSize
endPart := endByte / partSize
startOffset := startByte % partSize
for part := startPart; part <= endPart; part++ {
partStartByte := int64(0)
partEndByte := partSize - 1
if part == startPart {
partStartByte = startOffset
}
if part == endPart {
partEndByte = int64(endByte % partSize)
partEndByte = endByte % partSize
}
partByteRanges = append(partByteRanges, types.Range{Start: partStartByte, End: partEndByte, PartNo: part})
partByteRanges = append(partByteRanges, types.Range{
Start: partStartByte,
End: partEndByte,
PartNo: part,
})
startOffset = 0
}
@ -37,7 +39,7 @@ func calculatePartByteRanges(startByte, endByte, partSize int64) []types.Range {
return partByteRanges
}
type linearReader struct {
type LinearReader struct {
ctx context.Context
parts []types.Part
ranges []types.Range
@ -45,27 +47,19 @@ type linearReader struct {
reader io.ReadCloser
limit int64
config *config.TGConfig
channelId int64
channelID int64
worker *tgc.StreamWorker
client *tgc.Client
fileId string
fileID string
concurrency int
cache cache.Cacher
}
func NewLinearReader(ctx context.Context,
fileId string,
parts []types.Part,
start, end int64,
channelId int64,
config *config.TGConfig,
concurrency int,
client *tgc.Client,
worker *tgc.StreamWorker,
cache cache.Cacher,
) (reader io.ReadCloser, err error) {
func NewLinearReader(ctx context.Context, fileID string, parts []types.Part, start, end int64,
channelID int64, config *config.TGConfig, concurrency int, client *tgc.Client,
worker *tgc.StreamWorker, cache cache.Cacher) (io.ReadCloser, error) {
r := &linearReader{
r := &LinearReader{
ctx: ctx,
parts: parts,
limit: end - start + 1,
@ -73,14 +67,14 @@ func NewLinearReader(ctx context.Context,
config: config,
client: client,
worker: worker,
channelId: channelId,
fileId: fileId,
channelID: channelID,
fileID: fileID,
concurrency: concurrency,
cache: cache,
}
var err error
r.reader, err = r.nextPart()
if err != nil {
return nil, err
}
@ -88,49 +82,51 @@ func NewLinearReader(ctx context.Context,
return r, nil
}
func (r *linearReader) Read(p []byte) (int, error) {
func (r *LinearReader) Read(p []byte) (int, error) {
if r.limit <= 0 {
return 0, io.EOF
}
n, err := r.reader.Read(p)
if err == io.EOF {
if r.limit > 0 {
err = nil
if r.reader != nil {
r.reader.Close()
}
if err == io.EOF && r.limit > 0 {
err = nil
if r.reader != nil {
r.reader.Close()
}
r.pos++
if r.pos < len(r.ranges) {
r.reader, err = r.nextPart()
}
}
r.limit -= int64(n)
return n, err
}
func (r *linearReader) nextPart() (io.ReadCloser, error) {
func (r *LinearReader) nextPart() (io.ReadCloser, error) {
start := r.ranges[r.pos].Start
end := r.ranges[r.pos].End
chunkSrc := &chunkSource{channelId: r.channelId, worker: r.worker,
fileId: r.fileId, partId: r.parts[r.ranges[r.pos].PartNo].ID,
client: r.client, concurrency: r.concurrency, cache: r.cache}
chunkSrc := &chunkSource{
channelID: r.channelID,
worker: r.worker,
fileID: r.fileID,
partID: r.parts[r.ranges[r.pos].PartNo].ID,
client: r.client,
concurrency: r.concurrency,
cache: r.cache,
}
if r.concurrency < 2 {
return newTGReader(r.ctx, start, end, chunkSrc)
}
return newTGMultiReader(r.ctx, start, end, r.config, chunkSrc)
}
func (r *linearReader) Close() (err error) {
func (r *LinearReader) Close() error {
if r.reader != nil {
err = r.reader.Close()
err := r.reader.Close()
r.reader = nil
return err
}

View file

@ -5,7 +5,6 @@ import (
"errors"
"fmt"
"io"
"sync"
"time"
"github.com/divyam234/teldrive/internal/cache"
@ -15,7 +14,10 @@ import (
"golang.org/x/sync/errgroup"
)
var ErrorStreamAbandoned = errors.New("stream abandoned")
var (
ErrStreamAbandoned = errors.New("stream abandoned")
ErrChunkTimeout = errors.New("chunk fetch timed out")
)
type ChunkSource interface {
Chunk(ctx context.Context, offset int64, limit int64) ([]byte, error)
@ -23,10 +25,10 @@ type ChunkSource interface {
}
type chunkSource struct {
channelId int64
channelID int64
worker *tgc.StreamWorker
fileId string
partId int64
fileID string
partID int64
concurrency int
client *tgc.Client
cache cache.Cacher
@ -52,9 +54,9 @@ func (c *chunkSource) Chunk(ctx context.Context, offset int64, limit int64) ([]b
}()
if c.concurrency > 0 {
client, _, _ = c.worker.Next(c.channelId)
client, _, _ = c.worker.Next(c.channelID)
}
location, err = tgc.GetLocation(ctx, client, c.cache, c.fileId, c.channelId, c.partId)
location, err = tgc.GetLocation(ctx, client, c.cache, c.fileID, c.channelID, c.partID)
if err != nil {
return nil, err
@ -66,22 +68,19 @@ func (c *chunkSource) Chunk(ctx context.Context, offset int64, limit int64) ([]b
type tgMultiReader struct {
ctx context.Context
cancel context.CancelFunc
offset int64
limit int64
chunkSize int64
bufferChan chan *buffer
done chan struct{}
cur *buffer
err chan error
mu sync.Mutex
concurrency int
leftCut int64
rightCut int64
totalParts int
currentPart int
closed bool
timeout time.Duration
chunkSrc ChunkSource
timeout time.Duration
}
func newTGMultiReader(
@ -90,15 +89,15 @@ func newTGMultiReader(
end int64,
config *config.TGConfig,
chunkSrc ChunkSource,
) (*tgMultiReader, error) {
chunkSize := chunkSrc.ChunkSize(start, end)
offset := start - (start % chunkSize)
ctx, cancel := context.WithCancel(ctx)
r := &tgMultiReader{
ctx: ctx,
cancel: cancel,
limit: end - start + 1,
bufferChan: make(chan *buffer, config.Stream.Buffers),
concurrency: config.Stream.MultiThreads,
@ -109,8 +108,6 @@ func newTGMultiReader(
chunkSize: chunkSize,
chunkSrc: chunkSrc,
timeout: config.Stream.ChunkTimeout,
done: make(chan struct{}, 1),
err: make(chan error, 1),
}
go r.fillBufferConcurrently()
@ -118,45 +115,24 @@ func newTGMultiReader(
}
func (r *tgMultiReader) Close() error {
close(r.done)
close(r.bufferChan)
r.closed = true
for b := range r.bufferChan {
if b != nil {
b = nil
}
}
if r.cur != nil {
r.cur = nil
}
close(r.err)
r.cancel()
return nil
}
func (r *tgMultiReader) Read(p []byte) (int, error) {
r.mu.Lock()
defer r.mu.Unlock()
if r.limit <= 0 {
return 0, io.EOF
}
if r.cur.isEmpty() {
if r.cur != nil {
r.cur = nil
}
if r.cur == nil || r.cur.isEmpty() {
select {
case cur, ok := <-r.bufferChan:
if !ok && r.limit > 0 {
return 0, ErrorStreamAbandoned
if !ok {
return 0, ErrStreamAbandoned
}
r.cur = cur
case err := <-r.err:
return 0, fmt.Errorf("error reading chunk: %w", err)
case <-r.ctx.Done():
return 0, r.ctx.Err()
}
}
@ -171,91 +147,90 @@ func (r *tgMultiReader) Read(p []byte) (int, error) {
return n, nil
}
func (r *tgMultiReader) fillBufferConcurrently() error {
func (r *tgMultiReader) fillBufferConcurrently() {
defer close(r.bufferChan)
var mapMu sync.Mutex
bufferMap := make(map[int]*buffer)
defer func() {
for i := range bufferMap {
delete(bufferMap, i)
for r.currentPart < r.totalParts {
if err := r.fillBatch(); err != nil {
r.cancel()
return
}
}()
}
}
cb := func(ctx context.Context, i int) func() error {
return func() error {
func (r *tgMultiReader) fillBatch() error {
g, ctx := errgroup.WithContext(r.ctx)
g.SetLimit(r.concurrency)
chunk, err := r.chunkSrc.Chunk(ctx, r.offset+(int64(i)*r.chunkSize), r.chunkSize)
buffers := make([]*buffer, r.concurrency)
for i := 0; i < r.concurrency && r.currentPart+i < r.totalParts; i++ {
g.Go(func() error {
chunkCtx, cancel := context.WithTimeout(ctx, r.timeout)
defer cancel()
chunk, err := r.fetchChunkWithTimeout(chunkCtx, int64(i))
if err != nil {
return err
if errors.Is(err, context.DeadlineExceeded) {
return fmt.Errorf("chunk %d: %w", r.currentPart+i, ErrChunkTimeout)
}
return fmt.Errorf("chunk %d: %w", r.currentPart+i, err)
}
if r.totalParts == 1 {
chunk = chunk[r.leftCut:r.rightCut]
} else if r.currentPart+i+1 == 1 {
} else if r.currentPart+i == 0 {
chunk = chunk[r.leftCut:]
} else if r.currentPart+i+1 == r.totalParts {
chunk = chunk[:r.rightCut]
}
buf := &buffer{buf: chunk}
mapMu.Lock()
bufferMap[i] = buf
mapMu.Unlock()
buffers[i] = &buffer{buf: chunk}
return nil
}
})
}
for {
if err := g.Wait(); err != nil {
return err
}
g := errgroup.Group{}
g.SetLimit(r.concurrency)
for i := range r.concurrency {
if r.currentPart+i+1 <= r.totalParts {
g.Go(cb(r.ctx, i))
}
for _, buf := range buffers {
if buf == nil {
break
}
done := make(chan error, 1)
go func() {
done <- g.Wait()
}()
select {
case err := <-done:
if err != nil {
r.err <- err
return nil
} else {
for i := range r.concurrency {
if r.currentPart+i+1 <= r.totalParts {
select {
case <-r.done:
return nil
case r.bufferChan <- bufferMap[i]:
}
}
}
r.currentPart += r.concurrency
r.offset += r.chunkSize * int64(r.concurrency)
for i := range bufferMap {
delete(bufferMap, i)
}
if r.currentPart >= r.totalParts {
return nil
}
}
case <-time.After(r.timeout):
return nil
case <-r.done:
return nil
case r.bufferChan <- buf:
case <-r.ctx.Done():
return r.ctx.Err()
}
}
r.currentPart += r.concurrency
r.offset += r.chunkSize * int64(r.concurrency)
return nil
}
func (r *tgMultiReader) fetchChunkWithTimeout(ctx context.Context, i int64) ([]byte, error) {
chunkChan := make(chan []byte, 1)
errChan := make(chan error, 1)
go func() {
chunk, err := r.chunkSrc.Chunk(ctx, r.offset+i*r.chunkSize, r.chunkSize)
if err != nil {
errChan <- err
} else {
chunkChan <- chunk
}
}()
select {
case chunk := <-chunkChan:
return chunk, nil
case err := <-errChan:
return nil, err
case <-ctx.Done():
return nil, ctx.Err()
}
}
@ -265,13 +240,7 @@ type buffer struct {
}
func (b *buffer) isEmpty() bool {
if b == nil {
return true
}
if len(b.buf)-b.offset <= 0 {
return true
}
return false
return b == nil || len(b.buf)-b.offset <= 0
}
func (b *buffer) buffer() []byte {

View file

@ -92,7 +92,7 @@ func (suite *TestSuite) TestTimeout() {
assert.NoError(suite.T(), err)
test_data, err := io.ReadAll(reader)
assert.Greater(suite.T(), len(test_data), 0)
assert.Equal(suite.T(), err, ErrorStreamAbandoned)
assert.Equal(suite.T(), err, ErrStreamAbandoned)
}
func (suite *TestSuite) TestClose() {

View file

@ -201,7 +201,7 @@ func GetBotInfo(ctx context.Context, KV kv.KV, config *config.TGConfig, token st
func GetLocation(ctx context.Context, client *Client, cache cache.Cacher, fileId string, channelId int64, partId int64) (location *tg.InputDocumentFileLocation, err error) {
key := fmt.Sprintf("files:location:%s:%s:%d", client.UserId, fileId, partId)
key := fmt.Sprintf("files:location:%s:%s:%d", client.UserID, fileId, partId)
err = cache.Get(key, location)

View file

@ -15,46 +15,45 @@ import (
)
type UploadWorker struct {
mu sync.Mutex
mu sync.RWMutex
bots map[int64][]string
currIdx map[int64]int
}
func (w *UploadWorker) Set(bots []string, channelId int64) {
w.mu.Lock()
defer w.mu.Unlock()
_, ok := w.bots[channelId]
if !ok {
w.bots = make(map[int64][]string)
w.currIdx = make(map[int64]int)
w.bots[channelId] = bots
w.currIdx[channelId] = 0
func NewUploadWorker() *UploadWorker {
return &UploadWorker{
bots: make(map[int64][]string),
currIdx: make(map[int64]int),
}
}
func (w *UploadWorker) Next(channelId int64) (string, int) {
func (w *UploadWorker) Set(bots []string, channelID int64) {
w.mu.Lock()
defer w.mu.Unlock()
index := w.currIdx[channelId]
w.currIdx[channelId] = (index + 1) % len(w.bots[channelId])
return w.bots[channelId][index], index
w.bots[channelID] = bots
w.currIdx[channelID] = 0
}
func NewUploadWorker() *UploadWorker {
return &UploadWorker{}
func (w *UploadWorker) Next(channelID int64) (string, int) {
w.mu.RLock()
defer w.mu.RUnlock()
bots := w.bots[channelID]
index := w.currIdx[channelID]
w.currIdx[channelID] = (index + 1) % len(bots)
return bots[index], index
}
type Client struct {
Tg *telegram.Client
Stop StopFunc
Status string
UserId string
lastUsed time.Time
connections int
UserID string
LastUsed time.Time
Connections int
}
type StreamWorker struct {
mu sync.Mutex
mu sync.RWMutex
clients map[string]*Client
currIdx map[int64]int
channelBots map[int64][]string
@ -64,112 +63,135 @@ type StreamWorker struct {
logger *zap.SugaredLogger
}
func (w *StreamWorker) Set(bots []string, channelId int64) {
w.mu.Lock()
defer w.mu.Unlock()
_, ok := w.channelBots[channelId]
if !ok {
w.channelBots[channelId] = bots
w.currIdx[channelId] = 0
func NewStreamWorker(ctx context.Context) func(cnf *config.Config, kv kv.KV) *StreamWorker {
return func(cnf *config.Config, kv kv.KV) *StreamWorker {
worker := &StreamWorker{
cnf: &cnf.TG,
kv: kv,
ctx: ctx,
clients: make(map[string]*Client),
currIdx: make(map[int64]int),
channelBots: make(map[int64][]string),
logger: logging.FromContext(ctx),
}
go worker.startIdleClientMonitor()
return worker
}
}
func (w *StreamWorker) Next(channelId int64) (*Client, int, error) {
func (w *StreamWorker) Set(bots []string, channelID int64) {
w.mu.Lock()
defer w.mu.Unlock()
index := w.currIdx[channelId]
token := w.channelBots[channelId][index]
userId := strings.Split(token, ":")[0]
client, ok := w.clients[userId]
w.channelBots[channelID] = bots
w.currIdx[channelID] = 0
}
func (w *StreamWorker) Next(channelID int64) (*Client, int, error) {
w.mu.Lock()
defer w.mu.Unlock()
bots := w.channelBots[channelID]
index := w.currIdx[channelID]
token := bots[index]
userID := strings.Split(token, ":")[0]
client, err := w.getOrCreateClient(userID, token)
if err != nil {
return nil, 0, err
}
w.currIdx[channelID] = (index + 1) % len(bots)
client.LastUsed = time.Now()
client.Connections++
if client.Connections == 1 {
client.Status = "serving"
}
return client, index, nil
}
func (w *StreamWorker) getOrCreateClient(userID, token string) (*Client, error) {
client, ok := w.clients[userID]
if !ok || (client.Status == "idle" && client.Stop == nil) {
middlewares := Middlewares(w.cnf, 5)
tgClient, _ := BotClient(w.ctx, w.kv, w.cnf, token, middlewares...)
client = &Client{Tg: tgClient, Status: "idle", UserId: userId}
w.clients[userId] = client
client = &Client{Tg: tgClient, Status: "idle", UserID: userID}
w.clients[userID] = client
stop, err := Connect(client.Tg, WithBotToken(token))
if err != nil {
return nil, 0, err
return nil, err
}
client.Stop = stop
w.logger.Debug("started bg client: ", client.UserId)
w.logger.Debug("started bg client: ", client.UserID)
}
w.currIdx[channelId] = (index + 1) % len(w.channelBots[channelId])
client.lastUsed = time.Now()
if client.connections == 0 {
client.Status = "serving"
}
client.connections++
return client, index, nil
return client, nil
}
func (w *StreamWorker) Release(client *Client) {
w.mu.Lock()
defer w.mu.Unlock()
client.connections--
if client.connections == 0 {
client.Connections--
if client.Connections == 0 {
client.Status = "running"
}
}
func (w *StreamWorker) UserWorker(session string, userId int64) (*Client, error) {
func (w *StreamWorker) UserWorker(session string, userID int64) (*Client, error) {
w.mu.Lock()
defer w.mu.Unlock()
id := strconv.FormatInt(userId, 10)
id := strconv.FormatInt(userID, 10)
client, ok := w.clients[id]
if !ok || (client.Status == "idle" && client.Stop == nil) {
middlewares := Middlewares(w.cnf, 5)
tgClient, _ := AuthClient(w.ctx, w.cnf, session, middlewares...)
client = &Client{Tg: tgClient, Status: "idle", UserId: id}
client = &Client{Tg: tgClient, Status: "idle", UserID: id}
w.clients[id] = client
stop, err := Connect(client.Tg, WithContext(w.ctx))
if err != nil {
return nil, err
}
client.Stop = stop
w.logger.Debug("started bg client: ", client.UserId)
w.logger.Debug("started bg client: ", client.UserID)
}
client.lastUsed = time.Now()
if client.connections == 0 {
client.LastUsed = time.Now()
client.Connections++
if client.Connections == 1 {
client.Status = "serving"
}
client.connections++
return client, nil
}
func (w *StreamWorker) startIdleClientMonitor() {
ticker := time.NewTicker(w.cnf.BgBotsCheckInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
w.mu.Lock()
for _, client := range w.clients {
if client.Status == "running" && time.Since(client.lastUsed) > w.cnf.BgBotsTimeout {
if client.Stop != nil {
client.Stop()
client.Stop = nil
client.Status = "idle"
w.logger.Debug("stopped bg client: ", client.UserId)
}
}
}
w.mu.Unlock()
w.checkIdleClients()
case <-w.ctx.Done():
return
}
}
}
func NewStreamWorker(ctx context.Context) func(cnf *config.Config, kv kv.KV) *StreamWorker {
return func(cnf *config.Config, kv kv.KV) *StreamWorker {
worker := &StreamWorker{cnf: &cnf.TG, kv: kv, ctx: ctx,
clients: make(map[string]*Client), currIdx: make(map[int64]int),
channelBots: make(map[int64][]string), logger: logging.FromContext(ctx)}
go worker.startIdleClientMonitor()
return worker
func (w *StreamWorker) checkIdleClients() {
w.mu.Lock()
defer w.mu.Unlock()
for _, client := range w.clients {
if client.Status == "running" && time.Since(client.LastUsed) > w.cnf.BgBotsTimeout {
if client.Stop != nil {
client.Stop()
client.Stop = nil
client.Status = "idle"
w.logger.Debug("stopped bg client: ", client.UserID)
}
}
}
}