diff --git a/controllers/ext_client.go b/controllers/ext_client.go index 5796439b..02d70fd3 100644 --- a/controllers/ext_client.go +++ b/controllers/ext_client.go @@ -726,6 +726,16 @@ func createExtClient(w http.ResponseWriter, r *http.Request) { return } for _, extclient := range extclients { + // if device id is sent, then make sure extclient with the same device id + // does not exist. + if customExtClient.DeviceID != "" && extclient.DeviceID == customExtClient.DeviceID && + extclient.OwnerID == caller.UserName && nodeid == extclient.IngressGatewayID { + err = errors.New("remote client config already exists on the gateway") + slog.Error("failed to create extclient", "user", userName, "error", err) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) + return + } + if extclient.RemoteAccessClientID != "" && extclient.RemoteAccessClientID == customExtClient.RemoteAccessClientID && extclient.OwnerID == caller.UserName && nodeid == extclient.IngressGatewayID { // extclient on the gw already exists for the remote access client @@ -774,6 +784,7 @@ func createExtClient(w http.ResponseWriter, r *http.Request) { extclient.Enabled = parentNetwork.DefaultACL == "yes" } extclient.Os = customExtClient.Os + extclient.DeviceID = customExtClient.DeviceID extclient.DeviceName = customExtClient.DeviceName if customExtClient.IsAlreadyConnectedToInetGw { slog.Warn("RAC/Client is already connected to internet gateway. this may mask their real IP address", "client IP", customExtClient.PublicEndpoint) diff --git a/controllers/node.go b/controllers/node.go index a7f12a53..928da092 100644 --- a/controllers/node.go +++ b/controllers/node.go @@ -417,7 +417,7 @@ func getNetworkNodeStatus(w http.ResponseWriter, r *http.Request) { } nodes = logic.AddStaticNodestoList(nodes) - nodes = logic.AddStatusToNodes(nodes, false) + nodes = logic.AddStatusToNodes(nodes, true) // return all the nodes in JSON/API format apiNodesStatusMap := logic.GetNodesStatusAPI(nodes[:]) logger.Log(3, r.Header.Get("user"), "fetched all nodes they have access to") diff --git a/controllers/server.go b/controllers/server.go index 27780be4..ccad7425 100644 --- a/controllers/server.go +++ b/controllers/server.go @@ -1,8 +1,11 @@ package controller import ( + "context" "encoding/json" "errors" + "github.com/gravitl/netmaker/db" + "github.com/gravitl/netmaker/schema" "github.com/google/go-cmp/cmp" "net/http" "os" @@ -111,10 +114,7 @@ func getUsage(w http.ResponseWriter, _ *http.Request) { if err == nil { serverUsage.Ingresses = len(ingresses) } - egresses, err := logic.GetAllEgresses() - if err == nil { - serverUsage.Egresses = len(egresses) - } + serverUsage.Egresses, _ = (&schema.Egress{}).Count(db.WithContext(context.TODO())) relays, err := logic.GetRelays() if err == nil { serverUsage.Relays = len(relays) diff --git a/logic/extpeers.go b/logic/extpeers.go index 3fa431df..e8dfd780 100644 --- a/logic/extpeers.go +++ b/logic/extpeers.go @@ -433,6 +433,9 @@ func UpdateExtClient(old *models.ExtClient, update *models.CustomExtClient) mode if update.Country != "" && update.Country != old.Country { new.Country = update.Country } + if update.DeviceID != "" && old.DeviceID == "" { + new.DeviceID = update.DeviceID + } return new } diff --git a/logic/hosts.go b/logic/hosts.go index 6d469057..bbdd740d 100644 --- a/logic/hosts.go +++ b/logic/hosts.go @@ -126,7 +126,7 @@ func GetAllHostsWithStatus(status models.NodeStatus) ([]models.Host, error) { nodes := GetHostNodes(&host) for _, node := range nodes { - GetNodeCheckInStatus(&node, false) + getNodeCheckInStatus(&node, false) if node.Status == status { validHosts = append(validHosts, host) break diff --git a/logic/nodes.go b/logic/nodes.go index 600f3ce1..b197a3bb 100644 --- a/logic/nodes.go +++ b/logic/nodes.go @@ -471,7 +471,7 @@ func AddStatusToNodes(nodes []models.Node, statusCall bool) (nodesWithStatus []m if statusCall { GetNodeStatus(&node, aclDefaultPolicyStatusMap[node.Network]) } else { - GetNodeCheckInStatus(&node, true) + getNodeCheckInStatus(&node, true) } nodesWithStatus = append(nodesWithStatus, node) diff --git a/logic/status.go b/logic/status.go index ec8324a2..76b170fc 100644 --- a/logic/status.go +++ b/logic/status.go @@ -6,9 +6,9 @@ import ( "github.com/gravitl/netmaker/models" ) -var GetNodeStatus = GetNodeCheckInStatus +var GetNodeStatus = getNodeCheckInStatus -func GetNodeCheckInStatus(node *models.Node, t bool) { +func getNodeCheckInStatus(node *models.Node, t bool) { // On CE check only last check-in time if node.IsStatic { if !node.StaticNode.Enabled { diff --git a/models/extclient.go b/models/extclient.go index 2347061d..ae9acf04 100644 --- a/models/extclient.go +++ b/models/extclient.go @@ -24,6 +24,7 @@ type ExtClient struct { PostDown string `json:"postdown" bson:"postdown"` Tags map[TagID]struct{} `json:"tags"` Os string `json:"os"` + DeviceID string `json:"device_id"` DeviceName string `json:"device_name"` PublicEndpoint string `json:"public_endpoint"` Country string `json:"country"` @@ -44,6 +45,7 @@ type CustomExtClient struct { PostDown string `json:"postdown" bson:"postdown" validate:"max=1024"` Tags map[TagID]struct{} `json:"tags"` Os string `json:"os"` + DeviceID string `json:"device_id"` DeviceName string `json:"device_name"` IsAlreadyConnectedToInetGw bool `json:"is_already_connected_to_inet_gw"` PublicEndpoint string `json:"public_endpoint"` diff --git a/pro/controllers/users.go b/pro/controllers/users.go index 2ba43037..ed0edde1 100644 --- a/pro/controllers/users.go +++ b/pro/controllers/users.go @@ -1382,6 +1382,7 @@ func getUserRemoteAccessGwsV1(w http.ResponseWriter, r *http.Request) { logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("failed to fetch user %s, error: %v", username, err), "badrequest")) return } + deviceID := r.URL.Query().Get("device_id") remoteAccessClientID := r.URL.Query().Get("remote_access_clientid") var req models.UserRemoteGwsReq if remoteAccessClientID == "" { @@ -1407,58 +1408,95 @@ func getUserRemoteAccessGwsV1(w http.ResponseWriter, r *http.Request) { return } userGwNodes := proLogic.GetUserRAGNodes(*user) + + userExtClients := make(map[string][]models.ExtClient) + + // group all extclients of the requesting user by ingress + // gateway. for _, extClient := range allextClients { - node, ok := userGwNodes[extClient.IngressGatewayID] + // filter our extclients that don't belong to this user. + if extClient.OwnerID != username { + continue + } + + _, ok := userExtClients[extClient.IngressGatewayID] + if !ok { + userExtClients[extClient.IngressGatewayID] = []models.ExtClient{} + } + + userExtClients[extClient.IngressGatewayID] = append(userExtClients[extClient.IngressGatewayID], extClient) + } + + for ingressGatewayID, extClients := range userExtClients { + node, ok := userGwNodes[ingressGatewayID] if !ok { continue } - if extClient.RemoteAccessClientID == req.RemoteAccessClientID && extClient.OwnerID == username { - host, err := logic.GetHost(node.HostID.String()) - if err != nil { - continue + var gwClient models.ExtClient + var found bool + if deviceID != "" { + for _, extClient := range extClients { + if extClient.DeviceID == deviceID { + gwClient = extClient + found = true + break + } } - network, err := logic.GetNetwork(node.Network) - if err != nil { - slog.Error("failed to get node network", "error", err) - continue - } - nodesWithStatus := logic.AddStatusToNodes([]models.Node{node}, false) - if len(nodesWithStatus) > 0 { - node = nodesWithStatus[0] - } - - gws := userGws[node.Network] - if extClient.DNS == "" { - extClient.DNS = node.IngressDNS - } - - extClient.IngressGatewayEndpoint = utils.GetExtClientEndpoint( - host.EndpointIP, - host.EndpointIPv6, - logic.GetPeerListenPort(host), - ) - extClient.AllowedIPs = logic.GetExtclientAllowedIPs(extClient) - gws = append(gws, models.UserRemoteGws{ - GwID: node.ID.String(), - GWName: host.Name, - Network: node.Network, - GwClient: extClient, - Connected: true, - IsInternetGateway: node.IsInternetGateway, - GwPeerPublicKey: host.PublicKey.String(), - GwListenPort: logic.GetPeerListenPort(host), - Metadata: node.Metadata, - AllowedEndpoints: getAllowedRagEndpoints(&node, host), - NetworkAddresses: []string{network.AddressRange, network.AddressRange6}, - Status: node.Status, - DnsAddress: node.IngressDNS, - Addresses: utils.NoEmptyStringToCsv(node.Address.String(), node.Address6.String()), - }) - userGws[node.Network] = gws - delete(userGwNodes, node.ID.String()) } + + if !found { + // TODO: prevent ip clashes. + if len(extClients) > 0 { + gwClient = extClients[0] + } + } + + host, err := logic.GetHost(node.HostID.String()) + if err != nil { + continue + } + network, err := logic.GetNetwork(node.Network) + if err != nil { + slog.Error("failed to get node network", "error", err) + continue + } + nodesWithStatus := logic.AddStatusToNodes([]models.Node{node}, false) + if len(nodesWithStatus) > 0 { + node = nodesWithStatus[0] + } + + gws := userGws[node.Network] + if gwClient.DNS == "" { + gwClient.DNS = node.IngressDNS + } + + gwClient.IngressGatewayEndpoint = utils.GetExtClientEndpoint( + host.EndpointIP, + host.EndpointIPv6, + logic.GetPeerListenPort(host), + ) + gwClient.AllowedIPs = logic.GetExtclientAllowedIPs(gwClient) + gws = append(gws, models.UserRemoteGws{ + GwID: node.ID.String(), + GWName: host.Name, + Network: node.Network, + GwClient: gwClient, + Connected: true, + IsInternetGateway: node.IsInternetGateway, + GwPeerPublicKey: host.PublicKey.String(), + GwListenPort: logic.GetPeerListenPort(host), + Metadata: node.Metadata, + AllowedEndpoints: getAllowedRagEndpoints(&node, host), + NetworkAddresses: []string{network.AddressRange, network.AddressRange6}, + Status: node.Status, + DnsAddress: node.IngressDNS, + Addresses: utils.NoEmptyStringToCsv(node.Address.String(), node.Address6.String()), + }) + userGws[node.Network] = gws + delete(userGwNodes, node.ID.String()) } + // add remaining gw nodes to resp for gwID := range userGwNodes { node, err := logic.GetNodeByID(gwID) diff --git a/pro/util.go b/pro/util.go index 80e6e57a..141e9469 100644 --- a/pro/util.go +++ b/pro/util.go @@ -4,10 +4,11 @@ package pro import ( + "context" "encoding/base64" - + "github.com/gravitl/netmaker/db" "github.com/gravitl/netmaker/models" - + "github.com/gravitl/netmaker/schema" "github.com/gravitl/netmaker/logic" ) @@ -49,10 +50,7 @@ func getCurrentServerUsage() (limits Usage) { if err == nil { limits.Ingresses = len(ingresses) } - egresses, err := logic.GetAllEgresses() - if err == nil { - limits.Egresses = len(egresses) - } + limits.Egresses, _ = (&schema.Egress{}).Count(db.WithContext(context.TODO())) relays, err := logic.GetRelays() if err == nil { limits.Relays = len(relays) diff --git a/schema/egress.go b/schema/egress.go index 2c711f58..69246afe 100644 --- a/schema/egress.go +++ b/schema/egress.go @@ -63,6 +63,12 @@ func (e *Egress) ListByNetwork(ctx context.Context) (egs []Egress, err error) { return } +func (e *Egress) Count(ctx context.Context) (int, error) { + var count int64 + err := db.FromContext(ctx).Model(&Egress{}).Count(&count).Error + return int(count), err +} + func (e *Egress) Delete(ctx context.Context) error { return db.FromContext(ctx).Table(e.Table()).Where("id = ?", e.ID).Delete(&e).Error }