mirror of
https://github.com/gravitl/netmaker.git
synced 2025-09-04 04:04:17 +08:00
NET-2000: Api access tokens (#3418)
* feat: api access tokens * revoke all user tokens * redefine access token api routes, add auto egress option to enrollment keys * fix revoked tokens to be unauthorized * remove unused functions * convert access token to sql schema * switch access token to sql schema * revoke token generated by an user * add user token creation restriction by user role * add forbidden check for access token creation * revoke user token when group or role is changed * add default group to admin users on update * fix token removal on user update * fix token removal on user update
This commit is contained in:
parent
d5bdc723fc
commit
ca95954fb5
28 changed files with 507 additions and 200 deletions
|
@ -160,6 +160,7 @@ func createEnrollmentKey(w http.ResponseWriter, r *http.Request) {
|
|||
enrollmentKeyBody.Unlimited,
|
||||
relayId,
|
||||
false,
|
||||
enrollmentKeyBody.AutoEgress,
|
||||
)
|
||||
if err != nil {
|
||||
logger.Log(0, r.Header.Get("user"), "failed to create enrollment key:", err.Error())
|
||||
|
|
|
@ -6,7 +6,9 @@ import (
|
|||
"fmt"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/gravitl/netmaker/auth"
|
||||
|
@ -37,7 +39,164 @@ func userHandlers(r *mux.Router) {
|
|||
r.HandleFunc("/api/v1/users", logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(getUserV1)))).Methods(http.MethodGet)
|
||||
r.HandleFunc("/api/users", logic.SecurityCheck(true, http.HandlerFunc(getUsers))).Methods(http.MethodGet)
|
||||
r.HandleFunc("/api/v1/users/roles", logic.SecurityCheck(true, http.HandlerFunc(ListRoles))).Methods(http.MethodGet)
|
||||
r.HandleFunc("/api/v1/users/access_token", logic.SecurityCheck(true, http.HandlerFunc(createUserAccessToken))).Methods(http.MethodPost)
|
||||
r.HandleFunc("/api/v1/users/access_token", logic.SecurityCheck(true, http.HandlerFunc(getUserAccessTokens))).Methods(http.MethodGet)
|
||||
r.HandleFunc("/api/v1/users/access_token", logic.SecurityCheck(true, http.HandlerFunc(deleteUserAccessTokens))).Methods(http.MethodDelete)
|
||||
}
|
||||
|
||||
// @Summary Authenticate a user to retrieve an authorization token
|
||||
// @Router /api/v1/users/{username}/access_token [post]
|
||||
// @Tags Auth
|
||||
// @Accept json
|
||||
// @Param body body models.UserAuthParams true "Authentication parameters"
|
||||
// @Success 200 {object} models.SuccessResponse
|
||||
// @Failure 400 {object} models.ErrorResponse
|
||||
// @Failure 401 {object} models.ErrorResponse
|
||||
// @Failure 500 {object} models.ErrorResponse
|
||||
func createUserAccessToken(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// Auth request consists of Mac Address and Password (from node that is authorizing
|
||||
// in case of Master, auth is ignored and mac is set to "mastermac"
|
||||
var req models.UserAccessToken
|
||||
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
logger.Log(0, "error decoding request body: ",
|
||||
err.Error())
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
|
||||
return
|
||||
}
|
||||
if req.Name == "" {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("name is required"), "badrequest"))
|
||||
return
|
||||
}
|
||||
if req.UserName == "" {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("username is required"), "badrequest"))
|
||||
return
|
||||
}
|
||||
caller, err := logic.GetUser(r.Header.Get("user"))
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "unauthorized"))
|
||||
return
|
||||
}
|
||||
user, err := logic.GetUser(req.UserName)
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "unauthorized"))
|
||||
return
|
||||
}
|
||||
if caller.UserName != user.UserName && caller.PlatformRoleID != models.SuperAdminRole {
|
||||
if caller.PlatformRoleID == models.AdminRole {
|
||||
if user.PlatformRoleID == models.SuperAdminRole || user.PlatformRoleID == models.AdminRole {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("not enough permissions to create token for user "+user.UserName), logic.Forbidden_Msg))
|
||||
return
|
||||
}
|
||||
} else {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("not enough permissions to create token for user "+user.UserName), logic.Forbidden_Msg))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
req.ID = uuid.New().String()
|
||||
req.CreatedBy = r.Header.Get("user")
|
||||
req.CreatedAt = time.Now()
|
||||
jwt, err := logic.CreateUserAccessJwtToken(user.UserName, user.PlatformRoleID, req.ExpiresAt, req.ID)
|
||||
if jwt == "" {
|
||||
// very unlikely that err is !nil and no jwt returned, but handle it anyways.
|
||||
logic.ReturnErrorResponse(
|
||||
w,
|
||||
r,
|
||||
logic.FormatError(errors.New("error creating access token "+err.Error()), "internal"),
|
||||
)
|
||||
return
|
||||
}
|
||||
err = req.Create()
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(
|
||||
w,
|
||||
r,
|
||||
logic.FormatError(errors.New("error creating access token "+err.Error()), "internal"),
|
||||
)
|
||||
return
|
||||
}
|
||||
logic.ReturnSuccessResponseWithJson(w, r, models.SuccessfulUserLoginResponse{
|
||||
AuthToken: jwt,
|
||||
UserName: req.UserName,
|
||||
}, "api access token has generated for user "+req.UserName)
|
||||
}
|
||||
|
||||
// @Summary Authenticate a user to retrieve an authorization token
|
||||
// @Router /api/v1/users/{username}/access_token [post]
|
||||
// @Tags Auth
|
||||
// @Accept json
|
||||
// @Param body body models.UserAuthParams true "Authentication parameters"
|
||||
// @Success 200 {object} models.SuccessResponse
|
||||
// @Failure 400 {object} models.ErrorResponse
|
||||
// @Failure 401 {object} models.ErrorResponse
|
||||
// @Failure 500 {object} models.ErrorResponse
|
||||
func getUserAccessTokens(w http.ResponseWriter, r *http.Request) {
|
||||
username := r.URL.Query().Get("username")
|
||||
if username == "" {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("username is required"), "badrequest"))
|
||||
return
|
||||
}
|
||||
logic.ReturnSuccessResponseWithJson(w, r, (&models.UserAccessToken{UserName: username}).ListByUser(), "fetched api access tokens for user "+username)
|
||||
}
|
||||
|
||||
// @Summary Authenticate a user to retrieve an authorization token
|
||||
// @Router /api/v1/users/{username}/access_token [post]
|
||||
// @Tags Auth
|
||||
// @Accept json
|
||||
// @Param body body models.UserAuthParams true "Authentication parameters"
|
||||
// @Success 200 {object} models.SuccessResponse
|
||||
// @Failure 400 {object} models.ErrorResponse
|
||||
// @Failure 401 {object} models.ErrorResponse
|
||||
// @Failure 500 {object} models.ErrorResponse
|
||||
func deleteUserAccessTokens(w http.ResponseWriter, r *http.Request) {
|
||||
id := r.URL.Query().Get("id")
|
||||
if id == "" {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("id is required"), "badrequest"))
|
||||
return
|
||||
}
|
||||
a := models.UserAccessToken{
|
||||
ID: id,
|
||||
}
|
||||
err := a.Get()
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("id is required"), "badrequest"))
|
||||
return
|
||||
}
|
||||
caller, err := logic.GetUser(r.Header.Get("user"))
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "unauthorized"))
|
||||
return
|
||||
}
|
||||
user, err := logic.GetUser(a.UserName)
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "unauthorized"))
|
||||
return
|
||||
}
|
||||
if caller.UserName != user.UserName && caller.PlatformRoleID != models.SuperAdminRole {
|
||||
if caller.PlatformRoleID == models.AdminRole {
|
||||
if user.PlatformRoleID == models.SuperAdminRole || user.PlatformRoleID == models.AdminRole {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("not enough permissions to delete token of user "+user.UserName), logic.Forbidden_Msg))
|
||||
return
|
||||
}
|
||||
} else {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("not enough permissions to delete token of user "+user.UserName), logic.Forbidden_Msg))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
err = (&models.UserAccessToken{ID: id}).Delete()
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(
|
||||
w,
|
||||
r,
|
||||
logic.FormatError(errors.New("error deleting access token "+err.Error()), "internal"),
|
||||
)
|
||||
return
|
||||
}
|
||||
logic.ReturnSuccessResponseWithJson(w, r, nil, "revoked access token")
|
||||
}
|
||||
|
||||
// @Summary Authenticate a user to retrieve an authorization token
|
||||
|
@ -592,7 +751,10 @@ func updateUser(w http.ResponseWriter, r *http.Request) {
|
|||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "forbidden"))
|
||||
return
|
||||
}
|
||||
|
||||
logic.AddGlobalNetRolesToAdmins(&userchange)
|
||||
if userchange.PlatformRoleID != user.PlatformRoleID || !logic.CompareMaps(user.UserGroups, userchange.UserGroups) {
|
||||
(&models.UserAccessToken{UserName: user.UserName}).DeleteAllUserTokens()
|
||||
}
|
||||
user, err = logic.UpdateUser(&userchange, user)
|
||||
if err != nil {
|
||||
logger.Log(0, username,
|
||||
|
@ -671,17 +833,12 @@ func deleteUser(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
}
|
||||
success, err := logic.DeleteUser(username)
|
||||
err = logic.DeleteUser(username)
|
||||
if err != nil {
|
||||
logger.Log(0, username,
|
||||
"failed to delete user: ", err.Error())
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
|
||||
return
|
||||
} else if !success {
|
||||
err := errors.New("delete unsuccessful")
|
||||
logger.Log(0, username, err.Error())
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
|
||||
return
|
||||
}
|
||||
// check and delete extclient with this ownerID
|
||||
go func() {
|
||||
|
|
|
@ -1,18 +1,12 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gravitl/netmaker/logger"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/gravitl/netmaker/netclient/ncutils"
|
||||
"github.com/gravitl/netmaker/servercfg"
|
||||
"golang.org/x/crypto/nacl/box"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -25,6 +19,8 @@ const (
|
|||
DELETED_NODES_TABLE_NAME = "deletednodes"
|
||||
// USERS_TABLE_NAME - users table
|
||||
USERS_TABLE_NAME = "users"
|
||||
// ACCESS_TOKENS_TABLE_NAME - access tokens table
|
||||
ACCESS_TOKENS_TABLE_NAME = "user_access_tokens"
|
||||
// USER_PERMISSIONS_TABLE_NAME - user permissions table
|
||||
USER_PERMISSIONS_TABLE_NAME = "user_permissions"
|
||||
// CERTS_TABLE_NAME - certificates table
|
||||
|
@ -129,6 +125,7 @@ var Tables = []string{
|
|||
TAG_TABLE_NAME,
|
||||
ACLS_TABLE_NAME,
|
||||
PEER_ACK_TABLE,
|
||||
// ACCESS_TOKENS_TABLE_NAME,
|
||||
}
|
||||
|
||||
func getCurrentDB() map[string]interface{} {
|
||||
|
@ -160,7 +157,7 @@ func InitializeDatabase() error {
|
|||
time.Sleep(2 * time.Second)
|
||||
}
|
||||
createTables()
|
||||
return initializeUUID()
|
||||
return nil
|
||||
}
|
||||
|
||||
func createTables() {
|
||||
|
@ -173,35 +170,17 @@ func CreateTable(tableName string) error {
|
|||
return getCurrentDB()[CREATE_TABLE].(func(string) error)(tableName)
|
||||
}
|
||||
|
||||
// IsJSONString - checks if valid json
|
||||
func IsJSONString(value string) bool {
|
||||
var jsonInt interface{}
|
||||
var nodeInt models.Node
|
||||
return json.Unmarshal([]byte(value), &jsonInt) == nil || json.Unmarshal([]byte(value), &nodeInt) == nil
|
||||
}
|
||||
|
||||
// Insert - inserts object into db
|
||||
func Insert(key string, value string, tableName string) error {
|
||||
dbMutex.Lock()
|
||||
defer dbMutex.Unlock()
|
||||
if key != "" && value != "" && IsJSONString(value) {
|
||||
if key != "" && value != "" {
|
||||
return getCurrentDB()[INSERT].(func(string, string, string) error)(key, value, tableName)
|
||||
} else {
|
||||
return errors.New("invalid insert " + key + " : " + value)
|
||||
}
|
||||
}
|
||||
|
||||
// InsertPeer - inserts peer into db
|
||||
func InsertPeer(key string, value string) error {
|
||||
dbMutex.Lock()
|
||||
defer dbMutex.Unlock()
|
||||
if key != "" && value != "" && IsJSONString(value) {
|
||||
return getCurrentDB()[INSERT_PEER].(func(string, string) error)(key, value)
|
||||
} else {
|
||||
return errors.New("invalid peer insert " + key + " : " + value)
|
||||
}
|
||||
}
|
||||
|
||||
// DeleteRecord - deletes a record from db
|
||||
func DeleteRecord(tableName string, key string) error {
|
||||
dbMutex.Lock()
|
||||
|
@ -243,44 +222,6 @@ func FetchRecords(tableName string) (map[string]string, error) {
|
|||
return getCurrentDB()[FETCH_ALL].(func(string) (map[string]string, error))(tableName)
|
||||
}
|
||||
|
||||
// initializeUUID - create a UUID record for server if none exists
|
||||
func initializeUUID() error {
|
||||
records, err := FetchRecords(SERVER_UUID_TABLE_NAME)
|
||||
if err != nil {
|
||||
if !IsEmptyRecord(err) {
|
||||
return err
|
||||
}
|
||||
} else if len(records) > 0 {
|
||||
return nil
|
||||
}
|
||||
// setup encryption keys
|
||||
var trafficPubKey, trafficPrivKey, errT = box.GenerateKey(rand.Reader) // generate traffic keys
|
||||
if errT != nil {
|
||||
return errT
|
||||
}
|
||||
tPriv, err := ncutils.ConvertKeyToBytes(trafficPrivKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tPub, err := ncutils.ConvertKeyToBytes(trafficPubKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
telemetry := models.Telemetry{
|
||||
UUID: uuid.NewString(),
|
||||
TrafficKeyPriv: tPriv,
|
||||
TrafficKeyPub: tPub,
|
||||
}
|
||||
telJSON, err := json.Marshal(&telemetry)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return Insert(SERVER_UUID_RECORD_KEY, string(telJSON), SERVER_UUID_TABLE_NAME)
|
||||
}
|
||||
|
||||
// CloseDB - closes a database gracefully
|
||||
func CloseDB() {
|
||||
getCurrentDB()[CLOSE_DB].(func())()
|
||||
|
|
|
@ -59,7 +59,7 @@ func pgCreateTable(tableName string) error {
|
|||
}
|
||||
|
||||
func pgInsert(key string, value string, tableName string) error {
|
||||
if key != "" && value != "" && IsJSONString(value) {
|
||||
if key != "" && value != "" {
|
||||
insertSQL := "INSERT INTO " + tableName + " (key, value) VALUES ($1, $2) ON CONFLICT (key) DO UPDATE SET value = $3;"
|
||||
statement, err := PGDB.Prepare(insertSQL)
|
||||
if err != nil {
|
||||
|
@ -77,7 +77,7 @@ func pgInsert(key string, value string, tableName string) error {
|
|||
}
|
||||
|
||||
func pgInsertPeer(key string, value string) error {
|
||||
if key != "" && value != "" && IsJSONString(value) {
|
||||
if key != "" && value != "" {
|
||||
err := pgInsert(key, value, PEERS_TABLE_NAME)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
@ -43,7 +43,7 @@ func rqliteCreateTable(tableName string) error {
|
|||
}
|
||||
|
||||
func rqliteInsert(key string, value string, tableName string) error {
|
||||
if key != "" && value != "" && IsJSONString(value) {
|
||||
if key != "" && value != "" {
|
||||
_, err := RQliteDatabase.WriteOne("INSERT OR REPLACE INTO " + tableName + " (key, value) VALUES ('" + key + "', '" + value + "')")
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -54,7 +54,7 @@ func rqliteInsert(key string, value string, tableName string) error {
|
|||
}
|
||||
|
||||
func rqliteInsertPeer(key string, value string) error {
|
||||
if key != "" && value != "" && IsJSONString(value) {
|
||||
if key != "" && value != "" {
|
||||
_, err := RQliteDatabase.WriteOne("INSERT OR REPLACE INTO " + PEERS_TABLE_NAME + " (key, value) VALUES ('" + key + "', '" + value + "')")
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
@ -61,7 +61,7 @@ func sqliteCreateTable(tableName string) error {
|
|||
}
|
||||
|
||||
func sqliteInsert(key string, value string, tableName string) error {
|
||||
if key != "" && value != "" && IsJSONString(value) {
|
||||
if key != "" && value != "" {
|
||||
insertSQL := "INSERT OR REPLACE INTO " + tableName + " (key, value) VALUES (?, ?)"
|
||||
statement, err := SqliteDB.Prepare(insertSQL)
|
||||
if err != nil {
|
||||
|
@ -78,7 +78,7 @@ func sqliteInsert(key string, value string, tableName string) error {
|
|||
}
|
||||
|
||||
func sqliteInsertPeer(key string, value string) error {
|
||||
if key != "" && value != "" && IsJSONString(value) {
|
||||
if key != "" && value != "" {
|
||||
err := sqliteInsert(key, value, PEERS_TABLE_NAME)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
@ -1,59 +0,0 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// SetPeers - sets peers for a network
|
||||
func SetPeers(newPeers map[string]string, networkName string) bool {
|
||||
areEqual := PeersAreEqual(newPeers, networkName)
|
||||
if !areEqual {
|
||||
jsonData, err := json.Marshal(newPeers)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
InsertPeer(networkName, string(jsonData))
|
||||
return true
|
||||
}
|
||||
return !areEqual
|
||||
}
|
||||
|
||||
// GetPeers - gets peers for a given network
|
||||
func GetPeers(networkName string) (map[string]string, error) {
|
||||
record, err := FetchRecord(PEERS_TABLE_NAME, networkName)
|
||||
if err != nil && !IsEmptyRecord(err) {
|
||||
return nil, err
|
||||
}
|
||||
currentDataMap := make(map[string]string)
|
||||
if IsEmptyRecord(err) {
|
||||
return currentDataMap, nil
|
||||
}
|
||||
err = json.Unmarshal([]byte(record), ¤tDataMap)
|
||||
return currentDataMap, err
|
||||
}
|
||||
|
||||
// PeersAreEqual - checks if peers are the same
|
||||
func PeersAreEqual(toCompare map[string]string, networkName string) bool {
|
||||
currentDataMap, err := GetPeers(networkName)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if len(currentDataMap) != len(toCompare) {
|
||||
return false
|
||||
}
|
||||
for k := range currentDataMap {
|
||||
if toCompare[k] != currentDataMap[k] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// IsEmptyRecord - checks for if it's an empty record error or not
|
||||
func IsEmptyRecord(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(err.Error(), NO_RECORD) || strings.Contains(err.Error(), NO_RECORDS)
|
||||
}
|
11
database/utils.go
Normal file
11
database/utils.go
Normal file
|
@ -0,0 +1,11 @@
|
|||
package database
|
||||
|
||||
import "strings"
|
||||
|
||||
// IsEmptyRecord - checks for if it's an empty record error or not
|
||||
func IsEmptyRecord(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(err.Error(), NO_RECORD) || strings.Contains(err.Error(), NO_RECORDS)
|
||||
}
|
|
@ -2,7 +2,9 @@ package db
|
|||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/gravitl/netmaker/servercfg"
|
||||
"os"
|
||||
|
||||
"github.com/gravitl/netmaker/config"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
|
@ -14,10 +16,21 @@ type connector interface {
|
|||
connect() (*gorm.DB, error)
|
||||
}
|
||||
|
||||
// GetDB - gets the database type
|
||||
func GetDB() string {
|
||||
database := "sqlite"
|
||||
if os.Getenv("DATABASE") != "" {
|
||||
database = os.Getenv("DATABASE")
|
||||
} else if config.Config.Server.Database != "" {
|
||||
database = config.Config.Server.Database
|
||||
}
|
||||
return database
|
||||
}
|
||||
|
||||
// newConnector detects the database being
|
||||
// used and returns the corresponding connector.
|
||||
func newConnector() (connector, error) {
|
||||
switch servercfg.GetDB() {
|
||||
switch GetDB() {
|
||||
case "sqlite":
|
||||
return &sqliteConnector{}, nil
|
||||
case "postgres":
|
||||
|
|
7
db/db.go
7
db/db.go
|
@ -3,9 +3,10 @@ package db
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"gorm.io/gorm"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type ctxKey string
|
||||
|
@ -74,10 +75,6 @@ func Middleware(next http.Handler) http.Handler {
|
|||
//
|
||||
// The function panics, if a connection does not exist.
|
||||
func FromContext(ctx context.Context) *gorm.DB {
|
||||
db, ok := ctx.Value(dbCtxKey).(*gorm.DB)
|
||||
if !ok {
|
||||
panic(ErrDBNotFound)
|
||||
}
|
||||
|
||||
return db
|
||||
}
|
||||
|
|
|
@ -2,7 +2,10 @@ package db
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gravitl/netmaker/servercfg"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"github.com/gravitl/netmaker/config"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
|
@ -15,7 +18,7 @@ type postgresConnector struct{}
|
|||
// postgresConnector.connect connects and
|
||||
// initializes a connection to postgres.
|
||||
func (pg *postgresConnector) connect() (*gorm.DB, error) {
|
||||
pgConf := servercfg.GetSQLConf()
|
||||
pgConf := GetSQLConf()
|
||||
dsn := fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=%s connect_timeout=5",
|
||||
pgConf.Host,
|
||||
|
@ -47,3 +50,68 @@ func (pg *postgresConnector) connect() (*gorm.DB, error) {
|
|||
|
||||
return db, nil
|
||||
}
|
||||
func GetSQLConf() config.SQLConfig {
|
||||
var cfg config.SQLConfig
|
||||
cfg.Host = GetSQLHost()
|
||||
cfg.Port = GetSQLPort()
|
||||
cfg.Username = GetSQLUser()
|
||||
cfg.Password = GetSQLPass()
|
||||
cfg.DB = GetSQLDB()
|
||||
cfg.SSLMode = GetSQLSSLMode()
|
||||
return cfg
|
||||
}
|
||||
func GetSQLHost() string {
|
||||
host := "localhost"
|
||||
if os.Getenv("SQL_HOST") != "" {
|
||||
host = os.Getenv("SQL_HOST")
|
||||
} else if config.Config.SQL.Host != "" {
|
||||
host = config.Config.SQL.Host
|
||||
}
|
||||
return host
|
||||
}
|
||||
func GetSQLPort() int32 {
|
||||
port := int32(5432)
|
||||
envport, err := strconv.Atoi(os.Getenv("SQL_PORT"))
|
||||
if err == nil && envport != 0 {
|
||||
port = int32(envport)
|
||||
} else if config.Config.SQL.Port != 0 {
|
||||
port = config.Config.SQL.Port
|
||||
}
|
||||
return port
|
||||
}
|
||||
func GetSQLUser() string {
|
||||
user := "postgres"
|
||||
if os.Getenv("SQL_USER") != "" {
|
||||
user = os.Getenv("SQL_USER")
|
||||
} else if config.Config.SQL.Username != "" {
|
||||
user = config.Config.SQL.Username
|
||||
}
|
||||
return user
|
||||
}
|
||||
func GetSQLPass() string {
|
||||
pass := "nopass"
|
||||
if os.Getenv("SQL_PASS") != "" {
|
||||
pass = os.Getenv("SQL_PASS")
|
||||
} else if config.Config.SQL.Password != "" {
|
||||
pass = config.Config.SQL.Password
|
||||
}
|
||||
return pass
|
||||
}
|
||||
func GetSQLDB() string {
|
||||
db := "netmaker"
|
||||
if os.Getenv("SQL_DB") != "" {
|
||||
db = os.Getenv("SQL_DB")
|
||||
} else if config.Config.SQL.DB != "" {
|
||||
db = config.Config.SQL.DB
|
||||
}
|
||||
return db
|
||||
}
|
||||
func GetSQLSSLMode() string {
|
||||
sslmode := "disable"
|
||||
if os.Getenv("SQL_SSL_MODE") != "" {
|
||||
sslmode = os.Getenv("SQL_SSL_MODE")
|
||||
} else if config.Config.SQL.SSLMode != "" {
|
||||
sslmode = config.Config.SQL.SSLMode
|
||||
}
|
||||
return sslmode
|
||||
}
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
package db
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
// sqliteConnector for initializing and
|
||||
|
@ -28,7 +29,7 @@ func (s *sqliteConnector) connect() (*gorm.DB, error) {
|
|||
}
|
||||
}
|
||||
|
||||
dbFilePath := filepath.Join("data", "netmaker_v1.db")
|
||||
dbFilePath := filepath.Join("data", "netmaker.db")
|
||||
|
||||
// ensure netmaker_v1.db exists.
|
||||
_, err = os.Stat(dbFilePath)
|
||||
|
|
|
@ -169,6 +169,7 @@ func CreateUser(user *models.User) error {
|
|||
if IsOauthUser(user) == nil {
|
||||
user.AuthType = models.OAuth
|
||||
}
|
||||
AddGlobalNetRolesToAdmins(user)
|
||||
_, err = CreateUserJWT(user.UserName, user.PlatformRoleID)
|
||||
if err != nil {
|
||||
logger.Log(0, "failed to generate token", err.Error())
|
||||
|
@ -186,7 +187,6 @@ func CreateUser(user *models.User) error {
|
|||
logger.Log(0, "failed to insert user", err.Error())
|
||||
return err
|
||||
}
|
||||
AddGlobalNetRolesToAdmins(*user)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -305,7 +305,7 @@ func UpdateUser(userchange, user *models.User) (*models.User, error) {
|
|||
}
|
||||
user.UserGroups = userchange.UserGroups
|
||||
user.NetworkRoles = userchange.NetworkRoles
|
||||
AddGlobalNetRolesToAdmins(*user)
|
||||
AddGlobalNetRolesToAdmins(user)
|
||||
err := ValidateUser(user)
|
||||
if err != nil {
|
||||
return &models.User{}, err
|
||||
|
@ -349,19 +349,18 @@ func ValidateUser(user *models.User) error {
|
|||
}
|
||||
|
||||
// DeleteUser - deletes a given user
|
||||
func DeleteUser(user string) (bool, error) {
|
||||
func DeleteUser(user string) error {
|
||||
|
||||
if userRecord, err := database.FetchRecord(database.USERS_TABLE_NAME, user); err != nil || len(userRecord) == 0 {
|
||||
return false, errors.New("user does not exist")
|
||||
return errors.New("user does not exist")
|
||||
}
|
||||
|
||||
err := database.DeleteRecord(database.USERS_TABLE_NAME, user)
|
||||
if err != nil {
|
||||
return false, err
|
||||
return err
|
||||
}
|
||||
go RemoveUserFromAclPolicy(user)
|
||||
|
||||
return true, nil
|
||||
return (&models.UserAccessToken{UserName: user}).DeleteAllUserTokens()
|
||||
}
|
||||
|
||||
func SetAuthSecret(secret string) error {
|
||||
|
|
|
@ -38,7 +38,7 @@ var (
|
|||
)
|
||||
|
||||
// CreateEnrollmentKey - creates a new enrollment key in db
|
||||
func CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string, groups []models.TagID, unlimited bool, relay uuid.UUID, defaultKey bool) (*models.EnrollmentKey, error) {
|
||||
func CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string, groups []models.TagID, unlimited bool, relay uuid.UUID, defaultKey, autoEgress bool) (*models.EnrollmentKey, error) {
|
||||
newKeyID, err := getUniqueEnrollmentID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -54,6 +54,7 @@ func CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string
|
|||
Relay: relay,
|
||||
Groups: groups,
|
||||
Default: defaultKey,
|
||||
AutoEgress: autoEgress,
|
||||
}
|
||||
if uses > 0 {
|
||||
k.UsesRemaining = uses
|
||||
|
|
|
@ -14,35 +14,35 @@ func TestCreateEnrollmentKey(t *testing.T) {
|
|||
database.InitializeDatabase()
|
||||
defer database.CloseDB()
|
||||
t.Run("Can_Not_Create_Key", func(t *testing.T) {
|
||||
newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, nil, false, uuid.Nil, false)
|
||||
newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, nil, false, uuid.Nil, false, false)
|
||||
assert.Nil(t, newKey)
|
||||
assert.NotNil(t, err)
|
||||
assert.ErrorIs(t, err, models.ErrInvalidEnrollmentKey)
|
||||
})
|
||||
t.Run("Can_Create_Key_Uses", func(t *testing.T) {
|
||||
newKey, err := CreateEnrollmentKey(1, time.Time{}, nil, nil, nil, false, uuid.Nil, false)
|
||||
newKey, err := CreateEnrollmentKey(1, time.Time{}, nil, nil, nil, false, uuid.Nil, false, false)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 1, newKey.UsesRemaining)
|
||||
assert.True(t, newKey.IsValid())
|
||||
})
|
||||
t.Run("Can_Create_Key_Time", func(t *testing.T) {
|
||||
newKey, err := CreateEnrollmentKey(0, time.Now().Add(time.Minute), nil, nil, nil, false, uuid.Nil, false)
|
||||
newKey, err := CreateEnrollmentKey(0, time.Now().Add(time.Minute), nil, nil, nil, false, uuid.Nil, false, false)
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, newKey.IsValid())
|
||||
})
|
||||
t.Run("Can_Create_Key_Unlimited", func(t *testing.T) {
|
||||
newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, nil, true, uuid.Nil, false)
|
||||
newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, nil, true, uuid.Nil, false, false)
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, newKey.IsValid())
|
||||
})
|
||||
t.Run("Can_Create_Key_WithNetworks", func(t *testing.T) {
|
||||
newKey, err := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil, false)
|
||||
newKey, err := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil, false, false)
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, newKey.IsValid())
|
||||
assert.True(t, len(newKey.Networks) == 2)
|
||||
})
|
||||
t.Run("Can_Create_Key_WithTags", func(t *testing.T) {
|
||||
newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, []string{"tag1", "tag2"}, nil, true, uuid.Nil, false)
|
||||
newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, []string{"tag1", "tag2"}, nil, true, uuid.Nil, false, false)
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, newKey.IsValid())
|
||||
assert.True(t, len(newKey.Tags) == 2)
|
||||
|
@ -62,7 +62,7 @@ func TestCreateEnrollmentKey(t *testing.T) {
|
|||
func TestDelete_EnrollmentKey(t *testing.T) {
|
||||
database.InitializeDatabase()
|
||||
defer database.CloseDB()
|
||||
newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil, false)
|
||||
newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil, false, false)
|
||||
t.Run("Can_Delete_Key", func(t *testing.T) {
|
||||
assert.True(t, newKey.IsValid())
|
||||
err := DeleteEnrollmentKey(newKey.Value, false)
|
||||
|
@ -83,7 +83,7 @@ func TestDelete_EnrollmentKey(t *testing.T) {
|
|||
func TestDecrement_EnrollmentKey(t *testing.T) {
|
||||
database.InitializeDatabase()
|
||||
defer database.CloseDB()
|
||||
newKey, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, nil, false, uuid.Nil, false)
|
||||
newKey, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, nil, false, uuid.Nil, false, false)
|
||||
t.Run("Check_initial_uses", func(t *testing.T) {
|
||||
assert.True(t, newKey.IsValid())
|
||||
assert.Equal(t, newKey.UsesRemaining, 1)
|
||||
|
@ -107,9 +107,9 @@ func TestDecrement_EnrollmentKey(t *testing.T) {
|
|||
func TestUsability_EnrollmentKey(t *testing.T) {
|
||||
database.InitializeDatabase()
|
||||
defer database.CloseDB()
|
||||
key1, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, nil, false, uuid.Nil, false)
|
||||
key2, _ := CreateEnrollmentKey(0, time.Now().Add(time.Minute<<4), nil, nil, nil, false, uuid.Nil, false)
|
||||
key3, _ := CreateEnrollmentKey(0, time.Time{}, nil, nil, nil, true, uuid.Nil, false)
|
||||
key1, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, nil, false, uuid.Nil, false, false)
|
||||
key2, _ := CreateEnrollmentKey(0, time.Now().Add(time.Minute<<4), nil, nil, nil, false, uuid.Nil, false, false)
|
||||
key3, _ := CreateEnrollmentKey(0, time.Time{}, nil, nil, nil, true, uuid.Nil, false, false)
|
||||
t.Run("Check if valid use key can be used", func(t *testing.T) {
|
||||
assert.Equal(t, key1.UsesRemaining, 1)
|
||||
ok := TryToUseEnrollmentKey(key1)
|
||||
|
@ -145,7 +145,7 @@ func removeAllEnrollments() {
|
|||
func TestTokenize_EnrollmentKeys(t *testing.T) {
|
||||
database.InitializeDatabase()
|
||||
defer database.CloseDB()
|
||||
newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil, false)
|
||||
newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil, false, false)
|
||||
const defaultValue = "MwE5MwE5MwE5MwE5MwE5MwE5MwE5MwE5"
|
||||
const b64value = "eyJzZXJ2ZXIiOiJhcGkubXlzZXJ2ZXIuY29tIiwidmFsdWUiOiJNd0U1TXdFNU13RTVNd0U1TXdFNU13RTVNd0U1TXdFNSJ9"
|
||||
const serverAddr = "api.myserver.com"
|
||||
|
@ -178,7 +178,7 @@ func TestTokenize_EnrollmentKeys(t *testing.T) {
|
|||
func TestDeTokenize_EnrollmentKeys(t *testing.T) {
|
||||
database.InitializeDatabase()
|
||||
defer database.CloseDB()
|
||||
newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil, false)
|
||||
newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil, false, false)
|
||||
const b64Value = "eyJzZXJ2ZXIiOiJhcGkubXlzZXJ2ZXIuY29tIiwidmFsdWUiOiJNd0U1TXdFNU13RTVNd0U1TXdFNU13RTVNd0U1TXdFNSJ9"
|
||||
const serverAddr = "api.myserver.com"
|
||||
|
||||
|
|
|
@ -52,12 +52,38 @@ func CreateJWT(uuid string, macAddress string, network string) (response string,
|
|||
return "", err
|
||||
}
|
||||
|
||||
// CreateUserJWT - creates a user jwt token
|
||||
func CreateUserAccessJwtToken(username string, role models.UserRoleID, d time.Time, tokenID string) (response string, err error) {
|
||||
claims := &models.UserClaims{
|
||||
UserName: username,
|
||||
Role: role,
|
||||
TokenType: models.AccessTokenType,
|
||||
Api: servercfg.ServerInfo.APIHost,
|
||||
RacAutoDisable: servercfg.GetRacAutoDisable() && (role != models.SuperAdminRole && role != models.AdminRole),
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Issuer: "Netmaker",
|
||||
Subject: fmt.Sprintf("user|%s", username),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
ExpiresAt: jwt.NewNumericDate(d),
|
||||
ID: tokenID,
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
tokenString, err := token.SignedString(jwtSecretKey)
|
||||
if err == nil {
|
||||
return tokenString, nil
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
|
||||
// CreateUserJWT - creates a user jwt token
|
||||
func CreateUserJWT(username string, role models.UserRoleID) (response string, err error) {
|
||||
expirationTime := time.Now().Add(servercfg.GetServerConfig().JwtValidityDuration)
|
||||
claims := &models.UserClaims{
|
||||
UserName: username,
|
||||
Role: role,
|
||||
TokenType: models.UserIDTokenType,
|
||||
RacAutoDisable: servercfg.GetRacAutoDisable() && (role != models.SuperAdminRole && role != models.AdminRole),
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Issuer: "Netmaker",
|
||||
|
@ -75,18 +101,6 @@ func CreateUserJWT(username string, role models.UserRoleID) (response string, er
|
|||
return "", err
|
||||
}
|
||||
|
||||
// VerifyJWT verifies Auth Header
|
||||
func VerifyJWT(bearerToken string) (username string, issuperadmin, isadmin bool, err error) {
|
||||
token := ""
|
||||
tokenSplit := strings.Split(bearerToken, " ")
|
||||
if len(tokenSplit) > 1 {
|
||||
token = tokenSplit[1]
|
||||
} else {
|
||||
return "", false, false, errors.New("invalid auth header")
|
||||
}
|
||||
return VerifyUserToken(token)
|
||||
}
|
||||
|
||||
func GetUserNameFromToken(authtoken string) (username string, err error) {
|
||||
claims := &models.UserClaims{}
|
||||
var tokenSplit = strings.Split(authtoken, " ")
|
||||
|
@ -107,6 +121,20 @@ func GetUserNameFromToken(authtoken string) (username string, err error) {
|
|||
if err != nil {
|
||||
return "", Unauthorized_Err
|
||||
}
|
||||
if claims.TokenType == models.AccessTokenType {
|
||||
jti := claims.ID
|
||||
if jti != "" {
|
||||
a := models.UserAccessToken{ID: jti}
|
||||
// check if access token is active
|
||||
err := a.Get()
|
||||
if err != nil {
|
||||
err = errors.New("token revoked")
|
||||
return "", err
|
||||
}
|
||||
a.LastUsed = time.Now()
|
||||
a.Update()
|
||||
}
|
||||
}
|
||||
|
||||
if token != nil && token.Valid {
|
||||
var user *models.User
|
||||
|
@ -131,15 +159,26 @@ func GetUserNameFromToken(authtoken string) (username string, err error) {
|
|||
// VerifyUserToken func will used to Verify the JWT Token while using APIS
|
||||
func VerifyUserToken(tokenString string) (username string, issuperadmin, isadmin bool, err error) {
|
||||
claims := &models.UserClaims{}
|
||||
|
||||
if tokenString == servercfg.GetMasterKey() && servercfg.GetMasterKey() != "" {
|
||||
return MasterUser, true, true, nil
|
||||
}
|
||||
|
||||
token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) {
|
||||
return jwtSecretKey, nil
|
||||
})
|
||||
|
||||
if claims.TokenType == models.AccessTokenType {
|
||||
jti := claims.ID
|
||||
if jti != "" {
|
||||
a := models.UserAccessToken{ID: jti}
|
||||
// check if access token is active
|
||||
err := a.Get()
|
||||
if err != nil {
|
||||
err = errors.New("token revoked")
|
||||
return "", false, false, err
|
||||
}
|
||||
a.LastUsed = time.Now()
|
||||
a.Update()
|
||||
}
|
||||
}
|
||||
if token != nil && token.Valid {
|
||||
var user *models.User
|
||||
// check that user exists
|
||||
|
|
|
@ -302,6 +302,7 @@ func CreateNetwork(network models.Network) (models.Network, error) {
|
|||
true,
|
||||
uuid.Nil,
|
||||
true,
|
||||
false,
|
||||
)
|
||||
|
||||
return network, nil
|
||||
|
|
|
@ -61,7 +61,7 @@ var CreateDefaultNetworkRolesAndGroups = func(netID models.NetworkID) {}
|
|||
var CreateDefaultUserPolicies = func(netID models.NetworkID) {}
|
||||
var GetUserGroupsInNetwork = func(netID models.NetworkID) (networkGrps map[models.UserGroupID]models.UserGroup) { return }
|
||||
var GetUserGroup = func(groupId models.UserGroupID) (userGrps models.UserGroup, err error) { return }
|
||||
var AddGlobalNetRolesToAdmins = func(u models.User) {}
|
||||
var AddGlobalNetRolesToAdmins = func(u *models.User) {}
|
||||
|
||||
// GetRole - fetches role template by id
|
||||
func GetRole(roleID models.UserRoleID) (models.UserRolePermissionTemplate, error) {
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
"log/slog"
|
||||
"net"
|
||||
"os"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode"
|
||||
|
@ -201,3 +202,23 @@ func VersionLessThan(v1, v2 string) (bool, error) {
|
|||
}
|
||||
return sv1.LT(sv2), nil
|
||||
}
|
||||
|
||||
// Compare any two maps with any key and value types
|
||||
func CompareMaps[K comparable, V any](a, b map[K]V) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
|
||||
for key, valA := range a {
|
||||
valB, ok := b[key]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(valA, valB) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
|
51
main.go
51
main.go
|
@ -3,6 +3,8 @@ package main
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
|
@ -12,9 +14,11 @@ import (
|
|||
"sync"
|
||||
"syscall"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gravitl/netmaker/config"
|
||||
controller "github.com/gravitl/netmaker/controllers"
|
||||
"github.com/gravitl/netmaker/database"
|
||||
"github.com/gravitl/netmaker/db"
|
||||
"github.com/gravitl/netmaker/functions"
|
||||
"github.com/gravitl/netmaker/logger"
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
|
@ -22,9 +26,11 @@ import (
|
|||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/gravitl/netmaker/mq"
|
||||
"github.com/gravitl/netmaker/netclient/ncutils"
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
"github.com/gravitl/netmaker/servercfg"
|
||||
"github.com/gravitl/netmaker/serverctl"
|
||||
_ "go.uber.org/automaxprocs"
|
||||
"golang.org/x/crypto/nacl/box"
|
||||
"golang.org/x/exp/slog"
|
||||
)
|
||||
|
||||
|
@ -99,8 +105,13 @@ func initialize() { // Client Mode Prereq Check
|
|||
if err = database.InitializeDatabase(); err != nil {
|
||||
logger.FatalLog("Error connecting to database: ", err.Error())
|
||||
}
|
||||
// initialize sql schema db.
|
||||
err = db.InitializeDB(schema.ListModels()...)
|
||||
if err != nil {
|
||||
logger.FatalLog("Error connecting to v1 database: ", err.Error())
|
||||
}
|
||||
logger.Log(0, "database successfully connected")
|
||||
|
||||
initializeUUID()
|
||||
//initialize cache
|
||||
_, _ = logic.GetNetworks()
|
||||
_, _ = logic.GetAllNodes()
|
||||
|
@ -247,3 +258,41 @@ func setGarbageCollection() {
|
|||
debug.SetGCPercent(ncutils.DEFAULT_GC_PERCENT)
|
||||
}
|
||||
}
|
||||
|
||||
// initializeUUID - create a UUID record for server if none exists
|
||||
func initializeUUID() error {
|
||||
records, err := database.FetchRecords(database.SERVER_UUID_TABLE_NAME)
|
||||
if err != nil {
|
||||
if !database.IsEmptyRecord(err) {
|
||||
return err
|
||||
}
|
||||
} else if len(records) > 0 {
|
||||
return nil
|
||||
}
|
||||
// setup encryption keys
|
||||
var trafficPubKey, trafficPrivKey, errT = box.GenerateKey(rand.Reader) // generate traffic keys
|
||||
if errT != nil {
|
||||
return errT
|
||||
}
|
||||
tPriv, err := ncutils.ConvertKeyToBytes(trafficPrivKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tPub, err := ncutils.ConvertKeyToBytes(trafficPubKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
telemetry := models.Telemetry{
|
||||
UUID: uuid.NewString(),
|
||||
TrafficKeyPriv: tPriv,
|
||||
TrafficKeyPub: tPub,
|
||||
}
|
||||
telJSON, err := json.Marshal(&telemetry)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return database.Insert(database.SERVER_UUID_RECORD_KEY, string(telJSON), database.SERVER_UUID_TABLE_NAME)
|
||||
}
|
||||
|
|
|
@ -151,6 +151,7 @@ func updateEnrollmentKeys() {
|
|||
true,
|
||||
uuid.Nil,
|
||||
true,
|
||||
false,
|
||||
)
|
||||
|
||||
}
|
||||
|
@ -405,11 +406,12 @@ func syncUsers() {
|
|||
}
|
||||
if user.PlatformRoleID == models.SuperAdminRole && !user.IsSuperAdmin {
|
||||
user.IsSuperAdmin = true
|
||||
logic.UpsertUser(user)
|
||||
|
||||
}
|
||||
if user.PlatformRoleID.String() != "" {
|
||||
logic.MigrateUserRoleAndGroups(user)
|
||||
logic.AddGlobalNetRolesToAdmins(user)
|
||||
logic.AddGlobalNetRolesToAdmins(&user)
|
||||
logic.UpsertUser(user)
|
||||
continue
|
||||
}
|
||||
user.AuthType = models.BasicAuth
|
||||
|
@ -430,9 +432,9 @@ func syncUsers() {
|
|||
} else {
|
||||
user.PlatformRoleID = models.ServiceUser
|
||||
}
|
||||
logic.UpsertUser(user)
|
||||
logic.AddGlobalNetRolesToAdmins(user)
|
||||
logic.AddGlobalNetRolesToAdmins(&user)
|
||||
logic.MigrateUserRoleAndGroups(user)
|
||||
logic.UpsertUser(user)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,13 +1,60 @@
|
|||
package models
|
||||
|
||||
// AccessToken - token used to access netmaker
|
||||
type AccessToken struct {
|
||||
APIConnString string `json:"apiconnstring"`
|
||||
ClientConfig
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/gravitl/netmaker/db"
|
||||
)
|
||||
|
||||
// accessTokenTableName - access tokens table
|
||||
const accessTokenTableName = "user_access_tokens"
|
||||
|
||||
// UserAccessToken - token used to access netmaker
|
||||
type UserAccessToken struct {
|
||||
ID string `gorm:"id,primary_key" json:"id"`
|
||||
Name string `gorm:"name" json:"name"`
|
||||
UserName string `gorm:"user_name" json:"user_name"`
|
||||
ExpiresAt time.Time `gorm:"expires_at" json:"expires_at"`
|
||||
LastUsed time.Time `gorm:"last_used" json:"last_used"`
|
||||
CreatedBy string `gorm:"created_by" json:"created_by"`
|
||||
CreatedAt time.Time `gorm:"created_at" json:"created_at"`
|
||||
}
|
||||
|
||||
// ClientConfig - the config of the client
|
||||
type ClientConfig struct {
|
||||
Network string `json:"network"`
|
||||
Key string `json:"key"`
|
||||
func (a *UserAccessToken) Table() string {
|
||||
return accessTokenTableName
|
||||
}
|
||||
|
||||
func (a *UserAccessToken) Get() error {
|
||||
return db.FromContext(context.TODO()).Table(a.Table()).First(&a).Where("id = ?", a.ID).Error
|
||||
}
|
||||
|
||||
func (a *UserAccessToken) Update() error {
|
||||
return db.FromContext(context.TODO()).Table(a.Table()).Where("id = ?", a.ID).Updates(&a).Error
|
||||
}
|
||||
|
||||
func (a *UserAccessToken) Create() error {
|
||||
return db.FromContext(context.TODO()).Table(a.Table()).Create(&a).Error
|
||||
}
|
||||
|
||||
func (a *UserAccessToken) List() (ats []UserAccessToken, err error) {
|
||||
err = db.FromContext(context.TODO()).Table(a.Table()).Find(&ats).Error
|
||||
return
|
||||
}
|
||||
|
||||
func (a *UserAccessToken) ListByUser() (ats []UserAccessToken) {
|
||||
db.FromContext(context.TODO()).Table(a.Table()).Where("user_name = ?", a.UserName).Find(&ats)
|
||||
if ats == nil {
|
||||
ats = []UserAccessToken{}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *UserAccessToken) Delete() error {
|
||||
return db.FromContext(context.TODO()).Table(a.Table()).Where("id = ?", a.ID).Delete(&a).Error
|
||||
}
|
||||
|
||||
func (a *UserAccessToken) DeleteAllUserTokens() error {
|
||||
return db.FromContext(context.TODO()).Table(a.Table()).Where("user_name = ? OR created_by = ?", a.UserName, a.UserName).Delete(&a).Error
|
||||
|
||||
}
|
||||
|
|
|
@ -54,6 +54,7 @@ type EnrollmentKey struct {
|
|||
Relay uuid.UUID `json:"relay"`
|
||||
Groups []TagID `json:"groups"`
|
||||
Default bool `json:"default"`
|
||||
AutoEgress bool `json:"auto_egress"`
|
||||
}
|
||||
|
||||
// APIEnrollmentKey - used to create enrollment keys via API
|
||||
|
@ -66,6 +67,7 @@ type APIEnrollmentKey struct {
|
|||
Type KeyType `json:"type"`
|
||||
Relay string `json:"relay"`
|
||||
Groups []TagID `json:"groups"`
|
||||
AutoEgress bool `json:"auto_egress"`
|
||||
}
|
||||
|
||||
// RegisterResponse - the response to a successful enrollment register
|
||||
|
|
|
@ -263,6 +263,7 @@ type NodeJoinResponse struct {
|
|||
type ServerConfig struct {
|
||||
CoreDNSAddr string `yaml:"corednsaddr"`
|
||||
API string `yaml:"api"`
|
||||
APIHost string `yaml:"apihost"`
|
||||
APIPort string `yaml:"apiport"`
|
||||
DNSMode string `yaml:"dnsmode"`
|
||||
Version string `yaml:"version"`
|
||||
|
|
|
@ -13,6 +13,7 @@ type RsrcID string
|
|||
type UserRoleID string
|
||||
type UserGroupID string
|
||||
type AuthType string
|
||||
type TokenType string
|
||||
|
||||
var (
|
||||
BasicAuth AuthType = "basic_auth"
|
||||
|
@ -35,6 +36,15 @@ func GetRAGRoleID(netID, hostID string) UserRoleID {
|
|||
return UserRoleID(fmt.Sprintf("netID-%s-rag-%s", netID, hostID))
|
||||
}
|
||||
|
||||
func (t TokenType) String() string {
|
||||
return string(t)
|
||||
}
|
||||
|
||||
var (
|
||||
UserIDTokenType TokenType = "user_id_token"
|
||||
AccessTokenType TokenType = "access_token"
|
||||
)
|
||||
|
||||
var RsrcTypeMap = map[RsrcType]struct{}{
|
||||
HostRsrc: {},
|
||||
RelayRsrc: {},
|
||||
|
@ -185,6 +195,8 @@ type UserAuthParams struct {
|
|||
type UserClaims struct {
|
||||
Role UserRoleID
|
||||
UserName string
|
||||
Api string
|
||||
TokenType TokenType
|
||||
RacAutoDisable bool
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
|
|
@ -1211,11 +1211,10 @@ func GetUserGroupsInNetwork(netID models.NetworkID) (networkGrps map[models.User
|
|||
return
|
||||
}
|
||||
|
||||
func AddGlobalNetRolesToAdmins(u models.User) {
|
||||
func AddGlobalNetRolesToAdmins(u *models.User) {
|
||||
if u.PlatformRoleID != models.SuperAdminRole && u.PlatformRoleID != models.AdminRole {
|
||||
return
|
||||
}
|
||||
u.UserGroups = make(map[models.UserGroupID]struct{})
|
||||
u.UserGroups[models.UserGroupID(fmt.Sprintf("global-%s-grp", models.NetworkAdmin))] = struct{}{}
|
||||
logic.UpsertUser(u)
|
||||
}
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
package schema
|
||||
|
||||
import "github.com/gravitl/netmaker/models"
|
||||
|
||||
// ListModels lists all the models in this schema.
|
||||
func ListModels() []interface{} {
|
||||
return []interface{}{
|
||||
&Job{},
|
||||
&models.UserAccessToken{},
|
||||
}
|
||||
}
|
||||
|
|
|
@ -137,6 +137,7 @@ func GetServerInfo() models.ServerConfig {
|
|||
cfg.MQUserName = GetMqUserName()
|
||||
cfg.MQPassword = GetMqPassword()
|
||||
}
|
||||
cfg.APIHost = GetAPIHost()
|
||||
cfg.API = GetAPIConnString()
|
||||
cfg.CoreDNSAddr = GetCoreDNSAddr()
|
||||
cfg.APIPort = GetAPIPort()
|
||||
|
|
Loading…
Add table
Reference in a new issue