mirror of
https://github.com/gravitl/netmaker.git
synced 2025-09-05 20:54:18 +08:00
organized http logic, renamed files
This commit is contained in:
parent
6184d0b965
commit
0c6c09caa9
26 changed files with 1287 additions and 1485 deletions
189
controllers/dns.go
Normal file
189
controllers/dns.go
Normal file
|
@ -0,0 +1,189 @@
|
|||
package controller
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/gravitl/netmaker/database"
|
||||
"github.com/gravitl/netmaker/logger"
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
)
|
||||
|
||||
func dnsHandlers(r *mux.Router) {
|
||||
|
||||
r.HandleFunc("/api/dns", securityCheck(true, http.HandlerFunc(getAllDNS))).Methods("GET")
|
||||
r.HandleFunc("/api/dns/adm/{network}/nodes", securityCheck(false, http.HandlerFunc(getNodeDNS))).Methods("GET")
|
||||
r.HandleFunc("/api/dns/adm/{network}/custom", securityCheck(false, http.HandlerFunc(getCustomDNS))).Methods("GET")
|
||||
r.HandleFunc("/api/dns/adm/{network}", securityCheck(false, http.HandlerFunc(getDNS))).Methods("GET")
|
||||
r.HandleFunc("/api/dns/{network}", securityCheck(false, http.HandlerFunc(createDNS))).Methods("POST")
|
||||
r.HandleFunc("/api/dns/adm/pushdns", securityCheck(false, http.HandlerFunc(pushDNS))).Methods("POST")
|
||||
r.HandleFunc("/api/dns/{network}/{domain}", securityCheck(false, http.HandlerFunc(deleteDNS))).Methods("DELETE")
|
||||
}
|
||||
|
||||
//Gets all nodes associated with network, including pending nodes
|
||||
func getNodeDNS(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
var dns []models.DNSEntry
|
||||
var params = mux.Vars(r)
|
||||
|
||||
dns, err := logic.GetNodeDNS(params["network"])
|
||||
if err != nil {
|
||||
returnErrorResponse(w, r, formatError(err, "internal"))
|
||||
return
|
||||
}
|
||||
|
||||
//Returns all the nodes in JSON format
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(dns)
|
||||
}
|
||||
|
||||
//Gets all nodes associated with network, including pending nodes
|
||||
func getAllDNS(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
dns, err := logic.GetAllDNS()
|
||||
if err != nil {
|
||||
returnErrorResponse(w, r, formatError(err, "internal"))
|
||||
return
|
||||
}
|
||||
//Returns all the nodes in JSON format
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(dns)
|
||||
}
|
||||
|
||||
//Gets all nodes associated with network, including pending nodes
|
||||
func getCustomDNS(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
var dns []models.DNSEntry
|
||||
var params = mux.Vars(r)
|
||||
|
||||
dns, err := logic.GetCustomDNS(params["network"])
|
||||
if err != nil {
|
||||
returnErrorResponse(w, r, formatError(err, "internal"))
|
||||
return
|
||||
}
|
||||
|
||||
//Returns all the nodes in JSON format
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(dns)
|
||||
}
|
||||
|
||||
// Gets all nodes associated with network, including pending nodes
|
||||
func getDNS(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
var dns []models.DNSEntry
|
||||
var params = mux.Vars(r)
|
||||
|
||||
dns, err := logic.GetDNS(params["network"])
|
||||
if err != nil {
|
||||
returnErrorResponse(w, r, formatError(err, "internal"))
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(dns)
|
||||
}
|
||||
|
||||
func createDNS(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
var entry models.DNSEntry
|
||||
var params = mux.Vars(r)
|
||||
|
||||
//get node from body of request
|
||||
_ = json.NewDecoder(r.Body).Decode(&entry)
|
||||
entry.Network = params["network"]
|
||||
|
||||
err := logic.ValidateDNSCreate(entry)
|
||||
if err != nil {
|
||||
returnErrorResponse(w, r, formatError(err, "badrequest"))
|
||||
return
|
||||
}
|
||||
|
||||
entry, err = CreateDNS(entry)
|
||||
if err != nil {
|
||||
returnErrorResponse(w, r, formatError(err, "internal"))
|
||||
return
|
||||
}
|
||||
err = logic.SetDNS()
|
||||
if err != nil {
|
||||
returnErrorResponse(w, r, formatError(err, "internal"))
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(entry)
|
||||
}
|
||||
|
||||
func deleteDNS(w http.ResponseWriter, r *http.Request) {
|
||||
// Set header
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
// get params
|
||||
var params = mux.Vars(r)
|
||||
|
||||
err := logic.DeleteDNS(params["domain"], params["network"])
|
||||
|
||||
if err != nil {
|
||||
returnErrorResponse(w, r, formatError(err, "internal"))
|
||||
return
|
||||
}
|
||||
entrytext := params["domain"] + "." + params["network"]
|
||||
logger.Log(1, "deleted dns entry: ", entrytext)
|
||||
err = logic.SetDNS()
|
||||
if err != nil {
|
||||
returnErrorResponse(w, r, formatError(err, "internal"))
|
||||
return
|
||||
}
|
||||
json.NewEncoder(w).Encode(entrytext + " deleted.")
|
||||
}
|
||||
|
||||
// CreateDNS - creates a DNS entry
|
||||
func CreateDNS(entry models.DNSEntry) (models.DNSEntry, error) {
|
||||
|
||||
data, err := json.Marshal(&entry)
|
||||
if err != nil {
|
||||
return models.DNSEntry{}, err
|
||||
}
|
||||
key, err := logic.GetRecordKey(entry.Name, entry.Network)
|
||||
if err != nil {
|
||||
return models.DNSEntry{}, err
|
||||
}
|
||||
err = database.Insert(key, string(data), database.DNS_TABLE_NAME)
|
||||
|
||||
return entry, err
|
||||
}
|
||||
|
||||
// GetDNSEntry - gets a DNS entry
|
||||
func GetDNSEntry(domain string, network string) (models.DNSEntry, error) {
|
||||
var entry models.DNSEntry
|
||||
key, err := logic.GetRecordKey(domain, network)
|
||||
if err != nil {
|
||||
return entry, err
|
||||
}
|
||||
record, err := database.FetchRecord(database.DNS_TABLE_NAME, key)
|
||||
if err != nil {
|
||||
return entry, err
|
||||
}
|
||||
err = json.Unmarshal([]byte(record), &entry)
|
||||
return entry, err
|
||||
}
|
||||
|
||||
func pushDNS(w http.ResponseWriter, r *http.Request) {
|
||||
// Set header
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
err := logic.SetDNS()
|
||||
|
||||
if err != nil {
|
||||
returnErrorResponse(w, r, formatError(err, "internal"))
|
||||
return
|
||||
}
|
||||
logger.Log(1, r.Header.Get("user"), "pushed DNS updates to nameserver")
|
||||
json.NewEncoder(w).Encode("DNS Pushed to CoreDNS")
|
||||
}
|
|
@ -1,396 +0,0 @@
|
|||
package controller
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/gravitl/netmaker/database"
|
||||
"github.com/gravitl/netmaker/logger"
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
)
|
||||
|
||||
func dnsHandlers(r *mux.Router) {
|
||||
|
||||
r.HandleFunc("/api/dns", securityCheck(true, http.HandlerFunc(getAllDNS))).Methods("GET")
|
||||
r.HandleFunc("/api/dns/adm/{network}/nodes", securityCheck(false, http.HandlerFunc(getNodeDNS))).Methods("GET")
|
||||
r.HandleFunc("/api/dns/adm/{network}/custom", securityCheck(false, http.HandlerFunc(getCustomDNS))).Methods("GET")
|
||||
r.HandleFunc("/api/dns/adm/{network}", securityCheck(false, http.HandlerFunc(getDNS))).Methods("GET")
|
||||
r.HandleFunc("/api/dns/{network}", securityCheck(false, http.HandlerFunc(createDNS))).Methods("POST")
|
||||
r.HandleFunc("/api/dns/adm/pushdns", securityCheck(false, http.HandlerFunc(pushDNS))).Methods("POST")
|
||||
r.HandleFunc("/api/dns/{network}/{domain}", securityCheck(false, http.HandlerFunc(deleteDNS))).Methods("DELETE")
|
||||
r.HandleFunc("/api/dns/{network}/{domain}", securityCheck(false, http.HandlerFunc(updateDNS))).Methods("PUT")
|
||||
}
|
||||
|
||||
//Gets all nodes associated with network, including pending nodes
|
||||
func getNodeDNS(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
var dns []models.DNSEntry
|
||||
var params = mux.Vars(r)
|
||||
|
||||
dns, err := GetNodeDNS(params["network"])
|
||||
if err != nil {
|
||||
returnErrorResponse(w, r, formatError(err, "internal"))
|
||||
return
|
||||
}
|
||||
|
||||
//Returns all the nodes in JSON format
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(dns)
|
||||
}
|
||||
|
||||
//Gets all nodes associated with network, including pending nodes
|
||||
func getAllDNS(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
dns, err := GetAllDNS()
|
||||
if err != nil {
|
||||
returnErrorResponse(w, r, formatError(err, "internal"))
|
||||
return
|
||||
}
|
||||
//Returns all the nodes in JSON format
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(dns)
|
||||
}
|
||||
|
||||
// GetAllDNS - gets all dns entries
|
||||
func GetAllDNS() ([]models.DNSEntry, error) {
|
||||
var dns []models.DNSEntry
|
||||
networks, err := logic.GetNetworks()
|
||||
if err != nil && !database.IsEmptyRecord(err) {
|
||||
return []models.DNSEntry{}, err
|
||||
}
|
||||
for _, net := range networks {
|
||||
netdns, err := logic.GetDNS(net.NetID)
|
||||
if err != nil {
|
||||
return []models.DNSEntry{}, nil
|
||||
}
|
||||
dns = append(dns, netdns...)
|
||||
}
|
||||
return dns, nil
|
||||
}
|
||||
|
||||
// GetNodeDNS - gets node dns
|
||||
func GetNodeDNS(network string) ([]models.DNSEntry, error) {
|
||||
|
||||
var dns []models.DNSEntry
|
||||
|
||||
collection, err := database.FetchRecords(database.NODES_TABLE_NAME)
|
||||
if err != nil {
|
||||
return dns, err
|
||||
}
|
||||
|
||||
for _, value := range collection {
|
||||
var entry models.DNSEntry
|
||||
var node models.Node
|
||||
if err = json.Unmarshal([]byte(value), &node); err != nil {
|
||||
continue
|
||||
}
|
||||
if err = json.Unmarshal([]byte(value), &entry); node.Network == network && err == nil {
|
||||
dns = append(dns, entry)
|
||||
}
|
||||
}
|
||||
|
||||
return dns, nil
|
||||
}
|
||||
|
||||
//Gets all nodes associated with network, including pending nodes
|
||||
func getCustomDNS(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
var dns []models.DNSEntry
|
||||
var params = mux.Vars(r)
|
||||
|
||||
dns, err := logic.GetCustomDNS(params["network"])
|
||||
if err != nil {
|
||||
returnErrorResponse(w, r, formatError(err, "internal"))
|
||||
return
|
||||
}
|
||||
|
||||
//Returns all the nodes in JSON format
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(dns)
|
||||
}
|
||||
|
||||
// GetDNSEntryNum - gets which entry the dns was
|
||||
func GetDNSEntryNum(domain string, network string) (int, error) {
|
||||
|
||||
num := 0
|
||||
|
||||
entries, err := logic.GetDNS(network)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
for i := 0; i < len(entries); i++ {
|
||||
|
||||
if domain == entries[i].Name {
|
||||
num++
|
||||
}
|
||||
}
|
||||
|
||||
return num, nil
|
||||
}
|
||||
|
||||
// Gets all nodes associated with network, including pending nodes
|
||||
func getDNS(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
var dns []models.DNSEntry
|
||||
var params = mux.Vars(r)
|
||||
|
||||
dns, err := logic.GetDNS(params["network"])
|
||||
if err != nil {
|
||||
returnErrorResponse(w, r, formatError(err, "internal"))
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(dns)
|
||||
}
|
||||
|
||||
func createDNS(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
var entry models.DNSEntry
|
||||
var params = mux.Vars(r)
|
||||
|
||||
//get node from body of request
|
||||
_ = json.NewDecoder(r.Body).Decode(&entry)
|
||||
entry.Network = params["network"]
|
||||
|
||||
err := ValidateDNSCreate(entry)
|
||||
if err != nil {
|
||||
returnErrorResponse(w, r, formatError(err, "badrequest"))
|
||||
return
|
||||
}
|
||||
|
||||
entry, err = CreateDNS(entry)
|
||||
if err != nil {
|
||||
returnErrorResponse(w, r, formatError(err, "internal"))
|
||||
return
|
||||
}
|
||||
err = logic.SetDNS()
|
||||
if err != nil {
|
||||
returnErrorResponse(w, r, formatError(err, "internal"))
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(entry)
|
||||
}
|
||||
|
||||
func updateDNS(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
var params = mux.Vars(r)
|
||||
|
||||
var entry models.DNSEntry
|
||||
|
||||
//start here
|
||||
entry, err := GetDNSEntry(params["domain"], params["network"])
|
||||
if err != nil {
|
||||
returnErrorResponse(w, r, formatError(err, "badrequest"))
|
||||
return
|
||||
}
|
||||
|
||||
var dnschange models.DNSEntry
|
||||
|
||||
// we decode our body request params
|
||||
err = json.NewDecoder(r.Body).Decode(&dnschange)
|
||||
if err != nil {
|
||||
returnErrorResponse(w, r, formatError(err, "badrequest"))
|
||||
return
|
||||
}
|
||||
// fill in any missing fields
|
||||
if dnschange.Name == "" {
|
||||
dnschange.Name = entry.Name
|
||||
}
|
||||
if dnschange.Network == "" {
|
||||
dnschange.Network = entry.Network
|
||||
}
|
||||
if dnschange.Address == "" {
|
||||
dnschange.Address = entry.Address
|
||||
}
|
||||
|
||||
err = ValidateDNSUpdate(dnschange, entry)
|
||||
|
||||
if err != nil {
|
||||
returnErrorResponse(w, r, formatError(err, "badrequest"))
|
||||
return
|
||||
}
|
||||
|
||||
entry, err = UpdateDNS(dnschange, entry)
|
||||
|
||||
if err != nil {
|
||||
returnErrorResponse(w, r, formatError(err, "badrequest"))
|
||||
return
|
||||
}
|
||||
err = logic.SetDNS()
|
||||
if err != nil {
|
||||
returnErrorResponse(w, r, formatError(err, "internal"))
|
||||
return
|
||||
}
|
||||
json.NewEncoder(w).Encode(entry)
|
||||
}
|
||||
|
||||
func deleteDNS(w http.ResponseWriter, r *http.Request) {
|
||||
// Set header
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
// get params
|
||||
var params = mux.Vars(r)
|
||||
|
||||
err := DeleteDNS(params["domain"], params["network"])
|
||||
|
||||
if err != nil {
|
||||
returnErrorResponse(w, r, formatError(err, "internal"))
|
||||
return
|
||||
}
|
||||
entrytext := params["domain"] + "." + params["network"]
|
||||
logger.Log(1, "deleted dns entry: ", entrytext)
|
||||
err = logic.SetDNS()
|
||||
if err != nil {
|
||||
returnErrorResponse(w, r, formatError(err, "internal"))
|
||||
return
|
||||
}
|
||||
json.NewEncoder(w).Encode(entrytext + " deleted.")
|
||||
}
|
||||
|
||||
// CreateDNS - creates a DNS entry
|
||||
func CreateDNS(entry models.DNSEntry) (models.DNSEntry, error) {
|
||||
|
||||
data, err := json.Marshal(&entry)
|
||||
if err != nil {
|
||||
return models.DNSEntry{}, err
|
||||
}
|
||||
key, err := logic.GetRecordKey(entry.Name, entry.Network)
|
||||
if err != nil {
|
||||
return models.DNSEntry{}, err
|
||||
}
|
||||
err = database.Insert(key, string(data), database.DNS_TABLE_NAME)
|
||||
|
||||
return entry, err
|
||||
}
|
||||
|
||||
// GetDNSEntry - gets a DNS entry
|
||||
func GetDNSEntry(domain string, network string) (models.DNSEntry, error) {
|
||||
var entry models.DNSEntry
|
||||
key, err := logic.GetRecordKey(domain, network)
|
||||
if err != nil {
|
||||
return entry, err
|
||||
}
|
||||
record, err := database.FetchRecord(database.DNS_TABLE_NAME, key)
|
||||
if err != nil {
|
||||
return entry, err
|
||||
}
|
||||
err = json.Unmarshal([]byte(record), &entry)
|
||||
return entry, err
|
||||
}
|
||||
|
||||
// UpdateDNS - updates DNS entry
|
||||
func UpdateDNS(dnschange models.DNSEntry, entry models.DNSEntry) (models.DNSEntry, error) {
|
||||
|
||||
key, err := logic.GetRecordKey(entry.Name, entry.Network)
|
||||
if err != nil {
|
||||
return entry, err
|
||||
}
|
||||
if dnschange.Name != "" {
|
||||
entry.Name = dnschange.Name
|
||||
}
|
||||
if dnschange.Address != "" {
|
||||
entry.Address = dnschange.Address
|
||||
}
|
||||
newkey, err := logic.GetRecordKey(entry.Name, entry.Network)
|
||||
|
||||
err = database.DeleteRecord(database.DNS_TABLE_NAME, key)
|
||||
if err != nil {
|
||||
return entry, err
|
||||
}
|
||||
|
||||
data, err := json.Marshal(&entry)
|
||||
err = database.Insert(newkey, string(data), database.DNS_TABLE_NAME)
|
||||
return entry, err
|
||||
}
|
||||
|
||||
// DeleteDNS - deletes a DNS entry
|
||||
func DeleteDNS(domain string, network string) error {
|
||||
key, err := logic.GetRecordKey(domain, network)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = database.DeleteRecord(database.DNS_TABLE_NAME, key)
|
||||
return err
|
||||
}
|
||||
|
||||
func pushDNS(w http.ResponseWriter, r *http.Request) {
|
||||
// Set header
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
err := logic.SetDNS()
|
||||
|
||||
if err != nil {
|
||||
returnErrorResponse(w, r, formatError(err, "internal"))
|
||||
return
|
||||
}
|
||||
logger.Log(1, r.Header.Get("user"), "pushed DNS updates to nameserver")
|
||||
json.NewEncoder(w).Encode("DNS Pushed to CoreDNS")
|
||||
}
|
||||
|
||||
// ValidateDNSCreate - checks if an entry is valid
|
||||
func ValidateDNSCreate(entry models.DNSEntry) error {
|
||||
|
||||
v := validator.New()
|
||||
|
||||
_ = v.RegisterValidation("name_unique", func(fl validator.FieldLevel) bool {
|
||||
num, err := GetDNSEntryNum(entry.Name, entry.Network)
|
||||
return err == nil && num == 0
|
||||
})
|
||||
|
||||
_ = v.RegisterValidation("network_exists", func(fl validator.FieldLevel) bool {
|
||||
_, err := logic.GetParentNetwork(entry.Network)
|
||||
return err == nil
|
||||
})
|
||||
|
||||
err := v.Struct(entry)
|
||||
if err != nil {
|
||||
for _, e := range err.(validator.ValidationErrors) {
|
||||
logger.Log(1, e.Error())
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// ValidateDNSUpdate - validates a DNS update
|
||||
func ValidateDNSUpdate(change models.DNSEntry, entry models.DNSEntry) error {
|
||||
|
||||
v := validator.New()
|
||||
|
||||
_ = v.RegisterValidation("name_unique", func(fl validator.FieldLevel) bool {
|
||||
//if name & net not changing name we are good
|
||||
if change.Name == entry.Name && change.Network == entry.Network {
|
||||
return true
|
||||
}
|
||||
num, err := GetDNSEntryNum(change.Name, change.Network)
|
||||
return err == nil && num == 0
|
||||
})
|
||||
_ = v.RegisterValidation("network_exists", func(fl validator.FieldLevel) bool {
|
||||
_, err := logic.GetParentNetwork(change.Network)
|
||||
if err != nil {
|
||||
logger.Log(0, err.Error())
|
||||
}
|
||||
return err == nil
|
||||
})
|
||||
|
||||
err := v.Struct(change)
|
||||
|
||||
if err != nil {
|
||||
for _, e := range err.(validator.ValidationErrors) {
|
||||
logger.Log(1, e.Error())
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
|
@ -17,21 +17,21 @@ func TestGetAllDNS(t *testing.T) {
|
|||
deleteAllNetworks()
|
||||
createNet()
|
||||
t.Run("NoEntries", func(t *testing.T) {
|
||||
entries, err := GetAllDNS()
|
||||
entries, err := logic.GetAllDNS()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, []models.DNSEntry(nil), entries)
|
||||
})
|
||||
t.Run("OneEntry", func(t *testing.T) {
|
||||
entry := models.DNSEntry{"10.0.0.3", "newhost", "skynet"}
|
||||
CreateDNS(entry)
|
||||
entries, err := GetAllDNS()
|
||||
entries, err := logic.GetAllDNS()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 1, len(entries))
|
||||
})
|
||||
t.Run("MultipleEntry", func(t *testing.T) {
|
||||
entry := models.DNSEntry{"10.0.0.7", "anotherhost", "skynet"}
|
||||
CreateDNS(entry)
|
||||
entries, err := GetAllDNS()
|
||||
entries, err := logic.GetAllDNS()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 2, len(entries))
|
||||
})
|
||||
|
@ -43,13 +43,13 @@ func TestGetNodeDNS(t *testing.T) {
|
|||
deleteAllNetworks()
|
||||
createNet()
|
||||
t.Run("NoNodes", func(t *testing.T) {
|
||||
dns, err := GetNodeDNS("skynet")
|
||||
dns, err := logic.GetNodeDNS("skynet")
|
||||
assert.EqualError(t, err, "could not find any records")
|
||||
assert.Equal(t, []models.DNSEntry(nil), dns)
|
||||
})
|
||||
t.Run("NodeExists", func(t *testing.T) {
|
||||
createTestNode()
|
||||
dns, err := GetNodeDNS("skynet")
|
||||
dns, err := logic.GetNodeDNS("skynet")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "10.0.0.1", dns[0].Address)
|
||||
})
|
||||
|
@ -57,7 +57,7 @@ func TestGetNodeDNS(t *testing.T) {
|
|||
createnode := models.Node{PublicKey: "DM5qhLAE20PG9BbfBCger+Ac9D2NDOwCtY1rbYDLf34=", Endpoint: "10.100.100.3", MacAddress: "01:02:03:04:05:07", Password: "password", Network: "skynet"}
|
||||
_, err := logic.CreateNode(createnode, "skynet")
|
||||
assert.Nil(t, err)
|
||||
dns, err := GetNodeDNS("skynet")
|
||||
dns, err := logic.GetNodeDNS("skynet")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 2, len(dns))
|
||||
})
|
||||
|
@ -105,7 +105,7 @@ func TestGetDNSEntryNum(t *testing.T) {
|
|||
deleteAllNetworks()
|
||||
createNet()
|
||||
t.Run("NoNodes", func(t *testing.T) {
|
||||
num, err := GetDNSEntryNum("myhost", "skynet")
|
||||
num, err := logic.GetDNSEntryNum("myhost", "skynet")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 0, num)
|
||||
})
|
||||
|
@ -113,7 +113,7 @@ func TestGetDNSEntryNum(t *testing.T) {
|
|||
entry := models.DNSEntry{"10.0.0.2", "newhost", "skynet"}
|
||||
_, err := CreateDNS(entry)
|
||||
assert.Nil(t, err)
|
||||
num, err := GetDNSEntryNum("newhost", "skynet")
|
||||
num, err := logic.GetDNSEntryNum("newhost", "skynet")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 1, num)
|
||||
})
|
||||
|
@ -248,33 +248,34 @@ func TestGetDNSEntry(t *testing.T) {
|
|||
assert.Equal(t, models.DNSEntry{}, entry)
|
||||
})
|
||||
}
|
||||
func TestUpdateDNS(t *testing.T) {
|
||||
var newentry models.DNSEntry
|
||||
database.InitializeDatabase()
|
||||
deleteAllDNS(t)
|
||||
deleteAllNetworks()
|
||||
createNet()
|
||||
entry := models.DNSEntry{"10.0.0.2", "newhost", "skynet"}
|
||||
CreateDNS(entry)
|
||||
t.Run("change address", func(t *testing.T) {
|
||||
newentry.Address = "10.0.0.75"
|
||||
updated, err := UpdateDNS(newentry, entry)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, newentry.Address, updated.Address)
|
||||
})
|
||||
t.Run("change name", func(t *testing.T) {
|
||||
newentry.Name = "newname"
|
||||
updated, err := UpdateDNS(newentry, entry)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, newentry.Name, updated.Name)
|
||||
})
|
||||
t.Run("change network", func(t *testing.T) {
|
||||
newentry.Network = "wirecat"
|
||||
updated, err := UpdateDNS(newentry, entry)
|
||||
assert.Nil(t, err)
|
||||
assert.NotEqual(t, newentry.Network, updated.Network)
|
||||
})
|
||||
}
|
||||
|
||||
// func TestUpdateDNS(t *testing.T) {
|
||||
// var newentry models.DNSEntry
|
||||
// database.InitializeDatabase()
|
||||
// deleteAllDNS(t)
|
||||
// deleteAllNetworks()
|
||||
// createNet()
|
||||
// entry := models.DNSEntry{"10.0.0.2", "newhost", "skynet"}
|
||||
// CreateDNS(entry)
|
||||
// t.Run("change address", func(t *testing.T) {
|
||||
// newentry.Address = "10.0.0.75"
|
||||
// updated, err := UpdateDNS(newentry, entry)
|
||||
// assert.Nil(t, err)
|
||||
// assert.Equal(t, newentry.Address, updated.Address)
|
||||
// })
|
||||
// t.Run("change name", func(t *testing.T) {
|
||||
// newentry.Name = "newname"
|
||||
// updated, err := UpdateDNS(newentry, entry)
|
||||
// assert.Nil(t, err)
|
||||
// assert.Equal(t, newentry.Name, updated.Name)
|
||||
// })
|
||||
// t.Run("change network", func(t *testing.T) {
|
||||
// newentry.Network = "wirecat"
|
||||
// updated, err := UpdateDNS(newentry, entry)
|
||||
// assert.Nil(t, err)
|
||||
// assert.NotEqual(t, newentry.Network, updated.Network)
|
||||
// })
|
||||
// }
|
||||
func TestDeleteDNS(t *testing.T) {
|
||||
database.InitializeDatabase()
|
||||
deleteAllDNS(t)
|
||||
|
@ -283,16 +284,16 @@ func TestDeleteDNS(t *testing.T) {
|
|||
entry := models.DNSEntry{"10.0.0.2", "newhost", "skynet"}
|
||||
CreateDNS(entry)
|
||||
t.Run("EntryExists", func(t *testing.T) {
|
||||
err := DeleteDNS("newhost", "skynet")
|
||||
err := logic.DeleteDNS("newhost", "skynet")
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
t.Run("NodeExists", func(t *testing.T) {
|
||||
err := DeleteDNS("myhost", "skynet")
|
||||
err := logic.DeleteDNS("myhost", "skynet")
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
|
||||
t.Run("NoEntries", func(t *testing.T) {
|
||||
err := DeleteDNS("myhost", "skynet")
|
||||
err := logic.DeleteDNS("myhost", "skynet")
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
}
|
||||
|
@ -305,34 +306,34 @@ func TestValidateDNSUpdate(t *testing.T) {
|
|||
entry := models.DNSEntry{"10.0.0.2", "myhost", "skynet"}
|
||||
t.Run("BadNetwork", func(t *testing.T) {
|
||||
change := models.DNSEntry{"10.0.0.2", "myhost", "badnet"}
|
||||
err := ValidateDNSUpdate(change, entry)
|
||||
err := logic.ValidateDNSUpdate(change, entry)
|
||||
assert.NotNil(t, err)
|
||||
assert.Contains(t, err.Error(), "Field validation for 'Network' failed on the 'network_exists' tag")
|
||||
})
|
||||
t.Run("EmptyNetwork", func(t *testing.T) {
|
||||
//this can't actually happen as change.Network is populated if is blank
|
||||
change := models.DNSEntry{"10.0.0.2", "myhost", ""}
|
||||
err := ValidateDNSUpdate(change, entry)
|
||||
err := logic.ValidateDNSUpdate(change, entry)
|
||||
assert.NotNil(t, err)
|
||||
assert.Contains(t, err.Error(), "Field validation for 'Network' failed on the 'network_exists' tag")
|
||||
})
|
||||
t.Run("EmptyAddress", func(t *testing.T) {
|
||||
//this can't actually happen as change.Address is populated if is blank
|
||||
change := models.DNSEntry{"", "myhost", "skynet"}
|
||||
err := ValidateDNSUpdate(change, entry)
|
||||
err := logic.ValidateDNSUpdate(change, entry)
|
||||
assert.NotNil(t, err)
|
||||
assert.Contains(t, err.Error(), "Field validation for 'Address' failed on the 'required' tag")
|
||||
})
|
||||
t.Run("BadAddress", func(t *testing.T) {
|
||||
change := models.DNSEntry{"10.0.256.1", "myhost", "skynet"}
|
||||
err := ValidateDNSUpdate(change, entry)
|
||||
err := logic.ValidateDNSUpdate(change, entry)
|
||||
assert.NotNil(t, err)
|
||||
assert.Contains(t, err.Error(), "Field validation for 'Address' failed on the 'ip' tag")
|
||||
})
|
||||
t.Run("EmptyName", func(t *testing.T) {
|
||||
//this can't actually happen as change.Name is populated if is blank
|
||||
change := models.DNSEntry{"10.0.0.2", "", "skynet"}
|
||||
err := ValidateDNSUpdate(change, entry)
|
||||
err := logic.ValidateDNSUpdate(change, entry)
|
||||
assert.NotNil(t, err)
|
||||
assert.Contains(t, err.Error(), "Field validation for 'Name' failed on the 'required' tag")
|
||||
})
|
||||
|
@ -342,7 +343,7 @@ func TestValidateDNSUpdate(t *testing.T) {
|
|||
name = name + "a"
|
||||
}
|
||||
change := models.DNSEntry{"10.0.0.2", name, "skynet"}
|
||||
err := ValidateDNSUpdate(change, entry)
|
||||
err := logic.ValidateDNSUpdate(change, entry)
|
||||
assert.NotNil(t, err)
|
||||
assert.Contains(t, err.Error(), "Field validation for 'Name' failed on the 'max' tag")
|
||||
})
|
||||
|
@ -350,39 +351,39 @@ func TestValidateDNSUpdate(t *testing.T) {
|
|||
change := models.DNSEntry{"10.0.0.2", "myhost", "wirecat"}
|
||||
CreateDNS(entry)
|
||||
CreateDNS(change)
|
||||
err := ValidateDNSUpdate(change, entry)
|
||||
err := logic.ValidateDNSUpdate(change, entry)
|
||||
assert.NotNil(t, err)
|
||||
assert.Contains(t, err.Error(), "Field validation for 'Name' failed on the 'name_unique' tag")
|
||||
//cleanup
|
||||
err = DeleteDNS("myhost", "wirecat")
|
||||
err = logic.DeleteDNS("myhost", "wirecat")
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
|
||||
}
|
||||
func TestValidateDNSCreate(t *testing.T) {
|
||||
database.InitializeDatabase()
|
||||
_ = DeleteDNS("mynode", "skynet")
|
||||
_ = logic.DeleteDNS("mynode", "skynet")
|
||||
t.Run("NoNetwork", func(t *testing.T) {
|
||||
entry := models.DNSEntry{"10.0.0.2", "myhost", "badnet"}
|
||||
err := ValidateDNSCreate(entry)
|
||||
err := logic.ValidateDNSCreate(entry)
|
||||
assert.NotNil(t, err)
|
||||
assert.Contains(t, err.Error(), "Field validation for 'Network' failed on the 'network_exists' tag")
|
||||
})
|
||||
t.Run("EmptyAddress", func(t *testing.T) {
|
||||
entry := models.DNSEntry{"", "myhost", "skynet"}
|
||||
err := ValidateDNSCreate(entry)
|
||||
err := logic.ValidateDNSCreate(entry)
|
||||
assert.NotNil(t, err)
|
||||
assert.Contains(t, err.Error(), "Field validation for 'Address' failed on the 'required' tag")
|
||||
})
|
||||
t.Run("BadAddress", func(t *testing.T) {
|
||||
entry := models.DNSEntry{"10.0.256.1", "myhost", "skynet"}
|
||||
err := ValidateDNSCreate(entry)
|
||||
err := logic.ValidateDNSCreate(entry)
|
||||
assert.NotNil(t, err)
|
||||
assert.Contains(t, err.Error(), "Field validation for 'Address' failed on the 'ip' tag")
|
||||
})
|
||||
t.Run("EmptyName", func(t *testing.T) {
|
||||
entry := models.DNSEntry{"10.0.0.2", "", "skynet"}
|
||||
err := ValidateDNSCreate(entry)
|
||||
err := logic.ValidateDNSCreate(entry)
|
||||
assert.NotNil(t, err)
|
||||
assert.Contains(t, err.Error(), "Field validation for 'Name' failed on the 'required' tag")
|
||||
})
|
||||
|
@ -392,24 +393,24 @@ func TestValidateDNSCreate(t *testing.T) {
|
|||
name = name + "a"
|
||||
}
|
||||
entry := models.DNSEntry{"10.0.0.2", name, "skynet"}
|
||||
err := ValidateDNSCreate(entry)
|
||||
err := logic.ValidateDNSCreate(entry)
|
||||
assert.NotNil(t, err)
|
||||
assert.Contains(t, err.Error(), "Field validation for 'Name' failed on the 'max' tag")
|
||||
})
|
||||
t.Run("NameUnique", func(t *testing.T) {
|
||||
entry := models.DNSEntry{"10.0.0.2", "myhost", "skynet"}
|
||||
_, _ = CreateDNS(entry)
|
||||
err := ValidateDNSCreate(entry)
|
||||
err := logic.ValidateDNSCreate(entry)
|
||||
assert.NotNil(t, err)
|
||||
assert.Contains(t, err.Error(), "Field validation for 'Name' failed on the 'name_unique' tag")
|
||||
})
|
||||
}
|
||||
|
||||
func deleteAllDNS(t *testing.T) {
|
||||
dns, err := GetAllDNS()
|
||||
dns, err := logic.GetAllDNS()
|
||||
assert.Nil(t, err)
|
||||
for _, record := range dns {
|
||||
err := DeleteDNS(record.Name, record.Network)
|
||||
err := logic.DeleteDNS(record.Name, record.Network)
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
}
|
|
@ -7,7 +7,6 @@ import (
|
|||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/gravitl/netmaker/database"
|
||||
|
@ -16,7 +15,6 @@ import (
|
|||
"github.com/gravitl/netmaker/logic"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/skip2/go-qrcode"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
func extClientHandlers(r *mux.Router) {
|
||||
|
@ -45,7 +43,7 @@ func getNetworkExtClients(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
var extclients []models.ExtClient
|
||||
var params = mux.Vars(r)
|
||||
extclients, err := GetNetworkExtClients(params["network"])
|
||||
extclients, err := logic.GetNetworkExtClients(params["network"])
|
||||
if err != nil {
|
||||
returnErrorResponse(w, r, formatError(err, "internal"))
|
||||
return
|
||||
|
@ -56,27 +54,6 @@ func getNetworkExtClients(w http.ResponseWriter, r *http.Request) {
|
|||
json.NewEncoder(w).Encode(extclients)
|
||||
}
|
||||
|
||||
// GetNetworkExtClients - gets the ext clients of given network
|
||||
func GetNetworkExtClients(network string) ([]models.ExtClient, error) {
|
||||
var extclients []models.ExtClient
|
||||
|
||||
records, err := database.FetchRecords(database.EXT_CLIENT_TABLE_NAME)
|
||||
if err != nil {
|
||||
return extclients, err
|
||||
}
|
||||
for _, value := range records {
|
||||
var extclient models.ExtClient
|
||||
err = json.Unmarshal([]byte(value), &extclient)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if extclient.Network == network {
|
||||
extclients = append(extclients, extclient)
|
||||
}
|
||||
}
|
||||
return extclients, err
|
||||
}
|
||||
|
||||
//A separate function to get all extclients, not just extclients for a particular network.
|
||||
//Not quite sure if this is necessary. Probably necessary based on front end but may want to review after iteration 1 if it's being used or not
|
||||
func getAllExtClients(w http.ResponseWriter, r *http.Request) {
|
||||
|
@ -100,7 +77,7 @@ func getAllExtClients(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
} else {
|
||||
for _, network := range networksSlice {
|
||||
extclients, err := GetNetworkExtClients(network)
|
||||
extclients, err := logic.GetNetworkExtClients(network)
|
||||
if err == nil {
|
||||
clients = append(clients, extclients...)
|
||||
}
|
||||
|
@ -121,7 +98,7 @@ func getExtClient(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
clientid := params["clientid"]
|
||||
network := params["network"]
|
||||
client, err := GetExtClient(clientid, network)
|
||||
client, err := logic.GetExtClient(clientid, network)
|
||||
if err != nil {
|
||||
returnErrorResponse(w, r, formatError(err, "internal"))
|
||||
return
|
||||
|
@ -131,22 +108,6 @@ func getExtClient(w http.ResponseWriter, r *http.Request) {
|
|||
json.NewEncoder(w).Encode(client)
|
||||
}
|
||||
|
||||
// GetExtClient - gets a single ext client on a network
|
||||
func GetExtClient(clientid string, network string) (models.ExtClient, error) {
|
||||
var extclient models.ExtClient
|
||||
key, err := logic.GetRecordKey(clientid, network)
|
||||
if err != nil {
|
||||
return extclient, err
|
||||
}
|
||||
data, err := database.FetchRecord(database.EXT_CLIENT_TABLE_NAME, key)
|
||||
if err != nil {
|
||||
return extclient, err
|
||||
}
|
||||
err = json.Unmarshal([]byte(data), &extclient)
|
||||
|
||||
return extclient, err
|
||||
}
|
||||
|
||||
//Get an individual extclient. Nothin fancy here folks.
|
||||
func getExtClientConf(w http.ResponseWriter, r *http.Request) {
|
||||
// set header.
|
||||
|
@ -155,7 +116,7 @@ func getExtClientConf(w http.ResponseWriter, r *http.Request) {
|
|||
var params = mux.Vars(r)
|
||||
clientid := params["clientid"]
|
||||
networkid := params["network"]
|
||||
client, err := GetExtClient(clientid, networkid)
|
||||
client, err := logic.GetExtClient(clientid, networkid)
|
||||
if err != nil {
|
||||
returnErrorResponse(w, r, formatError(err, "internal"))
|
||||
return
|
||||
|
@ -240,47 +201,6 @@ Endpoint = %s
|
|||
json.NewEncoder(w).Encode(client)
|
||||
}
|
||||
|
||||
// CreateExtClient - creates an extclient
|
||||
func CreateExtClient(extclient models.ExtClient) error {
|
||||
if extclient.PrivateKey == "" {
|
||||
privateKey, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
extclient.PrivateKey = privateKey.String()
|
||||
extclient.PublicKey = privateKey.PublicKey().String()
|
||||
}
|
||||
|
||||
if extclient.Address == "" {
|
||||
newAddress, err := logic.UniqueAddress(extclient.Network)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
extclient.Address = newAddress
|
||||
}
|
||||
|
||||
if extclient.ClientID == "" {
|
||||
extclient.ClientID = models.GenerateNodeName()
|
||||
}
|
||||
|
||||
extclient.LastModified = time.Now().Unix()
|
||||
|
||||
key, err := logic.GetRecordKey(extclient.ClientID, extclient.Network)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data, err := json.Marshal(&extclient)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err = database.Insert(key, string(data), database.EXT_CLIENT_TABLE_NAME); err != nil {
|
||||
return err
|
||||
}
|
||||
err = logic.SetNetworkNodesLastModified(extclient.Network)
|
||||
return err
|
||||
}
|
||||
|
||||
/**
|
||||
* To create a extclient
|
||||
* Must have valid key and be unique
|
||||
|
@ -312,7 +232,7 @@ func createExtClient(w http.ResponseWriter, r *http.Request) {
|
|||
returnErrorResponse(w, r, formatError(err, "internal"))
|
||||
return
|
||||
}
|
||||
err = CreateExtClient(extclient)
|
||||
err = logic.CreateExtClient(extclient)
|
||||
|
||||
if err != nil {
|
||||
returnErrorResponse(w, r, formatError(err, "internal"))
|
||||
|
@ -344,7 +264,7 @@ func updateExtClient(w http.ResponseWriter, r *http.Request) {
|
|||
returnErrorResponse(w, r, formatError(err, "internal"))
|
||||
return
|
||||
}
|
||||
newclient, err := UpdateExtClient(newExtClient.ClientID, params["network"], oldExtClient)
|
||||
newclient, err := logic.UpdateExtClient(newExtClient.ClientID, params["network"], oldExtClient)
|
||||
if err != nil {
|
||||
returnErrorResponse(w, r, formatError(err, "internal"))
|
||||
return
|
||||
|
@ -354,45 +274,6 @@ func updateExtClient(w http.ResponseWriter, r *http.Request) {
|
|||
json.NewEncoder(w).Encode(newclient)
|
||||
}
|
||||
|
||||
// UpdateExtClient - only supports name changes right now
|
||||
func UpdateExtClient(newclientid string, network string, client models.ExtClient) (models.ExtClient, error) {
|
||||
|
||||
err := DeleteExtClient(network, client.ClientID)
|
||||
if err != nil {
|
||||
return client, err
|
||||
}
|
||||
client.ClientID = newclientid
|
||||
CreateExtClient(client)
|
||||
return client, err
|
||||
}
|
||||
|
||||
// DeleteExtClient - deletes an existing ext client
|
||||
func DeleteExtClient(network string, clientid string) error {
|
||||
key, err := logic.GetRecordKey(clientid, network)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = database.DeleteRecord(database.EXT_CLIENT_TABLE_NAME, key)
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteGatewayExtClients - deletes ext clients based on gateway (mac) of ingress node and network
|
||||
func DeleteGatewayExtClients(gatewayID string, networkName string) error {
|
||||
currentExtClients, err := GetNetworkExtClients(networkName)
|
||||
if err != nil && !database.IsEmptyRecord(err) {
|
||||
return err
|
||||
}
|
||||
for _, extClient := range currentExtClients {
|
||||
if extClient.IngressGatewayID == gatewayID {
|
||||
if err = DeleteExtClient(networkName, extClient.ClientID); err != nil {
|
||||
logger.Log(1, "failed to remove ext client", extClient.ClientID)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
//Delete a extclient
|
||||
//Pretty straightforward
|
||||
func deleteExtClient(w http.ResponseWriter, r *http.Request) {
|
||||
|
@ -402,7 +283,7 @@ func deleteExtClient(w http.ResponseWriter, r *http.Request) {
|
|||
// get params
|
||||
var params = mux.Vars(r)
|
||||
|
||||
err := DeleteExtClient(params["network"], params["clientid"])
|
||||
err := logic.DeleteExtClient(params["network"], params["clientid"])
|
||||
|
||||
if err != nil {
|
||||
err = errors.New("Could not delete extclient " + params["clientid"])
|
|
@ -1,17 +1,13 @@
|
|||
package controller
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/gravitl/netmaker/database"
|
||||
"github.com/gravitl/netmaker/functions"
|
||||
"github.com/gravitl/netmaker/logger"
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
|
@ -32,7 +28,6 @@ func networkHandlers(r *mux.Router) {
|
|||
r.HandleFunc("/api/networks/{networkname}/keyupdate", securityCheck(false, http.HandlerFunc(keyUpdate))).Methods("POST")
|
||||
r.HandleFunc("/api/networks/{networkname}/keys", securityCheck(false, http.HandlerFunc(createAccessKey))).Methods("POST")
|
||||
r.HandleFunc("/api/networks/{networkname}/keys", securityCheck(false, http.HandlerFunc(getAccessKeys))).Methods("GET")
|
||||
r.HandleFunc("/api/networks/{networkname}/signuptoken", securityCheck(false, http.HandlerFunc(getSignupToken))).Methods("GET")
|
||||
r.HandleFunc("/api/networks/{networkname}/keys/{name}", securityCheck(false, http.HandlerFunc(deleteAccessKey))).Methods("DELETE")
|
||||
}
|
||||
|
||||
|
@ -73,34 +68,13 @@ func getNetworks(w http.ResponseWriter, r *http.Request) {
|
|||
json.NewEncoder(w).Encode(allnetworks)
|
||||
}
|
||||
|
||||
func ValidateNetworkUpdate(network models.Network) error {
|
||||
v := validator.New()
|
||||
|
||||
_ = v.RegisterValidation("netid_valid", func(fl validator.FieldLevel) bool {
|
||||
if fl.Field().String() == "" {
|
||||
return true
|
||||
}
|
||||
inCharSet := functions.NameInNetworkCharSet(fl.Field().String())
|
||||
return inCharSet
|
||||
})
|
||||
|
||||
err := v.Struct(network)
|
||||
|
||||
if err != nil {
|
||||
for _, e := range err.(validator.ValidationErrors) {
|
||||
logger.Log(1, "validator", e.Error())
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
//Simple get network function
|
||||
func getNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
// set header.
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
var params = mux.Vars(r)
|
||||
netname := params["networkname"]
|
||||
network, err := GetNetwork(netname)
|
||||
network, err := logic.GetNetwork(netname)
|
||||
if err != nil {
|
||||
returnErrorResponse(w, r, formatError(err, "internal"))
|
||||
return
|
||||
|
@ -113,25 +87,11 @@ func getNetwork(w http.ResponseWriter, r *http.Request) {
|
|||
json.NewEncoder(w).Encode(network)
|
||||
}
|
||||
|
||||
func GetNetwork(name string) (models.Network, error) {
|
||||
var network models.Network
|
||||
record, err := database.FetchRecord(database.NETWORKS_TABLE_NAME, name)
|
||||
if err != nil {
|
||||
return network, err
|
||||
}
|
||||
|
||||
if err = json.Unmarshal([]byte(record), &network); err != nil {
|
||||
return models.Network{}, err
|
||||
}
|
||||
|
||||
return network, nil
|
||||
}
|
||||
|
||||
func keyUpdate(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
var params = mux.Vars(r)
|
||||
netname := params["networkname"]
|
||||
network, err := KeyUpdate(netname)
|
||||
network, err := logic.KeyUpdate(netname)
|
||||
if err != nil {
|
||||
returnErrorResponse(w, r, formatError(err, "internal"))
|
||||
return
|
||||
|
@ -141,33 +101,6 @@ func keyUpdate(w http.ResponseWriter, r *http.Request) {
|
|||
json.NewEncoder(w).Encode(network)
|
||||
}
|
||||
|
||||
func KeyUpdate(netname string) (models.Network, error) {
|
||||
err := functions.NetworkNodesUpdateAction(netname, models.NODE_UPDATE_KEY)
|
||||
if err != nil {
|
||||
return models.Network{}, err
|
||||
}
|
||||
return models.Network{}, nil
|
||||
}
|
||||
|
||||
//Update a network
|
||||
func AlertNetwork(netid string) error {
|
||||
|
||||
var network models.Network
|
||||
network, err := logic.GetParentNetwork(netid)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
updatetime := time.Now().Unix()
|
||||
network.NodesLastModified = updatetime
|
||||
network.NetworkLastModified = updatetime
|
||||
data, err := json.Marshal(&network)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
database.Insert(netid, string(data), database.NETWORKS_TABLE_NAME)
|
||||
return nil
|
||||
}
|
||||
|
||||
//Update a network
|
||||
func updateNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
@ -253,7 +186,7 @@ func deleteNetwork(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
var params = mux.Vars(r)
|
||||
network := params["networkname"]
|
||||
err := DeleteNetwork(network)
|
||||
err := logic.DeleteNetwork(network)
|
||||
|
||||
if err != nil {
|
||||
errtype := "badrequest"
|
||||
|
@ -268,29 +201,6 @@ func deleteNetwork(w http.ResponseWriter, r *http.Request) {
|
|||
json.NewEncoder(w).Encode("success")
|
||||
}
|
||||
|
||||
func DeleteNetwork(network string) error {
|
||||
nodeCount, err := functions.GetNetworkNonServerNodeCount(network)
|
||||
if nodeCount == 0 || database.IsEmptyRecord(err) {
|
||||
// delete server nodes first then db records
|
||||
servers, err := logic.GetSortedNetworkServerNodes(network)
|
||||
if err == nil {
|
||||
for _, s := range servers {
|
||||
if err = logic.DeleteNode(&s, true); err != nil {
|
||||
logger.Log(2, "could not removed server", s.Name, "before deleting network", network)
|
||||
} else {
|
||||
logger.Log(2, "removed server", s.Name, "before deleting network", network)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
logger.Log(1, "could not remove servers before deleting network", network)
|
||||
}
|
||||
return database.DeleteRecord(database.NETWORKS_TABLE_NAME, network)
|
||||
}
|
||||
return errors.New("node check failed. All nodes must be deleted before deleting network")
|
||||
}
|
||||
|
||||
//Create a network
|
||||
//Pretty simple
|
||||
func createNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
@ -304,49 +214,27 @@ func createNetwork(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
err = CreateNetwork(network)
|
||||
err = logic.CreateNetwork(network)
|
||||
if err != nil {
|
||||
returnErrorResponse(w, r, formatError(err, "badrequest"))
|
||||
return
|
||||
}
|
||||
logger.Log(1, r.Header.Get("user"), "created network", network.NetID)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
//json.NewEncoder(w).Encode(result)
|
||||
}
|
||||
|
||||
func CreateNetwork(network models.Network) error {
|
||||
|
||||
network.SetDefaults()
|
||||
network.SetNodesLastModified()
|
||||
network.SetNetworkLastModified()
|
||||
network.KeyUpdateTimeStamp = time.Now().Unix()
|
||||
|
||||
err := logic.ValidateNetwork(&network, false)
|
||||
if err != nil {
|
||||
//returnErrorResponse(w, r, formatError(err, "badrequest"))
|
||||
return err
|
||||
}
|
||||
|
||||
data, err := json.Marshal(&network)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err = database.Insert(network.NetID, string(data), database.NETWORKS_TABLE_NAME); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if servercfg.IsClientMode() != "off" {
|
||||
var success bool
|
||||
success, err = serverctl.AddNetwork(network.NetID)
|
||||
if err != nil || !success {
|
||||
DeleteNetwork(network.NetID)
|
||||
logic.DeleteNetwork(network.NetID)
|
||||
if err == nil {
|
||||
err = errors.New("Failed to add server to network " + network.DisplayName)
|
||||
}
|
||||
returnErrorResponse(w, r, formatError(err, "internal"))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
logger.Log(1, r.Header.Get("user"), "created network", network.NetID)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
// BEGIN KEY MANAGEMENT SECTION
|
||||
|
@ -366,7 +254,7 @@ func createAccessKey(w http.ResponseWriter, r *http.Request) {
|
|||
returnErrorResponse(w, r, formatError(err, "internal"))
|
||||
return
|
||||
}
|
||||
key, err := CreateAccessKey(accesskey, network)
|
||||
key, err := logic.CreateAccessKey(accesskey, network)
|
||||
if err != nil {
|
||||
returnErrorResponse(w, r, formatError(err, "badrequest"))
|
||||
return
|
||||
|
@ -374,131 +262,14 @@ func createAccessKey(w http.ResponseWriter, r *http.Request) {
|
|||
logger.Log(1, r.Header.Get("user"), "created access key", accesskey.Name, "on", netname)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(key)
|
||||
//w.Write([]byte(accesskey.AccessString))
|
||||
}
|
||||
|
||||
func CreateAccessKey(accesskey models.AccessKey, network models.Network) (models.AccessKey, error) {
|
||||
|
||||
if accesskey.Name == "" {
|
||||
accesskey.Name = functions.GenKeyName()
|
||||
}
|
||||
|
||||
if accesskey.Value == "" {
|
||||
accesskey.Value = functions.GenKey()
|
||||
}
|
||||
if accesskey.Uses == 0 {
|
||||
accesskey.Uses = 1
|
||||
}
|
||||
|
||||
checkkeys, err := GetKeys(network.NetID)
|
||||
if err != nil {
|
||||
return models.AccessKey{}, errors.New("could not retrieve network keys")
|
||||
}
|
||||
|
||||
for _, key := range checkkeys {
|
||||
if key.Name == accesskey.Name {
|
||||
return models.AccessKey{}, errors.New("duplicate AccessKey Name")
|
||||
}
|
||||
}
|
||||
privAddr := ""
|
||||
if network.IsLocal != "" {
|
||||
privAddr = network.LocalRange
|
||||
}
|
||||
|
||||
netID := network.NetID
|
||||
|
||||
var accessToken models.AccessToken
|
||||
s := servercfg.GetServerConfig()
|
||||
servervals := models.ServerConfig{
|
||||
CoreDNSAddr: s.CoreDNSAddr,
|
||||
APIConnString: s.APIConnString,
|
||||
APIHost: s.APIHost,
|
||||
APIPort: s.APIPort,
|
||||
GRPCConnString: s.GRPCConnString,
|
||||
GRPCHost: s.GRPCHost,
|
||||
GRPCPort: s.GRPCPort,
|
||||
GRPCSSL: s.GRPCSSL,
|
||||
CheckinInterval: s.CheckinInterval,
|
||||
}
|
||||
accessToken.ServerConfig = servervals
|
||||
accessToken.ClientConfig.Network = netID
|
||||
accessToken.ClientConfig.Key = accesskey.Value
|
||||
accessToken.ClientConfig.LocalRange = privAddr
|
||||
|
||||
tokenjson, err := json.Marshal(accessToken)
|
||||
if err != nil {
|
||||
return accesskey, err
|
||||
}
|
||||
|
||||
accesskey.AccessString = base64.StdEncoding.EncodeToString([]byte(tokenjson))
|
||||
|
||||
//validate accesskey
|
||||
v := validator.New()
|
||||
err = v.Struct(accesskey)
|
||||
if err != nil {
|
||||
for _, e := range err.(validator.ValidationErrors) {
|
||||
logger.Log(1, "validator", e.Error())
|
||||
}
|
||||
return models.AccessKey{}, err
|
||||
}
|
||||
|
||||
network.AccessKeys = append(network.AccessKeys, accesskey)
|
||||
data, err := json.Marshal(&network)
|
||||
if err != nil {
|
||||
return models.AccessKey{}, err
|
||||
}
|
||||
if err = database.Insert(network.NetID, string(data), database.NETWORKS_TABLE_NAME); err != nil {
|
||||
return models.AccessKey{}, err
|
||||
}
|
||||
|
||||
return accesskey, nil
|
||||
}
|
||||
|
||||
func GetSignupToken(netID string) (models.AccessKey, error) {
|
||||
|
||||
var accesskey models.AccessKey
|
||||
var accessToken models.AccessToken
|
||||
s := servercfg.GetServerConfig()
|
||||
servervals := models.ServerConfig{
|
||||
APIConnString: s.APIConnString,
|
||||
APIHost: s.APIHost,
|
||||
APIPort: s.APIPort,
|
||||
GRPCConnString: s.GRPCConnString,
|
||||
GRPCHost: s.GRPCHost,
|
||||
GRPCPort: s.GRPCPort,
|
||||
GRPCSSL: s.GRPCSSL,
|
||||
}
|
||||
accessToken.ServerConfig = servervals
|
||||
|
||||
tokenjson, err := json.Marshal(accessToken)
|
||||
if err != nil {
|
||||
return accesskey, err
|
||||
}
|
||||
|
||||
accesskey.AccessString = base64.StdEncoding.EncodeToString([]byte(tokenjson))
|
||||
return accesskey, nil
|
||||
}
|
||||
func getSignupToken(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
var params = mux.Vars(r)
|
||||
netID := params["networkname"]
|
||||
|
||||
token, err := GetSignupToken(netID)
|
||||
if err != nil {
|
||||
returnErrorResponse(w, r, formatError(err, "internal"))
|
||||
return
|
||||
}
|
||||
logger.Log(2, r.Header.Get("user"), "got signup token", netID)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(token)
|
||||
}
|
||||
|
||||
//pretty simple get
|
||||
// pretty simple get
|
||||
func getAccessKeys(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
var params = mux.Vars(r)
|
||||
network := params["networkname"]
|
||||
keys, err := GetKeys(network)
|
||||
keys, err := logic.GetKeys(network)
|
||||
if err != nil {
|
||||
returnErrorResponse(w, r, formatError(err, "internal"))
|
||||
return
|
||||
|
@ -510,26 +281,14 @@ func getAccessKeys(w http.ResponseWriter, r *http.Request) {
|
|||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(keys)
|
||||
}
|
||||
func GetKeys(net string) ([]models.AccessKey, error) {
|
||||
|
||||
record, err := database.FetchRecord(database.NETWORKS_TABLE_NAME, net)
|
||||
if err != nil {
|
||||
return []models.AccessKey{}, err
|
||||
}
|
||||
network, err := functions.ParseNetwork(record)
|
||||
if err != nil {
|
||||
return []models.AccessKey{}, err
|
||||
}
|
||||
return network.AccessKeys, nil
|
||||
}
|
||||
|
||||
//delete key. Has to do a little funky logic since it's not a collection item
|
||||
// delete key. Has to do a little funky logic since it's not a collection item
|
||||
func deleteAccessKey(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
var params = mux.Vars(r)
|
||||
keyname := params["name"]
|
||||
netname := params["networkname"]
|
||||
err := DeleteKey(keyname, netname)
|
||||
err := logic.DeleteKey(keyname, netname)
|
||||
if err != nil {
|
||||
returnErrorResponse(w, r, formatError(err, "badrequest"))
|
||||
return
|
||||
|
@ -537,33 +296,3 @@ func deleteAccessKey(w http.ResponseWriter, r *http.Request) {
|
|||
logger.Log(1, r.Header.Get("user"), "deleted access key", keyname, "on network,", netname)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
func DeleteKey(keyname, netname string) error {
|
||||
network, err := logic.GetParentNetwork(netname)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
//basically, turn the list of access keys into the list of access keys before and after the item
|
||||
//have not done any error handling for if there's like...1 item. I think it works? need to test.
|
||||
found := false
|
||||
var updatedKeys []models.AccessKey
|
||||
for _, currentkey := range network.AccessKeys {
|
||||
if currentkey.Name == keyname {
|
||||
found = true
|
||||
} else {
|
||||
updatedKeys = append(updatedKeys, currentkey)
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return errors.New("key " + keyname + " does not exist")
|
||||
}
|
||||
network.AccessKeys = updatedKeys
|
||||
data, err := json.Marshal(&network)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := database.Insert(network.NetID, string(data), database.NETWORKS_TABLE_NAME); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -25,7 +25,7 @@ func TestCreateNetwork(t *testing.T) {
|
|||
network.AddressRange = "10.0.0.1/24"
|
||||
network.DisplayName = "mynetwork"
|
||||
|
||||
err := CreateNetwork(network)
|
||||
err := logic.CreateNetwork(network)
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
func TestGetNetwork(t *testing.T) {
|
||||
|
@ -33,12 +33,12 @@ func TestGetNetwork(t *testing.T) {
|
|||
createNet()
|
||||
|
||||
t.Run("GetExistingNetwork", func(t *testing.T) {
|
||||
network, err := GetNetwork("skynet")
|
||||
network, err := logic.GetNetwork("skynet")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "skynet", network.NetID)
|
||||
})
|
||||
t.Run("GetNonExistantNetwork", func(t *testing.T) {
|
||||
network, err := GetNetwork("doesnotexist")
|
||||
network, err := logic.GetNetwork("doesnotexist")
|
||||
assert.EqualError(t, err, "no result found")
|
||||
assert.Equal(t, "", network.NetID)
|
||||
})
|
||||
|
@ -51,11 +51,11 @@ func TestDeleteNetwork(t *testing.T) {
|
|||
t.Run("NetworkwithNodes", func(t *testing.T) {
|
||||
})
|
||||
t.Run("DeleteExistingNetwork", func(t *testing.T) {
|
||||
err := DeleteNetwork("skynet")
|
||||
err := logic.DeleteNetwork("skynet")
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
t.Run("NonExistantNetwork", func(t *testing.T) {
|
||||
err := DeleteNetwork("skynet")
|
||||
err := logic.DeleteNetwork("skynet")
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
}
|
||||
|
@ -64,12 +64,12 @@ func TestKeyUpdate(t *testing.T) {
|
|||
t.Skip() //test is failing on last assert --- not sure why
|
||||
database.InitializeDatabase()
|
||||
createNet()
|
||||
existing, err := GetNetwork("skynet")
|
||||
existing, err := logic.GetNetwork("skynet")
|
||||
assert.Nil(t, err)
|
||||
time.Sleep(time.Second * 1)
|
||||
network, err := KeyUpdate("skynet")
|
||||
network, err := logic.KeyUpdate("skynet")
|
||||
assert.Nil(t, err)
|
||||
network, err = GetNetwork("skynet")
|
||||
network, err = logic.GetNetwork("skynet")
|
||||
assert.Nil(t, err)
|
||||
assert.Greater(t, network.KeyUpdateTimeStamp, existing.KeyUpdateTimeStamp)
|
||||
}
|
||||
|
@ -77,70 +77,70 @@ func TestKeyUpdate(t *testing.T) {
|
|||
func TestCreateKey(t *testing.T) {
|
||||
database.InitializeDatabase()
|
||||
createNet()
|
||||
keys, _ := GetKeys("skynet")
|
||||
keys, _ := logic.GetKeys("skynet")
|
||||
for _, key := range keys {
|
||||
DeleteKey(key.Name, "skynet")
|
||||
logic.DeleteKey(key.Name, "skynet")
|
||||
}
|
||||
var accesskey models.AccessKey
|
||||
var network models.Network
|
||||
network.NetID = "skynet"
|
||||
t.Run("NameTooLong", func(t *testing.T) {
|
||||
network, err := GetNetwork("skynet")
|
||||
network, err := logic.GetNetwork("skynet")
|
||||
assert.Nil(t, err)
|
||||
accesskey.Name = "Thisisareallylongkeynamethatwillfail"
|
||||
_, err = CreateAccessKey(accesskey, network)
|
||||
_, err = logic.CreateAccessKey(accesskey, network)
|
||||
assert.NotNil(t, err)
|
||||
assert.Contains(t, err.Error(), "Field validation for 'Name' failed on the 'max' tag")
|
||||
})
|
||||
t.Run("BlankName", func(t *testing.T) {
|
||||
network, err := GetNetwork("skynet")
|
||||
network, err := logic.GetNetwork("skynet")
|
||||
assert.Nil(t, err)
|
||||
accesskey.Name = ""
|
||||
key, err := CreateAccessKey(accesskey, network)
|
||||
key, err := logic.CreateAccessKey(accesskey, network)
|
||||
assert.Nil(t, err)
|
||||
assert.NotEqual(t, "", key.Name)
|
||||
})
|
||||
t.Run("InvalidValue", func(t *testing.T) {
|
||||
network, err := GetNetwork("skynet")
|
||||
network, err := logic.GetNetwork("skynet")
|
||||
assert.Nil(t, err)
|
||||
accesskey.Value = "bad-value"
|
||||
_, err = CreateAccessKey(accesskey, network)
|
||||
_, err = logic.CreateAccessKey(accesskey, network)
|
||||
assert.NotNil(t, err)
|
||||
assert.Contains(t, err.Error(), "Field validation for 'Value' failed on the 'alphanum' tag")
|
||||
})
|
||||
t.Run("BlankValue", func(t *testing.T) {
|
||||
network, err := GetNetwork("skynet")
|
||||
network, err := logic.GetNetwork("skynet")
|
||||
assert.Nil(t, err)
|
||||
accesskey.Name = "mykey"
|
||||
accesskey.Value = ""
|
||||
key, err := CreateAccessKey(accesskey, network)
|
||||
key, err := logic.CreateAccessKey(accesskey, network)
|
||||
assert.Nil(t, err)
|
||||
assert.NotEqual(t, "", key.Value)
|
||||
assert.Equal(t, accesskey.Name, key.Name)
|
||||
})
|
||||
t.Run("ValueTooLong", func(t *testing.T) {
|
||||
network, err := GetNetwork("skynet")
|
||||
network, err := logic.GetNetwork("skynet")
|
||||
assert.Nil(t, err)
|
||||
accesskey.Name = "keyname"
|
||||
accesskey.Value = "AccessKeyValuethatistoolong"
|
||||
_, err = CreateAccessKey(accesskey, network)
|
||||
_, err = logic.CreateAccessKey(accesskey, network)
|
||||
assert.NotNil(t, err)
|
||||
assert.Contains(t, err.Error(), "Field validation for 'Value' failed on the 'max' tag")
|
||||
})
|
||||
t.Run("BlankUses", func(t *testing.T) {
|
||||
network, err := GetNetwork("skynet")
|
||||
network, err := logic.GetNetwork("skynet")
|
||||
assert.Nil(t, err)
|
||||
accesskey.Uses = 0
|
||||
accesskey.Value = ""
|
||||
key, err := CreateAccessKey(accesskey, network)
|
||||
key, err := logic.CreateAccessKey(accesskey, network)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 1, key.Uses)
|
||||
})
|
||||
t.Run("DuplicateKey", func(t *testing.T) {
|
||||
network, err := GetNetwork("skynet")
|
||||
network, err := logic.GetNetwork("skynet")
|
||||
assert.Nil(t, err)
|
||||
accesskey.Name = "mykey"
|
||||
_, err = CreateAccessKey(accesskey, network)
|
||||
_, err = logic.CreateAccessKey(accesskey, network)
|
||||
assert.NotNil(t, err)
|
||||
assert.EqualError(t, err, "duplicate AccessKey Name")
|
||||
})
|
||||
|
@ -150,21 +150,21 @@ func TestGetKeys(t *testing.T) {
|
|||
database.InitializeDatabase()
|
||||
deleteAllNetworks()
|
||||
createNet()
|
||||
network, err := GetNetwork("skynet")
|
||||
network, err := logic.GetNetwork("skynet")
|
||||
assert.Nil(t, err)
|
||||
var key models.AccessKey
|
||||
key.Name = "mykey"
|
||||
_, err = CreateAccessKey(key, network)
|
||||
_, err = logic.CreateAccessKey(key, network)
|
||||
assert.Nil(t, err)
|
||||
t.Run("KeyExists", func(t *testing.T) {
|
||||
keys, err := GetKeys(network.NetID)
|
||||
keys, err := logic.GetKeys(network.NetID)
|
||||
assert.Nil(t, err)
|
||||
assert.NotEqual(t, models.AccessKey{}, keys)
|
||||
})
|
||||
t.Run("NonExistantKey", func(t *testing.T) {
|
||||
err := DeleteKey("mykey", "skynet")
|
||||
err := logic.DeleteKey("mykey", "skynet")
|
||||
assert.Nil(t, err)
|
||||
keys, err := GetKeys(network.NetID)
|
||||
keys, err := logic.GetKeys(network.NetID)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, []models.AccessKey(nil), keys)
|
||||
})
|
||||
|
@ -172,18 +172,18 @@ func TestGetKeys(t *testing.T) {
|
|||
func TestDeleteKey(t *testing.T) {
|
||||
database.InitializeDatabase()
|
||||
createNet()
|
||||
network, err := GetNetwork("skynet")
|
||||
network, err := logic.GetNetwork("skynet")
|
||||
assert.Nil(t, err)
|
||||
var key models.AccessKey
|
||||
key.Name = "mykey"
|
||||
_, err = CreateAccessKey(key, network)
|
||||
_, err = logic.CreateAccessKey(key, network)
|
||||
assert.Nil(t, err)
|
||||
t.Run("ExistingKey", func(t *testing.T) {
|
||||
err := DeleteKey("mykey", "skynet")
|
||||
err := logic.DeleteKey("mykey", "skynet")
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
t.Run("NonExistantKey", func(t *testing.T) {
|
||||
err := DeleteKey("mykey", "skynet")
|
||||
err := logic.DeleteKey("mykey", "skynet")
|
||||
assert.NotNil(t, err)
|
||||
assert.Equal(t, "key mykey does not exist", err.Error())
|
||||
})
|
||||
|
@ -325,7 +325,7 @@ func TestValidateNetworkUpdate(t *testing.T) {
|
|||
for _, tc := range cases {
|
||||
t.Run(tc.testname, func(t *testing.T) {
|
||||
network := models.Network(tc.network)
|
||||
err := ValidateNetworkUpdate(network)
|
||||
err := logic.ValidateNetworkUpdate(network)
|
||||
assert.NotNil(t, err)
|
||||
assert.Contains(t, err.Error(), tc.errMessage)
|
||||
})
|
||||
|
@ -336,7 +336,7 @@ func deleteAllNetworks() {
|
|||
deleteAllNodes()
|
||||
nets, _ := logic.GetNetworks()
|
||||
for _, net := range nets {
|
||||
DeleteNetwork(net.NetID)
|
||||
logic.DeleteNetwork(net.NetID)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -345,13 +345,13 @@ func createNet() {
|
|||
network.NetID = "skynet"
|
||||
network.AddressRange = "10.0.0.1/24"
|
||||
network.DisplayName = "mynetwork"
|
||||
_, err := GetNetwork("skynet")
|
||||
_, err := logic.GetNetwork("skynet")
|
||||
if err != nil {
|
||||
CreateNetwork(network)
|
||||
logic.CreateNetwork(network)
|
||||
}
|
||||
}
|
||||
|
||||
func getNet() models.Network {
|
||||
network, _ := GetNetwork("skynet")
|
||||
network, _ := logic.GetNetwork("skynet")
|
||||
return network
|
||||
}
|
|
@ -2,10 +2,8 @@ package controller
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/gravitl/netmaker/database"
|
||||
|
@ -349,7 +347,7 @@ func getLastModified(w http.ResponseWriter, r *http.Request) {
|
|||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
var params = mux.Vars(r)
|
||||
network, err := GetNetwork(params["network"])
|
||||
network, err := logic.GetNetwork(params["network"])
|
||||
if err != nil {
|
||||
returnErrorResponse(w, r, formatError(err, "internal"))
|
||||
return
|
||||
|
@ -431,7 +429,7 @@ func createNode(w http.ResponseWriter, r *http.Request) {
|
|||
func uncordonNode(w http.ResponseWriter, r *http.Request) {
|
||||
var params = mux.Vars(r)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
node, err := UncordonNode(params["network"], params["macaddress"])
|
||||
node, err := logic.UncordonNode(params["network"], params["macaddress"])
|
||||
if err != nil {
|
||||
returnErrorResponse(w, r, formatError(err, "internal"))
|
||||
return
|
||||
|
@ -441,28 +439,6 @@ func uncordonNode(w http.ResponseWriter, r *http.Request) {
|
|||
json.NewEncoder(w).Encode("SUCCESS")
|
||||
}
|
||||
|
||||
// UncordonNode - approves a node to join a network
|
||||
func UncordonNode(network, macaddress string) (models.Node, error) {
|
||||
node, err := logic.GetNodeByMacAddress(network, macaddress)
|
||||
if err != nil {
|
||||
return models.Node{}, err
|
||||
}
|
||||
node.SetLastModified()
|
||||
node.IsPending = "no"
|
||||
node.PullChanges = "yes"
|
||||
data, err := json.Marshal(&node)
|
||||
if err != nil {
|
||||
return node, err
|
||||
}
|
||||
key, err := logic.GetRecordKey(node.MacAddress, node.Network)
|
||||
if err != nil {
|
||||
return node, err
|
||||
}
|
||||
|
||||
err = database.Insert(key, string(data), database.NODES_TABLE_NAME)
|
||||
return node, err
|
||||
}
|
||||
|
||||
func createEgressGateway(w http.ResponseWriter, r *http.Request) {
|
||||
var gateway models.EgressGatewayRequest
|
||||
var params = mux.Vars(r)
|
||||
|
@ -474,7 +450,7 @@ func createEgressGateway(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
gateway.NetID = params["network"]
|
||||
gateway.NodeID = params["macaddress"]
|
||||
node, err := CreateEgressGateway(gateway)
|
||||
node, err := logic.CreateEgressGateway(gateway)
|
||||
if err != nil {
|
||||
returnErrorResponse(w, r, formatError(err, "internal"))
|
||||
return
|
||||
|
@ -484,80 +460,12 @@ func createEgressGateway(w http.ResponseWriter, r *http.Request) {
|
|||
json.NewEncoder(w).Encode(node)
|
||||
}
|
||||
|
||||
// CreateEgressGateway - creates an egress gateway
|
||||
func CreateEgressGateway(gateway models.EgressGatewayRequest) (models.Node, error) {
|
||||
node, err := logic.GetNodeByMacAddress(gateway.NetID, gateway.NodeID)
|
||||
if node.OS == "windows" || node.OS == "macos" { // add in darwin later
|
||||
return models.Node{}, errors.New(node.OS + " is unsupported for egress gateways")
|
||||
}
|
||||
if err != nil {
|
||||
return models.Node{}, err
|
||||
}
|
||||
err = ValidateEgressGateway(gateway)
|
||||
if err != nil {
|
||||
return models.Node{}, err
|
||||
}
|
||||
node.IsEgressGateway = "yes"
|
||||
node.EgressGatewayRanges = gateway.Ranges
|
||||
postUpCmd := "iptables -A FORWARD -i " + node.Interface + " -j ACCEPT; iptables -t nat -A POSTROUTING -o " + gateway.Interface + " -j MASQUERADE"
|
||||
postDownCmd := "iptables -D FORWARD -i " + node.Interface + " -j ACCEPT; iptables -t nat -D POSTROUTING -o " + gateway.Interface + " -j MASQUERADE"
|
||||
if gateway.PostUp != "" {
|
||||
postUpCmd = gateway.PostUp
|
||||
}
|
||||
if gateway.PostDown != "" {
|
||||
postDownCmd = gateway.PostDown
|
||||
}
|
||||
if node.PostUp != "" {
|
||||
if !strings.Contains(node.PostUp, postUpCmd) {
|
||||
postUpCmd = node.PostUp + "; " + postUpCmd
|
||||
}
|
||||
}
|
||||
if node.PostDown != "" {
|
||||
if !strings.Contains(node.PostDown, postDownCmd) {
|
||||
postDownCmd = node.PostDown + "; " + postDownCmd
|
||||
}
|
||||
}
|
||||
key, err := logic.GetRecordKey(gateway.NodeID, gateway.NetID)
|
||||
if err != nil {
|
||||
return node, err
|
||||
}
|
||||
node.PostUp = postUpCmd
|
||||
node.PostDown = postDownCmd
|
||||
node.SetLastModified()
|
||||
node.PullChanges = "yes"
|
||||
nodeData, err := json.Marshal(&node)
|
||||
if err != nil {
|
||||
return node, err
|
||||
}
|
||||
if err = database.Insert(key, string(nodeData), database.NODES_TABLE_NAME); err != nil {
|
||||
return models.Node{}, err
|
||||
}
|
||||
if err = functions.NetworkNodesUpdatePullChanges(node.Network); err != nil {
|
||||
return models.Node{}, err
|
||||
}
|
||||
return node, nil
|
||||
}
|
||||
|
||||
func ValidateEgressGateway(gateway models.EgressGatewayRequest) error {
|
||||
var err error
|
||||
//isIp := functions.IsIpCIDR(gateway.RangeString)
|
||||
empty := len(gateway.Ranges) == 0
|
||||
if empty {
|
||||
err = errors.New("IP Ranges Cannot Be Empty")
|
||||
}
|
||||
empty = gateway.Interface == ""
|
||||
if empty {
|
||||
err = errors.New("Interface cannot be empty")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func deleteEgressGateway(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
var params = mux.Vars(r)
|
||||
nodeMac := params["macaddress"]
|
||||
netid := params["network"]
|
||||
node, err := DeleteEgressGateway(netid, nodeMac)
|
||||
node, err := logic.DeleteEgressGateway(netid, nodeMac)
|
||||
if err != nil {
|
||||
returnErrorResponse(w, r, formatError(err, "internal"))
|
||||
return
|
||||
|
@ -567,48 +475,13 @@ func deleteEgressGateway(w http.ResponseWriter, r *http.Request) {
|
|||
json.NewEncoder(w).Encode(node)
|
||||
}
|
||||
|
||||
// DeleteEgressGateway - deletes egress from node
|
||||
func DeleteEgressGateway(network, macaddress string) (models.Node, error) {
|
||||
|
||||
node, err := logic.GetNodeByMacAddress(network, macaddress)
|
||||
if err != nil {
|
||||
return models.Node{}, err
|
||||
}
|
||||
|
||||
node.IsEgressGateway = "no"
|
||||
node.EgressGatewayRanges = []string{}
|
||||
node.PostUp = ""
|
||||
node.PostDown = ""
|
||||
if node.IsIngressGateway == "yes" { // check if node is still an ingress gateway before completely deleting postdown/up rules
|
||||
node.PostUp = "iptables -A FORWARD -i " + node.Interface + " -j ACCEPT; iptables -t nat -A POSTROUTING -o " + node.Interface + " -j MASQUERADE"
|
||||
node.PostDown = "iptables -D FORWARD -i " + node.Interface + " -j ACCEPT; iptables -t nat -D POSTROUTING -o " + node.Interface + " -j MASQUERADE"
|
||||
}
|
||||
node.SetLastModified()
|
||||
node.PullChanges = "yes"
|
||||
key, err := logic.GetRecordKey(node.MacAddress, node.Network)
|
||||
if err != nil {
|
||||
return models.Node{}, err
|
||||
}
|
||||
data, err := json.Marshal(&node)
|
||||
if err != nil {
|
||||
return models.Node{}, err
|
||||
}
|
||||
if err = database.Insert(key, string(data), database.NODES_TABLE_NAME); err != nil {
|
||||
return models.Node{}, err
|
||||
}
|
||||
if err = functions.NetworkNodesUpdatePullChanges(network); err != nil {
|
||||
return models.Node{}, err
|
||||
}
|
||||
return node, nil
|
||||
}
|
||||
|
||||
// == INGRESS ==
|
||||
func createIngressGateway(w http.ResponseWriter, r *http.Request) {
|
||||
var params = mux.Vars(r)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
nodeMac := params["macaddress"]
|
||||
netid := params["network"]
|
||||
node, err := CreateIngressGateway(netid, nodeMac)
|
||||
node, err := logic.CreateIngressGateway(netid, nodeMac)
|
||||
if err != nil {
|
||||
returnErrorResponse(w, r, formatError(err, "internal"))
|
||||
return
|
||||
|
@ -618,62 +491,11 @@ func createIngressGateway(w http.ResponseWriter, r *http.Request) {
|
|||
json.NewEncoder(w).Encode(node)
|
||||
}
|
||||
|
||||
// CreateIngressGateway - creates an ingress gateway
|
||||
func CreateIngressGateway(netid string, macaddress string) (models.Node, error) {
|
||||
|
||||
node, err := logic.GetNodeByMacAddress(netid, macaddress)
|
||||
if node.OS == "windows" || node.OS == "macos" { // add in darwin later
|
||||
return models.Node{}, errors.New(node.OS + " is unsupported for ingress gateways")
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return models.Node{}, err
|
||||
}
|
||||
|
||||
network, err := logic.GetParentNetwork(netid)
|
||||
if err != nil {
|
||||
return models.Node{}, err
|
||||
}
|
||||
node.IsIngressGateway = "yes"
|
||||
node.IngressGatewayRange = network.AddressRange
|
||||
postUpCmd := "iptables -A FORWARD -i " + node.Interface + " -j ACCEPT; iptables -t nat -A POSTROUTING -o " + node.Interface + " -j MASQUERADE"
|
||||
postDownCmd := "iptables -D FORWARD -i " + node.Interface + " -j ACCEPT; iptables -t nat -D POSTROUTING -o " + node.Interface + " -j MASQUERADE"
|
||||
if node.PostUp != "" {
|
||||
if !strings.Contains(node.PostUp, postUpCmd) {
|
||||
postUpCmd = node.PostUp + "; " + postUpCmd
|
||||
}
|
||||
}
|
||||
if node.PostDown != "" {
|
||||
if !strings.Contains(node.PostDown, postDownCmd) {
|
||||
postDownCmd = node.PostDown + "; " + postDownCmd
|
||||
}
|
||||
}
|
||||
node.SetLastModified()
|
||||
node.PostUp = postUpCmd
|
||||
node.PostDown = postDownCmd
|
||||
node.PullChanges = "yes"
|
||||
node.UDPHolePunch = "no"
|
||||
key, err := logic.GetRecordKey(node.MacAddress, node.Network)
|
||||
if err != nil {
|
||||
return models.Node{}, err
|
||||
}
|
||||
data, err := json.Marshal(&node)
|
||||
if err != nil {
|
||||
return models.Node{}, err
|
||||
}
|
||||
err = database.Insert(key, string(data), database.NODES_TABLE_NAME)
|
||||
if err != nil {
|
||||
return models.Node{}, err
|
||||
}
|
||||
err = logic.SetNetworkNodesLastModified(netid)
|
||||
return node, err
|
||||
}
|
||||
|
||||
func deleteIngressGateway(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
var params = mux.Vars(r)
|
||||
nodeMac := params["macaddress"]
|
||||
node, err := DeleteIngressGateway(params["network"], nodeMac)
|
||||
node, err := logic.DeleteIngressGateway(params["network"], nodeMac)
|
||||
if err != nil {
|
||||
returnErrorResponse(w, r, formatError(err, "internal"))
|
||||
return
|
||||
|
@ -683,44 +505,6 @@ func deleteIngressGateway(w http.ResponseWriter, r *http.Request) {
|
|||
json.NewEncoder(w).Encode(node)
|
||||
}
|
||||
|
||||
// DeleteIngressGateway - deletes an ingress gateway
|
||||
func DeleteIngressGateway(networkName string, macaddress string) (models.Node, error) {
|
||||
|
||||
node, err := logic.GetNodeByMacAddress(networkName, macaddress)
|
||||
if err != nil {
|
||||
return models.Node{}, err
|
||||
}
|
||||
network, err := logic.GetParentNetwork(networkName)
|
||||
if err != nil {
|
||||
return models.Node{}, err
|
||||
}
|
||||
// delete ext clients belonging to ingress gateway
|
||||
if err = DeleteGatewayExtClients(macaddress, networkName); err != nil {
|
||||
return models.Node{}, err
|
||||
}
|
||||
|
||||
node.UDPHolePunch = network.DefaultUDPHolePunch
|
||||
node.LastModified = time.Now().Unix()
|
||||
node.IsIngressGateway = "no"
|
||||
node.IngressGatewayRange = ""
|
||||
node.PullChanges = "yes"
|
||||
|
||||
key, err := logic.GetRecordKey(node.MacAddress, node.Network)
|
||||
if err != nil {
|
||||
return models.Node{}, err
|
||||
}
|
||||
data, err := json.Marshal(&node)
|
||||
if err != nil {
|
||||
return models.Node{}, err
|
||||
}
|
||||
err = database.Insert(key, string(data), database.NODES_TABLE_NAME)
|
||||
if err != nil {
|
||||
return models.Node{}, err
|
||||
}
|
||||
err = logic.SetNetworkNodesLastModified(networkName)
|
||||
return node, err
|
||||
}
|
||||
|
||||
func updateNode(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
|
@ -760,8 +544,8 @@ func updateNode(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
if relayupdate {
|
||||
UpdateRelay(node.Network, node.RelayAddrs, newNode.RelayAddrs)
|
||||
if err = functions.NetworkNodesUpdatePullChanges(node.Network); err != nil {
|
||||
logic.UpdateRelay(node.Network, node.RelayAddrs, newNode.RelayAddrs)
|
||||
if err = logic.NetworkNodesUpdatePullChanges(node.Network); err != nil {
|
||||
logger.Log(1, "error setting relay updates:", err.Error())
|
||||
}
|
||||
}
|
|
@ -17,7 +17,7 @@ func TestCreateEgressGateway(t *testing.T) {
|
|||
deleteAllNetworks()
|
||||
createNet()
|
||||
t.Run("NoNodes", func(t *testing.T) {
|
||||
node, err := CreateEgressGateway(gateway)
|
||||
node, err := logic.CreateEgressGateway(gateway)
|
||||
assert.Equal(t, models.Node{}, node)
|
||||
assert.EqualError(t, err, "unable to get record key")
|
||||
})
|
||||
|
@ -26,7 +26,7 @@ func TestCreateEgressGateway(t *testing.T) {
|
|||
gateway.NetID = "skynet"
|
||||
gateway.NodeID = testnode.MacAddress
|
||||
|
||||
node, err := CreateEgressGateway(gateway)
|
||||
node, err := logic.CreateEgressGateway(gateway)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "yes", node.IsEgressGateway)
|
||||
assert.Equal(t, gateway.Ranges, node.EgressGatewayRanges)
|
||||
|
@ -45,11 +45,11 @@ func TestDeleteEgressGateway(t *testing.T) {
|
|||
gateway.NetID = "skynet"
|
||||
gateway.NodeID = testnode.MacAddress
|
||||
t.Run("Success", func(t *testing.T) {
|
||||
node, err := CreateEgressGateway(gateway)
|
||||
node, err := logic.CreateEgressGateway(gateway)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "yes", node.IsEgressGateway)
|
||||
assert.Equal(t, []string{"10.100.100.0/24"}, node.EgressGatewayRanges)
|
||||
node, err = DeleteEgressGateway(gateway.NetID, gateway.NodeID)
|
||||
node, err = logic.DeleteEgressGateway(gateway.NetID, gateway.NodeID)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "no", node.IsEgressGateway)
|
||||
assert.Equal(t, []string([]string{}), node.EgressGatewayRanges)
|
||||
|
@ -57,7 +57,7 @@ func TestDeleteEgressGateway(t *testing.T) {
|
|||
assert.Equal(t, "", node.PostDown)
|
||||
})
|
||||
t.Run("NotGateway", func(t *testing.T) {
|
||||
node, err := DeleteEgressGateway(gateway.NetID, gateway.NodeID)
|
||||
node, err := logic.DeleteEgressGateway(gateway.NetID, gateway.NodeID)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "no", node.IsEgressGateway)
|
||||
assert.Equal(t, []string([]string{}), node.EgressGatewayRanges)
|
||||
|
@ -65,12 +65,12 @@ func TestDeleteEgressGateway(t *testing.T) {
|
|||
assert.Equal(t, "", node.PostDown)
|
||||
})
|
||||
t.Run("BadNode", func(t *testing.T) {
|
||||
node, err := DeleteEgressGateway(gateway.NetID, "01:02:03")
|
||||
node, err := logic.DeleteEgressGateway(gateway.NetID, "01:02:03")
|
||||
assert.EqualError(t, err, "no result found")
|
||||
assert.Equal(t, models.Node{}, node)
|
||||
})
|
||||
t.Run("BadNet", func(t *testing.T) {
|
||||
node, err := DeleteEgressGateway("badnet", gateway.NodeID)
|
||||
node, err := logic.DeleteEgressGateway("badnet", gateway.NodeID)
|
||||
assert.EqualError(t, err, "no result found")
|
||||
assert.Equal(t, models.Node{}, node)
|
||||
})
|
||||
|
@ -106,17 +106,17 @@ func TestUncordonNode(t *testing.T) {
|
|||
createNet()
|
||||
node := createTestNode()
|
||||
t.Run("BadNet", func(t *testing.T) {
|
||||
resp, err := UncordonNode("badnet", node.MacAddress)
|
||||
resp, err := logic.UncordonNode("badnet", node.MacAddress)
|
||||
assert.Equal(t, models.Node{}, resp)
|
||||
assert.EqualError(t, err, "no result found")
|
||||
})
|
||||
t.Run("BadMac", func(t *testing.T) {
|
||||
resp, err := UncordonNode("skynet", "01:02:03")
|
||||
resp, err := logic.UncordonNode("skynet", "01:02:03")
|
||||
assert.Equal(t, models.Node{}, resp)
|
||||
assert.EqualError(t, err, "no result found")
|
||||
})
|
||||
t.Run("Success", func(t *testing.T) {
|
||||
resp, err := UncordonNode("skynet", node.MacAddress)
|
||||
resp, err := logic.UncordonNode("skynet", node.MacAddress)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "no", resp.IsPending)
|
||||
})
|
||||
|
@ -127,19 +127,19 @@ func TestValidateEgressGateway(t *testing.T) {
|
|||
t.Run("EmptyRange", func(t *testing.T) {
|
||||
gateway.Interface = "eth0"
|
||||
gateway.Ranges = []string{}
|
||||
err := ValidateEgressGateway(gateway)
|
||||
err := logic.ValidateEgressGateway(gateway)
|
||||
assert.EqualError(t, err, "IP Ranges Cannot Be Empty")
|
||||
})
|
||||
t.Run("EmptyInterface", func(t *testing.T) {
|
||||
gateway.Interface = ""
|
||||
err := ValidateEgressGateway(gateway)
|
||||
err := logic.ValidateEgressGateway(gateway)
|
||||
assert.NotNil(t, err)
|
||||
assert.Equal(t, "Interface cannot be empty", err.Error())
|
||||
})
|
||||
t.Run("Success", func(t *testing.T) {
|
||||
gateway.Interface = "eth0"
|
||||
gateway.Ranges = []string{"10.100.100.0/24"}
|
||||
err := ValidateEgressGateway(gateway)
|
||||
err := logic.ValidateEgressGateway(gateway)
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
}
|
|
@ -2,13 +2,9 @@ package controller
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/gravitl/netmaker/database"
|
||||
"github.com/gravitl/netmaker/functions"
|
||||
"github.com/gravitl/netmaker/logger"
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
|
@ -25,7 +21,7 @@ func createRelay(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
relay.NetID = params["network"]
|
||||
relay.NodeID = params["macaddress"]
|
||||
node, err := CreateRelay(relay)
|
||||
node, err := logic.CreateRelay(relay)
|
||||
if err != nil {
|
||||
returnErrorResponse(w, r, formatError(err, "internal"))
|
||||
return
|
||||
|
@ -35,52 +31,12 @@ func createRelay(w http.ResponseWriter, r *http.Request) {
|
|||
json.NewEncoder(w).Encode(node)
|
||||
}
|
||||
|
||||
// CreateRelay - creates a relay
|
||||
func CreateRelay(relay models.RelayRequest) (models.Node, error) {
|
||||
node, err := logic.GetNodeByMacAddress(relay.NetID, relay.NodeID)
|
||||
if node.OS == "macos" { // add in darwin later
|
||||
return models.Node{}, errors.New(node.OS + " is unsupported for relay")
|
||||
}
|
||||
if err != nil {
|
||||
return models.Node{}, err
|
||||
}
|
||||
err = ValidateRelay(relay)
|
||||
if err != nil {
|
||||
return models.Node{}, err
|
||||
}
|
||||
node.IsRelay = "yes"
|
||||
node.RelayAddrs = relay.RelayAddrs
|
||||
|
||||
key, err := logic.GetRecordKey(relay.NodeID, relay.NetID)
|
||||
if err != nil {
|
||||
return node, err
|
||||
}
|
||||
node.SetLastModified()
|
||||
node.PullChanges = "yes"
|
||||
nodeData, err := json.Marshal(&node)
|
||||
if err != nil {
|
||||
return node, err
|
||||
}
|
||||
if err = database.Insert(key, string(nodeData), database.NODES_TABLE_NAME); err != nil {
|
||||
return models.Node{}, err
|
||||
}
|
||||
err = SetRelayedNodes("yes", node.Network, node.RelayAddrs)
|
||||
if err != nil {
|
||||
return node, err
|
||||
}
|
||||
|
||||
if err = functions.NetworkNodesUpdatePullChanges(node.Network); err != nil {
|
||||
return models.Node{}, err
|
||||
}
|
||||
return node, nil
|
||||
}
|
||||
|
||||
func deleteRelay(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
var params = mux.Vars(r)
|
||||
nodeMac := params["macaddress"]
|
||||
netid := params["network"]
|
||||
node, err := DeleteRelay(netid, nodeMac)
|
||||
node, err := logic.DeleteRelay(netid, nodeMac)
|
||||
if err != nil {
|
||||
returnErrorResponse(w, r, formatError(err, "internal"))
|
||||
return
|
||||
|
@ -89,92 +45,3 @@ func deleteRelay(w http.ResponseWriter, r *http.Request) {
|
|||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(node)
|
||||
}
|
||||
|
||||
// SetRelayedNodes- set relayed nodes
|
||||
func SetRelayedNodes(yesOrno string, networkName string, addrs []string) error {
|
||||
|
||||
collections, err := database.FetchRecords(database.NODES_TABLE_NAME)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, value := range collections {
|
||||
|
||||
var node models.Node
|
||||
err := json.Unmarshal([]byte(value), &node)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if node.Network == networkName {
|
||||
for _, addr := range addrs {
|
||||
if addr == node.Address || addr == node.Address6 {
|
||||
node.IsRelayed = yesOrno
|
||||
data, err := json.Marshal(&node)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
node.SetID()
|
||||
database.Insert(node.ID, string(data), database.NODES_TABLE_NAME)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateRelay - checks if relay is valid
|
||||
func ValidateRelay(relay models.RelayRequest) error {
|
||||
var err error
|
||||
//isIp := functions.IsIpCIDR(gateway.RangeString)
|
||||
empty := len(relay.RelayAddrs) == 0
|
||||
if empty {
|
||||
err = errors.New("IP Ranges Cannot Be Empty")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// UpdateRelay - updates a relay
|
||||
func UpdateRelay(network string, oldAddrs []string, newAddrs []string) {
|
||||
time.Sleep(time.Second / 4)
|
||||
err := SetRelayedNodes("no", network, oldAddrs)
|
||||
if err != nil {
|
||||
logger.Log(1, err.Error())
|
||||
}
|
||||
err = SetRelayedNodes("yes", network, newAddrs)
|
||||
if err != nil {
|
||||
logger.Log(1, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// DeleteRelay - deletes a relay
|
||||
func DeleteRelay(network, macaddress string) (models.Node, error) {
|
||||
|
||||
node, err := logic.GetNodeByMacAddress(network, macaddress)
|
||||
if err != nil {
|
||||
return models.Node{}, err
|
||||
}
|
||||
err = SetRelayedNodes("no", node.Network, node.RelayAddrs)
|
||||
if err != nil {
|
||||
return node, err
|
||||
}
|
||||
|
||||
node.IsRelay = "no"
|
||||
node.RelayAddrs = []string{}
|
||||
node.SetLastModified()
|
||||
node.PullChanges = "yes"
|
||||
key, err := logic.GetRecordKey(node.MacAddress, node.Network)
|
||||
if err != nil {
|
||||
return models.Node{}, err
|
||||
}
|
||||
data, err := json.Marshal(&node)
|
||||
if err != nil {
|
||||
return models.Node{}, err
|
||||
}
|
||||
if err = database.Insert(key, string(data), database.NODES_TABLE_NAME); err != nil {
|
||||
return models.Node{}, err
|
||||
}
|
||||
if err = functions.NetworkNodesUpdatePullChanges(network); err != nil {
|
||||
return models.Node{}, err
|
||||
}
|
||||
return node, nil
|
||||
}
|
||||
|
|
|
@ -112,36 +112,6 @@ func authenticateDNSToken(tokenString string) bool {
|
|||
return tokens[1] == servercfg.GetDNSKey()
|
||||
}
|
||||
|
||||
// ValidateUserToken - self explained
|
||||
func ValidateUserToken(token string, user string, adminonly bool) error {
|
||||
var tokenSplit = strings.Split(token, " ")
|
||||
//I put this in in case the user doesn't put in a token at all (in which case it's empty)
|
||||
//There's probably a smarter way of handling this.
|
||||
var authToken = "928rt238tghgwe@TY@$Y@#WQAEGB2FC#@HG#@$Hddd"
|
||||
|
||||
if len(tokenSplit) > 1 {
|
||||
authToken = tokenSplit[1]
|
||||
} else {
|
||||
return errors.New("Missing Auth Token.")
|
||||
}
|
||||
|
||||
username, _, isadmin, err := logic.VerifyUserToken(authToken)
|
||||
if err != nil {
|
||||
return errors.New("Error Verifying Auth Token")
|
||||
}
|
||||
isAuthorized := false
|
||||
if adminonly {
|
||||
isAuthorized = isadmin
|
||||
} else {
|
||||
isAuthorized = username == user || isadmin
|
||||
}
|
||||
if !isAuthorized {
|
||||
return errors.New("You are unauthorized to access this endpoint.")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func continueIfUserMatch(next http.Handler) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
var errorResponse = models.ErrorResponse{
|
||||
|
|
|
@ -241,29 +241,29 @@ func TestUpdateUser(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestValidateUserToken(t *testing.T) {
|
||||
t.Run("EmptyToken", func(t *testing.T) {
|
||||
err := ValidateUserToken("", "", false)
|
||||
assert.NotNil(t, err)
|
||||
assert.Equal(t, "Missing Auth Token.", err.Error())
|
||||
})
|
||||
t.Run("InvalidToken", func(t *testing.T) {
|
||||
err := ValidateUserToken("Bearer: badtoken", "", false)
|
||||
assert.NotNil(t, err)
|
||||
assert.Equal(t, "Error Verifying Auth Token", err.Error())
|
||||
})
|
||||
t.Run("InvalidUser", func(t *testing.T) {
|
||||
t.Skip()
|
||||
err := ValidateUserToken("Bearer: secretkey", "baduser", false)
|
||||
assert.NotNil(t, err)
|
||||
assert.Equal(t, "Error Verifying Auth Token", err.Error())
|
||||
//need authorization
|
||||
})
|
||||
t.Run("ValidToken", func(t *testing.T) {
|
||||
err := ValidateUserToken("Bearer: secretkey", "", true)
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
}
|
||||
// func TestValidateUserToken(t *testing.T) {
|
||||
// t.Run("EmptyToken", func(t *testing.T) {
|
||||
// err := ValidateUserToken("", "", false)
|
||||
// assert.NotNil(t, err)
|
||||
// assert.Equal(t, "Missing Auth Token.", err.Error())
|
||||
// })
|
||||
// t.Run("InvalidToken", func(t *testing.T) {
|
||||
// err := ValidateUserToken("Bearer: badtoken", "", false)
|
||||
// assert.NotNil(t, err)
|
||||
// assert.Equal(t, "Error Verifying Auth Token", err.Error())
|
||||
// })
|
||||
// t.Run("InvalidUser", func(t *testing.T) {
|
||||
// t.Skip()
|
||||
// err := ValidateUserToken("Bearer: secretkey", "baduser", false)
|
||||
// assert.NotNil(t, err)
|
||||
// assert.Equal(t, "Error Verifying Auth Token", err.Error())
|
||||
// //need authorization
|
||||
// })
|
||||
// t.Run("ValidToken", func(t *testing.T) {
|
||||
// err := ValidateUserToken("Bearer: secretkey", "", true)
|
||||
// assert.Nil(t, err)
|
||||
// })
|
||||
// }
|
||||
|
||||
func TestVerifyAuthRequest(t *testing.T) {
|
||||
database.InitializeDatabase()
|
|
@ -2,24 +2,14 @@ package functions
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"math/rand"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gravitl/netmaker/database"
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
)
|
||||
|
||||
// ParseNetwork - parses a network into a model
|
||||
func ParseNetwork(value string) (models.Network, error) {
|
||||
var network models.Network
|
||||
err := json.Unmarshal([]byte(value), &network)
|
||||
return network, err
|
||||
}
|
||||
|
||||
// ParseNode - parses a node into a model
|
||||
func ParseNode(value string) (models.Node, error) {
|
||||
var node models.Node
|
||||
|
@ -131,72 +121,6 @@ func NetworkExists(name string) (bool, error) {
|
|||
return len(network) > 0, nil
|
||||
}
|
||||
|
||||
// NetworkNodesUpdateAction - updates action of network nodes
|
||||
func NetworkNodesUpdateAction(networkName string, action string) error {
|
||||
|
||||
collections, err := database.FetchRecords(database.NODES_TABLE_NAME)
|
||||
if err != nil {
|
||||
if database.IsEmptyRecord(err) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
for _, value := range collections {
|
||||
var node models.Node
|
||||
err := json.Unmarshal([]byte(value), &node)
|
||||
if err != nil {
|
||||
fmt.Println("error in node address assignment!")
|
||||
return err
|
||||
}
|
||||
if action == models.NODE_UPDATE_KEY && node.IsStatic == "yes" {
|
||||
continue
|
||||
}
|
||||
if node.Network == networkName {
|
||||
node.Action = action
|
||||
data, err := json.Marshal(&node)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
node.SetID()
|
||||
database.Insert(node.ID, string(data), database.NODES_TABLE_NAME)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// NetworkNodesUpdatePullChanges - tells nodes on network to pull
|
||||
func NetworkNodesUpdatePullChanges(networkName string) error {
|
||||
|
||||
collections, err := database.FetchRecords(database.NODES_TABLE_NAME)
|
||||
if err != nil {
|
||||
if database.IsEmptyRecord(err) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
for _, value := range collections {
|
||||
var node models.Node
|
||||
err := json.Unmarshal([]byte(value), &node)
|
||||
if err != nil {
|
||||
fmt.Println("error in node address assignment!")
|
||||
return err
|
||||
}
|
||||
if node.Network == networkName {
|
||||
node.PullChanges = "yes"
|
||||
data, err := json.Marshal(&node)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
node.SetID()
|
||||
database.Insert(node.ID, string(data), database.NODES_TABLE_NAME)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsNetworkDisplayNameUnique - checks if network display name unique
|
||||
func IsNetworkDisplayNameUnique(name string) (bool, error) {
|
||||
|
||||
|
@ -228,28 +152,6 @@ func IsMacAddressUnique(macaddress string, networkName string) (bool, error) {
|
|||
return true, nil
|
||||
}
|
||||
|
||||
// GetNetworkNonServerNodeCount - get number of network non server nodes
|
||||
func GetNetworkNonServerNodeCount(networkName string) (int, error) {
|
||||
|
||||
collection, err := database.FetchRecords(database.NODES_TABLE_NAME)
|
||||
count := 0
|
||||
if err != nil && !database.IsEmptyRecord(err) {
|
||||
return count, err
|
||||
}
|
||||
for _, value := range collection {
|
||||
var node models.Node
|
||||
if err = json.Unmarshal([]byte(value), &node); err != nil {
|
||||
return count, err
|
||||
} else {
|
||||
if node.Network == networkName && node.IsServer != "yes" {
|
||||
count++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// IsKeyValidGlobal - checks if a key is valid globally
|
||||
func IsKeyValidGlobal(keyvalue string) bool {
|
||||
|
||||
|
@ -278,31 +180,6 @@ func IsKeyValidGlobal(keyvalue string) bool {
|
|||
return isvalid
|
||||
}
|
||||
|
||||
//TODO: Contains a fatal error return. Need to change
|
||||
//This just gets a network object from a network name
|
||||
//Should probably just be GetNetwork. kind of a dumb name.
|
||||
//Used in contexts where it's not the Parent network.
|
||||
|
||||
//Similar to above but checks if Cidr range is valid
|
||||
//At least this guy's got some print statements
|
||||
//still not good error handling
|
||||
|
||||
//This checks to make sure a network name is valid.
|
||||
//Switch to REGEX?
|
||||
|
||||
// NameInNetworkCharSet - see if name is in charset for networks
|
||||
func NameInNetworkCharSet(name string) bool {
|
||||
|
||||
charset := "abcdefghijklmnopqrstuvwxyz1234567890-_."
|
||||
|
||||
for _, char := range name {
|
||||
if !strings.Contains(charset, strings.ToLower(string(char))) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// NameInDNSCharSet - name in dns char set
|
||||
func NameInDNSCharSet(name string) bool {
|
||||
|
||||
|
@ -387,43 +264,6 @@ func GetAllExtClients() ([]models.ExtClient, error) {
|
|||
return extclients, nil
|
||||
}
|
||||
|
||||
// GenKey - generates access key
|
||||
func GenKey() string {
|
||||
|
||||
var seededRand *rand.Rand = rand.New(
|
||||
rand.NewSource(time.Now().UnixNano()))
|
||||
|
||||
length := 16
|
||||
charset := "abcdefghijklmnopqrstuvwxyz" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
|
||||
b := make([]byte, length)
|
||||
for i := range b {
|
||||
b[i] = charset[seededRand.Intn(len(charset))]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
//generate a key value
|
||||
//we should probably just have 1 random string generator
|
||||
//that can be used across all functions
|
||||
//have a "base string" a "length" and a "charset"
|
||||
|
||||
// GenKeyName - generates a key name
|
||||
func GenKeyName() string {
|
||||
|
||||
var seededRand *rand.Rand = rand.New(
|
||||
rand.NewSource(time.Now().UnixNano()))
|
||||
|
||||
length := 5
|
||||
charset := "abcdefghijklmnopqrstuvwxyz" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
|
||||
b := make([]byte, length)
|
||||
for i := range b {
|
||||
b[i] = charset[seededRand.Intn(len(charset))]
|
||||
}
|
||||
return "key" + string(b)
|
||||
}
|
||||
|
||||
// DeleteKey - deletes a key
|
||||
func DeleteKey(network models.Network, i int) {
|
||||
|
||||
|
|
|
@ -1,13 +1,147 @@
|
|||
package logic
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
"github.com/gravitl/netmaker/database"
|
||||
"github.com/gravitl/netmaker/logger"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/gravitl/netmaker/servercfg"
|
||||
)
|
||||
|
||||
const (
|
||||
charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
)
|
||||
|
||||
// CreateAccessKey - create access key
|
||||
func CreateAccessKey(accesskey models.AccessKey, network models.Network) (models.AccessKey, error) {
|
||||
|
||||
if accesskey.Name == "" {
|
||||
accesskey.Name = genKeyName()
|
||||
}
|
||||
|
||||
if accesskey.Value == "" {
|
||||
accesskey.Value = genKey()
|
||||
}
|
||||
if accesskey.Uses == 0 {
|
||||
accesskey.Uses = 1
|
||||
}
|
||||
|
||||
checkkeys, err := GetKeys(network.NetID)
|
||||
if err != nil {
|
||||
return models.AccessKey{}, errors.New("could not retrieve network keys")
|
||||
}
|
||||
|
||||
for _, key := range checkkeys {
|
||||
if key.Name == accesskey.Name {
|
||||
return models.AccessKey{}, errors.New("duplicate AccessKey Name")
|
||||
}
|
||||
}
|
||||
privAddr := ""
|
||||
if network.IsLocal != "" {
|
||||
privAddr = network.LocalRange
|
||||
}
|
||||
|
||||
netID := network.NetID
|
||||
|
||||
var accessToken models.AccessToken
|
||||
s := servercfg.GetServerConfig()
|
||||
servervals := models.ServerConfig{
|
||||
CoreDNSAddr: s.CoreDNSAddr,
|
||||
APIConnString: s.APIConnString,
|
||||
APIHost: s.APIHost,
|
||||
APIPort: s.APIPort,
|
||||
GRPCConnString: s.GRPCConnString,
|
||||
GRPCHost: s.GRPCHost,
|
||||
GRPCPort: s.GRPCPort,
|
||||
GRPCSSL: s.GRPCSSL,
|
||||
CheckinInterval: s.CheckinInterval,
|
||||
}
|
||||
accessToken.ServerConfig = servervals
|
||||
accessToken.ClientConfig.Network = netID
|
||||
accessToken.ClientConfig.Key = accesskey.Value
|
||||
accessToken.ClientConfig.LocalRange = privAddr
|
||||
|
||||
tokenjson, err := json.Marshal(accessToken)
|
||||
if err != nil {
|
||||
return accesskey, err
|
||||
}
|
||||
|
||||
accesskey.AccessString = base64.StdEncoding.EncodeToString([]byte(tokenjson))
|
||||
|
||||
//validate accesskey
|
||||
v := validator.New()
|
||||
err = v.Struct(accesskey)
|
||||
if err != nil {
|
||||
for _, e := range err.(validator.ValidationErrors) {
|
||||
logger.Log(1, "validator", e.Error())
|
||||
}
|
||||
return models.AccessKey{}, err
|
||||
}
|
||||
|
||||
network.AccessKeys = append(network.AccessKeys, accesskey)
|
||||
data, err := json.Marshal(&network)
|
||||
if err != nil {
|
||||
return models.AccessKey{}, err
|
||||
}
|
||||
if err = database.Insert(network.NetID, string(data), database.NETWORKS_TABLE_NAME); err != nil {
|
||||
return models.AccessKey{}, err
|
||||
}
|
||||
|
||||
return accesskey, nil
|
||||
}
|
||||
|
||||
// DeleteKey - deletes a key
|
||||
func DeleteKey(keyname, netname string) error {
|
||||
network, err := GetParentNetwork(netname)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
//basically, turn the list of access keys into the list of access keys before and after the item
|
||||
//have not done any error handling for if there's like...1 item. I think it works? need to test.
|
||||
found := false
|
||||
var updatedKeys []models.AccessKey
|
||||
for _, currentkey := range network.AccessKeys {
|
||||
if currentkey.Name == keyname {
|
||||
found = true
|
||||
} else {
|
||||
updatedKeys = append(updatedKeys, currentkey)
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return errors.New("key " + keyname + " does not exist")
|
||||
}
|
||||
network.AccessKeys = updatedKeys
|
||||
data, err := json.Marshal(&network)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := database.Insert(network.NetID, string(data), database.NETWORKS_TABLE_NAME); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetKeys - fetches keys for network
|
||||
func GetKeys(net string) ([]models.AccessKey, error) {
|
||||
|
||||
record, err := database.FetchRecord(database.NETWORKS_TABLE_NAME, net)
|
||||
if err != nil {
|
||||
return []models.AccessKey{}, err
|
||||
}
|
||||
network, err := ParseNetwork(record)
|
||||
if err != nil {
|
||||
return []models.AccessKey{}, err
|
||||
}
|
||||
return network.AccessKeys, nil
|
||||
}
|
||||
|
||||
// DecrimentKey - decriments key uses
|
||||
func DecrimentKey(networkName string, keyvalue string) {
|
||||
|
||||
|
@ -71,3 +205,33 @@ func RemoveKeySensitiveInfo(keys []models.AccessKey) []models.AccessKey {
|
|||
}
|
||||
return returnKeys
|
||||
}
|
||||
|
||||
// == private methods ==
|
||||
|
||||
func genKeyName() string {
|
||||
|
||||
var seededRand *rand.Rand = rand.New(
|
||||
rand.NewSource(time.Now().UnixNano()))
|
||||
|
||||
length := 5
|
||||
|
||||
b := make([]byte, length)
|
||||
for i := range b {
|
||||
b[i] = charset[seededRand.Intn(len(charset))]
|
||||
}
|
||||
return "key" + string(b)
|
||||
}
|
||||
|
||||
func genKey() string {
|
||||
|
||||
var seededRand *rand.Rand = rand.New(
|
||||
rand.NewSource(time.Now().UnixNano()))
|
||||
|
||||
length := 16
|
||||
|
||||
b := make([]byte, length)
|
||||
for i := range b {
|
||||
b[i] = charset[seededRand.Intn(len(charset))]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
|
103
logic/dns.go
103
logic/dns.go
|
@ -5,6 +5,7 @@ import (
|
|||
"io/ioutil"
|
||||
"os"
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
"github.com/gravitl/netmaker/database"
|
||||
"github.com/gravitl/netmaker/logger"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
|
@ -140,3 +141,105 @@ func SetCorefile(domains string) error {
|
|||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// GetAllDNS - gets all dns entries
|
||||
func GetAllDNS() ([]models.DNSEntry, error) {
|
||||
var dns []models.DNSEntry
|
||||
networks, err := GetNetworks()
|
||||
if err != nil && !database.IsEmptyRecord(err) {
|
||||
return []models.DNSEntry{}, err
|
||||
}
|
||||
for _, net := range networks {
|
||||
netdns, err := GetDNS(net.NetID)
|
||||
if err != nil {
|
||||
return []models.DNSEntry{}, nil
|
||||
}
|
||||
dns = append(dns, netdns...)
|
||||
}
|
||||
return dns, nil
|
||||
}
|
||||
|
||||
// GetDNSEntryNum - gets which entry the dns was
|
||||
func GetDNSEntryNum(domain string, network string) (int, error) {
|
||||
|
||||
num := 0
|
||||
|
||||
entries, err := GetDNS(network)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
for i := 0; i < len(entries); i++ {
|
||||
|
||||
if domain == entries[i].Name {
|
||||
num++
|
||||
}
|
||||
}
|
||||
|
||||
return num, nil
|
||||
}
|
||||
|
||||
// ValidateDNSCreate - checks if an entry is valid
|
||||
func ValidateDNSCreate(entry models.DNSEntry) error {
|
||||
|
||||
v := validator.New()
|
||||
|
||||
_ = v.RegisterValidation("name_unique", func(fl validator.FieldLevel) bool {
|
||||
num, err := GetDNSEntryNum(entry.Name, entry.Network)
|
||||
return err == nil && num == 0
|
||||
})
|
||||
|
||||
_ = v.RegisterValidation("network_exists", func(fl validator.FieldLevel) bool {
|
||||
_, err := GetParentNetwork(entry.Network)
|
||||
return err == nil
|
||||
})
|
||||
|
||||
err := v.Struct(entry)
|
||||
if err != nil {
|
||||
for _, e := range err.(validator.ValidationErrors) {
|
||||
logger.Log(1, e.Error())
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// ValidateDNSUpdate - validates a DNS update
|
||||
func ValidateDNSUpdate(change models.DNSEntry, entry models.DNSEntry) error {
|
||||
|
||||
v := validator.New()
|
||||
|
||||
_ = v.RegisterValidation("name_unique", func(fl validator.FieldLevel) bool {
|
||||
//if name & net not changing name we are good
|
||||
if change.Name == entry.Name && change.Network == entry.Network {
|
||||
return true
|
||||
}
|
||||
num, err := GetDNSEntryNum(change.Name, change.Network)
|
||||
return err == nil && num == 0
|
||||
})
|
||||
_ = v.RegisterValidation("network_exists", func(fl validator.FieldLevel) bool {
|
||||
_, err := GetParentNetwork(change.Network)
|
||||
if err != nil {
|
||||
logger.Log(0, err.Error())
|
||||
}
|
||||
return err == nil
|
||||
})
|
||||
|
||||
err := v.Struct(change)
|
||||
|
||||
if err != nil {
|
||||
for _, e := range err.(validator.ValidationErrors) {
|
||||
logger.Log(1, e.Error())
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteDNS - deletes a DNS entry
|
||||
func DeleteDNS(domain string, network string) error {
|
||||
key, err := GetRecordKey(domain, network)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = database.DeleteRecord(database.DNS_TABLE_NAME, key)
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -2,10 +2,12 @@ package logic
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/gravitl/netmaker/database"
|
||||
"github.com/gravitl/netmaker/logger"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
// GetExtPeersList - gets the ext peers lists
|
||||
|
@ -63,3 +65,103 @@ func GetEgressRangesOnNetwork(client *models.ExtClient) ([]string, error) {
|
|||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// DeleteExtClient - deletes an existing ext client
|
||||
func DeleteExtClient(network string, clientid string) error {
|
||||
key, err := GetRecordKey(clientid, network)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = database.DeleteRecord(database.EXT_CLIENT_TABLE_NAME, key)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetNetworkExtClients - gets the ext clients of given network
|
||||
func GetNetworkExtClients(network string) ([]models.ExtClient, error) {
|
||||
var extclients []models.ExtClient
|
||||
|
||||
records, err := database.FetchRecords(database.EXT_CLIENT_TABLE_NAME)
|
||||
if err != nil {
|
||||
return extclients, err
|
||||
}
|
||||
for _, value := range records {
|
||||
var extclient models.ExtClient
|
||||
err = json.Unmarshal([]byte(value), &extclient)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if extclient.Network == network {
|
||||
extclients = append(extclients, extclient)
|
||||
}
|
||||
}
|
||||
return extclients, err
|
||||
}
|
||||
|
||||
// GetExtClient - gets a single ext client on a network
|
||||
func GetExtClient(clientid string, network string) (models.ExtClient, error) {
|
||||
var extclient models.ExtClient
|
||||
key, err := GetRecordKey(clientid, network)
|
||||
if err != nil {
|
||||
return extclient, err
|
||||
}
|
||||
data, err := database.FetchRecord(database.EXT_CLIENT_TABLE_NAME, key)
|
||||
if err != nil {
|
||||
return extclient, err
|
||||
}
|
||||
err = json.Unmarshal([]byte(data), &extclient)
|
||||
|
||||
return extclient, err
|
||||
}
|
||||
|
||||
// CreateExtClient - creates an extclient
|
||||
func CreateExtClient(extclient models.ExtClient) error {
|
||||
if extclient.PrivateKey == "" {
|
||||
privateKey, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
extclient.PrivateKey = privateKey.String()
|
||||
extclient.PublicKey = privateKey.PublicKey().String()
|
||||
}
|
||||
|
||||
if extclient.Address == "" {
|
||||
newAddress, err := UniqueAddress(extclient.Network)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
extclient.Address = newAddress
|
||||
}
|
||||
|
||||
if extclient.ClientID == "" {
|
||||
extclient.ClientID = models.GenerateNodeName()
|
||||
}
|
||||
|
||||
extclient.LastModified = time.Now().Unix()
|
||||
|
||||
key, err := GetRecordKey(extclient.ClientID, extclient.Network)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data, err := json.Marshal(&extclient)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err = database.Insert(key, string(data), database.EXT_CLIENT_TABLE_NAME); err != nil {
|
||||
return err
|
||||
}
|
||||
err = SetNetworkNodesLastModified(extclient.Network)
|
||||
return err
|
||||
}
|
||||
|
||||
// UpdateExtClient - only supports name changes right now
|
||||
func UpdateExtClient(newclientid string, network string, client models.ExtClient) (models.ExtClient, error) {
|
||||
|
||||
err := DeleteExtClient(network, client.ClientID)
|
||||
if err != nil {
|
||||
return client, err
|
||||
}
|
||||
client.ClientID = newclientid
|
||||
CreateExtClient(client)
|
||||
return client, err
|
||||
}
|
||||
|
|
221
logic/gateway.go
Normal file
221
logic/gateway.go
Normal file
|
@ -0,0 +1,221 @@
|
|||
package logic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gravitl/netmaker/database"
|
||||
"github.com/gravitl/netmaker/logger"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
)
|
||||
|
||||
// CreateEgressGateway - creates an egress gateway
|
||||
func CreateEgressGateway(gateway models.EgressGatewayRequest) (models.Node, error) {
|
||||
node, err := GetNodeByMacAddress(gateway.NetID, gateway.NodeID)
|
||||
if node.OS == "windows" || node.OS == "macos" { // add in darwin later
|
||||
return models.Node{}, errors.New(node.OS + " is unsupported for egress gateways")
|
||||
}
|
||||
if err != nil {
|
||||
return models.Node{}, err
|
||||
}
|
||||
err = ValidateEgressGateway(gateway)
|
||||
if err != nil {
|
||||
return models.Node{}, err
|
||||
}
|
||||
node.IsEgressGateway = "yes"
|
||||
node.EgressGatewayRanges = gateway.Ranges
|
||||
postUpCmd := "iptables -A FORWARD -i " + node.Interface + " -j ACCEPT; iptables -t nat -A POSTROUTING -o " + gateway.Interface + " -j MASQUERADE"
|
||||
postDownCmd := "iptables -D FORWARD -i " + node.Interface + " -j ACCEPT; iptables -t nat -D POSTROUTING -o " + gateway.Interface + " -j MASQUERADE"
|
||||
if gateway.PostUp != "" {
|
||||
postUpCmd = gateway.PostUp
|
||||
}
|
||||
if gateway.PostDown != "" {
|
||||
postDownCmd = gateway.PostDown
|
||||
}
|
||||
if node.PostUp != "" {
|
||||
if !strings.Contains(node.PostUp, postUpCmd) {
|
||||
postUpCmd = node.PostUp + "; " + postUpCmd
|
||||
}
|
||||
}
|
||||
if node.PostDown != "" {
|
||||
if !strings.Contains(node.PostDown, postDownCmd) {
|
||||
postDownCmd = node.PostDown + "; " + postDownCmd
|
||||
}
|
||||
}
|
||||
key, err := GetRecordKey(gateway.NodeID, gateway.NetID)
|
||||
if err != nil {
|
||||
return node, err
|
||||
}
|
||||
node.PostUp = postUpCmd
|
||||
node.PostDown = postDownCmd
|
||||
node.SetLastModified()
|
||||
node.PullChanges = "yes"
|
||||
nodeData, err := json.Marshal(&node)
|
||||
if err != nil {
|
||||
return node, err
|
||||
}
|
||||
if err = database.Insert(key, string(nodeData), database.NODES_TABLE_NAME); err != nil {
|
||||
return models.Node{}, err
|
||||
}
|
||||
if err = NetworkNodesUpdatePullChanges(node.Network); err != nil {
|
||||
return models.Node{}, err
|
||||
}
|
||||
return node, nil
|
||||
}
|
||||
|
||||
func ValidateEgressGateway(gateway models.EgressGatewayRequest) error {
|
||||
var err error
|
||||
|
||||
empty := len(gateway.Ranges) == 0
|
||||
if empty {
|
||||
err = errors.New("IP Ranges Cannot Be Empty")
|
||||
}
|
||||
empty = gateway.Interface == ""
|
||||
if empty {
|
||||
err = errors.New("Interface cannot be empty")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteEgressGateway - deletes egress from node
|
||||
func DeleteEgressGateway(network, macaddress string) (models.Node, error) {
|
||||
|
||||
node, err := GetNodeByMacAddress(network, macaddress)
|
||||
if err != nil {
|
||||
return models.Node{}, err
|
||||
}
|
||||
|
||||
node.IsEgressGateway = "no"
|
||||
node.EgressGatewayRanges = []string{}
|
||||
node.PostUp = ""
|
||||
node.PostDown = ""
|
||||
if node.IsIngressGateway == "yes" { // check if node is still an ingress gateway before completely deleting postdown/up rules
|
||||
node.PostUp = "iptables -A FORWARD -i " + node.Interface + " -j ACCEPT; iptables -t nat -A POSTROUTING -o " + node.Interface + " -j MASQUERADE"
|
||||
node.PostDown = "iptables -D FORWARD -i " + node.Interface + " -j ACCEPT; iptables -t nat -D POSTROUTING -o " + node.Interface + " -j MASQUERADE"
|
||||
}
|
||||
node.SetLastModified()
|
||||
node.PullChanges = "yes"
|
||||
key, err := GetRecordKey(node.MacAddress, node.Network)
|
||||
if err != nil {
|
||||
return models.Node{}, err
|
||||
}
|
||||
data, err := json.Marshal(&node)
|
||||
if err != nil {
|
||||
return models.Node{}, err
|
||||
}
|
||||
if err = database.Insert(key, string(data), database.NODES_TABLE_NAME); err != nil {
|
||||
return models.Node{}, err
|
||||
}
|
||||
if err = NetworkNodesUpdatePullChanges(network); err != nil {
|
||||
return models.Node{}, err
|
||||
}
|
||||
return node, nil
|
||||
}
|
||||
|
||||
// CreateIngressGateway - creates an ingress gateway
|
||||
func CreateIngressGateway(netid string, macaddress string) (models.Node, error) {
|
||||
|
||||
node, err := GetNodeByMacAddress(netid, macaddress)
|
||||
if node.OS == "windows" || node.OS == "macos" { // add in darwin later
|
||||
return models.Node{}, errors.New(node.OS + " is unsupported for ingress gateways")
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return models.Node{}, err
|
||||
}
|
||||
|
||||
network, err := GetParentNetwork(netid)
|
||||
if err != nil {
|
||||
return models.Node{}, err
|
||||
}
|
||||
node.IsIngressGateway = "yes"
|
||||
node.IngressGatewayRange = network.AddressRange
|
||||
postUpCmd := "iptables -A FORWARD -i " + node.Interface + " -j ACCEPT; iptables -t nat -A POSTROUTING -o " + node.Interface + " -j MASQUERADE"
|
||||
postDownCmd := "iptables -D FORWARD -i " + node.Interface + " -j ACCEPT; iptables -t nat -D POSTROUTING -o " + node.Interface + " -j MASQUERADE"
|
||||
if node.PostUp != "" {
|
||||
if !strings.Contains(node.PostUp, postUpCmd) {
|
||||
postUpCmd = node.PostUp + "; " + postUpCmd
|
||||
}
|
||||
}
|
||||
if node.PostDown != "" {
|
||||
if !strings.Contains(node.PostDown, postDownCmd) {
|
||||
postDownCmd = node.PostDown + "; " + postDownCmd
|
||||
}
|
||||
}
|
||||
node.SetLastModified()
|
||||
node.PostUp = postUpCmd
|
||||
node.PostDown = postDownCmd
|
||||
node.PullChanges = "yes"
|
||||
node.UDPHolePunch = "no"
|
||||
key, err := GetRecordKey(node.MacAddress, node.Network)
|
||||
if err != nil {
|
||||
return models.Node{}, err
|
||||
}
|
||||
data, err := json.Marshal(&node)
|
||||
if err != nil {
|
||||
return models.Node{}, err
|
||||
}
|
||||
err = database.Insert(key, string(data), database.NODES_TABLE_NAME)
|
||||
if err != nil {
|
||||
return models.Node{}, err
|
||||
}
|
||||
err = SetNetworkNodesLastModified(netid)
|
||||
return node, err
|
||||
}
|
||||
|
||||
// DeleteIngressGateway - deletes an ingress gateway
|
||||
func DeleteIngressGateway(networkName string, macaddress string) (models.Node, error) {
|
||||
|
||||
node, err := GetNodeByMacAddress(networkName, macaddress)
|
||||
if err != nil {
|
||||
return models.Node{}, err
|
||||
}
|
||||
network, err := GetParentNetwork(networkName)
|
||||
if err != nil {
|
||||
return models.Node{}, err
|
||||
}
|
||||
// delete ext clients belonging to ingress gateway
|
||||
if err = DeleteGatewayExtClients(macaddress, networkName); err != nil {
|
||||
return models.Node{}, err
|
||||
}
|
||||
|
||||
node.UDPHolePunch = network.DefaultUDPHolePunch
|
||||
node.LastModified = time.Now().Unix()
|
||||
node.IsIngressGateway = "no"
|
||||
node.IngressGatewayRange = ""
|
||||
node.PullChanges = "yes"
|
||||
|
||||
key, err := GetRecordKey(node.MacAddress, node.Network)
|
||||
if err != nil {
|
||||
return models.Node{}, err
|
||||
}
|
||||
data, err := json.Marshal(&node)
|
||||
if err != nil {
|
||||
return models.Node{}, err
|
||||
}
|
||||
err = database.Insert(key, string(data), database.NODES_TABLE_NAME)
|
||||
if err != nil {
|
||||
return models.Node{}, err
|
||||
}
|
||||
err = SetNetworkNodesLastModified(networkName)
|
||||
return node, err
|
||||
}
|
||||
|
||||
// DeleteGatewayExtClients - deletes ext clients based on gateway (mac) of ingress node and network
|
||||
func DeleteGatewayExtClients(gatewayID string, networkName string) error {
|
||||
currentExtClients, err := GetNetworkExtClients(networkName)
|
||||
if err != nil && !database.IsEmptyRecord(err) {
|
||||
return err
|
||||
}
|
||||
for _, extClient := range currentExtClients {
|
||||
if extClient.IngressGatewayID == gatewayID {
|
||||
if err = DeleteExtClient(networkName, extClient.ClientID); err != nil {
|
||||
logger.Log(1, "failed to remove ext client", extClient.ClientID)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -7,6 +7,7 @@ import (
|
|||
"net"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
"github.com/gravitl/netmaker/database"
|
||||
|
@ -38,6 +39,107 @@ func GetNetworks() ([]models.Network, error) {
|
|||
return networks, err
|
||||
}
|
||||
|
||||
// DeleteNetwork - deletes a network
|
||||
func DeleteNetwork(network string) error {
|
||||
nodeCount, err := GetNetworkNonServerNodeCount(network)
|
||||
if nodeCount == 0 || database.IsEmptyRecord(err) {
|
||||
// delete server nodes first then db records
|
||||
servers, err := GetSortedNetworkServerNodes(network)
|
||||
if err == nil {
|
||||
for _, s := range servers {
|
||||
if err = DeleteNode(&s, true); err != nil {
|
||||
logger.Log(2, "could not removed server", s.Name, "before deleting network", network)
|
||||
} else {
|
||||
logger.Log(2, "removed server", s.Name, "before deleting network", network)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
logger.Log(1, "could not remove servers before deleting network", network)
|
||||
}
|
||||
return database.DeleteRecord(database.NETWORKS_TABLE_NAME, network)
|
||||
}
|
||||
return errors.New("node check failed. All nodes must be deleted before deleting network")
|
||||
}
|
||||
|
||||
// CreateNetwork - creates a network in database
|
||||
func CreateNetwork(network models.Network) error {
|
||||
|
||||
network.SetDefaults()
|
||||
network.SetNodesLastModified()
|
||||
network.SetNetworkLastModified()
|
||||
network.KeyUpdateTimeStamp = time.Now().Unix()
|
||||
|
||||
err := ValidateNetwork(&network, false)
|
||||
if err != nil {
|
||||
//returnErrorResponse(w, r, formatError(err, "badrequest"))
|
||||
return err
|
||||
}
|
||||
|
||||
data, err := json.Marshal(&network)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err = database.Insert(network.NetID, string(data), database.NETWORKS_TABLE_NAME); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// NetworkNodesUpdatePullChanges - tells nodes on network to pull
|
||||
func NetworkNodesUpdatePullChanges(networkName string) error {
|
||||
|
||||
collections, err := database.FetchRecords(database.NODES_TABLE_NAME)
|
||||
if err != nil {
|
||||
if database.IsEmptyRecord(err) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
for _, value := range collections {
|
||||
var node models.Node
|
||||
err := json.Unmarshal([]byte(value), &node)
|
||||
if err != nil {
|
||||
fmt.Println("error in node address assignment!")
|
||||
return err
|
||||
}
|
||||
if node.Network == networkName {
|
||||
node.PullChanges = "yes"
|
||||
data, err := json.Marshal(&node)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
node.SetID()
|
||||
database.Insert(node.ID, string(data), database.NODES_TABLE_NAME)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetNetworkNonServerNodeCount - get number of network non server nodes
|
||||
func GetNetworkNonServerNodeCount(networkName string) (int, error) {
|
||||
|
||||
collection, err := database.FetchRecords(database.NODES_TABLE_NAME)
|
||||
count := 0
|
||||
if err != nil && !database.IsEmptyRecord(err) {
|
||||
return count, err
|
||||
}
|
||||
for _, value := range collection {
|
||||
var node models.Node
|
||||
if err = json.Unmarshal([]byte(value), &node); err != nil {
|
||||
return count, err
|
||||
} else {
|
||||
if node.Network == networkName && node.IsServer != "yes" {
|
||||
count++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// GetParentNetwork - get parent network
|
||||
func GetParentNetwork(networkname string) (models.Network, error) {
|
||||
|
||||
|
@ -462,8 +564,91 @@ func ValidateNetwork(network *models.Network, isUpdate bool) error {
|
|||
return err
|
||||
}
|
||||
|
||||
// ParseNetwork - parses a network into a model
|
||||
func ParseNetwork(value string) (models.Network, error) {
|
||||
var network models.Network
|
||||
err := json.Unmarshal([]byte(value), &network)
|
||||
return network, err
|
||||
}
|
||||
|
||||
// ValidateNetworkUpdate - checks if network is valid to update
|
||||
func ValidateNetworkUpdate(network models.Network) error {
|
||||
v := validator.New()
|
||||
|
||||
_ = v.RegisterValidation("netid_valid", func(fl validator.FieldLevel) bool {
|
||||
if fl.Field().String() == "" {
|
||||
return true
|
||||
}
|
||||
inCharSet := nameInNetworkCharSet(fl.Field().String())
|
||||
return inCharSet
|
||||
})
|
||||
|
||||
err := v.Struct(network)
|
||||
|
||||
if err != nil {
|
||||
for _, e := range err.(validator.ValidationErrors) {
|
||||
logger.Log(1, "validator", e.Error())
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// KeyUpdate - updates keys on network
|
||||
func KeyUpdate(netname string) (models.Network, error) {
|
||||
err := networkNodesUpdateAction(netname, models.NODE_UPDATE_KEY)
|
||||
if err != nil {
|
||||
return models.Network{}, err
|
||||
}
|
||||
return models.Network{}, nil
|
||||
}
|
||||
|
||||
// == Private ==
|
||||
|
||||
func networkNodesUpdateAction(networkName string, action string) error {
|
||||
|
||||
collections, err := database.FetchRecords(database.NODES_TABLE_NAME)
|
||||
if err != nil {
|
||||
if database.IsEmptyRecord(err) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
for _, value := range collections {
|
||||
var node models.Node
|
||||
err := json.Unmarshal([]byte(value), &node)
|
||||
if err != nil {
|
||||
fmt.Println("error in node address assignment!")
|
||||
return err
|
||||
}
|
||||
if action == models.NODE_UPDATE_KEY && node.IsStatic == "yes" {
|
||||
continue
|
||||
}
|
||||
if node.Network == networkName {
|
||||
node.Action = action
|
||||
data, err := json.Marshal(&node)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
node.SetID()
|
||||
database.Insert(node.ID, string(data), database.NODES_TABLE_NAME)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func nameInNetworkCharSet(name string) bool {
|
||||
|
||||
charset := "abcdefghijklmnopqrstuvwxyz1234567890-_."
|
||||
|
||||
for _, char := range name {
|
||||
if !strings.Contains(charset, strings.ToLower(string(char))) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func deleteInterface(ifacename string, postdown string) error {
|
||||
var err error
|
||||
if !ncutils.IsKernel() {
|
||||
|
|
|
@ -63,6 +63,28 @@ func GetSortedNetworkServerNodes(network string) ([]models.Node, error) {
|
|||
return nodes, nil
|
||||
}
|
||||
|
||||
// UncordonNode - approves a node to join a network
|
||||
func UncordonNode(network, macaddress string) (models.Node, error) {
|
||||
node, err := GetNodeByMacAddress(network, macaddress)
|
||||
if err != nil {
|
||||
return models.Node{}, err
|
||||
}
|
||||
node.SetLastModified()
|
||||
node.IsPending = "no"
|
||||
node.PullChanges = "yes"
|
||||
data, err := json.Marshal(&node)
|
||||
if err != nil {
|
||||
return node, err
|
||||
}
|
||||
key, err := GetRecordKey(node.MacAddress, node.Network)
|
||||
if err != nil {
|
||||
return node, err
|
||||
}
|
||||
|
||||
err = database.Insert(key, string(data), database.NODES_TABLE_NAME)
|
||||
return node, err
|
||||
}
|
||||
|
||||
// GetPeers - gets the peers of a given node
|
||||
func GetPeers(node models.Node) ([]models.Node, error) {
|
||||
if node.IsServer == "yes" && IsLeader(&node) {
|
||||
|
|
140
logic/relay.go
Normal file
140
logic/relay.go
Normal file
|
@ -0,0 +1,140 @@
|
|||
package logic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/gravitl/netmaker/database"
|
||||
"github.com/gravitl/netmaker/logger"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
)
|
||||
|
||||
// CreateRelay - creates a relay
|
||||
func CreateRelay(relay models.RelayRequest) (models.Node, error) {
|
||||
node, err := GetNodeByMacAddress(relay.NetID, relay.NodeID)
|
||||
if node.OS == "macos" { // add in darwin later
|
||||
return models.Node{}, errors.New(node.OS + " is unsupported for relay")
|
||||
}
|
||||
if err != nil {
|
||||
return models.Node{}, err
|
||||
}
|
||||
err = ValidateRelay(relay)
|
||||
if err != nil {
|
||||
return models.Node{}, err
|
||||
}
|
||||
node.IsRelay = "yes"
|
||||
node.RelayAddrs = relay.RelayAddrs
|
||||
|
||||
key, err := GetRecordKey(relay.NodeID, relay.NetID)
|
||||
if err != nil {
|
||||
return node, err
|
||||
}
|
||||
node.SetLastModified()
|
||||
node.PullChanges = "yes"
|
||||
nodeData, err := json.Marshal(&node)
|
||||
if err != nil {
|
||||
return node, err
|
||||
}
|
||||
if err = database.Insert(key, string(nodeData), database.NODES_TABLE_NAME); err != nil {
|
||||
return models.Node{}, err
|
||||
}
|
||||
err = SetRelayedNodes("yes", node.Network, node.RelayAddrs)
|
||||
if err != nil {
|
||||
return node, err
|
||||
}
|
||||
|
||||
if err = NetworkNodesUpdatePullChanges(node.Network); err != nil {
|
||||
return models.Node{}, err
|
||||
}
|
||||
return node, nil
|
||||
}
|
||||
|
||||
// SetRelayedNodes- set relayed nodes
|
||||
func SetRelayedNodes(yesOrno string, networkName string, addrs []string) error {
|
||||
|
||||
collections, err := database.FetchRecords(database.NODES_TABLE_NAME)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, value := range collections {
|
||||
|
||||
var node models.Node
|
||||
err := json.Unmarshal([]byte(value), &node)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if node.Network == networkName {
|
||||
for _, addr := range addrs {
|
||||
if addr == node.Address || addr == node.Address6 {
|
||||
node.IsRelayed = yesOrno
|
||||
data, err := json.Marshal(&node)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
node.SetID()
|
||||
database.Insert(node.ID, string(data), database.NODES_TABLE_NAME)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateRelay - checks if relay is valid
|
||||
func ValidateRelay(relay models.RelayRequest) error {
|
||||
var err error
|
||||
//isIp := functions.IsIpCIDR(gateway.RangeString)
|
||||
empty := len(relay.RelayAddrs) == 0
|
||||
if empty {
|
||||
err = errors.New("IP Ranges Cannot Be Empty")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// UpdateRelay - updates a relay
|
||||
func UpdateRelay(network string, oldAddrs []string, newAddrs []string) {
|
||||
time.Sleep(time.Second / 4)
|
||||
err := SetRelayedNodes("no", network, oldAddrs)
|
||||
if err != nil {
|
||||
logger.Log(1, err.Error())
|
||||
}
|
||||
err = SetRelayedNodes("yes", network, newAddrs)
|
||||
if err != nil {
|
||||
logger.Log(1, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// DeleteRelay - deletes a relay
|
||||
func DeleteRelay(network, macaddress string) (models.Node, error) {
|
||||
|
||||
node, err := GetNodeByMacAddress(network, macaddress)
|
||||
if err != nil {
|
||||
return models.Node{}, err
|
||||
}
|
||||
err = SetRelayedNodes("no", node.Network, node.RelayAddrs)
|
||||
if err != nil {
|
||||
return node, err
|
||||
}
|
||||
|
||||
node.IsRelay = "no"
|
||||
node.RelayAddrs = []string{}
|
||||
node.SetLastModified()
|
||||
node.PullChanges = "yes"
|
||||
key, err := GetRecordKey(node.MacAddress, node.Network)
|
||||
if err != nil {
|
||||
return models.Node{}, err
|
||||
}
|
||||
data, err := json.Marshal(&node)
|
||||
if err != nil {
|
||||
return models.Node{}, err
|
||||
}
|
||||
if err = database.Insert(key, string(data), database.NODES_TABLE_NAME); err != nil {
|
||||
return models.Node{}, err
|
||||
}
|
||||
if err = NetworkNodesUpdatePullChanges(network); err != nil {
|
||||
return models.Node{}, err
|
||||
}
|
||||
return node, nil
|
||||
}
|
Loading…
Add table
Reference in a new issue