mirror of
https://github.com/tgdrive/teldrive.git
synced 2024-09-20 08:15:55 +08:00
refactor: multireader and workers
This commit is contained in:
parent
81c4bd775a
commit
3605ea4193
|
@ -113,10 +113,15 @@ func (r *decrpytedReader) nextPart() (io.ReadCloser, error) {
|
|||
if underlyingLimit >= 0 {
|
||||
end = min(r.parts[r.ranges[r.pos].PartNo].Size-1, underlyingOffset+underlyingLimit-1)
|
||||
}
|
||||
chunkSrc := &chunkSource{channelId: r.channelId, worker: r.worker,
|
||||
fileId: r.fileId, partId: r.parts[r.ranges[r.pos].PartNo].ID,
|
||||
client: r.client, concurrency: r.concurrency, cache: r.cache}
|
||||
|
||||
chunkSrc := &chunkSource{
|
||||
channelID: r.channelId,
|
||||
worker: r.worker,
|
||||
fileID: r.fileId,
|
||||
partID: r.parts[r.ranges[r.pos].PartNo].ID,
|
||||
client: r.client,
|
||||
concurrency: r.concurrency,
|
||||
cache: r.cache,
|
||||
}
|
||||
if r.concurrency < 2 {
|
||||
return newTGReader(r.ctx, underlyingOffset, end, chunkSrc)
|
||||
}
|
||||
|
|
|
@ -11,25 +11,27 @@ import (
|
|||
)
|
||||
|
||||
func calculatePartByteRanges(startByte, endByte, partSize int64) []types.Range {
|
||||
|
||||
partByteRanges := []types.Range{}
|
||||
|
||||
startPart := startByte / partSize
|
||||
|
||||
endPart := endByte / partSize
|
||||
|
||||
startOffset := startByte % partSize
|
||||
|
||||
for part := startPart; part <= endPart; part++ {
|
||||
partStartByte := int64(0)
|
||||
partEndByte := partSize - 1
|
||||
|
||||
if part == startPart {
|
||||
partStartByte = startOffset
|
||||
}
|
||||
if part == endPart {
|
||||
partEndByte = int64(endByte % partSize)
|
||||
partEndByte = endByte % partSize
|
||||
}
|
||||
partByteRanges = append(partByteRanges, types.Range{Start: partStartByte, End: partEndByte, PartNo: part})
|
||||
|
||||
partByteRanges = append(partByteRanges, types.Range{
|
||||
Start: partStartByte,
|
||||
End: partEndByte,
|
||||
PartNo: part,
|
||||
})
|
||||
|
||||
startOffset = 0
|
||||
}
|
||||
|
@ -37,7 +39,7 @@ func calculatePartByteRanges(startByte, endByte, partSize int64) []types.Range {
|
|||
return partByteRanges
|
||||
}
|
||||
|
||||
type linearReader struct {
|
||||
type LinearReader struct {
|
||||
ctx context.Context
|
||||
parts []types.Part
|
||||
ranges []types.Range
|
||||
|
@ -45,27 +47,19 @@ type linearReader struct {
|
|||
reader io.ReadCloser
|
||||
limit int64
|
||||
config *config.TGConfig
|
||||
channelId int64
|
||||
channelID int64
|
||||
worker *tgc.StreamWorker
|
||||
client *tgc.Client
|
||||
fileId string
|
||||
fileID string
|
||||
concurrency int
|
||||
cache cache.Cacher
|
||||
}
|
||||
|
||||
func NewLinearReader(ctx context.Context,
|
||||
fileId string,
|
||||
parts []types.Part,
|
||||
start, end int64,
|
||||
channelId int64,
|
||||
config *config.TGConfig,
|
||||
concurrency int,
|
||||
client *tgc.Client,
|
||||
worker *tgc.StreamWorker,
|
||||
cache cache.Cacher,
|
||||
) (reader io.ReadCloser, err error) {
|
||||
func NewLinearReader(ctx context.Context, fileID string, parts []types.Part, start, end int64,
|
||||
channelID int64, config *config.TGConfig, concurrency int, client *tgc.Client,
|
||||
worker *tgc.StreamWorker, cache cache.Cacher) (io.ReadCloser, error) {
|
||||
|
||||
r := &linearReader{
|
||||
r := &LinearReader{
|
||||
ctx: ctx,
|
||||
parts: parts,
|
||||
limit: end - start + 1,
|
||||
|
@ -73,14 +67,14 @@ func NewLinearReader(ctx context.Context,
|
|||
config: config,
|
||||
client: client,
|
||||
worker: worker,
|
||||
channelId: channelId,
|
||||
fileId: fileId,
|
||||
channelID: channelID,
|
||||
fileID: fileID,
|
||||
concurrency: concurrency,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
var err error
|
||||
r.reader, err = r.nextPart()
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -88,49 +82,51 @@ func NewLinearReader(ctx context.Context,
|
|||
return r, nil
|
||||
}
|
||||
|
||||
func (r *linearReader) Read(p []byte) (int, error) {
|
||||
|
||||
func (r *LinearReader) Read(p []byte) (int, error) {
|
||||
if r.limit <= 0 {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
n, err := r.reader.Read(p)
|
||||
|
||||
if err == io.EOF {
|
||||
if r.limit > 0 {
|
||||
err = nil
|
||||
if r.reader != nil {
|
||||
r.reader.Close()
|
||||
}
|
||||
if err == io.EOF && r.limit > 0 {
|
||||
err = nil
|
||||
if r.reader != nil {
|
||||
r.reader.Close()
|
||||
}
|
||||
r.pos++
|
||||
if r.pos < len(r.ranges) {
|
||||
r.reader, err = r.nextPart()
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
r.limit -= int64(n)
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (r *linearReader) nextPart() (io.ReadCloser, error) {
|
||||
|
||||
func (r *LinearReader) nextPart() (io.ReadCloser, error) {
|
||||
start := r.ranges[r.pos].Start
|
||||
end := r.ranges[r.pos].End
|
||||
|
||||
chunkSrc := &chunkSource{channelId: r.channelId, worker: r.worker,
|
||||
fileId: r.fileId, partId: r.parts[r.ranges[r.pos].PartNo].ID,
|
||||
client: r.client, concurrency: r.concurrency, cache: r.cache}
|
||||
chunkSrc := &chunkSource{
|
||||
channelID: r.channelID,
|
||||
worker: r.worker,
|
||||
fileID: r.fileID,
|
||||
partID: r.parts[r.ranges[r.pos].PartNo].ID,
|
||||
client: r.client,
|
||||
concurrency: r.concurrency,
|
||||
cache: r.cache,
|
||||
}
|
||||
|
||||
if r.concurrency < 2 {
|
||||
return newTGReader(r.ctx, start, end, chunkSrc)
|
||||
}
|
||||
return newTGMultiReader(r.ctx, start, end, r.config, chunkSrc)
|
||||
|
||||
}
|
||||
|
||||
func (r *linearReader) Close() (err error) {
|
||||
func (r *LinearReader) Close() error {
|
||||
if r.reader != nil {
|
||||
err = r.reader.Close()
|
||||
err := r.reader.Close()
|
||||
r.reader = nil
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -5,7 +5,6 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/divyam234/teldrive/internal/cache"
|
||||
|
@ -15,7 +14,10 @@ import (
|
|||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
var ErrorStreamAbandoned = errors.New("stream abandoned")
|
||||
var (
|
||||
ErrStreamAbandoned = errors.New("stream abandoned")
|
||||
ErrChunkTimeout = errors.New("chunk fetch timed out")
|
||||
)
|
||||
|
||||
type ChunkSource interface {
|
||||
Chunk(ctx context.Context, offset int64, limit int64) ([]byte, error)
|
||||
|
@ -23,10 +25,10 @@ type ChunkSource interface {
|
|||
}
|
||||
|
||||
type chunkSource struct {
|
||||
channelId int64
|
||||
channelID int64
|
||||
worker *tgc.StreamWorker
|
||||
fileId string
|
||||
partId int64
|
||||
fileID string
|
||||
partID int64
|
||||
concurrency int
|
||||
client *tgc.Client
|
||||
cache cache.Cacher
|
||||
|
@ -52,9 +54,9 @@ func (c *chunkSource) Chunk(ctx context.Context, offset int64, limit int64) ([]b
|
|||
}()
|
||||
|
||||
if c.concurrency > 0 {
|
||||
client, _, _ = c.worker.Next(c.channelId)
|
||||
client, _, _ = c.worker.Next(c.channelID)
|
||||
}
|
||||
location, err = tgc.GetLocation(ctx, client, c.cache, c.fileId, c.channelId, c.partId)
|
||||
location, err = tgc.GetLocation(ctx, client, c.cache, c.fileID, c.channelID, c.partID)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -66,22 +68,19 @@ func (c *chunkSource) Chunk(ctx context.Context, offset int64, limit int64) ([]b
|
|||
|
||||
type tgMultiReader struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
offset int64
|
||||
limit int64
|
||||
chunkSize int64
|
||||
bufferChan chan *buffer
|
||||
done chan struct{}
|
||||
cur *buffer
|
||||
err chan error
|
||||
mu sync.Mutex
|
||||
concurrency int
|
||||
leftCut int64
|
||||
rightCut int64
|
||||
totalParts int
|
||||
currentPart int
|
||||
closed bool
|
||||
timeout time.Duration
|
||||
chunkSrc ChunkSource
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
func newTGMultiReader(
|
||||
|
@ -90,15 +89,15 @@ func newTGMultiReader(
|
|||
end int64,
|
||||
config *config.TGConfig,
|
||||
chunkSrc ChunkSource,
|
||||
|
||||
) (*tgMultiReader, error) {
|
||||
|
||||
chunkSize := chunkSrc.ChunkSize(start, end)
|
||||
|
||||
offset := start - (start % chunkSize)
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
r := &tgMultiReader{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
limit: end - start + 1,
|
||||
bufferChan: make(chan *buffer, config.Stream.Buffers),
|
||||
concurrency: config.Stream.MultiThreads,
|
||||
|
@ -109,8 +108,6 @@ func newTGMultiReader(
|
|||
chunkSize: chunkSize,
|
||||
chunkSrc: chunkSrc,
|
||||
timeout: config.Stream.ChunkTimeout,
|
||||
done: make(chan struct{}, 1),
|
||||
err: make(chan error, 1),
|
||||
}
|
||||
|
||||
go r.fillBufferConcurrently()
|
||||
|
@ -118,45 +115,24 @@ func newTGMultiReader(
|
|||
}
|
||||
|
||||
func (r *tgMultiReader) Close() error {
|
||||
close(r.done)
|
||||
close(r.bufferChan)
|
||||
r.closed = true
|
||||
for b := range r.bufferChan {
|
||||
if b != nil {
|
||||
b = nil
|
||||
}
|
||||
}
|
||||
if r.cur != nil {
|
||||
r.cur = nil
|
||||
}
|
||||
close(r.err)
|
||||
r.cancel()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *tgMultiReader) Read(p []byte) (int, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if r.limit <= 0 {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
if r.cur.isEmpty() {
|
||||
if r.cur != nil {
|
||||
r.cur = nil
|
||||
}
|
||||
if r.cur == nil || r.cur.isEmpty() {
|
||||
select {
|
||||
case cur, ok := <-r.bufferChan:
|
||||
if !ok && r.limit > 0 {
|
||||
return 0, ErrorStreamAbandoned
|
||||
if !ok {
|
||||
return 0, ErrStreamAbandoned
|
||||
}
|
||||
r.cur = cur
|
||||
|
||||
case err := <-r.err:
|
||||
return 0, fmt.Errorf("error reading chunk: %w", err)
|
||||
case <-r.ctx.Done():
|
||||
return 0, r.ctx.Err()
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -171,91 +147,90 @@ func (r *tgMultiReader) Read(p []byte) (int, error) {
|
|||
return n, nil
|
||||
}
|
||||
|
||||
func (r *tgMultiReader) fillBufferConcurrently() error {
|
||||
func (r *tgMultiReader) fillBufferConcurrently() {
|
||||
defer close(r.bufferChan)
|
||||
|
||||
var mapMu sync.Mutex
|
||||
|
||||
bufferMap := make(map[int]*buffer)
|
||||
|
||||
defer func() {
|
||||
|
||||
for i := range bufferMap {
|
||||
delete(bufferMap, i)
|
||||
for r.currentPart < r.totalParts {
|
||||
if err := r.fillBatch(); err != nil {
|
||||
r.cancel()
|
||||
return
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
cb := func(ctx context.Context, i int) func() error {
|
||||
return func() error {
|
||||
func (r *tgMultiReader) fillBatch() error {
|
||||
g, ctx := errgroup.WithContext(r.ctx)
|
||||
g.SetLimit(r.concurrency)
|
||||
|
||||
chunk, err := r.chunkSrc.Chunk(ctx, r.offset+(int64(i)*r.chunkSize), r.chunkSize)
|
||||
buffers := make([]*buffer, r.concurrency)
|
||||
|
||||
for i := 0; i < r.concurrency && r.currentPart+i < r.totalParts; i++ {
|
||||
g.Go(func() error {
|
||||
chunkCtx, cancel := context.WithTimeout(ctx, r.timeout)
|
||||
defer cancel()
|
||||
|
||||
chunk, err := r.fetchChunkWithTimeout(chunkCtx, int64(i))
|
||||
if err != nil {
|
||||
return err
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
return fmt.Errorf("chunk %d: %w", r.currentPart+i, ErrChunkTimeout)
|
||||
}
|
||||
return fmt.Errorf("chunk %d: %w", r.currentPart+i, err)
|
||||
}
|
||||
|
||||
if r.totalParts == 1 {
|
||||
chunk = chunk[r.leftCut:r.rightCut]
|
||||
} else if r.currentPart+i+1 == 1 {
|
||||
} else if r.currentPart+i == 0 {
|
||||
chunk = chunk[r.leftCut:]
|
||||
} else if r.currentPart+i+1 == r.totalParts {
|
||||
chunk = chunk[:r.rightCut]
|
||||
}
|
||||
buf := &buffer{buf: chunk}
|
||||
mapMu.Lock()
|
||||
bufferMap[i] = buf
|
||||
mapMu.Unlock()
|
||||
|
||||
buffers[i] = &buffer{buf: chunk}
|
||||
return nil
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
for {
|
||||
if err := g.Wait(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
g := errgroup.Group{}
|
||||
|
||||
g.SetLimit(r.concurrency)
|
||||
|
||||
for i := range r.concurrency {
|
||||
if r.currentPart+i+1 <= r.totalParts {
|
||||
g.Go(cb(r.ctx, i))
|
||||
}
|
||||
for _, buf := range buffers {
|
||||
if buf == nil {
|
||||
break
|
||||
}
|
||||
|
||||
done := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
done <- g.Wait()
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
r.err <- err
|
||||
return nil
|
||||
} else {
|
||||
for i := range r.concurrency {
|
||||
if r.currentPart+i+1 <= r.totalParts {
|
||||
select {
|
||||
case <-r.done:
|
||||
return nil
|
||||
case r.bufferChan <- bufferMap[i]:
|
||||
}
|
||||
}
|
||||
}
|
||||
r.currentPart += r.concurrency
|
||||
r.offset += r.chunkSize * int64(r.concurrency)
|
||||
for i := range bufferMap {
|
||||
delete(bufferMap, i)
|
||||
}
|
||||
if r.currentPart >= r.totalParts {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
case <-time.After(r.timeout):
|
||||
return nil
|
||||
case <-r.done:
|
||||
return nil
|
||||
case r.bufferChan <- buf:
|
||||
case <-r.ctx.Done():
|
||||
return r.ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
r.currentPart += r.concurrency
|
||||
r.offset += r.chunkSize * int64(r.concurrency)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *tgMultiReader) fetchChunkWithTimeout(ctx context.Context, i int64) ([]byte, error) {
|
||||
chunkChan := make(chan []byte, 1)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
chunk, err := r.chunkSrc.Chunk(ctx, r.offset+i*r.chunkSize, r.chunkSize)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
} else {
|
||||
chunkChan <- chunk
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case chunk := <-chunkChan:
|
||||
return chunk, nil
|
||||
case err := <-errChan:
|
||||
return nil, err
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -265,13 +240,7 @@ type buffer struct {
|
|||
}
|
||||
|
||||
func (b *buffer) isEmpty() bool {
|
||||
if b == nil {
|
||||
return true
|
||||
}
|
||||
if len(b.buf)-b.offset <= 0 {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
return b == nil || len(b.buf)-b.offset <= 0
|
||||
}
|
||||
|
||||
func (b *buffer) buffer() []byte {
|
||||
|
|
|
@ -92,7 +92,7 @@ func (suite *TestSuite) TestTimeout() {
|
|||
assert.NoError(suite.T(), err)
|
||||
test_data, err := io.ReadAll(reader)
|
||||
assert.Greater(suite.T(), len(test_data), 0)
|
||||
assert.Equal(suite.T(), err, ErrorStreamAbandoned)
|
||||
assert.Equal(suite.T(), err, ErrStreamAbandoned)
|
||||
}
|
||||
|
||||
func (suite *TestSuite) TestClose() {
|
||||
|
|
|
@ -201,7 +201,7 @@ func GetBotInfo(ctx context.Context, KV kv.KV, config *config.TGConfig, token st
|
|||
|
||||
func GetLocation(ctx context.Context, client *Client, cache cache.Cacher, fileId string, channelId int64, partId int64) (location *tg.InputDocumentFileLocation, err error) {
|
||||
|
||||
key := fmt.Sprintf("files:location:%s:%s:%d", client.UserId, fileId, partId)
|
||||
key := fmt.Sprintf("files:location:%s:%s:%d", client.UserID, fileId, partId)
|
||||
|
||||
err = cache.Get(key, location)
|
||||
|
||||
|
|
|
@ -15,46 +15,45 @@ import (
|
|||
)
|
||||
|
||||
type UploadWorker struct {
|
||||
mu sync.Mutex
|
||||
mu sync.RWMutex
|
||||
bots map[int64][]string
|
||||
currIdx map[int64]int
|
||||
}
|
||||
|
||||
func (w *UploadWorker) Set(bots []string, channelId int64) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
_, ok := w.bots[channelId]
|
||||
if !ok {
|
||||
w.bots = make(map[int64][]string)
|
||||
w.currIdx = make(map[int64]int)
|
||||
w.bots[channelId] = bots
|
||||
w.currIdx[channelId] = 0
|
||||
func NewUploadWorker() *UploadWorker {
|
||||
return &UploadWorker{
|
||||
bots: make(map[int64][]string),
|
||||
currIdx: make(map[int64]int),
|
||||
}
|
||||
}
|
||||
|
||||
func (w *UploadWorker) Next(channelId int64) (string, int) {
|
||||
func (w *UploadWorker) Set(bots []string, channelID int64) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
index := w.currIdx[channelId]
|
||||
w.currIdx[channelId] = (index + 1) % len(w.bots[channelId])
|
||||
return w.bots[channelId][index], index
|
||||
w.bots[channelID] = bots
|
||||
w.currIdx[channelID] = 0
|
||||
}
|
||||
|
||||
func NewUploadWorker() *UploadWorker {
|
||||
return &UploadWorker{}
|
||||
func (w *UploadWorker) Next(channelID int64) (string, int) {
|
||||
w.mu.RLock()
|
||||
defer w.mu.RUnlock()
|
||||
bots := w.bots[channelID]
|
||||
index := w.currIdx[channelID]
|
||||
w.currIdx[channelID] = (index + 1) % len(bots)
|
||||
return bots[index], index
|
||||
}
|
||||
|
||||
type Client struct {
|
||||
Tg *telegram.Client
|
||||
Stop StopFunc
|
||||
Status string
|
||||
UserId string
|
||||
lastUsed time.Time
|
||||
connections int
|
||||
UserID string
|
||||
LastUsed time.Time
|
||||
Connections int
|
||||
}
|
||||
|
||||
type StreamWorker struct {
|
||||
mu sync.Mutex
|
||||
mu sync.RWMutex
|
||||
clients map[string]*Client
|
||||
currIdx map[int64]int
|
||||
channelBots map[int64][]string
|
||||
|
@ -64,112 +63,135 @@ type StreamWorker struct {
|
|||
logger *zap.SugaredLogger
|
||||
}
|
||||
|
||||
func (w *StreamWorker) Set(bots []string, channelId int64) {
|
||||
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
_, ok := w.channelBots[channelId]
|
||||
if !ok {
|
||||
w.channelBots[channelId] = bots
|
||||
w.currIdx[channelId] = 0
|
||||
func NewStreamWorker(ctx context.Context) func(cnf *config.Config, kv kv.KV) *StreamWorker {
|
||||
return func(cnf *config.Config, kv kv.KV) *StreamWorker {
|
||||
worker := &StreamWorker{
|
||||
cnf: &cnf.TG,
|
||||
kv: kv,
|
||||
ctx: ctx,
|
||||
clients: make(map[string]*Client),
|
||||
currIdx: make(map[int64]int),
|
||||
channelBots: make(map[int64][]string),
|
||||
logger: logging.FromContext(ctx),
|
||||
}
|
||||
go worker.startIdleClientMonitor()
|
||||
return worker
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (w *StreamWorker) Next(channelId int64) (*Client, int, error) {
|
||||
func (w *StreamWorker) Set(bots []string, channelID int64) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
index := w.currIdx[channelId]
|
||||
token := w.channelBots[channelId][index]
|
||||
userId := strings.Split(token, ":")[0]
|
||||
client, ok := w.clients[userId]
|
||||
w.channelBots[channelID] = bots
|
||||
w.currIdx[channelID] = 0
|
||||
}
|
||||
|
||||
func (w *StreamWorker) Next(channelID int64) (*Client, int, error) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
bots := w.channelBots[channelID]
|
||||
index := w.currIdx[channelID]
|
||||
token := bots[index]
|
||||
userID := strings.Split(token, ":")[0]
|
||||
|
||||
client, err := w.getOrCreateClient(userID, token)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
w.currIdx[channelID] = (index + 1) % len(bots)
|
||||
client.LastUsed = time.Now()
|
||||
client.Connections++
|
||||
if client.Connections == 1 {
|
||||
client.Status = "serving"
|
||||
}
|
||||
|
||||
return client, index, nil
|
||||
}
|
||||
|
||||
func (w *StreamWorker) getOrCreateClient(userID, token string) (*Client, error) {
|
||||
client, ok := w.clients[userID]
|
||||
if !ok || (client.Status == "idle" && client.Stop == nil) {
|
||||
middlewares := Middlewares(w.cnf, 5)
|
||||
tgClient, _ := BotClient(w.ctx, w.kv, w.cnf, token, middlewares...)
|
||||
client = &Client{Tg: tgClient, Status: "idle", UserId: userId}
|
||||
w.clients[userId] = client
|
||||
client = &Client{Tg: tgClient, Status: "idle", UserID: userID}
|
||||
w.clients[userID] = client
|
||||
|
||||
stop, err := Connect(client.Tg, WithBotToken(token))
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
return nil, err
|
||||
}
|
||||
client.Stop = stop
|
||||
w.logger.Debug("started bg client: ", client.UserId)
|
||||
w.logger.Debug("started bg client: ", client.UserID)
|
||||
}
|
||||
w.currIdx[channelId] = (index + 1) % len(w.channelBots[channelId])
|
||||
client.lastUsed = time.Now()
|
||||
if client.connections == 0 {
|
||||
client.Status = "serving"
|
||||
}
|
||||
client.connections++
|
||||
return client, index, nil
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (w *StreamWorker) Release(client *Client) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
client.connections--
|
||||
if client.connections == 0 {
|
||||
client.Connections--
|
||||
if client.Connections == 0 {
|
||||
client.Status = "running"
|
||||
}
|
||||
}
|
||||
|
||||
func (w *StreamWorker) UserWorker(session string, userId int64) (*Client, error) {
|
||||
func (w *StreamWorker) UserWorker(session string, userID int64) (*Client, error) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
id := strconv.FormatInt(userId, 10)
|
||||
id := strconv.FormatInt(userID, 10)
|
||||
client, ok := w.clients[id]
|
||||
if !ok || (client.Status == "idle" && client.Stop == nil) {
|
||||
middlewares := Middlewares(w.cnf, 5)
|
||||
tgClient, _ := AuthClient(w.ctx, w.cnf, session, middlewares...)
|
||||
client = &Client{Tg: tgClient, Status: "idle", UserId: id}
|
||||
client = &Client{Tg: tgClient, Status: "idle", UserID: id}
|
||||
w.clients[id] = client
|
||||
|
||||
stop, err := Connect(client.Tg, WithContext(w.ctx))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
client.Stop = stop
|
||||
w.logger.Debug("started bg client: ", client.UserId)
|
||||
w.logger.Debug("started bg client: ", client.UserID)
|
||||
}
|
||||
client.lastUsed = time.Now()
|
||||
if client.connections == 0 {
|
||||
|
||||
client.LastUsed = time.Now()
|
||||
client.Connections++
|
||||
if client.Connections == 1 {
|
||||
client.Status = "serving"
|
||||
}
|
||||
client.connections++
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (w *StreamWorker) startIdleClientMonitor() {
|
||||
ticker := time.NewTicker(w.cnf.BgBotsCheckInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
w.mu.Lock()
|
||||
for _, client := range w.clients {
|
||||
if client.Status == "running" && time.Since(client.lastUsed) > w.cnf.BgBotsTimeout {
|
||||
if client.Stop != nil {
|
||||
client.Stop()
|
||||
client.Stop = nil
|
||||
client.Status = "idle"
|
||||
w.logger.Debug("stopped bg client: ", client.UserId)
|
||||
}
|
||||
}
|
||||
}
|
||||
w.mu.Unlock()
|
||||
w.checkIdleClients()
|
||||
case <-w.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func NewStreamWorker(ctx context.Context) func(cnf *config.Config, kv kv.KV) *StreamWorker {
|
||||
return func(cnf *config.Config, kv kv.KV) *StreamWorker {
|
||||
worker := &StreamWorker{cnf: &cnf.TG, kv: kv, ctx: ctx,
|
||||
clients: make(map[string]*Client), currIdx: make(map[int64]int),
|
||||
channelBots: make(map[int64][]string), logger: logging.FromContext(ctx)}
|
||||
go worker.startIdleClientMonitor()
|
||||
return worker
|
||||
func (w *StreamWorker) checkIdleClients() {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
for _, client := range w.clients {
|
||||
if client.Status == "running" && time.Since(client.LastUsed) > w.cnf.BgBotsTimeout {
|
||||
if client.Stop != nil {
|
||||
client.Stop()
|
||||
client.Stop = nil
|
||||
client.Status = "idle"
|
||||
w.logger.Debug("stopped bg client: ", client.UserID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue