refactor: error handling

This commit is contained in:
divyam234 2024-03-19 22:31:56 +05:30
parent acbf8b71d4
commit b1b5dcdc4c
8 changed files with 98 additions and 90 deletions

View file

@ -12,6 +12,7 @@ import (
type decrpytedReader struct {
ctx context.Context
parts []types.Part
ranges []types.Range
pos int
client *telegram.Client
reader io.ReadCloser
@ -24,14 +25,15 @@ func NewDecryptedReader(
ctx context.Context,
client *telegram.Client,
parts []types.Part,
limit int64,
start, end int64,
encryptionKey string) (io.ReadCloser, error) {
r := &decrpytedReader{
ctx: ctx,
parts: parts,
client: client,
limit: limit,
limit: end - start + 1,
ranges: calculatePartByteRanges(start, end, parts[0].DecryptedSize),
encryptionKey: encryptionKey,
}
res, err := r.nextPart()
@ -68,7 +70,7 @@ func (r *decrpytedReader) Read(p []byte) (n int, err error) {
}
r.pos++
if r.pos < len(r.parts) {
r.reader, err = newTGReader(r.ctx, r.client, r.parts[r.pos])
r.reader, err = r.nextPart()
}
}
r.err = err
@ -86,7 +88,11 @@ func (r *decrpytedReader) Close() (err error) {
func (r *decrpytedReader) nextPart() (io.ReadCloser, error) {
cipher, _ := crypt.NewCipher(r.encryptionKey, r.parts[r.pos].Salt)
location := r.parts[r.ranges[r.pos].PartNo].Location
start := r.ranges[r.pos].Start
end := r.ranges[r.pos].End
salt := r.parts[r.ranges[r.pos].PartNo].Salt
cipher, _ := crypt.NewCipher(r.encryptionKey, salt)
return cipher.DecryptDataSeek(r.ctx,
func(ctx context.Context,
@ -96,14 +102,10 @@ func (r *decrpytedReader) nextPart() (io.ReadCloser, error) {
var end int64
if underlyingLimit >= 0 {
end = min(r.parts[r.pos].Size-1, underlyingOffset+underlyingLimit-1)
end = min(r.parts[r.ranges[r.pos].PartNo].Size-1, underlyingOffset+underlyingLimit-1)
}
return newTGReader(r.ctx, r.client, types.Part{
Start: underlyingOffset,
End: end,
Location: r.parts[r.pos].Location,
})
}, r.parts[r.pos].Start, r.parts[r.pos].End-r.parts[r.pos].Start+1)
return newTGReader(r.ctx, r.client, location, underlyingOffset, end)
}, start, end-start+1)
}

View file

@ -8,9 +8,37 @@ import (
"github.com/gotd/td/telegram"
)
func calculatePartByteRanges(startByte, endByte, partSize int64) []types.Range {
partByteRanges := []types.Range{}
startPart := startByte / partSize
endPart := endByte / partSize
startOffset := startByte % partSize
for part := startPart; part <= endPart; part++ {
partStartByte := int64(0)
partEndByte := partSize - 1
if part == startPart {
partStartByte = startOffset
}
if part == endPart {
partEndByte = int64(endByte % partSize)
}
partByteRanges = append(partByteRanges, types.Range{Start: partStartByte, End: partEndByte, PartNo: part})
startOffset = 0
}
return partByteRanges
}
type linearReader struct {
ctx context.Context
parts []types.Part
ranges []types.Range
pos int
client *telegram.Client
reader io.ReadCloser
@ -21,22 +49,22 @@ type linearReader struct {
func NewLinearReader(ctx context.Context,
client *telegram.Client,
parts []types.Part,
limit int64,
start, end int64,
) (reader io.ReadCloser, err error) {
r := &linearReader{
ctx: ctx,
parts: parts,
client: client,
limit: limit,
limit: end - start + 1,
ranges: calculatePartByteRanges(start, end, parts[0].Size),
}
reader, err = newTGReader(r.ctx, r.client, r.parts[r.pos])
r.reader, err = r.nextPart()
if err != nil {
return nil, err
}
r.reader = reader
return r, nil
}
@ -62,14 +90,24 @@ func (r *linearReader) Read(p []byte) (n int, err error) {
err = nil
}
r.pos++
if r.pos < len(r.parts) {
r.reader, err = newTGReader(r.ctx, r.client, r.parts[r.pos])
if r.pos < len(r.ranges) {
r.reader, err = r.nextPart()
}
}
r.err = err
return
}
func (r *linearReader) nextPart() (io.ReadCloser, error) {
location := r.parts[r.ranges[r.pos].PartNo].Location
startByte := r.ranges[r.pos].Start
endByte := r.ranges[r.pos].End
return newTGReader(r.ctx, r.client, location, startByte, endByte)
}
func (r *linearReader) Close() (err error) {
if r.reader != nil {
err = r.reader.Close()

View file

@ -5,7 +5,6 @@ import (
"fmt"
"io"
"github.com/divyam234/teldrive/pkg/types"
"github.com/gotd/td/telegram"
"github.com/gotd/td/tg"
)
@ -36,18 +35,20 @@ func calculateChunkSize(start, end int64) int64 {
func newTGReader(
ctx context.Context,
client *telegram.Client,
part types.Part,
location *tg.InputDocumentFileLocation,
start int64,
end int64,
) (io.ReadCloser, error) {
r := &tgReader{
ctx: ctx,
location: part.Location,
location: location,
client: client,
start: part.Start,
end: part.End,
chunkSize: calculateChunkSize(part.Start, part.End),
limit: part.End - part.Start + 1,
start: start,
end: end,
chunkSize: calculateChunkSize(start, end),
limit: end - start + 1,
}
r.next = r.partStream()
return r, nil

View file

@ -17,6 +17,7 @@ var internalErrors = []string{
"RPC_CALL_FAIL",
"RPC_MCGET_FAIL",
"WORKER_BUSY_TOO_LONG_RETRY",
"memory limit exit",
}
type retry struct {

View file

@ -163,64 +163,21 @@ func getParts(ctx context.Context, client *telegram.Client, file *schemas.FileOu
media := item.Media.(*tg.MessageMediaDocument)
document := media.Document.(*tg.Document)
location := document.AsInputDocumentFileLocation()
end := document.Size - 1
if file.Encrypted {
end, _ = crypt.DecryptedSize(document.Size)
end -= 1
}
parts = append(parts, types.Part{
part := types.Part{
Location: location,
End: end,
Size: document.Size,
Salt: file.Parts[i].Salt,
})
}
if file.Encrypted {
part.DecryptedSize, _ = crypt.DecryptedSize(document.Size)
}
parts = append(parts, part)
}
cache.Set(key, &parts, 3600)
return parts, nil
}
func rangedParts(parts []types.Part, startByte, endByte int64) []types.Part {
chunkSize := parts[0].End + 1
numParts := int64(len(parts))
validParts := []types.Part{}
firstChunk := max(startByte/chunkSize, 0)
lastChunk := min(endByte/chunkSize, numParts)
startInFirstChunk := startByte % chunkSize
endInLastChunk := endByte % chunkSize
if firstChunk == lastChunk {
part := parts[firstChunk]
part.Start = startInFirstChunk
part.End = endInLastChunk
validParts = append(validParts, part)
} else {
part := parts[firstChunk]
part.Start = startInFirstChunk
validParts = append(validParts, part)
// Add valid parts from any chunks in between.
for i := firstChunk + 1; i < lastChunk; i++ {
part := parts[i]
part.Start = 0
validParts = append(validParts, part)
}
// Add valid parts from the last chunk.
endPart := parts[lastChunk]
endPart.Start = 0
endPart.End = endInLastChunk
validParts = append(validParts, endPart)
}
return validParts
}
func GetChannelById(ctx context.Context, client *telegram.Client, channelId int64, userID string) (*tg.InputChannel, error) {
channel := &tg.InputChannel{}

View file

@ -524,12 +524,20 @@ func (fs *FileService) GetFileStream(c *gin.Context) {
return
}
parts = rangedParts(parts, start, end)
if file.Encrypted {
lr, _ = reader.NewDecryptedReader(c, client.Tg, parts, contentLength, fs.cnf.Uploads.EncryptionKey)
lr, err = reader.NewDecryptedReader(c, client.Tg, parts, start, end, fs.cnf.Uploads.EncryptionKey)
} else {
lr, _ = reader.NewLinearReader(c, client.Tg, parts, contentLength)
lr, err = reader.NewLinearReader(c, client.Tg, parts, start, end)
}
if err != nil {
logger.Error("file stream", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if lr == nil {
http.Error(w, "failed to initialise reader", http.StatusInternalServerError)
return
}
io.CopyN(w, lr, contentLength)

View file

@ -234,10 +234,6 @@ func (us *UserService) addBots(c context.Context, client *telegram.Client, userI
return err
}
if err != nil {
return err
}
botInfoChannel := make(chan *types.BotInfo, len(botsTokens))
waitChan := make(chan struct{}, 6)
@ -261,9 +257,9 @@ func (us *UserService) addBots(c context.Context, client *telegram.Client, userI
<-waitChan
wg.Done()
}()
if err == nil {
botInfoChannel <- info
}
botInfoChannel <- info
}(token)
}

View file

@ -12,11 +12,10 @@ type AppError struct {
}
type Part struct {
Location *tg.InputDocumentFileLocation
Start int64
End int64
Size int64
Salt string
Location *tg.InputDocumentFileLocation
DecryptedSize int64
Size int64
Salt string
}
type JWTClaims struct {
@ -48,3 +47,9 @@ type BotInfo struct {
AccessHash int64
Token string
}
type Range struct {
Start int64
End int64
PartNo int64
}