mirror of
https://github.com/gravitl/netmaker.git
synced 2025-09-07 05:34:38 +08:00
adjusted main to use one single context
This commit is contained in:
parent
3244472481
commit
5014c389ca
5 changed files with 56 additions and 105 deletions
|
@ -4,11 +4,8 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
|
||||||
"os/signal"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"syscall"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gorilla/handlers"
|
"github.com/gorilla/handlers"
|
||||||
|
@ -32,7 +29,7 @@ var HttpHandlers = []interface{}{
|
||||||
}
|
}
|
||||||
|
|
||||||
// HandleRESTRequests - handles the rest requests
|
// HandleRESTRequests - handles the rest requests
|
||||||
func HandleRESTRequests(wg *sync.WaitGroup) {
|
func HandleRESTRequests(wg *sync.WaitGroup, ctx context.Context) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
|
|
||||||
r := mux.NewRouter()
|
r := mux.NewRouter()
|
||||||
|
@ -58,18 +55,14 @@ func HandleRESTRequests(wg *sync.WaitGroup) {
|
||||||
}()
|
}()
|
||||||
logger.Log(0, "REST Server successfully started on port ", port, " (REST)")
|
logger.Log(0, "REST Server successfully started on port ", port, " (REST)")
|
||||||
|
|
||||||
// Relay os.Interrupt to our channel (os.Interrupt = CTRL+C)
|
|
||||||
// Ignore other incoming signals
|
|
||||||
ctx, stop := signal.NotifyContext(context.TODO(), syscall.SIGTERM, os.Interrupt)
|
|
||||||
defer stop()
|
|
||||||
|
|
||||||
// Block main routine until a signal is received
|
// Block main routine until a signal is received
|
||||||
// As long as user doesn't press CTRL+C a message is not passed and our main routine keeps running
|
// As long as user doesn't press CTRL+C a message is not passed and our main routine keeps running
|
||||||
<-ctx.Done()
|
<-ctx.Done()
|
||||||
|
|
||||||
// After receiving CTRL+C Properly stop the server
|
// After receiving CTRL+C Properly stop the server
|
||||||
logger.Log(0, "Stopping the REST server...")
|
logger.Log(0, "Stopping the REST server...")
|
||||||
|
if err := srv.Shutdown(context.TODO()); err != nil {
|
||||||
|
logger.Log(0, "REST shutdown error occurred -", err.Error())
|
||||||
|
}
|
||||||
logger.Log(0, "REST Server closed.")
|
logger.Log(0, "REST Server closed.")
|
||||||
logger.DumpFile(fmt.Sprintf("data/netmaker.log.%s", time.Now().Format(logger.TimeFormatDay)))
|
logger.DumpFile(fmt.Sprintf("data/netmaker.log.%s", time.Now().Format(logger.TimeFormatDay)))
|
||||||
srv.Shutdown(context.TODO())
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
package logic
|
package logic
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -421,35 +420,6 @@ func updateProNodeACLS(node *models.Node) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func PurgePendingNodes(ctx context.Context) {
|
|
||||||
ticker := time.NewTicker(NodePurgeCheckTime)
|
|
||||||
defer ticker.Stop()
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
case <-ticker.C:
|
|
||||||
nodes, err := GetAllNodes()
|
|
||||||
if err != nil {
|
|
||||||
logger.Log(0, "PurgePendingNodes failed to retrieve nodes", err.Error())
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
for _, node := range nodes {
|
|
||||||
if node.PendingDelete {
|
|
||||||
modified := node.LastModified
|
|
||||||
if time.Since(modified) > NodePurgeTime {
|
|
||||||
if err := DeleteNode(&node, true); err != nil {
|
|
||||||
logger.Log(0, "failed to purge node", node.ID.String(), err.Error())
|
|
||||||
} else {
|
|
||||||
logger.Log(0, "purged node ", node.ID.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// createNode - creates a node in database
|
// createNode - creates a node in database
|
||||||
func createNode(node *models.Node) error {
|
func createNode(node *models.Node) error {
|
||||||
host, err := GetHost(node.HostID.String())
|
host, err := GetHost(node.HostID.String())
|
||||||
|
|
52
main.go
52
main.go
|
@ -36,12 +36,16 @@ func main() {
|
||||||
setupConfig(*absoluteConfigPath)
|
setupConfig(*absoluteConfigPath)
|
||||||
servercfg.SetVersion(version)
|
servercfg.SetVersion(version)
|
||||||
fmt.Println(models.RetrieveLogo()) // print the logo
|
fmt.Println(models.RetrieveLogo()) // print the logo
|
||||||
// fmt.Println(models.ProLogo())
|
initialize() // initial db and acls
|
||||||
initialize() // initial db and acls; gen cert if required
|
|
||||||
setGarbageCollection()
|
setGarbageCollection()
|
||||||
setVerbosity()
|
setVerbosity()
|
||||||
defer database.CloseDB()
|
defer database.CloseDB()
|
||||||
startControllers() // start the api endpoint and mq
|
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM, os.Interrupt)
|
||||||
|
defer stop()
|
||||||
|
var waitGroup sync.WaitGroup
|
||||||
|
startControllers(&waitGroup, ctx) // start the api endpoint and mq and stun
|
||||||
|
<-ctx.Done()
|
||||||
|
waitGroup.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
func setupConfig(absoluteConfigPath string) {
|
func setupConfig(absoluteConfigPath string) {
|
||||||
|
@ -110,8 +114,7 @@ func initialize() { // Client Mode Prereq Check
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func startControllers() {
|
func startControllers(wg *sync.WaitGroup, ctx context.Context) {
|
||||||
var waitnetwork sync.WaitGroup
|
|
||||||
if servercfg.IsDNSMode() {
|
if servercfg.IsDNSMode() {
|
||||||
err := logic.SetDNS()
|
err := logic.SetDNS()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -127,13 +130,13 @@ func startControllers() {
|
||||||
logger.FatalLog("Unable to Set host. Exiting...", err.Error())
|
logger.FatalLog("Unable to Set host. Exiting...", err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
waitnetwork.Add(1)
|
wg.Add(1)
|
||||||
go controller.HandleRESTRequests(&waitnetwork)
|
go controller.HandleRESTRequests(wg, ctx)
|
||||||
}
|
}
|
||||||
//Run MessageQueue
|
//Run MessageQueue
|
||||||
if servercfg.IsMessageQueueBackend() {
|
if servercfg.IsMessageQueueBackend() {
|
||||||
waitnetwork.Add(1)
|
wg.Add(1)
|
||||||
go runMessageQueue(&waitnetwork)
|
go runMessageQueue(wg, ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !servercfg.IsRestBackend() && !servercfg.IsMessageQueueBackend() {
|
if !servercfg.IsRestBackend() && !servercfg.IsMessageQueueBackend() {
|
||||||
|
@ -141,34 +144,17 @@ func startControllers() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// starts the stun server
|
// starts the stun server
|
||||||
waitnetwork.Add(1)
|
wg.Add(1)
|
||||||
go stunserver.Start(&waitnetwork)
|
go stunserver.Start(wg, ctx)
|
||||||
if servercfg.IsProxyEnabled() {
|
|
||||||
|
|
||||||
waitnetwork.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer waitnetwork.Done()
|
|
||||||
_, cancel := context.WithCancel(context.Background())
|
|
||||||
waitnetwork.Add(1)
|
|
||||||
|
|
||||||
//go nmproxy.Start(ctx, logic.ProxyMgmChan, servercfg.GetAPIHost())
|
|
||||||
quit := make(chan os.Signal, 1)
|
|
||||||
signal.Notify(quit, syscall.SIGTERM, os.Interrupt)
|
|
||||||
<-quit
|
|
||||||
cancel()
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
waitnetwork.Wait()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should we be using a context vice a waitgroup????????????
|
// Should we be using a context vice a waitgroup????????????
|
||||||
func runMessageQueue(wg *sync.WaitGroup) {
|
func runMessageQueue(wg *sync.WaitGroup, ctx context.Context) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
brokerHost, secure := servercfg.GetMessageQueueEndpoint()
|
brokerHost, secure := servercfg.GetMessageQueueEndpoint()
|
||||||
logger.Log(0, "connecting to mq broker at", brokerHost, "with TLS?", fmt.Sprintf("%v", secure))
|
logger.Log(0, "connecting to mq broker at", brokerHost, "with TLS?", fmt.Sprintf("%v", secure))
|
||||||
mq.SetupMQTT()
|
mq.SetupMQTT()
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
defer mq.CloseClient()
|
||||||
go mq.Keepalive(ctx)
|
go mq.Keepalive(ctx)
|
||||||
go func() {
|
go func() {
|
||||||
peerUpdate := make(chan *models.Node)
|
peerUpdate := make(chan *models.Node)
|
||||||
|
@ -179,11 +165,7 @@ func runMessageQueue(wg *sync.WaitGroup) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
go logic.PurgePendingNodes(ctx)
|
<-ctx.Done()
|
||||||
quit := make(chan os.Signal, 1)
|
|
||||||
signal.Notify(quit, syscall.SIGTERM, os.Interrupt)
|
|
||||||
<-quit
|
|
||||||
cancel()
|
|
||||||
logger.Log(0, "Message Queue shutting down")
|
logger.Log(0, "Message Queue shutting down")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
5
mq/mq.go
5
mq/mq.go
|
@ -100,3 +100,8 @@ func Keepalive(ctx context.Context) {
|
||||||
func IsConnected() bool {
|
func IsConnected() bool {
|
||||||
return mqclient != nil && mqclient.IsConnected()
|
return mqclient != nil && mqclient.IsConnected()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CloseClient - function to close the mq connection from server
|
||||||
|
func CloseClient() {
|
||||||
|
mqclient.Disconnect(250)
|
||||||
|
}
|
||||||
|
|
|
@ -4,11 +4,8 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
|
||||||
"os/signal"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"syscall"
|
|
||||||
|
|
||||||
"github.com/gravitl/netmaker/logger"
|
"github.com/gravitl/netmaker/logger"
|
||||||
"github.com/gravitl/netmaker/servercfg"
|
"github.com/gravitl/netmaker/servercfg"
|
||||||
|
@ -23,7 +20,6 @@ import (
|
||||||
// backwards compatibility with RFC 3489.
|
// backwards compatibility with RFC 3489.
|
||||||
type Server struct {
|
type Server struct {
|
||||||
Addr string
|
Addr string
|
||||||
Ctx context.Context
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -60,48 +56,58 @@ func basicProcess(addr net.Addr, b []byte, req, res *stun.Message) error {
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) serveConn(c net.PacketConn, res, req *stun.Message) error {
|
func (s *Server) serveConn(c net.PacketConn, res, req *stun.Message, ctx context.Context) error {
|
||||||
if c == nil {
|
if c == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
go func(ctx context.Context) {
|
||||||
|
<-ctx.Done()
|
||||||
|
if c != nil {
|
||||||
|
// kill connection on server shutdown
|
||||||
|
c.Close()
|
||||||
|
}
|
||||||
|
}(ctx)
|
||||||
|
|
||||||
buf := make([]byte, 1024)
|
buf := make([]byte, 1024)
|
||||||
n, addr, err := c.ReadFrom(buf)
|
n, addr, err := c.ReadFrom(buf) // this be blocky af
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Log(1, "ReadFrom: %v", err.Error())
|
if !strings.Contains(err.Error(), "use of closed network connection") {
|
||||||
|
logger.Log(1, "STUN read error:", err.Error())
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err = req.Write(buf[:n]); err != nil {
|
if _, err = req.Write(buf[:n]); err != nil {
|
||||||
logger.Log(1, "Write: %v", err.Error())
|
logger.Log(1, "STUN write error:", err.Error())
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err = basicProcess(addr, buf[:n], req, res); err != nil {
|
if err = basicProcess(addr, buf[:n], req, res); err != nil {
|
||||||
if err == errNotSTUNMessage {
|
if err == errNotSTUNMessage {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
logger.Log(1, "basicProcess: %v", err.Error())
|
logger.Log(1, "STUN process error:", err.Error())
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
_, err = c.WriteTo(res.Raw, addr)
|
_, err = c.WriteTo(res.Raw, addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Log(1, "WriteTo: %v", err.Error())
|
logger.Log(1, "STUN response write error", err.Error())
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Serve reads packets from connections and responds to BINDING requests.
|
// Serve reads packets from connections and responds to BINDING requests.
|
||||||
func (s *Server) serve(c net.PacketConn) error {
|
func (s *Server) serve(c net.PacketConn, ctx context.Context) error {
|
||||||
var (
|
var (
|
||||||
res = new(stun.Message)
|
res = new(stun.Message)
|
||||||
req = new(stun.Message)
|
req = new(stun.Message)
|
||||||
)
|
)
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-s.Ctx.Done():
|
case <-ctx.Done():
|
||||||
logger.Log(0, "Shutting down stun server...")
|
logger.Log(0, "shut down STUN server")
|
||||||
c.Close()
|
|
||||||
return nil
|
return nil
|
||||||
default:
|
default:
|
||||||
if err := s.serveConn(c, res, req); err != nil {
|
if err := s.serveConn(c, res, req, ctx); err != nil {
|
||||||
logger.Log(1, "serve: %v", err.Error())
|
logger.Log(1, "serve: %v", err.Error())
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
@ -119,9 +125,8 @@ func listenUDPAndServe(ctx context.Context, serverNet, laddr string) error {
|
||||||
}
|
}
|
||||||
s := &Server{
|
s := &Server{
|
||||||
Addr: laddr,
|
Addr: laddr,
|
||||||
Ctx: ctx,
|
|
||||||
}
|
}
|
||||||
return s.serve(c)
|
return s.serve(c, ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
func normalize(address string) string {
|
func normalize(address string) string {
|
||||||
|
@ -135,19 +140,15 @@ func normalize(address string) string {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start - starts the stun server
|
// Start - starts the stun server
|
||||||
func Start(wg *sync.WaitGroup) {
|
func Start(wg *sync.WaitGroup, ctx context.Context) {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
defer wg.Done()
|
||||||
go func(wg *sync.WaitGroup) {
|
|
||||||
defer wg.Done()
|
|
||||||
quit := make(chan os.Signal, 1)
|
|
||||||
signal.Notify(quit, syscall.SIGTERM, os.Interrupt)
|
|
||||||
<-quit
|
|
||||||
cancel()
|
|
||||||
}(wg)
|
|
||||||
normalized := normalize(fmt.Sprintf("0.0.0.0:%d", servercfg.GetStunPort()))
|
normalized := normalize(fmt.Sprintf("0.0.0.0:%d", servercfg.GetStunPort()))
|
||||||
logger.Log(0, "netmaker-stun listening on", normalized, "via udp")
|
logger.Log(0, "netmaker-stun listening on", normalized, "via udp")
|
||||||
err := listenUDPAndServe(ctx, "udp", normalized)
|
if err := listenUDPAndServe(ctx, "udp", normalized); err != nil {
|
||||||
if err != nil {
|
if strings.Contains(err.Error(), "closed network connection") {
|
||||||
logger.Log(0, "failed to start stun server: ", err.Error())
|
logger.Log(0, "shutdown STUN server")
|
||||||
|
} else {
|
||||||
|
logger.Log(0, "server: ", err.Error())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Reference in a new issue