chore: authenticate streams by cookies

This commit is contained in:
divyam234 2024-06-29 15:35:23 +05:30
parent c5cd24bbb3
commit ae280eb532
3 changed files with 66 additions and 53 deletions

View file

@ -2,11 +2,15 @@ package auth
import ( import (
"encoding/json" "encoding/json"
"fmt"
"strconv" "strconv"
"strings"
"time"
"github.com/divyam234/teldrive/pkg/types" "github.com/divyam234/teldrive/pkg/types"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/go-jose/go-jose/v3" "github.com/go-jose/go-jose/v3"
"github.com/go-jose/go-jose/v3/jwt"
) )
func Encode(secret string, payload *types.JWTClaims) (string, error) { func Encode(secret string, payload *types.JWTClaims) (string, error) {
@ -68,3 +72,33 @@ func GetUser(c *gin.Context) (int64, string) {
userId, _ := strconv.ParseInt(jwtUser.Subject, 10, 64) userId, _ := strconv.ParseInt(jwtUser.Subject, 10, 64)
return userId, jwtUser.TgSession return userId, jwtUser.TgSession
} }
func VerifyUser(c *gin.Context, secret string) (*types.JWTClaims, error) {
var token string
cookie, err := c.Request.Cookie("user-session")
if err != nil {
authHeader := c.GetHeader("Authorization")
bearerToken := strings.Split(authHeader, "Bearer ")
if len(bearerToken) != 2 {
return nil, fmt.Errorf("missing auth token")
}
token = bearerToken[1]
} else {
token = cookie.Value
}
now := time.Now().UTC()
jwePayload, err := Decode(secret, token)
if err != nil {
return nil, err
}
if *jwePayload.Expiry < *jwt.NewNumericDate(now) {
return nil, fmt.Errorf("token expired")
}
return jwePayload, nil
}

View file

@ -3,13 +3,11 @@ package middleware
import ( import (
"context" "context"
"net/http" "net/http"
"strings"
"time" "time"
"github.com/divyam234/cors" "github.com/divyam234/cors"
"github.com/divyam234/teldrive/internal/auth" "github.com/divyam234/teldrive/internal/auth"
"github.com/gin-contrib/secure" "github.com/gin-contrib/secure"
"github.com/go-jose/go-jose/v3/jwt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@ -42,41 +40,12 @@ func Cors() gin.HandlerFunc {
func Authmiddleware(secret string) gin.HandlerFunc { func Authmiddleware(secret string) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
var token string user, err := auth.VerifyUser(c, secret)
cookie, err := c.Request.Cookie("user-session")
if err != nil { if err != nil {
authHeader := c.GetHeader("Authorization") c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
bearerToken := strings.Split(authHeader, "Bearer ")
if len(bearerToken) != 2 {
c.JSON(http.StatusUnauthorized, gin.H{"error": "missing auth token"})
c.Abort()
return return
} }
token = bearerToken[1] c.Set("jwtUser", user)
} else {
token = cookie.Value
}
now := time.Now().UTC()
jwePayload, err := auth.Decode(secret, token)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
c.Abort()
return
}
if *jwePayload.Expiry < *jwt.NewNumericDate(now) {
c.JSON(http.StatusUnauthorized, gin.H{"error": "token expired"})
c.Abort()
return
}
c.Set("jwtUser", jwePayload)
c.Next() c.Next()
} }
} }

View file

@ -70,12 +70,12 @@ func randInt64() (int64, error) {
type FileService struct { type FileService struct {
db *gorm.DB db *gorm.DB
cnf *config.TGConfig cnf *config.Config
worker *tgc.StreamWorker worker *tgc.StreamWorker
} }
func NewFileService(db *gorm.DB, cnf *config.Config, worker *tgc.StreamWorker) *FileService { func NewFileService(db *gorm.DB, cnf *config.Config, worker *tgc.StreamWorker) *FileService {
return &FileService{db: db, cnf: &cnf.TG, worker: worker} return &FileService{db: db, cnf: cnf, worker: worker}
} }
func (fs *FileService) CreateFile(c *gin.Context, userId int64, fileIn *schemas.FileIn) (*schemas.FileOut, *types.AppError) { func (fs *FileService) CreateFile(c *gin.Context, userId int64, fileIn *schemas.FileIn) (*schemas.FileOut, *types.AppError) {
@ -367,7 +367,7 @@ func (fs *FileService) DeleteFileParts(c *gin.Context, id string) (*schemas.Mess
_, session := auth.GetUser(c) _, session := auth.GetUser(c)
client, _ := tgc.AuthClient(c, fs.cnf, session) client, _ := tgc.AuthClient(c, &fs.cnf.TG, session)
ids := []int{} ids := []int{}
@ -417,7 +417,7 @@ func (fs *FileService) CopyFile(c *gin.Context) (*schemas.FileOut, *types.AppErr
userId, session := auth.GetUser(c) userId, session := auth.GetUser(c)
client, _ := tgc.AuthClient(c, fs.cnf, session) client, _ := tgc.AuthClient(c, &fs.cnf.TG, session)
var res []models.File var res []models.File
@ -531,18 +531,30 @@ func (fs *FileService) GetFileStream(c *gin.Context, download bool) {
authHash := c.Query("hash") authHash := c.Query("hash")
if authHash == "" {
http.Error(w, "missing hash param", http.StatusBadRequest)
return
}
cache := cache.FromContext(c) cache := cache.FromContext(c)
session, err := getSessionByHash(fs.db, cache, authHash) var (
session *models.Session
err error
appErr *types.AppError
user *types.JWTClaims
)
if authHash == "" {
user, err = auth.VerifyUser(c, fs.cnf.JWT.Secret)
if err != nil {
http.Error(w, "missing session or authash", http.StatusUnauthorized)
return
}
userId, _ := strconv.ParseInt(user.Subject, 10, 64)
session = &models.Session{UserId: userId, Session: user.TgSession}
} else {
session, err = getSessionByHash(fs.db, cache, authHash)
if err != nil { if err != nil {
http.Error(w, "invalid hash", http.StatusBadRequest) http.Error(w, "invalid hash", http.StatusBadRequest)
return return
} }
}
file := &schemas.FileOutFull{} file := &schemas.FileOutFull{}
@ -550,8 +562,6 @@ func (fs *FileService) GetFileStream(c *gin.Context, download bool) {
err = cache.Get(key, file) err = cache.Get(key, file)
var appErr *types.AppError
if err != nil { if err != nil {
file, appErr = fs.GetFileByID(fileID) file, appErr = fs.GetFileByID(fileID)
if appErr != nil { if appErr != nil {
@ -646,7 +656,7 @@ func (fs *FileService) GetFileStream(c *gin.Context, download bool) {
multiThreads int multiThreads int
) )
multiThreads = fs.cnf.Stream.MultiThreads multiThreads = fs.cnf.TG.Stream.MultiThreads
defer func() { defer func() {
if client != nil { if client != nil {
@ -654,7 +664,7 @@ func (fs *FileService) GetFileStream(c *gin.Context, download bool) {
} }
}() }()
if fs.cnf.DisableStreamBots || len(tokens) == 0 { if fs.cnf.TG.DisableStreamBots || len(tokens) == 0 {
client, err = fs.worker.UserWorker(session.Session, session.UserId) client, err = fs.worker.UserWorker(session.Session, session.UserId)
if err != nil { if err != nil {
logger.Error("file stream", zap.Error(err)) logger.Error("file stream", zap.Error(err))
@ -666,7 +676,7 @@ func (fs *FileService) GetFileStream(c *gin.Context, download bool) {
} else { } else {
limit := min(len(tokens), fs.cnf.BgBotsLimit) limit := min(len(tokens), fs.cnf.TG.BgBotsLimit)
fs.worker.Set(tokens[:limit], file.ChannelID) fs.worker.Set(tokens[:limit], file.ChannelID)
client, _, err = fs.worker.Next(file.ChannelID) client, _, err = fs.worker.Next(file.ChannelID)
@ -689,9 +699,9 @@ func (fs *FileService) GetFileStream(c *gin.Context, download bool) {
multiThreads = 0 multiThreads = 0
} }
if file.Encrypted { if file.Encrypted {
lr, err = reader.NewDecryptedReader(c, file.Id, parts, start, end, file.ChannelID, fs.cnf, multiThreads, client, fs.worker) lr, err = reader.NewDecryptedReader(c, file.Id, parts, start, end, file.ChannelID, &fs.cnf.TG, multiThreads, client, fs.worker)
} else { } else {
lr, err = reader.NewLinearReader(c, file.Id, parts, start, end, file.ChannelID, fs.cnf, multiThreads, client, fs.worker) lr, err = reader.NewLinearReader(c, file.Id, parts, start, end, file.ChannelID, &fs.cnf.TG, multiThreads, client, fs.worker)
} }
if err != nil { if err != nil {