mirror of
https://github.com/tgdrive/teldrive.git
synced 2025-09-08 23:46:22 +08:00
chore: authenticate streams by cookies
This commit is contained in:
parent
c5cd24bbb3
commit
ae280eb532
3 changed files with 66 additions and 53 deletions
|
@ -2,11 +2,15 @@ package auth
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/divyam234/teldrive/pkg/types"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-jose/go-jose/v3"
|
||||
"github.com/go-jose/go-jose/v3/jwt"
|
||||
)
|
||||
|
||||
func Encode(secret string, payload *types.JWTClaims) (string, error) {
|
||||
|
@ -68,3 +72,33 @@ func GetUser(c *gin.Context) (int64, string) {
|
|||
userId, _ := strconv.ParseInt(jwtUser.Subject, 10, 64)
|
||||
return userId, jwtUser.TgSession
|
||||
}
|
||||
|
||||
func VerifyUser(c *gin.Context, secret string) (*types.JWTClaims, error) {
|
||||
var token string
|
||||
cookie, err := c.Request.Cookie("user-session")
|
||||
|
||||
if err != nil {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
bearerToken := strings.Split(authHeader, "Bearer ")
|
||||
if len(bearerToken) != 2 {
|
||||
return nil, fmt.Errorf("missing auth token")
|
||||
}
|
||||
token = bearerToken[1]
|
||||
} else {
|
||||
token = cookie.Value
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
|
||||
jwePayload, err := Decode(secret, token)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if *jwePayload.Expiry < *jwt.NewNumericDate(now) {
|
||||
return nil, fmt.Errorf("token expired")
|
||||
|
||||
}
|
||||
return jwePayload, nil
|
||||
}
|
||||
|
|
|
@ -3,13 +3,11 @@ package middleware
|
|||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/divyam234/cors"
|
||||
"github.com/divyam234/teldrive/internal/auth"
|
||||
"github.com/gin-contrib/secure"
|
||||
"github.com/go-jose/go-jose/v3/jwt"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
@ -42,41 +40,12 @@ func Cors() gin.HandlerFunc {
|
|||
|
||||
func Authmiddleware(secret string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var token string
|
||||
|
||||
cookie, err := c.Request.Cookie("user-session")
|
||||
|
||||
user, err := auth.VerifyUser(c, secret)
|
||||
if err != nil {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
bearerToken := strings.Split(authHeader, "Bearer ")
|
||||
if len(bearerToken) != 2 {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "missing auth token"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
token = bearerToken[1]
|
||||
} else {
|
||||
token = cookie.Value
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
|
||||
jwePayload, err := auth.Decode(secret, token)
|
||||
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
|
||||
c.Abort()
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if *jwePayload.Expiry < *jwt.NewNumericDate(now) {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "token expired"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Set("jwtUser", jwePayload)
|
||||
|
||||
c.Set("jwtUser", user)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -70,12 +70,12 @@ func randInt64() (int64, error) {
|
|||
|
||||
type FileService struct {
|
||||
db *gorm.DB
|
||||
cnf *config.TGConfig
|
||||
cnf *config.Config
|
||||
worker *tgc.StreamWorker
|
||||
}
|
||||
|
||||
func NewFileService(db *gorm.DB, cnf *config.Config, worker *tgc.StreamWorker) *FileService {
|
||||
return &FileService{db: db, cnf: &cnf.TG, worker: worker}
|
||||
return &FileService{db: db, cnf: cnf, worker: worker}
|
||||
}
|
||||
|
||||
func (fs *FileService) CreateFile(c *gin.Context, userId int64, fileIn *schemas.FileIn) (*schemas.FileOut, *types.AppError) {
|
||||
|
@ -367,7 +367,7 @@ func (fs *FileService) DeleteFileParts(c *gin.Context, id string) (*schemas.Mess
|
|||
|
||||
_, session := auth.GetUser(c)
|
||||
|
||||
client, _ := tgc.AuthClient(c, fs.cnf, session)
|
||||
client, _ := tgc.AuthClient(c, &fs.cnf.TG, session)
|
||||
|
||||
ids := []int{}
|
||||
|
||||
|
@ -417,7 +417,7 @@ func (fs *FileService) CopyFile(c *gin.Context) (*schemas.FileOut, *types.AppErr
|
|||
|
||||
userId, session := auth.GetUser(c)
|
||||
|
||||
client, _ := tgc.AuthClient(c, fs.cnf, session)
|
||||
client, _ := tgc.AuthClient(c, &fs.cnf.TG, session)
|
||||
|
||||
var res []models.File
|
||||
|
||||
|
@ -531,17 +531,29 @@ func (fs *FileService) GetFileStream(c *gin.Context, download bool) {
|
|||
|
||||
authHash := c.Query("hash")
|
||||
|
||||
if authHash == "" {
|
||||
http.Error(w, "missing hash param", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
cache := cache.FromContext(c)
|
||||
|
||||
session, err := getSessionByHash(fs.db, cache, authHash)
|
||||
var (
|
||||
session *models.Session
|
||||
err error
|
||||
appErr *types.AppError
|
||||
user *types.JWTClaims
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
http.Error(w, "invalid hash", http.StatusBadRequest)
|
||||
return
|
||||
if authHash == "" {
|
||||
user, err = auth.VerifyUser(c, fs.cnf.JWT.Secret)
|
||||
if err != nil {
|
||||
http.Error(w, "missing session or authash", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
userId, _ := strconv.ParseInt(user.Subject, 10, 64)
|
||||
session = &models.Session{UserId: userId, Session: user.TgSession}
|
||||
} else {
|
||||
session, err = getSessionByHash(fs.db, cache, authHash)
|
||||
if err != nil {
|
||||
http.Error(w, "invalid hash", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
file := &schemas.FileOutFull{}
|
||||
|
@ -550,8 +562,6 @@ func (fs *FileService) GetFileStream(c *gin.Context, download bool) {
|
|||
|
||||
err = cache.Get(key, file)
|
||||
|
||||
var appErr *types.AppError
|
||||
|
||||
if err != nil {
|
||||
file, appErr = fs.GetFileByID(fileID)
|
||||
if appErr != nil {
|
||||
|
@ -646,7 +656,7 @@ func (fs *FileService) GetFileStream(c *gin.Context, download bool) {
|
|||
multiThreads int
|
||||
)
|
||||
|
||||
multiThreads = fs.cnf.Stream.MultiThreads
|
||||
multiThreads = fs.cnf.TG.Stream.MultiThreads
|
||||
|
||||
defer func() {
|
||||
if client != nil {
|
||||
|
@ -654,7 +664,7 @@ func (fs *FileService) GetFileStream(c *gin.Context, download bool) {
|
|||
}
|
||||
}()
|
||||
|
||||
if fs.cnf.DisableStreamBots || len(tokens) == 0 {
|
||||
if fs.cnf.TG.DisableStreamBots || len(tokens) == 0 {
|
||||
client, err = fs.worker.UserWorker(session.Session, session.UserId)
|
||||
if err != nil {
|
||||
logger.Error("file stream", zap.Error(err))
|
||||
|
@ -666,7 +676,7 @@ func (fs *FileService) GetFileStream(c *gin.Context, download bool) {
|
|||
|
||||
} else {
|
||||
|
||||
limit := min(len(tokens), fs.cnf.BgBotsLimit)
|
||||
limit := min(len(tokens), fs.cnf.TG.BgBotsLimit)
|
||||
|
||||
fs.worker.Set(tokens[:limit], file.ChannelID)
|
||||
client, _, err = fs.worker.Next(file.ChannelID)
|
||||
|
@ -689,9 +699,9 @@ func (fs *FileService) GetFileStream(c *gin.Context, download bool) {
|
|||
multiThreads = 0
|
||||
}
|
||||
if file.Encrypted {
|
||||
lr, err = reader.NewDecryptedReader(c, file.Id, parts, start, end, file.ChannelID, fs.cnf, multiThreads, client, fs.worker)
|
||||
lr, err = reader.NewDecryptedReader(c, file.Id, parts, start, end, file.ChannelID, &fs.cnf.TG, multiThreads, client, fs.worker)
|
||||
} else {
|
||||
lr, err = reader.NewLinearReader(c, file.Id, parts, start, end, file.ChannelID, fs.cnf, multiThreads, client, fs.worker)
|
||||
lr, err = reader.NewLinearReader(c, file.Id, parts, start, end, file.ChannelID, &fs.cnf.TG, multiThreads, client, fs.worker)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
|
|
Loading…
Add table
Reference in a new issue