organized http logic, renamed files

This commit is contained in:
0xdcarns 2021-12-07 12:46:55 -05:00
parent 6184d0b965
commit 0c6c09caa9
26 changed files with 1287 additions and 1485 deletions

189
controllers/dns.go Normal file
View file

@ -0,0 +1,189 @@
package controller
import (
"encoding/json"
"net/http"
"github.com/gorilla/mux"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/models"
)
func dnsHandlers(r *mux.Router) {
r.HandleFunc("/api/dns", securityCheck(true, http.HandlerFunc(getAllDNS))).Methods("GET")
r.HandleFunc("/api/dns/adm/{network}/nodes", securityCheck(false, http.HandlerFunc(getNodeDNS))).Methods("GET")
r.HandleFunc("/api/dns/adm/{network}/custom", securityCheck(false, http.HandlerFunc(getCustomDNS))).Methods("GET")
r.HandleFunc("/api/dns/adm/{network}", securityCheck(false, http.HandlerFunc(getDNS))).Methods("GET")
r.HandleFunc("/api/dns/{network}", securityCheck(false, http.HandlerFunc(createDNS))).Methods("POST")
r.HandleFunc("/api/dns/adm/pushdns", securityCheck(false, http.HandlerFunc(pushDNS))).Methods("POST")
r.HandleFunc("/api/dns/{network}/{domain}", securityCheck(false, http.HandlerFunc(deleteDNS))).Methods("DELETE")
}
//Gets all nodes associated with network, including pending nodes
func getNodeDNS(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
var dns []models.DNSEntry
var params = mux.Vars(r)
dns, err := logic.GetNodeDNS(params["network"])
if err != nil {
returnErrorResponse(w, r, formatError(err, "internal"))
return
}
//Returns all the nodes in JSON format
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(dns)
}
//Gets all nodes associated with network, including pending nodes
func getAllDNS(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
dns, err := logic.GetAllDNS()
if err != nil {
returnErrorResponse(w, r, formatError(err, "internal"))
return
}
//Returns all the nodes in JSON format
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(dns)
}
//Gets all nodes associated with network, including pending nodes
func getCustomDNS(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
var dns []models.DNSEntry
var params = mux.Vars(r)
dns, err := logic.GetCustomDNS(params["network"])
if err != nil {
returnErrorResponse(w, r, formatError(err, "internal"))
return
}
//Returns all the nodes in JSON format
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(dns)
}
// Gets all nodes associated with network, including pending nodes
func getDNS(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
var dns []models.DNSEntry
var params = mux.Vars(r)
dns, err := logic.GetDNS(params["network"])
if err != nil {
returnErrorResponse(w, r, formatError(err, "internal"))
return
}
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(dns)
}
func createDNS(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
var entry models.DNSEntry
var params = mux.Vars(r)
//get node from body of request
_ = json.NewDecoder(r.Body).Decode(&entry)
entry.Network = params["network"]
err := logic.ValidateDNSCreate(entry)
if err != nil {
returnErrorResponse(w, r, formatError(err, "badrequest"))
return
}
entry, err = CreateDNS(entry)
if err != nil {
returnErrorResponse(w, r, formatError(err, "internal"))
return
}
err = logic.SetDNS()
if err != nil {
returnErrorResponse(w, r, formatError(err, "internal"))
return
}
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(entry)
}
func deleteDNS(w http.ResponseWriter, r *http.Request) {
// Set header
w.Header().Set("Content-Type", "application/json")
// get params
var params = mux.Vars(r)
err := logic.DeleteDNS(params["domain"], params["network"])
if err != nil {
returnErrorResponse(w, r, formatError(err, "internal"))
return
}
entrytext := params["domain"] + "." + params["network"]
logger.Log(1, "deleted dns entry: ", entrytext)
err = logic.SetDNS()
if err != nil {
returnErrorResponse(w, r, formatError(err, "internal"))
return
}
json.NewEncoder(w).Encode(entrytext + " deleted.")
}
// CreateDNS - creates a DNS entry
func CreateDNS(entry models.DNSEntry) (models.DNSEntry, error) {
data, err := json.Marshal(&entry)
if err != nil {
return models.DNSEntry{}, err
}
key, err := logic.GetRecordKey(entry.Name, entry.Network)
if err != nil {
return models.DNSEntry{}, err
}
err = database.Insert(key, string(data), database.DNS_TABLE_NAME)
return entry, err
}
// GetDNSEntry - gets a DNS entry
func GetDNSEntry(domain string, network string) (models.DNSEntry, error) {
var entry models.DNSEntry
key, err := logic.GetRecordKey(domain, network)
if err != nil {
return entry, err
}
record, err := database.FetchRecord(database.DNS_TABLE_NAME, key)
if err != nil {
return entry, err
}
err = json.Unmarshal([]byte(record), &entry)
return entry, err
}
func pushDNS(w http.ResponseWriter, r *http.Request) {
// Set header
w.Header().Set("Content-Type", "application/json")
err := logic.SetDNS()
if err != nil {
returnErrorResponse(w, r, formatError(err, "internal"))
return
}
logger.Log(1, r.Header.Get("user"), "pushed DNS updates to nameserver")
json.NewEncoder(w).Encode("DNS Pushed to CoreDNS")
}

View file

@ -1,396 +0,0 @@
package controller
import (
"encoding/json"
"net/http"
"github.com/go-playground/validator/v10"
"github.com/gorilla/mux"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/models"
)
func dnsHandlers(r *mux.Router) {
r.HandleFunc("/api/dns", securityCheck(true, http.HandlerFunc(getAllDNS))).Methods("GET")
r.HandleFunc("/api/dns/adm/{network}/nodes", securityCheck(false, http.HandlerFunc(getNodeDNS))).Methods("GET")
r.HandleFunc("/api/dns/adm/{network}/custom", securityCheck(false, http.HandlerFunc(getCustomDNS))).Methods("GET")
r.HandleFunc("/api/dns/adm/{network}", securityCheck(false, http.HandlerFunc(getDNS))).Methods("GET")
r.HandleFunc("/api/dns/{network}", securityCheck(false, http.HandlerFunc(createDNS))).Methods("POST")
r.HandleFunc("/api/dns/adm/pushdns", securityCheck(false, http.HandlerFunc(pushDNS))).Methods("POST")
r.HandleFunc("/api/dns/{network}/{domain}", securityCheck(false, http.HandlerFunc(deleteDNS))).Methods("DELETE")
r.HandleFunc("/api/dns/{network}/{domain}", securityCheck(false, http.HandlerFunc(updateDNS))).Methods("PUT")
}
//Gets all nodes associated with network, including pending nodes
func getNodeDNS(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
var dns []models.DNSEntry
var params = mux.Vars(r)
dns, err := GetNodeDNS(params["network"])
if err != nil {
returnErrorResponse(w, r, formatError(err, "internal"))
return
}
//Returns all the nodes in JSON format
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(dns)
}
//Gets all nodes associated with network, including pending nodes
func getAllDNS(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
dns, err := GetAllDNS()
if err != nil {
returnErrorResponse(w, r, formatError(err, "internal"))
return
}
//Returns all the nodes in JSON format
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(dns)
}
// GetAllDNS - gets all dns entries
func GetAllDNS() ([]models.DNSEntry, error) {
var dns []models.DNSEntry
networks, err := logic.GetNetworks()
if err != nil && !database.IsEmptyRecord(err) {
return []models.DNSEntry{}, err
}
for _, net := range networks {
netdns, err := logic.GetDNS(net.NetID)
if err != nil {
return []models.DNSEntry{}, nil
}
dns = append(dns, netdns...)
}
return dns, nil
}
// GetNodeDNS - gets node dns
func GetNodeDNS(network string) ([]models.DNSEntry, error) {
var dns []models.DNSEntry
collection, err := database.FetchRecords(database.NODES_TABLE_NAME)
if err != nil {
return dns, err
}
for _, value := range collection {
var entry models.DNSEntry
var node models.Node
if err = json.Unmarshal([]byte(value), &node); err != nil {
continue
}
if err = json.Unmarshal([]byte(value), &entry); node.Network == network && err == nil {
dns = append(dns, entry)
}
}
return dns, nil
}
//Gets all nodes associated with network, including pending nodes
func getCustomDNS(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
var dns []models.DNSEntry
var params = mux.Vars(r)
dns, err := logic.GetCustomDNS(params["network"])
if err != nil {
returnErrorResponse(w, r, formatError(err, "internal"))
return
}
//Returns all the nodes in JSON format
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(dns)
}
// GetDNSEntryNum - gets which entry the dns was
func GetDNSEntryNum(domain string, network string) (int, error) {
num := 0
entries, err := logic.GetDNS(network)
if err != nil {
return 0, err
}
for i := 0; i < len(entries); i++ {
if domain == entries[i].Name {
num++
}
}
return num, nil
}
// Gets all nodes associated with network, including pending nodes
func getDNS(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
var dns []models.DNSEntry
var params = mux.Vars(r)
dns, err := logic.GetDNS(params["network"])
if err != nil {
returnErrorResponse(w, r, formatError(err, "internal"))
return
}
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(dns)
}
func createDNS(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
var entry models.DNSEntry
var params = mux.Vars(r)
//get node from body of request
_ = json.NewDecoder(r.Body).Decode(&entry)
entry.Network = params["network"]
err := ValidateDNSCreate(entry)
if err != nil {
returnErrorResponse(w, r, formatError(err, "badrequest"))
return
}
entry, err = CreateDNS(entry)
if err != nil {
returnErrorResponse(w, r, formatError(err, "internal"))
return
}
err = logic.SetDNS()
if err != nil {
returnErrorResponse(w, r, formatError(err, "internal"))
return
}
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(entry)
}
func updateDNS(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
var params = mux.Vars(r)
var entry models.DNSEntry
//start here
entry, err := GetDNSEntry(params["domain"], params["network"])
if err != nil {
returnErrorResponse(w, r, formatError(err, "badrequest"))
return
}
var dnschange models.DNSEntry
// we decode our body request params
err = json.NewDecoder(r.Body).Decode(&dnschange)
if err != nil {
returnErrorResponse(w, r, formatError(err, "badrequest"))
return
}
// fill in any missing fields
if dnschange.Name == "" {
dnschange.Name = entry.Name
}
if dnschange.Network == "" {
dnschange.Network = entry.Network
}
if dnschange.Address == "" {
dnschange.Address = entry.Address
}
err = ValidateDNSUpdate(dnschange, entry)
if err != nil {
returnErrorResponse(w, r, formatError(err, "badrequest"))
return
}
entry, err = UpdateDNS(dnschange, entry)
if err != nil {
returnErrorResponse(w, r, formatError(err, "badrequest"))
return
}
err = logic.SetDNS()
if err != nil {
returnErrorResponse(w, r, formatError(err, "internal"))
return
}
json.NewEncoder(w).Encode(entry)
}
func deleteDNS(w http.ResponseWriter, r *http.Request) {
// Set header
w.Header().Set("Content-Type", "application/json")
// get params
var params = mux.Vars(r)
err := DeleteDNS(params["domain"], params["network"])
if err != nil {
returnErrorResponse(w, r, formatError(err, "internal"))
return
}
entrytext := params["domain"] + "." + params["network"]
logger.Log(1, "deleted dns entry: ", entrytext)
err = logic.SetDNS()
if err != nil {
returnErrorResponse(w, r, formatError(err, "internal"))
return
}
json.NewEncoder(w).Encode(entrytext + " deleted.")
}
// CreateDNS - creates a DNS entry
func CreateDNS(entry models.DNSEntry) (models.DNSEntry, error) {
data, err := json.Marshal(&entry)
if err != nil {
return models.DNSEntry{}, err
}
key, err := logic.GetRecordKey(entry.Name, entry.Network)
if err != nil {
return models.DNSEntry{}, err
}
err = database.Insert(key, string(data), database.DNS_TABLE_NAME)
return entry, err
}
// GetDNSEntry - gets a DNS entry
func GetDNSEntry(domain string, network string) (models.DNSEntry, error) {
var entry models.DNSEntry
key, err := logic.GetRecordKey(domain, network)
if err != nil {
return entry, err
}
record, err := database.FetchRecord(database.DNS_TABLE_NAME, key)
if err != nil {
return entry, err
}
err = json.Unmarshal([]byte(record), &entry)
return entry, err
}
// UpdateDNS - updates DNS entry
func UpdateDNS(dnschange models.DNSEntry, entry models.DNSEntry) (models.DNSEntry, error) {
key, err := logic.GetRecordKey(entry.Name, entry.Network)
if err != nil {
return entry, err
}
if dnschange.Name != "" {
entry.Name = dnschange.Name
}
if dnschange.Address != "" {
entry.Address = dnschange.Address
}
newkey, err := logic.GetRecordKey(entry.Name, entry.Network)
err = database.DeleteRecord(database.DNS_TABLE_NAME, key)
if err != nil {
return entry, err
}
data, err := json.Marshal(&entry)
err = database.Insert(newkey, string(data), database.DNS_TABLE_NAME)
return entry, err
}
// DeleteDNS - deletes a DNS entry
func DeleteDNS(domain string, network string) error {
key, err := logic.GetRecordKey(domain, network)
if err != nil {
return err
}
err = database.DeleteRecord(database.DNS_TABLE_NAME, key)
return err
}
func pushDNS(w http.ResponseWriter, r *http.Request) {
// Set header
w.Header().Set("Content-Type", "application/json")
err := logic.SetDNS()
if err != nil {
returnErrorResponse(w, r, formatError(err, "internal"))
return
}
logger.Log(1, r.Header.Get("user"), "pushed DNS updates to nameserver")
json.NewEncoder(w).Encode("DNS Pushed to CoreDNS")
}
// ValidateDNSCreate - checks if an entry is valid
func ValidateDNSCreate(entry models.DNSEntry) error {
v := validator.New()
_ = v.RegisterValidation("name_unique", func(fl validator.FieldLevel) bool {
num, err := GetDNSEntryNum(entry.Name, entry.Network)
return err == nil && num == 0
})
_ = v.RegisterValidation("network_exists", func(fl validator.FieldLevel) bool {
_, err := logic.GetParentNetwork(entry.Network)
return err == nil
})
err := v.Struct(entry)
if err != nil {
for _, e := range err.(validator.ValidationErrors) {
logger.Log(1, e.Error())
}
}
return err
}
// ValidateDNSUpdate - validates a DNS update
func ValidateDNSUpdate(change models.DNSEntry, entry models.DNSEntry) error {
v := validator.New()
_ = v.RegisterValidation("name_unique", func(fl validator.FieldLevel) bool {
//if name & net not changing name we are good
if change.Name == entry.Name && change.Network == entry.Network {
return true
}
num, err := GetDNSEntryNum(change.Name, change.Network)
return err == nil && num == 0
})
_ = v.RegisterValidation("network_exists", func(fl validator.FieldLevel) bool {
_, err := logic.GetParentNetwork(change.Network)
if err != nil {
logger.Log(0, err.Error())
}
return err == nil
})
err := v.Struct(change)
if err != nil {
for _, e := range err.(validator.ValidationErrors) {
logger.Log(1, e.Error())
}
}
return err
}

View file

@ -17,21 +17,21 @@ func TestGetAllDNS(t *testing.T) {
deleteAllNetworks()
createNet()
t.Run("NoEntries", func(t *testing.T) {
entries, err := GetAllDNS()
entries, err := logic.GetAllDNS()
assert.Nil(t, err)
assert.Equal(t, []models.DNSEntry(nil), entries)
})
t.Run("OneEntry", func(t *testing.T) {
entry := models.DNSEntry{"10.0.0.3", "newhost", "skynet"}
CreateDNS(entry)
entries, err := GetAllDNS()
entries, err := logic.GetAllDNS()
assert.Nil(t, err)
assert.Equal(t, 1, len(entries))
})
t.Run("MultipleEntry", func(t *testing.T) {
entry := models.DNSEntry{"10.0.0.7", "anotherhost", "skynet"}
CreateDNS(entry)
entries, err := GetAllDNS()
entries, err := logic.GetAllDNS()
assert.Nil(t, err)
assert.Equal(t, 2, len(entries))
})
@ -43,13 +43,13 @@ func TestGetNodeDNS(t *testing.T) {
deleteAllNetworks()
createNet()
t.Run("NoNodes", func(t *testing.T) {
dns, err := GetNodeDNS("skynet")
dns, err := logic.GetNodeDNS("skynet")
assert.EqualError(t, err, "could not find any records")
assert.Equal(t, []models.DNSEntry(nil), dns)
})
t.Run("NodeExists", func(t *testing.T) {
createTestNode()
dns, err := GetNodeDNS("skynet")
dns, err := logic.GetNodeDNS("skynet")
assert.Nil(t, err)
assert.Equal(t, "10.0.0.1", dns[0].Address)
})
@ -57,7 +57,7 @@ func TestGetNodeDNS(t *testing.T) {
createnode := models.Node{PublicKey: "DM5qhLAE20PG9BbfBCger+Ac9D2NDOwCtY1rbYDLf34=", Endpoint: "10.100.100.3", MacAddress: "01:02:03:04:05:07", Password: "password", Network: "skynet"}
_, err := logic.CreateNode(createnode, "skynet")
assert.Nil(t, err)
dns, err := GetNodeDNS("skynet")
dns, err := logic.GetNodeDNS("skynet")
assert.Nil(t, err)
assert.Equal(t, 2, len(dns))
})
@ -105,7 +105,7 @@ func TestGetDNSEntryNum(t *testing.T) {
deleteAllNetworks()
createNet()
t.Run("NoNodes", func(t *testing.T) {
num, err := GetDNSEntryNum("myhost", "skynet")
num, err := logic.GetDNSEntryNum("myhost", "skynet")
assert.Nil(t, err)
assert.Equal(t, 0, num)
})
@ -113,7 +113,7 @@ func TestGetDNSEntryNum(t *testing.T) {
entry := models.DNSEntry{"10.0.0.2", "newhost", "skynet"}
_, err := CreateDNS(entry)
assert.Nil(t, err)
num, err := GetDNSEntryNum("newhost", "skynet")
num, err := logic.GetDNSEntryNum("newhost", "skynet")
assert.Nil(t, err)
assert.Equal(t, 1, num)
})
@ -248,33 +248,34 @@ func TestGetDNSEntry(t *testing.T) {
assert.Equal(t, models.DNSEntry{}, entry)
})
}
func TestUpdateDNS(t *testing.T) {
var newentry models.DNSEntry
database.InitializeDatabase()
deleteAllDNS(t)
deleteAllNetworks()
createNet()
entry := models.DNSEntry{"10.0.0.2", "newhost", "skynet"}
CreateDNS(entry)
t.Run("change address", func(t *testing.T) {
newentry.Address = "10.0.0.75"
updated, err := UpdateDNS(newentry, entry)
assert.Nil(t, err)
assert.Equal(t, newentry.Address, updated.Address)
})
t.Run("change name", func(t *testing.T) {
newentry.Name = "newname"
updated, err := UpdateDNS(newentry, entry)
assert.Nil(t, err)
assert.Equal(t, newentry.Name, updated.Name)
})
t.Run("change network", func(t *testing.T) {
newentry.Network = "wirecat"
updated, err := UpdateDNS(newentry, entry)
assert.Nil(t, err)
assert.NotEqual(t, newentry.Network, updated.Network)
})
}
// func TestUpdateDNS(t *testing.T) {
// var newentry models.DNSEntry
// database.InitializeDatabase()
// deleteAllDNS(t)
// deleteAllNetworks()
// createNet()
// entry := models.DNSEntry{"10.0.0.2", "newhost", "skynet"}
// CreateDNS(entry)
// t.Run("change address", func(t *testing.T) {
// newentry.Address = "10.0.0.75"
// updated, err := UpdateDNS(newentry, entry)
// assert.Nil(t, err)
// assert.Equal(t, newentry.Address, updated.Address)
// })
// t.Run("change name", func(t *testing.T) {
// newentry.Name = "newname"
// updated, err := UpdateDNS(newentry, entry)
// assert.Nil(t, err)
// assert.Equal(t, newentry.Name, updated.Name)
// })
// t.Run("change network", func(t *testing.T) {
// newentry.Network = "wirecat"
// updated, err := UpdateDNS(newentry, entry)
// assert.Nil(t, err)
// assert.NotEqual(t, newentry.Network, updated.Network)
// })
// }
func TestDeleteDNS(t *testing.T) {
database.InitializeDatabase()
deleteAllDNS(t)
@ -283,16 +284,16 @@ func TestDeleteDNS(t *testing.T) {
entry := models.DNSEntry{"10.0.0.2", "newhost", "skynet"}
CreateDNS(entry)
t.Run("EntryExists", func(t *testing.T) {
err := DeleteDNS("newhost", "skynet")
err := logic.DeleteDNS("newhost", "skynet")
assert.Nil(t, err)
})
t.Run("NodeExists", func(t *testing.T) {
err := DeleteDNS("myhost", "skynet")
err := logic.DeleteDNS("myhost", "skynet")
assert.Nil(t, err)
})
t.Run("NoEntries", func(t *testing.T) {
err := DeleteDNS("myhost", "skynet")
err := logic.DeleteDNS("myhost", "skynet")
assert.Nil(t, err)
})
}
@ -305,34 +306,34 @@ func TestValidateDNSUpdate(t *testing.T) {
entry := models.DNSEntry{"10.0.0.2", "myhost", "skynet"}
t.Run("BadNetwork", func(t *testing.T) {
change := models.DNSEntry{"10.0.0.2", "myhost", "badnet"}
err := ValidateDNSUpdate(change, entry)
err := logic.ValidateDNSUpdate(change, entry)
assert.NotNil(t, err)
assert.Contains(t, err.Error(), "Field validation for 'Network' failed on the 'network_exists' tag")
})
t.Run("EmptyNetwork", func(t *testing.T) {
//this can't actually happen as change.Network is populated if is blank
change := models.DNSEntry{"10.0.0.2", "myhost", ""}
err := ValidateDNSUpdate(change, entry)
err := logic.ValidateDNSUpdate(change, entry)
assert.NotNil(t, err)
assert.Contains(t, err.Error(), "Field validation for 'Network' failed on the 'network_exists' tag")
})
t.Run("EmptyAddress", func(t *testing.T) {
//this can't actually happen as change.Address is populated if is blank
change := models.DNSEntry{"", "myhost", "skynet"}
err := ValidateDNSUpdate(change, entry)
err := logic.ValidateDNSUpdate(change, entry)
assert.NotNil(t, err)
assert.Contains(t, err.Error(), "Field validation for 'Address' failed on the 'required' tag")
})
t.Run("BadAddress", func(t *testing.T) {
change := models.DNSEntry{"10.0.256.1", "myhost", "skynet"}
err := ValidateDNSUpdate(change, entry)
err := logic.ValidateDNSUpdate(change, entry)
assert.NotNil(t, err)
assert.Contains(t, err.Error(), "Field validation for 'Address' failed on the 'ip' tag")
})
t.Run("EmptyName", func(t *testing.T) {
//this can't actually happen as change.Name is populated if is blank
change := models.DNSEntry{"10.0.0.2", "", "skynet"}
err := ValidateDNSUpdate(change, entry)
err := logic.ValidateDNSUpdate(change, entry)
assert.NotNil(t, err)
assert.Contains(t, err.Error(), "Field validation for 'Name' failed on the 'required' tag")
})
@ -342,7 +343,7 @@ func TestValidateDNSUpdate(t *testing.T) {
name = name + "a"
}
change := models.DNSEntry{"10.0.0.2", name, "skynet"}
err := ValidateDNSUpdate(change, entry)
err := logic.ValidateDNSUpdate(change, entry)
assert.NotNil(t, err)
assert.Contains(t, err.Error(), "Field validation for 'Name' failed on the 'max' tag")
})
@ -350,39 +351,39 @@ func TestValidateDNSUpdate(t *testing.T) {
change := models.DNSEntry{"10.0.0.2", "myhost", "wirecat"}
CreateDNS(entry)
CreateDNS(change)
err := ValidateDNSUpdate(change, entry)
err := logic.ValidateDNSUpdate(change, entry)
assert.NotNil(t, err)
assert.Contains(t, err.Error(), "Field validation for 'Name' failed on the 'name_unique' tag")
//cleanup
err = DeleteDNS("myhost", "wirecat")
err = logic.DeleteDNS("myhost", "wirecat")
assert.Nil(t, err)
})
}
func TestValidateDNSCreate(t *testing.T) {
database.InitializeDatabase()
_ = DeleteDNS("mynode", "skynet")
_ = logic.DeleteDNS("mynode", "skynet")
t.Run("NoNetwork", func(t *testing.T) {
entry := models.DNSEntry{"10.0.0.2", "myhost", "badnet"}
err := ValidateDNSCreate(entry)
err := logic.ValidateDNSCreate(entry)
assert.NotNil(t, err)
assert.Contains(t, err.Error(), "Field validation for 'Network' failed on the 'network_exists' tag")
})
t.Run("EmptyAddress", func(t *testing.T) {
entry := models.DNSEntry{"", "myhost", "skynet"}
err := ValidateDNSCreate(entry)
err := logic.ValidateDNSCreate(entry)
assert.NotNil(t, err)
assert.Contains(t, err.Error(), "Field validation for 'Address' failed on the 'required' tag")
})
t.Run("BadAddress", func(t *testing.T) {
entry := models.DNSEntry{"10.0.256.1", "myhost", "skynet"}
err := ValidateDNSCreate(entry)
err := logic.ValidateDNSCreate(entry)
assert.NotNil(t, err)
assert.Contains(t, err.Error(), "Field validation for 'Address' failed on the 'ip' tag")
})
t.Run("EmptyName", func(t *testing.T) {
entry := models.DNSEntry{"10.0.0.2", "", "skynet"}
err := ValidateDNSCreate(entry)
err := logic.ValidateDNSCreate(entry)
assert.NotNil(t, err)
assert.Contains(t, err.Error(), "Field validation for 'Name' failed on the 'required' tag")
})
@ -392,24 +393,24 @@ func TestValidateDNSCreate(t *testing.T) {
name = name + "a"
}
entry := models.DNSEntry{"10.0.0.2", name, "skynet"}
err := ValidateDNSCreate(entry)
err := logic.ValidateDNSCreate(entry)
assert.NotNil(t, err)
assert.Contains(t, err.Error(), "Field validation for 'Name' failed on the 'max' tag")
})
t.Run("NameUnique", func(t *testing.T) {
entry := models.DNSEntry{"10.0.0.2", "myhost", "skynet"}
_, _ = CreateDNS(entry)
err := ValidateDNSCreate(entry)
err := logic.ValidateDNSCreate(entry)
assert.NotNil(t, err)
assert.Contains(t, err.Error(), "Field validation for 'Name' failed on the 'name_unique' tag")
})
}
func deleteAllDNS(t *testing.T) {
dns, err := GetAllDNS()
dns, err := logic.GetAllDNS()
assert.Nil(t, err)
for _, record := range dns {
err := DeleteDNS(record.Name, record.Network)
err := logic.DeleteDNS(record.Name, record.Network)
assert.Nil(t, err)
}
}

View file

@ -7,7 +7,6 @@ import (
"io"
"net/http"
"strconv"
"time"
"github.com/gorilla/mux"
"github.com/gravitl/netmaker/database"
@ -16,7 +15,6 @@ import (
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/models"
"github.com/skip2/go-qrcode"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
func extClientHandlers(r *mux.Router) {
@ -45,7 +43,7 @@ func getNetworkExtClients(w http.ResponseWriter, r *http.Request) {
var extclients []models.ExtClient
var params = mux.Vars(r)
extclients, err := GetNetworkExtClients(params["network"])
extclients, err := logic.GetNetworkExtClients(params["network"])
if err != nil {
returnErrorResponse(w, r, formatError(err, "internal"))
return
@ -56,27 +54,6 @@ func getNetworkExtClients(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(extclients)
}
// GetNetworkExtClients - gets the ext clients of given network
func GetNetworkExtClients(network string) ([]models.ExtClient, error) {
var extclients []models.ExtClient
records, err := database.FetchRecords(database.EXT_CLIENT_TABLE_NAME)
if err != nil {
return extclients, err
}
for _, value := range records {
var extclient models.ExtClient
err = json.Unmarshal([]byte(value), &extclient)
if err != nil {
continue
}
if extclient.Network == network {
extclients = append(extclients, extclient)
}
}
return extclients, err
}
//A separate function to get all extclients, not just extclients for a particular network.
//Not quite sure if this is necessary. Probably necessary based on front end but may want to review after iteration 1 if it's being used or not
func getAllExtClients(w http.ResponseWriter, r *http.Request) {
@ -100,7 +77,7 @@ func getAllExtClients(w http.ResponseWriter, r *http.Request) {
}
} else {
for _, network := range networksSlice {
extclients, err := GetNetworkExtClients(network)
extclients, err := logic.GetNetworkExtClients(network)
if err == nil {
clients = append(clients, extclients...)
}
@ -121,7 +98,7 @@ func getExtClient(w http.ResponseWriter, r *http.Request) {
clientid := params["clientid"]
network := params["network"]
client, err := GetExtClient(clientid, network)
client, err := logic.GetExtClient(clientid, network)
if err != nil {
returnErrorResponse(w, r, formatError(err, "internal"))
return
@ -131,22 +108,6 @@ func getExtClient(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(client)
}
// GetExtClient - gets a single ext client on a network
func GetExtClient(clientid string, network string) (models.ExtClient, error) {
var extclient models.ExtClient
key, err := logic.GetRecordKey(clientid, network)
if err != nil {
return extclient, err
}
data, err := database.FetchRecord(database.EXT_CLIENT_TABLE_NAME, key)
if err != nil {
return extclient, err
}
err = json.Unmarshal([]byte(data), &extclient)
return extclient, err
}
//Get an individual extclient. Nothin fancy here folks.
func getExtClientConf(w http.ResponseWriter, r *http.Request) {
// set header.
@ -155,7 +116,7 @@ func getExtClientConf(w http.ResponseWriter, r *http.Request) {
var params = mux.Vars(r)
clientid := params["clientid"]
networkid := params["network"]
client, err := GetExtClient(clientid, networkid)
client, err := logic.GetExtClient(clientid, networkid)
if err != nil {
returnErrorResponse(w, r, formatError(err, "internal"))
return
@ -240,47 +201,6 @@ Endpoint = %s
json.NewEncoder(w).Encode(client)
}
// CreateExtClient - creates an extclient
func CreateExtClient(extclient models.ExtClient) error {
if extclient.PrivateKey == "" {
privateKey, err := wgtypes.GeneratePrivateKey()
if err != nil {
return err
}
extclient.PrivateKey = privateKey.String()
extclient.PublicKey = privateKey.PublicKey().String()
}
if extclient.Address == "" {
newAddress, err := logic.UniqueAddress(extclient.Network)
if err != nil {
return err
}
extclient.Address = newAddress
}
if extclient.ClientID == "" {
extclient.ClientID = models.GenerateNodeName()
}
extclient.LastModified = time.Now().Unix()
key, err := logic.GetRecordKey(extclient.ClientID, extclient.Network)
if err != nil {
return err
}
data, err := json.Marshal(&extclient)
if err != nil {
return err
}
if err = database.Insert(key, string(data), database.EXT_CLIENT_TABLE_NAME); err != nil {
return err
}
err = logic.SetNetworkNodesLastModified(extclient.Network)
return err
}
/**
* To create a extclient
* Must have valid key and be unique
@ -312,7 +232,7 @@ func createExtClient(w http.ResponseWriter, r *http.Request) {
returnErrorResponse(w, r, formatError(err, "internal"))
return
}
err = CreateExtClient(extclient)
err = logic.CreateExtClient(extclient)
if err != nil {
returnErrorResponse(w, r, formatError(err, "internal"))
@ -344,7 +264,7 @@ func updateExtClient(w http.ResponseWriter, r *http.Request) {
returnErrorResponse(w, r, formatError(err, "internal"))
return
}
newclient, err := UpdateExtClient(newExtClient.ClientID, params["network"], oldExtClient)
newclient, err := logic.UpdateExtClient(newExtClient.ClientID, params["network"], oldExtClient)
if err != nil {
returnErrorResponse(w, r, formatError(err, "internal"))
return
@ -354,45 +274,6 @@ func updateExtClient(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(newclient)
}
// UpdateExtClient - only supports name changes right now
func UpdateExtClient(newclientid string, network string, client models.ExtClient) (models.ExtClient, error) {
err := DeleteExtClient(network, client.ClientID)
if err != nil {
return client, err
}
client.ClientID = newclientid
CreateExtClient(client)
return client, err
}
// DeleteExtClient - deletes an existing ext client
func DeleteExtClient(network string, clientid string) error {
key, err := logic.GetRecordKey(clientid, network)
if err != nil {
return err
}
err = database.DeleteRecord(database.EXT_CLIENT_TABLE_NAME, key)
return err
}
// DeleteGatewayExtClients - deletes ext clients based on gateway (mac) of ingress node and network
func DeleteGatewayExtClients(gatewayID string, networkName string) error {
currentExtClients, err := GetNetworkExtClients(networkName)
if err != nil && !database.IsEmptyRecord(err) {
return err
}
for _, extClient := range currentExtClients {
if extClient.IngressGatewayID == gatewayID {
if err = DeleteExtClient(networkName, extClient.ClientID); err != nil {
logger.Log(1, "failed to remove ext client", extClient.ClientID)
continue
}
}
}
return nil
}
//Delete a extclient
//Pretty straightforward
func deleteExtClient(w http.ResponseWriter, r *http.Request) {
@ -402,7 +283,7 @@ func deleteExtClient(w http.ResponseWriter, r *http.Request) {
// get params
var params = mux.Vars(r)
err := DeleteExtClient(params["network"], params["clientid"])
err := logic.DeleteExtClient(params["network"], params["clientid"])
if err != nil {
err = errors.New("Could not delete extclient " + params["clientid"])

View file

@ -1,17 +1,13 @@
package controller
import (
"encoding/base64"
"encoding/json"
"errors"
"net/http"
"strings"
"time"
"github.com/go-playground/validator/v10"
"github.com/gorilla/mux"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/functions"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/models"
@ -32,7 +28,6 @@ func networkHandlers(r *mux.Router) {
r.HandleFunc("/api/networks/{networkname}/keyupdate", securityCheck(false, http.HandlerFunc(keyUpdate))).Methods("POST")
r.HandleFunc("/api/networks/{networkname}/keys", securityCheck(false, http.HandlerFunc(createAccessKey))).Methods("POST")
r.HandleFunc("/api/networks/{networkname}/keys", securityCheck(false, http.HandlerFunc(getAccessKeys))).Methods("GET")
r.HandleFunc("/api/networks/{networkname}/signuptoken", securityCheck(false, http.HandlerFunc(getSignupToken))).Methods("GET")
r.HandleFunc("/api/networks/{networkname}/keys/{name}", securityCheck(false, http.HandlerFunc(deleteAccessKey))).Methods("DELETE")
}
@ -73,34 +68,13 @@ func getNetworks(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(allnetworks)
}
func ValidateNetworkUpdate(network models.Network) error {
v := validator.New()
_ = v.RegisterValidation("netid_valid", func(fl validator.FieldLevel) bool {
if fl.Field().String() == "" {
return true
}
inCharSet := functions.NameInNetworkCharSet(fl.Field().String())
return inCharSet
})
err := v.Struct(network)
if err != nil {
for _, e := range err.(validator.ValidationErrors) {
logger.Log(1, "validator", e.Error())
}
}
return err
}
//Simple get network function
func getNetwork(w http.ResponseWriter, r *http.Request) {
// set header.
w.Header().Set("Content-Type", "application/json")
var params = mux.Vars(r)
netname := params["networkname"]
network, err := GetNetwork(netname)
network, err := logic.GetNetwork(netname)
if err != nil {
returnErrorResponse(w, r, formatError(err, "internal"))
return
@ -113,25 +87,11 @@ func getNetwork(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(network)
}
func GetNetwork(name string) (models.Network, error) {
var network models.Network
record, err := database.FetchRecord(database.NETWORKS_TABLE_NAME, name)
if err != nil {
return network, err
}
if err = json.Unmarshal([]byte(record), &network); err != nil {
return models.Network{}, err
}
return network, nil
}
func keyUpdate(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
var params = mux.Vars(r)
netname := params["networkname"]
network, err := KeyUpdate(netname)
network, err := logic.KeyUpdate(netname)
if err != nil {
returnErrorResponse(w, r, formatError(err, "internal"))
return
@ -141,33 +101,6 @@ func keyUpdate(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(network)
}
func KeyUpdate(netname string) (models.Network, error) {
err := functions.NetworkNodesUpdateAction(netname, models.NODE_UPDATE_KEY)
if err != nil {
return models.Network{}, err
}
return models.Network{}, nil
}
//Update a network
func AlertNetwork(netid string) error {
var network models.Network
network, err := logic.GetParentNetwork(netid)
if err != nil {
return err
}
updatetime := time.Now().Unix()
network.NodesLastModified = updatetime
network.NetworkLastModified = updatetime
data, err := json.Marshal(&network)
if err != nil {
return err
}
database.Insert(netid, string(data), database.NETWORKS_TABLE_NAME)
return nil
}
//Update a network
func updateNetwork(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
@ -253,7 +186,7 @@ func deleteNetwork(w http.ResponseWriter, r *http.Request) {
var params = mux.Vars(r)
network := params["networkname"]
err := DeleteNetwork(network)
err := logic.DeleteNetwork(network)
if err != nil {
errtype := "badrequest"
@ -268,29 +201,6 @@ func deleteNetwork(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode("success")
}
func DeleteNetwork(network string) error {
nodeCount, err := functions.GetNetworkNonServerNodeCount(network)
if nodeCount == 0 || database.IsEmptyRecord(err) {
// delete server nodes first then db records
servers, err := logic.GetSortedNetworkServerNodes(network)
if err == nil {
for _, s := range servers {
if err = logic.DeleteNode(&s, true); err != nil {
logger.Log(2, "could not removed server", s.Name, "before deleting network", network)
} else {
logger.Log(2, "removed server", s.Name, "before deleting network", network)
}
}
} else {
logger.Log(1, "could not remove servers before deleting network", network)
}
return database.DeleteRecord(database.NETWORKS_TABLE_NAME, network)
}
return errors.New("node check failed. All nodes must be deleted before deleting network")
}
//Create a network
//Pretty simple
func createNetwork(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
@ -304,49 +214,27 @@ func createNetwork(w http.ResponseWriter, r *http.Request) {
return
}
err = CreateNetwork(network)
err = logic.CreateNetwork(network)
if err != nil {
returnErrorResponse(w, r, formatError(err, "badrequest"))
return
}
logger.Log(1, r.Header.Get("user"), "created network", network.NetID)
w.WriteHeader(http.StatusOK)
//json.NewEncoder(w).Encode(result)
}
func CreateNetwork(network models.Network) error {
network.SetDefaults()
network.SetNodesLastModified()
network.SetNetworkLastModified()
network.KeyUpdateTimeStamp = time.Now().Unix()
err := logic.ValidateNetwork(&network, false)
if err != nil {
//returnErrorResponse(w, r, formatError(err, "badrequest"))
return err
}
data, err := json.Marshal(&network)
if err != nil {
return err
}
if err = database.Insert(network.NetID, string(data), database.NETWORKS_TABLE_NAME); err != nil {
return err
}
if servercfg.IsClientMode() != "off" {
var success bool
success, err = serverctl.AddNetwork(network.NetID)
if err != nil || !success {
DeleteNetwork(network.NetID)
logic.DeleteNetwork(network.NetID)
if err == nil {
err = errors.New("Failed to add server to network " + network.DisplayName)
}
returnErrorResponse(w, r, formatError(err, "internal"))
return
}
}
return err
logger.Log(1, r.Header.Get("user"), "created network", network.NetID)
w.WriteHeader(http.StatusOK)
}
// BEGIN KEY MANAGEMENT SECTION
@ -366,7 +254,7 @@ func createAccessKey(w http.ResponseWriter, r *http.Request) {
returnErrorResponse(w, r, formatError(err, "internal"))
return
}
key, err := CreateAccessKey(accesskey, network)
key, err := logic.CreateAccessKey(accesskey, network)
if err != nil {
returnErrorResponse(w, r, formatError(err, "badrequest"))
return
@ -374,131 +262,14 @@ func createAccessKey(w http.ResponseWriter, r *http.Request) {
logger.Log(1, r.Header.Get("user"), "created access key", accesskey.Name, "on", netname)
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(key)
//w.Write([]byte(accesskey.AccessString))
}
func CreateAccessKey(accesskey models.AccessKey, network models.Network) (models.AccessKey, error) {
if accesskey.Name == "" {
accesskey.Name = functions.GenKeyName()
}
if accesskey.Value == "" {
accesskey.Value = functions.GenKey()
}
if accesskey.Uses == 0 {
accesskey.Uses = 1
}
checkkeys, err := GetKeys(network.NetID)
if err != nil {
return models.AccessKey{}, errors.New("could not retrieve network keys")
}
for _, key := range checkkeys {
if key.Name == accesskey.Name {
return models.AccessKey{}, errors.New("duplicate AccessKey Name")
}
}
privAddr := ""
if network.IsLocal != "" {
privAddr = network.LocalRange
}
netID := network.NetID
var accessToken models.AccessToken
s := servercfg.GetServerConfig()
servervals := models.ServerConfig{
CoreDNSAddr: s.CoreDNSAddr,
APIConnString: s.APIConnString,
APIHost: s.APIHost,
APIPort: s.APIPort,
GRPCConnString: s.GRPCConnString,
GRPCHost: s.GRPCHost,
GRPCPort: s.GRPCPort,
GRPCSSL: s.GRPCSSL,
CheckinInterval: s.CheckinInterval,
}
accessToken.ServerConfig = servervals
accessToken.ClientConfig.Network = netID
accessToken.ClientConfig.Key = accesskey.Value
accessToken.ClientConfig.LocalRange = privAddr
tokenjson, err := json.Marshal(accessToken)
if err != nil {
return accesskey, err
}
accesskey.AccessString = base64.StdEncoding.EncodeToString([]byte(tokenjson))
//validate accesskey
v := validator.New()
err = v.Struct(accesskey)
if err != nil {
for _, e := range err.(validator.ValidationErrors) {
logger.Log(1, "validator", e.Error())
}
return models.AccessKey{}, err
}
network.AccessKeys = append(network.AccessKeys, accesskey)
data, err := json.Marshal(&network)
if err != nil {
return models.AccessKey{}, err
}
if err = database.Insert(network.NetID, string(data), database.NETWORKS_TABLE_NAME); err != nil {
return models.AccessKey{}, err
}
return accesskey, nil
}
func GetSignupToken(netID string) (models.AccessKey, error) {
var accesskey models.AccessKey
var accessToken models.AccessToken
s := servercfg.GetServerConfig()
servervals := models.ServerConfig{
APIConnString: s.APIConnString,
APIHost: s.APIHost,
APIPort: s.APIPort,
GRPCConnString: s.GRPCConnString,
GRPCHost: s.GRPCHost,
GRPCPort: s.GRPCPort,
GRPCSSL: s.GRPCSSL,
}
accessToken.ServerConfig = servervals
tokenjson, err := json.Marshal(accessToken)
if err != nil {
return accesskey, err
}
accesskey.AccessString = base64.StdEncoding.EncodeToString([]byte(tokenjson))
return accesskey, nil
}
func getSignupToken(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
var params = mux.Vars(r)
netID := params["networkname"]
token, err := GetSignupToken(netID)
if err != nil {
returnErrorResponse(w, r, formatError(err, "internal"))
return
}
logger.Log(2, r.Header.Get("user"), "got signup token", netID)
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(token)
}
//pretty simple get
// pretty simple get
func getAccessKeys(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
var params = mux.Vars(r)
network := params["networkname"]
keys, err := GetKeys(network)
keys, err := logic.GetKeys(network)
if err != nil {
returnErrorResponse(w, r, formatError(err, "internal"))
return
@ -510,26 +281,14 @@ func getAccessKeys(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(keys)
}
func GetKeys(net string) ([]models.AccessKey, error) {
record, err := database.FetchRecord(database.NETWORKS_TABLE_NAME, net)
if err != nil {
return []models.AccessKey{}, err
}
network, err := functions.ParseNetwork(record)
if err != nil {
return []models.AccessKey{}, err
}
return network.AccessKeys, nil
}
//delete key. Has to do a little funky logic since it's not a collection item
// delete key. Has to do a little funky logic since it's not a collection item
func deleteAccessKey(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
var params = mux.Vars(r)
keyname := params["name"]
netname := params["networkname"]
err := DeleteKey(keyname, netname)
err := logic.DeleteKey(keyname, netname)
if err != nil {
returnErrorResponse(w, r, formatError(err, "badrequest"))
return
@ -537,33 +296,3 @@ func deleteAccessKey(w http.ResponseWriter, r *http.Request) {
logger.Log(1, r.Header.Get("user"), "deleted access key", keyname, "on network,", netname)
w.WriteHeader(http.StatusOK)
}
func DeleteKey(keyname, netname string) error {
network, err := logic.GetParentNetwork(netname)
if err != nil {
return err
}
//basically, turn the list of access keys into the list of access keys before and after the item
//have not done any error handling for if there's like...1 item. I think it works? need to test.
found := false
var updatedKeys []models.AccessKey
for _, currentkey := range network.AccessKeys {
if currentkey.Name == keyname {
found = true
} else {
updatedKeys = append(updatedKeys, currentkey)
}
}
if !found {
return errors.New("key " + keyname + " does not exist")
}
network.AccessKeys = updatedKeys
data, err := json.Marshal(&network)
if err != nil {
return err
}
if err := database.Insert(network.NetID, string(data), database.NETWORKS_TABLE_NAME); err != nil {
return err
}
return nil
}

View file

@ -25,7 +25,7 @@ func TestCreateNetwork(t *testing.T) {
network.AddressRange = "10.0.0.1/24"
network.DisplayName = "mynetwork"
err := CreateNetwork(network)
err := logic.CreateNetwork(network)
assert.Nil(t, err)
}
func TestGetNetwork(t *testing.T) {
@ -33,12 +33,12 @@ func TestGetNetwork(t *testing.T) {
createNet()
t.Run("GetExistingNetwork", func(t *testing.T) {
network, err := GetNetwork("skynet")
network, err := logic.GetNetwork("skynet")
assert.Nil(t, err)
assert.Equal(t, "skynet", network.NetID)
})
t.Run("GetNonExistantNetwork", func(t *testing.T) {
network, err := GetNetwork("doesnotexist")
network, err := logic.GetNetwork("doesnotexist")
assert.EqualError(t, err, "no result found")
assert.Equal(t, "", network.NetID)
})
@ -51,11 +51,11 @@ func TestDeleteNetwork(t *testing.T) {
t.Run("NetworkwithNodes", func(t *testing.T) {
})
t.Run("DeleteExistingNetwork", func(t *testing.T) {
err := DeleteNetwork("skynet")
err := logic.DeleteNetwork("skynet")
assert.Nil(t, err)
})
t.Run("NonExistantNetwork", func(t *testing.T) {
err := DeleteNetwork("skynet")
err := logic.DeleteNetwork("skynet")
assert.Nil(t, err)
})
}
@ -64,12 +64,12 @@ func TestKeyUpdate(t *testing.T) {
t.Skip() //test is failing on last assert --- not sure why
database.InitializeDatabase()
createNet()
existing, err := GetNetwork("skynet")
existing, err := logic.GetNetwork("skynet")
assert.Nil(t, err)
time.Sleep(time.Second * 1)
network, err := KeyUpdate("skynet")
network, err := logic.KeyUpdate("skynet")
assert.Nil(t, err)
network, err = GetNetwork("skynet")
network, err = logic.GetNetwork("skynet")
assert.Nil(t, err)
assert.Greater(t, network.KeyUpdateTimeStamp, existing.KeyUpdateTimeStamp)
}
@ -77,70 +77,70 @@ func TestKeyUpdate(t *testing.T) {
func TestCreateKey(t *testing.T) {
database.InitializeDatabase()
createNet()
keys, _ := GetKeys("skynet")
keys, _ := logic.GetKeys("skynet")
for _, key := range keys {
DeleteKey(key.Name, "skynet")
logic.DeleteKey(key.Name, "skynet")
}
var accesskey models.AccessKey
var network models.Network
network.NetID = "skynet"
t.Run("NameTooLong", func(t *testing.T) {
network, err := GetNetwork("skynet")
network, err := logic.GetNetwork("skynet")
assert.Nil(t, err)
accesskey.Name = "Thisisareallylongkeynamethatwillfail"
_, err = CreateAccessKey(accesskey, network)
_, err = logic.CreateAccessKey(accesskey, network)
assert.NotNil(t, err)
assert.Contains(t, err.Error(), "Field validation for 'Name' failed on the 'max' tag")
})
t.Run("BlankName", func(t *testing.T) {
network, err := GetNetwork("skynet")
network, err := logic.GetNetwork("skynet")
assert.Nil(t, err)
accesskey.Name = ""
key, err := CreateAccessKey(accesskey, network)
key, err := logic.CreateAccessKey(accesskey, network)
assert.Nil(t, err)
assert.NotEqual(t, "", key.Name)
})
t.Run("InvalidValue", func(t *testing.T) {
network, err := GetNetwork("skynet")
network, err := logic.GetNetwork("skynet")
assert.Nil(t, err)
accesskey.Value = "bad-value"
_, err = CreateAccessKey(accesskey, network)
_, err = logic.CreateAccessKey(accesskey, network)
assert.NotNil(t, err)
assert.Contains(t, err.Error(), "Field validation for 'Value' failed on the 'alphanum' tag")
})
t.Run("BlankValue", func(t *testing.T) {
network, err := GetNetwork("skynet")
network, err := logic.GetNetwork("skynet")
assert.Nil(t, err)
accesskey.Name = "mykey"
accesskey.Value = ""
key, err := CreateAccessKey(accesskey, network)
key, err := logic.CreateAccessKey(accesskey, network)
assert.Nil(t, err)
assert.NotEqual(t, "", key.Value)
assert.Equal(t, accesskey.Name, key.Name)
})
t.Run("ValueTooLong", func(t *testing.T) {
network, err := GetNetwork("skynet")
network, err := logic.GetNetwork("skynet")
assert.Nil(t, err)
accesskey.Name = "keyname"
accesskey.Value = "AccessKeyValuethatistoolong"
_, err = CreateAccessKey(accesskey, network)
_, err = logic.CreateAccessKey(accesskey, network)
assert.NotNil(t, err)
assert.Contains(t, err.Error(), "Field validation for 'Value' failed on the 'max' tag")
})
t.Run("BlankUses", func(t *testing.T) {
network, err := GetNetwork("skynet")
network, err := logic.GetNetwork("skynet")
assert.Nil(t, err)
accesskey.Uses = 0
accesskey.Value = ""
key, err := CreateAccessKey(accesskey, network)
key, err := logic.CreateAccessKey(accesskey, network)
assert.Nil(t, err)
assert.Equal(t, 1, key.Uses)
})
t.Run("DuplicateKey", func(t *testing.T) {
network, err := GetNetwork("skynet")
network, err := logic.GetNetwork("skynet")
assert.Nil(t, err)
accesskey.Name = "mykey"
_, err = CreateAccessKey(accesskey, network)
_, err = logic.CreateAccessKey(accesskey, network)
assert.NotNil(t, err)
assert.EqualError(t, err, "duplicate AccessKey Name")
})
@ -150,21 +150,21 @@ func TestGetKeys(t *testing.T) {
database.InitializeDatabase()
deleteAllNetworks()
createNet()
network, err := GetNetwork("skynet")
network, err := logic.GetNetwork("skynet")
assert.Nil(t, err)
var key models.AccessKey
key.Name = "mykey"
_, err = CreateAccessKey(key, network)
_, err = logic.CreateAccessKey(key, network)
assert.Nil(t, err)
t.Run("KeyExists", func(t *testing.T) {
keys, err := GetKeys(network.NetID)
keys, err := logic.GetKeys(network.NetID)
assert.Nil(t, err)
assert.NotEqual(t, models.AccessKey{}, keys)
})
t.Run("NonExistantKey", func(t *testing.T) {
err := DeleteKey("mykey", "skynet")
err := logic.DeleteKey("mykey", "skynet")
assert.Nil(t, err)
keys, err := GetKeys(network.NetID)
keys, err := logic.GetKeys(network.NetID)
assert.Nil(t, err)
assert.Equal(t, []models.AccessKey(nil), keys)
})
@ -172,18 +172,18 @@ func TestGetKeys(t *testing.T) {
func TestDeleteKey(t *testing.T) {
database.InitializeDatabase()
createNet()
network, err := GetNetwork("skynet")
network, err := logic.GetNetwork("skynet")
assert.Nil(t, err)
var key models.AccessKey
key.Name = "mykey"
_, err = CreateAccessKey(key, network)
_, err = logic.CreateAccessKey(key, network)
assert.Nil(t, err)
t.Run("ExistingKey", func(t *testing.T) {
err := DeleteKey("mykey", "skynet")
err := logic.DeleteKey("mykey", "skynet")
assert.Nil(t, err)
})
t.Run("NonExistantKey", func(t *testing.T) {
err := DeleteKey("mykey", "skynet")
err := logic.DeleteKey("mykey", "skynet")
assert.NotNil(t, err)
assert.Equal(t, "key mykey does not exist", err.Error())
})
@ -325,7 +325,7 @@ func TestValidateNetworkUpdate(t *testing.T) {
for _, tc := range cases {
t.Run(tc.testname, func(t *testing.T) {
network := models.Network(tc.network)
err := ValidateNetworkUpdate(network)
err := logic.ValidateNetworkUpdate(network)
assert.NotNil(t, err)
assert.Contains(t, err.Error(), tc.errMessage)
})
@ -336,7 +336,7 @@ func deleteAllNetworks() {
deleteAllNodes()
nets, _ := logic.GetNetworks()
for _, net := range nets {
DeleteNetwork(net.NetID)
logic.DeleteNetwork(net.NetID)
}
}
@ -345,13 +345,13 @@ func createNet() {
network.NetID = "skynet"
network.AddressRange = "10.0.0.1/24"
network.DisplayName = "mynetwork"
_, err := GetNetwork("skynet")
_, err := logic.GetNetwork("skynet")
if err != nil {
CreateNetwork(network)
logic.CreateNetwork(network)
}
}
func getNet() models.Network {
network, _ := GetNetwork("skynet")
network, _ := logic.GetNetwork("skynet")
return network
}

View file

@ -2,10 +2,8 @@ package controller
import (
"encoding/json"
"errors"
"net/http"
"strings"
"time"
"github.com/gorilla/mux"
"github.com/gravitl/netmaker/database"
@ -349,7 +347,7 @@ func getLastModified(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
var params = mux.Vars(r)
network, err := GetNetwork(params["network"])
network, err := logic.GetNetwork(params["network"])
if err != nil {
returnErrorResponse(w, r, formatError(err, "internal"))
return
@ -431,7 +429,7 @@ func createNode(w http.ResponseWriter, r *http.Request) {
func uncordonNode(w http.ResponseWriter, r *http.Request) {
var params = mux.Vars(r)
w.Header().Set("Content-Type", "application/json")
node, err := UncordonNode(params["network"], params["macaddress"])
node, err := logic.UncordonNode(params["network"], params["macaddress"])
if err != nil {
returnErrorResponse(w, r, formatError(err, "internal"))
return
@ -441,28 +439,6 @@ func uncordonNode(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode("SUCCESS")
}
// UncordonNode - approves a node to join a network
func UncordonNode(network, macaddress string) (models.Node, error) {
node, err := logic.GetNodeByMacAddress(network, macaddress)
if err != nil {
return models.Node{}, err
}
node.SetLastModified()
node.IsPending = "no"
node.PullChanges = "yes"
data, err := json.Marshal(&node)
if err != nil {
return node, err
}
key, err := logic.GetRecordKey(node.MacAddress, node.Network)
if err != nil {
return node, err
}
err = database.Insert(key, string(data), database.NODES_TABLE_NAME)
return node, err
}
func createEgressGateway(w http.ResponseWriter, r *http.Request) {
var gateway models.EgressGatewayRequest
var params = mux.Vars(r)
@ -474,7 +450,7 @@ func createEgressGateway(w http.ResponseWriter, r *http.Request) {
}
gateway.NetID = params["network"]
gateway.NodeID = params["macaddress"]
node, err := CreateEgressGateway(gateway)
node, err := logic.CreateEgressGateway(gateway)
if err != nil {
returnErrorResponse(w, r, formatError(err, "internal"))
return
@ -484,80 +460,12 @@ func createEgressGateway(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(node)
}
// CreateEgressGateway - creates an egress gateway
func CreateEgressGateway(gateway models.EgressGatewayRequest) (models.Node, error) {
node, err := logic.GetNodeByMacAddress(gateway.NetID, gateway.NodeID)
if node.OS == "windows" || node.OS == "macos" { // add in darwin later
return models.Node{}, errors.New(node.OS + " is unsupported for egress gateways")
}
if err != nil {
return models.Node{}, err
}
err = ValidateEgressGateway(gateway)
if err != nil {
return models.Node{}, err
}
node.IsEgressGateway = "yes"
node.EgressGatewayRanges = gateway.Ranges
postUpCmd := "iptables -A FORWARD -i " + node.Interface + " -j ACCEPT; iptables -t nat -A POSTROUTING -o " + gateway.Interface + " -j MASQUERADE"
postDownCmd := "iptables -D FORWARD -i " + node.Interface + " -j ACCEPT; iptables -t nat -D POSTROUTING -o " + gateway.Interface + " -j MASQUERADE"
if gateway.PostUp != "" {
postUpCmd = gateway.PostUp
}
if gateway.PostDown != "" {
postDownCmd = gateway.PostDown
}
if node.PostUp != "" {
if !strings.Contains(node.PostUp, postUpCmd) {
postUpCmd = node.PostUp + "; " + postUpCmd
}
}
if node.PostDown != "" {
if !strings.Contains(node.PostDown, postDownCmd) {
postDownCmd = node.PostDown + "; " + postDownCmd
}
}
key, err := logic.GetRecordKey(gateway.NodeID, gateway.NetID)
if err != nil {
return node, err
}
node.PostUp = postUpCmd
node.PostDown = postDownCmd
node.SetLastModified()
node.PullChanges = "yes"
nodeData, err := json.Marshal(&node)
if err != nil {
return node, err
}
if err = database.Insert(key, string(nodeData), database.NODES_TABLE_NAME); err != nil {
return models.Node{}, err
}
if err = functions.NetworkNodesUpdatePullChanges(node.Network); err != nil {
return models.Node{}, err
}
return node, nil
}
func ValidateEgressGateway(gateway models.EgressGatewayRequest) error {
var err error
//isIp := functions.IsIpCIDR(gateway.RangeString)
empty := len(gateway.Ranges) == 0
if empty {
err = errors.New("IP Ranges Cannot Be Empty")
}
empty = gateway.Interface == ""
if empty {
err = errors.New("Interface cannot be empty")
}
return err
}
func deleteEgressGateway(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
var params = mux.Vars(r)
nodeMac := params["macaddress"]
netid := params["network"]
node, err := DeleteEgressGateway(netid, nodeMac)
node, err := logic.DeleteEgressGateway(netid, nodeMac)
if err != nil {
returnErrorResponse(w, r, formatError(err, "internal"))
return
@ -567,48 +475,13 @@ func deleteEgressGateway(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(node)
}
// DeleteEgressGateway - deletes egress from node
func DeleteEgressGateway(network, macaddress string) (models.Node, error) {
node, err := logic.GetNodeByMacAddress(network, macaddress)
if err != nil {
return models.Node{}, err
}
node.IsEgressGateway = "no"
node.EgressGatewayRanges = []string{}
node.PostUp = ""
node.PostDown = ""
if node.IsIngressGateway == "yes" { // check if node is still an ingress gateway before completely deleting postdown/up rules
node.PostUp = "iptables -A FORWARD -i " + node.Interface + " -j ACCEPT; iptables -t nat -A POSTROUTING -o " + node.Interface + " -j MASQUERADE"
node.PostDown = "iptables -D FORWARD -i " + node.Interface + " -j ACCEPT; iptables -t nat -D POSTROUTING -o " + node.Interface + " -j MASQUERADE"
}
node.SetLastModified()
node.PullChanges = "yes"
key, err := logic.GetRecordKey(node.MacAddress, node.Network)
if err != nil {
return models.Node{}, err
}
data, err := json.Marshal(&node)
if err != nil {
return models.Node{}, err
}
if err = database.Insert(key, string(data), database.NODES_TABLE_NAME); err != nil {
return models.Node{}, err
}
if err = functions.NetworkNodesUpdatePullChanges(network); err != nil {
return models.Node{}, err
}
return node, nil
}
// == INGRESS ==
func createIngressGateway(w http.ResponseWriter, r *http.Request) {
var params = mux.Vars(r)
w.Header().Set("Content-Type", "application/json")
nodeMac := params["macaddress"]
netid := params["network"]
node, err := CreateIngressGateway(netid, nodeMac)
node, err := logic.CreateIngressGateway(netid, nodeMac)
if err != nil {
returnErrorResponse(w, r, formatError(err, "internal"))
return
@ -618,62 +491,11 @@ func createIngressGateway(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(node)
}
// CreateIngressGateway - creates an ingress gateway
func CreateIngressGateway(netid string, macaddress string) (models.Node, error) {
node, err := logic.GetNodeByMacAddress(netid, macaddress)
if node.OS == "windows" || node.OS == "macos" { // add in darwin later
return models.Node{}, errors.New(node.OS + " is unsupported for ingress gateways")
}
if err != nil {
return models.Node{}, err
}
network, err := logic.GetParentNetwork(netid)
if err != nil {
return models.Node{}, err
}
node.IsIngressGateway = "yes"
node.IngressGatewayRange = network.AddressRange
postUpCmd := "iptables -A FORWARD -i " + node.Interface + " -j ACCEPT; iptables -t nat -A POSTROUTING -o " + node.Interface + " -j MASQUERADE"
postDownCmd := "iptables -D FORWARD -i " + node.Interface + " -j ACCEPT; iptables -t nat -D POSTROUTING -o " + node.Interface + " -j MASQUERADE"
if node.PostUp != "" {
if !strings.Contains(node.PostUp, postUpCmd) {
postUpCmd = node.PostUp + "; " + postUpCmd
}
}
if node.PostDown != "" {
if !strings.Contains(node.PostDown, postDownCmd) {
postDownCmd = node.PostDown + "; " + postDownCmd
}
}
node.SetLastModified()
node.PostUp = postUpCmd
node.PostDown = postDownCmd
node.PullChanges = "yes"
node.UDPHolePunch = "no"
key, err := logic.GetRecordKey(node.MacAddress, node.Network)
if err != nil {
return models.Node{}, err
}
data, err := json.Marshal(&node)
if err != nil {
return models.Node{}, err
}
err = database.Insert(key, string(data), database.NODES_TABLE_NAME)
if err != nil {
return models.Node{}, err
}
err = logic.SetNetworkNodesLastModified(netid)
return node, err
}
func deleteIngressGateway(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
var params = mux.Vars(r)
nodeMac := params["macaddress"]
node, err := DeleteIngressGateway(params["network"], nodeMac)
node, err := logic.DeleteIngressGateway(params["network"], nodeMac)
if err != nil {
returnErrorResponse(w, r, formatError(err, "internal"))
return
@ -683,44 +505,6 @@ func deleteIngressGateway(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(node)
}
// DeleteIngressGateway - deletes an ingress gateway
func DeleteIngressGateway(networkName string, macaddress string) (models.Node, error) {
node, err := logic.GetNodeByMacAddress(networkName, macaddress)
if err != nil {
return models.Node{}, err
}
network, err := logic.GetParentNetwork(networkName)
if err != nil {
return models.Node{}, err
}
// delete ext clients belonging to ingress gateway
if err = DeleteGatewayExtClients(macaddress, networkName); err != nil {
return models.Node{}, err
}
node.UDPHolePunch = network.DefaultUDPHolePunch
node.LastModified = time.Now().Unix()
node.IsIngressGateway = "no"
node.IngressGatewayRange = ""
node.PullChanges = "yes"
key, err := logic.GetRecordKey(node.MacAddress, node.Network)
if err != nil {
return models.Node{}, err
}
data, err := json.Marshal(&node)
if err != nil {
return models.Node{}, err
}
err = database.Insert(key, string(data), database.NODES_TABLE_NAME)
if err != nil {
return models.Node{}, err
}
err = logic.SetNetworkNodesLastModified(networkName)
return node, err
}
func updateNode(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
@ -760,8 +544,8 @@ func updateNode(w http.ResponseWriter, r *http.Request) {
return
}
if relayupdate {
UpdateRelay(node.Network, node.RelayAddrs, newNode.RelayAddrs)
if err = functions.NetworkNodesUpdatePullChanges(node.Network); err != nil {
logic.UpdateRelay(node.Network, node.RelayAddrs, newNode.RelayAddrs)
if err = logic.NetworkNodesUpdatePullChanges(node.Network); err != nil {
logger.Log(1, "error setting relay updates:", err.Error())
}
}

View file

@ -17,7 +17,7 @@ func TestCreateEgressGateway(t *testing.T) {
deleteAllNetworks()
createNet()
t.Run("NoNodes", func(t *testing.T) {
node, err := CreateEgressGateway(gateway)
node, err := logic.CreateEgressGateway(gateway)
assert.Equal(t, models.Node{}, node)
assert.EqualError(t, err, "unable to get record key")
})
@ -26,7 +26,7 @@ func TestCreateEgressGateway(t *testing.T) {
gateway.NetID = "skynet"
gateway.NodeID = testnode.MacAddress
node, err := CreateEgressGateway(gateway)
node, err := logic.CreateEgressGateway(gateway)
assert.Nil(t, err)
assert.Equal(t, "yes", node.IsEgressGateway)
assert.Equal(t, gateway.Ranges, node.EgressGatewayRanges)
@ -45,11 +45,11 @@ func TestDeleteEgressGateway(t *testing.T) {
gateway.NetID = "skynet"
gateway.NodeID = testnode.MacAddress
t.Run("Success", func(t *testing.T) {
node, err := CreateEgressGateway(gateway)
node, err := logic.CreateEgressGateway(gateway)
assert.Nil(t, err)
assert.Equal(t, "yes", node.IsEgressGateway)
assert.Equal(t, []string{"10.100.100.0/24"}, node.EgressGatewayRanges)
node, err = DeleteEgressGateway(gateway.NetID, gateway.NodeID)
node, err = logic.DeleteEgressGateway(gateway.NetID, gateway.NodeID)
assert.Nil(t, err)
assert.Equal(t, "no", node.IsEgressGateway)
assert.Equal(t, []string([]string{}), node.EgressGatewayRanges)
@ -57,7 +57,7 @@ func TestDeleteEgressGateway(t *testing.T) {
assert.Equal(t, "", node.PostDown)
})
t.Run("NotGateway", func(t *testing.T) {
node, err := DeleteEgressGateway(gateway.NetID, gateway.NodeID)
node, err := logic.DeleteEgressGateway(gateway.NetID, gateway.NodeID)
assert.Nil(t, err)
assert.Equal(t, "no", node.IsEgressGateway)
assert.Equal(t, []string([]string{}), node.EgressGatewayRanges)
@ -65,12 +65,12 @@ func TestDeleteEgressGateway(t *testing.T) {
assert.Equal(t, "", node.PostDown)
})
t.Run("BadNode", func(t *testing.T) {
node, err := DeleteEgressGateway(gateway.NetID, "01:02:03")
node, err := logic.DeleteEgressGateway(gateway.NetID, "01:02:03")
assert.EqualError(t, err, "no result found")
assert.Equal(t, models.Node{}, node)
})
t.Run("BadNet", func(t *testing.T) {
node, err := DeleteEgressGateway("badnet", gateway.NodeID)
node, err := logic.DeleteEgressGateway("badnet", gateway.NodeID)
assert.EqualError(t, err, "no result found")
assert.Equal(t, models.Node{}, node)
})
@ -106,17 +106,17 @@ func TestUncordonNode(t *testing.T) {
createNet()
node := createTestNode()
t.Run("BadNet", func(t *testing.T) {
resp, err := UncordonNode("badnet", node.MacAddress)
resp, err := logic.UncordonNode("badnet", node.MacAddress)
assert.Equal(t, models.Node{}, resp)
assert.EqualError(t, err, "no result found")
})
t.Run("BadMac", func(t *testing.T) {
resp, err := UncordonNode("skynet", "01:02:03")
resp, err := logic.UncordonNode("skynet", "01:02:03")
assert.Equal(t, models.Node{}, resp)
assert.EqualError(t, err, "no result found")
})
t.Run("Success", func(t *testing.T) {
resp, err := UncordonNode("skynet", node.MacAddress)
resp, err := logic.UncordonNode("skynet", node.MacAddress)
assert.Nil(t, err)
assert.Equal(t, "no", resp.IsPending)
})
@ -127,19 +127,19 @@ func TestValidateEgressGateway(t *testing.T) {
t.Run("EmptyRange", func(t *testing.T) {
gateway.Interface = "eth0"
gateway.Ranges = []string{}
err := ValidateEgressGateway(gateway)
err := logic.ValidateEgressGateway(gateway)
assert.EqualError(t, err, "IP Ranges Cannot Be Empty")
})
t.Run("EmptyInterface", func(t *testing.T) {
gateway.Interface = ""
err := ValidateEgressGateway(gateway)
err := logic.ValidateEgressGateway(gateway)
assert.NotNil(t, err)
assert.Equal(t, "Interface cannot be empty", err.Error())
})
t.Run("Success", func(t *testing.T) {
gateway.Interface = "eth0"
gateway.Ranges = []string{"10.100.100.0/24"}
err := ValidateEgressGateway(gateway)
err := logic.ValidateEgressGateway(gateway)
assert.Nil(t, err)
})
}

View file

@ -2,13 +2,9 @@ package controller
import (
"encoding/json"
"errors"
"net/http"
"time"
"github.com/gorilla/mux"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/functions"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/models"
@ -25,7 +21,7 @@ func createRelay(w http.ResponseWriter, r *http.Request) {
}
relay.NetID = params["network"]
relay.NodeID = params["macaddress"]
node, err := CreateRelay(relay)
node, err := logic.CreateRelay(relay)
if err != nil {
returnErrorResponse(w, r, formatError(err, "internal"))
return
@ -35,52 +31,12 @@ func createRelay(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(node)
}
// CreateRelay - creates a relay
func CreateRelay(relay models.RelayRequest) (models.Node, error) {
node, err := logic.GetNodeByMacAddress(relay.NetID, relay.NodeID)
if node.OS == "macos" { // add in darwin later
return models.Node{}, errors.New(node.OS + " is unsupported for relay")
}
if err != nil {
return models.Node{}, err
}
err = ValidateRelay(relay)
if err != nil {
return models.Node{}, err
}
node.IsRelay = "yes"
node.RelayAddrs = relay.RelayAddrs
key, err := logic.GetRecordKey(relay.NodeID, relay.NetID)
if err != nil {
return node, err
}
node.SetLastModified()
node.PullChanges = "yes"
nodeData, err := json.Marshal(&node)
if err != nil {
return node, err
}
if err = database.Insert(key, string(nodeData), database.NODES_TABLE_NAME); err != nil {
return models.Node{}, err
}
err = SetRelayedNodes("yes", node.Network, node.RelayAddrs)
if err != nil {
return node, err
}
if err = functions.NetworkNodesUpdatePullChanges(node.Network); err != nil {
return models.Node{}, err
}
return node, nil
}
func deleteRelay(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
var params = mux.Vars(r)
nodeMac := params["macaddress"]
netid := params["network"]
node, err := DeleteRelay(netid, nodeMac)
node, err := logic.DeleteRelay(netid, nodeMac)
if err != nil {
returnErrorResponse(w, r, formatError(err, "internal"))
return
@ -89,92 +45,3 @@ func deleteRelay(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(node)
}
// SetRelayedNodes- set relayed nodes
func SetRelayedNodes(yesOrno string, networkName string, addrs []string) error {
collections, err := database.FetchRecords(database.NODES_TABLE_NAME)
if err != nil {
return err
}
for _, value := range collections {
var node models.Node
err := json.Unmarshal([]byte(value), &node)
if err != nil {
return err
}
if node.Network == networkName {
for _, addr := range addrs {
if addr == node.Address || addr == node.Address6 {
node.IsRelayed = yesOrno
data, err := json.Marshal(&node)
if err != nil {
return err
}
node.SetID()
database.Insert(node.ID, string(data), database.NODES_TABLE_NAME)
}
}
}
}
return nil
}
// ValidateRelay - checks if relay is valid
func ValidateRelay(relay models.RelayRequest) error {
var err error
//isIp := functions.IsIpCIDR(gateway.RangeString)
empty := len(relay.RelayAddrs) == 0
if empty {
err = errors.New("IP Ranges Cannot Be Empty")
}
return err
}
// UpdateRelay - updates a relay
func UpdateRelay(network string, oldAddrs []string, newAddrs []string) {
time.Sleep(time.Second / 4)
err := SetRelayedNodes("no", network, oldAddrs)
if err != nil {
logger.Log(1, err.Error())
}
err = SetRelayedNodes("yes", network, newAddrs)
if err != nil {
logger.Log(1, err.Error())
}
}
// DeleteRelay - deletes a relay
func DeleteRelay(network, macaddress string) (models.Node, error) {
node, err := logic.GetNodeByMacAddress(network, macaddress)
if err != nil {
return models.Node{}, err
}
err = SetRelayedNodes("no", node.Network, node.RelayAddrs)
if err != nil {
return node, err
}
node.IsRelay = "no"
node.RelayAddrs = []string{}
node.SetLastModified()
node.PullChanges = "yes"
key, err := logic.GetRecordKey(node.MacAddress, node.Network)
if err != nil {
return models.Node{}, err
}
data, err := json.Marshal(&node)
if err != nil {
return models.Node{}, err
}
if err = database.Insert(key, string(data), database.NODES_TABLE_NAME); err != nil {
return models.Node{}, err
}
if err = functions.NetworkNodesUpdatePullChanges(network); err != nil {
return models.Node{}, err
}
return node, nil
}

View file

@ -112,36 +112,6 @@ func authenticateDNSToken(tokenString string) bool {
return tokens[1] == servercfg.GetDNSKey()
}
// ValidateUserToken - self explained
func ValidateUserToken(token string, user string, adminonly bool) error {
var tokenSplit = strings.Split(token, " ")
//I put this in in case the user doesn't put in a token at all (in which case it's empty)
//There's probably a smarter way of handling this.
var authToken = "928rt238tghgwe@TY@$Y@#WQAEGB2FC#@HG#@$Hddd"
if len(tokenSplit) > 1 {
authToken = tokenSplit[1]
} else {
return errors.New("Missing Auth Token.")
}
username, _, isadmin, err := logic.VerifyUserToken(authToken)
if err != nil {
return errors.New("Error Verifying Auth Token")
}
isAuthorized := false
if adminonly {
isAuthorized = isadmin
} else {
isAuthorized = username == user || isadmin
}
if !isAuthorized {
return errors.New("You are unauthorized to access this endpoint.")
}
return nil
}
func continueIfUserMatch(next http.Handler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var errorResponse = models.ErrorResponse{

View file

@ -241,29 +241,29 @@ func TestUpdateUser(t *testing.T) {
})
}
func TestValidateUserToken(t *testing.T) {
t.Run("EmptyToken", func(t *testing.T) {
err := ValidateUserToken("", "", false)
assert.NotNil(t, err)
assert.Equal(t, "Missing Auth Token.", err.Error())
})
t.Run("InvalidToken", func(t *testing.T) {
err := ValidateUserToken("Bearer: badtoken", "", false)
assert.NotNil(t, err)
assert.Equal(t, "Error Verifying Auth Token", err.Error())
})
t.Run("InvalidUser", func(t *testing.T) {
t.Skip()
err := ValidateUserToken("Bearer: secretkey", "baduser", false)
assert.NotNil(t, err)
assert.Equal(t, "Error Verifying Auth Token", err.Error())
//need authorization
})
t.Run("ValidToken", func(t *testing.T) {
err := ValidateUserToken("Bearer: secretkey", "", true)
assert.Nil(t, err)
})
}
// func TestValidateUserToken(t *testing.T) {
// t.Run("EmptyToken", func(t *testing.T) {
// err := ValidateUserToken("", "", false)
// assert.NotNil(t, err)
// assert.Equal(t, "Missing Auth Token.", err.Error())
// })
// t.Run("InvalidToken", func(t *testing.T) {
// err := ValidateUserToken("Bearer: badtoken", "", false)
// assert.NotNil(t, err)
// assert.Equal(t, "Error Verifying Auth Token", err.Error())
// })
// t.Run("InvalidUser", func(t *testing.T) {
// t.Skip()
// err := ValidateUserToken("Bearer: secretkey", "baduser", false)
// assert.NotNil(t, err)
// assert.Equal(t, "Error Verifying Auth Token", err.Error())
// //need authorization
// })
// t.Run("ValidToken", func(t *testing.T) {
// err := ValidateUserToken("Bearer: secretkey", "", true)
// assert.Nil(t, err)
// })
// }
func TestVerifyAuthRequest(t *testing.T) {
database.InitializeDatabase()

View file

@ -2,24 +2,14 @@ package functions
import (
"encoding/json"
"fmt"
"log"
"math/rand"
"strings"
"time"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/models"
)
// ParseNetwork - parses a network into a model
func ParseNetwork(value string) (models.Network, error) {
var network models.Network
err := json.Unmarshal([]byte(value), &network)
return network, err
}
// ParseNode - parses a node into a model
func ParseNode(value string) (models.Node, error) {
var node models.Node
@ -131,72 +121,6 @@ func NetworkExists(name string) (bool, error) {
return len(network) > 0, nil
}
// NetworkNodesUpdateAction - updates action of network nodes
func NetworkNodesUpdateAction(networkName string, action string) error {
collections, err := database.FetchRecords(database.NODES_TABLE_NAME)
if err != nil {
if database.IsEmptyRecord(err) {
return nil
}
return err
}
for _, value := range collections {
var node models.Node
err := json.Unmarshal([]byte(value), &node)
if err != nil {
fmt.Println("error in node address assignment!")
return err
}
if action == models.NODE_UPDATE_KEY && node.IsStatic == "yes" {
continue
}
if node.Network == networkName {
node.Action = action
data, err := json.Marshal(&node)
if err != nil {
return err
}
node.SetID()
database.Insert(node.ID, string(data), database.NODES_TABLE_NAME)
}
}
return nil
}
// NetworkNodesUpdatePullChanges - tells nodes on network to pull
func NetworkNodesUpdatePullChanges(networkName string) error {
collections, err := database.FetchRecords(database.NODES_TABLE_NAME)
if err != nil {
if database.IsEmptyRecord(err) {
return nil
}
return err
}
for _, value := range collections {
var node models.Node
err := json.Unmarshal([]byte(value), &node)
if err != nil {
fmt.Println("error in node address assignment!")
return err
}
if node.Network == networkName {
node.PullChanges = "yes"
data, err := json.Marshal(&node)
if err != nil {
return err
}
node.SetID()
database.Insert(node.ID, string(data), database.NODES_TABLE_NAME)
}
}
return nil
}
// IsNetworkDisplayNameUnique - checks if network display name unique
func IsNetworkDisplayNameUnique(name string) (bool, error) {
@ -228,28 +152,6 @@ func IsMacAddressUnique(macaddress string, networkName string) (bool, error) {
return true, nil
}
// GetNetworkNonServerNodeCount - get number of network non server nodes
func GetNetworkNonServerNodeCount(networkName string) (int, error) {
collection, err := database.FetchRecords(database.NODES_TABLE_NAME)
count := 0
if err != nil && !database.IsEmptyRecord(err) {
return count, err
}
for _, value := range collection {
var node models.Node
if err = json.Unmarshal([]byte(value), &node); err != nil {
return count, err
} else {
if node.Network == networkName && node.IsServer != "yes" {
count++
}
}
}
return count, nil
}
// IsKeyValidGlobal - checks if a key is valid globally
func IsKeyValidGlobal(keyvalue string) bool {
@ -278,31 +180,6 @@ func IsKeyValidGlobal(keyvalue string) bool {
return isvalid
}
//TODO: Contains a fatal error return. Need to change
//This just gets a network object from a network name
//Should probably just be GetNetwork. kind of a dumb name.
//Used in contexts where it's not the Parent network.
//Similar to above but checks if Cidr range is valid
//At least this guy's got some print statements
//still not good error handling
//This checks to make sure a network name is valid.
//Switch to REGEX?
// NameInNetworkCharSet - see if name is in charset for networks
func NameInNetworkCharSet(name string) bool {
charset := "abcdefghijklmnopqrstuvwxyz1234567890-_."
for _, char := range name {
if !strings.Contains(charset, strings.ToLower(string(char))) {
return false
}
}
return true
}
// NameInDNSCharSet - name in dns char set
func NameInDNSCharSet(name string) bool {
@ -387,43 +264,6 @@ func GetAllExtClients() ([]models.ExtClient, error) {
return extclients, nil
}
// GenKey - generates access key
func GenKey() string {
var seededRand *rand.Rand = rand.New(
rand.NewSource(time.Now().UnixNano()))
length := 16
charset := "abcdefghijklmnopqrstuvwxyz" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
b := make([]byte, length)
for i := range b {
b[i] = charset[seededRand.Intn(len(charset))]
}
return string(b)
}
//generate a key value
//we should probably just have 1 random string generator
//that can be used across all functions
//have a "base string" a "length" and a "charset"
// GenKeyName - generates a key name
func GenKeyName() string {
var seededRand *rand.Rand = rand.New(
rand.NewSource(time.Now().UnixNano()))
length := 5
charset := "abcdefghijklmnopqrstuvwxyz" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
b := make([]byte, length)
for i := range b {
b[i] = charset[seededRand.Intn(len(charset))]
}
return "key" + string(b)
}
// DeleteKey - deletes a key
func DeleteKey(network models.Network, i int) {

View file

@ -1,13 +1,147 @@
package logic
import (
"encoding/base64"
"encoding/json"
"errors"
"math/rand"
"time"
"github.com/go-playground/validator/v10"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/servercfg"
)
const (
charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
)
// CreateAccessKey - create access key
func CreateAccessKey(accesskey models.AccessKey, network models.Network) (models.AccessKey, error) {
if accesskey.Name == "" {
accesskey.Name = genKeyName()
}
if accesskey.Value == "" {
accesskey.Value = genKey()
}
if accesskey.Uses == 0 {
accesskey.Uses = 1
}
checkkeys, err := GetKeys(network.NetID)
if err != nil {
return models.AccessKey{}, errors.New("could not retrieve network keys")
}
for _, key := range checkkeys {
if key.Name == accesskey.Name {
return models.AccessKey{}, errors.New("duplicate AccessKey Name")
}
}
privAddr := ""
if network.IsLocal != "" {
privAddr = network.LocalRange
}
netID := network.NetID
var accessToken models.AccessToken
s := servercfg.GetServerConfig()
servervals := models.ServerConfig{
CoreDNSAddr: s.CoreDNSAddr,
APIConnString: s.APIConnString,
APIHost: s.APIHost,
APIPort: s.APIPort,
GRPCConnString: s.GRPCConnString,
GRPCHost: s.GRPCHost,
GRPCPort: s.GRPCPort,
GRPCSSL: s.GRPCSSL,
CheckinInterval: s.CheckinInterval,
}
accessToken.ServerConfig = servervals
accessToken.ClientConfig.Network = netID
accessToken.ClientConfig.Key = accesskey.Value
accessToken.ClientConfig.LocalRange = privAddr
tokenjson, err := json.Marshal(accessToken)
if err != nil {
return accesskey, err
}
accesskey.AccessString = base64.StdEncoding.EncodeToString([]byte(tokenjson))
//validate accesskey
v := validator.New()
err = v.Struct(accesskey)
if err != nil {
for _, e := range err.(validator.ValidationErrors) {
logger.Log(1, "validator", e.Error())
}
return models.AccessKey{}, err
}
network.AccessKeys = append(network.AccessKeys, accesskey)
data, err := json.Marshal(&network)
if err != nil {
return models.AccessKey{}, err
}
if err = database.Insert(network.NetID, string(data), database.NETWORKS_TABLE_NAME); err != nil {
return models.AccessKey{}, err
}
return accesskey, nil
}
// DeleteKey - deletes a key
func DeleteKey(keyname, netname string) error {
network, err := GetParentNetwork(netname)
if err != nil {
return err
}
//basically, turn the list of access keys into the list of access keys before and after the item
//have not done any error handling for if there's like...1 item. I think it works? need to test.
found := false
var updatedKeys []models.AccessKey
for _, currentkey := range network.AccessKeys {
if currentkey.Name == keyname {
found = true
} else {
updatedKeys = append(updatedKeys, currentkey)
}
}
if !found {
return errors.New("key " + keyname + " does not exist")
}
network.AccessKeys = updatedKeys
data, err := json.Marshal(&network)
if err != nil {
return err
}
if err := database.Insert(network.NetID, string(data), database.NETWORKS_TABLE_NAME); err != nil {
return err
}
return nil
}
// GetKeys - fetches keys for network
func GetKeys(net string) ([]models.AccessKey, error) {
record, err := database.FetchRecord(database.NETWORKS_TABLE_NAME, net)
if err != nil {
return []models.AccessKey{}, err
}
network, err := ParseNetwork(record)
if err != nil {
return []models.AccessKey{}, err
}
return network.AccessKeys, nil
}
// DecrimentKey - decriments key uses
func DecrimentKey(networkName string, keyvalue string) {
@ -71,3 +205,33 @@ func RemoveKeySensitiveInfo(keys []models.AccessKey) []models.AccessKey {
}
return returnKeys
}
// == private methods ==
func genKeyName() string {
var seededRand *rand.Rand = rand.New(
rand.NewSource(time.Now().UnixNano()))
length := 5
b := make([]byte, length)
for i := range b {
b[i] = charset[seededRand.Intn(len(charset))]
}
return "key" + string(b)
}
func genKey() string {
var seededRand *rand.Rand = rand.New(
rand.NewSource(time.Now().UnixNano()))
length := 16
b := make([]byte, length)
for i := range b {
b[i] = charset[seededRand.Intn(len(charset))]
}
return string(b)
}

View file

@ -5,6 +5,7 @@ import (
"io/ioutil"
"os"
"github.com/go-playground/validator/v10"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/models"
@ -140,3 +141,105 @@ func SetCorefile(domains string) error {
}
return err
}
// GetAllDNS - gets all dns entries
func GetAllDNS() ([]models.DNSEntry, error) {
var dns []models.DNSEntry
networks, err := GetNetworks()
if err != nil && !database.IsEmptyRecord(err) {
return []models.DNSEntry{}, err
}
for _, net := range networks {
netdns, err := GetDNS(net.NetID)
if err != nil {
return []models.DNSEntry{}, nil
}
dns = append(dns, netdns...)
}
return dns, nil
}
// GetDNSEntryNum - gets which entry the dns was
func GetDNSEntryNum(domain string, network string) (int, error) {
num := 0
entries, err := GetDNS(network)
if err != nil {
return 0, err
}
for i := 0; i < len(entries); i++ {
if domain == entries[i].Name {
num++
}
}
return num, nil
}
// ValidateDNSCreate - checks if an entry is valid
func ValidateDNSCreate(entry models.DNSEntry) error {
v := validator.New()
_ = v.RegisterValidation("name_unique", func(fl validator.FieldLevel) bool {
num, err := GetDNSEntryNum(entry.Name, entry.Network)
return err == nil && num == 0
})
_ = v.RegisterValidation("network_exists", func(fl validator.FieldLevel) bool {
_, err := GetParentNetwork(entry.Network)
return err == nil
})
err := v.Struct(entry)
if err != nil {
for _, e := range err.(validator.ValidationErrors) {
logger.Log(1, e.Error())
}
}
return err
}
// ValidateDNSUpdate - validates a DNS update
func ValidateDNSUpdate(change models.DNSEntry, entry models.DNSEntry) error {
v := validator.New()
_ = v.RegisterValidation("name_unique", func(fl validator.FieldLevel) bool {
//if name & net not changing name we are good
if change.Name == entry.Name && change.Network == entry.Network {
return true
}
num, err := GetDNSEntryNum(change.Name, change.Network)
return err == nil && num == 0
})
_ = v.RegisterValidation("network_exists", func(fl validator.FieldLevel) bool {
_, err := GetParentNetwork(change.Network)
if err != nil {
logger.Log(0, err.Error())
}
return err == nil
})
err := v.Struct(change)
if err != nil {
for _, e := range err.(validator.ValidationErrors) {
logger.Log(1, e.Error())
}
}
return err
}
// DeleteDNS - deletes a DNS entry
func DeleteDNS(domain string, network string) error {
key, err := GetRecordKey(domain, network)
if err != nil {
return err
}
err = database.DeleteRecord(database.DNS_TABLE_NAME, key)
return err
}

View file

@ -2,10 +2,12 @@ package logic
import (
"encoding/json"
"time"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/models"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// GetExtPeersList - gets the ext peers lists
@ -63,3 +65,103 @@ func GetEgressRangesOnNetwork(client *models.ExtClient) ([]string, error) {
return result, nil
}
// DeleteExtClient - deletes an existing ext client
func DeleteExtClient(network string, clientid string) error {
key, err := GetRecordKey(clientid, network)
if err != nil {
return err
}
err = database.DeleteRecord(database.EXT_CLIENT_TABLE_NAME, key)
return err
}
// GetNetworkExtClients - gets the ext clients of given network
func GetNetworkExtClients(network string) ([]models.ExtClient, error) {
var extclients []models.ExtClient
records, err := database.FetchRecords(database.EXT_CLIENT_TABLE_NAME)
if err != nil {
return extclients, err
}
for _, value := range records {
var extclient models.ExtClient
err = json.Unmarshal([]byte(value), &extclient)
if err != nil {
continue
}
if extclient.Network == network {
extclients = append(extclients, extclient)
}
}
return extclients, err
}
// GetExtClient - gets a single ext client on a network
func GetExtClient(clientid string, network string) (models.ExtClient, error) {
var extclient models.ExtClient
key, err := GetRecordKey(clientid, network)
if err != nil {
return extclient, err
}
data, err := database.FetchRecord(database.EXT_CLIENT_TABLE_NAME, key)
if err != nil {
return extclient, err
}
err = json.Unmarshal([]byte(data), &extclient)
return extclient, err
}
// CreateExtClient - creates an extclient
func CreateExtClient(extclient models.ExtClient) error {
if extclient.PrivateKey == "" {
privateKey, err := wgtypes.GeneratePrivateKey()
if err != nil {
return err
}
extclient.PrivateKey = privateKey.String()
extclient.PublicKey = privateKey.PublicKey().String()
}
if extclient.Address == "" {
newAddress, err := UniqueAddress(extclient.Network)
if err != nil {
return err
}
extclient.Address = newAddress
}
if extclient.ClientID == "" {
extclient.ClientID = models.GenerateNodeName()
}
extclient.LastModified = time.Now().Unix()
key, err := GetRecordKey(extclient.ClientID, extclient.Network)
if err != nil {
return err
}
data, err := json.Marshal(&extclient)
if err != nil {
return err
}
if err = database.Insert(key, string(data), database.EXT_CLIENT_TABLE_NAME); err != nil {
return err
}
err = SetNetworkNodesLastModified(extclient.Network)
return err
}
// UpdateExtClient - only supports name changes right now
func UpdateExtClient(newclientid string, network string, client models.ExtClient) (models.ExtClient, error) {
err := DeleteExtClient(network, client.ClientID)
if err != nil {
return client, err
}
client.ClientID = newclientid
CreateExtClient(client)
return client, err
}

221
logic/gateway.go Normal file
View file

@ -0,0 +1,221 @@
package logic
import (
"encoding/json"
"errors"
"strings"
"time"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/models"
)
// CreateEgressGateway - creates an egress gateway
func CreateEgressGateway(gateway models.EgressGatewayRequest) (models.Node, error) {
node, err := GetNodeByMacAddress(gateway.NetID, gateway.NodeID)
if node.OS == "windows" || node.OS == "macos" { // add in darwin later
return models.Node{}, errors.New(node.OS + " is unsupported for egress gateways")
}
if err != nil {
return models.Node{}, err
}
err = ValidateEgressGateway(gateway)
if err != nil {
return models.Node{}, err
}
node.IsEgressGateway = "yes"
node.EgressGatewayRanges = gateway.Ranges
postUpCmd := "iptables -A FORWARD -i " + node.Interface + " -j ACCEPT; iptables -t nat -A POSTROUTING -o " + gateway.Interface + " -j MASQUERADE"
postDownCmd := "iptables -D FORWARD -i " + node.Interface + " -j ACCEPT; iptables -t nat -D POSTROUTING -o " + gateway.Interface + " -j MASQUERADE"
if gateway.PostUp != "" {
postUpCmd = gateway.PostUp
}
if gateway.PostDown != "" {
postDownCmd = gateway.PostDown
}
if node.PostUp != "" {
if !strings.Contains(node.PostUp, postUpCmd) {
postUpCmd = node.PostUp + "; " + postUpCmd
}
}
if node.PostDown != "" {
if !strings.Contains(node.PostDown, postDownCmd) {
postDownCmd = node.PostDown + "; " + postDownCmd
}
}
key, err := GetRecordKey(gateway.NodeID, gateway.NetID)
if err != nil {
return node, err
}
node.PostUp = postUpCmd
node.PostDown = postDownCmd
node.SetLastModified()
node.PullChanges = "yes"
nodeData, err := json.Marshal(&node)
if err != nil {
return node, err
}
if err = database.Insert(key, string(nodeData), database.NODES_TABLE_NAME); err != nil {
return models.Node{}, err
}
if err = NetworkNodesUpdatePullChanges(node.Network); err != nil {
return models.Node{}, err
}
return node, nil
}
func ValidateEgressGateway(gateway models.EgressGatewayRequest) error {
var err error
empty := len(gateway.Ranges) == 0
if empty {
err = errors.New("IP Ranges Cannot Be Empty")
}
empty = gateway.Interface == ""
if empty {
err = errors.New("Interface cannot be empty")
}
return err
}
// DeleteEgressGateway - deletes egress from node
func DeleteEgressGateway(network, macaddress string) (models.Node, error) {
node, err := GetNodeByMacAddress(network, macaddress)
if err != nil {
return models.Node{}, err
}
node.IsEgressGateway = "no"
node.EgressGatewayRanges = []string{}
node.PostUp = ""
node.PostDown = ""
if node.IsIngressGateway == "yes" { // check if node is still an ingress gateway before completely deleting postdown/up rules
node.PostUp = "iptables -A FORWARD -i " + node.Interface + " -j ACCEPT; iptables -t nat -A POSTROUTING -o " + node.Interface + " -j MASQUERADE"
node.PostDown = "iptables -D FORWARD -i " + node.Interface + " -j ACCEPT; iptables -t nat -D POSTROUTING -o " + node.Interface + " -j MASQUERADE"
}
node.SetLastModified()
node.PullChanges = "yes"
key, err := GetRecordKey(node.MacAddress, node.Network)
if err != nil {
return models.Node{}, err
}
data, err := json.Marshal(&node)
if err != nil {
return models.Node{}, err
}
if err = database.Insert(key, string(data), database.NODES_TABLE_NAME); err != nil {
return models.Node{}, err
}
if err = NetworkNodesUpdatePullChanges(network); err != nil {
return models.Node{}, err
}
return node, nil
}
// CreateIngressGateway - creates an ingress gateway
func CreateIngressGateway(netid string, macaddress string) (models.Node, error) {
node, err := GetNodeByMacAddress(netid, macaddress)
if node.OS == "windows" || node.OS == "macos" { // add in darwin later
return models.Node{}, errors.New(node.OS + " is unsupported for ingress gateways")
}
if err != nil {
return models.Node{}, err
}
network, err := GetParentNetwork(netid)
if err != nil {
return models.Node{}, err
}
node.IsIngressGateway = "yes"
node.IngressGatewayRange = network.AddressRange
postUpCmd := "iptables -A FORWARD -i " + node.Interface + " -j ACCEPT; iptables -t nat -A POSTROUTING -o " + node.Interface + " -j MASQUERADE"
postDownCmd := "iptables -D FORWARD -i " + node.Interface + " -j ACCEPT; iptables -t nat -D POSTROUTING -o " + node.Interface + " -j MASQUERADE"
if node.PostUp != "" {
if !strings.Contains(node.PostUp, postUpCmd) {
postUpCmd = node.PostUp + "; " + postUpCmd
}
}
if node.PostDown != "" {
if !strings.Contains(node.PostDown, postDownCmd) {
postDownCmd = node.PostDown + "; " + postDownCmd
}
}
node.SetLastModified()
node.PostUp = postUpCmd
node.PostDown = postDownCmd
node.PullChanges = "yes"
node.UDPHolePunch = "no"
key, err := GetRecordKey(node.MacAddress, node.Network)
if err != nil {
return models.Node{}, err
}
data, err := json.Marshal(&node)
if err != nil {
return models.Node{}, err
}
err = database.Insert(key, string(data), database.NODES_TABLE_NAME)
if err != nil {
return models.Node{}, err
}
err = SetNetworkNodesLastModified(netid)
return node, err
}
// DeleteIngressGateway - deletes an ingress gateway
func DeleteIngressGateway(networkName string, macaddress string) (models.Node, error) {
node, err := GetNodeByMacAddress(networkName, macaddress)
if err != nil {
return models.Node{}, err
}
network, err := GetParentNetwork(networkName)
if err != nil {
return models.Node{}, err
}
// delete ext clients belonging to ingress gateway
if err = DeleteGatewayExtClients(macaddress, networkName); err != nil {
return models.Node{}, err
}
node.UDPHolePunch = network.DefaultUDPHolePunch
node.LastModified = time.Now().Unix()
node.IsIngressGateway = "no"
node.IngressGatewayRange = ""
node.PullChanges = "yes"
key, err := GetRecordKey(node.MacAddress, node.Network)
if err != nil {
return models.Node{}, err
}
data, err := json.Marshal(&node)
if err != nil {
return models.Node{}, err
}
err = database.Insert(key, string(data), database.NODES_TABLE_NAME)
if err != nil {
return models.Node{}, err
}
err = SetNetworkNodesLastModified(networkName)
return node, err
}
// DeleteGatewayExtClients - deletes ext clients based on gateway (mac) of ingress node and network
func DeleteGatewayExtClients(gatewayID string, networkName string) error {
currentExtClients, err := GetNetworkExtClients(networkName)
if err != nil && !database.IsEmptyRecord(err) {
return err
}
for _, extClient := range currentExtClients {
if extClient.IngressGatewayID == gatewayID {
if err = DeleteExtClient(networkName, extClient.ClientID); err != nil {
logger.Log(1, "failed to remove ext client", extClient.ClientID)
continue
}
}
}
return nil
}

View file

@ -7,6 +7,7 @@ import (
"net"
"os/exec"
"strings"
"time"
"github.com/go-playground/validator/v10"
"github.com/gravitl/netmaker/database"
@ -38,6 +39,107 @@ func GetNetworks() ([]models.Network, error) {
return networks, err
}
// DeleteNetwork - deletes a network
func DeleteNetwork(network string) error {
nodeCount, err := GetNetworkNonServerNodeCount(network)
if nodeCount == 0 || database.IsEmptyRecord(err) {
// delete server nodes first then db records
servers, err := GetSortedNetworkServerNodes(network)
if err == nil {
for _, s := range servers {
if err = DeleteNode(&s, true); err != nil {
logger.Log(2, "could not removed server", s.Name, "before deleting network", network)
} else {
logger.Log(2, "removed server", s.Name, "before deleting network", network)
}
}
} else {
logger.Log(1, "could not remove servers before deleting network", network)
}
return database.DeleteRecord(database.NETWORKS_TABLE_NAME, network)
}
return errors.New("node check failed. All nodes must be deleted before deleting network")
}
// CreateNetwork - creates a network in database
func CreateNetwork(network models.Network) error {
network.SetDefaults()
network.SetNodesLastModified()
network.SetNetworkLastModified()
network.KeyUpdateTimeStamp = time.Now().Unix()
err := ValidateNetwork(&network, false)
if err != nil {
//returnErrorResponse(w, r, formatError(err, "badrequest"))
return err
}
data, err := json.Marshal(&network)
if err != nil {
return err
}
if err = database.Insert(network.NetID, string(data), database.NETWORKS_TABLE_NAME); err != nil {
return err
}
return err
}
// NetworkNodesUpdatePullChanges - tells nodes on network to pull
func NetworkNodesUpdatePullChanges(networkName string) error {
collections, err := database.FetchRecords(database.NODES_TABLE_NAME)
if err != nil {
if database.IsEmptyRecord(err) {
return nil
}
return err
}
for _, value := range collections {
var node models.Node
err := json.Unmarshal([]byte(value), &node)
if err != nil {
fmt.Println("error in node address assignment!")
return err
}
if node.Network == networkName {
node.PullChanges = "yes"
data, err := json.Marshal(&node)
if err != nil {
return err
}
node.SetID()
database.Insert(node.ID, string(data), database.NODES_TABLE_NAME)
}
}
return nil
}
// GetNetworkNonServerNodeCount - get number of network non server nodes
func GetNetworkNonServerNodeCount(networkName string) (int, error) {
collection, err := database.FetchRecords(database.NODES_TABLE_NAME)
count := 0
if err != nil && !database.IsEmptyRecord(err) {
return count, err
}
for _, value := range collection {
var node models.Node
if err = json.Unmarshal([]byte(value), &node); err != nil {
return count, err
} else {
if node.Network == networkName && node.IsServer != "yes" {
count++
}
}
}
return count, nil
}
// GetParentNetwork - get parent network
func GetParentNetwork(networkname string) (models.Network, error) {
@ -462,8 +564,91 @@ func ValidateNetwork(network *models.Network, isUpdate bool) error {
return err
}
// ParseNetwork - parses a network into a model
func ParseNetwork(value string) (models.Network, error) {
var network models.Network
err := json.Unmarshal([]byte(value), &network)
return network, err
}
// ValidateNetworkUpdate - checks if network is valid to update
func ValidateNetworkUpdate(network models.Network) error {
v := validator.New()
_ = v.RegisterValidation("netid_valid", func(fl validator.FieldLevel) bool {
if fl.Field().String() == "" {
return true
}
inCharSet := nameInNetworkCharSet(fl.Field().String())
return inCharSet
})
err := v.Struct(network)
if err != nil {
for _, e := range err.(validator.ValidationErrors) {
logger.Log(1, "validator", e.Error())
}
}
return err
}
// KeyUpdate - updates keys on network
func KeyUpdate(netname string) (models.Network, error) {
err := networkNodesUpdateAction(netname, models.NODE_UPDATE_KEY)
if err != nil {
return models.Network{}, err
}
return models.Network{}, nil
}
// == Private ==
func networkNodesUpdateAction(networkName string, action string) error {
collections, err := database.FetchRecords(database.NODES_TABLE_NAME)
if err != nil {
if database.IsEmptyRecord(err) {
return nil
}
return err
}
for _, value := range collections {
var node models.Node
err := json.Unmarshal([]byte(value), &node)
if err != nil {
fmt.Println("error in node address assignment!")
return err
}
if action == models.NODE_UPDATE_KEY && node.IsStatic == "yes" {
continue
}
if node.Network == networkName {
node.Action = action
data, err := json.Marshal(&node)
if err != nil {
return err
}
node.SetID()
database.Insert(node.ID, string(data), database.NODES_TABLE_NAME)
}
}
return nil
}
func nameInNetworkCharSet(name string) bool {
charset := "abcdefghijklmnopqrstuvwxyz1234567890-_."
for _, char := range name {
if !strings.Contains(charset, strings.ToLower(string(char))) {
return false
}
}
return true
}
func deleteInterface(ifacename string, postdown string) error {
var err error
if !ncutils.IsKernel() {

View file

@ -63,6 +63,28 @@ func GetSortedNetworkServerNodes(network string) ([]models.Node, error) {
return nodes, nil
}
// UncordonNode - approves a node to join a network
func UncordonNode(network, macaddress string) (models.Node, error) {
node, err := GetNodeByMacAddress(network, macaddress)
if err != nil {
return models.Node{}, err
}
node.SetLastModified()
node.IsPending = "no"
node.PullChanges = "yes"
data, err := json.Marshal(&node)
if err != nil {
return node, err
}
key, err := GetRecordKey(node.MacAddress, node.Network)
if err != nil {
return node, err
}
err = database.Insert(key, string(data), database.NODES_TABLE_NAME)
return node, err
}
// GetPeers - gets the peers of a given node
func GetPeers(node models.Node) ([]models.Node, error) {
if node.IsServer == "yes" && IsLeader(&node) {

140
logic/relay.go Normal file
View file

@ -0,0 +1,140 @@
package logic
import (
"encoding/json"
"errors"
"time"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/models"
)
// CreateRelay - creates a relay
func CreateRelay(relay models.RelayRequest) (models.Node, error) {
node, err := GetNodeByMacAddress(relay.NetID, relay.NodeID)
if node.OS == "macos" { // add in darwin later
return models.Node{}, errors.New(node.OS + " is unsupported for relay")
}
if err != nil {
return models.Node{}, err
}
err = ValidateRelay(relay)
if err != nil {
return models.Node{}, err
}
node.IsRelay = "yes"
node.RelayAddrs = relay.RelayAddrs
key, err := GetRecordKey(relay.NodeID, relay.NetID)
if err != nil {
return node, err
}
node.SetLastModified()
node.PullChanges = "yes"
nodeData, err := json.Marshal(&node)
if err != nil {
return node, err
}
if err = database.Insert(key, string(nodeData), database.NODES_TABLE_NAME); err != nil {
return models.Node{}, err
}
err = SetRelayedNodes("yes", node.Network, node.RelayAddrs)
if err != nil {
return node, err
}
if err = NetworkNodesUpdatePullChanges(node.Network); err != nil {
return models.Node{}, err
}
return node, nil
}
// SetRelayedNodes- set relayed nodes
func SetRelayedNodes(yesOrno string, networkName string, addrs []string) error {
collections, err := database.FetchRecords(database.NODES_TABLE_NAME)
if err != nil {
return err
}
for _, value := range collections {
var node models.Node
err := json.Unmarshal([]byte(value), &node)
if err != nil {
return err
}
if node.Network == networkName {
for _, addr := range addrs {
if addr == node.Address || addr == node.Address6 {
node.IsRelayed = yesOrno
data, err := json.Marshal(&node)
if err != nil {
return err
}
node.SetID()
database.Insert(node.ID, string(data), database.NODES_TABLE_NAME)
}
}
}
}
return nil
}
// ValidateRelay - checks if relay is valid
func ValidateRelay(relay models.RelayRequest) error {
var err error
//isIp := functions.IsIpCIDR(gateway.RangeString)
empty := len(relay.RelayAddrs) == 0
if empty {
err = errors.New("IP Ranges Cannot Be Empty")
}
return err
}
// UpdateRelay - updates a relay
func UpdateRelay(network string, oldAddrs []string, newAddrs []string) {
time.Sleep(time.Second / 4)
err := SetRelayedNodes("no", network, oldAddrs)
if err != nil {
logger.Log(1, err.Error())
}
err = SetRelayedNodes("yes", network, newAddrs)
if err != nil {
logger.Log(1, err.Error())
}
}
// DeleteRelay - deletes a relay
func DeleteRelay(network, macaddress string) (models.Node, error) {
node, err := GetNodeByMacAddress(network, macaddress)
if err != nil {
return models.Node{}, err
}
err = SetRelayedNodes("no", node.Network, node.RelayAddrs)
if err != nil {
return node, err
}
node.IsRelay = "no"
node.RelayAddrs = []string{}
node.SetLastModified()
node.PullChanges = "yes"
key, err := GetRecordKey(node.MacAddress, node.Network)
if err != nil {
return models.Node{}, err
}
data, err := json.Marshal(&node)
if err != nil {
return models.Node{}, err
}
if err = database.Insert(key, string(data), database.NODES_TABLE_NAME); err != nil {
return models.Node{}, err
}
if err = NetworkNodesUpdatePullChanges(network); err != nil {
return models.Node{}, err
}
return node, nil
}