diff --git a/controllers/controller.go b/controllers/controller.go index 571671bd..38dc3ecd 100644 --- a/controllers/controller.go +++ b/controllers/controller.go @@ -4,11 +4,8 @@ import ( "context" "fmt" "net/http" - "os" - "os/signal" "strings" "sync" - "syscall" "time" "github.com/gorilla/handlers" @@ -32,7 +29,7 @@ var HttpHandlers = []interface{}{ } // HandleRESTRequests - handles the rest requests -func HandleRESTRequests(wg *sync.WaitGroup) { +func HandleRESTRequests(wg *sync.WaitGroup, ctx context.Context) { defer wg.Done() r := mux.NewRouter() @@ -58,18 +55,14 @@ func HandleRESTRequests(wg *sync.WaitGroup) { }() logger.Log(0, "REST Server successfully started on port ", port, " (REST)") - // Relay os.Interrupt to our channel (os.Interrupt = CTRL+C) - // Ignore other incoming signals - ctx, stop := signal.NotifyContext(context.TODO(), syscall.SIGTERM, os.Interrupt) - defer stop() - // Block main routine until a signal is received // As long as user doesn't press CTRL+C a message is not passed and our main routine keeps running <-ctx.Done() - // After receiving CTRL+C Properly stop the server logger.Log(0, "Stopping the REST server...") + if err := srv.Shutdown(context.TODO()); err != nil { + logger.Log(0, "REST shutdown error occurred -", err.Error()) + } logger.Log(0, "REST Server closed.") logger.DumpFile(fmt.Sprintf("data/netmaker.log.%s", time.Now().Format(logger.TimeFormatDay))) - srv.Shutdown(context.TODO()) } diff --git a/logic/nodes.go b/logic/nodes.go index 45a0ac2c..ba4f4687 100644 --- a/logic/nodes.go +++ b/logic/nodes.go @@ -1,7 +1,6 @@ package logic import ( - "context" "encoding/json" "errors" "fmt" @@ -421,35 +420,6 @@ func updateProNodeACLS(node *models.Node) error { return nil } -func PurgePendingNodes(ctx context.Context) { - ticker := time.NewTicker(NodePurgeCheckTime) - defer ticker.Stop() - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - nodes, err := GetAllNodes() - if err != nil { - logger.Log(0, "PurgePendingNodes failed to retrieve nodes", err.Error()) - continue - } - for _, node := range nodes { - if node.PendingDelete { - modified := node.LastModified - if time.Since(modified) > NodePurgeTime { - if err := DeleteNode(&node, true); err != nil { - logger.Log(0, "failed to purge node", node.ID.String(), err.Error()) - } else { - logger.Log(0, "purged node ", node.ID.String()) - } - } - } - } - } - } -} - // createNode - creates a node in database func createNode(node *models.Node) error { host, err := GetHost(node.HostID.String()) diff --git a/main.go b/main.go index 53cf4712..0cda211e 100644 --- a/main.go +++ b/main.go @@ -36,12 +36,16 @@ func main() { setupConfig(*absoluteConfigPath) servercfg.SetVersion(version) fmt.Println(models.RetrieveLogo()) // print the logo - // fmt.Println(models.ProLogo()) - initialize() // initial db and acls; gen cert if required + initialize() // initial db and acls setGarbageCollection() setVerbosity() defer database.CloseDB() - startControllers() // start the api endpoint and mq + ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM, os.Interrupt) + defer stop() + var waitGroup sync.WaitGroup + startControllers(&waitGroup, ctx) // start the api endpoint and mq and stun + <-ctx.Done() + waitGroup.Wait() } func setupConfig(absoluteConfigPath string) { @@ -110,8 +114,7 @@ func initialize() { // Client Mode Prereq Check } } -func startControllers() { - var waitnetwork sync.WaitGroup +func startControllers(wg *sync.WaitGroup, ctx context.Context) { if servercfg.IsDNSMode() { err := logic.SetDNS() if err != nil { @@ -127,13 +130,13 @@ func startControllers() { logger.FatalLog("Unable to Set host. Exiting...", err.Error()) } } - waitnetwork.Add(1) - go controller.HandleRESTRequests(&waitnetwork) + wg.Add(1) + go controller.HandleRESTRequests(wg, ctx) } //Run MessageQueue if servercfg.IsMessageQueueBackend() { - waitnetwork.Add(1) - go runMessageQueue(&waitnetwork) + wg.Add(1) + go runMessageQueue(wg, ctx) } if !servercfg.IsRestBackend() && !servercfg.IsMessageQueueBackend() { @@ -141,34 +144,17 @@ func startControllers() { } // starts the stun server - waitnetwork.Add(1) - go stunserver.Start(&waitnetwork) - if servercfg.IsProxyEnabled() { - - waitnetwork.Add(1) - go func() { - defer waitnetwork.Done() - _, cancel := context.WithCancel(context.Background()) - waitnetwork.Add(1) - - //go nmproxy.Start(ctx, logic.ProxyMgmChan, servercfg.GetAPIHost()) - quit := make(chan os.Signal, 1) - signal.Notify(quit, syscall.SIGTERM, os.Interrupt) - <-quit - cancel() - }() - } - - waitnetwork.Wait() + wg.Add(1) + go stunserver.Start(wg, ctx) } // Should we be using a context vice a waitgroup???????????? -func runMessageQueue(wg *sync.WaitGroup) { +func runMessageQueue(wg *sync.WaitGroup, ctx context.Context) { defer wg.Done() brokerHost, secure := servercfg.GetMessageQueueEndpoint() logger.Log(0, "connecting to mq broker at", brokerHost, "with TLS?", fmt.Sprintf("%v", secure)) mq.SetupMQTT() - ctx, cancel := context.WithCancel(context.Background()) + defer mq.CloseClient() go mq.Keepalive(ctx) go func() { peerUpdate := make(chan *models.Node) @@ -179,11 +165,7 @@ func runMessageQueue(wg *sync.WaitGroup) { } } }() - go logic.PurgePendingNodes(ctx) - quit := make(chan os.Signal, 1) - signal.Notify(quit, syscall.SIGTERM, os.Interrupt) - <-quit - cancel() + <-ctx.Done() logger.Log(0, "Message Queue shutting down") } diff --git a/mq/mq.go b/mq/mq.go index 056dc925..9fabc90f 100644 --- a/mq/mq.go +++ b/mq/mq.go @@ -100,3 +100,8 @@ func Keepalive(ctx context.Context) { func IsConnected() bool { return mqclient != nil && mqclient.IsConnected() } + +// CloseClient - function to close the mq connection from server +func CloseClient() { + mqclient.Disconnect(250) +} diff --git a/stun-server/stun-server.go b/stun-server/stun-server.go index 7e4b768e..9bc22b14 100644 --- a/stun-server/stun-server.go +++ b/stun-server/stun-server.go @@ -4,11 +4,8 @@ import ( "context" "fmt" "net" - "os" - "os/signal" "strings" "sync" - "syscall" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/servercfg" @@ -23,7 +20,6 @@ import ( // backwards compatibility with RFC 3489. type Server struct { Addr string - Ctx context.Context } var ( @@ -60,48 +56,58 @@ func basicProcess(addr net.Addr, b []byte, req, res *stun.Message) error { ) } -func (s *Server) serveConn(c net.PacketConn, res, req *stun.Message) error { +func (s *Server) serveConn(c net.PacketConn, res, req *stun.Message, ctx context.Context) error { if c == nil { return nil } + go func(ctx context.Context) { + <-ctx.Done() + if c != nil { + // kill connection on server shutdown + c.Close() + } + }(ctx) + buf := make([]byte, 1024) - n, addr, err := c.ReadFrom(buf) + n, addr, err := c.ReadFrom(buf) // this be blocky af if err != nil { - logger.Log(1, "ReadFrom: %v", err.Error()) + if !strings.Contains(err.Error(), "use of closed network connection") { + logger.Log(1, "STUN read error:", err.Error()) + } return nil } + if _, err = req.Write(buf[:n]); err != nil { - logger.Log(1, "Write: %v", err.Error()) + logger.Log(1, "STUN write error:", err.Error()) return err } if err = basicProcess(addr, buf[:n], req, res); err != nil { if err == errNotSTUNMessage { return nil } - logger.Log(1, "basicProcess: %v", err.Error()) + logger.Log(1, "STUN process error:", err.Error()) return nil } _, err = c.WriteTo(res.Raw, addr) if err != nil { - logger.Log(1, "WriteTo: %v", err.Error()) + logger.Log(1, "STUN response write error", err.Error()) } return err } // Serve reads packets from connections and responds to BINDING requests. -func (s *Server) serve(c net.PacketConn) error { +func (s *Server) serve(c net.PacketConn, ctx context.Context) error { var ( res = new(stun.Message) req = new(stun.Message) ) for { select { - case <-s.Ctx.Done(): - logger.Log(0, "Shutting down stun server...") - c.Close() + case <-ctx.Done(): + logger.Log(0, "shut down STUN server") return nil default: - if err := s.serveConn(c, res, req); err != nil { + if err := s.serveConn(c, res, req, ctx); err != nil { logger.Log(1, "serve: %v", err.Error()) continue } @@ -119,9 +125,8 @@ func listenUDPAndServe(ctx context.Context, serverNet, laddr string) error { } s := &Server{ Addr: laddr, - Ctx: ctx, } - return s.serve(c) + return s.serve(c, ctx) } func normalize(address string) string { @@ -135,19 +140,15 @@ func normalize(address string) string { } // Start - starts the stun server -func Start(wg *sync.WaitGroup) { - ctx, cancel := context.WithCancel(context.Background()) - go func(wg *sync.WaitGroup) { - defer wg.Done() - quit := make(chan os.Signal, 1) - signal.Notify(quit, syscall.SIGTERM, os.Interrupt) - <-quit - cancel() - }(wg) +func Start(wg *sync.WaitGroup, ctx context.Context) { + defer wg.Done() normalized := normalize(fmt.Sprintf("0.0.0.0:%d", servercfg.GetStunPort())) logger.Log(0, "netmaker-stun listening on", normalized, "via udp") - err := listenUDPAndServe(ctx, "udp", normalized) - if err != nil { - logger.Log(0, "failed to start stun server: ", err.Error()) + if err := listenUDPAndServe(ctx, "udp", normalized); err != nil { + if strings.Contains(err.Error(), "closed network connection") { + logger.Log(0, "shutdown STUN server") + } else { + logger.Log(0, "server: ", err.Error()) + } } }