diff --git a/go.mod b/go.mod index 8fe9a87..92cd1f8 100644 --- a/go.mod +++ b/go.mod @@ -30,6 +30,7 @@ require ( require ( filippo.io/edwards25519 v1.1.0 // indirect + github.com/beevik/ntp v1.4.3 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/coder/websocket v1.8.12 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect diff --git a/go.sum b/go.sum index f021573..0212f13 100644 --- a/go.sum +++ b/go.sum @@ -4,6 +4,8 @@ github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7Oputl github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= github.com/WinterYukky/gorm-extra-clause-plugin v0.3.0 h1:fQfTkxoRso6mlm7eOfBIk94aqamJeqxCEfU+MyLWhgo= github.com/WinterYukky/gorm-extra-clause-plugin v0.3.0/go.mod h1:GFT8TzxeeGKYXNU/65PsiN2+zNHigm9HjybnbL1T7eg= +github.com/beevik/ntp v1.4.3 h1:PlbTvE5NNy4QHmA4Mg57n7mcFTmr1W1j3gcK7L1lqho= +github.com/beevik/ntp v1.4.3/go.mod h1:Unr8Zg+2dRn7d8bHFuehIMSvvUYssHMxW3Q5Nx4RW5Q= github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= diff --git a/internal/tgc/tgc.go b/internal/tgc/tgc.go index 5aacae1..f7b985a 100644 --- a/internal/tgc/tgc.go +++ b/internal/tgc/tgc.go @@ -6,6 +6,7 @@ import ( "github.com/cenkalti/backoff/v4" "github.com/go-faster/errors" + "github.com/gotd/contrib/clock" "github.com/gotd/contrib/middleware/floodwait" "github.com/gotd/contrib/middleware/ratelimit" "github.com/gotd/td/session" @@ -38,6 +39,10 @@ func New(ctx context.Context, config *config.TGConfig, handler telegram.UpdateHa logger = logging.FromContext(ctx).Named("td") } + c, err := clock.NewNTP() + if err != nil { + return nil, errors.Wrap(err, "create clock") + } opts := telegram.Options{ Resolver: dcs.Plain(dcs.PlainOptions{ @@ -61,6 +66,7 @@ func New(ctx context.Context, config *config.TGConfig, handler telegram.UpdateHa Middlewares: middlewares, UpdateHandler: handler, Logger: logger, + Clock: c, } return telegram.NewClient(config.AppId, config.AppHash, opts), nil diff --git a/pkg/services/auth.go b/pkg/services/auth.go index 8d4deb4..31f42d5 100644 --- a/pkg/services/auth.go +++ b/pkg/services/auth.go @@ -19,15 +19,18 @@ import ( "github.com/golang-jwt/jwt/v5" "github.com/gorilla/websocket" "github.com/gotd/td/session" + "github.com/gotd/td/telegram" tgauth "github.com/gotd/td/telegram/auth" "github.com/gotd/td/telegram/auth/qrlogin" "github.com/gotd/td/tg" "github.com/gotd/td/tgerr" "github.com/tgdrive/teldrive/internal/api" "github.com/tgdrive/teldrive/internal/auth" + "github.com/tgdrive/teldrive/internal/logging" "github.com/tgdrive/teldrive/internal/tgc" "github.com/tgdrive/teldrive/pkg/models" "github.com/tgdrive/teldrive/pkg/types" + "go.uber.org/zap" "gorm.io/gorm" "gorm.io/gorm/clause" ) @@ -176,139 +179,213 @@ func (a *apiService) AuthWs(ctx context.Context) error { func (e *extendedService) AuthWs(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + upgrader := websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true }, } + + logger := logging.FromContext(ctx).With(zap.String("handler", "AuthWs")) conn, err := upgrader.Upgrade(w, r, nil) if err != nil { + logger.Error("websocket upgrade error", zap.Error(err)) http.Error(w, "could not upgrade connection", http.StatusBadRequest) return } - defer conn.Close() + + defer func() { + if err := conn.Close(); err != nil { + logger.Error("error closing websocket connection", zap.Error(err)) + } else { + logger.Info("websocket connection closed") + } + }() dispatcher := tg.NewUpdateDispatcher() loggedIn := qrlogin.OnLoginToken(dispatcher) sessionStorage := &session.StorageMemory{} - tgClient, _ := tgc.NoAuthClient(ctx, &e.api.cnf.TG, dispatcher, sessionStorage) + tgClient, err := tgc.NoAuthClient(ctx, &e.api.cnf.TG, dispatcher, sessionStorage) + if err != nil { + logger.Error("error creating telegram client", zap.Error(err)) + return + } err = tgClient.Run(ctx, func(ctx context.Context) error { for { message := &types.SocketMessage{} err := conn.ReadJSON(message) - if err != nil { + if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) { + logger.Debug("websocket connection closed normally by client") + return nil + } + logger.Error("websocket read error", zap.Error(err)) return err } - if message.AuthType == "qr" { - go func() { - authorization, err := tgClient.QR().Auth(ctx, loggedIn, func(ctx context.Context, token qrlogin.Token) error { - conn.WriteJSON(map[string]interface{}{"type": "auth", "payload": map[string]string{"token": token.URL()}}) - return nil - }) - - if tgerr.Is(err, "SESSION_PASSWORD_NEEDED") { - conn.WriteJSON(map[string]interface{}{"type": "auth", "message": "2FA required"}) - return - } - - if err != nil { - conn.WriteJSON(map[string]interface{}{"type": "error", "message": err.Error()}) - return - } - user, ok := authorization.User.AsNotEmpty() - if !ok { - conn.WriteJSON(map[string]interface{}{"type": "error", "message": "auth failed"}) - return - } - if !checkUserIsAllowed(e.api.cnf.JWT.AllowedUsers, user.Username) { - conn.WriteJSON(map[string]interface{}{"type": "error", "message": "user not allowed"}) - tgClient.API().AuthLogOut(ctx) - return - } - res, _ := sessionStorage.LoadSession(ctx) - sessionData := &types.SessionData{} - json.Unmarshal(res, sessionData) - session := prepareSession(user, &sessionData.Data) - conn.WriteJSON(map[string]interface{}{"type": "auth", "payload": session, "message": "success"}) - }() - } - if message.AuthType == "phone" && message.Message == "sendcode" { - go func() { - res, err := tgClient.Auth().SendCode(ctx, message.PhoneNo, tgauth.SendCodeOptions{}) - if err != nil { - conn.WriteJSON(map[string]interface{}{"type": "error", "message": err.Error()}) - return - } - code := res.(*tg.AuthSentCode) - conn.WriteJSON(map[string]interface{}{"type": "auth", "payload": map[string]string{"phoneCodeHash": code.PhoneCodeHash}}) - }() - } - if message.AuthType == "phone" && message.Message == "signin" { - go func() { - auth, err := tgClient.Auth().SignIn(ctx, message.PhoneNo, message.PhoneCode, message.PhoneCodeHash) - - if errors.Is(err, tgauth.ErrPasswordAuthNeeded) { - conn.WriteJSON(map[string]interface{}{"type": "auth", "message": "2FA required"}) - return - } - - if err != nil { - conn.WriteJSON(map[string]interface{}{"type": "error", "message": err.Error()}) - return - } - user, ok := auth.User.AsNotEmpty() - if !ok { - conn.WriteJSON(map[string]interface{}{"type": "error", "message": "auth failed"}) - return - } - if !checkUserIsAllowed(e.api.cnf.JWT.AllowedUsers, user.Username) { - conn.WriteJSON(map[string]interface{}{"type": "error", "message": "user not allowed"}) - tgClient.API().AuthLogOut(ctx) - return - } - res, _ := sessionStorage.LoadSession(ctx) - sessionData := &types.SessionData{} - json.Unmarshal(res, sessionData) - session := prepareSession(user, &sessionData.Data) - conn.WriteJSON(map[string]interface{}{"type": "auth", "payload": session, "message": "success"}) - }() - } - - if message.AuthType == "2fa" && message.Password != "" { - go func() { - auth, err := tgClient.Auth().Password(ctx, message.Password) - if err != nil { - conn.WriteJSON(map[string]interface{}{"type": "error", "message": err.Error()}) - return - } - user, ok := auth.User.AsNotEmpty() - if !ok { - conn.WriteJSON(map[string]interface{}{"type": "error", "message": "auth failed"}) - return - } - if !checkUserIsAllowed(e.api.cnf.JWT.AllowedUsers, user.Username) { - conn.WriteJSON(map[string]interface{}{"type": "error", "message": "user not allowed"}) - tgClient.API().AuthLogOut(ctx) - return - } - res, _ := sessionStorage.LoadSession(ctx) - sessionData := &types.SessionData{} - json.Unmarshal(res, sessionData) - session := prepareSession(user, &sessionData.Data) - conn.WriteJSON(map[string]interface{}{"type": "auth", "payload": session, "message": "success"}) - }() + switch message.AuthType { + case "qr": + go e.handleQRAuth(ctx, conn, tgClient, loggedIn, sessionStorage, logger) + case "phone": + if message.Message == "sendcode" { + go e.handleSendCode(ctx, conn, tgClient, message, logger) + } else if message.Message == "signin" { + go e.handleSignIn(ctx, conn, tgClient, message, sessionStorage, logger) + } + case "2fa": + if message.Password != "" { + go e.handle2FA(ctx, conn, tgClient, message, sessionStorage, logger) + } } } }) if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + logger.Error("error during tgClient.Run", zap.Error(err)) + if !errors.Is(err, context.Canceled) { + } return } } +func (e *extendedService) handleQRAuth( + ctx context.Context, + conn *websocket.Conn, + tgClient *telegram.Client, + loggedIn qrlogin.LoggedIn, + sessionStorage *session.StorageMemory, + logger *zap.Logger) { + logger = logger.With(zap.String("handler", "handleQRAuth")) + authorization, err := tgClient.QR().Auth(ctx, loggedIn, func(ctx context.Context, token qrlogin.Token) error { + return conn.WriteJSON(map[string]interface{}{"type": "auth", "payload": map[string]string{"token": token.URL()}}) + }) + + if tgerr.Is(err, "SESSION_PASSWORD_NEEDED") { + logger.Debug("2FA required for QR auth") + conn.WriteJSON(map[string]interface{}{"type": "auth", "message": "2FA required"}) + return + } + + if err != nil { + logger.Error("QR auth error", zap.Error(err)) + conn.WriteJSON(map[string]interface{}{"type": "error", "message": err.Error()}) + return + } + user, ok := authorization.User.AsNotEmpty() + if !ok { + logger.Error("QR auth failed, user not found") + conn.WriteJSON(map[string]interface{}{"type": "error", "message": "auth failed"}) + return + } + if !checkUserIsAllowed(e.api.cnf.JWT.AllowedUsers, user.Username) { + logger.Error("user not allowed", zap.String("username", user.Username)) + conn.WriteJSON(map[string]interface{}{"type": "error", "message": "user not allowed"}) + tgClient.API().AuthLogOut(ctx) + return + } + + res, _ := sessionStorage.LoadSession(ctx) + sessionData := &types.SessionData{} + json.Unmarshal(res, sessionData) + session := prepareSession(user, &sessionData.Data) + conn.WriteJSON(map[string]interface{}{"type": "auth", "payload": session, "message": "success"}) + logger.Info("QR auth success", zap.String("username", user.Username)) +} + +func (e *extendedService) handleSendCode( + ctx context.Context, + conn *websocket.Conn, + tgClient *telegram.Client, + message *types.SocketMessage, + logger *zap.Logger) { + logger = logger.With(zap.String("handler", "handleSendCode")) + + res, err := tgClient.Auth().SendCode(ctx, message.PhoneNo, tgauth.SendCodeOptions{}) + if err != nil { + logger.Error("error sending code", zap.Error(err), zap.String("phoneNo", message.PhoneNo)) + conn.WriteJSON(map[string]interface{}{"type": "error", "message": err.Error()}) + return + } + code := res.(*tg.AuthSentCode) + conn.WriteJSON(map[string]interface{}{"type": "auth", "payload": map[string]string{"phoneCodeHash": code.PhoneCodeHash}}) +} + +func (e *extendedService) handleSignIn( + ctx context.Context, + conn *websocket.Conn, + tgClient *telegram.Client, + message *types.SocketMessage, + sessionStorage *session.StorageMemory, + logger *zap.Logger) { + logger = logger.With(zap.String("handler", "handleSignIn")) + + auth, err := tgClient.Auth().SignIn(ctx, message.PhoneNo, message.PhoneCode, message.PhoneCodeHash) + + if errors.Is(err, tgauth.ErrPasswordAuthNeeded) { + logger.Debug("2FA required for phone sign in") + conn.WriteJSON(map[string]interface{}{"type": "auth", "message": "2FA required"}) + return + } + + if err != nil { + logger.Error("phone sign-in error", zap.Error(err)) + conn.WriteJSON(map[string]interface{}{"type": "error", "message": err.Error()}) + return + } + user, ok := auth.User.AsNotEmpty() + if !ok { + logger.Error("phone sign-in failed, user not found") + conn.WriteJSON(map[string]interface{}{"type": "error", "message": "auth failed"}) + return + } + if !checkUserIsAllowed(e.api.cnf.JWT.AllowedUsers, user.Username) { + logger.Error("user not allowed", zap.String("username", user.Username)) + conn.WriteJSON(map[string]interface{}{"type": "error", "message": "user not allowed"}) + tgClient.API().AuthLogOut(ctx) + return + } + + res, _ := sessionStorage.LoadSession(ctx) + sessionData := &types.SessionData{} + json.Unmarshal(res, sessionData) + session := prepareSession(user, &sessionData.Data) + conn.WriteJSON(map[string]interface{}{"type": "auth", "payload": session, "message": "success"}) + logger.Debug("phone sign in success", zap.String("username", user.Username)) +} +func (e *extendedService) handle2FA( + ctx context.Context, + conn *websocket.Conn, + tgClient *telegram.Client, + message *types.SocketMessage, + sessionStorage *session.StorageMemory, + logger *zap.Logger) { + logger = logger.With(zap.String("handler", "handle2FA")) + auth, err := tgClient.Auth().Password(ctx, message.Password) + if err != nil { + logger.Error("2FA authentication error", zap.Error(err)) + conn.WriteJSON(map[string]interface{}{"type": "error", "message": err.Error()}) + return + } + user, ok := auth.User.AsNotEmpty() + if !ok { + logger.Error("2FA authentication failed, user not found") + conn.WriteJSON(map[string]interface{}{"type": "error", "message": "auth failed"}) + return + } + if !checkUserIsAllowed(e.api.cnf.JWT.AllowedUsers, user.Username) { + logger.Error("user not allowed", zap.String("username", user.Username)) + conn.WriteJSON(map[string]interface{}{"type": "error", "message": "user not allowed"}) + tgClient.API().AuthLogOut(ctx) + return + } + res, _ := sessionStorage.LoadSession(ctx) + sessionData := &types.SessionData{} + json.Unmarshal(res, sessionData) + session := prepareSession(user, &sessionData.Data) + conn.WriteJSON(map[string]interface{}{"type": "auth", "payload": session, "message": "success"}) + logger.Debug("2FA authentication success", zap.String("username", user.Username)) +} + func ip4toInt(IPv4Address net.IP) int64 { IPv4Int := big.NewInt(0) IPv4Int.SetBytes(IPv4Address.To4())