diff --git a/Dockerfile b/Dockerfile index 667fe1f1..179d9a6c 100644 --- a/Dockerfile +++ b/Dockerfile @@ -11,6 +11,8 @@ FROM alpine:3.21.2 # add a c lib # set the working directory WORKDIR /root/ +RUN apk update && apk upgrade +RUN apk add --no-cache sqlite RUN mkdir -p /etc/netclient/config COPY --from=builder /app/netmaker . COPY --from=builder /app/config config diff --git a/controllers/hosts.go b/controllers/hosts.go index abed0d3a..c17cf90a 100644 --- a/controllers/hosts.go +++ b/controllers/hosts.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "net/http" + "time" "github.com/google/uuid" "github.com/gorilla/mux" @@ -48,6 +49,8 @@ func hostHandlers(r *mux.Router) { Methods(http.MethodPost) r.HandleFunc("/api/v1/fallback/host/{hostid}", Authorize(true, false, "host", http.HandlerFunc(hostUpdateFallback))). Methods(http.MethodPut) + r.HandleFunc("/api/v1/host/{hostid}/peer_info", Authorize(true, false, "host", http.HandlerFunc(getHostPeerInfo))). + Methods(http.MethodGet) r.HandleFunc("/api/emqx/hosts", logic.SecurityCheck(true, http.HandlerFunc(delEmqxHosts))). Methods(http.MethodDelete) r.HandleFunc("/api/v1/auth-register/host", socketHandler) @@ -232,7 +235,7 @@ func pull(w http.ResponseWriter, r *http.Request) { slog.Error("failed to get node:", "id", node.ID, "error", err) continue } - if node.FailedOverBy != uuid.Nil { + if node.FailedOverBy != uuid.Nil && r.URL.Query().Get("reset_failovered") == "true" { logic.ResetFailedOverPeer(&node) sendPeerUpdate = true } @@ -943,6 +946,7 @@ func syncHosts(w http.ResponseWriter, r *http.Request) { slog.Info("host sync requested", "user", user, "host", host.ID.String()) } }(host) + time.Sleep(time.Millisecond * 100) } }() @@ -1017,3 +1021,33 @@ func delEmqxHosts(w http.ResponseWriter, r *http.Request) { } logic.ReturnSuccessResponse(w, r, "deleted hosts data on emqx") } + +// @Summary Fetches host peerinfo +// @Router /api/host/{hostid}/peer_info [get] +// @Tags Hosts +// @Security oauth +// @Param hostid path string true "Host ID" +// @Success 200 {object} models.SuccessResponse +// @Failure 500 {object} models.ErrorResponse +func getHostPeerInfo(w http.ResponseWriter, r *http.Request) { + hostId := mux.Vars(r)["hostid"] + var errorResponse = models.ErrorResponse{} + + host, err := logic.GetHost(hostId) + if err != nil { + slog.Error("failed to retrieve host", "error", err) + errorResponse.Code = http.StatusBadRequest + errorResponse.Message = err.Error() + logic.ReturnErrorResponse(w, r, errorResponse) + return + } + peerInfo, err := logic.GetHostPeerInfo(host) + if err != nil { + slog.Error("failed to retrieve host peerinfo", "error", err) + errorResponse.Code = http.StatusBadRequest + errorResponse.Message = err.Error() + logic.ReturnErrorResponse(w, r, errorResponse) + return + } + logic.ReturnSuccessResponseWithJson(w, r, peerInfo, "fetched host peer info") +} diff --git a/logic/acls.go b/logic/acls.go index 38079bcc..613db4d6 100644 --- a/logic/acls.go +++ b/logic/acls.go @@ -164,6 +164,7 @@ func storeAclInCache(a models.Acl) { aclCacheMutex.Lock() defer aclCacheMutex.Unlock() aclCacheMap[a.ID] = a + } func removeAclFromCache(a models.Acl) { @@ -585,6 +586,7 @@ func IsPeerAllowed(node, peer models.Node, checkDefaultPolicy bool) bool { return true } } + } // list device policies policies := listDevicePolicies(models.NetworkID(peer.Network)) diff --git a/logic/extpeers.go b/logic/extpeers.go index e61d42b9..706c5631 100644 --- a/logic/extpeers.go +++ b/logic/extpeers.go @@ -461,9 +461,7 @@ func GetFwRulesOnIngressGateway(node models.Node) (rules []models.FwRule) { defaultDevicePolicy, _ := GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy) nodes, _ := GetNetworkNodes(node.Network) nodes = append(nodes, GetStaticNodesByNetwork(models.NetworkID(node.Network), true)...) - //fmt.Printf("=====> NODES: %+v \n\n", nodes) userNodes := GetStaticUserNodesByNetwork(models.NetworkID(node.Network)) - //fmt.Printf("=====> USER NODES %+v \n\n", userNodes) for _, userNodeI := range userNodes { for _, peer := range nodes { if peer.IsUserNode { diff --git a/logic/nodes.go b/logic/nodes.go index 0b0c2495..f47e6512 100644 --- a/logic/nodes.go +++ b/logic/nodes.go @@ -40,9 +40,7 @@ func getNodeFromCache(nodeID string) (node models.Node, ok bool) { } func getNodesFromCache() (nodes []models.Node) { nodeCacheMutex.RLock() - for _, node := range nodesCacheMap { - nodes = append(nodes, node) - } + nodes = slices.Collect(maps.Values(nodesCacheMap)) nodeCacheMutex.RUnlock() return } @@ -141,7 +139,7 @@ func GetNetworkNodesMemory(allNodes []models.Node, network string) []models.Node defer nodeNetworkCacheMutex.Unlock() return slices.Collect(maps.Values(networkNodes)) } - var nodes = []models.Node{} + var nodes = make([]models.Node, 0, len(allNodes)) for i := range allNodes { node := allNodes[i] if node.Network == network { diff --git a/logic/peers.go b/logic/peers.go index 076c0dbb..64f6acbb 100644 --- a/logic/peers.go +++ b/logic/peers.go @@ -59,6 +59,80 @@ var ( } ) +// GetHostPeerInfo - fetches required peer info per network +func GetHostPeerInfo(host *models.Host) (models.HostPeerInfo, error) { + peerInfo := models.HostPeerInfo{ + NetworkPeerIDs: make(map[models.NetworkID]models.PeerMap), + } + allNodes, err := GetAllNodes() + if err != nil { + return peerInfo, err + } + for _, nodeID := range host.Nodes { + nodeID := nodeID + node, err := GetNodeByID(nodeID) + if err != nil { + continue + } + + if !node.Connected || node.PendingDelete || node.Action == models.NODE_DELETE { + continue + } + networkPeersInfo := make(models.PeerMap) + defaultDevicePolicy, _ := GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy) + + currentPeers := GetNetworkNodesMemory(allNodes, node.Network) + for _, peer := range currentPeers { + peer := peer + if peer.ID.String() == node.ID.String() { + logger.Log(2, "peer update, skipping self") + // skip yourself + continue + } + + peerHost, err := GetHost(peer.HostID.String()) + if err != nil { + logger.Log(1, "no peer host", peer.HostID.String(), err.Error()) + continue + } + + var allowedToComm bool + if defaultDevicePolicy.Enabled { + allowedToComm = true + } else { + allowedToComm = IsPeerAllowed(node, peer, false) + } + if peer.Action != models.NODE_DELETE && + !peer.PendingDelete && + peer.Connected && + nodeacls.AreNodesAllowed(nodeacls.NetworkID(node.Network), nodeacls.NodeID(node.ID.String()), nodeacls.NodeID(peer.ID.String())) && + (defaultDevicePolicy.Enabled || allowedToComm) { + + networkPeersInfo[peerHost.PublicKey.String()] = models.IDandAddr{ + ID: peer.ID.String(), + HostID: peerHost.ID.String(), + Address: peer.PrimaryAddress(), + Name: peerHost.Name, + Network: peer.Network, + ListenPort: peerHost.ListenPort, + } + + } + } + var extPeerIDAndAddrs []models.IDandAddr + if node.IsIngressGateway { + _, extPeerIDAndAddrs, _, err = GetExtPeers(&node, &node) + if err == nil { + for _, extPeerIdAndAddr := range extPeerIDAndAddrs { + networkPeersInfo[extPeerIdAndAddr.ID] = extPeerIdAndAddr + } + } + } + peerInfo.NetworkPeerIDs[models.NetworkID(node.Network)] = networkPeersInfo + } + return peerInfo, nil +} + // GetPeerUpdateForHost - gets the consolidated peer update for the host from all networks func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.Node, deletedNode *models.Node, deletedClients []models.ExtClient) (models.HostPeerUpdate, error) { @@ -295,15 +369,19 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N peerConfig.Endpoint.IP = peer.LocalAddress.IP peerConfig.Endpoint.Port = peerHost.ListenPort } - allowedips := GetAllowedIPs(&node, &peer, nil) - allowedToComm := IsPeerAllowed(node, peer, false) + var allowedToComm bool + if defaultDevicePolicy.Enabled { + allowedToComm = true + } else { + allowedToComm = IsPeerAllowed(node, peer, false) + } if peer.Action != models.NODE_DELETE && !peer.PendingDelete && peer.Connected && nodeacls.AreNodesAllowed(nodeacls.NetworkID(node.Network), nodeacls.NodeID(node.ID.String()), nodeacls.NodeID(peer.ID.String())) && (defaultDevicePolicy.Enabled || allowedToComm) && (deletedNode == nil || (deletedNode != nil && peer.ID.String() != deletedNode.ID.String())) { - peerConfig.AllowedIPs = allowedips // only append allowed IPs if valid connection + peerConfig.AllowedIPs = GetAllowedIPs(&node, &peer, nil) // only append allowed IPs if valid connection } var nodePeer wgtypes.PeerConfig diff --git a/logic/pro/failover b/logic/pro/failover deleted file mode 100644 index e69de29b..00000000 diff --git a/models/mqtt.go b/models/mqtt.go index 4b7ce10a..80c3d5b0 100644 --- a/models/mqtt.go +++ b/models/mqtt.go @@ -6,6 +6,10 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) +type HostPeerInfo struct { + NetworkPeerIDs map[NetworkID]PeerMap `json:"network_peers"` +} + // HostPeerUpdate - struct for host peer updates type HostPeerUpdate struct { Host Host `json:"host"` diff --git a/pro/controllers/failover.go b/pro/controllers/failover.go index a4fddad9..13a9df30 100644 --- a/pro/controllers/failover.go +++ b/pro/controllers/failover.go @@ -19,7 +19,7 @@ import ( // FailOverHandlers - handlers for FailOver func FailOverHandlers(r *mux.Router) { - r.HandleFunc("/api/v1/node/{nodeid}/failover", http.HandlerFunc(getfailOver)). + r.HandleFunc("/api/v1/node/{nodeid}/failover", controller.Authorize(true, false, "host", http.HandlerFunc(getfailOver))). Methods(http.MethodGet) r.HandleFunc("/api/v1/node/{nodeid}/failover", logic.SecurityCheck(true, http.HandlerFunc(createfailOver))). Methods(http.MethodPost) @@ -29,6 +29,8 @@ func FailOverHandlers(r *mux.Router) { Methods(http.MethodPost) r.HandleFunc("/api/v1/node/{nodeid}/failover_me", controller.Authorize(true, false, "host", http.HandlerFunc(failOverME))). Methods(http.MethodPost) + r.HandleFunc("/api/v1/node/{nodeid}/failover_check", controller.Authorize(true, false, "host", http.HandlerFunc(checkfailOverCtx))). + Methods(http.MethodGet) } // @Summary Get failover node @@ -44,7 +46,6 @@ func getfailOver(w http.ResponseWriter, r *http.Request) { // confirm host exists node, err := logic.GetNodeByID(nodeid) if err != nil { - slog.Error("failed to get node:", "node", nodeid, "error", err.Error()) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } @@ -140,6 +141,7 @@ func deletefailOver(w http.ResponseWriter, r *http.Request) { logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } + proLogic.RemoveFailOverFromCache(node.Network) go func() { proLogic.ResetFailOver(&node) mq.PublishPeerUpdate(false) @@ -265,10 +267,9 @@ func failOverME(w http.ResponseWriter, r *http.Request) { ) return } - err = proLogic.SetFailOverCtx(failOverNode, node, peerNode) if err != nil { - slog.Error("failed to create failover", "id", node.ID.String(), + slog.Debug("failed to create failover", "id", node.ID.String(), "network", node.Network, "error", err) logic.ReturnErrorResponse( w, @@ -293,3 +294,135 @@ func failOverME(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") logic.ReturnSuccessResponse(w, r, "relayed successfully") } + +// @Summary checkfailOverCtx +// @Router /api/v1/node/{nodeid}/failover_check [get] +// @Tags PRO +// @Param nodeid path string true "Node ID" +// @Accept json +// @Param body body models.FailOverMeReq true "Failover request" +// @Success 200 {object} models.SuccessResponse +// @Failure 400 {object} models.ErrorResponse +// @Failure 500 {object} models.ErrorResponse +func checkfailOverCtx(w http.ResponseWriter, r *http.Request) { + var params = mux.Vars(r) + nodeid := params["nodeid"] + // confirm host exists + node, err := logic.GetNodeByID(nodeid) + if err != nil { + logger.Log(0, r.Header.Get("user"), "failed to get node:", err.Error()) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) + return + } + host, err := logic.GetHost(node.HostID.String()) + if err != nil { + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) + return + } + + failOverNode, exists := proLogic.FailOverExists(node.Network) + if !exists { + logic.ReturnErrorResponse( + w, + r, + logic.FormatError( + fmt.Errorf("req-from: %s, failover node doesn't exist in the network", host.Name), + "badrequest", + ), + ) + return + } + var failOverReq models.FailOverMeReq + err = json.NewDecoder(r.Body).Decode(&failOverReq) + if err != nil { + logger.Log(0, r.Header.Get("user"), "error decoding request body: ", err.Error()) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) + return + } + peerNode, err := logic.GetNodeByID(failOverReq.NodeID) + if err != nil { + slog.Error("peer not found: ", "nodeid", failOverReq.NodeID, "error", err) + logic.ReturnErrorResponse( + w, + r, + logic.FormatError(errors.New("peer not found"), "badrequest"), + ) + return + } + if peerNode.IsFailOver { + logic.ReturnErrorResponse( + w, + r, + logic.FormatError(errors.New("peer is acting as failover"), "badrequest"), + ) + return + } + if node.IsFailOver { + logic.ReturnErrorResponse( + w, + r, + logic.FormatError(errors.New("node is acting as failover"), "badrequest"), + ) + return + } + if peerNode.IsFailOver { + logic.ReturnErrorResponse( + w, + r, + logic.FormatError(errors.New("peer is acting as failover"), "badrequest"), + ) + return + } + if node.IsRelayed && node.RelayedBy == peerNode.ID.String() { + logic.ReturnErrorResponse( + w, + r, + logic.FormatError(errors.New("node is relayed by peer node"), "badrequest"), + ) + return + } + if node.IsRelay && peerNode.RelayedBy == node.ID.String() { + logic.ReturnErrorResponse( + w, + r, + logic.FormatError(errors.New("node acting as relay for the peer node"), "badrequest"), + ) + return + } + if node.IsInternetGateway && peerNode.InternetGwID == node.ID.String() { + logic.ReturnErrorResponse( + w, + r, + logic.FormatError( + errors.New("node acting as internet gw for the peer node"), + "badrequest", + ), + ) + return + } + if node.InternetGwID != "" && node.InternetGwID == peerNode.ID.String() { + logic.ReturnErrorResponse( + w, + r, + logic.FormatError( + errors.New("node using a internet gw by the peer node"), + "badrequest", + ), + ) + return + } + + err = proLogic.CheckFailOverCtx(failOverNode, node, peerNode) + if err != nil { + slog.Error("failover ctx cannot be set ", "error", err) + logic.ReturnErrorResponse( + w, + r, + logic.FormatError(fmt.Errorf("failover ctx cannot be set: %v", err), "internal"), + ) + return + } + + w.Header().Set("Content-Type", "application/json") + logic.ReturnSuccessResponse(w, r, "failover can be set") +} diff --git a/pro/initialize.go b/pro/initialize.go index 6c6587f9..4292d3f9 100644 --- a/pro/initialize.go +++ b/pro/initialize.go @@ -90,6 +90,7 @@ func InitPro() { slog.Error("no OAuth provider found or not configured, continuing without OAuth") } proLogic.LoadNodeMetricsToCache() + proLogic.InitFailOverCache() }) logic.ResetFailOver = proLogic.ResetFailOver logic.ResetFailedOverPeer = proLogic.ResetFailedOverPeer diff --git a/pro/logic/failover.go b/pro/logic/failover.go index 4b0debdc..d4ac5ff6 100644 --- a/pro/logic/failover.go +++ b/pro/logic/failover.go @@ -13,7 +13,49 @@ import ( ) var failOverCtxMutex = &sync.RWMutex{} +var failOverCacheMutex = &sync.RWMutex{} +var failOverCache = make(map[models.NetworkID]string) +func InitFailOverCache() { + failOverCacheMutex.Lock() + defer failOverCacheMutex.Unlock() + networks, err := logic.GetNetworks() + if err != nil { + return + } + allNodes, err := logic.GetAllNodes() + if err != nil { + return + } + + for _, network := range networks { + networkNodes := logic.GetNetworkNodesMemory(allNodes, network.NetID) + for _, node := range networkNodes { + if node.IsFailOver { + failOverCache[models.NetworkID(network.NetID)] = node.ID.String() + break + } + } + } +} + +func CheckFailOverCtx(failOverNode, victimNode, peerNode models.Node) error { + failOverCtxMutex.RLock() + defer failOverCtxMutex.RUnlock() + if peerNode.FailOverPeers == nil { + return nil + } + if victimNode.FailOverPeers == nil { + return nil + } + _, peerHasFailovered := peerNode.FailOverPeers[victimNode.ID.String()] + _, victimHasFailovered := victimNode.FailOverPeers[peerNode.ID.String()] + if peerHasFailovered && victimHasFailovered && + victimNode.FailedOverBy == failOverNode.ID && peerNode.FailedOverBy == failOverNode.ID { + return errors.New("failover ctx is already set") + } + return nil +} func SetFailOverCtx(failOverNode, victimNode, peerNode models.Node) error { failOverCtxMutex.Lock() defer failOverCtxMutex.Unlock() @@ -23,13 +65,16 @@ func SetFailOverCtx(failOverNode, victimNode, peerNode models.Node) error { if victimNode.FailOverPeers == nil { victimNode.FailOverPeers = make(map[string]struct{}) } + _, peerHasFailovered := peerNode.FailOverPeers[victimNode.ID.String()] + _, victimHasFailovered := victimNode.FailOverPeers[peerNode.ID.String()] + if peerHasFailovered && victimHasFailovered && + victimNode.FailedOverBy == failOverNode.ID && peerNode.FailedOverBy == failOverNode.ID { + return errors.New("failover ctx is already set") + } peerNode.FailOverPeers[victimNode.ID.String()] = struct{}{} victimNode.FailOverPeers[peerNode.ID.String()] = struct{}{} victimNode.FailedOverBy = failOverNode.ID peerNode.FailedOverBy = failOverNode.ID - if err := logic.UpsertNode(&failOverNode); err != nil { - return err - } if err := logic.UpsertNode(&victimNode); err != nil { return err } @@ -50,17 +95,26 @@ func GetFailOverNode(network string, allNodes []models.Node) (models.Node, error return models.Node{}, errors.New("auto relay not found") } +func RemoveFailOverFromCache(network string) { + failOverCacheMutex.Lock() + defer failOverCacheMutex.Unlock() + delete(failOverCache, models.NetworkID(network)) +} + +func SetFailOverInCache(node models.Node) { + failOverCacheMutex.Lock() + defer failOverCacheMutex.Unlock() + failOverCache[models.NetworkID(node.Network)] = node.ID.String() +} + // FailOverExists - checks if failOver exists already in the network func FailOverExists(network string) (failOverNode models.Node, exists bool) { - nodes, err := logic.GetNetworkNodes(network) - if err != nil { - return - } - for _, node := range nodes { - if node.IsFailOver { - exists = true - failOverNode = node - return + failOverCacheMutex.RLock() + defer failOverCacheMutex.RUnlock() + if nodeID, ok := failOverCache[models.NetworkID(network)]; ok { + failOverNode, err := logic.GetNodeByID(nodeID) + if err == nil { + return failOverNode, true } } return @@ -185,5 +239,6 @@ func CreateFailOver(node models.Node) error { slog.Error("failed to upsert node", "node", node.ID.String(), "error", err) return err } + SetFailOverInCache(node) return nil }