adjusted main to use one single context

This commit is contained in:
0xdcarns 2023-02-24 15:37:53 -05:00
parent 3244472481
commit 5014c389ca
5 changed files with 56 additions and 105 deletions

View file

@ -4,11 +4,8 @@ import (
"context" "context"
"fmt" "fmt"
"net/http" "net/http"
"os"
"os/signal"
"strings" "strings"
"sync" "sync"
"syscall"
"time" "time"
"github.com/gorilla/handlers" "github.com/gorilla/handlers"
@ -32,7 +29,7 @@ var HttpHandlers = []interface{}{
} }
// HandleRESTRequests - handles the rest requests // HandleRESTRequests - handles the rest requests
func HandleRESTRequests(wg *sync.WaitGroup) { func HandleRESTRequests(wg *sync.WaitGroup, ctx context.Context) {
defer wg.Done() defer wg.Done()
r := mux.NewRouter() r := mux.NewRouter()
@ -58,18 +55,14 @@ func HandleRESTRequests(wg *sync.WaitGroup) {
}() }()
logger.Log(0, "REST Server successfully started on port ", port, " (REST)") logger.Log(0, "REST Server successfully started on port ", port, " (REST)")
// Relay os.Interrupt to our channel (os.Interrupt = CTRL+C)
// Ignore other incoming signals
ctx, stop := signal.NotifyContext(context.TODO(), syscall.SIGTERM, os.Interrupt)
defer stop()
// Block main routine until a signal is received // Block main routine until a signal is received
// As long as user doesn't press CTRL+C a message is not passed and our main routine keeps running // As long as user doesn't press CTRL+C a message is not passed and our main routine keeps running
<-ctx.Done() <-ctx.Done()
// After receiving CTRL+C Properly stop the server // After receiving CTRL+C Properly stop the server
logger.Log(0, "Stopping the REST server...") logger.Log(0, "Stopping the REST server...")
if err := srv.Shutdown(context.TODO()); err != nil {
logger.Log(0, "REST shutdown error occurred -", err.Error())
}
logger.Log(0, "REST Server closed.") logger.Log(0, "REST Server closed.")
logger.DumpFile(fmt.Sprintf("data/netmaker.log.%s", time.Now().Format(logger.TimeFormatDay))) logger.DumpFile(fmt.Sprintf("data/netmaker.log.%s", time.Now().Format(logger.TimeFormatDay)))
srv.Shutdown(context.TODO())
} }

View file

@ -1,7 +1,6 @@
package logic package logic
import ( import (
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -421,35 +420,6 @@ func updateProNodeACLS(node *models.Node) error {
return nil return nil
} }
func PurgePendingNodes(ctx context.Context) {
ticker := time.NewTicker(NodePurgeCheckTime)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
nodes, err := GetAllNodes()
if err != nil {
logger.Log(0, "PurgePendingNodes failed to retrieve nodes", err.Error())
continue
}
for _, node := range nodes {
if node.PendingDelete {
modified := node.LastModified
if time.Since(modified) > NodePurgeTime {
if err := DeleteNode(&node, true); err != nil {
logger.Log(0, "failed to purge node", node.ID.String(), err.Error())
} else {
logger.Log(0, "purged node ", node.ID.String())
}
}
}
}
}
}
}
// createNode - creates a node in database // createNode - creates a node in database
func createNode(node *models.Node) error { func createNode(node *models.Node) error {
host, err := GetHost(node.HostID.String()) host, err := GetHost(node.HostID.String())

52
main.go
View file

@ -36,12 +36,16 @@ func main() {
setupConfig(*absoluteConfigPath) setupConfig(*absoluteConfigPath)
servercfg.SetVersion(version) servercfg.SetVersion(version)
fmt.Println(models.RetrieveLogo()) // print the logo fmt.Println(models.RetrieveLogo()) // print the logo
// fmt.Println(models.ProLogo()) initialize() // initial db and acls
initialize() // initial db and acls; gen cert if required
setGarbageCollection() setGarbageCollection()
setVerbosity() setVerbosity()
defer database.CloseDB() defer database.CloseDB()
startControllers() // start the api endpoint and mq ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM, os.Interrupt)
defer stop()
var waitGroup sync.WaitGroup
startControllers(&waitGroup, ctx) // start the api endpoint and mq and stun
<-ctx.Done()
waitGroup.Wait()
} }
func setupConfig(absoluteConfigPath string) { func setupConfig(absoluteConfigPath string) {
@ -110,8 +114,7 @@ func initialize() { // Client Mode Prereq Check
} }
} }
func startControllers() { func startControllers(wg *sync.WaitGroup, ctx context.Context) {
var waitnetwork sync.WaitGroup
if servercfg.IsDNSMode() { if servercfg.IsDNSMode() {
err := logic.SetDNS() err := logic.SetDNS()
if err != nil { if err != nil {
@ -127,13 +130,13 @@ func startControllers() {
logger.FatalLog("Unable to Set host. Exiting...", err.Error()) logger.FatalLog("Unable to Set host. Exiting...", err.Error())
} }
} }
waitnetwork.Add(1) wg.Add(1)
go controller.HandleRESTRequests(&waitnetwork) go controller.HandleRESTRequests(wg, ctx)
} }
//Run MessageQueue //Run MessageQueue
if servercfg.IsMessageQueueBackend() { if servercfg.IsMessageQueueBackend() {
waitnetwork.Add(1) wg.Add(1)
go runMessageQueue(&waitnetwork) go runMessageQueue(wg, ctx)
} }
if !servercfg.IsRestBackend() && !servercfg.IsMessageQueueBackend() { if !servercfg.IsRestBackend() && !servercfg.IsMessageQueueBackend() {
@ -141,34 +144,17 @@ func startControllers() {
} }
// starts the stun server // starts the stun server
waitnetwork.Add(1) wg.Add(1)
go stunserver.Start(&waitnetwork) go stunserver.Start(wg, ctx)
if servercfg.IsProxyEnabled() {
waitnetwork.Add(1)
go func() {
defer waitnetwork.Done()
_, cancel := context.WithCancel(context.Background())
waitnetwork.Add(1)
//go nmproxy.Start(ctx, logic.ProxyMgmChan, servercfg.GetAPIHost())
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGTERM, os.Interrupt)
<-quit
cancel()
}()
}
waitnetwork.Wait()
} }
// Should we be using a context vice a waitgroup???????????? // Should we be using a context vice a waitgroup????????????
func runMessageQueue(wg *sync.WaitGroup) { func runMessageQueue(wg *sync.WaitGroup, ctx context.Context) {
defer wg.Done() defer wg.Done()
brokerHost, secure := servercfg.GetMessageQueueEndpoint() brokerHost, secure := servercfg.GetMessageQueueEndpoint()
logger.Log(0, "connecting to mq broker at", brokerHost, "with TLS?", fmt.Sprintf("%v", secure)) logger.Log(0, "connecting to mq broker at", brokerHost, "with TLS?", fmt.Sprintf("%v", secure))
mq.SetupMQTT() mq.SetupMQTT()
ctx, cancel := context.WithCancel(context.Background()) defer mq.CloseClient()
go mq.Keepalive(ctx) go mq.Keepalive(ctx)
go func() { go func() {
peerUpdate := make(chan *models.Node) peerUpdate := make(chan *models.Node)
@ -179,11 +165,7 @@ func runMessageQueue(wg *sync.WaitGroup) {
} }
} }
}() }()
go logic.PurgePendingNodes(ctx) <-ctx.Done()
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGTERM, os.Interrupt)
<-quit
cancel()
logger.Log(0, "Message Queue shutting down") logger.Log(0, "Message Queue shutting down")
} }

View file

@ -100,3 +100,8 @@ func Keepalive(ctx context.Context) {
func IsConnected() bool { func IsConnected() bool {
return mqclient != nil && mqclient.IsConnected() return mqclient != nil && mqclient.IsConnected()
} }
// CloseClient - function to close the mq connection from server
func CloseClient() {
mqclient.Disconnect(250)
}

View file

@ -4,11 +4,8 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"os"
"os/signal"
"strings" "strings"
"sync" "sync"
"syscall"
"github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/servercfg" "github.com/gravitl/netmaker/servercfg"
@ -23,7 +20,6 @@ import (
// backwards compatibility with RFC 3489. // backwards compatibility with RFC 3489.
type Server struct { type Server struct {
Addr string Addr string
Ctx context.Context
} }
var ( var (
@ -60,48 +56,58 @@ func basicProcess(addr net.Addr, b []byte, req, res *stun.Message) error {
) )
} }
func (s *Server) serveConn(c net.PacketConn, res, req *stun.Message) error { func (s *Server) serveConn(c net.PacketConn, res, req *stun.Message, ctx context.Context) error {
if c == nil { if c == nil {
return nil return nil
} }
go func(ctx context.Context) {
<-ctx.Done()
if c != nil {
// kill connection on server shutdown
c.Close()
}
}(ctx)
buf := make([]byte, 1024) buf := make([]byte, 1024)
n, addr, err := c.ReadFrom(buf) n, addr, err := c.ReadFrom(buf) // this be blocky af
if err != nil { if err != nil {
logger.Log(1, "ReadFrom: %v", err.Error()) if !strings.Contains(err.Error(), "use of closed network connection") {
logger.Log(1, "STUN read error:", err.Error())
}
return nil return nil
} }
if _, err = req.Write(buf[:n]); err != nil { if _, err = req.Write(buf[:n]); err != nil {
logger.Log(1, "Write: %v", err.Error()) logger.Log(1, "STUN write error:", err.Error())
return err return err
} }
if err = basicProcess(addr, buf[:n], req, res); err != nil { if err = basicProcess(addr, buf[:n], req, res); err != nil {
if err == errNotSTUNMessage { if err == errNotSTUNMessage {
return nil return nil
} }
logger.Log(1, "basicProcess: %v", err.Error()) logger.Log(1, "STUN process error:", err.Error())
return nil return nil
} }
_, err = c.WriteTo(res.Raw, addr) _, err = c.WriteTo(res.Raw, addr)
if err != nil { if err != nil {
logger.Log(1, "WriteTo: %v", err.Error()) logger.Log(1, "STUN response write error", err.Error())
} }
return err return err
} }
// Serve reads packets from connections and responds to BINDING requests. // Serve reads packets from connections and responds to BINDING requests.
func (s *Server) serve(c net.PacketConn) error { func (s *Server) serve(c net.PacketConn, ctx context.Context) error {
var ( var (
res = new(stun.Message) res = new(stun.Message)
req = new(stun.Message) req = new(stun.Message)
) )
for { for {
select { select {
case <-s.Ctx.Done(): case <-ctx.Done():
logger.Log(0, "Shutting down stun server...") logger.Log(0, "shut down STUN server")
c.Close()
return nil return nil
default: default:
if err := s.serveConn(c, res, req); err != nil { if err := s.serveConn(c, res, req, ctx); err != nil {
logger.Log(1, "serve: %v", err.Error()) logger.Log(1, "serve: %v", err.Error())
continue continue
} }
@ -119,9 +125,8 @@ func listenUDPAndServe(ctx context.Context, serverNet, laddr string) error {
} }
s := &Server{ s := &Server{
Addr: laddr, Addr: laddr,
Ctx: ctx,
} }
return s.serve(c) return s.serve(c, ctx)
} }
func normalize(address string) string { func normalize(address string) string {
@ -135,19 +140,15 @@ func normalize(address string) string {
} }
// Start - starts the stun server // Start - starts the stun server
func Start(wg *sync.WaitGroup) { func Start(wg *sync.WaitGroup, ctx context.Context) {
ctx, cancel := context.WithCancel(context.Background()) defer wg.Done()
go func(wg *sync.WaitGroup) {
defer wg.Done()
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGTERM, os.Interrupt)
<-quit
cancel()
}(wg)
normalized := normalize(fmt.Sprintf("0.0.0.0:%d", servercfg.GetStunPort())) normalized := normalize(fmt.Sprintf("0.0.0.0:%d", servercfg.GetStunPort()))
logger.Log(0, "netmaker-stun listening on", normalized, "via udp") logger.Log(0, "netmaker-stun listening on", normalized, "via udp")
err := listenUDPAndServe(ctx, "udp", normalized) if err := listenUDPAndServe(ctx, "udp", normalized); err != nil {
if err != nil { if strings.Contains(err.Error(), "closed network connection") {
logger.Log(0, "failed to start stun server: ", err.Error()) logger.Log(0, "shutdown STUN server")
} else {
logger.Log(0, "server: ", err.Error())
}
} }
} }