teldrive/pkg/services/auth.go

391 lines
11 KiB
Go
Raw Normal View History

2023-08-12 19:21:42 +08:00
package services
import (
"bytes"
2023-08-13 04:15:19 +08:00
"context"
2023-09-20 03:20:44 +08:00
"crypto/md5"
2023-08-12 19:21:42 +08:00
"encoding/base64"
"encoding/binary"
2023-09-20 03:20:44 +08:00
"encoding/hex"
2023-08-15 03:36:24 +08:00
"encoding/json"
2023-08-12 19:21:42 +08:00
"errors"
2023-08-15 03:36:24 +08:00
"fmt"
2023-08-12 19:21:42 +08:00
"math/big"
"net"
"net/http"
2023-08-13 04:15:19 +08:00
"strconv"
2023-08-12 19:21:42 +08:00
"time"
2023-12-03 03:47:23 +08:00
cnf "github.com/divyam234/teldrive/config"
"github.com/divyam234/teldrive/internal/auth"
"github.com/divyam234/teldrive/internal/tgc"
"github.com/divyam234/teldrive/internal/utils"
"github.com/divyam234/teldrive/pkg/models"
"github.com/divyam234/teldrive/pkg/schemas"
"github.com/divyam234/teldrive/pkg/types"
2023-08-12 19:21:42 +08:00
"github.com/gin-gonic/gin"
"github.com/go-jose/go-jose/v3/jwt"
2023-08-15 03:36:24 +08:00
"github.com/gorilla/websocket"
"github.com/gotd/td/session"
tgauth "github.com/gotd/td/telegram/auth"
"github.com/gotd/td/telegram/auth/qrlogin"
"github.com/gotd/td/tg"
2023-08-22 23:00:14 +08:00
"github.com/gotd/td/tgerr"
2023-08-14 04:58:06 +08:00
"gorm.io/gorm"
2023-08-12 19:21:42 +08:00
)
type AuthService struct {
Db *gorm.DB
SessionMaxAge int
SessionCookieName string
2023-08-12 19:21:42 +08:00
}
2023-12-03 05:46:53 +08:00
func NewAuthService(db *gorm.DB) *AuthService {
return &AuthService{
Db: db,
SessionMaxAge: 30 * 24 * 60 * 60,
SessionCookieName: "user-session"}
2023-08-15 03:36:24 +08:00
}
2023-12-03 05:46:53 +08:00
func ip4toInt(IPv4Address net.IP) int64 {
2023-08-12 19:21:42 +08:00
IPv4Int := big.NewInt(0)
IPv4Int.SetBytes(IPv4Address.To4())
return IPv4Int.Int64()
}
2023-12-03 05:46:53 +08:00
func pack32BinaryIP4(ip4Address string) []byte {
ipv4Decimal := ip4toInt(net.ParseIP(ip4Address))
2023-08-12 19:21:42 +08:00
buf := new(bytes.Buffer)
binary.Write(buf, binary.BigEndian, uint32(ipv4Decimal))
return buf.Bytes()
}
2023-08-15 03:36:24 +08:00
func generateTgSession(dcID int, authKey []byte, port int) string {
dcMaps := map[int]string{
1: "149.154.175.53",
2: "149.154.167.51",
3: "149.154.175.100",
4: "149.154.167.91",
5: "91.108.56.130",
}
2023-08-12 19:21:42 +08:00
dcIDByte := byte(dcID)
2023-12-03 05:46:53 +08:00
serverAddressBytes := pack32BinaryIP4(dcMaps[dcID])
2023-08-12 19:21:42 +08:00
portByte := make([]byte, 2)
binary.BigEndian.PutUint16(portByte, uint16(port))
packet := make([]byte, 0)
packet = append(packet, dcIDByte)
packet = append(packet, serverAddressBytes...)
packet = append(packet, portByte...)
packet = append(packet, authKey...)
base64Encoded := base64.URLEncoding.EncodeToString(packet)
return "1" + base64Encoded
}
2023-08-17 23:32:40 +08:00
func setCookie(c *gin.Context, key string, value string, age int) {
2023-12-03 03:47:23 +08:00
config := cnf.GetConfig()
2023-08-17 23:32:40 +08:00
if config.CookieSameSite {
c.SetSameSite(2)
} else {
c.SetSameSite(4)
}
c.SetCookie(key, value, age, "/", c.Request.Host, config.Https, true)
}
2023-09-08 22:51:54 +08:00
func checkUserIsAllowed(userName string) bool {
2023-12-03 03:47:23 +08:00
config := cnf.GetConfig()
2023-09-08 22:51:54 +08:00
found := false
if len(config.AllowedUsers) > 0 {
for _, user := range config.AllowedUsers {
if user == userName {
found = true
break
}
}
} else {
found = true
}
return found
}
2023-08-14 04:58:06 +08:00
func (as *AuthService) LogIn(c *gin.Context) (*schemas.Message, *types.AppError) {
2023-12-03 03:47:23 +08:00
var session schemas.TgSession
2023-08-12 19:21:42 +08:00
if err := c.ShouldBindJSON(&session); err != nil {
2023-08-14 04:58:06 +08:00
return nil, &types.AppError{Error: errors.New("invalid request payload"), Code: http.StatusBadRequest}
2023-08-12 19:21:42 +08:00
}
2023-09-08 22:51:54 +08:00
if !checkUserIsAllowed(session.UserName) {
return nil, &types.AppError{Error: errors.New("user not allowed"), Code: http.StatusUnauthorized}
}
2023-08-12 19:21:42 +08:00
now := time.Now().UTC()
jwtClaims := &types.JWTClaims{Claims: jwt.Claims{
2023-08-22 23:00:14 +08:00
Subject: strconv.FormatInt(session.UserID, 10),
2023-08-12 19:21:42 +08:00
IssuedAt: jwt.NewNumericDate(now),
Expiry: jwt.NewNumericDate(now.Add(time.Duration(as.SessionMaxAge) * time.Second)),
2023-08-15 03:36:24 +08:00
}, TgSession: session.Sesssion,
Name: session.Name,
UserName: session.UserName,
Bot: session.Bot,
IsPremium: session.IsPremium,
}
2023-08-12 19:21:42 +08:00
2023-11-02 21:51:30 +08:00
tokenhash := md5.Sum([]byte(session.Sesssion))
hexToken := hex.EncodeToString(tokenhash[:])
2023-09-20 03:20:44 +08:00
jwtClaims.Hash = hexToken
2023-08-12 19:21:42 +08:00
jweToken, err := auth.Encode(jwtClaims)
if err != nil {
2023-08-14 04:58:06 +08:00
return nil, &types.AppError{Error: err, Code: http.StatusBadRequest}
2023-08-12 19:21:42 +08:00
}
2023-08-14 04:58:06 +08:00
user := models.User{
2023-08-15 03:36:24 +08:00
UserId: session.UserID,
2023-08-14 04:58:06 +08:00
Name: session.Name,
UserName: session.UserName,
IsPremium: session.IsPremium,
}
2023-08-17 00:24:43 +08:00
var result []models.User
2023-09-20 03:20:44 +08:00
if err := as.Db.Model(&models.User{}).Where("user_id = ?", session.UserID).
Find(&result).Error; err != nil {
2023-12-04 03:21:30 +08:00
return nil, &types.AppError{Error: err, Code: http.StatusInternalServerError}
2023-08-14 04:58:06 +08:00
}
2023-08-17 00:24:43 +08:00
if len(result) == 0 {
if err := as.Db.Create(&user).Error; err != nil {
2023-12-04 03:21:30 +08:00
return nil, &types.AppError{Error: err, Code: http.StatusInternalServerError}
2023-08-17 00:24:43 +08:00
}
2023-08-17 01:10:41 +08:00
//Create root folder on first login
file := &models.File{
Name: "root",
Type: "folder",
MimeType: "drive/folder",
Path: "/",
Depth: utils.IntPointer(0),
UserID: session.UserID,
Status: "active",
ParentID: "root",
}
if err := as.Db.Create(file).Error; err != nil {
2023-12-04 03:21:30 +08:00
return nil, &types.AppError{Error: err, Code: http.StatusInternalServerError}
2023-08-17 01:10:41 +08:00
}
2023-08-17 00:24:43 +08:00
}
2023-09-20 03:20:44 +08:00
setCookie(c, as.SessionCookieName, jweToken, as.SessionMaxAge)
2023-09-20 03:20:44 +08:00
2023-11-02 21:51:30 +08:00
//create session
if err := as.Db.Create(&models.Session{UserId: session.UserID, Hash: hexToken, Session: session.Sesssion}).Error; err != nil {
2023-12-04 03:21:30 +08:00
return nil, &types.AppError{Error: err, Code: http.StatusInternalServerError}
2023-11-02 21:51:30 +08:00
}
2023-09-20 03:20:44 +08:00
2023-12-03 03:47:23 +08:00
return &schemas.Message{Message: "login success"}, nil
2023-08-12 19:21:42 +08:00
}
2023-12-03 03:47:23 +08:00
func (as *AuthService) GetSession(c *gin.Context) *schemas.Session {
2023-08-12 19:21:42 +08:00
cookie, err := c.Request.Cookie(as.SessionCookieName)
2023-08-12 19:21:42 +08:00
if err != nil {
return nil
}
jwePayload, err := auth.Decode(cookie.Value)
if err != nil {
return nil
}
now := time.Now().UTC()
newExpires := now.Add(time.Duration(as.SessionMaxAge) * time.Second)
2023-12-03 03:47:23 +08:00
session := &schemas.Session{Name: jwePayload.Name,
2023-09-20 03:20:44 +08:00
UserName: jwePayload.UserName,
Hash: jwePayload.Hash,
Expires: newExpires.Format(time.RFC3339)}
2023-08-12 19:21:42 +08:00
jwePayload.IssuedAt = jwt.NewNumericDate(now)
jwePayload.Expiry = jwt.NewNumericDate(newExpires)
jweToken, err := auth.Encode(jwePayload)
if err != nil {
return nil
}
setCookie(c, as.SessionCookieName, jweToken, as.SessionMaxAge)
2023-08-12 19:21:42 +08:00
return session
}
2023-08-13 04:15:19 +08:00
2023-08-14 04:58:06 +08:00
func (as *AuthService) Logout(c *gin.Context) (*schemas.Message, *types.AppError) {
2023-08-13 04:15:19 +08:00
val, _ := c.Get("jwtUser")
jwtUser := val.(*types.JWTClaims)
2023-11-25 12:31:29 +08:00
client, _ := tgc.UserLogin(c, jwtUser.TgSession)
2023-08-24 02:40:40 +08:00
2023-09-20 03:20:44 +08:00
tgc.RunWithAuth(c, client, "", func(ctx context.Context) error {
2023-08-24 02:40:40 +08:00
_, err := client.API().AuthLogOut(c)
return err
})
2023-08-13 04:15:19 +08:00
setCookie(c, as.SessionCookieName, "", -1)
2023-11-02 21:51:30 +08:00
as.Db.Where("session = ?", jwtUser.TgSession).Delete(&models.Session{})
2023-12-03 03:47:23 +08:00
return &schemas.Message{Message: "logout success"}, nil
2023-08-13 04:15:19 +08:00
}
2023-08-15 03:36:24 +08:00
2023-12-03 03:47:23 +08:00
func prepareSession(user *tg.User, data *session.Data) *schemas.TgSession {
2023-08-15 03:36:24 +08:00
sessionString := generateTgSession(data.DC, data.AuthKey, 443)
2023-12-03 03:47:23 +08:00
session := &schemas.TgSession{
2023-08-15 03:36:24 +08:00
Sesssion: sessionString,
2023-08-22 23:00:14 +08:00
UserID: user.ID,
2023-08-15 03:36:24 +08:00
Bot: user.Bot,
UserName: user.Username,
Name: fmt.Sprintf("%s %s", user.FirstName, user.LastName),
IsPremium: user.Premium,
}
return session
}
2023-08-22 23:00:14 +08:00
func (as *AuthService) HandleMultipleLogin(c *gin.Context) {
2023-08-15 03:36:24 +08:00
upgrader := websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true
},
}
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
return
}
defer conn.Close()
dispatcher := tg.NewUpdateDispatcher()
loggedIn := qrlogin.OnLoginToken(dispatcher)
sessionStorage := &session.StorageMemory{}
2023-11-25 12:31:29 +08:00
tgClient := tgc.NoLogin(c, dispatcher, sessionStorage)
2023-08-25 18:30:46 +08:00
2023-09-20 03:20:44 +08:00
err = tgClient.Run(c, func(ctx context.Context) error {
for {
2023-12-03 05:46:53 +08:00
message := &types.SocketMessage{}
2023-09-20 03:20:44 +08:00
err := conn.ReadJSON(message)
2023-08-25 18:30:46 +08:00
2023-09-20 03:20:44 +08:00
if err != nil {
return err
}
if message.AuthType == "qr" {
go func() {
authorization, err := tgClient.QR().Auth(c, loggedIn, func(ctx context.Context, token qrlogin.Token) error {
conn.WriteJSON(map[string]interface{}{"type": "auth", "payload": map[string]string{"token": token.URL()}})
return nil
})
if tgerr.Is(err, "SESSION_PASSWORD_NEEDED") {
conn.WriteJSON(map[string]interface{}{"type": "auth", "message": "2FA required"})
return
}
if err != nil {
conn.WriteJSON(map[string]interface{}{"type": "error", "message": err.Error()})
return
}
user, ok := authorization.User.AsNotEmpty()
if !ok {
conn.WriteJSON(map[string]interface{}{"type": "error", "message": "auth failed"})
return
}
if !checkUserIsAllowed(user.Username) {
conn.WriteJSON(map[string]interface{}{"type": "error", "message": "user not allowed"})
tgClient.API().AuthLogOut(c)
return
}
res, _ := sessionStorage.LoadSession(c)
2023-12-03 05:46:53 +08:00
sessionData := &types.SessionData{}
2023-09-20 03:20:44 +08:00
json.Unmarshal(res, sessionData)
session := prepareSession(user, &sessionData.Data)
conn.WriteJSON(map[string]interface{}{"type": "auth", "payload": session, "message": "success"})
}()
}
if message.AuthType == "phone" && message.Message == "sendcode" {
go func() {
res, err := tgClient.Auth().SendCode(c, message.PhoneNo, tgauth.SendCodeOptions{})
if err != nil {
conn.WriteJSON(map[string]interface{}{"type": "error", "message": err.Error()})
return
}
code := res.(*tg.AuthSentCode)
conn.WriteJSON(map[string]interface{}{"type": "auth", "payload": map[string]string{"phoneCodeHash": code.PhoneCodeHash}})
}()
}
if message.AuthType == "phone" && message.Message == "signin" {
go func() {
auth, err := tgClient.Auth().SignIn(c, message.PhoneNo, message.PhoneCode, message.PhoneCodeHash)
if errors.Is(err, tgauth.ErrPasswordAuthNeeded) {
conn.WriteJSON(map[string]interface{}{"type": "auth", "message": "2FA required"})
return
}
if err != nil {
conn.WriteJSON(map[string]interface{}{"type": "error", "message": err.Error()})
return
}
user, ok := auth.User.AsNotEmpty()
if !ok {
conn.WriteJSON(map[string]interface{}{"type": "error", "message": "auth failed"})
return
}
if !checkUserIsAllowed(user.Username) {
conn.WriteJSON(map[string]interface{}{"type": "error", "message": "user not allowed"})
tgClient.API().AuthLogOut(c)
return
}
res, _ := sessionStorage.LoadSession(c)
2023-12-03 05:46:53 +08:00
sessionData := &types.SessionData{}
2023-09-20 03:20:44 +08:00
json.Unmarshal(res, sessionData)
session := prepareSession(user, &sessionData.Data)
conn.WriteJSON(map[string]interface{}{"type": "auth", "payload": session, "message": "success"})
}()
}
2023-08-25 18:30:46 +08:00
2023-09-20 03:20:44 +08:00
if message.AuthType == "2fa" && message.Password != "" {
go func() {
auth, err := tgClient.Auth().Password(c, message.Password)
if err != nil {
conn.WriteJSON(map[string]interface{}{"type": "error", "message": err.Error()})
return
}
user, ok := auth.User.AsNotEmpty()
if !ok {
conn.WriteJSON(map[string]interface{}{"type": "error", "message": "auth failed"})
return
}
if !checkUserIsAllowed(user.Username) {
conn.WriteJSON(map[string]interface{}{"type": "error", "message": "user not allowed"})
tgClient.API().AuthLogOut(c)
return
}
res, _ := sessionStorage.LoadSession(c)
2023-12-03 05:46:53 +08:00
sessionData := &types.SessionData{}
2023-09-20 03:20:44 +08:00
json.Unmarshal(res, sessionData)
session := prepareSession(user, &sessionData.Data)
conn.WriteJSON(map[string]interface{}{"type": "auth", "payload": session, "message": "success"})
}()
}
2023-08-25 18:30:46 +08:00
}
2023-09-20 03:20:44 +08:00
})
2023-08-25 18:30:46 +08:00
2023-09-20 03:20:44 +08:00
if err != nil {
return
2023-08-25 18:30:46 +08:00
}
2023-08-15 03:36:24 +08:00
}