fix(go): test fixes;

This commit is contained in:
Vishal Dalwadi 2025-07-08 11:43:15 +05:30
parent c71e106421
commit c57ddbbc87
13 changed files with 60 additions and 35 deletions

View file

@ -225,15 +225,6 @@ func updateNetworkACLv2(w http.ResponseWriter, r *http.Request) {
return
}
var networkACL acls.ACLContainer
networkACL, err = networkACL.Get(acls.ContainerID(networkID))
if err != nil {
logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to fetch ACLs for network [%s]: %v", networkID, err))
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
}
_network := &schema.Network{
ID: networkID,
}

View file

@ -69,7 +69,7 @@ func TestGetNetwork(t *testing.T) {
})
t.Run("GetNonExistantNetwork", func(t *testing.T) {
network, err := logic.GetNetwork("doesnotexist")
assert.EqualError(t, err, "no result found")
assert.EqualError(t, err, "record not found")
assert.Equal(t, "", network.NetID)
})
}
@ -215,7 +215,10 @@ func TestIpv6Network(t *testing.T) {
func deleteAllNetworks() {
deleteAllNodes()
database.DeleteAllRecords(database.NETWORKS_TABLE_NAME)
networks, _ := (&schema.Network{}).ListAll(db.WithContext(context.TODO()))
for _, network := range networks {
_ = network.Delete(db.WithContext(context.TODO()))
}
}
func createNet() {

View file

@ -1,11 +1,15 @@
package controller
import (
"context"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/db"
"github.com/gravitl/netmaker/schema"
"log"
"net"
"testing"
"github.com/google/uuid"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/logic/acls"
"github.com/gravitl/netmaker/logic/acls/nodeacls"
@ -23,12 +27,12 @@ func TestGetNetworkNodes(t *testing.T) {
t.Run("BadNet", func(t *testing.T) {
node, err := logic.GetNetworkNodes("badnet")
assert.Nil(t, err)
assert.Equal(t, []models.Node{}, node)
assert.Equal(t, []models.Node(nil), node)
})
t.Run("NoNodes", func(t *testing.T) {
node, err := logic.GetNetworkNodes("skynet")
assert.Nil(t, err)
assert.Equal(t, []models.Node{}, node)
assert.Equal(t, []models.Node(nil), node)
})
t.Run("Success", func(t *testing.T) {
createTestNode()
@ -51,6 +55,7 @@ func TestValidateEgressGateway(t *testing.T) {
func TestNodeACLs(t *testing.T) {
deleteAllNodes()
deleteAllAcls()
node1 := createNodeWithParams("", "10.0.0.50/32")
node2 := createNodeWithParams("", "10.0.0.100/32")
logic.AssociateNodeToHost(node1, &linuxHost)
@ -92,7 +97,7 @@ func TestNodeACLs(t *testing.T) {
currentACL.Save(acls.ContainerID(node1.Network))
})
t.Run("node acls correct after add new node not allowed", func(t *testing.T) {
node3 := createNodeWithParams("", "10.0.0.100/32")
node3 := createNodeWithParams("", "10.0.0.75/32")
createNodeHosts()
n, e := logic.GetNetwork(node3.Network)
assert.Nil(t, e)
@ -123,7 +128,14 @@ func TestNodeACLs(t *testing.T) {
}
func deleteAllNodes() {
database.DeleteAllRecords(database.NODES_TABLE_NAME)
nodes, _ := (&schema.Node{}).ListAll(db.WithContext(context.TODO()))
for _, node := range nodes {
_ = node.Delete(db.WithContext(context.TODO()))
}
}
func deleteAllAcls() {
_ = database.DeleteAllRecords(database.NODE_ACLS_TABLE_NAME)
}
func createTestNode() *models.Node {
@ -162,7 +174,10 @@ func createNodeHosts() {
OS: "linux",
Name: "linuxhost",
}
_ = logic.CreateHost(&linuxHost)
err := logic.CreateHost(&linuxHost)
if err != nil {
log.Fatal(err)
}
nonLinuxHost = models.Host{
ID: uuid.New(),
OS: "windows",
@ -171,5 +186,8 @@ func createNodeHosts() {
HostPass: "password",
}
_ = logic.CreateHost(&nonLinuxHost)
err = logic.CreateHost(&nonLinuxHost)
if err != nil {
log.Fatal(err)
}
}

View file

@ -72,7 +72,7 @@ func ToSchemaNode(node models.Node) schema.Node {
egressGatewayNodeConfig = &config
}
var failOverPeers datatypes.JSONMap
failOverPeers := make(datatypes.JSONMap)
if node.IsFailOver {
for peer := range node.FailOverPeers {
failOverPeers[peer] = true

View file

@ -5,8 +5,6 @@ import (
"database/sql"
"errors"
"github.com/gravitl/netmaker/db"
"time"
_ "github.com/lib/pq"
)
@ -35,9 +33,6 @@ func initPGDB() error {
return dbOpenErr
}
PGDB.SetMaxOpenConns(5)
PGDB.SetConnMaxLifetime(time.Hour)
return PGDB.Ping()
}

View file

@ -48,18 +48,21 @@ func TestMain(m *testing.M) {
}
func TestNetworkExists(t *testing.T) {
database.DeleteRecord(database.NETWORKS_TABLE_NAME, testNetwork.NetID)
_network := &schema.Network{
ID: testNetwork.NetID,
}
_ = _network.Delete(db.WithContext(context.TODO()))
exists, err := logic.NetworkExists(testNetwork.NetID)
assert.NotNil(t, err)
assert.Nil(t, err)
assert.False(t, exists)
err = logic.SaveNetwork(testNetwork)
err = _network.Create(db.WithContext(context.TODO()))
assert.Nil(t, err)
exists, err = logic.NetworkExists(testNetwork.NetID)
assert.Nil(t, err)
assert.True(t, exists)
err = database.DeleteRecord(database.NETWORKS_TABLE_NAME, testNetwork.NetID)
err = _network.Delete(db.WithContext(context.TODO()))
assert.Nil(t, err)
}

View file

@ -133,7 +133,7 @@ func CreateHost(h *models.Host) error {
return errors.New("free tier limits exceeded on machines")
}
_, err := GetHost(h.ID.String())
if (err != nil && !database.IsEmptyRecord(err)) || (err == nil) {
if (err != nil && !errors.Is(err, gorm.ErrRecordNotFound)) || (err == nil) {
return ErrHostExists
}

View file

@ -142,6 +142,6 @@ func UpdateProNodeACLs(node *models.Node) error {
}
}
_, err = currentACLs.Save(acls.ContainerID(node.Network))
_, _ = currentACLs.Save(acls.ContainerID(node.Network))
return nil
}

View file

@ -39,7 +39,10 @@ func (a *ACL) Create(ctx context.Context) error {
}
func (a *ACL) Get(ctx context.Context) error {
return db.FromContext(ctx).Model(a).First(a).Error
return db.FromContext(ctx).Model(a).
Where("id = ?", a.ID).
First(a).
Error
}
func (a *ACL) ListAll(ctx context.Context) ([]ACL, error) {

View file

@ -57,7 +57,10 @@ func (h *Host) Create(ctx context.Context) error {
}
func (h *Host) Get(ctx context.Context) error {
return db.FromContext(ctx).Model(h).First(h).Error
return db.FromContext(ctx).Model(h).
Where("id = ?", h.ID).
First(h).
Error
}
func (h *Host) GetNodes(ctx context.Context) ([]Node, error) {

View file

@ -31,5 +31,8 @@ func (j *Job) Create(ctx context.Context) error {
// Get returns a job record with the given Job.ID.
func (j *Job) Get(ctx context.Context) error {
return db.FromContext(ctx).Model(j).First(j).Error
return db.FromContext(ctx).Model(j).
Where("id = ?", j.ID).
First(j).
Error
}

View file

@ -46,7 +46,10 @@ func (n *Network) Create(ctx context.Context) error {
}
func (n *Network) Get(ctx context.Context) error {
return db.FromContext(ctx).Model(n).First(n).Error
return db.FromContext(ctx).Model(n).
Where("id = ?", n.ID).
First(n).
Error
}
func (n *Network) GetNodes(ctx context.Context) ([]Node, error) {

View file

@ -115,7 +115,10 @@ func (n *Node) Create(ctx context.Context) error {
}
func (n *Node) Get(ctx context.Context) error {
return db.FromContext(ctx).Model(n).First(n).Error
return db.FromContext(ctx).Model(n).
Where("id = ?", n.ID).
First(n).
Error
}
func (n *Node) GetHost(ctx context.Context) error {