diff --git a/netclient/daemon/common.go b/netclient/daemon/common.go index da24e268..c36d91a6 100644 --- a/netclient/daemon/common.go +++ b/netclient/daemon/common.go @@ -2,8 +2,13 @@ package daemon import ( "errors" + "fmt" + "os" "runtime" + "syscall" "time" + + "github.com/gravitl/netmaker/netclient/ncutils" ) // InstallDaemon - Calls the correct function to install the netclient as a daemon service on the given operating system. @@ -28,24 +33,18 @@ func InstallDaemon() error { // Restart - restarts a system daemon func Restart() error { - os := runtime.GOOS - var err error - - time.Sleep(time.Second) - - switch os { - case "windows": - RestartWindowsDaemon() - case "darwin": - RestartLaunchD() - case "linux": - RestartSystemD() - case "freebsd": - FreebsdDaemon("restart") - default: - err = errors.New("this os is not yet supported for daemon mode. Run join cmd with flag '--daemon off'") + pid, err := ncutils.ReadPID() + if err != nil { + return fmt.Errorf("failed to find pid %w", err) } - return err + p, err := os.FindProcess(pid) + if err != nil { + return fmt.Errorf("failed to find running process for pid %d -- %w", pid, err) + } + if err := p.Signal(syscall.SIGHUP); err != nil { + return fmt.Errorf("SIGHUP failed -- %w", err) + } + return nil } // Stop - stops a system daemon diff --git a/netclient/functions/daemon.go b/netclient/functions/daemon.go index 06e15bad..02457c7d 100644 --- a/netclient/functions/daemon.go +++ b/netclient/functions/daemon.go @@ -30,7 +30,7 @@ import ( ) var messageCache = new(sync.Map) -var networkcontext = new(sync.Map) +var serverSet map[string]bool const lastNodeUpdate = "lnu" const lastPeerUpdate = "lpu" @@ -43,19 +43,53 @@ type cachedMessage struct { // Daemon runs netclient daemon from command line func Daemon() error { UpdateClientConfig() - serverSet := make(map[string]bool) + if err := ncutils.SavePID(); err != nil { + return err + } + serverSet = make(map[string]bool) // == initial pull of all networks == networks, _ := ncutils.GetSystemNetworks() if len(networks) == 0 { return errors.New("no networks") } - pubNetworks = append(pubNetworks, networks...) // set ipforwarding on startup err := local.SetIPForwarding() if err != nil { logger.Log(0, err.Error()) } + // == add waitgroup and cancel for checkin routine == + wg := sync.WaitGroup{} + quit := make(chan os.Signal, 1) + reset := make(chan os.Signal, 1) + signal.Notify(quit, syscall.SIGTERM, os.Interrupt) + signal.Notify(reset, syscall.SIGHUP) + cancel := startGoRoutines(&wg) + for { + select { + case <-quit: + cancel() + logger.Log(0, "shutting down netclient daemon") + wg.Wait() + logger.Log(0, "shutdown complete") + return nil + case <-reset: + logger.Log(0, "received reset") + cancel() + wg.Wait() + logger.Log(0, "restarting daemon") + cancel = startGoRoutines(&wg) + } + } +} + +func startGoRoutines(wg *sync.WaitGroup) context.CancelFunc { + defer wg.Done() + ctx, cancel := context.WithCancel(context.Background()) + wg.Add(1) + go Checkin(ctx, wg) + serverSet := make(map[string]bool) + networks, _ := ncutils.GetSystemNetworks() for _, network := range networks { logger.Log(3, "initializing network", network) cfg := config.ClientConfig{} @@ -69,30 +103,10 @@ func Daemon() error { // == subscribe to all nodes for each on machine == serverSet[server] = true logger.Log(1, "started daemon for server ", server) - ctx, cancel := context.WithCancel(context.Background()) - networkcontext.Store(server, cancel) - go messageQueue(ctx, &cfg) + go messageQueue(ctx, wg, &cfg) } } - - // == add waitgroup and cancel for checkin routine == - wg := sync.WaitGroup{} - ctx, cancel := context.WithCancel(context.Background()) - wg.Add(1) - go Checkin(ctx, &wg) - quit := make(chan os.Signal, 1) - signal.Notify(quit, syscall.SIGTERM, os.Interrupt) - <-quit - for server := range serverSet { - if cancel, ok := networkcontext.Load(server); ok { - cancel.(context.CancelFunc)() - } - } - cancel() - logger.Log(0, "shutting down netclient daemon") - wg.Wait() - logger.Log(0, "shutdown complete") - return nil + return cancel } // UpdateKeys -- updates private key and returns new publickey @@ -167,7 +181,8 @@ func unsubscribeNode(client mqtt.Client, nodeCfg *config.ClientConfig) { // sets up Message Queue and subsribes/publishes updates to/from server // the client should subscribe to ALL nodes that exist on server locally -func messageQueue(ctx context.Context, cfg *config.ClientConfig) { +func messageQueue(ctx context.Context, wg *sync.WaitGroup, cfg *config.ClientConfig) { + defer wg.Done() logger.Log(0, "netclient daemon started for server: ", cfg.Server.Server) client, err := setupMQTT(cfg, false) if err != nil { diff --git a/netclient/functions/install.go b/netclient/functions/install.go index bd8dec7f..1564d450 100644 --- a/netclient/functions/install.go +++ b/netclient/functions/install.go @@ -1,6 +1,8 @@ package functions import ( + "time" + "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/netclient/daemon" ) @@ -12,5 +14,6 @@ func Install() error { logger.Log(0, "error installing daemon", err.Error()) return err } + time.Sleep(time.Second * 5) return daemon.Restart() } diff --git a/netclient/ncutils/pid.go b/netclient/ncutils/pid.go new file mode 100644 index 00000000..f4eb969b --- /dev/null +++ b/netclient/ncutils/pid.go @@ -0,0 +1,32 @@ +package ncutils + +import ( + "fmt" + "os" + "strconv" +) + +// PIDFILE - path/name of pid file +const PIDFILE = "/var/run/netclient.pid" + +// SavePID - saves the pid of running program to disk +func SavePID() error { + pid := os.Getpid() + if err := os.WriteFile(PIDFILE, []byte(fmt.Sprintf("%d", pid)), 0644); err != nil { + return fmt.Errorf("could not write to pid file %w", err) + } + return nil +} + +// ReadPID - reads a previously saved pid from disk +func ReadPID() (int, error) { + bytes, err := os.ReadFile(PIDFILE) + if err != nil { + return 0, fmt.Errorf("could not read pid file %w", err) + } + pid, err := strconv.Atoi(string(bytes)) + if err != nil { + return 0, fmt.Errorf("pid file contents invalid %w", err) + } + return pid, nil +}