diff --git a/controllers/network.go b/controllers/network.go index 3258d10d..077a3c1c 100644 --- a/controllers/network.go +++ b/controllers/network.go @@ -225,15 +225,6 @@ func updateNetworkACLv2(w http.ResponseWriter, r *http.Request) { return } - var networkACL acls.ACLContainer - networkACL, err = networkACL.Get(acls.ContainerID(networkID)) - if err != nil { - logger.Log(0, r.Header.Get("user"), - fmt.Sprintf("failed to fetch ACLs for network [%s]: %v", networkID, err)) - logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) - return - } - _network := &schema.Network{ ID: networkID, } diff --git a/controllers/network_test.go b/controllers/network_test.go index 258f4883..503e96ea 100644 --- a/controllers/network_test.go +++ b/controllers/network_test.go @@ -69,7 +69,7 @@ func TestGetNetwork(t *testing.T) { }) t.Run("GetNonExistantNetwork", func(t *testing.T) { network, err := logic.GetNetwork("doesnotexist") - assert.EqualError(t, err, "no result found") + assert.EqualError(t, err, "record not found") assert.Equal(t, "", network.NetID) }) } @@ -215,7 +215,10 @@ func TestIpv6Network(t *testing.T) { func deleteAllNetworks() { deleteAllNodes() - database.DeleteAllRecords(database.NETWORKS_TABLE_NAME) + networks, _ := (&schema.Network{}).ListAll(db.WithContext(context.TODO())) + for _, network := range networks { + _ = network.Delete(db.WithContext(context.TODO())) + } } func createNet() { diff --git a/controllers/node_test.go b/controllers/node_test.go index 8996b2f5..8d4ca123 100644 --- a/controllers/node_test.go +++ b/controllers/node_test.go @@ -1,11 +1,15 @@ package controller import ( + "context" + "github.com/gravitl/netmaker/database" + "github.com/gravitl/netmaker/db" + "github.com/gravitl/netmaker/schema" + "log" "net" "testing" "github.com/google/uuid" - "github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/logic/acls" "github.com/gravitl/netmaker/logic/acls/nodeacls" @@ -23,12 +27,12 @@ func TestGetNetworkNodes(t *testing.T) { t.Run("BadNet", func(t *testing.T) { node, err := logic.GetNetworkNodes("badnet") assert.Nil(t, err) - assert.Equal(t, []models.Node{}, node) + assert.Equal(t, []models.Node(nil), node) }) t.Run("NoNodes", func(t *testing.T) { node, err := logic.GetNetworkNodes("skynet") assert.Nil(t, err) - assert.Equal(t, []models.Node{}, node) + assert.Equal(t, []models.Node(nil), node) }) t.Run("Success", func(t *testing.T) { createTestNode() @@ -51,6 +55,7 @@ func TestValidateEgressGateway(t *testing.T) { func TestNodeACLs(t *testing.T) { deleteAllNodes() + deleteAllAcls() node1 := createNodeWithParams("", "10.0.0.50/32") node2 := createNodeWithParams("", "10.0.0.100/32") logic.AssociateNodeToHost(node1, &linuxHost) @@ -92,7 +97,7 @@ func TestNodeACLs(t *testing.T) { currentACL.Save(acls.ContainerID(node1.Network)) }) t.Run("node acls correct after add new node not allowed", func(t *testing.T) { - node3 := createNodeWithParams("", "10.0.0.100/32") + node3 := createNodeWithParams("", "10.0.0.75/32") createNodeHosts() n, e := logic.GetNetwork(node3.Network) assert.Nil(t, e) @@ -123,7 +128,14 @@ func TestNodeACLs(t *testing.T) { } func deleteAllNodes() { - database.DeleteAllRecords(database.NODES_TABLE_NAME) + nodes, _ := (&schema.Node{}).ListAll(db.WithContext(context.TODO())) + for _, node := range nodes { + _ = node.Delete(db.WithContext(context.TODO())) + } +} + +func deleteAllAcls() { + _ = database.DeleteAllRecords(database.NODE_ACLS_TABLE_NAME) } func createTestNode() *models.Node { @@ -162,7 +174,10 @@ func createNodeHosts() { OS: "linux", Name: "linuxhost", } - _ = logic.CreateHost(&linuxHost) + err := logic.CreateHost(&linuxHost) + if err != nil { + log.Fatal(err) + } nonLinuxHost = models.Host{ ID: uuid.New(), OS: "windows", @@ -171,5 +186,8 @@ func createNodeHosts() { HostPass: "password", } - _ = logic.CreateHost(&nonLinuxHost) + err = logic.CreateHost(&nonLinuxHost) + if err != nil { + log.Fatal(err) + } } diff --git a/converters/node.go b/converters/node.go index ead6218e..d4830c19 100644 --- a/converters/node.go +++ b/converters/node.go @@ -72,7 +72,7 @@ func ToSchemaNode(node models.Node) schema.Node { egressGatewayNodeConfig = &config } - var failOverPeers datatypes.JSONMap + failOverPeers := make(datatypes.JSONMap) if node.IsFailOver { for peer := range node.FailOverPeers { failOverPeers[peer] = true diff --git a/database/postgres.go b/database/postgres.go index c28e9425..c1917227 100644 --- a/database/postgres.go +++ b/database/postgres.go @@ -5,8 +5,6 @@ import ( "database/sql" "errors" "github.com/gravitl/netmaker/db" - "time" - _ "github.com/lib/pq" ) @@ -35,9 +33,6 @@ func initPGDB() error { return dbOpenErr } - PGDB.SetMaxOpenConns(5) - PGDB.SetConnMaxLifetime(time.Hour) - return PGDB.Ping() } diff --git a/functions/helpers_test.go b/functions/helpers_test.go index 7c5dd1aa..7903e1f2 100644 --- a/functions/helpers_test.go +++ b/functions/helpers_test.go @@ -48,18 +48,21 @@ func TestMain(m *testing.M) { } func TestNetworkExists(t *testing.T) { - database.DeleteRecord(database.NETWORKS_TABLE_NAME, testNetwork.NetID) + _network := &schema.Network{ + ID: testNetwork.NetID, + } + _ = _network.Delete(db.WithContext(context.TODO())) exists, err := logic.NetworkExists(testNetwork.NetID) - assert.NotNil(t, err) + assert.Nil(t, err) assert.False(t, exists) - err = logic.SaveNetwork(testNetwork) + err = _network.Create(db.WithContext(context.TODO())) assert.Nil(t, err) exists, err = logic.NetworkExists(testNetwork.NetID) assert.Nil(t, err) assert.True(t, exists) - err = database.DeleteRecord(database.NETWORKS_TABLE_NAME, testNetwork.NetID) + err = _network.Delete(db.WithContext(context.TODO())) assert.Nil(t, err) } diff --git a/logic/hosts.go b/logic/hosts.go index 3b36a23a..c8043957 100644 --- a/logic/hosts.go +++ b/logic/hosts.go @@ -133,7 +133,7 @@ func CreateHost(h *models.Host) error { return errors.New("free tier limits exceeded on machines") } _, err := GetHost(h.ID.String()) - if (err != nil && !database.IsEmptyRecord(err)) || (err == nil) { + if (err != nil && !errors.Is(err, gorm.ErrRecordNotFound)) || (err == nil) { return ErrHostExists } diff --git a/pro/logic/ext_acls.go b/pro/logic/ext_acls.go index aa2c99a4..e94e559d 100644 --- a/pro/logic/ext_acls.go +++ b/pro/logic/ext_acls.go @@ -142,6 +142,6 @@ func UpdateProNodeACLs(node *models.Node) error { } } - _, err = currentACLs.Save(acls.ContainerID(node.Network)) + _, _ = currentACLs.Save(acls.ContainerID(node.Network)) return nil } diff --git a/schema/acl.go b/schema/acl.go index 7e09152f..99f30c0e 100644 --- a/schema/acl.go +++ b/schema/acl.go @@ -39,7 +39,10 @@ func (a *ACL) Create(ctx context.Context) error { } func (a *ACL) Get(ctx context.Context) error { - return db.FromContext(ctx).Model(a).First(a).Error + return db.FromContext(ctx).Model(a). + Where("id = ?", a.ID). + First(a). + Error } func (a *ACL) ListAll(ctx context.Context) ([]ACL, error) { diff --git a/schema/host.go b/schema/host.go index 49a08225..a0dcb46f 100644 --- a/schema/host.go +++ b/schema/host.go @@ -57,7 +57,10 @@ func (h *Host) Create(ctx context.Context) error { } func (h *Host) Get(ctx context.Context) error { - return db.FromContext(ctx).Model(h).First(h).Error + return db.FromContext(ctx).Model(h). + Where("id = ?", h.ID). + First(h). + Error } func (h *Host) GetNodes(ctx context.Context) ([]Node, error) { diff --git a/schema/job.go b/schema/job.go index 7fd3a23c..fa88d729 100644 --- a/schema/job.go +++ b/schema/job.go @@ -31,5 +31,8 @@ func (j *Job) Create(ctx context.Context) error { // Get returns a job record with the given Job.ID. func (j *Job) Get(ctx context.Context) error { - return db.FromContext(ctx).Model(j).First(j).Error + return db.FromContext(ctx).Model(j). + Where("id = ?", j.ID). + First(j). + Error } diff --git a/schema/network.go b/schema/network.go index 43747d93..8cd027c3 100644 --- a/schema/network.go +++ b/schema/network.go @@ -46,7 +46,10 @@ func (n *Network) Create(ctx context.Context) error { } func (n *Network) Get(ctx context.Context) error { - return db.FromContext(ctx).Model(n).First(n).Error + return db.FromContext(ctx).Model(n). + Where("id = ?", n.ID). + First(n). + Error } func (n *Network) GetNodes(ctx context.Context) ([]Node, error) { diff --git a/schema/node.go b/schema/node.go index 14ccbf2f..0416da95 100644 --- a/schema/node.go +++ b/schema/node.go @@ -115,7 +115,10 @@ func (n *Node) Create(ctx context.Context) error { } func (n *Node) Get(ctx context.Context) error { - return db.FromContext(ctx).Model(n).First(n).Error + return db.FromContext(ctx).Model(n). + Where("id = ?", n.ID). + First(n). + Error } func (n *Node) GetHost(ctx context.Context) error {