netmaker/logic/enrollmentkey.go

225 lines
5.6 KiB
Go
Raw Normal View History

2023-02-15 06:21:51 +08:00
package logic
import (
2023-02-16 05:32:16 +08:00
b64 "encoding/base64"
2023-02-15 06:21:51 +08:00
"encoding/json"
2023-02-16 04:52:58 +08:00
"errors"
2023-02-15 06:21:51 +08:00
"fmt"
"time"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/models"
)
2023-02-17 07:56:45 +08:00
// EnrollmentErrors - struct for holding EnrollmentKey error messages
var EnrollmentErrors = struct {
2023-02-16 05:32:16 +08:00
InvalidCreate error
NoKeyFound error
InvalidKey error
NoUsesRemaining error
FailedToTokenize error
FailedToDeTokenize error
2023-02-15 06:21:51 +08:00
}{
2023-02-16 05:32:16 +08:00
InvalidCreate: fmt.Errorf("invalid enrollment key created"),
NoKeyFound: fmt.Errorf("no enrollmentkey found"),
InvalidKey: fmt.Errorf("invalid key provided"),
NoUsesRemaining: fmt.Errorf("no uses remaining"),
FailedToTokenize: fmt.Errorf("failed to tokenize"),
FailedToDeTokenize: fmt.Errorf("failed to detokenize"),
2023-02-15 06:21:51 +08:00
}
// CreateEnrollmentKey - creates a new enrollment key in db
func CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string, unlimited bool) (k *models.EnrollmentKey, err error) {
newKeyID, err := getUniqueEnrollmentID()
if err != nil {
return nil, err
}
k = &models.EnrollmentKey{
Value: newKeyID,
Expiration: time.Time{},
UsesRemaining: 0,
Unlimited: unlimited,
Networks: []string{},
Tags: []string{},
2023-05-05 23:03:59 +08:00
Type: models.Undefined,
2023-02-15 06:21:51 +08:00
}
if uses > 0 {
k.UsesRemaining = uses
2023-05-05 23:03:59 +08:00
k.Type = models.Uses
2023-05-08 18:42:16 +08:00
} else if !expiration.IsZero() {
2023-02-15 06:21:51 +08:00
k.Expiration = expiration
2023-05-05 23:03:59 +08:00
k.Type = models.TimeExpiration
2023-05-08 18:42:16 +08:00
} else if k.Unlimited {
2023-05-05 23:03:59 +08:00
k.Type = models.Unlimited
2023-02-15 06:21:51 +08:00
}
if len(networks) > 0 {
k.Networks = networks
}
if len(tags) > 0 {
k.Tags = tags
}
if ok := k.Validate(); !ok {
2023-02-17 07:56:45 +08:00
return nil, EnrollmentErrors.InvalidCreate
2023-02-15 06:21:51 +08:00
}
if err = upsertEnrollmentKey(k); err != nil {
return nil, err
}
return
}
// GetAllEnrollmentKeys - fetches all enrollment keys from DB
func GetAllEnrollmentKeys() ([]*models.EnrollmentKey, error) {
2023-02-15 06:21:51 +08:00
currentKeys, err := getEnrollmentKeysMap()
if err != nil {
return nil, err
}
var currentKeysList = []*models.EnrollmentKey{}
2023-02-15 06:21:51 +08:00
for k := range currentKeys {
currentKeysList = append(currentKeysList, currentKeys[k])
2023-02-15 06:21:51 +08:00
}
return currentKeysList, nil
}
// GetEnrollmentKey - fetches a single enrollment key
// returns nil and error if not found
func GetEnrollmentKey(value string) (*models.EnrollmentKey, error) {
currentKeys, err := getEnrollmentKeysMap()
if err != nil {
return nil, err
}
if key, ok := currentKeys[value]; ok {
return key, nil
}
2023-02-17 07:56:45 +08:00
return nil, EnrollmentErrors.NoKeyFound
2023-02-15 06:21:51 +08:00
}
// DeleteEnrollmentKey - delete's a given enrollment key by value
func DeleteEnrollmentKey(value string) error {
2023-02-16 04:27:26 +08:00
_, err := GetEnrollmentKey(value)
if err != nil {
return err
}
2023-02-15 06:21:51 +08:00
return database.DeleteRecord(database.ENROLLMENT_KEYS_TABLE_NAME, value)
}
2023-02-16 04:52:58 +08:00
// TryToUseEnrollmentKey - checks first if key can be decremented
// returns true if it is decremented or isvalid
func TryToUseEnrollmentKey(k *models.EnrollmentKey) bool {
key, err := decrementEnrollmentKey(k.Value)
if err != nil {
2023-02-17 07:56:45 +08:00
if errors.Is(err, EnrollmentErrors.NoUsesRemaining) {
2023-02-16 04:52:58 +08:00
return k.IsValid()
}
} else {
k.UsesRemaining = key.UsesRemaining
return true
}
return false
}
2023-02-16 05:32:16 +08:00
// Tokenize - tokenizes an enrollment key to be used via registration
// and attaches it to the Token field on the struct
func Tokenize(k *models.EnrollmentKey, serverAddr string) error {
2023-02-16 23:56:13 +08:00
if len(serverAddr) == 0 || k == nil {
2023-02-17 07:56:45 +08:00
return EnrollmentErrors.FailedToTokenize
2023-02-16 05:32:16 +08:00
}
newToken := models.EnrollmentToken{
Server: serverAddr,
Value: k.Value,
}
data, err := json.Marshal(&newToken)
if err != nil {
return err
}
2023-02-16 23:56:13 +08:00
k.Token = b64.StdEncoding.EncodeToString(data)
2023-02-16 05:32:16 +08:00
return nil
}
// DeTokenize - detokenizes a base64 encoded string
// and finds the associated enrollment key
func DeTokenize(b64Token string) (*models.EnrollmentKey, error) {
if len(b64Token) == 0 {
2023-02-17 07:56:45 +08:00
return nil, EnrollmentErrors.FailedToDeTokenize
2023-02-16 05:32:16 +08:00
}
tokenData, err := b64.StdEncoding.DecodeString(b64Token)
if err != nil {
return nil, err
}
var newToken models.EnrollmentToken
err = json.Unmarshal(tokenData, &newToken)
if err != nil {
return nil, err
}
k, err := GetEnrollmentKey(newToken.Value)
if err != nil {
return nil, err
}
return k, nil
}
2023-02-16 04:52:58 +08:00
// == private ==
// decrementEnrollmentKey - decrements the uses on a key if above 0 remaining
func decrementEnrollmentKey(value string) (*models.EnrollmentKey, error) {
2023-02-16 04:27:26 +08:00
k, err := GetEnrollmentKey(value)
if err != nil {
return nil, err
}
if k.UsesRemaining == 0 {
2023-02-17 07:56:45 +08:00
return nil, EnrollmentErrors.NoUsesRemaining
2023-02-16 04:27:26 +08:00
}
k.UsesRemaining = k.UsesRemaining - 1
if err = upsertEnrollmentKey(k); err != nil {
return nil, err
}
return k, nil
}
2023-02-15 06:21:51 +08:00
func upsertEnrollmentKey(k *models.EnrollmentKey) error {
if k == nil {
2023-02-17 07:56:45 +08:00
return EnrollmentErrors.InvalidKey
2023-02-15 06:21:51 +08:00
}
data, err := json.Marshal(k)
if err != nil {
return err
}
return database.Insert(k.Value, string(data), database.ENROLLMENT_KEYS_TABLE_NAME)
}
func getUniqueEnrollmentID() (string, error) {
currentKeys, err := getEnrollmentKeysMap()
if err != nil {
return "", err
}
newID := RandomString(models.EnrollmentKeyLength)
2023-02-16 04:27:26 +08:00
for _, ok := currentKeys[newID]; ok; {
newID = RandomString(models.EnrollmentKeyLength)
2023-02-15 06:21:51 +08:00
}
return newID, nil
}
func getEnrollmentKeysMap() (map[string]*models.EnrollmentKey, error) {
records, err := database.FetchRecords(database.ENROLLMENT_KEYS_TABLE_NAME)
if err != nil {
if !database.IsEmptyRecord(err) {
return nil, err
}
}
2023-02-16 04:27:26 +08:00
if records == nil {
records = make(map[string]string)
}
2023-02-17 04:13:40 +08:00
currentKeys := make(map[string]*models.EnrollmentKey, 0)
2023-02-15 06:21:51 +08:00
if len(records) > 0 {
for k := range records {
var currentKey models.EnrollmentKey
if err = json.Unmarshal([]byte(records[k]), &currentKey); err != nil {
continue
}
currentKeys[k] = &currentKey
}
}
return currentKeys, nil
}