diff --git a/hostmap.go b/hostmap.go index 597d7e4..bef6e4e 100644 --- a/hostmap.go +++ b/hostmap.go @@ -261,7 +261,7 @@ func NewHostMap(l *logrus.Logger, name string, vpnCIDR *net.IPNet, preferredRang r := map[uint32]*HostInfo{} relays := map[uint32]*HostInfo{} m := HostMap{ - syncRWMutex: newSyncRWMutex("hostmap", name), + syncRWMutex: newSyncRWMutex(mutexKey{Type: "hostmap", SubType: name}), name: name, Indexes: i, Relays: relays, @@ -322,7 +322,7 @@ func (hm *HostMap) AddVpnIp(vpnIp iputil.VpnIp, init func(hostinfo *HostInfo)) ( if h, ok := hm.Hosts[vpnIp]; !ok { hm.RUnlock() h = &HostInfo{ - syncRWMutex: newSyncRWMutex("hostinfo"), + syncRWMutex: newSyncRWMutex(mutexKey{Type: "hostinfo", ID: uint32(vpnIp)}), vpnIp: vpnIp, HandshakePacket: make(map[uint8][]byte, 0), relayState: RelayState{ diff --git a/mutex.go b/mutex.go index b7ff88a..1817674 100644 --- a/mutex.go +++ b/mutex.go @@ -9,6 +9,12 @@ import ( type syncRWMutex = sync.RWMutex -func newSyncRWMutex(t ...string) syncRWMutex { +func newSyncRWMutex(mutexKey) syncRWMutex { return sync.RWMutex{} } + +type mutexKey struct { + Type string + SubType string + ID uint32 +} diff --git a/mutex_debug.go b/mutex_debug.go index 5f20092..c692cb8 100644 --- a/mutex_debug.go +++ b/mutex_debug.go @@ -4,63 +4,75 @@ package nebula import ( - "strings" + "fmt" "sync" "github.com/timandy/routine" ) -var threadLocal routine.ThreadLocal = routine.NewThreadLocalWithInitial(func() any { return map[string]bool{} }) +var threadLocal routine.ThreadLocal = routine.NewThreadLocalWithInitial(func() any { return map[mutexKey]bool{} }) + +type mutexKey struct { + Type string + SubType string + ID uint32 +} type syncRWMutex struct { sync.RWMutex - mutexType string + mutexKey } -func newSyncRWMutex(t ...string) syncRWMutex { +func newSyncRWMutex(key mutexKey) syncRWMutex { return syncRWMutex{ - mutexType: strings.Join(t, "-"), + mutexKey: key, } } -func checkMutex(state map[string]bool, add string) { - if add == "hostinfo" { - if state["hostmap-main"] { - panic("grabbing hostinfo lock and already have hostmap-main") +func checkMutex(state map[mutexKey]bool, add mutexKey) { + switch add.Type { + case "hostinfo": + // Check for any other hostinfo keys: + for k, v := range state { + if k.Type == "hostinfo" && v { + panic(fmt.Errorf("grabbing hostinfo lock and already have a hostinfo lock: state=%v add=%v", state, add)) + } } - if state["hostmap-pending"] { - panic("grabbing hostinfo lock and already have hostmap-pending") + if state[mutexKey{Type: "hostmap", SubType: "main"}] { + panic(fmt.Errorf("grabbing hostinfo lock and already have hostmap-main: state=%v add=%v", state, add)) } - } - if add == "hostmap-pending" { - if state["hostmap-main"] { - panic("grabbing hostmap-pending lock and already have hostmap-main") + if state[mutexKey{Type: "hostmap", SubType: "pending"}] { + panic(fmt.Errorf("grabbing hostinfo lock and already have hostmap-pending: state=%v add=%v", state, add)) + } + case "hostmap-pending": + if state[mutexKey{Type: "hostmap", SubType: "main"}] { + panic(fmt.Errorf("grabbing hostmap-pending lock and already have hostmap-main: state=%v add=%v", state, add)) } } } func (s *syncRWMutex) Lock() { - m := threadLocal.Get().(map[string]bool) - checkMutex(m, s.mutexType) - m[s.mutexType] = true + m := threadLocal.Get().(map[mutexKey]bool) + checkMutex(m, s.mutexKey) + m[s.mutexKey] = true s.RWMutex.Lock() } func (s *syncRWMutex) Unlock() { - m := threadLocal.Get().(map[string]bool) - m[s.mutexType] = false + m := threadLocal.Get().(map[mutexKey]bool) + m[s.mutexKey] = false s.RWMutex.Unlock() } func (s *syncRWMutex) RLock() { - m := threadLocal.Get().(map[string]bool) - checkMutex(m, s.mutexType) - m[s.mutexType] = true + m := threadLocal.Get().(map[mutexKey]bool) + checkMutex(m, s.mutexKey) + m[s.mutexKey] = true s.RWMutex.RLock() } func (s *syncRWMutex) RUnlock() { - m := threadLocal.Get().(map[string]bool) - m[s.mutexType] = false + m := threadLocal.Get().(map[mutexKey]bool) + m[s.mutexKey] = false s.RWMutex.RUnlock() }