diff --git a/controllers/server.go b/controllers/server.go index bdcf51bd..6f5ee765 100644 --- a/controllers/server.go +++ b/controllers/server.go @@ -297,9 +297,10 @@ func updateSettings(w http.ResponseWriter, r *http.Request) { func reInit(curr, new models.ServerSettings, force bool) { logic.SettingsMutex.Lock() defer logic.SettingsMutex.Unlock() - logic.InitializeAuthProvider() + logic.ResetAuthProvider() logic.EmailInit() logic.SetVerbosity(int(logic.GetServerSettings().Verbosity)) + logic.ResetIDPSyncHook() // check if auto update is changed if force { if curr.NetclientAutoUpdate != new.NetclientAutoUpdate { diff --git a/controllers/user.go b/controllers/user.go index 386095b2..f71f9064 100644 --- a/controllers/user.go +++ b/controllers/user.go @@ -37,6 +37,8 @@ func userHandlers(r *mux.Router) { r.HandleFunc("/api/users/{username}", logic.SecurityCheck(true, checkFreeTierLimits(limitChoiceUsers, http.HandlerFunc(createUser)))).Methods(http.MethodPost) r.HandleFunc("/api/users/{username}", logic.SecurityCheck(true, http.HandlerFunc(deleteUser))).Methods(http.MethodDelete) r.HandleFunc("/api/users/{username}", logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(getUser)))).Methods(http.MethodGet) + r.HandleFunc("/api/users/{username}/enable", logic.SecurityCheck(true, http.HandlerFunc(enableUserAccount))).Methods(http.MethodPost) + r.HandleFunc("/api/users/{username}/disable", logic.SecurityCheck(true, http.HandlerFunc(disableUserAccount))).Methods(http.MethodPost) r.HandleFunc("/api/v1/users", logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(getUserV1)))).Methods(http.MethodGet) r.HandleFunc("/api/users", logic.SecurityCheck(true, http.HandlerFunc(getUsers))).Methods(http.MethodGet) r.HandleFunc("/api/v1/users/roles", logic.SecurityCheck(true, http.HandlerFunc(ListRoles))).Methods(http.MethodGet) @@ -270,6 +272,13 @@ func authenticateUser(response http.ResponseWriter, request *http.Request) { logic.ReturnErrorResponse(response, request, logic.FormatError(errors.New("user is registered via SSO"), "badrequest")) return } + + if user.AccountDisabled { + err = errors.New("user account disabled") + logic.ReturnErrorResponse(response, request, logic.FormatError(err, "unauthorized")) + return + } + if !user.IsSuperAdmin && !logic.IsBasicAuthEnabled() { logic.ReturnErrorResponse( response, @@ -446,6 +455,65 @@ func getUser(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(user) } +// @Summary Enable a user's account +// @Router /api/users/{username}/enable [post] +// @Tags Users +// @Param username path string true "Username of the user to enable" +// @Success 200 {object} models.SuccessResponse +// @Failure 400 {object} models.ErrorResponse +// @Failure 500 {object} models.ErrorResponse +func enableUserAccount(w http.ResponseWriter, r *http.Request) { + username := mux.Vars(r)["username"] + user, err := logic.GetUser(username) + if err != nil { + logger.Log(0, "failed to fetch user: ", err.Error()) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) + return + } + + user.AccountDisabled = false + err = logic.UpsertUser(*user) + if err != nil { + logger.Log(0, "failed to enable user account: ", err.Error()) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) + } + + logic.ReturnSuccessResponse(w, r, "user account enabled") +} + +// @Summary Disable a user's account +// @Router /api/users/{username}/disable [post] +// @Tags Users +// @Param username path string true "Username of the user to disable" +// @Success 200 {object} models.SuccessResponse +// @Failure 400 {object} models.ErrorResponse +// @Failure 500 {object} models.ErrorResponse +func disableUserAccount(w http.ResponseWriter, r *http.Request) { + username := mux.Vars(r)["username"] + user, err := logic.GetUser(username) + if err != nil { + logger.Log(0, "failed to fetch user: ", err.Error()) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) + return + } + + if user.PlatformRoleID == models.SuperAdminRole { + err = errors.New("cannot disable super-admin user account") + logger.Log(0, err.Error()) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) + return + } + + user.AccountDisabled = true + err = logic.UpsertUser(*user) + if err != nil { + logger.Log(0, "failed to disable user account: ", err.Error()) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) + } + + logic.ReturnSuccessResponse(w, r, "user account disabled") +} + // swagger:route GET /api/v1/users user getUserV1 // // Get an individual user with role info. diff --git a/database/database.go b/database/database.go index 483eb35f..59abd7dd 100644 --- a/database/database.go +++ b/database/database.go @@ -19,8 +19,6 @@ const ( DELETED_NODES_TABLE_NAME = "deletednodes" // USERS_TABLE_NAME - users table USERS_TABLE_NAME = "users" - // ACCESS_TOKENS_TABLE_NAME - access tokens table - ACCESS_TOKENS_TABLE_NAME = "user_access_tokens" // USER_PERMISSIONS_TABLE_NAME - user permissions table USER_PERMISSIONS_TABLE_NAME = "user_permissions" // CERTS_TABLE_NAME - certificates table diff --git a/go.mod b/go.mod index 3c1b664c..82bd5006 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ toolchain go1.23.7 require ( github.com/blang/semver v3.5.1+incompatible - github.com/eclipse/paho.mqtt.golang v1.4.3 + github.com/eclipse/paho.mqtt.golang v1.5.0 github.com/go-playground/validator/v10 v10.26.0 github.com/golang-jwt/jwt/v4 v4.5.2 github.com/google/uuid v1.6.0 @@ -21,7 +21,7 @@ require ( github.com/txn2/txeh v1.5.5 go.uber.org/automaxprocs v1.6.0 golang.org/x/crypto v0.38.0 - golang.org/x/net v0.37.0 // indirect + golang.org/x/net v0.39.0 // indirect golang.org/x/oauth2 v0.29.0 golang.org/x/sys v0.33.0 // indirect golang.org/x/text v0.25.0 // indirect @@ -42,11 +42,13 @@ require ( ) require ( + github.com/google/go-cmp v0.7.0 github.com/goombaio/namegenerator v0.0.0-20181006234301-989e774b106e github.com/guumaster/tablewriter v0.0.10 github.com/matryer/is v1.4.1 github.com/olekukonko/tablewriter v0.0.5 github.com/spf13/cobra v1.9.1 + google.golang.org/api v0.229.0 gopkg.in/mail.v2 v2.3.1 gorm.io/datatypes v1.2.5 gorm.io/driver/postgres v1.5.11 @@ -55,11 +57,17 @@ require ( ) require ( - cloud.google.com/go/compute/metadata v0.3.0 // indirect + cloud.google.com/go/auth v0.16.0 // indirect + cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect + cloud.google.com/go/compute/metadata v0.6.0 // indirect github.com/gabriel-vasile/mimetype v1.4.8 // indirect github.com/go-jose/go-jose/v4 v4.0.5 // indirect + github.com/go-logr/logr v1.4.2 // indirect + github.com/go-logr/stdr v1.2.2 // indirect github.com/go-sql-driver/mysql v1.8.1 // indirect - github.com/google/go-cmp v0.7.0 // indirect + github.com/google/s2a-go v0.1.9 // indirect + github.com/googleapis/enterprise-certificate-proxy v0.3.6 // indirect + github.com/googleapis/gax-go/v2 v2.14.1 // indirect github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect @@ -68,18 +76,25 @@ require ( github.com/jackc/puddle/v2 v2.2.2 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect - github.com/kr/text v0.2.0 // indirect github.com/rivo/uniseg v0.2.0 // indirect github.com/rogpeppe/go-internal v1.14.1 // indirect github.com/seancfoley/bintree v1.3.1 // indirect github.com/spf13/pflag v1.0.6 // indirect + go.opentelemetry.io/auto/sdk v1.1.0 // indirect + go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.60.0 // indirect + go.opentelemetry.io/otel v1.35.0 // indirect + go.opentelemetry.io/otel/metric v1.35.0 // indirect + go.opentelemetry.io/otel/trace v1.35.0 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20250414145226-207652e42e2e // indirect + google.golang.org/grpc v1.71.1 // indirect + google.golang.org/protobuf v1.36.6 // indirect gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect gorm.io/driver/mysql v1.5.6 // indirect ) require ( github.com/davecgh/go-spew v1.1.1 // indirect - github.com/felixge/httpsnoop v1.0.3 // indirect + github.com/felixge/httpsnoop v1.0.4 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect github.com/hashicorp/go-version v1.7.0 diff --git a/go.sum b/go.sum index f7823c29..313acd7b 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,9 @@ -cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc= -cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= +cloud.google.com/go/auth v0.16.0 h1:Pd8P1s9WkcrBE2n/PhAwKsdrR35V3Sg2II9B+ndM3CU= +cloud.google.com/go/auth v0.16.0/go.mod h1:1howDHJ5IETh/LwYs3ZxvlkXF48aSqqJUM+5o02dNOI= +cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIiLpZnkHRbnc= +cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3RrU88hAYYbjDWYDL+c= +cloud.google.com/go/compute/metadata v0.6.0 h1:A6hENjEsCDtC1k8byVsgwvVcioamEHvZ4j01OwKxG9I= +cloud.google.com/go/compute/metadata v0.6.0/go.mod h1:FjyFAW1MW0C203CEOMDTu3Dk1FlqW3Rga40jzHL4hfg= filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= github.com/blang/semver v3.5.1+incompatible h1:cQNTCjp13qL8KC3Nbxr/y2Bqb63oX6wdnnjpJbkM4JQ= @@ -9,18 +13,22 @@ github.com/c-robinson/iplib v1.0.8/go.mod h1:i3LuuFL1hRT5gFpBRnEydzw8R6yhGkF4szN github.com/coreos/go-oidc/v3 v3.14.1 h1:9ePWwfdwC4QKRlCXsJGou56adA/owXczOzwKdOumLqk= github.com/coreos/go-oidc/v3 v3.14.1/go.mod h1:HaZ3szPaZ0e4r6ebqvsLWlk2Tn+aejfmrfah6hnSYEU= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= -github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/eclipse/paho.mqtt.golang v1.4.3 h1:2kwcUGn8seMUfWndX0hGbvH8r7crgcJguQNCyp70xik= -github.com/eclipse/paho.mqtt.golang v1.4.3/go.mod h1:CSYvoAlsMkhYOXh/oKyxa8EcBci6dVkLCbo5tTC1RIE= -github.com/felixge/httpsnoop v1.0.3 h1:s/nj+GCswXYzN5v2DpNMuMQYe+0DDwt5WVCU6CWBdXk= -github.com/felixge/httpsnoop v1.0.3/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= +github.com/eclipse/paho.mqtt.golang v1.5.0 h1:EH+bUVJNgttidWFkLLVKaQPGmkTUfQQqjOsyvMGvD6o= +github.com/eclipse/paho.mqtt.golang v1.5.0/go.mod h1:du/2qNQVqJf/Sqs4MEL77kR8QTqANF7XU7Fk0aOTAgk= +github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= +github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/gabriel-vasile/mimetype v1.4.8 h1:FfZ3gj38NjllZIeJAmMhr+qKL8Wu+nOoI3GqacKw1NM= github.com/gabriel-vasile/mimetype v1.4.8/go.mod h1:ByKUIKGjh1ODkGM1asKUbQZOLGrPjydw3hYPU2YU9t8= github.com/go-jose/go-jose/v4 v4.0.5 h1:M6T8+mKZl/+fNNuFHvGIzDz7BTLQPIounk/b9dw3AaE= github.com/go-jose/go-jose/v4 v4.0.5/go.mod h1:s3P1lRrkT8igV8D9OjyL4WRyHvjB6a4JSllnOrmmBOA= +github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= +github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= @@ -38,12 +46,18 @@ github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0kt github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A= github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EOqtpKwwwHI= -github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= -github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= +github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0= +github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/googleapis/enterprise-certificate-proxy v0.3.6 h1:GW/XbdyBFQ8Qe+YAmFU9uHLo7OnF5tL52HFAgMmyrf4= +github.com/googleapis/enterprise-certificate-proxy v0.3.6/go.mod h1:MkHOF77EYAE7qfSuSS9PU6g4Nt4e11cnsDUowfwewLA= +github.com/googleapis/gax-go/v2 v2.14.1 h1:hb0FFeiPaQskmvakKu5EbCbpntQn48jyHuvrkurSS/Q= +github.com/googleapis/gax-go/v2 v2.14.1/go.mod h1:Hb/NubMaVM88SrNkvl8X/o8XWwDJEPqouaLeN2IUxoA= github.com/goombaio/namegenerator v0.0.0-20181006234301-989e774b106e h1:XmA6L9IPRdUr28a+SK/oMchGgQy159wvzXA5tJ7l+40= github.com/goombaio/namegenerator v0.0.0-20181006234301-989e774b106e/go.mod h1:AFIo+02s+12CEg8Gzz9kzhCbmbq6JcKNrhHffCGA9z4= github.com/gorilla/handlers v1.5.2 h1:cLTUSsNkgcwhgRqvCNmdbRWG0A3N4F+M2nWKdScwyEE= @@ -72,8 +86,8 @@ github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= -github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= -github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= @@ -123,14 +137,30 @@ github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOf github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/txn2/txeh v1.5.5 h1:UN4e/lCK5HGw/gGAi2GCVrNKg0GTCUWs7gs5riaZlz4= github.com/txn2/txeh v1.5.5/go.mod h1:qYzGG9kCzeVEI12geK4IlanHWY8X4uy/I3NcW7mk8g4= +go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= +go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.60.0 h1:x7wzEgXfnzJcHDwStJT+mxOz4etr2EcexjqhBvmoakw= +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.60.0/go.mod h1:rg+RlpR5dKwaS95IyyZqj5Wd4E13lk/msnTS0Xl9lJM= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.60.0 h1:sbiXRNDSWJOTobXh5HyQKjq6wUC5tNybqjIqDpAY4CU= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.60.0/go.mod h1:69uWxva0WgAA/4bu2Yy70SLDBwZXuQ6PbBpbsa5iZrQ= +go.opentelemetry.io/otel v1.35.0 h1:xKWKPxrxB6OtMCbmMY021CqC45J+3Onta9MqjhnusiQ= +go.opentelemetry.io/otel v1.35.0/go.mod h1:UEqy8Zp11hpkUrL73gSlELM0DupHoiq72dR+Zqel/+Y= +go.opentelemetry.io/otel/metric v1.35.0 h1:0znxYu2SNyuMSQT4Y9WDWej0VpcsxkuklLa4/siN90M= +go.opentelemetry.io/otel/metric v1.35.0/go.mod h1:nKVFgxBZ2fReX6IlyW28MgZojkoAkJGaE8CpgeAU3oE= +go.opentelemetry.io/otel/sdk v1.35.0 h1:iPctf8iprVySXSKJffSS79eOjl9pvxV9ZqOWT0QejKY= +go.opentelemetry.io/otel/sdk v1.35.0/go.mod h1:+ga1bZliga3DxJ3CQGg3updiaAJoNECOgJREo9KHGQg= +go.opentelemetry.io/otel/sdk/metric v1.35.0 h1:1RriWBmCKgkeHEhM7a2uMjMUfP7MsOF5JpUCaEqEI9o= +go.opentelemetry.io/otel/sdk/metric v1.35.0/go.mod h1:is6XYCUMpcKi+ZsOvfluY5YstFnhW0BidkR+gL+qN+w= +go.opentelemetry.io/otel/trace v1.35.0 h1:dPpEfJu1sDIqruz7BHFG3c7528f6ddfSWfFDVt/xgMs= +go.opentelemetry.io/otel/trace v1.35.0/go.mod h1:WUk7DtFp1Aw2MkvqGdwiXYDZZNvA/1J8o6xRXLrIkyc= go.uber.org/automaxprocs v1.6.0 h1:O3y2/QNTOdbF+e/dpXNNW7Rx2hZ4sTIPyybbxyNqTUs= go.uber.org/automaxprocs v1.6.0/go.mod h1:ifeIMSnPZuznNm6jmdzmU3/bfk01Fe2fotchwEFJ8r8= golang.org/x/crypto v0.38.0 h1:jt+WWG8IZlBnVbomuhg2Mdq0+BBQaHbtqHEFEigjUV8= golang.org/x/crypto v0.38.0/go.mod h1:MvrbAqul58NNYPKnOra203SB9vpuZW0e+RRZV+Ggqjw= golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 h1:k/i9J1pBpvlfR+9QsetwPyERsqu1GIbi967PQMq3Ivc= golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w= -golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c= -golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= +golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY= +golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E= golang.org/x/oauth2 v0.29.0 h1:WdYw2tdTK1S8olAzWHdgeqfy+Mtm9XNhv/xJsY65d98= golang.org/x/oauth2 v0.29.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT8= golang.org/x/sync v0.14.0 h1:woo0S4Yywslg6hp4eUFjTVOyKt0RookbpAHG4c1HmhQ= @@ -139,8 +169,20 @@ golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4= golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA= +golang.org/x/time v0.11.0 h1:/bpjEDfN9tkoN/ryeYHnv5hcMlc8ncjMcM4XBk5NWV0= +golang.org/x/time v0.11.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= golang.zx2c4.com/wireguard/wgctrl v0.0.0-20221104135756-97bc4ad4a1cb h1:9aqVcYEDHmSNb0uOWukxV5lHV09WqiSiCuhEgWNETLY= golang.zx2c4.com/wireguard/wgctrl v0.0.0-20221104135756-97bc4ad4a1cb/go.mod h1:mQqgjkW8GQQcJQsbBvK890TKqUK1DfKWkuBGbOkuMHQ= +google.golang.org/api v0.229.0 h1:p98ymMtqeJ5i3lIBMj5MpR9kzIIgzpHHh8vQ+vgAzx8= +google.golang.org/api v0.229.0/go.mod h1:wyDfmq5g1wYJWn29O22FDWN48P7Xcz0xz+LBpptYvB0= +google.golang.org/genproto/googleapis/api v0.0.0-20250106144421-5f5ef82da422 h1:GVIKPyP/kLIyVOgOnTwFOrvQaQUzOzGMCxgFUOEmm24= +google.golang.org/genproto/googleapis/api v0.0.0-20250106144421-5f5ef82da422/go.mod h1:b6h1vNKhxaSoEI+5jc3PJUCustfli/mRab7295pY7rw= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250414145226-207652e42e2e h1:ztQaXfzEXTmCBvbtWYRhJxW+0iJcz2qXfd38/e9l7bA= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250414145226-207652e42e2e/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A= +google.golang.org/grpc v1.71.1 h1:ffsFWr7ygTUscGPI0KKK6TLrGz0476KUvvsbqWK0rPI= +google.golang.org/grpc v1.71.1/go.mod h1:H0GRtasmQOh9LkFoCPDu3ZrwUtD1YGE+b2vYBYd/8Ec= +google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY= +google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc h1:2gGKlE2+asNV9m7xrywl36YYNnBG5ZQ0r/BOOxqPpmk= gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc/go.mod h1:m7x9LTH6d71AHyAX77c9yqWCCa3UKHcVEj9y7hAtKDk= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/logic/auth.go b/logic/auth.go index b2fa7d0d..cb427484 100644 --- a/logic/auth.go +++ b/logic/auth.go @@ -8,15 +8,16 @@ import ( "fmt" "time" + "github.com/gravitl/netmaker/db" + "github.com/gravitl/netmaker/schema" + "github.com/go-playground/validator/v10" "golang.org/x/crypto/bcrypt" "golang.org/x/exp/slog" "github.com/gravitl/netmaker/database" - "github.com/gravitl/netmaker/db" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/models" - "github.com/gravitl/netmaker/schema" ) const ( @@ -31,7 +32,8 @@ func ClearSuperUserCache() { superUser = models.User{} } -var InitializeAuthProvider = func() string { return "" } +var ResetAuthProvider = func() {} +var ResetIDPSyncHook = func() {} // HasSuperAdmin - checks if server has an superadmin/owner func HasSuperAdmin() (bool, error) { @@ -303,11 +305,55 @@ func UpdateUser(userchange, user *models.User) (*models.User, error) { if err := IsNetworkRolesValid(userchange.NetworkRoles); err != nil { return userchange, errors.New("invalid network roles: " + err.Error()) } + + if userchange.DisplayName != "" { + if user.ExternalIdentityProviderID != "" && + user.DisplayName != userchange.DisplayName { + return userchange, errors.New("display name cannot be updated for external user") + } + + user.DisplayName = userchange.DisplayName + } + + if user.ExternalIdentityProviderID != "" && + userchange.AccountDisabled != user.AccountDisabled { + return userchange, errors.New("account status cannot be updated for external user") + } + // Reset Gw Access for service users go UpdateUserGwAccess(*user, *userchange) if userchange.PlatformRoleID != "" { user.PlatformRoleID = userchange.PlatformRoleID } + + for groupID := range userchange.UserGroups { + _, ok := user.UserGroups[groupID] + if !ok { + group, err := GetUserGroup(groupID) + if err != nil { + return userchange, err + } + + if group.ExternalIdentityProviderID != "" { + return userchange, errors.New("cannot modify membership of external groups") + } + } + } + + for groupID := range user.UserGroups { + _, ok := userchange.UserGroups[groupID] + if !ok { + group, err := GetUserGroup(groupID) + if err != nil { + return userchange, err + } + + if group.ExternalIdentityProviderID != "" { + return userchange, errors.New("cannot modify membership of external groups") + } + } + } + user.UserGroups = userchange.UserGroups user.NetworkRoles = userchange.NetworkRoles AddGlobalNetRolesToAdmins(user) diff --git a/logic/jwts.go b/logic/jwts.go index 64eea672..3872568c 100644 --- a/logic/jwts.go +++ b/logic/jwts.go @@ -163,9 +163,11 @@ func GetUserNameFromToken(authtoken string) (username string, err error) { // VerifyUserToken func will used to Verify the JWT Token while using APIS func VerifyUserToken(tokenString string) (username string, issuperadmin, isadmin bool, err error) { claims := &models.UserClaims{} + if tokenString == servercfg.GetMasterKey() && servercfg.GetMasterKey() != "" { return MasterUser, true, true, nil } + token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) { return jwtSecretKey, nil }) diff --git a/logic/security.go b/logic/security.go index f4c8a23e..843ffa27 100644 --- a/logic/security.go +++ b/logic/security.go @@ -1,6 +1,7 @@ package logic import ( + "errors" "net/http" "strings" @@ -32,6 +33,19 @@ func SecurityCheck(reqAdmin bool, next http.Handler) http.HandlerFunc { ReturnErrorResponse(w, r, FormatError(err, "unauthorized")) return } + + user, err := GetUser(username) + if err != nil { + ReturnErrorResponse(w, r, FormatError(err, "unauthorized")) + return + } + + if user.AccountDisabled { + err = errors.New("user account disabled") + ReturnErrorResponse(w, r, FormatError(err, "unauthorized")) + return + } + // detect masteradmin if username == MasterUser { r.Header.Set("ismaster", "yes") diff --git a/logic/settings.go b/logic/settings.go index 4b062743..d704b6e5 100644 --- a/logic/settings.go +++ b/logic/settings.go @@ -272,6 +272,26 @@ func GetAzureTenant() string { return GetServerSettings().AzureTenant } +// IsSyncEnabled returns whether auth provider sync is enabled. +func IsSyncEnabled() bool { + return GetServerSettings().SyncEnabled +} + +// GetIDPSyncInterval returns the interval at which the netmaker should sync +// data from IDP. +func GetIDPSyncInterval() time.Duration { + syncInterval, err := time.ParseDuration(GetServerSettings().IDPSyncInterval) + if err != nil { + return 24 * time.Hour + } + + if syncInterval == 0 { + return 24 * time.Hour + } + + return syncInterval +} + // GetMetricsPort - get metrics port func GetMetricsPort() int { return GetServerSettings().MetricsPort diff --git a/logic/user_mgmt.go b/logic/user_mgmt.go index 7eb3de7b..c85a6b5b 100644 --- a/logic/user_mgmt.go +++ b/logic/user_mgmt.go @@ -50,6 +50,8 @@ var MigrateUserRoleAndGroups = func(u models.User) { } +var MigrateGroups = func() {} + var UpdateUserGwAccess = func(currentUser, changeUser models.User) {} var UpdateRole = func(r models.UserRolePermissionTemplate) error { return nil } diff --git a/logic/users.go b/logic/users.go index 168fd928..4b9b5171 100644 --- a/logic/users.go +++ b/logic/users.go @@ -41,13 +41,15 @@ func GetReturnUser(username string) (models.ReturnUser, error) { // ToReturnUser - gets a user as a return user func ToReturnUser(user models.User) models.ReturnUser { return models.ReturnUser{ - UserName: user.UserName, - PlatformRoleID: user.PlatformRoleID, - AuthType: user.AuthType, - UserGroups: user.UserGroups, - NetworkRoles: user.NetworkRoles, - RemoteGwIDs: user.RemoteGwIDs, - LastLoginTime: user.LastLoginTime, + UserName: user.UserName, + DisplayName: user.DisplayName, + AccountDisabled: user.AccountDisabled, + AuthType: user.AuthType, + RemoteGwIDs: user.RemoteGwIDs, + UserGroups: user.UserGroups, + PlatformRoleID: user.PlatformRoleID, + NetworkRoles: user.NetworkRoles, + LastLoginTime: user.LastLoginTime, } } @@ -78,7 +80,7 @@ func GetSuperAdmin() (models.ReturnUser, error) { return models.ReturnUser{}, err } for _, user := range users { - if user.IsSuperAdmin { + if user.IsSuperAdmin || user.PlatformRoleID == models.SuperAdminRole { return user, nil } } @@ -113,7 +115,7 @@ func IsPendingUser(username string) bool { return false } -func ListPendingUsers() ([]models.ReturnUser, error) { +func ListPendingReturnUsers() ([]models.ReturnUser, error) { pendingUsers := []models.ReturnUser{} records, err := database.FetchRecords(database.PENDING_USERS_TABLE_NAME) if err != nil && !database.IsEmptyRecord(err) { @@ -129,6 +131,22 @@ func ListPendingUsers() ([]models.ReturnUser, error) { return pendingUsers, nil } +func ListPendingUsers() ([]models.User, error) { + var pendingUsers []models.User + records, err := database.FetchRecords(database.PENDING_USERS_TABLE_NAME) + if err != nil && !database.IsEmptyRecord(err) { + return pendingUsers, err + } + for _, record := range records { + var u models.User + err = json.Unmarshal([]byte(record), &u) + if err == nil { + pendingUsers = append(pendingUsers, u) + } + } + return pendingUsers, nil +} + func GetUserMap() (map[string]models.User, error) { userMap := make(map[string]models.User) records, err := database.FetchRecords(database.USERS_TABLE_NAME) diff --git a/migrate/migrate.go b/migrate/migrate.go index 484aa879..83c240f2 100644 --- a/migrate/migrate.go +++ b/migrate/migrate.go @@ -29,6 +29,7 @@ func Run() { assignSuperAdmin() createDefaultTagsAndPolicies() removeOldUserGrps() + syncGroups() syncUsers() updateHosts() updateNodes() @@ -393,6 +394,10 @@ func MigrateEmqx() { } +func syncGroups() { + logic.MigrateGroups() +} + func syncUsers() { // create default network user roles for existing networks if servercfg.IsPro { diff --git a/models/settings.go b/models/settings.go index c7aa394c..7ae0fb45 100644 --- a/models/settings.go +++ b/models/settings.go @@ -15,7 +15,13 @@ type ServerSettings struct { OIDCIssuer string `json:"oidcissuer"` ClientID string `json:"client_id"` ClientSecret string `json:"client_secret"` + SyncEnabled bool `json:"sync_enabled"` + GoogleAdminEmail string `json:"google_admin_email"` + GoogleSACredsJson string `json:"google_sa_creds_json"` AzureTenant string `json:"azure_tenant"` + UserFilters []string `json:"user_filters"` + GroupFilters []string `json:"group_filters"` + IDPSyncInterval string `json:"idp_sync_interval"` Telemetry string `json:"telemetry"` BasicAuth bool `json:"basic_auth"` JwtValidityDuration int `json:"jwt_validity_duration"` diff --git a/models/user_mgmt.go b/models/user_mgmt.go index 17d6689e..94fa9595 100644 --- a/models/user_mgmt.go +++ b/models/user_mgmt.go @@ -144,17 +144,20 @@ type CreateGroupReq struct { } type UserGroup struct { - ID UserGroupID `json:"id"` - Default bool `json:"default"` - Name string `json:"name"` - NetworkRoles map[NetworkID]map[UserRoleID]struct{} `json:"network_roles"` - MetaData string `json:"meta_data"` + ID UserGroupID `json:"id"` + ExternalIdentityProviderID string `json:"external_identity_provider_id"` + Default bool `json:"default"` + Name string `json:"name"` + NetworkRoles map[NetworkID]map[UserRoleID]struct{} `json:"network_roles"` + MetaData string `json:"meta_data"` } // User struct - struct for Users type User struct { UserName string `json:"username" bson:"username" validate:"min=3,in_charset|email"` ExternalIdentityProviderID string `json:"external_identity_provider_id"` + DisplayName string `json:"display_name"` + AccountDisabled bool `json:"account_disabled"` Password string `json:"password" bson:"password" validate:"required,min=5"` IsAdmin bool `json:"isadmin" bson:"isadmin"` // deprecated IsSuperAdmin bool `json:"issuperadmin"` // deprecated @@ -174,15 +177,18 @@ type ReturnUserWithRolesAndGroups struct { // ReturnUser - return user struct type ReturnUser struct { - UserName string `json:"username"` - IsAdmin bool `json:"isadmin"` - IsSuperAdmin bool `json:"issuperadmin"` - AuthType AuthType `json:"auth_type"` - RemoteGwIDs map[string]struct{} `json:"remote_gw_ids"` // deprecated - UserGroups map[UserGroupID]struct{} `json:"user_group_ids"` - PlatformRoleID UserRoleID `json:"platform_role_id"` - NetworkRoles map[NetworkID]map[UserRoleID]struct{} `json:"network_roles"` - LastLoginTime time.Time `json:"last_login_time"` + UserName string `json:"username"` + ExternalIdentityProviderID string `json:"external_identity_provider_id"` + DisplayName string `json:"display_name"` + AccountDisabled bool `json:"account_disabled"` + IsAdmin bool `json:"isadmin"` + IsSuperAdmin bool `json:"issuperadmin"` + AuthType AuthType `json:"auth_type"` + RemoteGwIDs map[string]struct{} `json:"remote_gw_ids"` // deprecated + UserGroups map[UserGroupID]struct{} `json:"user_group_ids"` + PlatformRoleID UserRoleID `json:"platform_role_id"` + NetworkRoles map[NetworkID]map[UserRoleID]struct{} `json:"network_roles"` + LastLoginTime time.Time `json:"last_login_time"` } // UserAuthParams - user auth params struct diff --git a/pro/auth/auth.go b/pro/auth/auth.go index 215a6263..70c9de13 100644 --- a/pro/auth/auth.go +++ b/pro/auth/auth.go @@ -34,6 +34,7 @@ const ( // OAuthUser - generic OAuth strategy user type OAuthUser struct { + ID string `json:"id" bson:"id"` Name string `json:"name" bson:"name"` Email string `json:"email" bson:"email"` Login string `json:"login" bson:"login"` @@ -63,6 +64,17 @@ func getCurrentAuthFunctions() map[string]interface{} { } } +// ResetAuthProvider resets the auth provider configuration. +func ResetAuthProvider() { + settings := logic.GetServerSettings() + + if settings.AuthProvider == "" { + auth_provider = nil + } + + InitializeAuthProvider() +} + // InitializeAuthProvider - initializes the auth provider if any is present func InitializeAuthProvider() string { var functions = getCurrentAuthFunctions() diff --git a/pro/auth/azure-ad.go b/pro/auth/azure-ad.go index e67edc3e..f6ce5d63 100644 --- a/pro/auth/azure-ad.go +++ b/pro/auth/azure-ad.go @@ -111,7 +111,7 @@ func handleAzureCallback(w http.ResponseWriter, r *http.Request) { logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } - user.ExternalIdentityProviderID = content.UserPrincipalName + user.ExternalIdentityProviderID = content.ID if err = logic.CreateUser(&user); err != nil { handleSomethingWentWrong(w) return @@ -124,7 +124,9 @@ func handleAzureCallback(w http.ResponseWriter, r *http.Request) { return } err = logic.InsertPendingUser(&models.User{ - UserName: content.Email, + UserName: content.Email, + ExternalIdentityProviderID: content.ID, + AuthType: models.OAuth, }) if err != nil { handleSomethingWentWrong(w) @@ -152,6 +154,12 @@ func handleAzureCallback(w http.ResponseWriter, r *http.Request) { handleOauthUserNotFound(w) return } + + if user.AccountDisabled { + handleUserAccountDisabled(w) + return + } + userRole, err := logic.GetRole(user.PlatformRoleID) if err != nil { handleSomethingWentWrong(w) diff --git a/pro/auth/error.go b/pro/auth/error.go index d9beecb6..a5863753 100644 --- a/pro/auth/error.go +++ b/pro/auth/error.go @@ -113,6 +113,8 @@ var notallowedtosignup = fmt.Sprintf(htmlBaseTemplate, `

Your email is not al var authTypeMismatch = fmt.Sprintf(htmlBaseTemplate, `

It looks like you already have an account with us using Basic Authentication.

To continue, please log in with your existing credentials or reset your password if needed.

`) +var userAccountDisabled = fmt.Sprintf(htmlBaseTemplate, `

Your account has been disabled. Please contact your administrator for more information about your account.

`) + func handleOauthUserNotFound(response http.ResponseWriter) { response.Header().Set("Content-Type", "text/html; charset=utf-8") response.WriteHeader(http.StatusNotFound) @@ -166,3 +168,9 @@ func handleAuthTypeMismatch(response http.ResponseWriter) { response.WriteHeader(http.StatusBadRequest) response.Write([]byte(authTypeMismatch)) } + +func handleUserAccountDisabled(response http.ResponseWriter) { + response.Header().Set("Content-Type", "text/html; charset=utf-8") + response.WriteHeader(http.StatusUnauthorized) + response.Write([]byte(userAccountDisabled)) +} diff --git a/pro/auth/github.go b/pro/auth/github.go index 0d543f48..a7d468d5 100644 --- a/pro/auth/github.go +++ b/pro/auth/github.go @@ -111,7 +111,7 @@ func handleGithubCallback(w http.ResponseWriter, r *http.Request) { logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } - user.ExternalIdentityProviderID = content.Login + user.ExternalIdentityProviderID = content.ID if err = logic.CreateUser(&user); err != nil { handleSomethingWentWrong(w) return @@ -124,7 +124,9 @@ func handleGithubCallback(w http.ResponseWriter, r *http.Request) { return } err = logic.InsertPendingUser(&models.User{ - UserName: content.Email, + UserName: content.Email, + ExternalIdentityProviderID: content.ID, + AuthType: models.OAuth, }) if err != nil { handleSomethingWentWrong(w) @@ -143,6 +145,12 @@ func handleGithubCallback(w http.ResponseWriter, r *http.Request) { handleOauthUserNotFound(w) return } + + if user.AccountDisabled { + handleUserAccountDisabled(w) + return + } + userRole, err := logic.GetRole(user.PlatformRoleID) if err != nil { handleSomethingWentWrong(w) diff --git a/pro/auth/google.go b/pro/auth/google.go index 97bf3143..767645f9 100644 --- a/pro/auth/google.go +++ b/pro/auth/google.go @@ -105,7 +105,9 @@ func handleGoogleCallback(w http.ResponseWriter, r *http.Request) { return } err = logic.InsertPendingUser(&models.User{ - UserName: content.Email, + UserName: content.Email, + ExternalIdentityProviderID: content.ID, + AuthType: models.OAuth, }) if err != nil { handleSomethingWentWrong(w) @@ -136,6 +138,11 @@ func handleGoogleCallback(w http.ResponseWriter, r *http.Request) { return } + if user.AccountDisabled { + handleUserAccountDisabled(w) + return + } + userRole, err := logic.GetRole(user.PlatformRoleID) if err != nil { handleSomethingWentWrong(w) diff --git a/pro/auth/headless_callback.go b/pro/auth/headless_callback.go index de0627ca..2e13ddc9 100644 --- a/pro/auth/headless_callback.go +++ b/pro/auth/headless_callback.go @@ -64,7 +64,9 @@ func HandleHeadlessSSOCallback(w http.ResponseWriter, r *http.Request) { if err != nil { if database.IsEmptyRecord(err) { // user must not exist, so try to make one err = logic.InsertPendingUser(&models.User{ - UserName: userClaims.getUserName(), + UserName: userClaims.getUserName(), + ExternalIdentityProviderID: userClaims.ID, + AuthType: models.OAuth, }) if err != nil { handleSomethingWentWrong(w) diff --git a/pro/auth/oidc.go b/pro/auth/oidc.go index 37f8918f..30fdd08f 100644 --- a/pro/auth/oidc.go +++ b/pro/auth/oidc.go @@ -102,7 +102,7 @@ func handleOIDCCallback(w http.ResponseWriter, r *http.Request) { logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } - user.ExternalIdentityProviderID = content.Email + user.ExternalIdentityProviderID = content.ID if err = logic.CreateUser(&user); err != nil { handleSomethingWentWrong(w) return @@ -115,7 +115,9 @@ func handleOIDCCallback(w http.ResponseWriter, r *http.Request) { return } err = logic.InsertPendingUser(&models.User{ - UserName: content.Email, + UserName: content.Email, + ExternalIdentityProviderID: content.ID, + AuthType: models.OAuth, }) if err != nil { handleSomethingWentWrong(w) @@ -143,6 +145,12 @@ func handleOIDCCallback(w http.ResponseWriter, r *http.Request) { handleOauthUserNotFound(w) return } + + if user.AccountDisabled { + handleUserAccountDisabled(w) + return + } + userRole, err := logic.GetRole(user.PlatformRoleID) if err != nil { handleSomethingWentWrong(w) @@ -224,6 +232,8 @@ func getOIDCUserInfo(state string, code string) (u *OAuthUser, e error) { e = fmt.Errorf("error when claiming OIDCUser: \"%s\"", err.Error()) } + u.ID = idToken.Subject + return } diff --git a/pro/auth/sync.go b/pro/auth/sync.go new file mode 100644 index 00000000..bd1f2ec0 --- /dev/null +++ b/pro/auth/sync.go @@ -0,0 +1,281 @@ +package auth + +import ( + "fmt" + "github.com/gravitl/netmaker/database" + "github.com/gravitl/netmaker/logger" + "github.com/gravitl/netmaker/logic" + "github.com/gravitl/netmaker/models" + "github.com/gravitl/netmaker/pro/idp" + "github.com/gravitl/netmaker/pro/idp/azure" + "github.com/gravitl/netmaker/pro/idp/google" + proLogic "github.com/gravitl/netmaker/pro/logic" + "strings" + "time" +) + +var syncTicker *time.Ticker + +func StartSyncHook() { + syncTicker = time.NewTicker(logic.GetIDPSyncInterval()) + + for range syncTicker.C { + err := SyncFromIDP() + if err != nil { + logger.Log(0, "failed to sync from idp: ", err.Error()) + } else { + logger.Log(0, "sync from idp complete") + } + } +} + +func ResetIDPSyncHook() { + if syncTicker != nil { + syncTicker.Stop() + if logic.IsSyncEnabled() { + go StartSyncHook() + } + } +} + +func SyncFromIDP() error { + settings := logic.GetServerSettings() + + var idpClient idp.Client + var idpUsers []idp.User + var idpGroups []idp.Group + var err error + + switch settings.AuthProvider { + case "google": + idpClient, err = google.NewGoogleWorkspaceClient() + if err != nil { + return err + } + case "azure-ad": + idpClient = azure.NewAzureEntraIDClient() + default: + if settings.AuthProvider != "" { + return fmt.Errorf("invalid auth provider: %s", settings.AuthProvider) + } + } + + if settings.AuthProvider != "" && idpClient != nil { + idpUsers, err = idpClient.GetUsers() + if err != nil { + return err + } + + idpGroups, err = idpClient.GetGroups() + if err != nil { + return err + } + } + + err = syncUsers(idpUsers) + if err != nil { + return err + } + + return syncGroups(idpGroups) +} + +func syncUsers(idpUsers []idp.User) error { + dbUsers, err := logic.GetUsersDB() + if err != nil && !database.IsEmptyRecord(err) { + return err + } + + password, err := logic.FetchPassValue("") + if err != nil { + return err + } + + idpUsersMap := make(map[string]struct{}) + for _, user := range idpUsers { + idpUsersMap[user.Username] = struct{}{} + } + + dbUsersMap := make(map[string]models.User) + for _, user := range dbUsers { + dbUsersMap[user.UserName] = user + } + + filters := logic.GetServerSettings().UserFilters + + for _, user := range idpUsers { + var found bool + for _, filter := range filters { + if strings.HasPrefix(user.Username, filter) { + found = true + break + } + } + + // if there are filters but none of them match, then skip this user. + if len(filters) > 0 && !found { + continue + } + + dbUser, ok := dbUsersMap[user.Username] + if !ok { + // create the user only if it doesn't exist. + err = logic.CreateUser(&models.User{ + UserName: user.Username, + ExternalIdentityProviderID: user.ID, + DisplayName: user.DisplayName, + AccountDisabled: user.AccountDisabled, + Password: password, + AuthType: models.OAuth, + PlatformRoleID: models.ServiceUser, + }) + if err != nil { + return err + } + } else if dbUser.AuthType == models.OAuth { + if dbUser.AccountDisabled != user.AccountDisabled || + dbUser.DisplayName != user.DisplayName || + dbUser.ExternalIdentityProviderID != user.ID { + + dbUser.AccountDisabled = user.AccountDisabled + dbUser.DisplayName = user.DisplayName + dbUser.ExternalIdentityProviderID = user.ID + + err = logic.UpsertUser(dbUser) + if err != nil { + return err + } + } + } else { + logger.Log(0, "user with username "+user.Username+" already exists, skipping creation") + continue + } + } + + for _, user := range dbUsersMap { + if user.ExternalIdentityProviderID == "" { + continue + } + if _, ok := idpUsersMap[user.UserName]; !ok { + // delete the user if it has been deleted on idp. + err = logic.DeleteUser(user.UserName) + if err != nil { + return err + } + } + } + + return nil +} + +func syncGroups(idpGroups []idp.Group) error { + dbGroups, err := proLogic.ListUserGroups() + if err != nil && !database.IsEmptyRecord(err) { + return err + } + + dbUsers, err := logic.GetUsersDB() + if err != nil && !database.IsEmptyRecord(err) { + return err + } + + idpGroupsMap := make(map[string]struct{}) + for _, group := range idpGroups { + idpGroupsMap[group.ID] = struct{}{} + } + + dbGroupsMap := make(map[string]models.UserGroup) + for _, group := range dbGroups { + if group.ExternalIdentityProviderID != "" { + dbGroupsMap[group.ExternalIdentityProviderID] = group + } + } + + dbUsersMap := make(map[string]models.User) + for _, user := range dbUsers { + if user.ExternalIdentityProviderID != "" { + dbUsersMap[user.ExternalIdentityProviderID] = user + } + } + + modifiedUsers := make(map[string]struct{}) + + filters := logic.GetServerSettings().GroupFilters + + for _, group := range idpGroups { + var found bool + for _, filter := range filters { + if strings.HasPrefix(group.Name, filter) { + found = true + break + } + } + + // if there are filters but none of them match, then skip this group. + if len(filters) > 0 && !found { + continue + } + + dbGroup, ok := dbGroupsMap[group.ID] + if !ok { + err := proLogic.CreateUserGroup(models.UserGroup{ + ExternalIdentityProviderID: group.ID, + Default: false, + Name: group.Name, + }) + if err != nil { + return err + } + } else { + dbGroup.Name = group.Name + err = proLogic.UpdateUserGroup(dbGroup) + if err != nil { + return err + } + } + + groupMembersMap := make(map[string]struct{}) + for _, member := range group.Members { + groupMembersMap[member] = struct{}{} + } + + for _, user := range dbUsers { + // use dbGroup.Name because the group name may have been changed on idp. + _, inNetmakerGroup := user.UserGroups[models.UserGroupID(dbGroup.Name)] + _, inIDPGroup := groupMembersMap[user.ExternalIdentityProviderID] + + if inNetmakerGroup && !inIDPGroup { + // use dbGroup.Name because the group name may have been changed on idp. + delete(dbUsersMap[user.ExternalIdentityProviderID].UserGroups, models.UserGroupID(dbGroup.Name)) + modifiedUsers[user.ExternalIdentityProviderID] = struct{}{} + } + + if !inNetmakerGroup && inIDPGroup { + // use dbGroup.Name because the group name may have been changed on idp. + dbUsersMap[user.ExternalIdentityProviderID].UserGroups[models.UserGroupID(dbGroup.Name)] = struct{}{} + modifiedUsers[user.ExternalIdentityProviderID] = struct{}{} + } + } + } + + for userID := range modifiedUsers { + err = logic.UpsertUser(dbUsersMap[userID]) + if err != nil { + return err + } + } + + for _, group := range dbGroups { + if group.ExternalIdentityProviderID != "" { + if _, ok := idpGroupsMap[group.ExternalIdentityProviderID]; !ok { + // delete the group if it has been deleted on idp. + err = proLogic.DeleteUserGroup(group.ID) + if err != nil { + return err + } + } + } + } + + return nil +} diff --git a/pro/controllers/users.go b/pro/controllers/users.go index 94aade4c..272305d2 100644 --- a/pro/controllers/users.go +++ b/pro/controllers/users.go @@ -62,6 +62,9 @@ func UserHandlers(r *mux.Router) { r.HandleFunc("/api/users/{username}/remote_access_gw/{remote_access_gateway_id}", logic.SecurityCheck(true, http.HandlerFunc(removeUserFromRemoteAccessGW))).Methods(http.MethodDelete) r.HandleFunc("/api/users/{username}/remote_access_gw", logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(getUserRemoteAccessGwsV1)))).Methods(http.MethodGet) r.HandleFunc("/api/users/ingress/{ingress_id}", logic.SecurityCheck(true, http.HandlerFunc(ingressGatewayUsers))).Methods(http.MethodGet) + + r.HandleFunc("/api/idp/sync", logic.SecurityCheck(true, http.HandlerFunc(syncIDP))).Methods(http.MethodPost) + r.HandleFunc("/api/idp", logic.SecurityCheck(true, http.HandlerFunc(removeIDPIntegration))).Methods(http.MethodDelete) } // swagger:route POST /api/v1/users/invite-signup user userInviteSignUp @@ -546,6 +549,9 @@ func updateUserGroup(w http.ResponseWriter, r *http.Request) { logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } + + userGroup.ExternalIdentityProviderID = currUserG.ExternalIdentityProviderID + err = proLogic.UpdateUserGroup(userGroup) if err != nil { logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) @@ -1423,7 +1429,7 @@ func getPendingUsers(w http.ResponseWriter, r *http.Request) { // set header. w.Header().Set("Content-Type", "application/json") - users, err := logic.ListPendingUsers() + users, err := logic.ListPendingReturnUsers() if err != nil { logger.Log(0, "failed to fetch users: ", err.Error()) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) @@ -1461,9 +1467,11 @@ func approvePendingUser(w http.ResponseWriter, r *http.Request) { return } if err = logic.CreateUser(&models.User{ - UserName: user.UserName, - Password: newPass, - PlatformRoleID: models.ServiceUser, + UserName: user.UserName, + ExternalIdentityProviderID: user.ExternalIdentityProviderID, + Password: newPass, + AuthType: user.AuthType, + PlatformRoleID: models.ServiceUser, }); err != nil { logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("failed to create user: %s", err), "internal")) return @@ -1505,7 +1513,7 @@ func deletePendingUser(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") var params = mux.Vars(r) username := params["username"] - users, err := logic.ListPendingUsers() + users, err := logic.ListPendingReturnUsers() if err != nil { logger.Log(0, "failed to fetch users: ", err.Error()) @@ -1569,3 +1577,82 @@ func deleteAllPendingUsers(w http.ResponseWriter, r *http.Request) { }) logic.ReturnSuccessResponse(w, r, "cleared all pending users") } + +// @Summary Sync users and groups from idp. +// @Router /api/idp/sync [post] +// @Tags IDP +// @Success 200 {object} models.SuccessResponse +func syncIDP(w http.ResponseWriter, r *http.Request) { + go func() { + err := proAuth.SyncFromIDP() + if err != nil { + logger.Log(0, "failed to sync from idp: ", err.Error()) + } else { + logger.Log(0, "sync from idp complete") + } + }() + + logic.ReturnSuccessResponse(w, r, "starting sync from idp") +} + +// @Summary Remove idp integration. +// @Router /api/idp [delete] +// @Tags IDP +// @Success 200 {object} models.SuccessResponse +// @Failure 500 {object} models.ErrorResponse +func removeIDPIntegration(w http.ResponseWriter, r *http.Request) { + superAdmin, err := logic.GetSuperAdmin() + if err != nil { + logic.ReturnErrorResponse( + w, + r, + logic.FormatError(fmt.Errorf("failed to get superadmin: %v", err), "internal"), + ) + return + } + + if superAdmin.AuthType == models.OAuth { + logic.ReturnErrorResponse( + w, + r, + logic.FormatError(fmt.Errorf("cannot remove idp integration with superadmin oauth user"), "badrequest"), + ) + return + } + + settings := logic.GetServerSettings() + settings.AuthProvider = "" + settings.OIDCIssuer = "" + settings.ClientID = "" + settings.ClientSecret = "" + settings.SyncEnabled = false + settings.GoogleAdminEmail = "" + settings.GoogleSACredsJson = "" + settings.AzureTenant = "" + settings.UserFilters = nil + settings.GroupFilters = nil + + err = logic.UpsertServerSettings(settings) + if err != nil { + logic.ReturnErrorResponse( + w, + r, + logic.FormatError(fmt.Errorf("failed to remove idp integration: %v", err), "internal"), + ) + return + } + + proAuth.ResetAuthProvider() + proAuth.ResetIDPSyncHook() + + go func() { + err := proAuth.SyncFromIDP() + if err != nil { + logger.Log(0, "failed to sync from idp: ", err.Error()) + } else { + logger.Log(0, "sync from idp complete") + } + }() + + logic.ReturnSuccessResponse(w, r, "removed idp integration successfully") +} diff --git a/pro/idp/azure/azure.go b/pro/idp/azure/azure.go new file mode 100644 index 00000000..57fc736d --- /dev/null +++ b/pro/idp/azure/azure.go @@ -0,0 +1,167 @@ +package azure + +import ( + "encoding/json" + "errors" + "fmt" + "github.com/gravitl/netmaker/logic" + "github.com/gravitl/netmaker/pro/idp" + "net/http" + "net/url" +) + +type Client struct { + clientID string + clientSecret string + tenantID string +} + +func NewAzureEntraIDClient() *Client { + settings := logic.GetServerSettings() + + return &Client{ + clientID: settings.ClientID, + clientSecret: settings.ClientSecret, + tenantID: settings.AzureTenant, + } +} + +func (a *Client) GetUsers() ([]idp.User, error) { + accessToken, err := a.getAccessToken() + if err != nil { + return nil, err + } + + client := &http.Client{} + req, err := http.NewRequest("GET", "https://graph.microsoft.com/v1.0/users?$select=id,userPrincipalName,displayName,accountEnabled", nil) + if err != nil { + return nil, err + } + + req.Header.Add("Authorization", "Bearer "+accessToken) + req.Header.Add("Accept", "application/json") + + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer func() { + _ = resp.Body.Close() + }() + + var users getUsersResponse + err = json.NewDecoder(resp.Body).Decode(&users) + if err != nil { + return nil, err + } + + retval := make([]idp.User, len(users.Value)) + for i, user := range users.Value { + retval[i] = idp.User{ + ID: user.Id, + Username: user.UserPrincipalName, + DisplayName: user.DisplayName, + AccountDisabled: !user.AccountEnabled, + } + } + + return retval, nil +} + +func (a *Client) GetGroups() ([]idp.Group, error) { + accessToken, err := a.getAccessToken() + if err != nil { + return nil, err + } + + client := &http.Client{} + req, err := http.NewRequest("GET", "https://graph.microsoft.com/v1.0/groups?$select=id,displayName&$expand=members($select=id)", nil) + if err != nil { + return nil, err + } + + req.Header.Add("Authorization", "Bearer "+accessToken) + req.Header.Add("Accept", "application/json") + + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer func() { + _ = resp.Body.Close() + }() + + var groups getGroupsResponse + err = json.NewDecoder(resp.Body).Decode(&groups) + if err != nil { + return nil, err + } + + retval := make([]idp.Group, len(groups.Value)) + for i, group := range groups.Value { + retvalMembers := make([]string, len(group.Members)) + for j, member := range group.Members { + retvalMembers[j] = member.Id + } + + retval[i] = idp.Group{ + ID: group.Id, + Name: group.DisplayName, + Members: retvalMembers, + } + } + + return retval, nil +} + +func (a *Client) getAccessToken() (string, error) { + tokenURL := fmt.Sprintf("https://login.microsoftonline.com/%s/oauth2/v2.0/token", a.tenantID) + + var data = url.Values{} + data.Set("grant_type", "client_credentials") + data.Set("client_id", a.clientID) + data.Set("client_secret", a.clientSecret) + data.Set("scope", "https://graph.microsoft.com/.default") + + resp, err := http.PostForm(tokenURL, data) + if err != nil { + return "", err + } + defer func() { + _ = resp.Body.Close() + }() + + var tokenResp map[string]interface{} + err = json.NewDecoder(resp.Body).Decode(&tokenResp) + if err != nil { + return "", err + } + + if token, ok := tokenResp["access_token"].(string); ok { + return token, nil + } + + return "", errors.New("failed to get access token") +} + +type getUsersResponse struct { + OdataContext string `json:"@odata.context"` + Value []struct { + Id string `json:"id"` + UserPrincipalName string `json:"userPrincipalName"` + DisplayName string `json:"displayName"` + AccountEnabled bool `json:"accountEnabled"` + } `json:"value"` +} + +type getGroupsResponse struct { + OdataContext string `json:"@odata.context"` + Value []struct { + Id string `json:"id"` + DisplayName string `json:"displayName"` + Members []struct { + OdataType string `json:"@odata.type"` + Id string `json:"id"` + } `json:"members"` + } `json:"value"` +} diff --git a/pro/idp/google/google.go b/pro/idp/google/google.go new file mode 100644 index 00000000..12f961c2 --- /dev/null +++ b/pro/idp/google/google.go @@ -0,0 +1,115 @@ +package google + +import ( + "context" + "encoding/base64" + "encoding/json" + "github.com/gravitl/netmaker/logic" + "github.com/gravitl/netmaker/pro/idp" + admindir "google.golang.org/api/admin/directory/v1" + "google.golang.org/api/impersonate" + "google.golang.org/api/option" +) + +type Client struct { + service *admindir.Service +} + +func NewGoogleWorkspaceClient() (*Client, error) { + settings := logic.GetServerSettings() + + credsJson, err := base64.StdEncoding.DecodeString(settings.GoogleSACredsJson) + if err != nil { + return nil, err + } + + credsJsonMap := make(map[string]interface{}) + err = json.Unmarshal(credsJson, &credsJsonMap) + if err != nil { + return nil, err + } + + source, err := impersonate.CredentialsTokenSource( + context.TODO(), + impersonate.CredentialsConfig{ + TargetPrincipal: credsJsonMap["client_email"].(string), + Scopes: []string{ + admindir.AdminDirectoryUserReadonlyScope, + admindir.AdminDirectoryGroupReadonlyScope, + admindir.AdminDirectoryGroupMemberReadonlyScope, + }, + Subject: settings.GoogleAdminEmail, + }, + option.WithCredentialsJSON(credsJson), + ) + if err != nil { + return nil, err + } + + service, err := admindir.NewService( + context.TODO(), + option.WithTokenSource(source), + ) + if err != nil { + return nil, err + } + + return &Client{ + service: service, + }, nil +} + +func (g *Client) GetUsers() ([]idp.User, error) { + var retval []idp.User + err := g.service.Users.List(). + Customer("my_customer"). + Fields("users(id,primaryEmail,name,suspended)", "nextPageToken"). + Pages(context.TODO(), func(users *admindir.Users) error { + for _, user := range users.Users { + retval = append(retval, idp.User{ + ID: user.Id, + Username: user.PrimaryEmail, + DisplayName: user.Name.FullName, + AccountDisabled: user.Suspended, + }) + } + + return nil + }) + + return retval, err +} + +func (g *Client) GetGroups() ([]idp.Group, error) { + var retval []idp.Group + err := g.service.Groups.List(). + Customer("my_customer"). + Fields("groups(id,name)", "nextPageToken"). + Pages(context.TODO(), func(groups *admindir.Groups) error { + for _, group := range groups.Groups { + var retvalMembers []string + err := g.service.Members.List(group.Id). + Fields("members(id)", "nextPageToken"). + Pages(context.TODO(), func(members *admindir.Members) error { + for _, member := range members.Members { + retvalMembers = append(retvalMembers, member.Id) + } + + return nil + }) + if err != nil { + return err + } + + retval = append(retval, idp.Group{ + ID: group.Id, + Name: group.Name, + Members: retvalMembers, + }) + } + + return nil + }) + + return retval, err +} diff --git a/pro/idp/idp.go b/pro/idp/idp.go new file mode 100644 index 00000000..a76b65ff --- /dev/null +++ b/pro/idp/idp.go @@ -0,0 +1,19 @@ +package idp + +type Client interface { + GetUsers() ([]User, error) + GetGroups() ([]Group, error) +} + +type User struct { + ID string + Username string + DisplayName string + AccountDisabled bool +} + +type Group struct { + ID string + Name string + Members []string +} diff --git a/pro/initialize.go b/pro/initialize.go index 67705a3f..3b3dc942 100644 --- a/pro/initialize.go +++ b/pro/initialize.go @@ -93,6 +93,7 @@ func InitPro() { } proLogic.LoadNodeMetricsToCache() proLogic.InitFailOverCache() + auth.StartSyncHook() email.Init() proLogic.EventWatcher() }) @@ -135,12 +136,14 @@ func InitPro() { logic.UpdateUserGwAccess = proLogic.UpdateUserGwAccess logic.CreateDefaultUserPolicies = proLogic.CreateDefaultUserPolicies logic.MigrateUserRoleAndGroups = proLogic.MigrateUserRoleAndGroups + logic.MigrateGroups = proLogic.MigrateGroups logic.IntialiseGroups = proLogic.UserGroupsInit logic.AddGlobalNetRolesToAdmins = proLogic.AddGlobalNetRolesToAdmins logic.GetUserGroupsInNetwork = proLogic.GetUserGroupsInNetwork logic.GetUserGroup = proLogic.GetUserGroup logic.GetNodeStatus = proLogic.GetNodeStatus - logic.InitializeAuthProvider = auth.InitializeAuthProvider + logic.ResetAuthProvider = auth.ResetAuthProvider + logic.ResetIDPSyncHook = auth.ResetIDPSyncHook logic.EmailInit = email.Init logic.LogEvent = proLogic.LogEvent } diff --git a/pro/logic/migrate.go b/pro/logic/migrate.go index fedef3c9..5fac1ead 100644 --- a/pro/logic/migrate.go +++ b/pro/logic/migrate.go @@ -1,14 +1,75 @@ package logic import ( - "fmt" + "encoding/json" + "github.com/google/uuid" + "github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/models" ) +func MigrateGroups() { + groups, err := ListUserGroups() + if err != nil { + return + } + + groupMapping := make(map[models.UserGroupID]models.UserGroupID) + + for _, group := range groups { + if group.Default { + continue + } + + _, err := uuid.Parse(string(group.ID)) + if err == nil { + // group id is already an uuid, so no need to update + continue + } + + oldGroupID := group.ID + group.ID = models.UserGroupID(uuid.NewString()) + groupMapping[oldGroupID] = group.ID + + groupBytes, err := json.Marshal(group) + if err != nil { + continue + } + + err = database.Insert(group.ID.String(), string(groupBytes), database.USER_GROUPS_TABLE_NAME) + if err != nil { + continue + } + + err = database.DeleteRecord(database.USER_GROUPS_TABLE_NAME, oldGroupID.String()) + if err != nil { + continue + } + } + + users, err := logic.GetUsersDB() + if err != nil { + return + } + + for _, user := range users { + userGroups := make(map[models.UserGroupID]struct{}) + for groupID := range user.UserGroups { + newGroupID, ok := groupMapping[groupID] + if !ok { + userGroups[groupID] = struct{}{} + } else { + userGroups[newGroupID] = struct{}{} + } + } + + user.UserGroups = userGroups + logic.UpsertUser(user) + } +} + func MigrateUserRoleAndGroups(user models.User) { - var err error if user.PlatformRoleID == models.AdminRole || user.PlatformRoleID == models.SuperAdminRole { return } @@ -20,22 +81,21 @@ func MigrateUserRoleAndGroups(user models.User) { if err != nil { continue } - var g models.UserGroup + var groupID models.UserGroupID if user.PlatformRoleID == models.ServiceUser { - g, err = GetUserGroup(models.UserGroupID(fmt.Sprintf("%s-%s-grp", gwNode.Network, models.NetworkUser))) + groupID = GetDefaultNetworkUserGroupID(models.NetworkID(gwNode.Network)) } else { - g, err = GetUserGroup(models.UserGroupID(fmt.Sprintf("%s-%s-grp", - gwNode.Network, models.NetworkAdmin))) + groupID = GetDefaultNetworkAdminGroupID(models.NetworkID(gwNode.Network)) } if err != nil { continue } - user.UserGroups[g.ID] = struct{}{} + user.UserGroups[groupID] = struct{}{} } } if len(user.NetworkRoles) > 0 { for netID, netRoles := range user.NetworkRoles { - var g models.UserGroup + var groupID models.UserGroupID adminAccess := false for netRoleID := range netRoles { permTemplate, err := logic.GetRole(netRoleID) @@ -47,19 +107,15 @@ func MigrateUserRoleAndGroups(user models.User) { } if user.PlatformRoleID == models.ServiceUser { - g, err = GetUserGroup(models.UserGroupID(fmt.Sprintf("%s-%s-grp", netID, models.NetworkUser))) + groupID = GetDefaultNetworkUserGroupID(netID) } else { - role := models.NetworkUser if adminAccess { - role = models.NetworkAdmin + groupID = GetDefaultNetworkAdminGroupID(netID) + } else { + groupID = GetDefaultNetworkUserGroupID(netID) } - g, err = GetUserGroup(models.UserGroupID(fmt.Sprintf("%s-%s-grp", - netID, role))) } - if err != nil { - continue - } - user.UserGroups[g.ID] = struct{}{} + user.UserGroups[groupID] = struct{}{} user.NetworkRoles = make(map[models.NetworkID]map[models.UserRoleID]struct{}) } diff --git a/pro/logic/user_mgmt.go b/pro/logic/user_mgmt.go index 0be37520..9f530968 100644 --- a/pro/logic/user_mgmt.go +++ b/pro/logic/user_mgmt.go @@ -4,6 +4,7 @@ import ( "encoding/json" "errors" "fmt" + "github.com/google/uuid" "time" "github.com/gravitl/netmaker/database" @@ -14,6 +15,11 @@ import ( "golang.org/x/exp/slog" ) +var ( + globalNetworksAdminGroupID = models.UserGroupID(fmt.Sprintf("global-%s-grp", models.NetworkAdmin)) + globalNetworksUserGroupID = models.UserGroupID(fmt.Sprintf("global-%s-grp", models.NetworkUser)) +) + var ServiceUserPermissionTemplate = models.UserRolePermissionTemplate{ ID: models.ServiceUser, Default: true, @@ -111,7 +117,7 @@ func UserRolesInit() { func UserGroupsInit() { // create default network groups var NetworkGlobalAdminGroup = models.UserGroup{ - ID: models.UserGroupID(fmt.Sprintf("global-%s-grp", models.NetworkAdmin)), + ID: globalNetworksAdminGroupID, Default: true, Name: "All Networks Admin Group", MetaData: "can manage configuration of all networks", @@ -122,11 +128,11 @@ func UserGroupsInit() { }, } var NetworkGlobalUserGroup = models.UserGroup{ - ID: models.UserGroupID(fmt.Sprintf("global-%s-grp", models.NetworkUser)), + ID: globalNetworksUserGroupID, Name: "All Networks User Group", Default: true, NetworkRoles: map[models.NetworkID]map[models.UserRoleID]struct{}{ - models.NetworkID(models.AllNetworks): { + models.AllNetworks: { models.UserRoleID(fmt.Sprintf("global-%s", models.NetworkUser)): {}, }, }, @@ -215,7 +221,7 @@ func CreateDefaultNetworkRolesAndGroups(netID models.NetworkID) { // create default network groups var NetworkAdminGroup = models.UserGroup{ - ID: models.UserGroupID(fmt.Sprintf("%s-%s-grp", netID, models.NetworkAdmin)), + ID: GetDefaultNetworkAdminGroupID(netID), Name: fmt.Sprintf("%s Admin Group", netID), Default: true, NetworkRoles: map[models.NetworkID]map[models.UserRoleID]struct{}{ @@ -226,7 +232,7 @@ func CreateDefaultNetworkRolesAndGroups(netID models.NetworkID) { MetaData: fmt.Sprintf("can manage your network `%s` configuration including adding and removing devices.", netID), } var NetworkUserGroup = models.UserGroup{ - ID: models.UserGroupID(fmt.Sprintf("%s-%s-grp", netID, models.NetworkUser)), + ID: GetDefaultNetworkUserGroupID(netID), Name: fmt.Sprintf("%s User Group", netID), Default: true, NetworkRoles: map[models.NetworkID]map[models.UserRoleID]struct{}{ @@ -248,28 +254,29 @@ func DeleteNetworkRoles(netID string) { if err != nil { return } - defaultUserGrp := fmt.Sprintf("%s-%s-grp", netID, models.NetworkUser) - defaultAdminGrp := fmt.Sprintf("%s-%s-grp", netID, models.NetworkAdmin) + + defaultAdminGrpID := GetDefaultNetworkAdminGroupID(models.NetworkID(netID)) + defaultUserGrpID := GetDefaultNetworkUserGroupID(models.NetworkID(netID)) for _, user := range users { var upsert bool if _, ok := user.NetworkRoles[models.NetworkID(netID)]; ok { delete(user.NetworkRoles, models.NetworkID(netID)) upsert = true } - if _, ok := user.UserGroups[models.UserGroupID(defaultUserGrp)]; ok { - delete(user.UserGroups, models.UserGroupID(defaultUserGrp)) + if _, ok := user.UserGroups[defaultUserGrpID]; ok { + delete(user.UserGroups, defaultUserGrpID) upsert = true } - if _, ok := user.UserGroups[models.UserGroupID(defaultAdminGrp)]; ok { - delete(user.UserGroups, models.UserGroupID(defaultAdminGrp)) + if _, ok := user.UserGroups[defaultAdminGrpID]; ok { + delete(user.UserGroups, defaultAdminGrpID) upsert = true } if upsert { logic.UpsertUser(user) } } - database.DeleteRecord(database.USER_GROUPS_TABLE_NAME, defaultUserGrp) - database.DeleteRecord(database.USER_GROUPS_TABLE_NAME, defaultAdminGrp) + database.DeleteRecord(database.USER_GROUPS_TABLE_NAME, defaultUserGrpID.String()) + database.DeleteRecord(database.USER_GROUPS_TABLE_NAME, defaultAdminGrpID.String()) userGs, _ := ListUserGroups() for _, userGI := range userGs { if _, ok := userGI.NetworkRoles[models.NetworkID(netID)]; ok { @@ -524,14 +531,31 @@ func ValidateUpdateGroupReq(g models.UserGroup) error { // CreateUserGroup - creates new user group func CreateUserGroup(g models.UserGroup) error { - // check if role already exists - if g.ID == "" { - return errors.New("group id cannot be empty") + // default groups are currently created directly in the db. + // this check is only to prevent future errors. + if g.Default && g.ID == "" { + return errors.New("group id cannot be empty for default group") } - _, err := database.FetchRecord(database.USER_GROUPS_TABLE_NAME, g.ID.String()) - if err == nil { - return errors.New("group already exists") + + if !g.Default { + g.ID = models.UserGroupID(uuid.NewString()) } + + // check if the group already exists + if g.Name == "" { + return errors.New("group name cannot be empty") + } + groups, err := ListUserGroups() + if err != nil { + return err + } + + for _, group := range groups { + if group.Name == g.Name { + return errors.New("group already exists") + } + } + d, err := json.Marshal(g) if err != nil { return err @@ -553,6 +577,14 @@ func GetUserGroup(gid models.UserGroupID) (models.UserGroup, error) { return ug, nil } +func GetDefaultNetworkAdminGroupID(networkID models.NetworkID) models.UserGroupID { + return models.UserGroupID(fmt.Sprintf("%s-%s-grp", networkID, models.NetworkAdmin)) +} + +func GetDefaultNetworkUserGroupID(networkID models.NetworkID) models.UserGroupID { + return models.UserGroupID(fmt.Sprintf("%s-%s-grp", networkID, models.NetworkUser)) +} + // ListUserGroups - lists user groups func ListUserGroups() ([]models.UserGroup, error) { data, err := database.FetchRecords(database.USER_GROUPS_TABLE_NAME) @@ -573,7 +605,7 @@ func ListUserGroups() ([]models.UserGroup, error) { // UpdateUserGroup - updates new user group func UpdateUserGroup(g models.UserGroup) error { - // check if group exists + // check if the group exists if g.ID == "" { return errors.New("group id cannot be empty") } @@ -591,7 +623,7 @@ func UpdateUserGroup(g models.UserGroup) error { // DeleteUserGroup - deletes user group func DeleteUserGroup(gid models.UserGroupID) error { users, err := logic.GetUsersDB() - if err != nil { + if err != nil && !database.IsEmptyRecord(err) { return err } for _, user := range users { @@ -1110,6 +1142,8 @@ func CreateDefaultUserPolicies(netID models.NetworkID) { } if !logic.IsAclExists(fmt.Sprintf("%s.%s-grp", netID, models.NetworkAdmin)) { + networkAdminGroupID := GetDefaultNetworkAdminGroupID(netID) + defaultUserAcl := models.Acl{ ID: fmt.Sprintf("%s.%s-grp", netID, models.NetworkAdmin), Name: "Network Admin", @@ -1122,11 +1156,11 @@ func CreateDefaultUserPolicies(netID models.NetworkID) { Src: []models.AclPolicyTag{ { ID: models.UserGroupAclID, - Value: fmt.Sprintf("%s-%s-grp", netID, models.NetworkAdmin), + Value: globalNetworksAdminGroupID.String(), }, { ID: models.UserGroupAclID, - Value: fmt.Sprintf("global-%s-grp", models.NetworkAdmin), + Value: networkAdminGroupID.String(), }, }, Dst: []models.AclPolicyTag{ @@ -1143,6 +1177,8 @@ func CreateDefaultUserPolicies(netID models.NetworkID) { } if !logic.IsAclExists(fmt.Sprintf("%s.%s-grp", netID, models.NetworkUser)) { + networkUserGroupID := GetDefaultNetworkUserGroupID(netID) + defaultUserAcl := models.Acl{ ID: fmt.Sprintf("%s.%s-grp", netID, models.NetworkUser), Name: "Network User", @@ -1155,11 +1191,11 @@ func CreateDefaultUserPolicies(netID models.NetworkID) { Src: []models.AclPolicyTag{ { ID: models.UserGroupAclID, - Value: fmt.Sprintf("%s-%s-grp", netID, models.NetworkUser), + Value: globalNetworksAdminGroupID.String(), }, { ID: models.UserGroupAclID, - Value: fmt.Sprintf("global-%s-grp", models.NetworkUser), + Value: networkUserGroupID.String(), }, }, @@ -1198,5 +1234,6 @@ func AddGlobalNetRolesToAdmins(u *models.User) { return } u.UserGroups = make(map[models.UserGroupID]struct{}) - u.UserGroups[models.UserGroupID(fmt.Sprintf("global-%s-grp", models.NetworkAdmin))] = struct{}{} + + u.UserGroups[globalNetworksAdminGroupID] = struct{}{} }