package services import ( "bytes" "context" "crypto/md5" "encoding/base64" "encoding/binary" "encoding/hex" "encoding/json" "fmt" "math/big" "net" "net/http" "strconv" "time" "github.com/go-faster/errors" "github.com/golang-jwt/jwt/v5" "github.com/gorilla/websocket" "github.com/gotd/td/session" tgauth "github.com/gotd/td/telegram/auth" "github.com/gotd/td/telegram/auth/qrlogin" "github.com/gotd/td/tg" "github.com/gotd/td/tgerr" "github.com/tgdrive/teldrive/internal/api" "github.com/tgdrive/teldrive/internal/auth" "github.com/tgdrive/teldrive/internal/cache" "github.com/tgdrive/teldrive/internal/logging" "github.com/tgdrive/teldrive/internal/tgc" "github.com/tgdrive/teldrive/pkg/models" "github.com/tgdrive/teldrive/pkg/types" "go.uber.org/zap" "gorm.io/gorm" "gorm.io/gorm/clause" ) var authCookieName = "access_token" func (a *apiService) AuthLogin(ctx context.Context, session *api.SessionCreate) (*api.AuthLoginNoContent, error) { if !checkUserIsAllowed(a.cnf.JWT.AllowedUsers, session.UserName) { return nil, &apiError{code: http.StatusForbidden, err: errors.New("user not allowed")} } now := time.Now().UTC() jwtClaims := &types.JWTClaims{ Name: session.Name, UserName: session.UserName, IsPremium: session.IsPremium, RegisteredClaims: jwt.RegisteredClaims{ Subject: strconv.FormatInt(session.UserId, 10), IssuedAt: jwt.NewNumericDate(now), ExpiresAt: jwt.NewNumericDate(now.Add(a.cnf.JWT.SessionTime)), }} tokenhash := md5.Sum([]byte(session.Session)) hexToken := hex.EncodeToString(tokenhash[:]) jwtClaims.Hash = hexToken jwtToken, err := auth.Encode(a.cnf.JWT.Secret, jwtClaims) if err != nil { return nil, &apiError{err: err} } user := models.User{ UserId: session.UserId, Name: session.Name, UserName: session.UserName, IsPremium: session.IsPremium, } err = a.db.Transaction(func(tx *gorm.DB) error { if err := a.db.Clauses(clause.OnConflict{DoNothing: true}).Create(&user).Error; err != nil { return err } file := &models.File{ Name: "root", Type: "folder", MimeType: "drive/folder", UserId: session.UserId, Status: "active", Parts: nil, } if err := a.db.Clauses(clause.OnConflict{DoNothing: true}).Create(file).Error; err != nil { return err } return nil }) if err != nil { return nil, &apiError{err: err} } client, _ := tgc.AuthClient(ctx, &a.cnf.TG, session.Session, a.middlewares...) var auth *tg.Authorization err = client.Run(ctx, func(ctx context.Context) error { auths, err := client.API().AccountGetAuthorizations(ctx) if err != nil { return err } for _, a := range auths.Authorizations { if a.Current { auth = &a break } } return nil }) if err != nil { return nil, &apiError{err: err} } if err := a.db.Create(&models.Session{UserId: session.UserId, Hash: hexToken, Session: session.Session, SessionDate: auth.DateCreated}).Error; err != nil { return nil, &apiError{err: err} } return &api.AuthLoginNoContent{SetCookie: setCookie(authCookieName, jwtToken, int(a.cnf.JWT.SessionTime.Seconds()))}, nil } func (a *apiService) AuthLogout(ctx context.Context) (*api.AuthLogoutNoContent, error) { authUser := auth.GetJWTUser(ctx) client, _ := tgc.AuthClient(ctx, &a.cnf.TG, authUser.TgSession, a.middlewares...) tgc.RunWithAuth(ctx, client, "", func(ctx context.Context) error { _, err := client.API().AuthLogOut(ctx) return err }) a.db.Where("hash = ?", authUser.Hash).Delete(&models.Session{}) a.cache.Delete(cache.Key("sessions", authUser.Hash), cache.Key("users", "sessions", authUser.ID)) return &api.AuthLogoutNoContent{SetCookie: setCookie(authCookieName, "", -1)}, nil } func (a *apiService) AuthSession(ctx context.Context, params api.AuthSessionParams) (api.AuthSessionRes, error) { if params.AccessToken.Value == "" { return &api.AuthSessionNoContent{}, nil } claims, err := auth.VerifyUser(a.db, a.cache, a.cnf.JWT.Secret, params.AccessToken.Value) if err != nil { return &api.AuthSessionNoContent{}, nil } claims.TgSession = "" now := time.Now().UTC() newExpires := now.Add(a.cnf.JWT.SessionTime) userId, _ := strconv.ParseInt(claims.Subject, 10, 64) session := api.Session{ Name: claims.Name, UserName: claims.UserName, UserId: userId, Hash: claims.Hash, Expires: newExpires} claims.IssuedAt = jwt.NewNumericDate(now) claims.ExpiresAt = jwt.NewNumericDate(newExpires) jweToken, err := auth.Encode(a.cnf.JWT.Secret, claims) if err != nil { return &api.AuthSessionNoContent{}, nil } return &api.SessionHeaders{SetCookie: setCookie(authCookieName, jweToken, int(a.cnf.JWT.SessionTime.Seconds())), Response: session}, nil } func (a *apiService) AuthWs(ctx context.Context) error { return nil } func (e *extendedService) AuthWs(w http.ResponseWriter, r *http.Request) { ctx := r.Context() upgrader := websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true }, } logger := logging.FromContext(ctx) conn, err := upgrader.Upgrade(w, r, nil) if err != nil { logger.Error("websocket upgrade error", zap.Error(err)) http.Error(w, "could not upgrade connection", http.StatusBadRequest) return } defer func() { if err := conn.Close(); err != nil { logger.Error("error closing websocket connection", zap.Error(err)) } }() dispatcher := tg.NewUpdateDispatcher() loggedIn := qrlogin.OnLoginToken(dispatcher) sessionStorage := &session.StorageMemory{} tgClient, err := tgc.NoAuthClient(ctx, &e.api.cnf.TG, dispatcher, sessionStorage) if err != nil { logger.Error("error creating telegram client", zap.Error(err)) return } err = tgClient.Run(ctx, func(ctx context.Context) error { for { message := &types.SocketMessage{} err := conn.ReadJSON(message) if err != nil { if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) { logger.Debug("websocket connection closed normally by client") return nil } return err } switch message.AuthType { case "qr": go func() { authorization, err := tgClient.QR().Auth(ctx, loggedIn, func(ctx context.Context, token qrlogin.Token) error { conn.WriteJSON(map[string]interface{}{"type": "auth", "payload": map[string]string{"token": token.URL()}}) return nil }) if errors.Is(err, context.Canceled) { return } if tgerr.Is(err, "SESSION_PASSWORD_NEEDED") { conn.WriteJSON(map[string]interface{}{"type": "auth", "message": "2FA required"}) return } if err != nil { logger.Error("QR auth error", zap.Error(err)) conn.WriteJSON(map[string]interface{}{"type": "error", "message": err.Error()}) return } user, ok := authorization.User.AsNotEmpty() if !ok { conn.WriteJSON(map[string]interface{}{"type": "error", "message": "auth failed"}) return } if !checkUserIsAllowed(e.api.cnf.JWT.AllowedUsers, user.Username) { conn.WriteJSON(map[string]interface{}{"type": "error", "message": "user not allowed"}) tgClient.API().AuthLogOut(ctx) return } res, _ := sessionStorage.LoadSession(ctx) sessionData := &types.SessionData{} json.Unmarshal(res, sessionData) session := prepareSession(user, &sessionData.Data) conn.WriteJSON(map[string]interface{}{"type": "auth", "payload": session, "message": "success"}) }() case "phone": if message.Message == "sendcode" { go func() { res, err := tgClient.Auth().SendCode(ctx, message.PhoneNo, tgauth.SendCodeOptions{}) if errors.Is(err, context.Canceled) { return } logger.Error("error sending code", zap.Error(err)) if err != nil { conn.WriteJSON(map[string]interface{}{"type": "error", "message": err.Error()}) return } code := res.(*tg.AuthSentCode) conn.WriteJSON(map[string]interface{}{"type": "auth", "payload": map[string]string{"phoneCodeHash": code.PhoneCodeHash}}) }() } else if message.Message == "signin" { go func() { auth, err := tgClient.Auth().SignIn(ctx, message.PhoneNo, message.PhoneCode, message.PhoneCodeHash) if errors.Is(err, context.Canceled) { return } if errors.Is(err, tgauth.ErrPasswordAuthNeeded) { conn.WriteJSON(map[string]interface{}{"type": "auth", "message": tgauth.ErrPasswordAuthNeeded.Error()}) return } if tgerr.Is(err, "PHONE_CODE_INVALID") { conn.WriteJSON(map[string]interface{}{"type": "auth", "message": "PHONE_CODE_INVALID"}) return } if err != nil { logger.Error("phone sign-in error", zap.Error(err)) conn.WriteJSON(map[string]interface{}{"type": "error", "message": err.Error()}) return } user, ok := auth.User.AsNotEmpty() if !ok { conn.WriteJSON(map[string]interface{}{"type": "error", "message": "auth failed"}) return } if !checkUserIsAllowed(e.api.cnf.JWT.AllowedUsers, user.Username) { conn.WriteJSON(map[string]interface{}{"type": "error", "message": "user not allowed"}) tgClient.API().AuthLogOut(ctx) return } res, _ := sessionStorage.LoadSession(ctx) sessionData := &types.SessionData{} json.Unmarshal(res, sessionData) session := prepareSession(user, &sessionData.Data) conn.WriteJSON(map[string]interface{}{"type": "auth", "payload": session, "message": "success"}) }() } case "2fa": if message.Password != "" { go func() { auth, err := tgClient.Auth().Password(ctx, message.Password) if errors.Is(err, context.Canceled) { return } if err != nil { logger.Error("phone sign-in error", zap.Error(err)) conn.WriteJSON(map[string]interface{}{"type": "error", "message": err.Error()}) return } user, ok := auth.User.AsNotEmpty() if !ok { conn.WriteJSON(map[string]interface{}{"type": "error", "message": "auth failed"}) return } if !checkUserIsAllowed(e.api.cnf.JWT.AllowedUsers, user.Username) { conn.WriteJSON(map[string]interface{}{"type": "error", "message": "user not allowed"}) tgClient.API().AuthLogOut(ctx) return } res, _ := sessionStorage.LoadSession(ctx) sessionData := &types.SessionData{} json.Unmarshal(res, sessionData) session := prepareSession(user, &sessionData.Data) conn.WriteJSON(map[string]interface{}{"type": "auth", "payload": session, "message": "success"}) }() } } } }) if err != nil { logger.Error("error running telegram client", zap.Error(err)) return } } func ip4toInt(IPv4Address net.IP) int64 { IPv4Int := big.NewInt(0) IPv4Int.SetBytes(IPv4Address.To4()) return IPv4Int.Int64() } func pack32BinaryIP4(ip4Address string) []byte { ipv4Decimal := ip4toInt(net.ParseIP(ip4Address)) buf := new(bytes.Buffer) binary.Write(buf, binary.BigEndian, uint32(ipv4Decimal)) return buf.Bytes() } func generateTgSession(dcId int, authKey []byte, port int) string { dcMaps := map[int]string{ 1: "149.154.175.53", 2: "149.154.167.51", 3: "149.154.175.100", 4: "149.154.167.91", 5: "91.108.56.130", } dcIDByte := byte(dcId) serverAddressBytes := pack32BinaryIP4(dcMaps[dcId]) portByte := make([]byte, 2) binary.BigEndian.PutUint16(portByte, uint16(port)) packet := make([]byte, 0) packet = append(packet, dcIDByte) packet = append(packet, serverAddressBytes...) packet = append(packet, portByte...) packet = append(packet, authKey...) base64Encoded := base64.URLEncoding.EncodeToString(packet) return "1" + base64Encoded } func checkUserIsAllowed(allowedUsers []string, userName string) bool { found := false if len(allowedUsers) > 0 { for _, user := range allowedUsers { if user == userName { found = true break } } } else { found = true } return found } func prepareSession(user *tg.User, data *session.Data) *api.SessionCreate { sessionString := generateTgSession(data.DC, data.AuthKey, 443) session := &api.SessionCreate{ Session: sessionString, UserId: user.ID, UserName: user.Username, Name: fmt.Sprintf("%s %s", user.FirstName, user.LastName), IsPremium: user.Premium, } return session } func setCookie(name, value string, maxAge int) string { cookie := http.Cookie{ Name: name, Value: value, MaxAge: maxAge, HttpOnly: true, Path: "/", SameSite: http.SameSiteLaxMode, } return cookie.String() }