chore: authenticate streams by cookies

This commit is contained in:
divyam234 2024-06-29 15:35:23 +05:30
parent c5cd24bbb3
commit ae280eb532
3 changed files with 66 additions and 53 deletions

View file

@ -2,11 +2,15 @@ package auth
import (
"encoding/json"
"fmt"
"strconv"
"strings"
"time"
"github.com/divyam234/teldrive/pkg/types"
"github.com/gin-gonic/gin"
"github.com/go-jose/go-jose/v3"
"github.com/go-jose/go-jose/v3/jwt"
)
func Encode(secret string, payload *types.JWTClaims) (string, error) {
@ -68,3 +72,33 @@ func GetUser(c *gin.Context) (int64, string) {
userId, _ := strconv.ParseInt(jwtUser.Subject, 10, 64)
return userId, jwtUser.TgSession
}
func VerifyUser(c *gin.Context, secret string) (*types.JWTClaims, error) {
var token string
cookie, err := c.Request.Cookie("user-session")
if err != nil {
authHeader := c.GetHeader("Authorization")
bearerToken := strings.Split(authHeader, "Bearer ")
if len(bearerToken) != 2 {
return nil, fmt.Errorf("missing auth token")
}
token = bearerToken[1]
} else {
token = cookie.Value
}
now := time.Now().UTC()
jwePayload, err := Decode(secret, token)
if err != nil {
return nil, err
}
if *jwePayload.Expiry < *jwt.NewNumericDate(now) {
return nil, fmt.Errorf("token expired")
}
return jwePayload, nil
}

View file

@ -3,13 +3,11 @@ package middleware
import (
"context"
"net/http"
"strings"
"time"
"github.com/divyam234/cors"
"github.com/divyam234/teldrive/internal/auth"
"github.com/gin-contrib/secure"
"github.com/go-jose/go-jose/v3/jwt"
"github.com/gin-gonic/gin"
)
@ -42,41 +40,12 @@ func Cors() gin.HandlerFunc {
func Authmiddleware(secret string) gin.HandlerFunc {
return func(c *gin.Context) {
var token string
cookie, err := c.Request.Cookie("user-session")
user, err := auth.VerifyUser(c, secret)
if err != nil {
authHeader := c.GetHeader("Authorization")
bearerToken := strings.Split(authHeader, "Bearer ")
if len(bearerToken) != 2 {
c.JSON(http.StatusUnauthorized, gin.H{"error": "missing auth token"})
c.Abort()
return
}
token = bearerToken[1]
} else {
token = cookie.Value
}
now := time.Now().UTC()
jwePayload, err := auth.Decode(secret, token)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
c.Abort()
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
return
}
if *jwePayload.Expiry < *jwt.NewNumericDate(now) {
c.JSON(http.StatusUnauthorized, gin.H{"error": "token expired"})
c.Abort()
return
}
c.Set("jwtUser", jwePayload)
c.Set("jwtUser", user)
c.Next()
}
}

View file

@ -70,12 +70,12 @@ func randInt64() (int64, error) {
type FileService struct {
db *gorm.DB
cnf *config.TGConfig
cnf *config.Config
worker *tgc.StreamWorker
}
func NewFileService(db *gorm.DB, cnf *config.Config, worker *tgc.StreamWorker) *FileService {
return &FileService{db: db, cnf: &cnf.TG, worker: worker}
return &FileService{db: db, cnf: cnf, worker: worker}
}
func (fs *FileService) CreateFile(c *gin.Context, userId int64, fileIn *schemas.FileIn) (*schemas.FileOut, *types.AppError) {
@ -367,7 +367,7 @@ func (fs *FileService) DeleteFileParts(c *gin.Context, id string) (*schemas.Mess
_, session := auth.GetUser(c)
client, _ := tgc.AuthClient(c, fs.cnf, session)
client, _ := tgc.AuthClient(c, &fs.cnf.TG, session)
ids := []int{}
@ -417,7 +417,7 @@ func (fs *FileService) CopyFile(c *gin.Context) (*schemas.FileOut, *types.AppErr
userId, session := auth.GetUser(c)
client, _ := tgc.AuthClient(c, fs.cnf, session)
client, _ := tgc.AuthClient(c, &fs.cnf.TG, session)
var res []models.File
@ -531,17 +531,29 @@ func (fs *FileService) GetFileStream(c *gin.Context, download bool) {
authHash := c.Query("hash")
if authHash == "" {
http.Error(w, "missing hash param", http.StatusBadRequest)
return
}
cache := cache.FromContext(c)
session, err := getSessionByHash(fs.db, cache, authHash)
var (
session *models.Session
err error
appErr *types.AppError
user *types.JWTClaims
)
if err != nil {
http.Error(w, "invalid hash", http.StatusBadRequest)
return
if authHash == "" {
user, err = auth.VerifyUser(c, fs.cnf.JWT.Secret)
if err != nil {
http.Error(w, "missing session or authash", http.StatusUnauthorized)
return
}
userId, _ := strconv.ParseInt(user.Subject, 10, 64)
session = &models.Session{UserId: userId, Session: user.TgSession}
} else {
session, err = getSessionByHash(fs.db, cache, authHash)
if err != nil {
http.Error(w, "invalid hash", http.StatusBadRequest)
return
}
}
file := &schemas.FileOutFull{}
@ -550,8 +562,6 @@ func (fs *FileService) GetFileStream(c *gin.Context, download bool) {
err = cache.Get(key, file)
var appErr *types.AppError
if err != nil {
file, appErr = fs.GetFileByID(fileID)
if appErr != nil {
@ -646,7 +656,7 @@ func (fs *FileService) GetFileStream(c *gin.Context, download bool) {
multiThreads int
)
multiThreads = fs.cnf.Stream.MultiThreads
multiThreads = fs.cnf.TG.Stream.MultiThreads
defer func() {
if client != nil {
@ -654,7 +664,7 @@ func (fs *FileService) GetFileStream(c *gin.Context, download bool) {
}
}()
if fs.cnf.DisableStreamBots || len(tokens) == 0 {
if fs.cnf.TG.DisableStreamBots || len(tokens) == 0 {
client, err = fs.worker.UserWorker(session.Session, session.UserId)
if err != nil {
logger.Error("file stream", zap.Error(err))
@ -666,7 +676,7 @@ func (fs *FileService) GetFileStream(c *gin.Context, download bool) {
} else {
limit := min(len(tokens), fs.cnf.BgBotsLimit)
limit := min(len(tokens), fs.cnf.TG.BgBotsLimit)
fs.worker.Set(tokens[:limit], file.ChannelID)
client, _, err = fs.worker.Next(file.ChannelID)
@ -689,9 +699,9 @@ func (fs *FileService) GetFileStream(c *gin.Context, download bool) {
multiThreads = 0
}
if file.Encrypted {
lr, err = reader.NewDecryptedReader(c, file.Id, parts, start, end, file.ChannelID, fs.cnf, multiThreads, client, fs.worker)
lr, err = reader.NewDecryptedReader(c, file.Id, parts, start, end, file.ChannelID, &fs.cnf.TG, multiThreads, client, fs.worker)
} else {
lr, err = reader.NewLinearReader(c, file.Id, parts, start, end, file.ChannelID, fs.cnf, multiThreads, client, fs.worker)
lr, err = reader.NewLinearReader(c, file.Id, parts, start, end, file.ChannelID, &fs.cnf.TG, multiThreads, client, fs.worker)
}
if err != nil {