mirror of
https://github.com/tgdrive/teldrive.git
synced 2025-09-11 08:54:35 +08:00
chore: authenticate streams by cookies
This commit is contained in:
parent
c5cd24bbb3
commit
ae280eb532
3 changed files with 66 additions and 53 deletions
|
@ -2,11 +2,15 @@ package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/divyam234/teldrive/pkg/types"
|
"github.com/divyam234/teldrive/pkg/types"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/go-jose/go-jose/v3"
|
"github.com/go-jose/go-jose/v3"
|
||||||
|
"github.com/go-jose/go-jose/v3/jwt"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Encode(secret string, payload *types.JWTClaims) (string, error) {
|
func Encode(secret string, payload *types.JWTClaims) (string, error) {
|
||||||
|
@ -68,3 +72,33 @@ func GetUser(c *gin.Context) (int64, string) {
|
||||||
userId, _ := strconv.ParseInt(jwtUser.Subject, 10, 64)
|
userId, _ := strconv.ParseInt(jwtUser.Subject, 10, 64)
|
||||||
return userId, jwtUser.TgSession
|
return userId, jwtUser.TgSession
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func VerifyUser(c *gin.Context, secret string) (*types.JWTClaims, error) {
|
||||||
|
var token string
|
||||||
|
cookie, err := c.Request.Cookie("user-session")
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
authHeader := c.GetHeader("Authorization")
|
||||||
|
bearerToken := strings.Split(authHeader, "Bearer ")
|
||||||
|
if len(bearerToken) != 2 {
|
||||||
|
return nil, fmt.Errorf("missing auth token")
|
||||||
|
}
|
||||||
|
token = bearerToken[1]
|
||||||
|
} else {
|
||||||
|
token = cookie.Value
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now().UTC()
|
||||||
|
|
||||||
|
jwePayload, err := Decode(secret, token)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if *jwePayload.Expiry < *jwt.NewNumericDate(now) {
|
||||||
|
return nil, fmt.Errorf("token expired")
|
||||||
|
|
||||||
|
}
|
||||||
|
return jwePayload, nil
|
||||||
|
}
|
||||||
|
|
|
@ -3,13 +3,11 @@ package middleware
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/divyam234/cors"
|
"github.com/divyam234/cors"
|
||||||
"github.com/divyam234/teldrive/internal/auth"
|
"github.com/divyam234/teldrive/internal/auth"
|
||||||
"github.com/gin-contrib/secure"
|
"github.com/gin-contrib/secure"
|
||||||
"github.com/go-jose/go-jose/v3/jwt"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
@ -42,41 +40,12 @@ func Cors() gin.HandlerFunc {
|
||||||
|
|
||||||
func Authmiddleware(secret string) gin.HandlerFunc {
|
func Authmiddleware(secret string) gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
var token string
|
user, err := auth.VerifyUser(c, secret)
|
||||||
|
|
||||||
cookie, err := c.Request.Cookie("user-session")
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
authHeader := c.GetHeader("Authorization")
|
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
|
||||||
bearerToken := strings.Split(authHeader, "Bearer ")
|
|
||||||
if len(bearerToken) != 2 {
|
|
||||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "missing auth token"})
|
|
||||||
c.Abort()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
token = bearerToken[1]
|
c.Set("jwtUser", user)
|
||||||
} else {
|
|
||||||
token = cookie.Value
|
|
||||||
}
|
|
||||||
|
|
||||||
now := time.Now().UTC()
|
|
||||||
|
|
||||||
jwePayload, err := auth.Decode(secret, token)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
|
|
||||||
c.Abort()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if *jwePayload.Expiry < *jwt.NewNumericDate(now) {
|
|
||||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "token expired"})
|
|
||||||
c.Abort()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Set("jwtUser", jwePayload)
|
|
||||||
|
|
||||||
c.Next()
|
c.Next()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -70,12 +70,12 @@ func randInt64() (int64, error) {
|
||||||
|
|
||||||
type FileService struct {
|
type FileService struct {
|
||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
cnf *config.TGConfig
|
cnf *config.Config
|
||||||
worker *tgc.StreamWorker
|
worker *tgc.StreamWorker
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewFileService(db *gorm.DB, cnf *config.Config, worker *tgc.StreamWorker) *FileService {
|
func NewFileService(db *gorm.DB, cnf *config.Config, worker *tgc.StreamWorker) *FileService {
|
||||||
return &FileService{db: db, cnf: &cnf.TG, worker: worker}
|
return &FileService{db: db, cnf: cnf, worker: worker}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fs *FileService) CreateFile(c *gin.Context, userId int64, fileIn *schemas.FileIn) (*schemas.FileOut, *types.AppError) {
|
func (fs *FileService) CreateFile(c *gin.Context, userId int64, fileIn *schemas.FileIn) (*schemas.FileOut, *types.AppError) {
|
||||||
|
@ -367,7 +367,7 @@ func (fs *FileService) DeleteFileParts(c *gin.Context, id string) (*schemas.Mess
|
||||||
|
|
||||||
_, session := auth.GetUser(c)
|
_, session := auth.GetUser(c)
|
||||||
|
|
||||||
client, _ := tgc.AuthClient(c, fs.cnf, session)
|
client, _ := tgc.AuthClient(c, &fs.cnf.TG, session)
|
||||||
|
|
||||||
ids := []int{}
|
ids := []int{}
|
||||||
|
|
||||||
|
@ -417,7 +417,7 @@ func (fs *FileService) CopyFile(c *gin.Context) (*schemas.FileOut, *types.AppErr
|
||||||
|
|
||||||
userId, session := auth.GetUser(c)
|
userId, session := auth.GetUser(c)
|
||||||
|
|
||||||
client, _ := tgc.AuthClient(c, fs.cnf, session)
|
client, _ := tgc.AuthClient(c, &fs.cnf.TG, session)
|
||||||
|
|
||||||
var res []models.File
|
var res []models.File
|
||||||
|
|
||||||
|
@ -531,18 +531,30 @@ func (fs *FileService) GetFileStream(c *gin.Context, download bool) {
|
||||||
|
|
||||||
authHash := c.Query("hash")
|
authHash := c.Query("hash")
|
||||||
|
|
||||||
if authHash == "" {
|
|
||||||
http.Error(w, "missing hash param", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
cache := cache.FromContext(c)
|
cache := cache.FromContext(c)
|
||||||
|
|
||||||
session, err := getSessionByHash(fs.db, cache, authHash)
|
var (
|
||||||
|
session *models.Session
|
||||||
|
err error
|
||||||
|
appErr *types.AppError
|
||||||
|
user *types.JWTClaims
|
||||||
|
)
|
||||||
|
|
||||||
|
if authHash == "" {
|
||||||
|
user, err = auth.VerifyUser(c, fs.cnf.JWT.Secret)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "missing session or authash", http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
userId, _ := strconv.ParseInt(user.Subject, 10, 64)
|
||||||
|
session = &models.Session{UserId: userId, Session: user.TgSession}
|
||||||
|
} else {
|
||||||
|
session, err = getSessionByHash(fs.db, cache, authHash)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, "invalid hash", http.StatusBadRequest)
|
http.Error(w, "invalid hash", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
file := &schemas.FileOutFull{}
|
file := &schemas.FileOutFull{}
|
||||||
|
|
||||||
|
@ -550,8 +562,6 @@ func (fs *FileService) GetFileStream(c *gin.Context, download bool) {
|
||||||
|
|
||||||
err = cache.Get(key, file)
|
err = cache.Get(key, file)
|
||||||
|
|
||||||
var appErr *types.AppError
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
file, appErr = fs.GetFileByID(fileID)
|
file, appErr = fs.GetFileByID(fileID)
|
||||||
if appErr != nil {
|
if appErr != nil {
|
||||||
|
@ -646,7 +656,7 @@ func (fs *FileService) GetFileStream(c *gin.Context, download bool) {
|
||||||
multiThreads int
|
multiThreads int
|
||||||
)
|
)
|
||||||
|
|
||||||
multiThreads = fs.cnf.Stream.MultiThreads
|
multiThreads = fs.cnf.TG.Stream.MultiThreads
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if client != nil {
|
if client != nil {
|
||||||
|
@ -654,7 +664,7 @@ func (fs *FileService) GetFileStream(c *gin.Context, download bool) {
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if fs.cnf.DisableStreamBots || len(tokens) == 0 {
|
if fs.cnf.TG.DisableStreamBots || len(tokens) == 0 {
|
||||||
client, err = fs.worker.UserWorker(session.Session, session.UserId)
|
client, err = fs.worker.UserWorker(session.Session, session.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("file stream", zap.Error(err))
|
logger.Error("file stream", zap.Error(err))
|
||||||
|
@ -666,7 +676,7 @@ func (fs *FileService) GetFileStream(c *gin.Context, download bool) {
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
|
|
||||||
limit := min(len(tokens), fs.cnf.BgBotsLimit)
|
limit := min(len(tokens), fs.cnf.TG.BgBotsLimit)
|
||||||
|
|
||||||
fs.worker.Set(tokens[:limit], file.ChannelID)
|
fs.worker.Set(tokens[:limit], file.ChannelID)
|
||||||
client, _, err = fs.worker.Next(file.ChannelID)
|
client, _, err = fs.worker.Next(file.ChannelID)
|
||||||
|
@ -689,9 +699,9 @@ func (fs *FileService) GetFileStream(c *gin.Context, download bool) {
|
||||||
multiThreads = 0
|
multiThreads = 0
|
||||||
}
|
}
|
||||||
if file.Encrypted {
|
if file.Encrypted {
|
||||||
lr, err = reader.NewDecryptedReader(c, file.Id, parts, start, end, file.ChannelID, fs.cnf, multiThreads, client, fs.worker)
|
lr, err = reader.NewDecryptedReader(c, file.Id, parts, start, end, file.ChannelID, &fs.cnf.TG, multiThreads, client, fs.worker)
|
||||||
} else {
|
} else {
|
||||||
lr, err = reader.NewLinearReader(c, file.Id, parts, start, end, file.ChannelID, fs.cnf, multiThreads, client, fs.worker)
|
lr, err = reader.NewLinearReader(c, file.Id, parts, start, end, file.ChannelID, &fs.cnf.TG, multiThreads, client, fs.worker)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
Loading…
Add table
Reference in a new issue