diff --git a/main.go b/main.go index a5d79299..5fb8607a 100644 --- a/main.go +++ b/main.go @@ -173,14 +173,24 @@ func startControllers() { logger.Log(0, "No Server Mode selected, so nothing is being served! Set Agent mode (AGENT_BACKEND) or Rest mode (REST_BACKEND) or MessageQueue (MESSAGEQUEUE_BACKEND) to 'true'.") } // starts the stun server - go stunserver.Start() - go nmproxy.Start(logic.ProxyMgmChan) + waitnetwork.Add(1) + go stunserver.Start(&waitnetwork) + waitnetwork.Add(1) go func() { + defer waitnetwork.Done() + ctx, cancel := context.WithCancel(context.Background()) + waitnetwork.Add(1) + go nmproxy.Start(ctx, logic.ProxyMgmChan, servercfg.GetAPIHost()) err := serverctl.SyncServerNetworkWithProxy() if err != nil { logger.Log(0, "failed to sync proxy with server interfaces: ", err.Error()) } + quit := make(chan os.Signal, 1) + signal.Notify(quit, syscall.SIGTERM, os.Interrupt) + <-quit + cancel() }() + waitnetwork.Wait() } diff --git a/netclient/functions/daemon.go b/netclient/functions/daemon.go index 7db09e31..a4cb462c 100644 --- a/netclient/functions/daemon.go +++ b/netclient/functions/daemon.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "log" + "net" "net/http" "os" "os/signal" @@ -33,7 +34,7 @@ import ( var ProxyMgmChan = make(chan *manager.ManagerAction, 100) var messageCache = new(sync.Map) -var ProxyStatus = "OFF" + var serverSet map[string]bool var mqclient mqtt.Client @@ -123,16 +124,19 @@ func startGoRoutines(wg *sync.WaitGroup) context.CancelFunc { } wg.Add(1) go Checkin(ctx, wg) - if ProxyStatus == "OFF" { - ProxyStatus = "ON" - go nmproxy.Start(ProxyMgmChan) - } else { - log.Println("Proxy already running...") + + if len(networks) != 0 { + cfg := config.ClientConfig{} + cfg.Network = networks[0] + cfg.ReadConfig() + apiHost, _, err := net.SplitHostPort(cfg.Server.API) + if err == nil { + go nmproxy.Start(ctx, ProxyMgmChan, apiHost) + } } - go func() { + go func(networks []string) { - networks, _ := ncutils.GetSystemNetworks() for _, network := range networks { logger.Log(0, "Collecting interface and peers info to configure proxy...") cfg := config.ClientConfig{} @@ -153,7 +157,7 @@ func startGoRoutines(wg *sync.WaitGroup) context.CancelFunc { } - }() + }(networks) return cancel } func GetNodeInfo(cfg *config.ClientConfig) (models.NodeGet, error) { diff --git a/nm-proxy/nm-proxy.go b/nm-proxy/nm-proxy.go index 8cbbdead..11089fbe 100644 --- a/nm-proxy/nm-proxy.go +++ b/nm-proxy/nm-proxy.go @@ -1,6 +1,7 @@ package nmproxy import ( + "context" "log" "net" "os" @@ -17,11 +18,11 @@ import ( 2. Delete - remove close all conns for the interface,cleanup */ -func Start(mgmChan chan *manager.ManagerAction) { +func Start(ctx context.Context, mgmChan chan *manager.ManagerAction, apiServerAddr string) { log.Println("Starting Proxy...") common.IsHostNetwork = (os.Getenv("HOST_NETWORK") == "" || os.Getenv("HOST_NETWORK") == "on") go manager.StartProxyManager(mgmChan) - hInfo := stun.GetHostInfo() + hInfo := stun.GetHostInfo(apiServerAddr) stun.Host = hInfo log.Printf("HOSTINFO: %+v", hInfo) if IsPublicIP(hInfo.PrivIp) { @@ -32,7 +33,8 @@ func Start(mgmChan chan *manager.ManagerAction) { if err != nil { log.Fatal("failed to create proxy: ", err) } - server.NmProxyServer.Listen() + server.NmProxyServer.Listen(ctx) + } // IsPublicIP indicates whether IP is public or not. diff --git a/nm-proxy/server/server.go b/nm-proxy/server/server.go index 06bfb847..ec92f9fb 100644 --- a/nm-proxy/server/server.go +++ b/nm-proxy/server/server.go @@ -1,6 +1,7 @@ package server import ( + "context" "fmt" "log" "net" @@ -32,63 +33,80 @@ type ProxyServer struct { } // Proxy.Listen - begins listening for packets -func (p *ProxyServer) Listen() { +func (p *ProxyServer) Listen(ctx context.Context) { // Buffer with indicated body size buffer := make([]byte, 1532) for { - // Read Packet - n, source, err := p.Server.ReadFromUDP(buffer) - if err != nil { // in future log errors? - log.Println("RECV ERROR: ", err) - continue - } - var srcPeerKeyHash, dstPeerKeyHash string - n, srcPeerKeyHash, dstPeerKeyHash = packet.ExtractInfo(buffer, n) - //log.Printf("--------> RECV PKT [DSTPORT: %d], [SRCKEYHASH: %s], SourceIP: [%s] \n", localWgPort, srcPeerKeyHash, source.IP.String()) - if common.IsRelay && dstPeerKeyHash != "" && srcPeerKeyHash != "" { - if _, ok := common.WgIfaceKeyMap[dstPeerKeyHash]; !ok { - log.Println("----------> Relaying######") - // check for routing map and forward to right proxy - if remoteMap, ok := common.RelayPeerMap[srcPeerKeyHash]; ok { - if conf, ok := remoteMap[dstPeerKeyHash]; ok { - log.Printf("--------> Relaying PKT [ SourceIP: %s:%d ], [ SourceKeyHash: %s ], [ DstIP: %s:%d ], [ DstHashKey: %s ] \n", - source.IP.String(), source.Port, srcPeerKeyHash, conf.Endpoint.String(), conf.Endpoint.Port, dstPeerKeyHash) - _, err = NmProxyServer.Server.WriteToUDP(buffer[:n+32], conf.Endpoint) - if err != nil { - log.Println("Failed to send to remote: ", err) - } - } - } else { - if remoteMap, ok := common.RelayPeerMap[dstPeerKeyHash]; ok { + select { + case <-ctx.Done(): + log.Println("--------->### Shutting down Proxy.....") + // clean up proxy connections + for iface, peers := range common.WgIFaceMap { + log.Println("########------------> CLEANING UP: ", iface) + for _, peerI := range peers { + peerI.Proxy.Cancel() + } + } + // close server connection + NmProxyServer.Server.Close() + return + default: + // Read Packet + n, source, err := p.Server.ReadFromUDP(buffer) + if err != nil { // in future log errors? + log.Println("RECV ERROR: ", err) + continue + } + var srcPeerKeyHash, dstPeerKeyHash string + n, srcPeerKeyHash, dstPeerKeyHash = packet.ExtractInfo(buffer, n) + //log.Printf("--------> RECV PKT [DSTPORT: %d], [SRCKEYHASH: %s], SourceIP: [%s] \n", localWgPort, srcPeerKeyHash, source.IP.String()) + if common.IsRelay && dstPeerKeyHash != "" && srcPeerKeyHash != "" { + if _, ok := common.WgIfaceKeyMap[dstPeerKeyHash]; !ok { + + log.Println("----------> Relaying######") + // check for routing map and forward to right proxy + if remoteMap, ok := common.RelayPeerMap[srcPeerKeyHash]; ok { if conf, ok := remoteMap[dstPeerKeyHash]; ok { - log.Printf("--------> Relaying BACK TO RELAYED NODE PKT [ SourceIP: %s ], [ SourceKeyHash: %s ], [ DstIP: %s ], [ DstHashKey: %s ] \n", - source.String(), srcPeerKeyHash, conf.Endpoint.String(), dstPeerKeyHash) + log.Printf("--------> Relaying PKT [ SourceIP: %s:%d ], [ SourceKeyHash: %s ], [ DstIP: %s:%d ], [ DstHashKey: %s ] \n", + source.IP.String(), source.Port, srcPeerKeyHash, conf.Endpoint.String(), conf.Endpoint.Port, dstPeerKeyHash) _, err = NmProxyServer.Server.WriteToUDP(buffer[:n+32], conf.Endpoint) if err != nil { log.Println("Failed to send to remote: ", err) } } + } else { + if remoteMap, ok := common.RelayPeerMap[dstPeerKeyHash]; ok { + if conf, ok := remoteMap[dstPeerKeyHash]; ok { + log.Printf("--------> Relaying BACK TO RELAYED NODE PKT [ SourceIP: %s ], [ SourceKeyHash: %s ], [ DstIP: %s ], [ DstHashKey: %s ] \n", + source.String(), srcPeerKeyHash, conf.Endpoint.String(), dstPeerKeyHash) + _, err = NmProxyServer.Server.WriteToUDP(buffer[:n+32], conf.Endpoint) + if err != nil { + log.Println("Failed to send to remote: ", err) + } + } + } } - } + } } - } - if peerInfo, ok := common.PeerKeyHashMap[srcPeerKeyHash]; ok { - if peers, ok := common.WgIFaceMap[peerInfo.Interface]; ok { - if peerI, ok := peers[peerInfo.PeerKey]; ok { - log.Printf("PROXING TO LOCAL!!!---> %s <<<< %s <<<<<<<< %s [[ RECV PKT [SRCKEYHASH: %s], [DSTKEYHASH: %s], SourceIP: [%s] ]]\n", - peerI.Proxy.LocalConn.RemoteAddr(), peerI.Proxy.LocalConn.LocalAddr(), - fmt.Sprintf("%s:%d", source.IP.String(), source.Port), srcPeerKeyHash, dstPeerKeyHash, source.IP.String()) - _, err = peerI.Proxy.LocalConn.Write(buffer[:n]) - if err != nil { - log.Println("Failed to proxy to Wg local interface: ", err) - continue + if peerInfo, ok := common.PeerKeyHashMap[srcPeerKeyHash]; ok { + if peers, ok := common.WgIFaceMap[peerInfo.Interface]; ok { + if peerI, ok := peers[peerInfo.PeerKey]; ok { + log.Printf("PROXING TO LOCAL!!!---> %s <<<< %s <<<<<<<< %s [[ RECV PKT [SRCKEYHASH: %s], [DSTKEYHASH: %s], SourceIP: [%s] ]]\n", + peerI.Proxy.LocalConn.RemoteAddr(), peerI.Proxy.LocalConn.LocalAddr(), + fmt.Sprintf("%s:%d", source.IP.String(), source.Port), srcPeerKeyHash, dstPeerKeyHash, source.IP.String()) + _, err = peerI.Proxy.LocalConn.Write(buffer[:n]) + if err != nil { + log.Println("Failed to proxy to Wg local interface: ", err) + continue + } + } - } + } } diff --git a/nm-proxy/stun/stun.go b/nm-proxy/stun/stun.go index 558e1d2e..8b282827 100644 --- a/nm-proxy/stun/stun.go +++ b/nm-proxy/stun/stun.go @@ -20,11 +20,12 @@ type HostInfo struct { var Host HostInfo -func GetHostInfo() (info HostInfo) { +func GetHostInfo(stunHostAddr string) (info HostInfo) { - s, err := net.ResolveUDPAddr("udp", "stun.nm.134.209.115.146.nip.io:3478") + s, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:3478", stunHostAddr)) if err != nil { log.Println("Resolve: ", err) + return } l := &net.UDPAddr{ IP: net.ParseIP(""), @@ -32,13 +33,14 @@ func GetHostInfo() (info HostInfo) { } conn, err := net.DialUDP("udp", l, s) if err != nil { - log.Fatal(err) + log.Println(err) + return } defer conn.Close() - fmt.Printf("%+v\n", conn.LocalAddr()) c, err := stun.NewClient(conn) if err != nil { - panic(err) + log.Println(err) + return } defer c.Close() re := strings.Split(conn.LocalAddr().String(), ":") @@ -49,17 +51,19 @@ func GetHostInfo() (info HostInfo) { // Sending request to STUN server, waiting for response message. if err := c.Do(message, func(res stun.Event) { if res.Error != nil { - panic(res.Error) + log.Println("stun error: ", res.Error) + return } // Decoding XOR-MAPPED-ADDRESS attribute from message. var xorAddr stun.XORMappedAddress if err := xorAddr.GetFrom(res.Message); err != nil { - panic(err) + log.Println("stun error: ", res.Error) + return } info.PublicIp = xorAddr.IP info.PubPort = xorAddr.Port }); err != nil { - panic(err) + log.Println("stun error: ", err) } return } diff --git a/stun-server/stun-server.go b/stun-server/stun-server.go index c0728de0..5081fed6 100644 --- a/stun-server/stun-server.go +++ b/stun-server/stun-server.go @@ -1,10 +1,15 @@ package stunserver import ( + "context" "fmt" "log" "net" + "os" + "os/signal" "strings" + "sync" + "syscall" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/servercfg" @@ -19,9 +24,8 @@ import ( // nor ALTERNATE-SERVER, nor credentials mechanisms. It does not support // backwards compatibility with RFC 3489. type Server struct { - Addr string - LogAllErrors bool - log Logger + Addr string + Ctx context.Context } // Logger is used for logging formatted messages. @@ -72,54 +76,62 @@ func (s *Server) serveConn(c net.PacketConn, res, req *stun.Message) error { buf := make([]byte, 1024) n, addr, err := c.ReadFrom(buf) if err != nil { - s.log.Printf("ReadFrom: %v", err) + logger.Log(1, "ReadFrom: %v", err.Error()) return nil } log.Printf("read %d bytes from %s\n", n, addr) if _, err = req.Write(buf[:n]); err != nil { - s.log.Printf("Write: %v", err) + logger.Log(1, "Write: %v", err.Error()) return err } if err = basicProcess(addr, buf[:n], req, res); err != nil { if err == errNotSTUNMessage { return nil } - s.log.Printf("basicProcess: %v", err) + logger.Log(1, "basicProcess: %v", err.Error()) return nil } _, err = c.WriteTo(res.Raw, addr) if err != nil { - s.log.Printf("WriteTo: %v", err) + logger.Log(1, "WriteTo: %v", err.Error()) } return err } // Serve reads packets from connections and responds to BINDING requests. -func (s *Server) Serve(c net.PacketConn) error { +func (s *Server) serve(c net.PacketConn) error { var ( res = new(stun.Message) req = new(stun.Message) ) for { - if err := s.serveConn(c, res, req); err != nil { - s.log.Printf("serve: %v", err) - return err + select { + case <-s.Ctx.Done(): + logger.Log(0, "Shutting down stun server...") + c.Close() + return nil + default: + if err := s.serveConn(c, res, req); err != nil { + logger.Log(1, "serve: %v", err.Error()) + continue + } + res.Reset() + req.Reset() } - res.Reset() - req.Reset() } } -// ListenUDPAndServe listens on laddr and process incoming packets. -func ListenUDPAndServe(serverNet, laddr string) error { +// listenUDPAndServe listens on laddr and process incoming packets. +func listenUDPAndServe(ctx context.Context, serverNet, laddr string) error { c, err := net.ListenPacket(serverNet, laddr) if err != nil { return err } s := &Server{ - log: defaultLogger, + Addr: laddr, + Ctx: ctx, } - return s.Serve(c) + return s.serve(c) } func normalize(address string) string { @@ -132,11 +144,18 @@ func normalize(address string) string { return address } -func Start() { - +func Start(wg *sync.WaitGroup) { + defer wg.Done() + ctx, cancel := context.WithCancel(context.Background()) + go func() { + quit := make(chan os.Signal, 1) + signal.Notify(quit, syscall.SIGTERM, os.Interrupt) + <-quit + cancel() + }() normalized := normalize(fmt.Sprintf("0.0.0.0:%s", servercfg.GetStunPort())) logger.Log(0, "netmaker-stun listening on", normalized, "via udp") - err := ListenUDPAndServe("udp", normalized) + err := listenUDPAndServe(ctx, "udp", normalized) if err != nil { logger.Log(0, "failed to start stun server: ", err.Error()) }