NET-2000: Api access tokens (#3418)

* feat: api access tokens

* revoke all user tokens

* redefine access token api routes, add auto egress option to enrollment keys

* fix revoked tokens to be unauthorized

* remove unused functions

* convert access token to sql schema

* switch access token to sql schema

* revoke token generated by an user

* add user token creation restriction by user role

* add forbidden check for access token creation

* revoke user token when group or role is changed

* add default group to admin users on update

* fix token removal on user update

* fix token removal on user update
This commit is contained in:
Abhishek K 2025-04-23 20:21:42 +04:00 committed by GitHub
parent d5bdc723fc
commit ca95954fb5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
28 changed files with 507 additions and 200 deletions

View file

@ -160,6 +160,7 @@ func createEnrollmentKey(w http.ResponseWriter, r *http.Request) {
enrollmentKeyBody.Unlimited,
relayId,
false,
enrollmentKeyBody.AutoEgress,
)
if err != nil {
logger.Log(0, r.Header.Get("user"), "failed to create enrollment key:", err.Error())

View file

@ -6,7 +6,9 @@ import (
"fmt"
"net/http"
"reflect"
"time"
"github.com/google/uuid"
"github.com/gorilla/mux"
"github.com/gorilla/websocket"
"github.com/gravitl/netmaker/auth"
@ -37,7 +39,164 @@ func userHandlers(r *mux.Router) {
r.HandleFunc("/api/v1/users", logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(getUserV1)))).Methods(http.MethodGet)
r.HandleFunc("/api/users", logic.SecurityCheck(true, http.HandlerFunc(getUsers))).Methods(http.MethodGet)
r.HandleFunc("/api/v1/users/roles", logic.SecurityCheck(true, http.HandlerFunc(ListRoles))).Methods(http.MethodGet)
r.HandleFunc("/api/v1/users/access_token", logic.SecurityCheck(true, http.HandlerFunc(createUserAccessToken))).Methods(http.MethodPost)
r.HandleFunc("/api/v1/users/access_token", logic.SecurityCheck(true, http.HandlerFunc(getUserAccessTokens))).Methods(http.MethodGet)
r.HandleFunc("/api/v1/users/access_token", logic.SecurityCheck(true, http.HandlerFunc(deleteUserAccessTokens))).Methods(http.MethodDelete)
}
// @Summary Authenticate a user to retrieve an authorization token
// @Router /api/v1/users/{username}/access_token [post]
// @Tags Auth
// @Accept json
// @Param body body models.UserAuthParams true "Authentication parameters"
// @Success 200 {object} models.SuccessResponse
// @Failure 400 {object} models.ErrorResponse
// @Failure 401 {object} models.ErrorResponse
// @Failure 500 {object} models.ErrorResponse
func createUserAccessToken(w http.ResponseWriter, r *http.Request) {
// Auth request consists of Mac Address and Password (from node that is authorizing
// in case of Master, auth is ignored and mac is set to "mastermac"
var req models.UserAccessToken
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
logger.Log(0, "error decoding request body: ",
err.Error())
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return
}
if req.Name == "" {
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("name is required"), "badrequest"))
return
}
if req.UserName == "" {
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("username is required"), "badrequest"))
return
}
caller, err := logic.GetUser(r.Header.Get("user"))
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "unauthorized"))
return
}
user, err := logic.GetUser(req.UserName)
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "unauthorized"))
return
}
if caller.UserName != user.UserName && caller.PlatformRoleID != models.SuperAdminRole {
if caller.PlatformRoleID == models.AdminRole {
if user.PlatformRoleID == models.SuperAdminRole || user.PlatformRoleID == models.AdminRole {
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("not enough permissions to create token for user "+user.UserName), logic.Forbidden_Msg))
return
}
} else {
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("not enough permissions to create token for user "+user.UserName), logic.Forbidden_Msg))
return
}
}
req.ID = uuid.New().String()
req.CreatedBy = r.Header.Get("user")
req.CreatedAt = time.Now()
jwt, err := logic.CreateUserAccessJwtToken(user.UserName, user.PlatformRoleID, req.ExpiresAt, req.ID)
if jwt == "" {
// very unlikely that err is !nil and no jwt returned, but handle it anyways.
logic.ReturnErrorResponse(
w,
r,
logic.FormatError(errors.New("error creating access token "+err.Error()), "internal"),
)
return
}
err = req.Create()
if err != nil {
logic.ReturnErrorResponse(
w,
r,
logic.FormatError(errors.New("error creating access token "+err.Error()), "internal"),
)
return
}
logic.ReturnSuccessResponseWithJson(w, r, models.SuccessfulUserLoginResponse{
AuthToken: jwt,
UserName: req.UserName,
}, "api access token has generated for user "+req.UserName)
}
// @Summary Authenticate a user to retrieve an authorization token
// @Router /api/v1/users/{username}/access_token [post]
// @Tags Auth
// @Accept json
// @Param body body models.UserAuthParams true "Authentication parameters"
// @Success 200 {object} models.SuccessResponse
// @Failure 400 {object} models.ErrorResponse
// @Failure 401 {object} models.ErrorResponse
// @Failure 500 {object} models.ErrorResponse
func getUserAccessTokens(w http.ResponseWriter, r *http.Request) {
username := r.URL.Query().Get("username")
if username == "" {
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("username is required"), "badrequest"))
return
}
logic.ReturnSuccessResponseWithJson(w, r, (&models.UserAccessToken{UserName: username}).ListByUser(), "fetched api access tokens for user "+username)
}
// @Summary Authenticate a user to retrieve an authorization token
// @Router /api/v1/users/{username}/access_token [post]
// @Tags Auth
// @Accept json
// @Param body body models.UserAuthParams true "Authentication parameters"
// @Success 200 {object} models.SuccessResponse
// @Failure 400 {object} models.ErrorResponse
// @Failure 401 {object} models.ErrorResponse
// @Failure 500 {object} models.ErrorResponse
func deleteUserAccessTokens(w http.ResponseWriter, r *http.Request) {
id := r.URL.Query().Get("id")
if id == "" {
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("id is required"), "badrequest"))
return
}
a := models.UserAccessToken{
ID: id,
}
err := a.Get()
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("id is required"), "badrequest"))
return
}
caller, err := logic.GetUser(r.Header.Get("user"))
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "unauthorized"))
return
}
user, err := logic.GetUser(a.UserName)
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "unauthorized"))
return
}
if caller.UserName != user.UserName && caller.PlatformRoleID != models.SuperAdminRole {
if caller.PlatformRoleID == models.AdminRole {
if user.PlatformRoleID == models.SuperAdminRole || user.PlatformRoleID == models.AdminRole {
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("not enough permissions to delete token of user "+user.UserName), logic.Forbidden_Msg))
return
}
} else {
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("not enough permissions to delete token of user "+user.UserName), logic.Forbidden_Msg))
return
}
}
err = (&models.UserAccessToken{ID: id}).Delete()
if err != nil {
logic.ReturnErrorResponse(
w,
r,
logic.FormatError(errors.New("error deleting access token "+err.Error()), "internal"),
)
return
}
logic.ReturnSuccessResponseWithJson(w, r, nil, "revoked access token")
}
// @Summary Authenticate a user to retrieve an authorization token
@ -592,7 +751,10 @@ func updateUser(w http.ResponseWriter, r *http.Request) {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "forbidden"))
return
}
logic.AddGlobalNetRolesToAdmins(&userchange)
if userchange.PlatformRoleID != user.PlatformRoleID || !logic.CompareMaps(user.UserGroups, userchange.UserGroups) {
(&models.UserAccessToken{UserName: user.UserName}).DeleteAllUserTokens()
}
user, err = logic.UpdateUser(&userchange, user)
if err != nil {
logger.Log(0, username,
@ -671,17 +833,12 @@ func deleteUser(w http.ResponseWriter, r *http.Request) {
return
}
}
success, err := logic.DeleteUser(username)
err = logic.DeleteUser(username)
if err != nil {
logger.Log(0, username,
"failed to delete user: ", err.Error())
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
} else if !success {
err := errors.New("delete unsuccessful")
logger.Log(0, username, err.Error())
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return
}
// check and delete extclient with this ownerID
go func() {

View file

@ -1,18 +1,12 @@
package database
import (
"crypto/rand"
"encoding/json"
"errors"
"sync"
"time"
"github.com/google/uuid"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/netclient/ncutils"
"github.com/gravitl/netmaker/servercfg"
"golang.org/x/crypto/nacl/box"
)
const (
@ -25,6 +19,8 @@ const (
DELETED_NODES_TABLE_NAME = "deletednodes"
// USERS_TABLE_NAME - users table
USERS_TABLE_NAME = "users"
// ACCESS_TOKENS_TABLE_NAME - access tokens table
ACCESS_TOKENS_TABLE_NAME = "user_access_tokens"
// USER_PERMISSIONS_TABLE_NAME - user permissions table
USER_PERMISSIONS_TABLE_NAME = "user_permissions"
// CERTS_TABLE_NAME - certificates table
@ -129,6 +125,7 @@ var Tables = []string{
TAG_TABLE_NAME,
ACLS_TABLE_NAME,
PEER_ACK_TABLE,
// ACCESS_TOKENS_TABLE_NAME,
}
func getCurrentDB() map[string]interface{} {
@ -160,7 +157,7 @@ func InitializeDatabase() error {
time.Sleep(2 * time.Second)
}
createTables()
return initializeUUID()
return nil
}
func createTables() {
@ -173,35 +170,17 @@ func CreateTable(tableName string) error {
return getCurrentDB()[CREATE_TABLE].(func(string) error)(tableName)
}
// IsJSONString - checks if valid json
func IsJSONString(value string) bool {
var jsonInt interface{}
var nodeInt models.Node
return json.Unmarshal([]byte(value), &jsonInt) == nil || json.Unmarshal([]byte(value), &nodeInt) == nil
}
// Insert - inserts object into db
func Insert(key string, value string, tableName string) error {
dbMutex.Lock()
defer dbMutex.Unlock()
if key != "" && value != "" && IsJSONString(value) {
if key != "" && value != "" {
return getCurrentDB()[INSERT].(func(string, string, string) error)(key, value, tableName)
} else {
return errors.New("invalid insert " + key + " : " + value)
}
}
// InsertPeer - inserts peer into db
func InsertPeer(key string, value string) error {
dbMutex.Lock()
defer dbMutex.Unlock()
if key != "" && value != "" && IsJSONString(value) {
return getCurrentDB()[INSERT_PEER].(func(string, string) error)(key, value)
} else {
return errors.New("invalid peer insert " + key + " : " + value)
}
}
// DeleteRecord - deletes a record from db
func DeleteRecord(tableName string, key string) error {
dbMutex.Lock()
@ -243,44 +222,6 @@ func FetchRecords(tableName string) (map[string]string, error) {
return getCurrentDB()[FETCH_ALL].(func(string) (map[string]string, error))(tableName)
}
// initializeUUID - create a UUID record for server if none exists
func initializeUUID() error {
records, err := FetchRecords(SERVER_UUID_TABLE_NAME)
if err != nil {
if !IsEmptyRecord(err) {
return err
}
} else if len(records) > 0 {
return nil
}
// setup encryption keys
var trafficPubKey, trafficPrivKey, errT = box.GenerateKey(rand.Reader) // generate traffic keys
if errT != nil {
return errT
}
tPriv, err := ncutils.ConvertKeyToBytes(trafficPrivKey)
if err != nil {
return err
}
tPub, err := ncutils.ConvertKeyToBytes(trafficPubKey)
if err != nil {
return err
}
telemetry := models.Telemetry{
UUID: uuid.NewString(),
TrafficKeyPriv: tPriv,
TrafficKeyPub: tPub,
}
telJSON, err := json.Marshal(&telemetry)
if err != nil {
return err
}
return Insert(SERVER_UUID_RECORD_KEY, string(telJSON), SERVER_UUID_TABLE_NAME)
}
// CloseDB - closes a database gracefully
func CloseDB() {
getCurrentDB()[CLOSE_DB].(func())()

View file

@ -59,7 +59,7 @@ func pgCreateTable(tableName string) error {
}
func pgInsert(key string, value string, tableName string) error {
if key != "" && value != "" && IsJSONString(value) {
if key != "" && value != "" {
insertSQL := "INSERT INTO " + tableName + " (key, value) VALUES ($1, $2) ON CONFLICT (key) DO UPDATE SET value = $3;"
statement, err := PGDB.Prepare(insertSQL)
if err != nil {
@ -77,7 +77,7 @@ func pgInsert(key string, value string, tableName string) error {
}
func pgInsertPeer(key string, value string) error {
if key != "" && value != "" && IsJSONString(value) {
if key != "" && value != "" {
err := pgInsert(key, value, PEERS_TABLE_NAME)
if err != nil {
return err

View file

@ -43,7 +43,7 @@ func rqliteCreateTable(tableName string) error {
}
func rqliteInsert(key string, value string, tableName string) error {
if key != "" && value != "" && IsJSONString(value) {
if key != "" && value != "" {
_, err := RQliteDatabase.WriteOne("INSERT OR REPLACE INTO " + tableName + " (key, value) VALUES ('" + key + "', '" + value + "')")
if err != nil {
return err
@ -54,7 +54,7 @@ func rqliteInsert(key string, value string, tableName string) error {
}
func rqliteInsertPeer(key string, value string) error {
if key != "" && value != "" && IsJSONString(value) {
if key != "" && value != "" {
_, err := RQliteDatabase.WriteOne("INSERT OR REPLACE INTO " + PEERS_TABLE_NAME + " (key, value) VALUES ('" + key + "', '" + value + "')")
if err != nil {
return err

View file

@ -61,7 +61,7 @@ func sqliteCreateTable(tableName string) error {
}
func sqliteInsert(key string, value string, tableName string) error {
if key != "" && value != "" && IsJSONString(value) {
if key != "" && value != "" {
insertSQL := "INSERT OR REPLACE INTO " + tableName + " (key, value) VALUES (?, ?)"
statement, err := SqliteDB.Prepare(insertSQL)
if err != nil {
@ -78,7 +78,7 @@ func sqliteInsert(key string, value string, tableName string) error {
}
func sqliteInsertPeer(key string, value string) error {
if key != "" && value != "" && IsJSONString(value) {
if key != "" && value != "" {
err := sqliteInsert(key, value, PEERS_TABLE_NAME)
if err != nil {
return err

View file

@ -1,59 +0,0 @@
package database
import (
"encoding/json"
"strings"
)
// SetPeers - sets peers for a network
func SetPeers(newPeers map[string]string, networkName string) bool {
areEqual := PeersAreEqual(newPeers, networkName)
if !areEqual {
jsonData, err := json.Marshal(newPeers)
if err != nil {
return false
}
InsertPeer(networkName, string(jsonData))
return true
}
return !areEqual
}
// GetPeers - gets peers for a given network
func GetPeers(networkName string) (map[string]string, error) {
record, err := FetchRecord(PEERS_TABLE_NAME, networkName)
if err != nil && !IsEmptyRecord(err) {
return nil, err
}
currentDataMap := make(map[string]string)
if IsEmptyRecord(err) {
return currentDataMap, nil
}
err = json.Unmarshal([]byte(record), &currentDataMap)
return currentDataMap, err
}
// PeersAreEqual - checks if peers are the same
func PeersAreEqual(toCompare map[string]string, networkName string) bool {
currentDataMap, err := GetPeers(networkName)
if err != nil {
return false
}
if len(currentDataMap) != len(toCompare) {
return false
}
for k := range currentDataMap {
if toCompare[k] != currentDataMap[k] {
return false
}
}
return true
}
// IsEmptyRecord - checks for if it's an empty record error or not
func IsEmptyRecord(err error) bool {
if err == nil {
return false
}
return strings.Contains(err.Error(), NO_RECORD) || strings.Contains(err.Error(), NO_RECORDS)
}

11
database/utils.go Normal file
View file

@ -0,0 +1,11 @@
package database
import "strings"
// IsEmptyRecord - checks for if it's an empty record error or not
func IsEmptyRecord(err error) bool {
if err == nil {
return false
}
return strings.Contains(err.Error(), NO_RECORD) || strings.Contains(err.Error(), NO_RECORDS)
}

View file

@ -2,7 +2,9 @@ package db
import (
"errors"
"github.com/gravitl/netmaker/servercfg"
"os"
"github.com/gravitl/netmaker/config"
"gorm.io/gorm"
)
@ -14,10 +16,21 @@ type connector interface {
connect() (*gorm.DB, error)
}
// GetDB - gets the database type
func GetDB() string {
database := "sqlite"
if os.Getenv("DATABASE") != "" {
database = os.Getenv("DATABASE")
} else if config.Config.Server.Database != "" {
database = config.Config.Server.Database
}
return database
}
// newConnector detects the database being
// used and returns the corresponding connector.
func newConnector() (connector, error) {
switch servercfg.GetDB() {
switch GetDB() {
case "sqlite":
return &sqliteConnector{}, nil
case "postgres":

View file

@ -3,9 +3,10 @@ package db
import (
"context"
"errors"
"gorm.io/gorm"
"net/http"
"time"
"gorm.io/gorm"
)
type ctxKey string
@ -74,10 +75,6 @@ func Middleware(next http.Handler) http.Handler {
//
// The function panics, if a connection does not exist.
func FromContext(ctx context.Context) *gorm.DB {
db, ok := ctx.Value(dbCtxKey).(*gorm.DB)
if !ok {
panic(ErrDBNotFound)
}
return db
}

View file

@ -2,7 +2,10 @@ package db
import (
"fmt"
"github.com/gravitl/netmaker/servercfg"
"os"
"strconv"
"github.com/gravitl/netmaker/config"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
@ -15,7 +18,7 @@ type postgresConnector struct{}
// postgresConnector.connect connects and
// initializes a connection to postgres.
func (pg *postgresConnector) connect() (*gorm.DB, error) {
pgConf := servercfg.GetSQLConf()
pgConf := GetSQLConf()
dsn := fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=%s connect_timeout=5",
pgConf.Host,
@ -47,3 +50,68 @@ func (pg *postgresConnector) connect() (*gorm.DB, error) {
return db, nil
}
func GetSQLConf() config.SQLConfig {
var cfg config.SQLConfig
cfg.Host = GetSQLHost()
cfg.Port = GetSQLPort()
cfg.Username = GetSQLUser()
cfg.Password = GetSQLPass()
cfg.DB = GetSQLDB()
cfg.SSLMode = GetSQLSSLMode()
return cfg
}
func GetSQLHost() string {
host := "localhost"
if os.Getenv("SQL_HOST") != "" {
host = os.Getenv("SQL_HOST")
} else if config.Config.SQL.Host != "" {
host = config.Config.SQL.Host
}
return host
}
func GetSQLPort() int32 {
port := int32(5432)
envport, err := strconv.Atoi(os.Getenv("SQL_PORT"))
if err == nil && envport != 0 {
port = int32(envport)
} else if config.Config.SQL.Port != 0 {
port = config.Config.SQL.Port
}
return port
}
func GetSQLUser() string {
user := "postgres"
if os.Getenv("SQL_USER") != "" {
user = os.Getenv("SQL_USER")
} else if config.Config.SQL.Username != "" {
user = config.Config.SQL.Username
}
return user
}
func GetSQLPass() string {
pass := "nopass"
if os.Getenv("SQL_PASS") != "" {
pass = os.Getenv("SQL_PASS")
} else if config.Config.SQL.Password != "" {
pass = config.Config.SQL.Password
}
return pass
}
func GetSQLDB() string {
db := "netmaker"
if os.Getenv("SQL_DB") != "" {
db = os.Getenv("SQL_DB")
} else if config.Config.SQL.DB != "" {
db = config.Config.SQL.DB
}
return db
}
func GetSQLSSLMode() string {
sslmode := "disable"
if os.Getenv("SQL_SSL_MODE") != "" {
sslmode = os.Getenv("SQL_SSL_MODE")
} else if config.Config.SQL.SSLMode != "" {
sslmode = config.Config.SQL.SSLMode
}
return sslmode
}

View file

@ -1,11 +1,12 @@
package db
import (
"os"
"path/filepath"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"os"
"path/filepath"
)
// sqliteConnector for initializing and
@ -28,7 +29,7 @@ func (s *sqliteConnector) connect() (*gorm.DB, error) {
}
}
dbFilePath := filepath.Join("data", "netmaker_v1.db")
dbFilePath := filepath.Join("data", "netmaker.db")
// ensure netmaker_v1.db exists.
_, err = os.Stat(dbFilePath)

View file

@ -169,6 +169,7 @@ func CreateUser(user *models.User) error {
if IsOauthUser(user) == nil {
user.AuthType = models.OAuth
}
AddGlobalNetRolesToAdmins(user)
_, err = CreateUserJWT(user.UserName, user.PlatformRoleID)
if err != nil {
logger.Log(0, "failed to generate token", err.Error())
@ -186,7 +187,6 @@ func CreateUser(user *models.User) error {
logger.Log(0, "failed to insert user", err.Error())
return err
}
AddGlobalNetRolesToAdmins(*user)
return nil
}
@ -305,7 +305,7 @@ func UpdateUser(userchange, user *models.User) (*models.User, error) {
}
user.UserGroups = userchange.UserGroups
user.NetworkRoles = userchange.NetworkRoles
AddGlobalNetRolesToAdmins(*user)
AddGlobalNetRolesToAdmins(user)
err := ValidateUser(user)
if err != nil {
return &models.User{}, err
@ -349,19 +349,18 @@ func ValidateUser(user *models.User) error {
}
// DeleteUser - deletes a given user
func DeleteUser(user string) (bool, error) {
func DeleteUser(user string) error {
if userRecord, err := database.FetchRecord(database.USERS_TABLE_NAME, user); err != nil || len(userRecord) == 0 {
return false, errors.New("user does not exist")
return errors.New("user does not exist")
}
err := database.DeleteRecord(database.USERS_TABLE_NAME, user)
if err != nil {
return false, err
return err
}
go RemoveUserFromAclPolicy(user)
return true, nil
return (&models.UserAccessToken{UserName: user}).DeleteAllUserTokens()
}
func SetAuthSecret(secret string) error {

View file

@ -38,7 +38,7 @@ var (
)
// CreateEnrollmentKey - creates a new enrollment key in db
func CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string, groups []models.TagID, unlimited bool, relay uuid.UUID, defaultKey bool) (*models.EnrollmentKey, error) {
func CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string, groups []models.TagID, unlimited bool, relay uuid.UUID, defaultKey, autoEgress bool) (*models.EnrollmentKey, error) {
newKeyID, err := getUniqueEnrollmentID()
if err != nil {
return nil, err
@ -54,6 +54,7 @@ func CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string
Relay: relay,
Groups: groups,
Default: defaultKey,
AutoEgress: autoEgress,
}
if uses > 0 {
k.UsesRemaining = uses

View file

@ -14,35 +14,35 @@ func TestCreateEnrollmentKey(t *testing.T) {
database.InitializeDatabase()
defer database.CloseDB()
t.Run("Can_Not_Create_Key", func(t *testing.T) {
newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, nil, false, uuid.Nil, false)
newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, nil, false, uuid.Nil, false, false)
assert.Nil(t, newKey)
assert.NotNil(t, err)
assert.ErrorIs(t, err, models.ErrInvalidEnrollmentKey)
})
t.Run("Can_Create_Key_Uses", func(t *testing.T) {
newKey, err := CreateEnrollmentKey(1, time.Time{}, nil, nil, nil, false, uuid.Nil, false)
newKey, err := CreateEnrollmentKey(1, time.Time{}, nil, nil, nil, false, uuid.Nil, false, false)
assert.Nil(t, err)
assert.Equal(t, 1, newKey.UsesRemaining)
assert.True(t, newKey.IsValid())
})
t.Run("Can_Create_Key_Time", func(t *testing.T) {
newKey, err := CreateEnrollmentKey(0, time.Now().Add(time.Minute), nil, nil, nil, false, uuid.Nil, false)
newKey, err := CreateEnrollmentKey(0, time.Now().Add(time.Minute), nil, nil, nil, false, uuid.Nil, false, false)
assert.Nil(t, err)
assert.True(t, newKey.IsValid())
})
t.Run("Can_Create_Key_Unlimited", func(t *testing.T) {
newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, nil, true, uuid.Nil, false)
newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, nil, true, uuid.Nil, false, false)
assert.Nil(t, err)
assert.True(t, newKey.IsValid())
})
t.Run("Can_Create_Key_WithNetworks", func(t *testing.T) {
newKey, err := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil, false)
newKey, err := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil, false, false)
assert.Nil(t, err)
assert.True(t, newKey.IsValid())
assert.True(t, len(newKey.Networks) == 2)
})
t.Run("Can_Create_Key_WithTags", func(t *testing.T) {
newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, []string{"tag1", "tag2"}, nil, true, uuid.Nil, false)
newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, []string{"tag1", "tag2"}, nil, true, uuid.Nil, false, false)
assert.Nil(t, err)
assert.True(t, newKey.IsValid())
assert.True(t, len(newKey.Tags) == 2)
@ -62,7 +62,7 @@ func TestCreateEnrollmentKey(t *testing.T) {
func TestDelete_EnrollmentKey(t *testing.T) {
database.InitializeDatabase()
defer database.CloseDB()
newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil, false)
newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil, false, false)
t.Run("Can_Delete_Key", func(t *testing.T) {
assert.True(t, newKey.IsValid())
err := DeleteEnrollmentKey(newKey.Value, false)
@ -83,7 +83,7 @@ func TestDelete_EnrollmentKey(t *testing.T) {
func TestDecrement_EnrollmentKey(t *testing.T) {
database.InitializeDatabase()
defer database.CloseDB()
newKey, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, nil, false, uuid.Nil, false)
newKey, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, nil, false, uuid.Nil, false, false)
t.Run("Check_initial_uses", func(t *testing.T) {
assert.True(t, newKey.IsValid())
assert.Equal(t, newKey.UsesRemaining, 1)
@ -107,9 +107,9 @@ func TestDecrement_EnrollmentKey(t *testing.T) {
func TestUsability_EnrollmentKey(t *testing.T) {
database.InitializeDatabase()
defer database.CloseDB()
key1, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, nil, false, uuid.Nil, false)
key2, _ := CreateEnrollmentKey(0, time.Now().Add(time.Minute<<4), nil, nil, nil, false, uuid.Nil, false)
key3, _ := CreateEnrollmentKey(0, time.Time{}, nil, nil, nil, true, uuid.Nil, false)
key1, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, nil, false, uuid.Nil, false, false)
key2, _ := CreateEnrollmentKey(0, time.Now().Add(time.Minute<<4), nil, nil, nil, false, uuid.Nil, false, false)
key3, _ := CreateEnrollmentKey(0, time.Time{}, nil, nil, nil, true, uuid.Nil, false, false)
t.Run("Check if valid use key can be used", func(t *testing.T) {
assert.Equal(t, key1.UsesRemaining, 1)
ok := TryToUseEnrollmentKey(key1)
@ -145,7 +145,7 @@ func removeAllEnrollments() {
func TestTokenize_EnrollmentKeys(t *testing.T) {
database.InitializeDatabase()
defer database.CloseDB()
newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil, false)
newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil, false, false)
const defaultValue = "MwE5MwE5MwE5MwE5MwE5MwE5MwE5MwE5"
const b64value = "eyJzZXJ2ZXIiOiJhcGkubXlzZXJ2ZXIuY29tIiwidmFsdWUiOiJNd0U1TXdFNU13RTVNd0U1TXdFNU13RTVNd0U1TXdFNSJ9"
const serverAddr = "api.myserver.com"
@ -178,7 +178,7 @@ func TestTokenize_EnrollmentKeys(t *testing.T) {
func TestDeTokenize_EnrollmentKeys(t *testing.T) {
database.InitializeDatabase()
defer database.CloseDB()
newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil, false)
newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil, false, false)
const b64Value = "eyJzZXJ2ZXIiOiJhcGkubXlzZXJ2ZXIuY29tIiwidmFsdWUiOiJNd0U1TXdFNU13RTVNd0U1TXdFNU13RTVNd0U1TXdFNSJ9"
const serverAddr = "api.myserver.com"

View file

@ -52,12 +52,38 @@ func CreateJWT(uuid string, macAddress string, network string) (response string,
return "", err
}
// CreateUserJWT - creates a user jwt token
func CreateUserAccessJwtToken(username string, role models.UserRoleID, d time.Time, tokenID string) (response string, err error) {
claims := &models.UserClaims{
UserName: username,
Role: role,
TokenType: models.AccessTokenType,
Api: servercfg.ServerInfo.APIHost,
RacAutoDisable: servercfg.GetRacAutoDisable() && (role != models.SuperAdminRole && role != models.AdminRole),
RegisteredClaims: jwt.RegisteredClaims{
Issuer: "Netmaker",
Subject: fmt.Sprintf("user|%s", username),
IssuedAt: jwt.NewNumericDate(time.Now()),
ExpiresAt: jwt.NewNumericDate(d),
ID: tokenID,
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, err := token.SignedString(jwtSecretKey)
if err == nil {
return tokenString, nil
}
return "", err
}
// CreateUserJWT - creates a user jwt token
func CreateUserJWT(username string, role models.UserRoleID) (response string, err error) {
expirationTime := time.Now().Add(servercfg.GetServerConfig().JwtValidityDuration)
claims := &models.UserClaims{
UserName: username,
Role: role,
TokenType: models.UserIDTokenType,
RacAutoDisable: servercfg.GetRacAutoDisable() && (role != models.SuperAdminRole && role != models.AdminRole),
RegisteredClaims: jwt.RegisteredClaims{
Issuer: "Netmaker",
@ -75,18 +101,6 @@ func CreateUserJWT(username string, role models.UserRoleID) (response string, er
return "", err
}
// VerifyJWT verifies Auth Header
func VerifyJWT(bearerToken string) (username string, issuperadmin, isadmin bool, err error) {
token := ""
tokenSplit := strings.Split(bearerToken, " ")
if len(tokenSplit) > 1 {
token = tokenSplit[1]
} else {
return "", false, false, errors.New("invalid auth header")
}
return VerifyUserToken(token)
}
func GetUserNameFromToken(authtoken string) (username string, err error) {
claims := &models.UserClaims{}
var tokenSplit = strings.Split(authtoken, " ")
@ -107,6 +121,20 @@ func GetUserNameFromToken(authtoken string) (username string, err error) {
if err != nil {
return "", Unauthorized_Err
}
if claims.TokenType == models.AccessTokenType {
jti := claims.ID
if jti != "" {
a := models.UserAccessToken{ID: jti}
// check if access token is active
err := a.Get()
if err != nil {
err = errors.New("token revoked")
return "", err
}
a.LastUsed = time.Now()
a.Update()
}
}
if token != nil && token.Valid {
var user *models.User
@ -131,15 +159,26 @@ func GetUserNameFromToken(authtoken string) (username string, err error) {
// VerifyUserToken func will used to Verify the JWT Token while using APIS
func VerifyUserToken(tokenString string) (username string, issuperadmin, isadmin bool, err error) {
claims := &models.UserClaims{}
if tokenString == servercfg.GetMasterKey() && servercfg.GetMasterKey() != "" {
return MasterUser, true, true, nil
}
token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) {
return jwtSecretKey, nil
})
if claims.TokenType == models.AccessTokenType {
jti := claims.ID
if jti != "" {
a := models.UserAccessToken{ID: jti}
// check if access token is active
err := a.Get()
if err != nil {
err = errors.New("token revoked")
return "", false, false, err
}
a.LastUsed = time.Now()
a.Update()
}
}
if token != nil && token.Valid {
var user *models.User
// check that user exists

View file

@ -302,6 +302,7 @@ func CreateNetwork(network models.Network) (models.Network, error) {
true,
uuid.Nil,
true,
false,
)
return network, nil

View file

@ -61,7 +61,7 @@ var CreateDefaultNetworkRolesAndGroups = func(netID models.NetworkID) {}
var CreateDefaultUserPolicies = func(netID models.NetworkID) {}
var GetUserGroupsInNetwork = func(netID models.NetworkID) (networkGrps map[models.UserGroupID]models.UserGroup) { return }
var GetUserGroup = func(groupId models.UserGroupID) (userGrps models.UserGroup, err error) { return }
var AddGlobalNetRolesToAdmins = func(u models.User) {}
var AddGlobalNetRolesToAdmins = func(u *models.User) {}
// GetRole - fetches role template by id
func GetRole(roleID models.UserRoleID) (models.UserRolePermissionTemplate, error) {

View file

@ -10,6 +10,7 @@ import (
"log/slog"
"net"
"os"
"reflect"
"strings"
"time"
"unicode"
@ -201,3 +202,23 @@ func VersionLessThan(v1, v2 string) (bool, error) {
}
return sv1.LT(sv2), nil
}
// Compare any two maps with any key and value types
func CompareMaps[K comparable, V any](a, b map[K]V) bool {
if len(a) != len(b) {
return false
}
for key, valA := range a {
valB, ok := b[key]
if !ok {
return false
}
if !reflect.DeepEqual(valA, valB) {
return false
}
}
return true
}

51
main.go
View file

@ -3,6 +3,8 @@ package main
import (
"context"
"crypto/rand"
"encoding/json"
"flag"
"fmt"
"os"
@ -12,9 +14,11 @@ import (
"sync"
"syscall"
"github.com/google/uuid"
"github.com/gravitl/netmaker/config"
controller "github.com/gravitl/netmaker/controllers"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/db"
"github.com/gravitl/netmaker/functions"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic"
@ -22,9 +26,11 @@ import (
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/mq"
"github.com/gravitl/netmaker/netclient/ncutils"
"github.com/gravitl/netmaker/schema"
"github.com/gravitl/netmaker/servercfg"
"github.com/gravitl/netmaker/serverctl"
_ "go.uber.org/automaxprocs"
"golang.org/x/crypto/nacl/box"
"golang.org/x/exp/slog"
)
@ -99,8 +105,13 @@ func initialize() { // Client Mode Prereq Check
if err = database.InitializeDatabase(); err != nil {
logger.FatalLog("Error connecting to database: ", err.Error())
}
// initialize sql schema db.
err = db.InitializeDB(schema.ListModels()...)
if err != nil {
logger.FatalLog("Error connecting to v1 database: ", err.Error())
}
logger.Log(0, "database successfully connected")
initializeUUID()
//initialize cache
_, _ = logic.GetNetworks()
_, _ = logic.GetAllNodes()
@ -247,3 +258,41 @@ func setGarbageCollection() {
debug.SetGCPercent(ncutils.DEFAULT_GC_PERCENT)
}
}
// initializeUUID - create a UUID record for server if none exists
func initializeUUID() error {
records, err := database.FetchRecords(database.SERVER_UUID_TABLE_NAME)
if err != nil {
if !database.IsEmptyRecord(err) {
return err
}
} else if len(records) > 0 {
return nil
}
// setup encryption keys
var trafficPubKey, trafficPrivKey, errT = box.GenerateKey(rand.Reader) // generate traffic keys
if errT != nil {
return errT
}
tPriv, err := ncutils.ConvertKeyToBytes(trafficPrivKey)
if err != nil {
return err
}
tPub, err := ncutils.ConvertKeyToBytes(trafficPubKey)
if err != nil {
return err
}
telemetry := models.Telemetry{
UUID: uuid.NewString(),
TrafficKeyPriv: tPriv,
TrafficKeyPub: tPub,
}
telJSON, err := json.Marshal(&telemetry)
if err != nil {
return err
}
return database.Insert(database.SERVER_UUID_RECORD_KEY, string(telJSON), database.SERVER_UUID_TABLE_NAME)
}

View file

@ -151,6 +151,7 @@ func updateEnrollmentKeys() {
true,
uuid.Nil,
true,
false,
)
}
@ -405,11 +406,12 @@ func syncUsers() {
}
if user.PlatformRoleID == models.SuperAdminRole && !user.IsSuperAdmin {
user.IsSuperAdmin = true
logic.UpsertUser(user)
}
if user.PlatformRoleID.String() != "" {
logic.MigrateUserRoleAndGroups(user)
logic.AddGlobalNetRolesToAdmins(user)
logic.AddGlobalNetRolesToAdmins(&user)
logic.UpsertUser(user)
continue
}
user.AuthType = models.BasicAuth
@ -430,9 +432,9 @@ func syncUsers() {
} else {
user.PlatformRoleID = models.ServiceUser
}
logic.UpsertUser(user)
logic.AddGlobalNetRolesToAdmins(user)
logic.AddGlobalNetRolesToAdmins(&user)
logic.MigrateUserRoleAndGroups(user)
logic.UpsertUser(user)
}
}

View file

@ -1,13 +1,60 @@
package models
// AccessToken - token used to access netmaker
type AccessToken struct {
APIConnString string `json:"apiconnstring"`
ClientConfig
import (
"context"
"time"
"github.com/gravitl/netmaker/db"
)
// accessTokenTableName - access tokens table
const accessTokenTableName = "user_access_tokens"
// UserAccessToken - token used to access netmaker
type UserAccessToken struct {
ID string `gorm:"id,primary_key" json:"id"`
Name string `gorm:"name" json:"name"`
UserName string `gorm:"user_name" json:"user_name"`
ExpiresAt time.Time `gorm:"expires_at" json:"expires_at"`
LastUsed time.Time `gorm:"last_used" json:"last_used"`
CreatedBy string `gorm:"created_by" json:"created_by"`
CreatedAt time.Time `gorm:"created_at" json:"created_at"`
}
// ClientConfig - the config of the client
type ClientConfig struct {
Network string `json:"network"`
Key string `json:"key"`
func (a *UserAccessToken) Table() string {
return accessTokenTableName
}
func (a *UserAccessToken) Get() error {
return db.FromContext(context.TODO()).Table(a.Table()).First(&a).Where("id = ?", a.ID).Error
}
func (a *UserAccessToken) Update() error {
return db.FromContext(context.TODO()).Table(a.Table()).Where("id = ?", a.ID).Updates(&a).Error
}
func (a *UserAccessToken) Create() error {
return db.FromContext(context.TODO()).Table(a.Table()).Create(&a).Error
}
func (a *UserAccessToken) List() (ats []UserAccessToken, err error) {
err = db.FromContext(context.TODO()).Table(a.Table()).Find(&ats).Error
return
}
func (a *UserAccessToken) ListByUser() (ats []UserAccessToken) {
db.FromContext(context.TODO()).Table(a.Table()).Where("user_name = ?", a.UserName).Find(&ats)
if ats == nil {
ats = []UserAccessToken{}
}
return
}
func (a *UserAccessToken) Delete() error {
return db.FromContext(context.TODO()).Table(a.Table()).Where("id = ?", a.ID).Delete(&a).Error
}
func (a *UserAccessToken) DeleteAllUserTokens() error {
return db.FromContext(context.TODO()).Table(a.Table()).Where("user_name = ? OR created_by = ?", a.UserName, a.UserName).Delete(&a).Error
}

View file

@ -54,6 +54,7 @@ type EnrollmentKey struct {
Relay uuid.UUID `json:"relay"`
Groups []TagID `json:"groups"`
Default bool `json:"default"`
AutoEgress bool `json:"auto_egress"`
}
// APIEnrollmentKey - used to create enrollment keys via API
@ -66,6 +67,7 @@ type APIEnrollmentKey struct {
Type KeyType `json:"type"`
Relay string `json:"relay"`
Groups []TagID `json:"groups"`
AutoEgress bool `json:"auto_egress"`
}
// RegisterResponse - the response to a successful enrollment register

View file

@ -263,6 +263,7 @@ type NodeJoinResponse struct {
type ServerConfig struct {
CoreDNSAddr string `yaml:"corednsaddr"`
API string `yaml:"api"`
APIHost string `yaml:"apihost"`
APIPort string `yaml:"apiport"`
DNSMode string `yaml:"dnsmode"`
Version string `yaml:"version"`

View file

@ -13,6 +13,7 @@ type RsrcID string
type UserRoleID string
type UserGroupID string
type AuthType string
type TokenType string
var (
BasicAuth AuthType = "basic_auth"
@ -35,6 +36,15 @@ func GetRAGRoleID(netID, hostID string) UserRoleID {
return UserRoleID(fmt.Sprintf("netID-%s-rag-%s", netID, hostID))
}
func (t TokenType) String() string {
return string(t)
}
var (
UserIDTokenType TokenType = "user_id_token"
AccessTokenType TokenType = "access_token"
)
var RsrcTypeMap = map[RsrcType]struct{}{
HostRsrc: {},
RelayRsrc: {},
@ -185,6 +195,8 @@ type UserAuthParams struct {
type UserClaims struct {
Role UserRoleID
UserName string
Api string
TokenType TokenType
RacAutoDisable bool
jwt.RegisteredClaims
}

View file

@ -1211,11 +1211,10 @@ func GetUserGroupsInNetwork(netID models.NetworkID) (networkGrps map[models.User
return
}
func AddGlobalNetRolesToAdmins(u models.User) {
func AddGlobalNetRolesToAdmins(u *models.User) {
if u.PlatformRoleID != models.SuperAdminRole && u.PlatformRoleID != models.AdminRole {
return
}
u.UserGroups = make(map[models.UserGroupID]struct{})
u.UserGroups[models.UserGroupID(fmt.Sprintf("global-%s-grp", models.NetworkAdmin))] = struct{}{}
logic.UpsertUser(u)
}

View file

@ -1,8 +1,11 @@
package schema
import "github.com/gravitl/netmaker/models"
// ListModels lists all the models in this schema.
func ListModels() []interface{} {
return []interface{}{
&Job{},
&models.UserAccessToken{},
}
}

View file

@ -137,6 +137,7 @@ func GetServerInfo() models.ServerConfig {
cfg.MQUserName = GetMqUserName()
cfg.MQPassword = GetMqPassword()
}
cfg.APIHost = GetAPIHost()
cfg.API = GetAPIConnString()
cfg.CoreDNSAddr = GetCoreDNSAddr()
cfg.APIPort = GetAPIPort()