diff --git a/lighthouse.go b/lighthouse.go index 460a1cb..da604f0 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -262,6 +262,18 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { //NOTE: many things will get much simpler when we combine static_host_map and lighthouse.hosts in config if initial || c.HasChanged("static_host_map") || c.HasChanged("static_map.cadence") || c.HasChanged("static_map.network") || c.HasChanged("static_map.lookup_timeout") { + // Clean up. Entries still in the static_host_map will be re-built. + // Entries no longer present must have their (possible) background DNS goroutines stopped. + if existingStaticList := lh.staticList.Load(); existingStaticList != nil { + lh.RLock() + for staticVpnIp := range *existingStaticList { + if am, ok := lh.addrMap[staticVpnIp]; ok && am != nil { + am.hr.Cancel() + } + } + lh.RUnlock() + } + // Build a new list based on current config. staticList := make(map[iputil.VpnIp]struct{}) err := lh.loadStaticMap(c, lh.myVpnNet, staticList) if err != nil { diff --git a/lighthouse_test.go b/lighthouse_test.go index aa4da4c..73632ac 100644 --- a/lighthouse_test.go +++ b/lighthouse_test.go @@ -12,6 +12,7 @@ import ( "github.com/slackhq/nebula/test" "github.com/slackhq/nebula/udp" "github.com/stretchr/testify/assert" + "gopkg.in/yaml.v2" ) //TODO: Add a test to ensure udpAddr is copied and not reused @@ -242,8 +243,17 @@ func TestLighthouse_reload(t *testing.T) { lh, err := NewLightHouseFromConfig(context.Background(), l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil) assert.NoError(t, err) - c.Settings["static_host_map"] = map[interface{}]interface{}{"10.128.0.2": []interface{}{"1.1.1.1:4242"}} - lh.reload(c, false) + nc := map[interface{}]interface{}{ + "static_host_map": map[interface{}]interface{}{ + "10.128.0.2": []interface{}{"1.1.1.1:4242"}, + }, + } + rc, err := yaml.Marshal(nc) + assert.NoError(t, err) + c.ReloadConfigString(string(rc)) + + err = lh.reload(c, false) + assert.NoError(t, err) } func newLHHostRequest(fromAddr *udp.Addr, myVpnIp, queryVpnIp iputil.VpnIp, lhh *LightHouseHandler) testLhReply { diff --git a/remote_list.go b/remote_list.go index f2b0d12..60a1afd 100644 --- a/remote_list.go +++ b/remote_list.go @@ -70,7 +70,7 @@ type hostnamesResults struct { hostnames []hostnamePort network string lookupTimeout time.Duration - stop chan struct{} + cancelFn func() l *logrus.Logger ips atomic.Pointer[map[netip.AddrPort]struct{}] } @@ -80,7 +80,6 @@ func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration, hostnames: make([]hostnamePort, len(hostPorts)), network: network, lookupTimeout: timeout, - stop: make(chan (struct{})), l: l, } @@ -115,6 +114,8 @@ func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration, // Time for the DNS lookup goroutine if performBackgroundLookup { + newCtx, cancel := context.WithCancel(ctx) + r.cancelFn = cancel ticker := time.NewTicker(d) go func() { defer ticker.Stop() @@ -154,9 +155,7 @@ func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration, onUpdate() } select { - case <-ctx.Done(): - return - case <-r.stop: + case <-newCtx.Done(): return case <-ticker.C: continue @@ -169,8 +168,8 @@ func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration, } func (hr *hostnamesResults) Cancel() { - if hr != nil { - hr.stop <- struct{}{} + if hr != nil && hr.cancelFn != nil { + hr.cancelFn() } }