mirror of
https://github.com/slackhq/nebula.git
synced 2024-09-20 06:46:11 +08:00
Switch most everything to netip in prep for ipv6 in the overlay (#1173)
This commit is contained in:
parent
00458302ca
commit
e264a0ff88
|
@ -2,17 +2,16 @@ package nebula
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"regexp"
|
||||
|
||||
"github.com/slackhq/nebula/cidr"
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
)
|
||||
|
||||
type AllowList struct {
|
||||
// The values of this cidrTree are `bool`, signifying allow/deny
|
||||
cidrTree *cidr.Tree6[bool]
|
||||
cidrTree *bart.Table[bool]
|
||||
}
|
||||
|
||||
type RemoteAllowList struct {
|
||||
|
@ -20,7 +19,7 @@ type RemoteAllowList struct {
|
|||
|
||||
// Inside Range Specific, keys of this tree are inside CIDRs and values
|
||||
// are *AllowList
|
||||
insideAllowLists *cidr.Tree6[*AllowList]
|
||||
insideAllowLists *bart.Table[*AllowList]
|
||||
}
|
||||
|
||||
type LocalAllowList struct {
|
||||
|
@ -88,7 +87,7 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in
|
|||
return nil, fmt.Errorf("config `%s` has invalid type: %T", k, raw)
|
||||
}
|
||||
|
||||
tree := cidr.NewTree6[bool]()
|
||||
tree := new(bart.Table[bool])
|
||||
|
||||
// Keep track of the rules we have added for both ipv4 and ipv6
|
||||
type allowListRules struct {
|
||||
|
@ -122,18 +121,20 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in
|
|||
return nil, fmt.Errorf("config `%s` has invalid value (type %T): %v", k, rawValue, rawValue)
|
||||
}
|
||||
|
||||
_, ipNet, err := net.ParseCIDR(rawCIDR)
|
||||
ipNet, err := netip.ParsePrefix(rawCIDR)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
|
||||
return nil, fmt.Errorf("config `%s` has invalid CIDR: %s. %w", k, rawCIDR, err)
|
||||
}
|
||||
|
||||
// TODO: should we error on duplicate CIDRs in the config?
|
||||
tree.AddCIDR(ipNet, value)
|
||||
ipNet = netip.PrefixFrom(ipNet.Addr().Unmap(), ipNet.Bits())
|
||||
|
||||
maskBits, maskSize := ipNet.Mask.Size()
|
||||
// TODO: should we error on duplicate CIDRs in the config?
|
||||
tree.Insert(ipNet, value)
|
||||
|
||||
maskBits := ipNet.Bits()
|
||||
|
||||
var rules *allowListRules
|
||||
if maskSize == 32 {
|
||||
if ipNet.Addr().Is4() {
|
||||
rules = &rules4
|
||||
} else {
|
||||
rules = &rules6
|
||||
|
@ -156,8 +157,7 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in
|
|||
|
||||
if !rules4.defaultSet {
|
||||
if rules4.allValuesMatch {
|
||||
_, zeroCIDR, _ := net.ParseCIDR("0.0.0.0/0")
|
||||
tree.AddCIDR(zeroCIDR, !rules4.allValues)
|
||||
tree.Insert(netip.PrefixFrom(netip.IPv4Unspecified(), 0), !rules4.allValues)
|
||||
} else {
|
||||
return nil, fmt.Errorf("config `%s` contains both true and false rules, but no default set for 0.0.0.0/0", k)
|
||||
}
|
||||
|
@ -165,8 +165,7 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in
|
|||
|
||||
if !rules6.defaultSet {
|
||||
if rules6.allValuesMatch {
|
||||
_, zeroCIDR, _ := net.ParseCIDR("::/0")
|
||||
tree.AddCIDR(zeroCIDR, !rules6.allValues)
|
||||
tree.Insert(netip.PrefixFrom(netip.IPv6Unspecified(), 0), !rules6.allValues)
|
||||
} else {
|
||||
return nil, fmt.Errorf("config `%s` contains both true and false rules, but no default set for ::/0", k)
|
||||
}
|
||||
|
@ -218,13 +217,13 @@ func getAllowListInterfaces(k string, v interface{}) ([]AllowListNameRule, error
|
|||
return nameRules, nil
|
||||
}
|
||||
|
||||
func getRemoteAllowRanges(c *config.C, k string) (*cidr.Tree6[*AllowList], error) {
|
||||
func getRemoteAllowRanges(c *config.C, k string) (*bart.Table[*AllowList], error) {
|
||||
value := c.Get(k)
|
||||
if value == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
remoteAllowRanges := cidr.NewTree6[*AllowList]()
|
||||
remoteAllowRanges := new(bart.Table[*AllowList])
|
||||
|
||||
rawMap, ok := value.(map[interface{}]interface{})
|
||||
if !ok {
|
||||
|
@ -241,45 +240,27 @@ func getRemoteAllowRanges(c *config.C, k string) (*cidr.Tree6[*AllowList], error
|
|||
return nil, err
|
||||
}
|
||||
|
||||
_, ipNet, err := net.ParseCIDR(rawCIDR)
|
||||
ipNet, err := netip.ParsePrefix(rawCIDR)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
|
||||
return nil, fmt.Errorf("config `%s` has invalid CIDR: %s. %w", k, rawCIDR, err)
|
||||
}
|
||||
|
||||
remoteAllowRanges.AddCIDR(ipNet, allowList)
|
||||
remoteAllowRanges.Insert(netip.PrefixFrom(ipNet.Addr().Unmap(), ipNet.Bits()), allowList)
|
||||
}
|
||||
|
||||
return remoteAllowRanges, nil
|
||||
}
|
||||
|
||||
func (al *AllowList) Allow(ip net.IP) bool {
|
||||
func (al *AllowList) Allow(ip netip.Addr) bool {
|
||||
if al == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
_, result := al.cidrTree.MostSpecificContains(ip)
|
||||
result, _ := al.cidrTree.Lookup(ip)
|
||||
return result
|
||||
}
|
||||
|
||||
func (al *AllowList) AllowIpV4(ip iputil.VpnIp) bool {
|
||||
if al == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
_, result := al.cidrTree.MostSpecificContainsIpV4(ip)
|
||||
return result
|
||||
}
|
||||
|
||||
func (al *AllowList) AllowIpV6(hi, lo uint64) bool {
|
||||
if al == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
_, result := al.cidrTree.MostSpecificContainsIpV6(hi, lo)
|
||||
return result
|
||||
}
|
||||
|
||||
func (al *LocalAllowList) Allow(ip net.IP) bool {
|
||||
func (al *LocalAllowList) Allow(ip netip.Addr) bool {
|
||||
if al == nil {
|
||||
return true
|
||||
}
|
||||
|
@ -301,43 +282,23 @@ func (al *LocalAllowList) AllowName(name string) bool {
|
|||
return !al.nameRules[0].Allow
|
||||
}
|
||||
|
||||
func (al *RemoteAllowList) AllowUnknownVpnIp(ip net.IP) bool {
|
||||
func (al *RemoteAllowList) AllowUnknownVpnIp(ip netip.Addr) bool {
|
||||
if al == nil {
|
||||
return true
|
||||
}
|
||||
return al.AllowList.Allow(ip)
|
||||
}
|
||||
|
||||
func (al *RemoteAllowList) Allow(vpnIp iputil.VpnIp, ip net.IP) bool {
|
||||
func (al *RemoteAllowList) Allow(vpnIp netip.Addr, ip netip.Addr) bool {
|
||||
if !al.getInsideAllowList(vpnIp).Allow(ip) {
|
||||
return false
|
||||
}
|
||||
return al.AllowList.Allow(ip)
|
||||
}
|
||||
|
||||
func (al *RemoteAllowList) AllowIpV4(vpnIp iputil.VpnIp, ip iputil.VpnIp) bool {
|
||||
if al == nil {
|
||||
return true
|
||||
}
|
||||
if !al.getInsideAllowList(vpnIp).AllowIpV4(ip) {
|
||||
return false
|
||||
}
|
||||
return al.AllowList.AllowIpV4(ip)
|
||||
}
|
||||
|
||||
func (al *RemoteAllowList) AllowIpV6(vpnIp iputil.VpnIp, hi, lo uint64) bool {
|
||||
if al == nil {
|
||||
return true
|
||||
}
|
||||
if !al.getInsideAllowList(vpnIp).AllowIpV6(hi, lo) {
|
||||
return false
|
||||
}
|
||||
return al.AllowList.AllowIpV6(hi, lo)
|
||||
}
|
||||
|
||||
func (al *RemoteAllowList) getInsideAllowList(vpnIp iputil.VpnIp) *AllowList {
|
||||
func (al *RemoteAllowList) getInsideAllowList(vpnIp netip.Addr) *AllowList {
|
||||
if al.insideAllowLists != nil {
|
||||
ok, inside := al.insideAllowLists.MostSpecificContainsIpV4(vpnIp)
|
||||
inside, ok := al.insideAllowLists.Lookup(vpnIp)
|
||||
if ok {
|
||||
return inside
|
||||
}
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
package nebula
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"regexp"
|
||||
"testing"
|
||||
|
||||
"github.com/slackhq/nebula/cidr"
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/test"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
@ -18,7 +18,7 @@ func TestNewAllowListFromConfig(t *testing.T) {
|
|||
"192.168.0.0": true,
|
||||
}
|
||||
r, err := newAllowListFromConfig(c, "allowlist", nil)
|
||||
assert.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0")
|
||||
assert.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0. netip.ParsePrefix(\"192.168.0.0\"): no '/'")
|
||||
assert.Nil(t, r)
|
||||
|
||||
c.Settings["allowlist"] = map[interface{}]interface{}{
|
||||
|
@ -98,26 +98,26 @@ func TestNewAllowListFromConfig(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestAllowList_Allow(t *testing.T) {
|
||||
assert.Equal(t, true, ((*AllowList)(nil)).Allow(net.ParseIP("1.1.1.1")))
|
||||
assert.Equal(t, true, ((*AllowList)(nil)).Allow(netip.MustParseAddr("1.1.1.1")))
|
||||
|
||||
tree := cidr.NewTree6[bool]()
|
||||
tree.AddCIDR(cidr.Parse("0.0.0.0/0"), true)
|
||||
tree.AddCIDR(cidr.Parse("10.0.0.0/8"), false)
|
||||
tree.AddCIDR(cidr.Parse("10.42.42.42/32"), true)
|
||||
tree.AddCIDR(cidr.Parse("10.42.0.0/16"), true)
|
||||
tree.AddCIDR(cidr.Parse("10.42.42.0/24"), true)
|
||||
tree.AddCIDR(cidr.Parse("10.42.42.0/24"), false)
|
||||
tree.AddCIDR(cidr.Parse("::1/128"), true)
|
||||
tree.AddCIDR(cidr.Parse("::2/128"), false)
|
||||
tree := new(bart.Table[bool])
|
||||
tree.Insert(netip.MustParsePrefix("0.0.0.0/0"), true)
|
||||
tree.Insert(netip.MustParsePrefix("10.0.0.0/8"), false)
|
||||
tree.Insert(netip.MustParsePrefix("10.42.42.42/32"), true)
|
||||
tree.Insert(netip.MustParsePrefix("10.42.0.0/16"), true)
|
||||
tree.Insert(netip.MustParsePrefix("10.42.42.0/24"), true)
|
||||
tree.Insert(netip.MustParsePrefix("10.42.42.0/24"), false)
|
||||
tree.Insert(netip.MustParsePrefix("::1/128"), true)
|
||||
tree.Insert(netip.MustParsePrefix("::2/128"), false)
|
||||
al := &AllowList{cidrTree: tree}
|
||||
|
||||
assert.Equal(t, true, al.Allow(net.ParseIP("1.1.1.1")))
|
||||
assert.Equal(t, false, al.Allow(net.ParseIP("10.0.0.4")))
|
||||
assert.Equal(t, true, al.Allow(net.ParseIP("10.42.42.42")))
|
||||
assert.Equal(t, false, al.Allow(net.ParseIP("10.42.42.41")))
|
||||
assert.Equal(t, true, al.Allow(net.ParseIP("10.42.0.1")))
|
||||
assert.Equal(t, true, al.Allow(net.ParseIP("::1")))
|
||||
assert.Equal(t, false, al.Allow(net.ParseIP("::2")))
|
||||
assert.Equal(t, true, al.Allow(netip.MustParseAddr("1.1.1.1")))
|
||||
assert.Equal(t, false, al.Allow(netip.MustParseAddr("10.0.0.4")))
|
||||
assert.Equal(t, true, al.Allow(netip.MustParseAddr("10.42.42.42")))
|
||||
assert.Equal(t, false, al.Allow(netip.MustParseAddr("10.42.42.41")))
|
||||
assert.Equal(t, true, al.Allow(netip.MustParseAddr("10.42.0.1")))
|
||||
assert.Equal(t, true, al.Allow(netip.MustParseAddr("::1")))
|
||||
assert.Equal(t, false, al.Allow(netip.MustParseAddr("::2")))
|
||||
}
|
||||
|
||||
func TestLocalAllowList_AllowName(t *testing.T) {
|
||||
|
|
|
@ -1,41 +1,36 @@
|
|||
package nebula
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"math"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
|
||||
"github.com/slackhq/nebula/cidr"
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
)
|
||||
|
||||
// This allows us to "guess" what the remote might be for a host while we wait
|
||||
// for the lighthouse response. See "lighthouse.calculated_remotes" in the
|
||||
// example config file.
|
||||
type calculatedRemote struct {
|
||||
ipNet net.IPNet
|
||||
maskIP iputil.VpnIp
|
||||
mask iputil.VpnIp
|
||||
port uint32
|
||||
ipNet netip.Prefix
|
||||
mask netip.Prefix
|
||||
port uint32
|
||||
}
|
||||
|
||||
func newCalculatedRemote(ipNet *net.IPNet, port int) (*calculatedRemote, error) {
|
||||
// Ensure this is an IPv4 mask that we expect
|
||||
ones, bits := ipNet.Mask.Size()
|
||||
if ones == 0 || bits != 32 {
|
||||
return nil, fmt.Errorf("invalid mask: %v", ipNet)
|
||||
}
|
||||
func newCalculatedRemote(maskCidr netip.Prefix, port int) (*calculatedRemote, error) {
|
||||
masked := maskCidr.Masked()
|
||||
if port < 0 || port > math.MaxUint16 {
|
||||
return nil, fmt.Errorf("invalid port: %d", port)
|
||||
}
|
||||
|
||||
return &calculatedRemote{
|
||||
ipNet: *ipNet,
|
||||
maskIP: iputil.Ip2VpnIp(ipNet.IP),
|
||||
mask: iputil.Ip2VpnIp(ipNet.Mask),
|
||||
port: uint32(port),
|
||||
ipNet: maskCidr,
|
||||
mask: masked,
|
||||
port: uint32(port),
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -43,21 +38,41 @@ func (c *calculatedRemote) String() string {
|
|||
return fmt.Sprintf("CalculatedRemote(mask=%v port=%d)", c.ipNet, c.port)
|
||||
}
|
||||
|
||||
func (c *calculatedRemote) Apply(ip iputil.VpnIp) *Ip4AndPort {
|
||||
func (c *calculatedRemote) Apply(ip netip.Addr) *Ip4AndPort {
|
||||
// Combine the masked bytes of the "mask" IP with the unmasked bytes
|
||||
// of the overlay IP
|
||||
masked := (c.maskIP & c.mask) | (ip & ^c.mask)
|
||||
|
||||
return &Ip4AndPort{Ip: uint32(masked), Port: c.port}
|
||||
if c.ipNet.Addr().Is4() {
|
||||
return c.apply4(ip)
|
||||
}
|
||||
return c.apply6(ip)
|
||||
}
|
||||
|
||||
func NewCalculatedRemotesFromConfig(c *config.C, k string) (*cidr.Tree4[[]*calculatedRemote], error) {
|
||||
func (c *calculatedRemote) apply4(ip netip.Addr) *Ip4AndPort {
|
||||
//TODO: IPV6-WORK this can be less crappy
|
||||
maskb := net.CIDRMask(c.mask.Bits(), c.mask.Addr().BitLen())
|
||||
mask := binary.BigEndian.Uint32(maskb[:])
|
||||
|
||||
b := c.mask.Addr().As4()
|
||||
maskIp := binary.BigEndian.Uint32(b[:])
|
||||
|
||||
b = ip.As4()
|
||||
intIp := binary.BigEndian.Uint32(b[:])
|
||||
|
||||
return &Ip4AndPort{(maskIp & mask) | (intIp & ^mask), c.port}
|
||||
}
|
||||
|
||||
func (c *calculatedRemote) apply6(ip netip.Addr) *Ip4AndPort {
|
||||
//TODO: IPV6-WORK
|
||||
panic("Can not calculate ipv6 remote addresses")
|
||||
}
|
||||
|
||||
func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calculatedRemote], error) {
|
||||
value := c.Get(k)
|
||||
if value == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
calculatedRemotes := cidr.NewTree4[[]*calculatedRemote]()
|
||||
calculatedRemotes := new(bart.Table[[]*calculatedRemote])
|
||||
|
||||
rawMap, ok := value.(map[any]any)
|
||||
if !ok {
|
||||
|
@ -69,17 +84,18 @@ func NewCalculatedRemotesFromConfig(c *config.C, k string) (*cidr.Tree4[[]*calcu
|
|||
return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
|
||||
}
|
||||
|
||||
_, ipNet, err := net.ParseCIDR(rawCIDR)
|
||||
cidr, err := netip.ParsePrefix(rawCIDR)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
|
||||
}
|
||||
|
||||
//TODO: IPV6-WORK this does not verify that rawValue contains the same bits as cidr here
|
||||
entry, err := newCalculatedRemotesListFromConfig(rawValue)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("config '%s.%s': %w", k, rawCIDR, err)
|
||||
}
|
||||
|
||||
calculatedRemotes.AddCIDR(ipNet, entry)
|
||||
calculatedRemotes.Insert(cidr, entry)
|
||||
}
|
||||
|
||||
return calculatedRemotes, nil
|
||||
|
@ -117,7 +133,7 @@ func newCalculatedRemotesEntryFromConfig(raw any) (*calculatedRemote, error) {
|
|||
if !ok {
|
||||
return nil, fmt.Errorf("invalid mask (type %T): %v", rawValue, rawValue)
|
||||
}
|
||||
_, ipNet, err := net.ParseCIDR(rawMask)
|
||||
maskCidr, err := netip.ParsePrefix(rawMask)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid mask: %s", rawMask)
|
||||
}
|
||||
|
@ -139,5 +155,5 @@ func newCalculatedRemotesEntryFromConfig(raw any) (*calculatedRemote, error) {
|
|||
return nil, fmt.Errorf("invalid port (type %T): %v", rawValue, rawValue)
|
||||
}
|
||||
|
||||
return newCalculatedRemote(ipNet, port)
|
||||
return newCalculatedRemote(maskCidr, port)
|
||||
}
|
||||
|
|
|
@ -1,27 +1,25 @@
|
|||
package nebula
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCalculatedRemoteApply(t *testing.T) {
|
||||
_, ipNet, err := net.ParseCIDR("192.168.1.0/24")
|
||||
ipNet, err := netip.ParsePrefix("192.168.1.0/24")
|
||||
require.NoError(t, err)
|
||||
|
||||
c, err := newCalculatedRemote(ipNet, 4242)
|
||||
require.NoError(t, err)
|
||||
|
||||
input := iputil.Ip2VpnIp([]byte{10, 0, 10, 182})
|
||||
input, err := netip.ParseAddr("10.0.10.182")
|
||||
assert.NoError(t, err)
|
||||
|
||||
expected := &Ip4AndPort{
|
||||
Ip: uint32(iputil.Ip2VpnIp([]byte{192, 168, 1, 182})),
|
||||
Port: 4242,
|
||||
}
|
||||
expected, err := netip.ParseAddr("192.168.1.182")
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, expected, c.Apply(input))
|
||||
assert.Equal(t, NewIp4AndPortFromNetIP(expected, 4242), c.Apply(input))
|
||||
}
|
||||
|
|
|
@ -1,10 +0,0 @@
|
|||
package cidr
|
||||
|
||||
import "net"
|
||||
|
||||
// Parse is a convenience function that returns only the IPNet
|
||||
// This function ignores errors since it is primarily a test helper, the result could be nil
|
||||
func Parse(s string) *net.IPNet {
|
||||
_, c, _ := net.ParseCIDR(s)
|
||||
return c
|
||||
}
|
203
cidr/tree4.go
203
cidr/tree4.go
|
@ -1,203 +0,0 @@
|
|||
package cidr
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
)
|
||||
|
||||
type Node[T any] struct {
|
||||
left *Node[T]
|
||||
right *Node[T]
|
||||
parent *Node[T]
|
||||
hasValue bool
|
||||
value T
|
||||
}
|
||||
|
||||
type entry[T any] struct {
|
||||
CIDR *net.IPNet
|
||||
Value T
|
||||
}
|
||||
|
||||
type Tree4[T any] struct {
|
||||
root *Node[T]
|
||||
list []entry[T]
|
||||
}
|
||||
|
||||
const (
|
||||
startbit = iputil.VpnIp(0x80000000)
|
||||
)
|
||||
|
||||
func NewTree4[T any]() *Tree4[T] {
|
||||
tree := new(Tree4[T])
|
||||
tree.root = &Node[T]{}
|
||||
tree.list = []entry[T]{}
|
||||
return tree
|
||||
}
|
||||
|
||||
func (tree *Tree4[T]) AddCIDR(cidr *net.IPNet, val T) {
|
||||
bit := startbit
|
||||
node := tree.root
|
||||
next := tree.root
|
||||
|
||||
ip := iputil.Ip2VpnIp(cidr.IP)
|
||||
mask := iputil.Ip2VpnIp(cidr.Mask)
|
||||
|
||||
// Find our last ancestor in the tree
|
||||
for bit&mask != 0 {
|
||||
if ip&bit != 0 {
|
||||
next = node.right
|
||||
} else {
|
||||
next = node.left
|
||||
}
|
||||
|
||||
if next == nil {
|
||||
break
|
||||
}
|
||||
|
||||
bit = bit >> 1
|
||||
node = next
|
||||
}
|
||||
|
||||
// We already have this range so update the value
|
||||
if next != nil {
|
||||
addCIDR := cidr.String()
|
||||
for i, v := range tree.list {
|
||||
if addCIDR == v.CIDR.String() {
|
||||
tree.list = append(tree.list[:i], tree.list[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
tree.list = append(tree.list, entry[T]{CIDR: cidr, Value: val})
|
||||
node.value = val
|
||||
node.hasValue = true
|
||||
return
|
||||
}
|
||||
|
||||
// Build up the rest of the tree we don't already have
|
||||
for bit&mask != 0 {
|
||||
next = &Node[T]{}
|
||||
next.parent = node
|
||||
|
||||
if ip&bit != 0 {
|
||||
node.right = next
|
||||
} else {
|
||||
node.left = next
|
||||
}
|
||||
|
||||
bit >>= 1
|
||||
node = next
|
||||
}
|
||||
|
||||
// Final node marks our cidr, set the value
|
||||
node.value = val
|
||||
node.hasValue = true
|
||||
tree.list = append(tree.list, entry[T]{CIDR: cidr, Value: val})
|
||||
}
|
||||
|
||||
// Contains finds the first match, which may be the least specific
|
||||
func (tree *Tree4[T]) Contains(ip iputil.VpnIp) (ok bool, value T) {
|
||||
bit := startbit
|
||||
node := tree.root
|
||||
|
||||
for node != nil {
|
||||
if node.hasValue {
|
||||
return true, node.value
|
||||
}
|
||||
|
||||
if ip&bit != 0 {
|
||||
node = node.right
|
||||
} else {
|
||||
node = node.left
|
||||
}
|
||||
|
||||
bit >>= 1
|
||||
|
||||
}
|
||||
|
||||
return false, value
|
||||
}
|
||||
|
||||
// MostSpecificContains finds the most specific match
|
||||
func (tree *Tree4[T]) MostSpecificContains(ip iputil.VpnIp) (ok bool, value T) {
|
||||
bit := startbit
|
||||
node := tree.root
|
||||
|
||||
for node != nil {
|
||||
if node.hasValue {
|
||||
value = node.value
|
||||
ok = true
|
||||
}
|
||||
|
||||
if ip&bit != 0 {
|
||||
node = node.right
|
||||
} else {
|
||||
node = node.left
|
||||
}
|
||||
|
||||
bit >>= 1
|
||||
}
|
||||
|
||||
return ok, value
|
||||
}
|
||||
|
||||
type eachFunc[T any] func(T) bool
|
||||
|
||||
// EachContains will call a function, passing the value, for each entry until the function returns true or the search is complete
|
||||
// The final return value will be true if the provided function returned true
|
||||
func (tree *Tree4[T]) EachContains(ip iputil.VpnIp, each eachFunc[T]) bool {
|
||||
bit := startbit
|
||||
node := tree.root
|
||||
|
||||
for node != nil {
|
||||
if node.hasValue {
|
||||
// If the each func returns true then we can exit the loop
|
||||
if each(node.value) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
if ip&bit != 0 {
|
||||
node = node.right
|
||||
} else {
|
||||
node = node.left
|
||||
}
|
||||
|
||||
bit >>= 1
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// GetCIDR returns the entry added by the most recent matching AddCIDR call
|
||||
func (tree *Tree4[T]) GetCIDR(cidr *net.IPNet) (ok bool, value T) {
|
||||
bit := startbit
|
||||
node := tree.root
|
||||
|
||||
ip := iputil.Ip2VpnIp(cidr.IP)
|
||||
mask := iputil.Ip2VpnIp(cidr.Mask)
|
||||
|
||||
// Find our last ancestor in the tree
|
||||
for node != nil && bit&mask != 0 {
|
||||
if ip&bit != 0 {
|
||||
node = node.right
|
||||
} else {
|
||||
node = node.left
|
||||
}
|
||||
|
||||
bit = bit >> 1
|
||||
}
|
||||
|
||||
if bit&mask == 0 && node != nil {
|
||||
value = node.value
|
||||
ok = node.hasValue
|
||||
}
|
||||
|
||||
return ok, value
|
||||
}
|
||||
|
||||
// List will return all CIDRs and their current values. Do not modify the contents!
|
||||
func (tree *Tree4[T]) List() []entry[T] {
|
||||
return tree.list
|
||||
}
|
|
@ -1,170 +0,0 @@
|
|||
package cidr
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestCIDRTree_List(t *testing.T) {
|
||||
tree := NewTree4[string]()
|
||||
tree.AddCIDR(Parse("1.0.0.0/16"), "1")
|
||||
tree.AddCIDR(Parse("1.0.0.0/8"), "2")
|
||||
tree.AddCIDR(Parse("1.0.0.0/16"), "3")
|
||||
tree.AddCIDR(Parse("1.0.0.0/16"), "4")
|
||||
list := tree.List()
|
||||
assert.Len(t, list, 2)
|
||||
assert.Equal(t, "1.0.0.0/8", list[0].CIDR.String())
|
||||
assert.Equal(t, "2", list[0].Value)
|
||||
assert.Equal(t, "1.0.0.0/16", list[1].CIDR.String())
|
||||
assert.Equal(t, "4", list[1].Value)
|
||||
}
|
||||
|
||||
func TestCIDRTree_Contains(t *testing.T) {
|
||||
tree := NewTree4[string]()
|
||||
tree.AddCIDR(Parse("1.0.0.0/8"), "1")
|
||||
tree.AddCIDR(Parse("2.1.0.0/16"), "2")
|
||||
tree.AddCIDR(Parse("3.1.1.0/24"), "3")
|
||||
tree.AddCIDR(Parse("4.1.1.0/24"), "4a")
|
||||
tree.AddCIDR(Parse("4.1.1.1/32"), "4b")
|
||||
tree.AddCIDR(Parse("4.1.2.1/32"), "4c")
|
||||
tree.AddCIDR(Parse("254.0.0.0/4"), "5")
|
||||
|
||||
tests := []struct {
|
||||
Found bool
|
||||
Result interface{}
|
||||
IP string
|
||||
}{
|
||||
{true, "1", "1.0.0.0"},
|
||||
{true, "1", "1.255.255.255"},
|
||||
{true, "2", "2.1.0.0"},
|
||||
{true, "2", "2.1.255.255"},
|
||||
{true, "3", "3.1.1.0"},
|
||||
{true, "3", "3.1.1.255"},
|
||||
{true, "4a", "4.1.1.255"},
|
||||
{true, "4a", "4.1.1.1"},
|
||||
{true, "5", "240.0.0.0"},
|
||||
{true, "5", "255.255.255.255"},
|
||||
{false, "", "239.0.0.0"},
|
||||
{false, "", "4.1.2.2"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
ok, r := tree.Contains(iputil.Ip2VpnIp(net.ParseIP(tt.IP)))
|
||||
assert.Equal(t, tt.Found, ok)
|
||||
assert.Equal(t, tt.Result, r)
|
||||
}
|
||||
|
||||
tree = NewTree4[string]()
|
||||
tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
|
||||
ok, r := tree.Contains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0")))
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "cool", r)
|
||||
|
||||
ok, r = tree.Contains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255")))
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "cool", r)
|
||||
}
|
||||
|
||||
func TestCIDRTree_MostSpecificContains(t *testing.T) {
|
||||
tree := NewTree4[string]()
|
||||
tree.AddCIDR(Parse("1.0.0.0/8"), "1")
|
||||
tree.AddCIDR(Parse("2.1.0.0/16"), "2")
|
||||
tree.AddCIDR(Parse("3.1.1.0/24"), "3")
|
||||
tree.AddCIDR(Parse("4.1.1.0/24"), "4a")
|
||||
tree.AddCIDR(Parse("4.1.1.0/30"), "4b")
|
||||
tree.AddCIDR(Parse("4.1.1.1/32"), "4c")
|
||||
tree.AddCIDR(Parse("254.0.0.0/4"), "5")
|
||||
|
||||
tests := []struct {
|
||||
Found bool
|
||||
Result interface{}
|
||||
IP string
|
||||
}{
|
||||
{true, "1", "1.0.0.0"},
|
||||
{true, "1", "1.255.255.255"},
|
||||
{true, "2", "2.1.0.0"},
|
||||
{true, "2", "2.1.255.255"},
|
||||
{true, "3", "3.1.1.0"},
|
||||
{true, "3", "3.1.1.255"},
|
||||
{true, "4a", "4.1.1.255"},
|
||||
{true, "4b", "4.1.1.2"},
|
||||
{true, "4c", "4.1.1.1"},
|
||||
{true, "5", "240.0.0.0"},
|
||||
{true, "5", "255.255.255.255"},
|
||||
{false, "", "239.0.0.0"},
|
||||
{false, "", "4.1.2.2"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
ok, r := tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP(tt.IP)))
|
||||
assert.Equal(t, tt.Found, ok)
|
||||
assert.Equal(t, tt.Result, r)
|
||||
}
|
||||
|
||||
tree = NewTree4[string]()
|
||||
tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
|
||||
ok, r := tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0")))
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "cool", r)
|
||||
|
||||
ok, r = tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255")))
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "cool", r)
|
||||
}
|
||||
|
||||
func TestTree4_GetCIDR(t *testing.T) {
|
||||
tree := NewTree4[string]()
|
||||
tree.AddCIDR(Parse("1.0.0.0/8"), "1")
|
||||
tree.AddCIDR(Parse("2.1.0.0/16"), "2")
|
||||
tree.AddCIDR(Parse("3.1.1.0/24"), "3")
|
||||
tree.AddCIDR(Parse("4.1.1.0/24"), "4a")
|
||||
tree.AddCIDR(Parse("4.1.1.1/32"), "4b")
|
||||
tree.AddCIDR(Parse("4.1.2.1/32"), "4c")
|
||||
tree.AddCIDR(Parse("254.0.0.0/4"), "5")
|
||||
|
||||
tests := []struct {
|
||||
Found bool
|
||||
Result interface{}
|
||||
IPNet *net.IPNet
|
||||
}{
|
||||
{true, "1", Parse("1.0.0.0/8")},
|
||||
{true, "2", Parse("2.1.0.0/16")},
|
||||
{true, "3", Parse("3.1.1.0/24")},
|
||||
{true, "4a", Parse("4.1.1.0/24")},
|
||||
{true, "4b", Parse("4.1.1.1/32")},
|
||||
{true, "4c", Parse("4.1.2.1/32")},
|
||||
{true, "5", Parse("254.0.0.0/4")},
|
||||
{false, "", Parse("2.0.0.0/8")},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
ok, r := tree.GetCIDR(tt.IPNet)
|
||||
assert.Equal(t, tt.Found, ok)
|
||||
assert.Equal(t, tt.Result, r)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkCIDRTree_Contains(b *testing.B) {
|
||||
tree := NewTree4[string]()
|
||||
tree.AddCIDR(Parse("1.1.0.0/16"), "1")
|
||||
tree.AddCIDR(Parse("1.2.1.1/32"), "1")
|
||||
tree.AddCIDR(Parse("192.2.1.1/32"), "1")
|
||||
tree.AddCIDR(Parse("172.2.1.1/32"), "1")
|
||||
|
||||
ip := iputil.Ip2VpnIp(net.ParseIP("1.2.1.1"))
|
||||
b.Run("found", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
tree.Contains(ip)
|
||||
}
|
||||
})
|
||||
|
||||
ip = iputil.Ip2VpnIp(net.ParseIP("1.2.1.255"))
|
||||
b.Run("not found", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
tree.Contains(ip)
|
||||
}
|
||||
})
|
||||
}
|
189
cidr/tree6.go
189
cidr/tree6.go
|
@ -1,189 +0,0 @@
|
|||
package cidr
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
)
|
||||
|
||||
const startbit6 = uint64(1 << 63)
|
||||
|
||||
type Tree6[T any] struct {
|
||||
root4 *Node[T]
|
||||
root6 *Node[T]
|
||||
}
|
||||
|
||||
func NewTree6[T any]() *Tree6[T] {
|
||||
tree := new(Tree6[T])
|
||||
tree.root4 = &Node[T]{}
|
||||
tree.root6 = &Node[T]{}
|
||||
return tree
|
||||
}
|
||||
|
||||
func (tree *Tree6[T]) AddCIDR(cidr *net.IPNet, val T) {
|
||||
var node, next *Node[T]
|
||||
|
||||
cidrIP, ipv4 := isIPV4(cidr.IP)
|
||||
if ipv4 {
|
||||
node = tree.root4
|
||||
next = tree.root4
|
||||
|
||||
} else {
|
||||
node = tree.root6
|
||||
next = tree.root6
|
||||
}
|
||||
|
||||
for i := 0; i < len(cidrIP); i += 4 {
|
||||
ip := iputil.Ip2VpnIp(cidrIP[i : i+4])
|
||||
mask := iputil.Ip2VpnIp(cidr.Mask[i : i+4])
|
||||
bit := startbit
|
||||
|
||||
// Find our last ancestor in the tree
|
||||
for bit&mask != 0 {
|
||||
if ip&bit != 0 {
|
||||
next = node.right
|
||||
} else {
|
||||
next = node.left
|
||||
}
|
||||
|
||||
if next == nil {
|
||||
break
|
||||
}
|
||||
|
||||
bit = bit >> 1
|
||||
node = next
|
||||
}
|
||||
|
||||
// Build up the rest of the tree we don't already have
|
||||
for bit&mask != 0 {
|
||||
next = &Node[T]{}
|
||||
next.parent = node
|
||||
|
||||
if ip&bit != 0 {
|
||||
node.right = next
|
||||
} else {
|
||||
node.left = next
|
||||
}
|
||||
|
||||
bit >>= 1
|
||||
node = next
|
||||
}
|
||||
}
|
||||
|
||||
// Final node marks our cidr, set the value
|
||||
node.value = val
|
||||
node.hasValue = true
|
||||
}
|
||||
|
||||
// Finds the most specific match
|
||||
func (tree *Tree6[T]) MostSpecificContains(ip net.IP) (ok bool, value T) {
|
||||
var node *Node[T]
|
||||
|
||||
wholeIP, ipv4 := isIPV4(ip)
|
||||
if ipv4 {
|
||||
node = tree.root4
|
||||
} else {
|
||||
node = tree.root6
|
||||
}
|
||||
|
||||
for i := 0; i < len(wholeIP); i += 4 {
|
||||
ip := iputil.Ip2VpnIp(wholeIP[i : i+4])
|
||||
bit := startbit
|
||||
|
||||
for node != nil {
|
||||
if node.hasValue {
|
||||
value = node.value
|
||||
ok = true
|
||||
}
|
||||
|
||||
if bit == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
if ip&bit != 0 {
|
||||
node = node.right
|
||||
} else {
|
||||
node = node.left
|
||||
}
|
||||
|
||||
bit >>= 1
|
||||
}
|
||||
}
|
||||
|
||||
return ok, value
|
||||
}
|
||||
|
||||
func (tree *Tree6[T]) MostSpecificContainsIpV4(ip iputil.VpnIp) (ok bool, value T) {
|
||||
bit := startbit
|
||||
node := tree.root4
|
||||
|
||||
for node != nil {
|
||||
if node.hasValue {
|
||||
value = node.value
|
||||
ok = true
|
||||
}
|
||||
|
||||
if ip&bit != 0 {
|
||||
node = node.right
|
||||
} else {
|
||||
node = node.left
|
||||
}
|
||||
|
||||
bit >>= 1
|
||||
}
|
||||
|
||||
return ok, value
|
||||
}
|
||||
|
||||
func (tree *Tree6[T]) MostSpecificContainsIpV6(hi, lo uint64) (ok bool, value T) {
|
||||
ip := hi
|
||||
node := tree.root6
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
bit := startbit6
|
||||
|
||||
for node != nil {
|
||||
if node.hasValue {
|
||||
value = node.value
|
||||
ok = true
|
||||
}
|
||||
|
||||
if bit == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
if ip&bit != 0 {
|
||||
node = node.right
|
||||
} else {
|
||||
node = node.left
|
||||
}
|
||||
|
||||
bit >>= 1
|
||||
}
|
||||
|
||||
ip = lo
|
||||
}
|
||||
|
||||
return ok, value
|
||||
}
|
||||
|
||||
func isIPV4(ip net.IP) (net.IP, bool) {
|
||||
if len(ip) == net.IPv4len {
|
||||
return ip, true
|
||||
}
|
||||
|
||||
if len(ip) == net.IPv6len && isZeros(ip[0:10]) && ip[10] == 0xff && ip[11] == 0xff {
|
||||
return ip[12:16], true
|
||||
}
|
||||
|
||||
return ip, false
|
||||
}
|
||||
|
||||
func isZeros(p net.IP) bool {
|
||||
for i := 0; i < len(p); i++ {
|
||||
if p[i] != 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
|
@ -1,98 +0,0 @@
|
|||
package cidr
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestCIDR6Tree_MostSpecificContains(t *testing.T) {
|
||||
tree := NewTree6[string]()
|
||||
tree.AddCIDR(Parse("1.0.0.0/8"), "1")
|
||||
tree.AddCIDR(Parse("2.1.0.0/16"), "2")
|
||||
tree.AddCIDR(Parse("3.1.1.0/24"), "3")
|
||||
tree.AddCIDR(Parse("4.1.1.1/24"), "4a")
|
||||
tree.AddCIDR(Parse("4.1.1.1/30"), "4b")
|
||||
tree.AddCIDR(Parse("4.1.1.1/32"), "4c")
|
||||
tree.AddCIDR(Parse("254.0.0.0/4"), "5")
|
||||
tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/64"), "6a")
|
||||
tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/80"), "6b")
|
||||
tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/96"), "6c")
|
||||
|
||||
tests := []struct {
|
||||
Found bool
|
||||
Result interface{}
|
||||
IP string
|
||||
}{
|
||||
{true, "1", "1.0.0.0"},
|
||||
{true, "1", "1.255.255.255"},
|
||||
{true, "2", "2.1.0.0"},
|
||||
{true, "2", "2.1.255.255"},
|
||||
{true, "3", "3.1.1.0"},
|
||||
{true, "3", "3.1.1.255"},
|
||||
{true, "4a", "4.1.1.255"},
|
||||
{true, "4b", "4.1.1.2"},
|
||||
{true, "4c", "4.1.1.1"},
|
||||
{true, "5", "240.0.0.0"},
|
||||
{true, "5", "255.255.255.255"},
|
||||
{true, "6a", "1:2:0:4:1:1:1:1"},
|
||||
{true, "6b", "1:2:0:4:5:1:1:1"},
|
||||
{true, "6c", "1:2:0:4:5:0:0:0"},
|
||||
{false, "", "239.0.0.0"},
|
||||
{false, "", "4.1.2.2"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
ok, r := tree.MostSpecificContains(net.ParseIP(tt.IP))
|
||||
assert.Equal(t, tt.Found, ok)
|
||||
assert.Equal(t, tt.Result, r)
|
||||
}
|
||||
|
||||
tree = NewTree6[string]()
|
||||
tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
|
||||
tree.AddCIDR(Parse("::/0"), "cool6")
|
||||
ok, r := tree.MostSpecificContains(net.ParseIP("0.0.0.0"))
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "cool", r)
|
||||
|
||||
ok, r = tree.MostSpecificContains(net.ParseIP("255.255.255.255"))
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "cool", r)
|
||||
|
||||
ok, r = tree.MostSpecificContains(net.ParseIP("::"))
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "cool6", r)
|
||||
|
||||
ok, r = tree.MostSpecificContains(net.ParseIP("1:2:3:4:5:6:7:8"))
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "cool6", r)
|
||||
}
|
||||
|
||||
func TestCIDR6Tree_MostSpecificContainsIpV6(t *testing.T) {
|
||||
tree := NewTree6[string]()
|
||||
tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/64"), "6a")
|
||||
tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/80"), "6b")
|
||||
tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/96"), "6c")
|
||||
|
||||
tests := []struct {
|
||||
Found bool
|
||||
Result interface{}
|
||||
IP string
|
||||
}{
|
||||
{true, "6a", "1:2:0:4:1:1:1:1"},
|
||||
{true, "6b", "1:2:0:4:5:1:1:1"},
|
||||
{true, "6c", "1:2:0:4:5:0:0:0"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
ip := net.ParseIP(tt.IP)
|
||||
hi := binary.BigEndian.Uint64(ip[:8])
|
||||
lo := binary.BigEndian.Uint64(ip[8:])
|
||||
|
||||
ok, r := tree.MostSpecificContainsIpV6(hi, lo)
|
||||
assert.Equal(t, tt.Found, ok)
|
||||
assert.Equal(t, tt.Result, r)
|
||||
}
|
||||
}
|
|
@ -3,6 +3,8 @@ package nebula
|
|||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
|
@ -10,8 +12,6 @@ import (
|
|||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/header"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/slackhq/nebula/udp"
|
||||
)
|
||||
|
||||
type trafficDecision int
|
||||
|
@ -224,8 +224,8 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
|
|||
existing, ok := newhostinfo.relayState.QueryRelayForByIp(r.PeerIp)
|
||||
|
||||
var index uint32
|
||||
var relayFrom iputil.VpnIp
|
||||
var relayTo iputil.VpnIp
|
||||
var relayFrom netip.Addr
|
||||
var relayTo netip.Addr
|
||||
switch {
|
||||
case ok && existing.State == Established:
|
||||
// This relay already exists in newhostinfo, then do nothing.
|
||||
|
@ -235,7 +235,7 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
|
|||
index = existing.LocalIndex
|
||||
switch r.Type {
|
||||
case TerminalType:
|
||||
relayFrom = n.intf.myVpnIp
|
||||
relayFrom = n.intf.myVpnNet.Addr()
|
||||
relayTo = existing.PeerIp
|
||||
case ForwardingType:
|
||||
relayFrom = existing.PeerIp
|
||||
|
@ -260,7 +260,7 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
|
|||
}
|
||||
switch r.Type {
|
||||
case TerminalType:
|
||||
relayFrom = n.intf.myVpnIp
|
||||
relayFrom = n.intf.myVpnNet.Addr()
|
||||
relayTo = r.PeerIp
|
||||
case ForwardingType:
|
||||
relayFrom = r.PeerIp
|
||||
|
@ -270,12 +270,16 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
|
|||
}
|
||||
}
|
||||
|
||||
//TODO: IPV6-WORK
|
||||
relayFromB := relayFrom.As4()
|
||||
relayToB := relayTo.As4()
|
||||
|
||||
// Send a CreateRelayRequest to the peer.
|
||||
req := NebulaControl{
|
||||
Type: NebulaControl_CreateRelayRequest,
|
||||
InitiatorRelayIndex: index,
|
||||
RelayFromIp: uint32(relayFrom),
|
||||
RelayToIp: uint32(relayTo),
|
||||
RelayFromIp: binary.BigEndian.Uint32(relayFromB[:]),
|
||||
RelayToIp: binary.BigEndian.Uint32(relayToB[:]),
|
||||
}
|
||||
msg, err := req.Marshal()
|
||||
if err != nil {
|
||||
|
@ -283,8 +287,8 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
|
|||
} else {
|
||||
n.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu))
|
||||
n.l.WithFields(logrus.Fields{
|
||||
"relayFrom": iputil.VpnIp(req.RelayFromIp),
|
||||
"relayTo": iputil.VpnIp(req.RelayToIp),
|
||||
"relayFrom": req.RelayFromIp,
|
||||
"relayTo": req.RelayToIp,
|
||||
"initiatorRelayIndex": req.InitiatorRelayIndex,
|
||||
"responderRelayIndex": req.ResponderRelayIndex,
|
||||
"vpnIp": newhostinfo.vpnIp}).
|
||||
|
@ -403,7 +407,7 @@ func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool {
|
|||
// If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary.
|
||||
// Let's sort this out.
|
||||
|
||||
if current.vpnIp < n.intf.myVpnIp {
|
||||
if current.vpnIp.Compare(n.intf.myVpnNet.Addr()) < 0 {
|
||||
// Only one side should flip primary because if both flip then we may never resolve to a single tunnel.
|
||||
// vpn ip is static across all tunnels for this host pair so lets use that to determine who is flipping.
|
||||
// The remotes vpn ip is lower than mine. I will not flip.
|
||||
|
@ -457,12 +461,12 @@ func (n *connectionManager) sendPunch(hostinfo *HostInfo) {
|
|||
}
|
||||
|
||||
if n.punchy.GetTargetEverything() {
|
||||
hostinfo.remotes.ForEach(n.hostMap.GetPreferredRanges(), func(addr *udp.Addr, preferred bool) {
|
||||
hostinfo.remotes.ForEach(n.hostMap.GetPreferredRanges(), func(addr netip.AddrPort, preferred bool) {
|
||||
n.metricsTxPunchy.Inc(1)
|
||||
n.intf.outside.WriteTo([]byte{1}, addr)
|
||||
})
|
||||
|
||||
} else if hostinfo.remote != nil {
|
||||
} else if hostinfo.remote.IsValid() {
|
||||
n.metricsTxPunchy.Inc(1)
|
||||
n.intf.outside.WriteTo([]byte{1}, hostinfo.remote)
|
||||
}
|
||||
|
|
|
@ -5,28 +5,26 @@ import (
|
|||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/flynn/noise"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/slackhq/nebula/test"
|
||||
"github.com/slackhq/nebula/udp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var vpnIp iputil.VpnIp
|
||||
|
||||
func newTestLighthouse() *LightHouse {
|
||||
lh := &LightHouse{
|
||||
l: test.NewLogger(),
|
||||
addrMap: map[iputil.VpnIp]*RemoteList{},
|
||||
queryChan: make(chan iputil.VpnIp, 10),
|
||||
addrMap: map[netip.Addr]*RemoteList{},
|
||||
queryChan: make(chan netip.Addr, 10),
|
||||
}
|
||||
lighthouses := map[iputil.VpnIp]struct{}{}
|
||||
staticList := map[iputil.VpnIp]struct{}{}
|
||||
lighthouses := map[netip.Addr]struct{}{}
|
||||
staticList := map[netip.Addr]struct{}{}
|
||||
|
||||
lh.lighthouses.Store(&lighthouses)
|
||||
lh.staticList.Store(&staticList)
|
||||
|
@ -37,10 +35,10 @@ func newTestLighthouse() *LightHouse {
|
|||
func Test_NewConnectionManagerTest(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
|
||||
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
|
||||
vpnIp = iputil.Ip2VpnIp(net.ParseIP("172.1.1.2"))
|
||||
preferredRanges := []*net.IPNet{localrange}
|
||||
vpncidr := netip.MustParsePrefix("172.1.1.1/24")
|
||||
localrange := netip.MustParsePrefix("10.1.1.1/24")
|
||||
vpnIp := netip.MustParseAddr("172.1.1.2")
|
||||
preferredRanges := []netip.Prefix{localrange}
|
||||
|
||||
// Very incomplete mock objects
|
||||
hostMap := newHostMap(l, vpncidr)
|
||||
|
@ -120,9 +118,10 @@ func Test_NewConnectionManagerTest(t *testing.T) {
|
|||
func Test_NewConnectionManagerTest2(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
|
||||
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
|
||||
preferredRanges := []*net.IPNet{localrange}
|
||||
vpncidr := netip.MustParsePrefix("172.1.1.1/24")
|
||||
localrange := netip.MustParsePrefix("10.1.1.1/24")
|
||||
vpnIp := netip.MustParseAddr("172.1.1.2")
|
||||
preferredRanges := []netip.Prefix{localrange}
|
||||
|
||||
// Very incomplete mock objects
|
||||
hostMap := newHostMap(l, vpncidr)
|
||||
|
@ -211,9 +210,10 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
|
|||
IP: net.IPv4(172, 1, 1, 2),
|
||||
Mask: net.IPMask{255, 255, 255, 0},
|
||||
}
|
||||
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
|
||||
preferredRanges := []*net.IPNet{localrange}
|
||||
vpncidr := netip.MustParsePrefix("172.1.1.1/24")
|
||||
localrange := netip.MustParsePrefix("10.1.1.1/24")
|
||||
vpnIp := netip.MustParseAddr("172.1.1.2")
|
||||
preferredRanges := []netip.Prefix{localrange}
|
||||
hostMap := newHostMap(l, vpncidr)
|
||||
hostMap.preferredRanges.Store(&preferredRanges)
|
||||
|
||||
|
|
40
control.go
40
control.go
|
@ -2,7 +2,7 @@ package nebula
|
|||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
@ -10,9 +10,7 @@ import (
|
|||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/header"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/slackhq/nebula/overlay"
|
||||
"github.com/slackhq/nebula/udp"
|
||||
)
|
||||
|
||||
// Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching
|
||||
|
@ -21,10 +19,10 @@ import (
|
|||
type controlEach func(h *HostInfo)
|
||||
|
||||
type controlHostLister interface {
|
||||
QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo
|
||||
QueryVpnIp(vpnIp netip.Addr) *HostInfo
|
||||
ForEachIndex(each controlEach)
|
||||
ForEachVpnIp(each controlEach)
|
||||
GetPreferredRanges() []*net.IPNet
|
||||
GetPreferredRanges() []netip.Prefix
|
||||
}
|
||||
|
||||
type Control struct {
|
||||
|
@ -39,15 +37,15 @@ type Control struct {
|
|||
}
|
||||
|
||||
type ControlHostInfo struct {
|
||||
VpnIp net.IP `json:"vpnIp"`
|
||||
VpnIp netip.Addr `json:"vpnIp"`
|
||||
LocalIndex uint32 `json:"localIndex"`
|
||||
RemoteIndex uint32 `json:"remoteIndex"`
|
||||
RemoteAddrs []*udp.Addr `json:"remoteAddrs"`
|
||||
RemoteAddrs []netip.AddrPort `json:"remoteAddrs"`
|
||||
Cert *cert.NebulaCertificate `json:"cert"`
|
||||
MessageCounter uint64 `json:"messageCounter"`
|
||||
CurrentRemote *udp.Addr `json:"currentRemote"`
|
||||
CurrentRelaysToMe []iputil.VpnIp `json:"currentRelaysToMe"`
|
||||
CurrentRelaysThroughMe []iputil.VpnIp `json:"currentRelaysThroughMe"`
|
||||
CurrentRemote netip.AddrPort `json:"currentRemote"`
|
||||
CurrentRelaysToMe []netip.Addr `json:"currentRelaysToMe"`
|
||||
CurrentRelaysThroughMe []netip.Addr `json:"currentRelaysThroughMe"`
|
||||
}
|
||||
|
||||
// Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock()
|
||||
|
@ -132,7 +130,8 @@ func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo {
|
|||
}
|
||||
|
||||
// GetHostInfoByVpnIp returns a single tunnels hostInfo, or nil if not found
|
||||
func (c *Control) GetHostInfoByVpnIp(vpnIp iputil.VpnIp, pending bool) *ControlHostInfo {
|
||||
// Caller should take care to Unmap() any 4in6 addresses prior to calling.
|
||||
func (c *Control) GetHostInfoByVpnIp(vpnIp netip.Addr, pending bool) *ControlHostInfo {
|
||||
var hl controlHostLister
|
||||
if pending {
|
||||
hl = c.f.handshakeManager
|
||||
|
@ -150,19 +149,21 @@ func (c *Control) GetHostInfoByVpnIp(vpnIp iputil.VpnIp, pending bool) *ControlH
|
|||
}
|
||||
|
||||
// SetRemoteForTunnel forces a tunnel to use a specific remote
|
||||
func (c *Control) SetRemoteForTunnel(vpnIp iputil.VpnIp, addr udp.Addr) *ControlHostInfo {
|
||||
// Caller should take care to Unmap() any 4in6 addresses prior to calling.
|
||||
func (c *Control) SetRemoteForTunnel(vpnIp netip.Addr, addr netip.AddrPort) *ControlHostInfo {
|
||||
hostInfo := c.f.hostMap.QueryVpnIp(vpnIp)
|
||||
if hostInfo == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
hostInfo.SetRemote(addr.Copy())
|
||||
hostInfo.SetRemote(addr)
|
||||
ch := copyHostInfo(hostInfo, c.f.hostMap.GetPreferredRanges())
|
||||
return &ch
|
||||
}
|
||||
|
||||
// CloseTunnel closes a fully established tunnel. If localOnly is false it will notify the remote end as well.
|
||||
func (c *Control) CloseTunnel(vpnIp iputil.VpnIp, localOnly bool) bool {
|
||||
// Caller should take care to Unmap() any 4in6 addresses prior to calling.
|
||||
func (c *Control) CloseTunnel(vpnIp netip.Addr, localOnly bool) bool {
|
||||
hostInfo := c.f.hostMap.QueryVpnIp(vpnIp)
|
||||
if hostInfo == nil {
|
||||
return false
|
||||
|
@ -205,7 +206,7 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
|
|||
}
|
||||
|
||||
// Learn which hosts are being used as relays, so we can shut them down last.
|
||||
relayingHosts := map[iputil.VpnIp]*HostInfo{}
|
||||
relayingHosts := map[netip.Addr]*HostInfo{}
|
||||
// Grab the hostMap lock to access the Relays map
|
||||
c.f.hostMap.Lock()
|
||||
for _, relayingHost := range c.f.hostMap.Relays {
|
||||
|
@ -236,15 +237,16 @@ func (c *Control) Device() overlay.Device {
|
|||
return c.f.inside
|
||||
}
|
||||
|
||||
func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo {
|
||||
func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo {
|
||||
|
||||
chi := ControlHostInfo{
|
||||
VpnIp: h.vpnIp.ToIP(),
|
||||
VpnIp: h.vpnIp,
|
||||
LocalIndex: h.localIndexId,
|
||||
RemoteIndex: h.remoteIndexId,
|
||||
RemoteAddrs: h.remotes.CopyAddrs(preferredRanges),
|
||||
CurrentRelaysToMe: h.relayState.CopyRelayIps(),
|
||||
CurrentRelaysThroughMe: h.relayState.CopyRelayForIps(),
|
||||
CurrentRemote: h.remote,
|
||||
}
|
||||
|
||||
if h.ConnectionState != nil {
|
||||
|
@ -255,10 +257,6 @@ func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo {
|
|||
chi.Cert = c.Copy()
|
||||
}
|
||||
|
||||
if h.remote != nil {
|
||||
chi.CurrentRemote = h.remote.Copy()
|
||||
}
|
||||
|
||||
return chi
|
||||
}
|
||||
|
||||
|
|
|
@ -2,15 +2,14 @@ package nebula
|
|||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/slackhq/nebula/test"
|
||||
"github.com/slackhq/nebula/udp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
|
@ -18,18 +17,19 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
|
|||
l := test.NewLogger()
|
||||
// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
|
||||
// To properly ensure we are not exposing core memory to the caller
|
||||
hm := newHostMap(l, &net.IPNet{})
|
||||
hm.preferredRanges.Store(&[]*net.IPNet{})
|
||||
hm := newHostMap(l, netip.Prefix{})
|
||||
hm.preferredRanges.Store(&[]netip.Prefix{})
|
||||
|
||||
remote1 := netip.MustParseAddrPort("0.0.0.100:4444")
|
||||
remote2 := netip.MustParseAddrPort("[1:2:3:4:5:6:7:8]:4444")
|
||||
|
||||
remote1 := udp.NewAddr(net.ParseIP("0.0.0.100"), 4444)
|
||||
remote2 := udp.NewAddr(net.ParseIP("1:2:3:4:5:6:7:8"), 4444)
|
||||
ipNet := net.IPNet{
|
||||
IP: net.IPv4(1, 2, 3, 4),
|
||||
IP: remote1.Addr().AsSlice(),
|
||||
Mask: net.IPMask{255, 255, 255, 0},
|
||||
}
|
||||
|
||||
ipNet2 := net.IPNet{
|
||||
IP: net.ParseIP("1:2:3:4:5:6:7:8"),
|
||||
IP: remote2.Addr().AsSlice(),
|
||||
Mask: net.IPMask{255, 255, 255, 0},
|
||||
}
|
||||
|
||||
|
@ -50,8 +50,12 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
|
|||
}
|
||||
|
||||
remotes := NewRemoteList(nil)
|
||||
remotes.unlockedPrependV4(0, NewIp4AndPort(remote1.IP, uint32(remote1.Port)))
|
||||
remotes.unlockedPrependV6(0, NewIp6AndPort(remote2.IP, uint32(remote2.Port)))
|
||||
remotes.unlockedPrependV4(netip.IPv4Unspecified(), NewIp4AndPortFromNetIP(remote1.Addr(), remote1.Port()))
|
||||
remotes.unlockedPrependV6(netip.IPv4Unspecified(), NewIp6AndPortFromNetIP(remote2.Addr(), remote2.Port()))
|
||||
|
||||
vpnIp, ok := netip.AddrFromSlice(ipNet.IP)
|
||||
assert.True(t, ok)
|
||||
|
||||
hm.unlockedAddHostInfo(&HostInfo{
|
||||
remote: remote1,
|
||||
remotes: remotes,
|
||||
|
@ -60,14 +64,17 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
|
|||
},
|
||||
remoteIndexId: 200,
|
||||
localIndexId: 201,
|
||||
vpnIp: iputil.Ip2VpnIp(ipNet.IP),
|
||||
vpnIp: vpnIp,
|
||||
relayState: RelayState{
|
||||
relays: map[iputil.VpnIp]struct{}{},
|
||||
relayForByIp: map[iputil.VpnIp]*Relay{},
|
||||
relays: map[netip.Addr]struct{}{},
|
||||
relayForByIp: map[netip.Addr]*Relay{},
|
||||
relayForByIdx: map[uint32]*Relay{},
|
||||
},
|
||||
}, &Interface{})
|
||||
|
||||
vpnIp2, ok := netip.AddrFromSlice(ipNet2.IP)
|
||||
assert.True(t, ok)
|
||||
|
||||
hm.unlockedAddHostInfo(&HostInfo{
|
||||
remote: remote1,
|
||||
remotes: remotes,
|
||||
|
@ -76,10 +83,10 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
|
|||
},
|
||||
remoteIndexId: 200,
|
||||
localIndexId: 201,
|
||||
vpnIp: iputil.Ip2VpnIp(ipNet2.IP),
|
||||
vpnIp: vpnIp2,
|
||||
relayState: RelayState{
|
||||
relays: map[iputil.VpnIp]struct{}{},
|
||||
relayForByIp: map[iputil.VpnIp]*Relay{},
|
||||
relays: map[netip.Addr]struct{}{},
|
||||
relayForByIp: map[netip.Addr]*Relay{},
|
||||
relayForByIdx: map[uint32]*Relay{},
|
||||
},
|
||||
}, &Interface{})
|
||||
|
@ -91,27 +98,29 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
|
|||
l: logrus.New(),
|
||||
}
|
||||
|
||||
thi := c.GetHostInfoByVpnIp(iputil.Ip2VpnIp(ipNet.IP), false)
|
||||
thi := c.GetHostInfoByVpnIp(vpnIp, false)
|
||||
|
||||
expectedInfo := ControlHostInfo{
|
||||
VpnIp: net.IPv4(1, 2, 3, 4).To4(),
|
||||
VpnIp: vpnIp,
|
||||
LocalIndex: 201,
|
||||
RemoteIndex: 200,
|
||||
RemoteAddrs: []*udp.Addr{remote2, remote1},
|
||||
RemoteAddrs: []netip.AddrPort{remote2, remote1},
|
||||
Cert: crt.Copy(),
|
||||
MessageCounter: 0,
|
||||
CurrentRemote: udp.NewAddr(net.ParseIP("0.0.0.100"), 4444),
|
||||
CurrentRelaysToMe: []iputil.VpnIp{},
|
||||
CurrentRelaysThroughMe: []iputil.VpnIp{},
|
||||
CurrentRemote: remote1,
|
||||
CurrentRelaysToMe: []netip.Addr{},
|
||||
CurrentRelaysThroughMe: []netip.Addr{},
|
||||
}
|
||||
|
||||
// Make sure we don't have any unexpected fields
|
||||
assertFields(t, []string{"VpnIp", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi)
|
||||
test.AssertDeepCopyEqual(t, &expectedInfo, thi)
|
||||
assert.EqualValues(t, &expectedInfo, thi)
|
||||
//TODO: netip.Addr reuses global memory for zone identifiers which breaks our "no reused memory check" here
|
||||
//test.AssertDeepCopyEqual(t, &expectedInfo, thi)
|
||||
|
||||
// Make sure we don't panic if the host info doesn't have a cert yet
|
||||
assert.NotPanics(t, func() {
|
||||
thi = c.GetHostInfoByVpnIp(iputil.Ip2VpnIp(ipNet2.IP), false)
|
||||
thi = c.GetHostInfoByVpnIp(vpnIp2, false)
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -4,14 +4,13 @@
|
|||
package nebula
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/slackhq/nebula/cert"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/slackhq/nebula/header"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/slackhq/nebula/overlay"
|
||||
"github.com/slackhq/nebula/udp"
|
||||
)
|
||||
|
@ -50,37 +49,30 @@ func (c *Control) WaitForTypeByIndex(toIndex uint32, msgType header.MessageType,
|
|||
|
||||
// InjectLightHouseAddr will push toAddr into the local lighthouse cache for the vpnIp
|
||||
// This is necessary if you did not configure static hosts or are not running a lighthouse
|
||||
func (c *Control) InjectLightHouseAddr(vpnIp net.IP, toAddr *net.UDPAddr) {
|
||||
func (c *Control) InjectLightHouseAddr(vpnIp netip.Addr, toAddr netip.AddrPort) {
|
||||
c.f.lightHouse.Lock()
|
||||
remoteList := c.f.lightHouse.unlockedGetRemoteList(iputil.Ip2VpnIp(vpnIp))
|
||||
remoteList := c.f.lightHouse.unlockedGetRemoteList(vpnIp)
|
||||
remoteList.Lock()
|
||||
defer remoteList.Unlock()
|
||||
c.f.lightHouse.Unlock()
|
||||
|
||||
iVpnIp := iputil.Ip2VpnIp(vpnIp)
|
||||
if v4 := toAddr.IP.To4(); v4 != nil {
|
||||
remoteList.unlockedPrependV4(iVpnIp, NewIp4AndPort(v4, uint32(toAddr.Port)))
|
||||
if toAddr.Addr().Is4() {
|
||||
remoteList.unlockedPrependV4(vpnIp, NewIp4AndPortFromNetIP(toAddr.Addr(), toAddr.Port()))
|
||||
} else {
|
||||
remoteList.unlockedPrependV6(iVpnIp, NewIp6AndPort(toAddr.IP, uint32(toAddr.Port)))
|
||||
remoteList.unlockedPrependV6(vpnIp, NewIp6AndPortFromNetIP(toAddr.Addr(), toAddr.Port()))
|
||||
}
|
||||
}
|
||||
|
||||
// InjectRelays will push relayVpnIps into the local lighthouse cache for the vpnIp
|
||||
// This is necessary to inform an initiator of possible relays for communicating with a responder
|
||||
func (c *Control) InjectRelays(vpnIp net.IP, relayVpnIps []net.IP) {
|
||||
func (c *Control) InjectRelays(vpnIp netip.Addr, relayVpnIps []netip.Addr) {
|
||||
c.f.lightHouse.Lock()
|
||||
remoteList := c.f.lightHouse.unlockedGetRemoteList(iputil.Ip2VpnIp(vpnIp))
|
||||
remoteList := c.f.lightHouse.unlockedGetRemoteList(vpnIp)
|
||||
remoteList.Lock()
|
||||
defer remoteList.Unlock()
|
||||
c.f.lightHouse.Unlock()
|
||||
|
||||
iVpnIp := iputil.Ip2VpnIp(vpnIp)
|
||||
uVpnIp := []uint32{}
|
||||
for _, rVPnIp := range relayVpnIps {
|
||||
uVpnIp = append(uVpnIp, uint32(iputil.Ip2VpnIp(rVPnIp)))
|
||||
}
|
||||
|
||||
remoteList.unlockedSetRelay(iVpnIp, iVpnIp, uVpnIp)
|
||||
remoteList.unlockedSetRelay(vpnIp, vpnIp, relayVpnIps)
|
||||
}
|
||||
|
||||
// GetFromTun will pull a packet off the tun side of nebula
|
||||
|
@ -107,13 +99,14 @@ func (c *Control) InjectUDPPacket(p *udp.Packet) {
|
|||
}
|
||||
|
||||
// InjectTunUDPPacket puts a udp packet on the tun interface. Using UDP here because it's a simpler protocol
|
||||
func (c *Control) InjectTunUDPPacket(toIp net.IP, toPort uint16, fromPort uint16, data []byte) {
|
||||
func (c *Control) InjectTunUDPPacket(toIp netip.Addr, toPort uint16, fromPort uint16, data []byte) {
|
||||
//TODO: IPV6-WORK
|
||||
ip := layers.IPv4{
|
||||
Version: 4,
|
||||
TTL: 64,
|
||||
Protocol: layers.IPProtocolUDP,
|
||||
SrcIP: c.f.inside.Cidr().IP,
|
||||
DstIP: toIp,
|
||||
SrcIP: c.f.inside.Cidr().Addr().Unmap().AsSlice(),
|
||||
DstIP: toIp.Unmap().AsSlice(),
|
||||
}
|
||||
|
||||
udp := layers.UDP{
|
||||
|
@ -138,16 +131,16 @@ func (c *Control) InjectTunUDPPacket(toIp net.IP, toPort uint16, fromPort uint16
|
|||
c.f.inside.(*overlay.TestTun).Send(buffer.Bytes())
|
||||
}
|
||||
|
||||
func (c *Control) GetVpnIp() iputil.VpnIp {
|
||||
return c.f.myVpnIp
|
||||
func (c *Control) GetVpnIp() netip.Addr {
|
||||
return c.f.myVpnNet.Addr()
|
||||
}
|
||||
|
||||
func (c *Control) GetUDPAddr() string {
|
||||
return c.f.outside.(*udp.TesterConn).Addr.String()
|
||||
func (c *Control) GetUDPAddr() netip.AddrPort {
|
||||
return c.f.outside.(*udp.TesterConn).Addr
|
||||
}
|
||||
|
||||
func (c *Control) KillPendingTunnel(vpnIp net.IP) bool {
|
||||
hostinfo := c.f.handshakeManager.QueryVpnIp(iputil.Ip2VpnIp(vpnIp))
|
||||
func (c *Control) KillPendingTunnel(vpnIp netip.Addr) bool {
|
||||
hostinfo := c.f.handshakeManager.QueryVpnIp(vpnIp)
|
||||
if hostinfo == nil {
|
||||
return false
|
||||
}
|
||||
|
@ -164,6 +157,6 @@ func (c *Control) GetCert() *cert.NebulaCertificate {
|
|||
return c.f.pki.GetCertState().Certificate
|
||||
}
|
||||
|
||||
func (c *Control) ReHandshake(vpnIp iputil.VpnIp) {
|
||||
func (c *Control) ReHandshake(vpnIp netip.Addr) {
|
||||
c.f.handshakeManager.StartHandshake(vpnIp, nil)
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package nebula
|
|||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
@ -10,7 +11,6 @@ import (
|
|||
"github.com/miekg/dns"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
)
|
||||
|
||||
// This whole thing should be rewritten to use context
|
||||
|
@ -42,19 +42,21 @@ func (d *dnsRecords) Query(data string) string {
|
|||
}
|
||||
|
||||
func (d *dnsRecords) QueryCert(data string) string {
|
||||
ip := net.ParseIP(data[:len(data)-1])
|
||||
if ip == nil {
|
||||
ip, err := netip.ParseAddr(data[:len(data)-1])
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
iip := iputil.Ip2VpnIp(ip)
|
||||
hostinfo := d.hostMap.QueryVpnIp(iip)
|
||||
|
||||
hostinfo := d.hostMap.QueryVpnIp(ip)
|
||||
if hostinfo == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
q := hostinfo.GetCert()
|
||||
if q == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
cert := q.Details
|
||||
c := fmt.Sprintf("\"Name: %s\" \"Ips: %s\" \"Subnets %s\" \"Groups %s\" \"NotBefore %s\" \"NotAfter %s\" \"PublicKey %x\" \"IsCA %t\" \"Issuer %s\"", cert.Name, cert.Ips, cert.Subnets, cert.Groups, cert.NotBefore, cert.NotAfter, cert.PublicKey, cert.IsCA, cert.Issuer)
|
||||
return c
|
||||
|
@ -80,7 +82,11 @@ func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) {
|
|||
}
|
||||
case dns.TypeTXT:
|
||||
a, _, _ := net.SplitHostPort(w.RemoteAddr().String())
|
||||
b := net.ParseIP(a)
|
||||
b, err := netip.ParseAddr(a)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// We don't answer these queries from non nebula nodes or localhost
|
||||
//l.Debugf("Does %s contain %s", b, dnsR.hostMap.vpnCIDR)
|
||||
if !dnsR.hostMap.vpnCIDR.Contains(b) && a != "127.0.0.1" {
|
||||
|
|
|
@ -5,7 +5,7 @@ package e2e
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -13,19 +13,18 @@ import (
|
|||
"github.com/slackhq/nebula"
|
||||
"github.com/slackhq/nebula/e2e/router"
|
||||
"github.com/slackhq/nebula/header"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/slackhq/nebula/udp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
func BenchmarkHotPath(b *testing.B) {
|
||||
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
|
||||
myControl, _, _, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil)
|
||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
|
||||
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
myControl, _, _, _ := newSimpleServer(ca, caKey, "me", "10.128.0.1/24", nil)
|
||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil)
|
||||
|
||||
// Put their info in our lighthouse
|
||||
myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
|
||||
myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
|
||||
|
||||
// Start the servers
|
||||
myControl.Start()
|
||||
|
@ -35,7 +34,7 @@ func BenchmarkHotPath(b *testing.B) {
|
|||
r.CancelFlowLogs()
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
|
||||
myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
|
||||
_ = r.RouteForAllUntilTxTun(theirControl)
|
||||
}
|
||||
|
||||
|
@ -44,19 +43,19 @@ func BenchmarkHotPath(b *testing.B) {
|
|||
}
|
||||
|
||||
func TestGoodHandshake(t *testing.T) {
|
||||
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
|
||||
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil)
|
||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
|
||||
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", "10.128.0.1/24", nil)
|
||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil)
|
||||
|
||||
// Put their info in our lighthouse
|
||||
myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
|
||||
myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
|
||||
|
||||
// Start the servers
|
||||
myControl.Start()
|
||||
theirControl.Start()
|
||||
|
||||
t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side")
|
||||
myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
|
||||
myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
|
||||
|
||||
t.Log("Have them consume my stage 0 packet. They have a tunnel now")
|
||||
theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
|
||||
|
@ -77,16 +76,16 @@ func TestGoodHandshake(t *testing.T) {
|
|||
myControl.WaitForType(1, 0, theirControl)
|
||||
|
||||
t.Log("Make sure our host infos are correct")
|
||||
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl)
|
||||
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl)
|
||||
|
||||
t.Log("Get that cached packet and make sure it looks right")
|
||||
myCachedPacket := theirControl.GetFromTun(true)
|
||||
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
|
||||
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
|
||||
|
||||
t.Log("Do a bidirectional tunnel test")
|
||||
r := router.NewR(t, myControl, theirControl)
|
||||
defer r.RenderFlow()
|
||||
assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
|
||||
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
|
||||
|
||||
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
|
||||
myControl.Stop()
|
||||
|
@ -95,20 +94,20 @@ func TestGoodHandshake(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestWrongResponderHandshake(t *testing.T) {
|
||||
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
|
||||
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
|
||||
// The IPs here are chosen on purpose:
|
||||
// The current remote handling will sort by preference, public, and then lexically.
|
||||
// So we need them to have a higher address than evil (we could apply a preference though)
|
||||
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 100}, nil)
|
||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 99}, nil)
|
||||
evilControl, evilVpnIp, evilUdpAddr, _ := newSimpleServer(ca, caKey, "evil", net.IP{10, 0, 0, 2}, nil)
|
||||
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", "10.128.0.100/24", nil)
|
||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.99/24", nil)
|
||||
evilControl, evilVpnIp, evilUdpAddr, _ := newSimpleServer(ca, caKey, "evil", "10.128.0.2/24", nil)
|
||||
|
||||
// Add their real udp addr, which should be tried after evil.
|
||||
myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
|
||||
myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
|
||||
|
||||
// Put the evil udp addr in for their vpn Ip, this is a case of being lied to by the lighthouse.
|
||||
myControl.InjectLightHouseAddr(theirVpnIpNet.IP, evilUdpAddr)
|
||||
myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), evilUdpAddr)
|
||||
|
||||
// Build a router so we don't have to reason who gets which packet
|
||||
r := router.NewR(t, myControl, theirControl, evilControl)
|
||||
|
@ -120,7 +119,7 @@ func TestWrongResponderHandshake(t *testing.T) {
|
|||
evilControl.Start()
|
||||
|
||||
t.Log("Start the handshake process, we will route until we see our cached packet get sent to them")
|
||||
myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
|
||||
myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
|
||||
r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
|
||||
h := &header.H{}
|
||||
err := h.Parse(p.Data)
|
||||
|
@ -128,7 +127,7 @@ func TestWrongResponderHandshake(t *testing.T) {
|
|||
panic(err)
|
||||
}
|
||||
|
||||
if p.ToIp.Equal(theirUdpAddr.IP) && p.ToPort == uint16(theirUdpAddr.Port) && h.Type == 1 {
|
||||
if p.To == theirUdpAddr && h.Type == 1 {
|
||||
return router.RouteAndExit
|
||||
}
|
||||
|
||||
|
@ -139,18 +138,18 @@ func TestWrongResponderHandshake(t *testing.T) {
|
|||
|
||||
t.Log("My cached packet should be received by them")
|
||||
myCachedPacket := theirControl.GetFromTun(true)
|
||||
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
|
||||
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
|
||||
|
||||
t.Log("Test the tunnel with them")
|
||||
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl)
|
||||
assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
|
||||
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl)
|
||||
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
|
||||
|
||||
t.Log("Flush all packets from all controllers")
|
||||
r.FlushAll()
|
||||
|
||||
t.Log("Ensure ensure I don't have any hostinfo artifacts from evil")
|
||||
assert.Nil(t, myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(evilVpnIp.IP), true), "My pending hostmap should not contain evil")
|
||||
assert.Nil(t, myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(evilVpnIp.IP), false), "My main hostmap should not contain evil")
|
||||
assert.Nil(t, myControl.GetHostInfoByVpnIp(evilVpnIp.Addr(), true), "My pending hostmap should not contain evil")
|
||||
assert.Nil(t, myControl.GetHostInfoByVpnIp(evilVpnIp.Addr(), false), "My main hostmap should not contain evil")
|
||||
//NOTE: if evil lost the handshake race it may still have a tunnel since me would reject the handshake since the tunnel is complete
|
||||
|
||||
//TODO: assert hostmaps for everyone
|
||||
|
@ -164,13 +163,13 @@ func TestStage1Race(t *testing.T) {
|
|||
// This tests ensures that two hosts handshaking with each other at the same time will allow traffic to flow
|
||||
// But will eventually collapse down to a single tunnel
|
||||
|
||||
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
|
||||
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, nil)
|
||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
|
||||
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", nil)
|
||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil)
|
||||
|
||||
// Put their info in our lighthouse and vice versa
|
||||
myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
|
||||
theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr)
|
||||
myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
|
||||
theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
|
||||
|
||||
// Build a router so we don't have to reason who gets which packet
|
||||
r := router.NewR(t, myControl, theirControl)
|
||||
|
@ -181,8 +180,8 @@ func TestStage1Race(t *testing.T) {
|
|||
theirControl.Start()
|
||||
|
||||
t.Log("Trigger a handshake to start on both me and them")
|
||||
myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
|
||||
theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them"))
|
||||
myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
|
||||
theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them"))
|
||||
|
||||
t.Log("Get both stage 1 handshake packets")
|
||||
myHsForThem := myControl.GetFromUDP(true)
|
||||
|
@ -194,14 +193,14 @@ func TestStage1Race(t *testing.T) {
|
|||
|
||||
r.Log("Route until they receive a message packet")
|
||||
myCachedPacket := r.RouteForAllUntilTxTun(theirControl)
|
||||
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
|
||||
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
|
||||
|
||||
r.Log("Their cached packet should be received by me")
|
||||
theirCachedPacket := r.RouteForAllUntilTxTun(myControl)
|
||||
assertUdpPacket(t, []byte("Hi from them"), theirCachedPacket, theirVpnIpNet.IP, myVpnIpNet.IP, 80, 80)
|
||||
assertUdpPacket(t, []byte("Hi from them"), theirCachedPacket, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), 80, 80)
|
||||
|
||||
r.Log("Do a bidirectional tunnel test")
|
||||
assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
|
||||
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
|
||||
|
||||
myHostmapHosts := myControl.ListHostmapHosts(false)
|
||||
myHostmapIndexes := myControl.ListHostmapIndexes(false)
|
||||
|
@ -219,7 +218,7 @@ func TestStage1Race(t *testing.T) {
|
|||
r.Log("Spin until connection manager tears down a tunnel")
|
||||
|
||||
for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 {
|
||||
assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
|
||||
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
|
||||
t.Log("Connection manager hasn't ticked yet")
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
|
@ -241,13 +240,13 @@ func TestStage1Race(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestUncleanShutdownRaceLoser(t *testing.T) {
|
||||
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
|
||||
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, nil)
|
||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
|
||||
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", nil)
|
||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil)
|
||||
|
||||
// Teach my how to get to the relay and that their can be reached via the relay
|
||||
myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
|
||||
theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr)
|
||||
myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
|
||||
theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
|
||||
|
||||
// Build a router so we don't have to reason who gets which packet
|
||||
r := router.NewR(t, myControl, theirControl)
|
||||
|
@ -258,28 +257,28 @@ func TestUncleanShutdownRaceLoser(t *testing.T) {
|
|||
theirControl.Start()
|
||||
|
||||
r.Log("Trigger a handshake from me to them")
|
||||
myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
|
||||
myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
|
||||
|
||||
p := r.RouteForAllUntilTxTun(theirControl)
|
||||
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
|
||||
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
|
||||
|
||||
r.Log("Nuke my hostmap")
|
||||
myHostmap := myControl.GetHostmap()
|
||||
myHostmap.Hosts = map[iputil.VpnIp]*nebula.HostInfo{}
|
||||
myHostmap.Hosts = map[netip.Addr]*nebula.HostInfo{}
|
||||
myHostmap.Indexes = map[uint32]*nebula.HostInfo{}
|
||||
myHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{}
|
||||
|
||||
myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me again"))
|
||||
myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me again"))
|
||||
p = r.RouteForAllUntilTxTun(theirControl)
|
||||
assertUdpPacket(t, []byte("Hi from me again"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
|
||||
assertUdpPacket(t, []byte("Hi from me again"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
|
||||
|
||||
r.Log("Assert the tunnel works")
|
||||
assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
|
||||
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
|
||||
|
||||
r.Log("Wait for the dead index to go away")
|
||||
start := len(theirControl.GetHostmap().Indexes)
|
||||
for {
|
||||
assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
|
||||
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
|
||||
if len(theirControl.GetHostmap().Indexes) < start {
|
||||
break
|
||||
}
|
||||
|
@ -290,13 +289,13 @@ func TestUncleanShutdownRaceLoser(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestUncleanShutdownRaceWinner(t *testing.T) {
|
||||
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
|
||||
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, nil)
|
||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
|
||||
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", nil)
|
||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil)
|
||||
|
||||
// Teach my how to get to the relay and that their can be reached via the relay
|
||||
myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
|
||||
theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr)
|
||||
myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
|
||||
theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
|
||||
|
||||
// Build a router so we don't have to reason who gets which packet
|
||||
r := router.NewR(t, myControl, theirControl)
|
||||
|
@ -307,30 +306,30 @@ func TestUncleanShutdownRaceWinner(t *testing.T) {
|
|||
theirControl.Start()
|
||||
|
||||
r.Log("Trigger a handshake from me to them")
|
||||
myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
|
||||
myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
|
||||
|
||||
p := r.RouteForAllUntilTxTun(theirControl)
|
||||
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
|
||||
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
|
||||
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
|
||||
|
||||
r.Log("Nuke my hostmap")
|
||||
theirHostmap := theirControl.GetHostmap()
|
||||
theirHostmap.Hosts = map[iputil.VpnIp]*nebula.HostInfo{}
|
||||
theirHostmap.Hosts = map[netip.Addr]*nebula.HostInfo{}
|
||||
theirHostmap.Indexes = map[uint32]*nebula.HostInfo{}
|
||||
theirHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{}
|
||||
|
||||
theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them again"))
|
||||
theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them again"))
|
||||
p = r.RouteForAllUntilTxTun(myControl)
|
||||
assertUdpPacket(t, []byte("Hi from them again"), p, theirVpnIpNet.IP, myVpnIpNet.IP, 80, 80)
|
||||
assertUdpPacket(t, []byte("Hi from them again"), p, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), 80, 80)
|
||||
r.RenderHostmaps("Derp hostmaps", myControl, theirControl)
|
||||
|
||||
r.Log("Assert the tunnel works")
|
||||
assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
|
||||
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
|
||||
|
||||
r.Log("Wait for the dead index to go away")
|
||||
start := len(myControl.GetHostmap().Indexes)
|
||||
for {
|
||||
assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
|
||||
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
|
||||
if len(myControl.GetHostmap().Indexes) < start {
|
||||
break
|
||||
}
|
||||
|
@ -341,15 +340,15 @@ func TestUncleanShutdownRaceWinner(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestRelays(t *testing.T) {
|
||||
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
|
||||
myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}})
|
||||
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}})
|
||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
|
||||
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
|
||||
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
|
||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
|
||||
|
||||
// Teach my how to get to the relay and that their can be reached via the relay
|
||||
myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr)
|
||||
myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP})
|
||||
relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
|
||||
myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
|
||||
myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
|
||||
relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
|
||||
|
||||
// Build a router so we don't have to reason who gets which packet
|
||||
r := router.NewR(t, myControl, relayControl, theirControl)
|
||||
|
@ -361,31 +360,31 @@ func TestRelays(t *testing.T) {
|
|||
theirControl.Start()
|
||||
|
||||
t.Log("Trigger a handshake from me to them via the relay")
|
||||
myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
|
||||
myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
|
||||
|
||||
p := r.RouteForAllUntilTxTun(theirControl)
|
||||
r.Log("Assert the tunnel works")
|
||||
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
|
||||
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
|
||||
r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
|
||||
//TODO: assert we actually used the relay even though it should be impossible for a tunnel to have occurred without it
|
||||
}
|
||||
|
||||
func TestStage1RaceRelays(t *testing.T) {
|
||||
//NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay
|
||||
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
|
||||
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}})
|
||||
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}})
|
||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
|
||||
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
|
||||
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
|
||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
|
||||
|
||||
// Teach my how to get to the relay and that their can be reached via the relay
|
||||
myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr)
|
||||
theirControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr)
|
||||
myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
|
||||
theirControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
|
||||
|
||||
myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP})
|
||||
theirControl.InjectRelays(myVpnIpNet.IP, []net.IP{relayVpnIpNet.IP})
|
||||
myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
|
||||
theirControl.InjectRelays(myVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
|
||||
|
||||
relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
|
||||
relayControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr)
|
||||
relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
|
||||
relayControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
|
||||
|
||||
// Build a router so we don't have to reason who gets which packet
|
||||
r := router.NewR(t, myControl, relayControl, theirControl)
|
||||
|
@ -397,14 +396,14 @@ func TestStage1RaceRelays(t *testing.T) {
|
|||
theirControl.Start()
|
||||
|
||||
r.Log("Get a tunnel between me and relay")
|
||||
assertTunnel(t, myVpnIpNet.IP, relayVpnIpNet.IP, myControl, relayControl, r)
|
||||
assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r)
|
||||
|
||||
r.Log("Get a tunnel between them and relay")
|
||||
assertTunnel(t, theirVpnIpNet.IP, relayVpnIpNet.IP, theirControl, relayControl, r)
|
||||
assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r)
|
||||
|
||||
r.Log("Trigger a handshake from both them and me via relay to them and me")
|
||||
myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
|
||||
theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them"))
|
||||
myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
|
||||
theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them"))
|
||||
|
||||
r.Log("Wait for a packet from them to me")
|
||||
p := r.RouteForAllUntilTxTun(myControl)
|
||||
|
@ -421,21 +420,21 @@ func TestStage1RaceRelays(t *testing.T) {
|
|||
|
||||
func TestStage1RaceRelays2(t *testing.T) {
|
||||
//NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay
|
||||
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
|
||||
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}})
|
||||
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}})
|
||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
|
||||
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
|
||||
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
|
||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
|
||||
l := NewTestLogger()
|
||||
|
||||
// Teach my how to get to the relay and that their can be reached via the relay
|
||||
myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr)
|
||||
theirControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr)
|
||||
myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
|
||||
theirControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
|
||||
|
||||
myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP})
|
||||
theirControl.InjectRelays(myVpnIpNet.IP, []net.IP{relayVpnIpNet.IP})
|
||||
myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
|
||||
theirControl.InjectRelays(myVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
|
||||
|
||||
relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
|
||||
relayControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr)
|
||||
relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
|
||||
relayControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
|
||||
|
||||
// Build a router so we don't have to reason who gets which packet
|
||||
r := router.NewR(t, myControl, relayControl, theirControl)
|
||||
|
@ -448,16 +447,16 @@ func TestStage1RaceRelays2(t *testing.T) {
|
|||
|
||||
r.Log("Get a tunnel between me and relay")
|
||||
l.Info("Get a tunnel between me and relay")
|
||||
assertTunnel(t, myVpnIpNet.IP, relayVpnIpNet.IP, myControl, relayControl, r)
|
||||
assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r)
|
||||
|
||||
r.Log("Get a tunnel between them and relay")
|
||||
l.Info("Get a tunnel between them and relay")
|
||||
assertTunnel(t, theirVpnIpNet.IP, relayVpnIpNet.IP, theirControl, relayControl, r)
|
||||
assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r)
|
||||
|
||||
r.Log("Trigger a handshake from both them and me via relay to them and me")
|
||||
l.Info("Trigger a handshake from both them and me via relay to them and me")
|
||||
myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
|
||||
theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them"))
|
||||
myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
|
||||
theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them"))
|
||||
|
||||
//r.RouteUntilAfterMsgType(myControl, header.Control, header.MessageNone)
|
||||
//r.RouteUntilAfterMsgType(theirControl, header.Control, header.MessageNone)
|
||||
|
@ -470,7 +469,7 @@ func TestStage1RaceRelays2(t *testing.T) {
|
|||
|
||||
r.Log("Assert the tunnel works")
|
||||
l.Info("Assert the tunnel works")
|
||||
assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
|
||||
assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
|
||||
|
||||
t.Log("Wait until we remove extra tunnels")
|
||||
l.Info("Wait until we remove extra tunnels")
|
||||
|
@ -490,7 +489,7 @@ func TestStage1RaceRelays2(t *testing.T) {
|
|||
"theirControl": len(theirControl.GetHostmap().Indexes),
|
||||
"relayControl": len(relayControl.GetHostmap().Indexes),
|
||||
}).Info("Waiting for hostinfos to be removed...")
|
||||
assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
|
||||
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
|
||||
t.Log("Connection manager hasn't ticked yet")
|
||||
time.Sleep(time.Second)
|
||||
retries--
|
||||
|
@ -498,7 +497,7 @@ func TestStage1RaceRelays2(t *testing.T) {
|
|||
|
||||
r.Log("Assert the tunnel works")
|
||||
l.Info("Assert the tunnel works")
|
||||
assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
|
||||
assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
|
||||
|
||||
myControl.Stop()
|
||||
theirControl.Stop()
|
||||
|
@ -507,16 +506,17 @@ func TestStage1RaceRelays2(t *testing.T) {
|
|||
//
|
||||
////TODO: assert hostmaps
|
||||
}
|
||||
|
||||
func TestRehandshakingRelays(t *testing.T) {
|
||||
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
|
||||
myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}})
|
||||
relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}})
|
||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
|
||||
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
|
||||
relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
|
||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
|
||||
|
||||
// Teach my how to get to the relay and that their can be reached via the relay
|
||||
myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr)
|
||||
myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP})
|
||||
relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
|
||||
myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
|
||||
myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
|
||||
relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
|
||||
|
||||
// Build a router so we don't have to reason who gets which packet
|
||||
r := router.NewR(t, myControl, relayControl, theirControl)
|
||||
|
@ -528,11 +528,11 @@ func TestRehandshakingRelays(t *testing.T) {
|
|||
theirControl.Start()
|
||||
|
||||
t.Log("Trigger a handshake from me to them via the relay")
|
||||
myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
|
||||
myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
|
||||
|
||||
p := r.RouteForAllUntilTxTun(theirControl)
|
||||
r.Log("Assert the tunnel works")
|
||||
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
|
||||
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
|
||||
r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl)
|
||||
|
||||
// When I update the certificate for the relay, both me and them will have 2 host infos for the relay,
|
||||
|
@ -556,8 +556,8 @@ func TestRehandshakingRelays(t *testing.T) {
|
|||
|
||||
for {
|
||||
r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet")
|
||||
assertTunnel(t, myVpnIpNet.IP, relayVpnIpNet.IP, myControl, relayControl, r)
|
||||
c := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(relayVpnIpNet.IP), false)
|
||||
assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r)
|
||||
c := myControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false)
|
||||
if len(c.Cert.Details.Groups) != 0 {
|
||||
// We have a new certificate now
|
||||
r.Log("Certificate between my and relay is updated!")
|
||||
|
@ -569,8 +569,8 @@ func TestRehandshakingRelays(t *testing.T) {
|
|||
|
||||
for {
|
||||
r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet")
|
||||
assertTunnel(t, theirVpnIpNet.IP, relayVpnIpNet.IP, theirControl, relayControl, r)
|
||||
c := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(relayVpnIpNet.IP), false)
|
||||
assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r)
|
||||
c := theirControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false)
|
||||
if len(c.Cert.Details.Groups) != 0 {
|
||||
// We have a new certificate now
|
||||
r.Log("Certificate between their and relay is updated!")
|
||||
|
@ -581,13 +581,13 @@ func TestRehandshakingRelays(t *testing.T) {
|
|||
}
|
||||
|
||||
r.Log("Assert the relay tunnel still works")
|
||||
assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
|
||||
assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
|
||||
r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl)
|
||||
// We should have two hostinfos on all sides
|
||||
for len(myControl.GetHostmap().Indexes) != 2 {
|
||||
t.Logf("Waiting for myControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(myControl.GetHostmap().Indexes))
|
||||
r.Log("Assert the relay tunnel still works")
|
||||
assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
|
||||
assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
|
||||
r.Log("yupitdoes")
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
|
@ -595,7 +595,7 @@ func TestRehandshakingRelays(t *testing.T) {
|
|||
for len(theirControl.GetHostmap().Indexes) != 2 {
|
||||
t.Logf("Waiting for theirControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(theirControl.GetHostmap().Indexes))
|
||||
r.Log("Assert the relay tunnel still works")
|
||||
assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
|
||||
assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
|
||||
r.Log("yupitdoes")
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
|
@ -603,7 +603,7 @@ func TestRehandshakingRelays(t *testing.T) {
|
|||
for len(relayControl.GetHostmap().Indexes) != 2 {
|
||||
t.Logf("Waiting for relayControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(relayControl.GetHostmap().Indexes))
|
||||
r.Log("Assert the relay tunnel still works")
|
||||
assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
|
||||
assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
|
||||
r.Log("yupitdoes")
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
|
@ -612,15 +612,15 @@ func TestRehandshakingRelays(t *testing.T) {
|
|||
|
||||
func TestRehandshakingRelaysPrimary(t *testing.T) {
|
||||
// This test is the same as TestRehandshakingRelays but one of the terminal types is a primary swap winner
|
||||
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
|
||||
myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 128}, m{"relay": m{"use_relays": true}})
|
||||
relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 1}, m{"relay": m{"am_relay": true}})
|
||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
|
||||
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.128/24", m{"relay": m{"use_relays": true}})
|
||||
relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay ", "10.128.0.1/24", m{"relay": m{"am_relay": true}})
|
||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
|
||||
|
||||
// Teach my how to get to the relay and that their can be reached via the relay
|
||||
myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr)
|
||||
myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP})
|
||||
relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
|
||||
myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
|
||||
myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
|
||||
relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
|
||||
|
||||
// Build a router so we don't have to reason who gets which packet
|
||||
r := router.NewR(t, myControl, relayControl, theirControl)
|
||||
|
@ -632,11 +632,11 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
|
|||
theirControl.Start()
|
||||
|
||||
t.Log("Trigger a handshake from me to them via the relay")
|
||||
myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
|
||||
myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
|
||||
|
||||
p := r.RouteForAllUntilTxTun(theirControl)
|
||||
r.Log("Assert the tunnel works")
|
||||
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
|
||||
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
|
||||
r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl)
|
||||
|
||||
// When I update the certificate for the relay, both me and them will have 2 host infos for the relay,
|
||||
|
@ -660,8 +660,8 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
|
|||
|
||||
for {
|
||||
r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet")
|
||||
assertTunnel(t, myVpnIpNet.IP, relayVpnIpNet.IP, myControl, relayControl, r)
|
||||
c := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(relayVpnIpNet.IP), false)
|
||||
assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r)
|
||||
c := myControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false)
|
||||
if len(c.Cert.Details.Groups) != 0 {
|
||||
// We have a new certificate now
|
||||
r.Log("Certificate between my and relay is updated!")
|
||||
|
@ -673,8 +673,8 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
|
|||
|
||||
for {
|
||||
r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet")
|
||||
assertTunnel(t, theirVpnIpNet.IP, relayVpnIpNet.IP, theirControl, relayControl, r)
|
||||
c := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(relayVpnIpNet.IP), false)
|
||||
assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r)
|
||||
c := theirControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false)
|
||||
if len(c.Cert.Details.Groups) != 0 {
|
||||
// We have a new certificate now
|
||||
r.Log("Certificate between their and relay is updated!")
|
||||
|
@ -685,13 +685,13 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
|
|||
}
|
||||
|
||||
r.Log("Assert the relay tunnel still works")
|
||||
assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
|
||||
assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
|
||||
r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl)
|
||||
// We should have two hostinfos on all sides
|
||||
for len(myControl.GetHostmap().Indexes) != 2 {
|
||||
t.Logf("Waiting for myControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(myControl.GetHostmap().Indexes))
|
||||
r.Log("Assert the relay tunnel still works")
|
||||
assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
|
||||
assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
|
||||
r.Log("yupitdoes")
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
|
@ -699,7 +699,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
|
|||
for len(theirControl.GetHostmap().Indexes) != 2 {
|
||||
t.Logf("Waiting for theirControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(theirControl.GetHostmap().Indexes))
|
||||
r.Log("Assert the relay tunnel still works")
|
||||
assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
|
||||
assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
|
||||
r.Log("yupitdoes")
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
|
@ -707,7 +707,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
|
|||
for len(relayControl.GetHostmap().Indexes) != 2 {
|
||||
t.Logf("Waiting for relayControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(relayControl.GetHostmap().Indexes))
|
||||
r.Log("Assert the relay tunnel still works")
|
||||
assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
|
||||
assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
|
||||
r.Log("yupitdoes")
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
|
@ -715,13 +715,13 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestRehandshaking(t *testing.T) {
|
||||
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
|
||||
myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 2}, nil)
|
||||
theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 1}, nil)
|
||||
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me ", "10.128.0.2/24", nil)
|
||||
theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", "10.128.0.1/24", nil)
|
||||
|
||||
// Put their info in our lighthouse and vice versa
|
||||
myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
|
||||
theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr)
|
||||
myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
|
||||
theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
|
||||
|
||||
// Build a router so we don't have to reason who gets which packet
|
||||
r := router.NewR(t, myControl, theirControl)
|
||||
|
@ -732,7 +732,7 @@ func TestRehandshaking(t *testing.T) {
|
|||
theirControl.Start()
|
||||
|
||||
t.Log("Stand up a tunnel between me and them")
|
||||
assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
|
||||
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
|
||||
|
||||
r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
|
||||
|
||||
|
@ -754,8 +754,8 @@ func TestRehandshaking(t *testing.T) {
|
|||
myConfig.ReloadConfigString(string(rc))
|
||||
|
||||
for {
|
||||
assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
|
||||
c := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(myVpnIpNet.IP), false)
|
||||
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
|
||||
c := theirControl.GetHostInfoByVpnIp(myVpnIpNet.Addr(), false)
|
||||
if len(c.Cert.Details.Groups) != 0 {
|
||||
// We have a new certificate now
|
||||
break
|
||||
|
@ -781,19 +781,19 @@ func TestRehandshaking(t *testing.T) {
|
|||
|
||||
r.Log("Spin until there is only 1 tunnel")
|
||||
for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 {
|
||||
assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
|
||||
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
|
||||
t.Log("Connection manager hasn't ticked yet")
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
|
||||
assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
|
||||
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
|
||||
myFinalHostmapHosts := myControl.ListHostmapHosts(false)
|
||||
myFinalHostmapIndexes := myControl.ListHostmapIndexes(false)
|
||||
theirFinalHostmapHosts := theirControl.ListHostmapHosts(false)
|
||||
theirFinalHostmapIndexes := theirControl.ListHostmapIndexes(false)
|
||||
|
||||
// Make sure the correct tunnel won
|
||||
c := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(myVpnIpNet.IP), false)
|
||||
c := theirControl.GetHostInfoByVpnIp(myVpnIpNet.Addr(), false)
|
||||
assert.Contains(t, c.Cert.Details.Groups, "new group")
|
||||
|
||||
// We should only have a single tunnel now on both sides
|
||||
|
@ -811,13 +811,13 @@ func TestRehandshaking(t *testing.T) {
|
|||
func TestRehandshakingLoser(t *testing.T) {
|
||||
// The purpose of this test is that the race loser renews their certificate and rehandshakes. The final tunnel
|
||||
// Should be the one with the new certificate
|
||||
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
|
||||
myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 2}, nil)
|
||||
theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 1}, nil)
|
||||
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me ", "10.128.0.2/24", nil)
|
||||
theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", "10.128.0.1/24", nil)
|
||||
|
||||
// Put their info in our lighthouse and vice versa
|
||||
myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
|
||||
theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr)
|
||||
myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
|
||||
theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
|
||||
|
||||
// Build a router so we don't have to reason who gets which packet
|
||||
r := router.NewR(t, myControl, theirControl)
|
||||
|
@ -828,10 +828,10 @@ func TestRehandshakingLoser(t *testing.T) {
|
|||
theirControl.Start()
|
||||
|
||||
t.Log("Stand up a tunnel between me and them")
|
||||
assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
|
||||
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
|
||||
|
||||
tt1 := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(theirVpnIpNet.IP), false)
|
||||
tt2 := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(myVpnIpNet.IP), false)
|
||||
tt1 := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), false)
|
||||
tt2 := theirControl.GetHostInfoByVpnIp(myVpnIpNet.Addr(), false)
|
||||
fmt.Println(tt1.LocalIndex, tt2.LocalIndex)
|
||||
|
||||
r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
|
||||
|
@ -854,8 +854,8 @@ func TestRehandshakingLoser(t *testing.T) {
|
|||
theirConfig.ReloadConfigString(string(rc))
|
||||
|
||||
for {
|
||||
assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
|
||||
theirCertInMe := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(theirVpnIpNet.IP), false)
|
||||
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
|
||||
theirCertInMe := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), false)
|
||||
|
||||
_, theirNewGroup := theirCertInMe.Cert.Details.InvertedGroups["their new group"]
|
||||
if theirNewGroup {
|
||||
|
@ -882,19 +882,19 @@ func TestRehandshakingLoser(t *testing.T) {
|
|||
|
||||
r.Log("Spin until there is only 1 tunnel")
|
||||
for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 {
|
||||
assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
|
||||
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
|
||||
t.Log("Connection manager hasn't ticked yet")
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
|
||||
assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
|
||||
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
|
||||
myFinalHostmapHosts := myControl.ListHostmapHosts(false)
|
||||
myFinalHostmapIndexes := myControl.ListHostmapIndexes(false)
|
||||
theirFinalHostmapHosts := theirControl.ListHostmapHosts(false)
|
||||
theirFinalHostmapIndexes := theirControl.ListHostmapIndexes(false)
|
||||
|
||||
// Make sure the correct tunnel won
|
||||
theirCertInMe := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(theirVpnIpNet.IP), false)
|
||||
theirCertInMe := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), false)
|
||||
assert.Contains(t, theirCertInMe.Cert.Details.Groups, "their new group")
|
||||
|
||||
// We should only have a single tunnel now on both sides
|
||||
|
@ -912,13 +912,13 @@ func TestRaceRegression(t *testing.T) {
|
|||
// This test forces stage 1, stage 2, stage 1 to be received by me from them
|
||||
// We had a bug where we were not finding the duplicate handshake and responding to the final stage 1 which
|
||||
// caused a cross-linked hostinfo
|
||||
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
|
||||
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil)
|
||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
|
||||
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", "10.128.0.1/24", nil)
|
||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil)
|
||||
|
||||
// Put their info in our lighthouse
|
||||
myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
|
||||
theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr)
|
||||
myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
|
||||
theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
|
||||
|
||||
// Start the servers
|
||||
myControl.Start()
|
||||
|
@ -932,8 +932,8 @@ func TestRaceRegression(t *testing.T) {
|
|||
//them rx stage:2 initiatorIndex=120607833 responderIndex=4209862089
|
||||
|
||||
t.Log("Start both handshakes")
|
||||
myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
|
||||
theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them"))
|
||||
myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
|
||||
theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them"))
|
||||
|
||||
t.Log("Get both stage 1")
|
||||
myStage1ForThem := myControl.GetFromUDP(true)
|
||||
|
@ -963,7 +963,7 @@ func TestRaceRegression(t *testing.T) {
|
|||
r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
|
||||
|
||||
t.Log("Make sure the tunnel still works")
|
||||
assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
|
||||
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
|
||||
|
||||
myControl.Stop()
|
||||
theirControl.Stop()
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"crypto/rand"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/slackhq/nebula/cert"
|
||||
|
@ -12,7 +13,7 @@ import (
|
|||
)
|
||||
|
||||
// NewTestCaCert will generate a CA cert
|
||||
func NewTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) {
|
||||
func NewTestCaCert(before, after time.Time, ips, subnets []netip.Prefix, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) {
|
||||
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||
if before.IsZero() {
|
||||
before = time.Now().Add(time.Second * -60).Round(time.Second)
|
||||
|
@ -33,11 +34,17 @@ func NewTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []
|
|||
}
|
||||
|
||||
if len(ips) > 0 {
|
||||
nc.Details.Ips = ips
|
||||
nc.Details.Ips = make([]*net.IPNet, len(ips))
|
||||
for i, ip := range ips {
|
||||
nc.Details.Ips[i] = &net.IPNet{IP: ip.Addr().AsSlice(), Mask: net.CIDRMask(ip.Bits(), ip.Addr().BitLen())}
|
||||
}
|
||||
}
|
||||
|
||||
if len(subnets) > 0 {
|
||||
nc.Details.Subnets = subnets
|
||||
nc.Details.Subnets = make([]*net.IPNet, len(subnets))
|
||||
for i, ip := range subnets {
|
||||
nc.Details.Ips[i] = &net.IPNet{IP: ip.Addr().AsSlice(), Mask: net.CIDRMask(ip.Bits(), ip.Addr().BitLen())}
|
||||
}
|
||||
}
|
||||
|
||||
if len(groups) > 0 {
|
||||
|
@ -59,7 +66,7 @@ func NewTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []
|
|||
|
||||
// NewTestCert will generate a signed certificate with the provided details.
|
||||
// Expiry times are defaulted if you do not pass them in
|
||||
func NewTestCert(ca *cert.NebulaCertificate, key []byte, name string, before, after time.Time, ip *net.IPNet, subnets []*net.IPNet, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) {
|
||||
func NewTestCert(ca *cert.NebulaCertificate, key []byte, name string, before, after time.Time, ip netip.Prefix, subnets []netip.Prefix, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) {
|
||||
issuer, err := ca.Sha256Sum()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
|
@ -74,12 +81,12 @@ func NewTestCert(ca *cert.NebulaCertificate, key []byte, name string, before, af
|
|||
}
|
||||
|
||||
pub, rawPriv := x25519Keypair()
|
||||
|
||||
ipb := ip.Addr().AsSlice()
|
||||
nc := &cert.NebulaCertificate{
|
||||
Details: cert.NebulaCertificateDetails{
|
||||
Name: name,
|
||||
Ips: []*net.IPNet{ip},
|
||||
Subnets: subnets,
|
||||
Name: name,
|
||||
Ips: []*net.IPNet{{IP: ipb[:], Mask: net.CIDRMask(ip.Bits(), ip.Addr().BitLen())}},
|
||||
//Subnets: subnets,
|
||||
Groups: groups,
|
||||
NotBefore: time.Unix(before.Unix(), 0),
|
||||
NotAfter: time.Unix(after.Unix(), 0),
|
||||
|
|
|
@ -6,7 +6,7 @@ package e2e
|
|||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
@ -19,7 +19,6 @@ import (
|
|||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/e2e/router"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"gopkg.in/yaml.v2"
|
||||
)
|
||||
|
@ -27,15 +26,23 @@ import (
|
|||
type m map[string]interface{}
|
||||
|
||||
// newSimpleServer creates a nebula instance with many assumptions
|
||||
func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp net.IP, overrides m) (*nebula.Control, *net.IPNet, *net.UDPAddr, *config.C) {
|
||||
func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, sVpnIpNet string, overrides m) (*nebula.Control, netip.Prefix, netip.AddrPort, *config.C) {
|
||||
l := NewTestLogger()
|
||||
|
||||
vpnIpNet := &net.IPNet{IP: make([]byte, len(udpIp)), Mask: net.IPMask{255, 255, 255, 0}}
|
||||
copy(vpnIpNet.IP, udpIp)
|
||||
vpnIpNet.IP[1] += 128
|
||||
udpAddr := net.UDPAddr{
|
||||
IP: udpIp,
|
||||
Port: 4242,
|
||||
vpnIpNet, err := netip.ParsePrefix(sVpnIpNet)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
var udpAddr netip.AddrPort
|
||||
if vpnIpNet.Addr().Is4() {
|
||||
budpIp := vpnIpNet.Addr().As4()
|
||||
budpIp[1] -= 128
|
||||
udpAddr = netip.AddrPortFrom(netip.AddrFrom4(budpIp), 4242)
|
||||
} else {
|
||||
budpIp := vpnIpNet.Addr().As16()
|
||||
budpIp[13] -= 128
|
||||
udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
|
||||
}
|
||||
_, _, myPrivKey, myPEM := NewTestCert(caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnIpNet, nil, []string{})
|
||||
|
||||
|
@ -67,8 +74,8 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u
|
|||
// "try_interval": "1s",
|
||||
//},
|
||||
"listen": m{
|
||||
"host": udpAddr.IP.String(),
|
||||
"port": udpAddr.Port,
|
||||
"host": udpAddr.Addr().String(),
|
||||
"port": udpAddr.Port(),
|
||||
},
|
||||
"logging": m{
|
||||
"timestamp_format": fmt.Sprintf("%v 15:04:05.000000", name),
|
||||
|
@ -102,7 +109,7 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u
|
|||
panic(err)
|
||||
}
|
||||
|
||||
return control, vpnIpNet, &udpAddr, c
|
||||
return control, vpnIpNet, udpAddr, c
|
||||
}
|
||||
|
||||
type doneCb func()
|
||||
|
@ -123,7 +130,7 @@ func deadline(t *testing.T, seconds time.Duration) doneCb {
|
|||
}
|
||||
}
|
||||
|
||||
func assertTunnel(t *testing.T, vpnIpA, vpnIpB net.IP, controlA, controlB *nebula.Control, r *router.R) {
|
||||
func assertTunnel(t *testing.T, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) {
|
||||
// Send a packet from them to me
|
||||
controlB.InjectTunUDPPacket(vpnIpA, 80, 90, []byte("Hi from B"))
|
||||
bPacket := r.RouteForAllUntilTxTun(controlA)
|
||||
|
@ -135,23 +142,20 @@ func assertTunnel(t *testing.T, vpnIpA, vpnIpB net.IP, controlA, controlB *nebul
|
|||
assertUdpPacket(t, []byte("Hello from A"), aPacket, vpnIpA, vpnIpB, 90, 80)
|
||||
}
|
||||
|
||||
func assertHostInfoPair(t *testing.T, addrA, addrB *net.UDPAddr, vpnIpA, vpnIpB net.IP, controlA, controlB *nebula.Control) {
|
||||
func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control) {
|
||||
// Get both host infos
|
||||
hBinA := controlA.GetHostInfoByVpnIp(iputil.Ip2VpnIp(vpnIpB), false)
|
||||
hBinA := controlA.GetHostInfoByVpnIp(vpnIpB, false)
|
||||
assert.NotNil(t, hBinA, "Host B was not found by vpnIp in controlA")
|
||||
|
||||
hAinB := controlB.GetHostInfoByVpnIp(iputil.Ip2VpnIp(vpnIpA), false)
|
||||
hAinB := controlB.GetHostInfoByVpnIp(vpnIpA, false)
|
||||
assert.NotNil(t, hAinB, "Host A was not found by vpnIp in controlB")
|
||||
|
||||
// Check that both vpn and real addr are correct
|
||||
assert.Equal(t, vpnIpB, hBinA.VpnIp, "Host B VpnIp is wrong in control A")
|
||||
assert.Equal(t, vpnIpA, hAinB.VpnIp, "Host A VpnIp is wrong in control B")
|
||||
|
||||
assert.Equal(t, addrB.IP.To16(), hBinA.CurrentRemote.IP.To16(), "Host B remote ip is wrong in control A")
|
||||
assert.Equal(t, addrA.IP.To16(), hAinB.CurrentRemote.IP.To16(), "Host A remote ip is wrong in control B")
|
||||
|
||||
assert.Equal(t, addrB.Port, int(hBinA.CurrentRemote.Port), "Host B remote port is wrong in control A")
|
||||
assert.Equal(t, addrA.Port, int(hAinB.CurrentRemote.Port), "Host A remote port is wrong in control B")
|
||||
assert.Equal(t, addrB, hBinA.CurrentRemote, "Host B remote is wrong in control A")
|
||||
assert.Equal(t, addrA, hAinB.CurrentRemote, "Host A remote is wrong in control B")
|
||||
|
||||
// Check that our indexes match
|
||||
assert.Equal(t, hBinA.LocalIndex, hAinB.RemoteIndex, "Host B local index does not match host A remote index")
|
||||
|
@ -174,13 +178,13 @@ func assertHostInfoPair(t *testing.T, addrA, addrB *net.UDPAddr, vpnIpA, vpnIpB
|
|||
//checkIndexes("hmB", hmB, hAinB)
|
||||
}
|
||||
|
||||
func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp net.IP, fromPort, toPort uint16) {
|
||||
func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
|
||||
packet := gopacket.NewPacket(b, layers.LayerTypeIPv4, gopacket.Lazy)
|
||||
v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4)
|
||||
assert.NotNil(t, v4, "No ipv4 data found")
|
||||
|
||||
assert.Equal(t, fromIp, v4.SrcIP, "Source ip was incorrect")
|
||||
assert.Equal(t, toIp, v4.DstIP, "Dest ip was incorrect")
|
||||
assert.Equal(t, fromIp.AsSlice(), []byte(v4.SrcIP), "Source ip was incorrect")
|
||||
assert.Equal(t, toIp.AsSlice(), []byte(v4.DstIP), "Dest ip was incorrect")
|
||||
|
||||
udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP)
|
||||
assert.NotNil(t, udp, "No udp data found")
|
||||
|
|
|
@ -5,11 +5,11 @@ package router
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/slackhq/nebula"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
)
|
||||
|
||||
type edge struct {
|
||||
|
@ -118,14 +118,14 @@ func renderHostmap(c *nebula.Control) (string, []*edge) {
|
|||
return r, globalLines
|
||||
}
|
||||
|
||||
func sortedHosts(hosts map[iputil.VpnIp]*nebula.HostInfo) []iputil.VpnIp {
|
||||
keys := make([]iputil.VpnIp, 0, len(hosts))
|
||||
func sortedHosts(hosts map[netip.Addr]*nebula.HostInfo) []netip.Addr {
|
||||
keys := make([]netip.Addr, 0, len(hosts))
|
||||
for key := range hosts {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
|
||||
sort.SliceStable(keys, func(i, j int) bool {
|
||||
return keys[i] > keys[j]
|
||||
return keys[i].Compare(keys[j]) > 0
|
||||
})
|
||||
|
||||
return keys
|
||||
|
|
|
@ -6,12 +6,11 @@ package router
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
@ -21,7 +20,6 @@ import (
|
|||
"github.com/google/gopacket/layers"
|
||||
"github.com/slackhq/nebula"
|
||||
"github.com/slackhq/nebula/header"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/slackhq/nebula/udp"
|
||||
"golang.org/x/exp/maps"
|
||||
)
|
||||
|
@ -29,18 +27,18 @@ import (
|
|||
type R struct {
|
||||
// Simple map of the ip:port registered on a control to the control
|
||||
// Basically a router, right?
|
||||
controls map[string]*nebula.Control
|
||||
controls map[netip.AddrPort]*nebula.Control
|
||||
|
||||
// A map for inbound packets for a control that doesn't know about this address
|
||||
inNat map[string]*nebula.Control
|
||||
inNat map[netip.AddrPort]*nebula.Control
|
||||
|
||||
// A last used map, if an inbound packet hit the inNat map then
|
||||
// all return packets should use the same last used inbound address for the outbound sender
|
||||
// map[from address + ":" + to address] => ip:port to rewrite in the udp packet to receiver
|
||||
outNat map[string]net.UDPAddr
|
||||
outNat map[string]netip.AddrPort
|
||||
|
||||
// A map of vpn ip to the nebula control it belongs to
|
||||
vpnControls map[iputil.VpnIp]*nebula.Control
|
||||
vpnControls map[netip.Addr]*nebula.Control
|
||||
|
||||
ignoreFlows []ignoreFlow
|
||||
flow []flowEntry
|
||||
|
@ -118,10 +116,10 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R {
|
|||
}
|
||||
|
||||
r := &R{
|
||||
controls: make(map[string]*nebula.Control),
|
||||
vpnControls: make(map[iputil.VpnIp]*nebula.Control),
|
||||
inNat: make(map[string]*nebula.Control),
|
||||
outNat: make(map[string]net.UDPAddr),
|
||||
controls: make(map[netip.AddrPort]*nebula.Control),
|
||||
vpnControls: make(map[netip.Addr]*nebula.Control),
|
||||
inNat: make(map[netip.AddrPort]*nebula.Control),
|
||||
outNat: make(map[string]netip.AddrPort),
|
||||
flow: []flowEntry{},
|
||||
ignoreFlows: []ignoreFlow{},
|
||||
fn: filepath.Join("mermaid", fmt.Sprintf("%s.md", t.Name())),
|
||||
|
@ -135,7 +133,7 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R {
|
|||
for _, c := range controls {
|
||||
addr := c.GetUDPAddr()
|
||||
if _, ok := r.controls[addr]; ok {
|
||||
panic("Duplicate listen address: " + addr)
|
||||
panic("Duplicate listen address: " + addr.String())
|
||||
}
|
||||
|
||||
r.vpnControls[c.GetVpnIp()] = c
|
||||
|
@ -165,13 +163,13 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R {
|
|||
// It does not look at the addr attached to the instance.
|
||||
// If a route is used, this will behave like a NAT for the return path.
|
||||
// Rewriting the source ip:port to what was last sent to from the origin
|
||||
func (r *R) AddRoute(ip net.IP, port uint16, c *nebula.Control) {
|
||||
func (r *R) AddRoute(ip netip.Addr, port uint16, c *nebula.Control) {
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
|
||||
inAddr := net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port))
|
||||
inAddr := netip.AddrPortFrom(ip, port)
|
||||
if _, ok := r.inNat[inAddr]; ok {
|
||||
panic("Duplicate listen address inNat: " + inAddr)
|
||||
panic("Duplicate listen address inNat: " + inAddr.String())
|
||||
}
|
||||
r.inNat[inAddr] = c
|
||||
}
|
||||
|
@ -198,7 +196,7 @@ func (r *R) renderFlow() {
|
|||
panic(err)
|
||||
}
|
||||
|
||||
var participants = map[string]struct{}{}
|
||||
var participants = map[netip.AddrPort]struct{}{}
|
||||
var participantsVals []string
|
||||
|
||||
fmt.Fprintln(f, "```mermaid")
|
||||
|
@ -215,7 +213,7 @@ func (r *R) renderFlow() {
|
|||
continue
|
||||
}
|
||||
participants[addr] = struct{}{}
|
||||
sanAddr := strings.Replace(addr, ":", "-", 1)
|
||||
sanAddr := strings.Replace(addr.String(), ":", "-", 1)
|
||||
participantsVals = append(participantsVals, sanAddr)
|
||||
fmt.Fprintf(
|
||||
f, " participant %s as Nebula: %s<br/>UDP: %s\n",
|
||||
|
@ -252,9 +250,9 @@ func (r *R) renderFlow() {
|
|||
|
||||
fmt.Fprintf(f,
|
||||
" %s%s%s: %s(%s), index %v, counter: %v\n",
|
||||
strings.Replace(p.from.GetUDPAddr(), ":", "-", 1),
|
||||
strings.Replace(p.from.GetUDPAddr().String(), ":", "-", 1),
|
||||
line,
|
||||
strings.Replace(p.to.GetUDPAddr(), ":", "-", 1),
|
||||
strings.Replace(p.to.GetUDPAddr().String(), ":", "-", 1),
|
||||
h.TypeName(), h.SubTypeName(), h.RemoteIndex, h.MessageCounter,
|
||||
)
|
||||
}
|
||||
|
@ -305,7 +303,7 @@ func (r *R) RenderHostmaps(title string, controls ...*nebula.Control) {
|
|||
func (r *R) renderHostmaps(title string) {
|
||||
c := maps.Values(r.controls)
|
||||
sort.SliceStable(c, func(i, j int) bool {
|
||||
return c[i].GetVpnIp() > c[j].GetVpnIp()
|
||||
return c[i].GetVpnIp().Compare(c[j].GetVpnIp()) > 0
|
||||
})
|
||||
|
||||
s := renderHostmaps(c...)
|
||||
|
@ -420,10 +418,8 @@ func (r *R) RouteUntilTxTun(sender *nebula.Control, receiver *nebula.Control) []
|
|||
|
||||
// Nope, lets push the sender along
|
||||
case p := <-udpTx:
|
||||
outAddr := sender.GetUDPAddr()
|
||||
r.Lock()
|
||||
inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
|
||||
c := r.getControl(outAddr, inAddr, p)
|
||||
c := r.getControl(sender.GetUDPAddr(), p.To, p)
|
||||
if c == nil {
|
||||
r.Unlock()
|
||||
panic("No control for udp tx")
|
||||
|
@ -479,10 +475,7 @@ func (r *R) RouteForAllUntilTxTun(receiver *nebula.Control) []byte {
|
|||
} else {
|
||||
// we are a udp tx, route and continue
|
||||
p := rx.Interface().(*udp.Packet)
|
||||
outAddr := cm[x].GetUDPAddr()
|
||||
|
||||
inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
|
||||
c := r.getControl(outAddr, inAddr, p)
|
||||
c := r.getControl(cm[x].GetUDPAddr(), p.To, p)
|
||||
if c == nil {
|
||||
r.Unlock()
|
||||
panic("No control for udp tx")
|
||||
|
@ -509,12 +502,10 @@ func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) {
|
|||
panic(err)
|
||||
}
|
||||
|
||||
outAddr := sender.GetUDPAddr()
|
||||
inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
|
||||
receiver := r.getControl(outAddr, inAddr, p)
|
||||
receiver := r.getControl(sender.GetUDPAddr(), p.To, p)
|
||||
if receiver == nil {
|
||||
r.Unlock()
|
||||
panic("Can't route for host: " + inAddr)
|
||||
panic("Can't RouteExitFunc for host: " + p.To.String())
|
||||
}
|
||||
|
||||
e := whatDo(p, receiver)
|
||||
|
@ -590,13 +581,13 @@ func (r *R) InjectUDPPacket(sender, receiver *nebula.Control, packet *udp.Packet
|
|||
// RouteForUntilAfterToAddr will route for sender and return only after it sees and sends a packet destined for toAddr
|
||||
// finish can be any of the exitType values except `keepRouting`, the default value is `routeAndExit`
|
||||
// If the router doesn't have the nebula controller for that address, we panic
|
||||
func (r *R) RouteForUntilAfterToAddr(sender *nebula.Control, toAddr *net.UDPAddr, finish ExitType) {
|
||||
func (r *R) RouteForUntilAfterToAddr(sender *nebula.Control, toAddr netip.AddrPort, finish ExitType) {
|
||||
if finish == KeepRouting {
|
||||
finish = RouteAndExit
|
||||
}
|
||||
|
||||
r.RouteExitFunc(sender, func(p *udp.Packet, r *nebula.Control) ExitType {
|
||||
if p.ToIp.Equal(toAddr.IP) && p.ToPort == uint16(toAddr.Port) {
|
||||
if p.To == toAddr {
|
||||
return finish
|
||||
}
|
||||
|
||||
|
@ -630,13 +621,10 @@ func (r *R) RouteForAllExitFunc(whatDo ExitFunc) {
|
|||
r.Lock()
|
||||
|
||||
p := rx.Interface().(*udp.Packet)
|
||||
|
||||
outAddr := cm[x].GetUDPAddr()
|
||||
inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
|
||||
receiver := r.getControl(outAddr, inAddr, p)
|
||||
receiver := r.getControl(cm[x].GetUDPAddr(), p.To, p)
|
||||
if receiver == nil {
|
||||
r.Unlock()
|
||||
panic("Can't route for host: " + inAddr)
|
||||
panic("Can't RouteForAllExitFunc for host: " + p.To.String())
|
||||
}
|
||||
|
||||
e := whatDo(p, receiver)
|
||||
|
@ -697,12 +685,10 @@ func (r *R) FlushAll() {
|
|||
|
||||
p := rx.Interface().(*udp.Packet)
|
||||
|
||||
outAddr := cm[x].GetUDPAddr()
|
||||
inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
|
||||
receiver := r.getControl(outAddr, inAddr, p)
|
||||
receiver := r.getControl(cm[x].GetUDPAddr(), p.To, p)
|
||||
if receiver == nil {
|
||||
r.Unlock()
|
||||
panic("Can't route for host: " + inAddr)
|
||||
panic("Can't FlushAll for host: " + p.To.String())
|
||||
}
|
||||
r.Unlock()
|
||||
}
|
||||
|
@ -710,28 +696,14 @@ func (r *R) FlushAll() {
|
|||
|
||||
// getControl performs or seeds NAT translation and returns the control for toAddr, p from fields may change
|
||||
// This is an internal router function, the caller must hold the lock
|
||||
func (r *R) getControl(fromAddr, toAddr string, p *udp.Packet) *nebula.Control {
|
||||
if newAddr, ok := r.outNat[fromAddr+":"+toAddr]; ok {
|
||||
p.FromIp = newAddr.IP
|
||||
p.FromPort = uint16(newAddr.Port)
|
||||
func (r *R) getControl(fromAddr, toAddr netip.AddrPort, p *udp.Packet) *nebula.Control {
|
||||
if newAddr, ok := r.outNat[fromAddr.String()+":"+toAddr.String()]; ok {
|
||||
p.From = newAddr
|
||||
}
|
||||
|
||||
c, ok := r.inNat[toAddr]
|
||||
if ok {
|
||||
sHost, sPort, err := net.SplitHostPort(toAddr)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
port, err := strconv.Atoi(sPort)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
r.outNat[c.GetUDPAddr()+":"+fromAddr] = net.UDPAddr{
|
||||
IP: net.ParseIP(sHost),
|
||||
Port: port,
|
||||
}
|
||||
r.outNat[c.GetUDPAddr().String()+":"+fromAddr.String()] = toAddr
|
||||
return c
|
||||
}
|
||||
|
||||
|
@ -746,8 +718,9 @@ func (r *R) formatUdpPacket(p *packet) string {
|
|||
}
|
||||
|
||||
from := "unknown"
|
||||
if c, ok := r.vpnControls[iputil.Ip2VpnIp(v4.SrcIP)]; ok {
|
||||
from = c.GetUDPAddr()
|
||||
srcAddr, _ := netip.AddrFromSlice(v4.SrcIP)
|
||||
if c, ok := r.vpnControls[srcAddr]; ok {
|
||||
from = c.GetUDPAddr().String()
|
||||
}
|
||||
|
||||
udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP)
|
||||
|
@ -759,7 +732,7 @@ func (r *R) formatUdpPacket(p *packet) string {
|
|||
return fmt.Sprintf(
|
||||
" %s-->>%s: src port: %v<br/>dest port: %v<br/>data: \"%v\"\n",
|
||||
strings.Replace(from, ":", "-", 1),
|
||||
strings.Replace(p.to.GetUDPAddr(), ":", "-", 1),
|
||||
strings.Replace(p.to.GetUDPAddr().String(), ":", "-", 1),
|
||||
udp.SrcPort,
|
||||
udp.DstPort,
|
||||
string(data.Payload()),
|
||||
|
|
100
firewall.go
100
firewall.go
|
@ -6,23 +6,23 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"net"
|
||||
"net/netip"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/rcrowley/go-metrics"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/cidr"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/firewall"
|
||||
)
|
||||
|
||||
type FirewallInterface interface {
|
||||
AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error
|
||||
AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error
|
||||
}
|
||||
|
||||
type conn struct {
|
||||
|
@ -52,8 +52,8 @@ type Firewall struct {
|
|||
DefaultTimeout time.Duration //linux: 600s
|
||||
|
||||
// Used to ensure we don't emit local packets for ips we don't own
|
||||
localIps *cidr.Tree4[struct{}]
|
||||
assignedCIDR *net.IPNet
|
||||
localIps *bart.Table[struct{}]
|
||||
assignedCIDR netip.Prefix
|
||||
hasSubnets bool
|
||||
|
||||
rules string
|
||||
|
@ -108,7 +108,7 @@ type FirewallRule struct {
|
|||
Any *firewallLocalCIDR
|
||||
Hosts map[string]*firewallLocalCIDR
|
||||
Groups []*firewallGroups
|
||||
CIDR *cidr.Tree4[*firewallLocalCIDR]
|
||||
CIDR *bart.Table[*firewallLocalCIDR]
|
||||
}
|
||||
|
||||
type firewallGroups struct {
|
||||
|
@ -122,7 +122,7 @@ type firewallPort map[int32]*FirewallCA
|
|||
|
||||
type firewallLocalCIDR struct {
|
||||
Any bool
|
||||
LocalCIDR *cidr.Tree4[struct{}]
|
||||
LocalCIDR *bart.Table[struct{}]
|
||||
}
|
||||
|
||||
// NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
|
||||
|
@ -144,20 +144,28 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
|
|||
max = defaultTimeout
|
||||
}
|
||||
|
||||
localIps := cidr.NewTree4[struct{}]()
|
||||
var assignedCIDR *net.IPNet
|
||||
localIps := new(bart.Table[struct{}])
|
||||
var assignedCIDR netip.Prefix
|
||||
var assignedSet bool
|
||||
for _, ip := range c.Details.Ips {
|
||||
ipNet := &net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}
|
||||
localIps.AddCIDR(ipNet, struct{}{})
|
||||
//TODO: IPV6-WORK the unmap is a bit unfortunate
|
||||
nip, _ := netip.AddrFromSlice(ip.IP)
|
||||
nip = nip.Unmap()
|
||||
nprefix := netip.PrefixFrom(nip, nip.BitLen())
|
||||
localIps.Insert(nprefix, struct{}{})
|
||||
|
||||
if assignedCIDR == nil {
|
||||
if !assignedSet {
|
||||
// Only grabbing the first one in the cert since any more than that currently has undefined behavior
|
||||
assignedCIDR = ipNet
|
||||
assignedCIDR = nprefix
|
||||
assignedSet = true
|
||||
}
|
||||
}
|
||||
|
||||
for _, n := range c.Details.Subnets {
|
||||
localIps.AddCIDR(n, struct{}{})
|
||||
nip, _ := netip.AddrFromSlice(n.IP)
|
||||
ones, _ := n.Mask.Size()
|
||||
nip = nip.Unmap()
|
||||
localIps.Insert(netip.PrefixFrom(nip, ones), struct{}{})
|
||||
}
|
||||
|
||||
return &Firewall{
|
||||
|
@ -237,15 +245,15 @@ func NewFirewallFromConfig(l *logrus.Logger, nc *cert.NebulaCertificate, c *conf
|
|||
}
|
||||
|
||||
// AddRule properly creates the in memory rule structure for a firewall table.
|
||||
func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error {
|
||||
func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error {
|
||||
// Under gomobile, stringing a nil pointer with fmt causes an abort in debug mode for iOS
|
||||
// https://github.com/golang/go/issues/14131
|
||||
sIp := ""
|
||||
if ip != nil {
|
||||
if ip.IsValid() {
|
||||
sIp = ip.String()
|
||||
}
|
||||
lIp := ""
|
||||
if localIp != nil {
|
||||
if localIp.IsValid() {
|
||||
lIp = localIp.String()
|
||||
}
|
||||
|
||||
|
@ -382,17 +390,17 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
|
|||
return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto)
|
||||
}
|
||||
|
||||
var cidr *net.IPNet
|
||||
var cidr netip.Prefix
|
||||
if r.Cidr != "" {
|
||||
_, cidr, err = net.ParseCIDR(r.Cidr)
|
||||
cidr, err = netip.ParsePrefix(r.Cidr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s rule #%v; cidr did not parse; %s", table, i, err)
|
||||
}
|
||||
}
|
||||
|
||||
var localCidr *net.IPNet
|
||||
var localCidr netip.Prefix
|
||||
if r.LocalCidr != "" {
|
||||
_, localCidr, err = net.ParseCIDR(r.LocalCidr)
|
||||
localCidr, err = netip.ParsePrefix(r.LocalCidr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s rule #%v; local_cidr did not parse; %s", table, i, err)
|
||||
}
|
||||
|
@ -421,7 +429,8 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
|
|||
|
||||
// Make sure remote address matches nebula certificate
|
||||
if remoteCidr := h.remoteCidr; remoteCidr != nil {
|
||||
ok, _ := remoteCidr.Contains(fp.RemoteIP)
|
||||
//TODO: this would be better if we had a least specific match lookup, could waste time here, need to benchmark since the algo is different
|
||||
_, ok := remoteCidr.Lookup(fp.RemoteIP)
|
||||
if !ok {
|
||||
f.metrics(incoming).droppedRemoteIP.Inc(1)
|
||||
return ErrInvalidRemoteIP
|
||||
|
@ -435,7 +444,8 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
|
|||
}
|
||||
|
||||
// Make sure we are supposed to be handling this local ip address
|
||||
ok, _ := f.localIps.Contains(fp.LocalIP)
|
||||
//TODO: this would be better if we had a least specific match lookup, could waste time here, need to benchmark since the algo is different
|
||||
_, ok := f.localIps.Lookup(fp.LocalIP)
|
||||
if !ok {
|
||||
f.metrics(incoming).droppedLocalIP.Inc(1)
|
||||
return ErrInvalidLocalIP
|
||||
|
@ -589,7 +599,6 @@ func (f *Firewall) addConn(fp firewall.Packet, incoming bool) {
|
|||
// Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel
|
||||
// Caller must own the connMutex lock!
|
||||
func (f *Firewall) evict(p firewall.Packet) {
|
||||
//TODO: report a stat if the tcp rtt tracking was never resolved?
|
||||
// Are we still tracking this conn?
|
||||
conntrack := f.Conntrack
|
||||
t, ok := conntrack.Conns[p]
|
||||
|
@ -633,7 +642,7 @@ func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.NebulaC
|
|||
return false
|
||||
}
|
||||
|
||||
func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error {
|
||||
func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error {
|
||||
if startPort > endPort {
|
||||
return fmt.Errorf("start port was lower than end port")
|
||||
}
|
||||
|
@ -677,12 +686,12 @@ func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.NebulaCer
|
|||
return fp[firewall.PortAny].match(p, c, caPool)
|
||||
}
|
||||
|
||||
func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, localIp *net.IPNet, caName, caSha string) error {
|
||||
func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, localIp netip.Prefix, caName, caSha string) error {
|
||||
fr := func() *FirewallRule {
|
||||
return &FirewallRule{
|
||||
Hosts: make(map[string]*firewallLocalCIDR),
|
||||
Groups: make([]*firewallGroups, 0),
|
||||
CIDR: cidr.NewTree4[*firewallLocalCIDR](),
|
||||
CIDR: new(bart.Table[*firewallLocalCIDR]),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -740,10 +749,10 @@ func (fc *FirewallCA) match(p firewall.Packet, c *cert.NebulaCertificate, caPool
|
|||
return fc.CANames[s.Details.Name].match(p, c)
|
||||
}
|
||||
|
||||
func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip *net.IPNet, localCIDR *net.IPNet) error {
|
||||
func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip, localCIDR netip.Prefix) error {
|
||||
flc := func() *firewallLocalCIDR {
|
||||
return &firewallLocalCIDR{
|
||||
LocalCIDR: cidr.NewTree4[struct{}](),
|
||||
LocalCIDR: new(bart.Table[struct{}]),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -780,8 +789,8 @@ func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip *n
|
|||
fr.Hosts[host] = nlc
|
||||
}
|
||||
|
||||
if ip != nil {
|
||||
_, nlc := fr.CIDR.GetCIDR(ip)
|
||||
if ip.IsValid() {
|
||||
nlc, _ := fr.CIDR.Get(ip)
|
||||
if nlc == nil {
|
||||
nlc = flc()
|
||||
}
|
||||
|
@ -789,14 +798,14 @@ func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip *n
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fr.CIDR.AddCIDR(ip, nlc)
|
||||
fr.CIDR.Insert(ip, nlc)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool {
|
||||
if len(groups) == 0 && host == "" && ip == nil {
|
||||
func (fr *FirewallRule) isAny(groups []string, host string, ip netip.Prefix) bool {
|
||||
if len(groups) == 0 && host == "" && !ip.IsValid() {
|
||||
return true
|
||||
}
|
||||
|
||||
|
@ -810,7 +819,7 @@ func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool
|
|||
return true
|
||||
}
|
||||
|
||||
if ip != nil && ip.Contains(net.IPv4(0, 0, 0, 0)) {
|
||||
if ip.IsValid() && ip.Bits() == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
|
@ -853,24 +862,31 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool
|
|||
}
|
||||
}
|
||||
|
||||
return fr.CIDR.EachContains(p.RemoteIP, func(flc *firewallLocalCIDR) bool {
|
||||
return flc.match(p, c)
|
||||
matched := false
|
||||
prefix := netip.PrefixFrom(p.RemoteIP, p.RemoteIP.BitLen())
|
||||
fr.CIDR.EachLookupPrefix(prefix, func(prefix netip.Prefix, val *firewallLocalCIDR) bool {
|
||||
if prefix.Contains(p.RemoteIP) && val.match(p, c) {
|
||||
matched = true
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
return matched
|
||||
}
|
||||
|
||||
func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp *net.IPNet) error {
|
||||
if localIp == nil {
|
||||
func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error {
|
||||
if !localIp.IsValid() {
|
||||
if !f.hasSubnets || f.defaultLocalCIDRAny {
|
||||
flc.Any = true
|
||||
return nil
|
||||
}
|
||||
|
||||
localIp = f.assignedCIDR
|
||||
} else if localIp.Contains(net.IPv4(0, 0, 0, 0)) {
|
||||
} else if localIp.Bits() == 0 {
|
||||
flc.Any = true
|
||||
}
|
||||
|
||||
flc.LocalCIDR.AddCIDR(localIp, struct{}{})
|
||||
flc.LocalCIDR.Insert(localIp, struct{}{})
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -883,7 +899,7 @@ func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.NebulaCertificate
|
|||
return true
|
||||
}
|
||||
|
||||
ok, _ := flc.LocalCIDR.Contains(p.LocalIP)
|
||||
_, ok := flc.LocalCIDR.Lookup(p.LocalIP)
|
||||
return ok
|
||||
}
|
||||
|
||||
|
|
|
@ -3,8 +3,7 @@ package firewall
|
|||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
type m map[string]interface{}
|
||||
|
@ -20,8 +19,8 @@ const (
|
|||
)
|
||||
|
||||
type Packet struct {
|
||||
LocalIP iputil.VpnIp
|
||||
RemoteIP iputil.VpnIp
|
||||
LocalIP netip.Addr
|
||||
RemoteIP netip.Addr
|
||||
LocalPort uint16
|
||||
RemotePort uint16
|
||||
Protocol uint8
|
||||
|
|
147
firewall_test.go
147
firewall_test.go
|
@ -5,13 +5,13 @@ import (
|
|||
"errors"
|
||||
"math"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/firewall"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/slackhq/nebula/test"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
@ -65,59 +65,62 @@ func TestFirewall_AddRule(t *testing.T) {
|
|||
assert.NotNil(t, fw.InRules)
|
||||
assert.NotNil(t, fw.OutRules)
|
||||
|
||||
_, ti, _ := net.ParseCIDR("1.2.3.4/32")
|
||||
ti, err := netip.ParsePrefix("1.2.3.4/32")
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", nil, nil, "", ""))
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
// An empty rule is any
|
||||
assert.True(t, fw.InRules.TCP[1].Any.Any.Any)
|
||||
assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
|
||||
assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
|
||||
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "", ""))
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
assert.Nil(t, fw.InRules.UDP[1].Any.Any)
|
||||
assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1")
|
||||
assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
|
||||
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", nil, nil, "", ""))
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
assert.Nil(t, fw.InRules.ICMP[1].Any.Any)
|
||||
assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
|
||||
assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
|
||||
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, nil, "", ""))
|
||||
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, netip.Prefix{}, "", ""))
|
||||
assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
|
||||
ok, _ := fw.OutRules.AnyProto[1].Any.CIDR.GetCIDR(ti)
|
||||
_, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti)
|
||||
assert.True(t, ok)
|
||||
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", nil, ti, "", ""))
|
||||
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", ""))
|
||||
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
|
||||
ok, _ = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.GetCIDR(ti)
|
||||
_, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti)
|
||||
assert.True(t, ok)
|
||||
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "ca-name", ""))
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", ""))
|
||||
assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
|
||||
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "", "ca-sha"))
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "ca-sha"))
|
||||
assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
|
||||
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, nil, "", ""))
|
||||
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
|
||||
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
_, anyIp, _ := net.ParseCIDR("0.0.0.0/0")
|
||||
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, nil, "", ""))
|
||||
anyIp, err := netip.ParsePrefix("0.0.0.0/0")
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", ""))
|
||||
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
|
||||
|
||||
// Test error conditions
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", nil, nil, "", ""))
|
||||
assert.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", nil, nil, "", ""))
|
||||
assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
assert.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
}
|
||||
|
||||
func TestFirewall_Drop(t *testing.T) {
|
||||
|
@ -126,8 +129,8 @@ func TestFirewall_Drop(t *testing.T) {
|
|||
l.SetOutput(ob)
|
||||
|
||||
p := firewall.Packet{
|
||||
LocalIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
|
||||
RemoteIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
|
||||
LocalIP: netip.MustParseAddr("1.2.3.4"),
|
||||
RemoteIP: netip.MustParseAddr("1.2.3.4"),
|
||||
LocalPort: 10,
|
||||
RemotePort: 90,
|
||||
Protocol: firewall.ProtoUDP,
|
||||
|
@ -152,16 +155,16 @@ func TestFirewall_Drop(t *testing.T) {
|
|||
ConnectionState: &ConnectionState{
|
||||
peerCert: &c,
|
||||
},
|
||||
vpnIp: iputil.Ip2VpnIp(ipNet.IP),
|
||||
vpnIp: netip.MustParseAddr("1.2.3.4"),
|
||||
}
|
||||
h.CreateRemoteCIDR(&c)
|
||||
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, nil, "", ""))
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
cp := cert.NewCAPool()
|
||||
|
||||
// Drop outbound
|
||||
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil))
|
||||
// Allow inbound
|
||||
resetConntrack(fw)
|
||||
assert.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||
|
@ -170,34 +173,34 @@ func TestFirewall_Drop(t *testing.T) {
|
|||
|
||||
// test remote mismatch
|
||||
oldRemote := p.RemoteIP
|
||||
p.RemoteIP = iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 10))
|
||||
p.RemoteIP = netip.MustParseAddr("1.2.3.10")
|
||||
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
|
||||
p.RemoteIP = oldRemote
|
||||
|
||||
// ensure signer doesn't get in the way of group checks
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "", "signer-shasum"))
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "", "signer-shasum-bad"))
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
|
||||
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
|
||||
|
||||
// test caSha doesn't drop on match
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "", "signer-shasum-bad"))
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "", "signer-shasum"))
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
|
||||
assert.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||
|
||||
// ensure ca name doesn't get in the way of group checks
|
||||
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "ca-good", ""))
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "ca-good-bad", ""))
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
|
||||
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
|
||||
|
||||
// test caName doesn't drop on match
|
||||
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "ca-good-bad", ""))
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "ca-good", ""))
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
|
||||
assert.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||
}
|
||||
|
||||
|
@ -207,10 +210,9 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
|||
TCP: firewallPort{},
|
||||
}
|
||||
|
||||
_, n, _ := net.ParseCIDR("172.1.1.1/32")
|
||||
goodLocalCIDRIP := iputil.Ip2VpnIp(n.IP)
|
||||
_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", n, nil, "", "")
|
||||
_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", nil, n, "", "")
|
||||
pfix := netip.MustParsePrefix("172.1.1.1/32")
|
||||
_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix, netip.Prefix{}, "", "")
|
||||
_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix, "", "")
|
||||
cp := cert.NewCAPool()
|
||||
|
||||
b.Run("fail on proto", func(b *testing.B) {
|
||||
|
@ -231,10 +233,9 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
|||
|
||||
b.Run("pass proto, port, fail on local CIDR", func(b *testing.B) {
|
||||
c := &cert.NebulaCertificate{}
|
||||
ip, _, _ := net.ParseCIDR("9.254.254.254/32")
|
||||
lip := iputil.Ip2VpnIp(ip)
|
||||
ip := netip.MustParsePrefix("9.254.254.254/32")
|
||||
for n := 0; n < b.N; n++ {
|
||||
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: lip}, true, c, cp))
|
||||
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: ip.Addr()}, true, c, cp))
|
||||
}
|
||||
})
|
||||
|
||||
|
@ -262,7 +263,7 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
|||
},
|
||||
}
|
||||
for n := 0; n < b.N; n++ {
|
||||
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: goodLocalCIDRIP}, true, c, cp))
|
||||
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: pfix.Addr()}, true, c, cp))
|
||||
}
|
||||
})
|
||||
|
||||
|
@ -286,7 +287,7 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
|||
},
|
||||
}
|
||||
for n := 0; n < b.N; n++ {
|
||||
assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: goodLocalCIDRIP}, true, c, cp))
|
||||
assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: pfix.Addr()}, true, c, cp))
|
||||
}
|
||||
})
|
||||
|
||||
|
@ -363,8 +364,8 @@ func TestFirewall_Drop2(t *testing.T) {
|
|||
l.SetOutput(ob)
|
||||
|
||||
p := firewall.Packet{
|
||||
LocalIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
|
||||
RemoteIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
|
||||
LocalIP: netip.MustParseAddr("1.2.3.4"),
|
||||
RemoteIP: netip.MustParseAddr("1.2.3.4"),
|
||||
LocalPort: 10,
|
||||
RemotePort: 90,
|
||||
Protocol: firewall.ProtoUDP,
|
||||
|
@ -387,7 +388,7 @@ func TestFirewall_Drop2(t *testing.T) {
|
|||
ConnectionState: &ConnectionState{
|
||||
peerCert: &c,
|
||||
},
|
||||
vpnIp: iputil.Ip2VpnIp(ipNet.IP),
|
||||
vpnIp: netip.MustParseAddr(ipNet.IP.String()),
|
||||
}
|
||||
h.CreateRemoteCIDR(&c)
|
||||
|
||||
|
@ -406,7 +407,7 @@ func TestFirewall_Drop2(t *testing.T) {
|
|||
h1.CreateRemoteCIDR(&c1)
|
||||
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, nil, "", ""))
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
cp := cert.NewCAPool()
|
||||
|
||||
// h1/c1 lacks the proper groups
|
||||
|
@ -422,8 +423,8 @@ func TestFirewall_Drop3(t *testing.T) {
|
|||
l.SetOutput(ob)
|
||||
|
||||
p := firewall.Packet{
|
||||
LocalIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
|
||||
RemoteIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
|
||||
LocalIP: netip.MustParseAddr("1.2.3.4"),
|
||||
RemoteIP: netip.MustParseAddr("1.2.3.4"),
|
||||
LocalPort: 1,
|
||||
RemotePort: 1,
|
||||
Protocol: firewall.ProtoUDP,
|
||||
|
@ -453,7 +454,7 @@ func TestFirewall_Drop3(t *testing.T) {
|
|||
ConnectionState: &ConnectionState{
|
||||
peerCert: &c1,
|
||||
},
|
||||
vpnIp: iputil.Ip2VpnIp(ipNet.IP),
|
||||
vpnIp: netip.MustParseAddr(ipNet.IP.String()),
|
||||
}
|
||||
h1.CreateRemoteCIDR(&c1)
|
||||
|
||||
|
@ -468,7 +469,7 @@ func TestFirewall_Drop3(t *testing.T) {
|
|||
ConnectionState: &ConnectionState{
|
||||
peerCert: &c2,
|
||||
},
|
||||
vpnIp: iputil.Ip2VpnIp(ipNet.IP),
|
||||
vpnIp: netip.MustParseAddr(ipNet.IP.String()),
|
||||
}
|
||||
h2.CreateRemoteCIDR(&c2)
|
||||
|
||||
|
@ -483,13 +484,13 @@ func TestFirewall_Drop3(t *testing.T) {
|
|||
ConnectionState: &ConnectionState{
|
||||
peerCert: &c3,
|
||||
},
|
||||
vpnIp: iputil.Ip2VpnIp(ipNet.IP),
|
||||
vpnIp: netip.MustParseAddr(ipNet.IP.String()),
|
||||
}
|
||||
h3.CreateRemoteCIDR(&c3)
|
||||
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", nil, nil, "", ""))
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", nil, nil, "", "signer-sha"))
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-sha"))
|
||||
cp := cert.NewCAPool()
|
||||
|
||||
// c1 should pass because host match
|
||||
|
@ -508,8 +509,8 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
|
|||
l.SetOutput(ob)
|
||||
|
||||
p := firewall.Packet{
|
||||
LocalIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
|
||||
RemoteIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
|
||||
LocalIP: netip.MustParseAddr("1.2.3.4"),
|
||||
RemoteIP: netip.MustParseAddr("1.2.3.4"),
|
||||
LocalPort: 10,
|
||||
RemotePort: 90,
|
||||
Protocol: firewall.ProtoUDP,
|
||||
|
@ -534,12 +535,12 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
|
|||
ConnectionState: &ConnectionState{
|
||||
peerCert: &c,
|
||||
},
|
||||
vpnIp: iputil.Ip2VpnIp(ipNet.IP),
|
||||
vpnIp: netip.MustParseAddr(ipNet.IP.String()),
|
||||
}
|
||||
h.CreateRemoteCIDR(&c)
|
||||
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, nil, "", ""))
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
cp := cert.NewCAPool()
|
||||
|
||||
// Drop outbound
|
||||
|
@ -552,7 +553,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
|
|||
|
||||
oldFw := fw
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", nil, nil, "", ""))
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
fw.Conntrack = oldFw.Conntrack
|
||||
fw.rulesVersion = oldFw.rulesVersion + 1
|
||||
|
||||
|
@ -561,7 +562,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
|
|||
|
||||
oldFw = fw
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", nil, nil, "", ""))
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
fw.Conntrack = oldFw.Conntrack
|
||||
fw.rulesVersion = oldFw.rulesVersion + 1
|
||||
|
||||
|
@ -725,13 +726,13 @@ func TestNewFirewallFromConfig(t *testing.T) {
|
|||
conf = config.NewC(l)
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}}
|
||||
_, err = NewFirewallFromConfig(l, c, conf)
|
||||
assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; invalid CIDR address: testh")
|
||||
assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
|
||||
|
||||
// Test local_cidr parse error
|
||||
conf = config.NewC(l)
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "local_cidr": "testh", "proto": "any"}}}
|
||||
_, err = NewFirewallFromConfig(l, c, conf)
|
||||
assert.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; invalid CIDR address: testh")
|
||||
assert.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
|
||||
|
||||
// Test both group and groups
|
||||
conf = config.NewC(l)
|
||||
|
@ -747,78 +748,78 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
|
|||
mf := &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}}
|
||||
assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall)
|
||||
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
|
||||
|
||||
// Test adding udp rule
|
||||
conf = config.NewC(l)
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}}
|
||||
assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall)
|
||||
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
|
||||
|
||||
// Test adding icmp rule
|
||||
conf = config.NewC(l)
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}}
|
||||
assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall)
|
||||
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
|
||||
|
||||
// Test adding any rule
|
||||
conf = config.NewC(l)
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
|
||||
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall)
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
|
||||
|
||||
// Test adding rule with cidr
|
||||
cidr := &net.IPNet{IP: net.ParseIP("10.0.0.0").To4(), Mask: net.IPv4Mask(255, 0, 0, 0)}
|
||||
cidr := netip.MustParsePrefix("10.0.0.0/8")
|
||||
conf = config.NewC(l)
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "cidr": cidr.String()}}}
|
||||
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: nil}, mf.lastCall)
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: netip.Prefix{}}, mf.lastCall)
|
||||
|
||||
// Test adding rule with local_cidr
|
||||
conf = config.NewC(l)
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "local_cidr": cidr.String()}}}
|
||||
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: cidr}, mf.lastCall)
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall)
|
||||
|
||||
// Test adding rule with ca_sha
|
||||
conf = config.NewC(l)
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
|
||||
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: nil, caSha: "12312313123"}, mf.lastCall)
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caSha: "12312313123"}, mf.lastCall)
|
||||
|
||||
// Test adding rule with ca_name
|
||||
conf = config.NewC(l)
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}}
|
||||
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: nil, caName: "root01"}, mf.lastCall)
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caName: "root01"}, mf.lastCall)
|
||||
|
||||
// Test single group
|
||||
conf = config.NewC(l)
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}}
|
||||
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil, localIp: nil}, mf.lastCall)
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
|
||||
|
||||
// Test single groups
|
||||
conf = config.NewC(l)
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}}
|
||||
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil, localIp: nil}, mf.lastCall)
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
|
||||
|
||||
// Test multiple AND groups
|
||||
conf = config.NewC(l)
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
|
||||
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: nil, localIp: nil}, mf.lastCall)
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
|
||||
|
||||
// Test Add error
|
||||
conf = config.NewC(l)
|
||||
|
@ -871,8 +872,8 @@ type addRuleCall struct {
|
|||
endPort int32
|
||||
groups []string
|
||||
host string
|
||||
ip *net.IPNet
|
||||
localIp *net.IPNet
|
||||
ip netip.Prefix
|
||||
localIp netip.Prefix
|
||||
caName string
|
||||
caSha string
|
||||
}
|
||||
|
@ -882,7 +883,7 @@ type mockFirewall struct {
|
|||
nextCallReturn error
|
||||
}
|
||||
|
||||
func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error {
|
||||
func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip netip.Prefix, localIp netip.Prefix, caName string, caSha string) error {
|
||||
mf.lastCall = addRuleCall{
|
||||
incoming: incoming,
|
||||
proto: proto,
|
||||
|
|
2
go.mod
2
go.mod
|
@ -38,8 +38,10 @@ require (
|
|||
|
||||
require (
|
||||
github.com/beorn7/perks v1.0.1 // indirect
|
||||
github.com/bits-and-blooms/bitset v1.13.0 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.2.0 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/gaissmai/bart v0.11.1 // indirect
|
||||
github.com/google/btree v1.1.2 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/prometheus/client_model v0.5.0 // indirect
|
||||
|
|
6
go.sum
6
go.sum
|
@ -14,6 +14,8 @@ github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24
|
|||
github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8=
|
||||
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
||||
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
|
||||
github.com/bits-and-blooms/bitset v1.13.0 h1:bAQ9OPNFYbGHV6Nez0tmNI0RiEu7/hxlYJRUA0wFAVE=
|
||||
github.com/bits-and-blooms/bitset v1.13.0/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8=
|
||||
github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
|
||||
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
|
@ -24,6 +26,10 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
|
|||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg=
|
||||
github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
|
||||
github.com/gaissmai/bart v0.10.0 h1:yCZCYF8xzcRnqDe4jMk14NlJjL1WmMsE7ilBzvuHtiI=
|
||||
github.com/gaissmai/bart v0.10.0/go.mod h1:KHeYECXQiBjTzQz/om2tqn3sZF1J7hw9m6z41ftj3fg=
|
||||
github.com/gaissmai/bart v0.11.1 h1:5Uv5XwsaFBRo4E5VBcb9TzY8B7zxFf+U7isDxqOrRfc=
|
||||
github.com/gaissmai/bart v0.11.1/go.mod h1:KHeYECXQiBjTzQz/om2tqn3sZF1J7hw9m6z41ftj3fg=
|
||||
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
|
||||
github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
|
||||
github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
|
||||
|
|
|
@ -1,13 +1,12 @@
|
|||
package nebula
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/flynn/noise"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/header"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/slackhq/nebula/udp"
|
||||
)
|
||||
|
||||
// NOISE IX Handshakes
|
||||
|
@ -63,7 +62,7 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
|
|||
return true
|
||||
}
|
||||
|
||||
func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []byte, h *header.H) {
|
||||
func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) {
|
||||
certState := f.pki.GetCertState()
|
||||
ci := NewConnectionState(f.l, f.cipher, certState, false, noise.HandshakeIX, []byte{}, 0)
|
||||
// Mark packet 1 as seen so it doesn't show up as missed
|
||||
|
@ -99,12 +98,26 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
|
|||
e.Info("Invalid certificate from host")
|
||||
return
|
||||
}
|
||||
vpnIp := iputil.Ip2VpnIp(remoteCert.Details.Ips[0].IP)
|
||||
|
||||
vpnIp, ok := netip.AddrFromSlice(remoteCert.Details.Ips[0].IP)
|
||||
if !ok {
|
||||
e := f.l.WithError(err).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"})
|
||||
|
||||
if f.l.Level > logrus.DebugLevel {
|
||||
e = e.WithField("cert", remoteCert)
|
||||
}
|
||||
|
||||
e.Info("Invalid vpn ip from host")
|
||||
return
|
||||
}
|
||||
|
||||
vpnIp = vpnIp.Unmap()
|
||||
certName := remoteCert.Details.Name
|
||||
fingerprint, _ := remoteCert.Sha256Sum()
|
||||
issuer := remoteCert.Details.Issuer
|
||||
|
||||
if vpnIp == f.myVpnIp {
|
||||
if vpnIp == f.myVpnNet.Addr() {
|
||||
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("fingerprint", fingerprint).
|
||||
|
@ -113,8 +126,8 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
|
|||
return
|
||||
}
|
||||
|
||||
if addr != nil {
|
||||
if !f.lightHouse.GetRemoteAllowList().Allow(vpnIp, addr.IP) {
|
||||
if addr.IsValid() {
|
||||
if !f.lightHouse.GetRemoteAllowList().Allow(vpnIp, addr.Addr()) {
|
||||
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
|
||||
return
|
||||
}
|
||||
|
@ -138,8 +151,8 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
|
|||
HandshakePacket: make(map[uint8][]byte, 0),
|
||||
lastHandshakeTime: hs.Details.Time,
|
||||
relayState: RelayState{
|
||||
relays: map[iputil.VpnIp]struct{}{},
|
||||
relayForByIp: map[iputil.VpnIp]*Relay{},
|
||||
relays: map[netip.Addr]struct{}{},
|
||||
relayForByIp: map[netip.Addr]*Relay{},
|
||||
relayForByIdx: map[uint32]*Relay{},
|
||||
},
|
||||
}
|
||||
|
@ -218,7 +231,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
|
|||
|
||||
msg = existing.HandshakePacket[2]
|
||||
f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
|
||||
if addr != nil {
|
||||
if addr.IsValid() {
|
||||
err := f.outside.WriteTo(msg, addr)
|
||||
if err != nil {
|
||||
f.l.WithField("vpnIp", existing.vpnIp).WithField("udpAddr", addr).
|
||||
|
@ -284,7 +297,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
|
|||
|
||||
// Do the send
|
||||
f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
|
||||
if addr != nil {
|
||||
if addr.IsValid() {
|
||||
err = f.outside.WriteTo(msg, addr)
|
||||
if err != nil {
|
||||
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
|
||||
|
@ -326,7 +339,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
|
|||
return
|
||||
}
|
||||
|
||||
func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool {
|
||||
func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool {
|
||||
if hh == nil {
|
||||
// Nothing here to tear down, got a bogus stage 2 packet
|
||||
return true
|
||||
|
@ -336,8 +349,8 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *Handsha
|
|||
defer hh.Unlock()
|
||||
|
||||
hostinfo := hh.hostinfo
|
||||
if addr != nil {
|
||||
if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, addr.IP) {
|
||||
if addr.IsValid() {
|
||||
if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, addr.Addr()) {
|
||||
f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
|
||||
return false
|
||||
}
|
||||
|
@ -389,7 +402,20 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *Handsha
|
|||
return true
|
||||
}
|
||||
|
||||
vpnIp := iputil.Ip2VpnIp(remoteCert.Details.Ips[0].IP)
|
||||
vpnIp, ok := netip.AddrFromSlice(remoteCert.Details.Ips[0].IP)
|
||||
if !ok {
|
||||
e := f.l.WithError(err).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"})
|
||||
|
||||
if f.l.Level > logrus.DebugLevel {
|
||||
e = e.WithField("cert", remoteCert)
|
||||
}
|
||||
|
||||
e.Info("Invalid vpn ip from host")
|
||||
return true
|
||||
}
|
||||
|
||||
vpnIp = vpnIp.Unmap()
|
||||
certName := remoteCert.Details.Name
|
||||
fingerprint, _ := remoteCert.Sha256Sum()
|
||||
issuer := remoteCert.Details.Issuer
|
||||
|
@ -453,7 +479,7 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *Handsha
|
|||
ci.eKey = NewNebulaCipherState(eKey)
|
||||
|
||||
// Make sure the current udpAddr being used is set for responding
|
||||
if addr != nil {
|
||||
if addr.IsValid() {
|
||||
hostinfo.SetRemote(addr)
|
||||
} else {
|
||||
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp)
|
||||
|
|
|
@ -6,15 +6,15 @@ import (
|
|||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/rcrowley/go-metrics"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/header"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/slackhq/nebula/udp"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -46,14 +46,14 @@ type HandshakeManager struct {
|
|||
// Mutex for interacting with the vpnIps and indexes maps
|
||||
sync.RWMutex
|
||||
|
||||
vpnIps map[iputil.VpnIp]*HandshakeHostInfo
|
||||
vpnIps map[netip.Addr]*HandshakeHostInfo
|
||||
indexes map[uint32]*HandshakeHostInfo
|
||||
|
||||
mainHostMap *HostMap
|
||||
lightHouse *LightHouse
|
||||
outside udp.Conn
|
||||
config HandshakeConfig
|
||||
OutboundHandshakeTimer *LockingTimerWheel[iputil.VpnIp]
|
||||
OutboundHandshakeTimer *LockingTimerWheel[netip.Addr]
|
||||
messageMetrics *MessageMetrics
|
||||
metricInitiated metrics.Counter
|
||||
metricTimedOut metrics.Counter
|
||||
|
@ -61,17 +61,17 @@ type HandshakeManager struct {
|
|||
l *logrus.Logger
|
||||
|
||||
// can be used to trigger outbound handshake for the given vpnIp
|
||||
trigger chan iputil.VpnIp
|
||||
trigger chan netip.Addr
|
||||
}
|
||||
|
||||
type HandshakeHostInfo struct {
|
||||
sync.Mutex
|
||||
|
||||
startTime time.Time // Time that we first started trying with this handshake
|
||||
ready bool // Is the handshake ready
|
||||
counter int // How many attempts have we made so far
|
||||
lastRemotes []*udp.Addr // Remotes that we sent to during the previous attempt
|
||||
packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes
|
||||
startTime time.Time // Time that we first started trying with this handshake
|
||||
ready bool // Is the handshake ready
|
||||
counter int // How many attempts have we made so far
|
||||
lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt
|
||||
packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes
|
||||
|
||||
hostinfo *HostInfo
|
||||
}
|
||||
|
@ -103,14 +103,14 @@ func (hh *HandshakeHostInfo) cachePacket(l *logrus.Logger, t header.MessageType,
|
|||
|
||||
func NewHandshakeManager(l *logrus.Logger, mainHostMap *HostMap, lightHouse *LightHouse, outside udp.Conn, config HandshakeConfig) *HandshakeManager {
|
||||
return &HandshakeManager{
|
||||
vpnIps: map[iputil.VpnIp]*HandshakeHostInfo{},
|
||||
vpnIps: map[netip.Addr]*HandshakeHostInfo{},
|
||||
indexes: map[uint32]*HandshakeHostInfo{},
|
||||
mainHostMap: mainHostMap,
|
||||
lightHouse: lightHouse,
|
||||
outside: outside,
|
||||
config: config,
|
||||
trigger: make(chan iputil.VpnIp, config.triggerBuffer),
|
||||
OutboundHandshakeTimer: NewLockingTimerWheel[iputil.VpnIp](config.tryInterval, hsTimeout(config.retries, config.tryInterval)),
|
||||
trigger: make(chan netip.Addr, config.triggerBuffer),
|
||||
OutboundHandshakeTimer: NewLockingTimerWheel[netip.Addr](config.tryInterval, hsTimeout(config.retries, config.tryInterval)),
|
||||
messageMetrics: config.messageMetrics,
|
||||
metricInitiated: metrics.GetOrRegisterCounter("handshake_manager.initiated", nil),
|
||||
metricTimedOut: metrics.GetOrRegisterCounter("handshake_manager.timed_out", nil),
|
||||
|
@ -134,10 +134,10 @@ func (c *HandshakeManager) Run(ctx context.Context) {
|
|||
}
|
||||
}
|
||||
|
||||
func (hm *HandshakeManager) HandleIncoming(addr *udp.Addr, via *ViaSender, packet []byte, h *header.H) {
|
||||
func (hm *HandshakeManager) HandleIncoming(addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) {
|
||||
// First remote allow list check before we know the vpnIp
|
||||
if addr != nil {
|
||||
if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnIp(addr.IP) {
|
||||
if addr.IsValid() {
|
||||
if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnIp(addr.Addr()) {
|
||||
hm.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
|
||||
return
|
||||
}
|
||||
|
@ -170,7 +170,7 @@ func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time) {
|
|||
}
|
||||
}
|
||||
|
||||
func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTriggered bool) {
|
||||
func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered bool) {
|
||||
hh := hm.queryVpnIp(vpnIp)
|
||||
if hh == nil {
|
||||
return
|
||||
|
@ -212,7 +212,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
|
|||
}
|
||||
|
||||
remotes := hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges())
|
||||
remotesHaveChanged := !udp.AddrSlice(remotes).Equal(hh.lastRemotes)
|
||||
remotesHaveChanged := !slices.Equal(remotes, hh.lastRemotes)
|
||||
|
||||
// We only care about a lighthouse trigger if we have new remotes to send to.
|
||||
// This is a very specific optimization for a fast lighthouse reply.
|
||||
|
@ -234,8 +234,8 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
|
|||
}
|
||||
|
||||
// Send the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply
|
||||
var sentTo []*udp.Addr
|
||||
hostinfo.remotes.ForEach(hm.mainHostMap.GetPreferredRanges(), func(addr *udp.Addr, _ bool) {
|
||||
var sentTo []netip.AddrPort
|
||||
hostinfo.remotes.ForEach(hm.mainHostMap.GetPreferredRanges(), func(addr netip.AddrPort, _ bool) {
|
||||
hm.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1)
|
||||
err := hm.outside.WriteTo(hostinfo.HandshakePacket[0], addr)
|
||||
if err != nil {
|
||||
|
@ -268,13 +268,13 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
|
|||
// Send a RelayRequest to all known Relay IP's
|
||||
for _, relay := range hostinfo.remotes.relays {
|
||||
// Don't relay to myself, and don't relay through the host I'm trying to connect to
|
||||
if *relay == vpnIp || *relay == hm.lightHouse.myVpnIp {
|
||||
if relay == vpnIp || relay == hm.lightHouse.myVpnNet.Addr() {
|
||||
continue
|
||||
}
|
||||
relayHostInfo := hm.mainHostMap.QueryVpnIp(*relay)
|
||||
if relayHostInfo == nil || relayHostInfo.remote == nil {
|
||||
relayHostInfo := hm.mainHostMap.QueryVpnIp(relay)
|
||||
if relayHostInfo == nil || !relayHostInfo.remote.IsValid() {
|
||||
hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Establish tunnel to relay target")
|
||||
hm.f.Handshake(*relay)
|
||||
hm.f.Handshake(relay)
|
||||
continue
|
||||
}
|
||||
// Check the relay HostInfo to see if we already established a relay through it
|
||||
|
@ -285,12 +285,17 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
|
|||
hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false)
|
||||
case Requested:
|
||||
hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request")
|
||||
|
||||
//TODO: IPV6-WORK
|
||||
myVpnIpB := hm.f.myVpnNet.Addr().As4()
|
||||
theirVpnIpB := vpnIp.As4()
|
||||
|
||||
// Re-send the CreateRelay request, in case the previous one was lost.
|
||||
m := NebulaControl{
|
||||
Type: NebulaControl_CreateRelayRequest,
|
||||
InitiatorRelayIndex: existingRelay.LocalIndex,
|
||||
RelayFromIp: uint32(hm.lightHouse.myVpnIp),
|
||||
RelayToIp: uint32(vpnIp),
|
||||
RelayFromIp: binary.BigEndian.Uint32(myVpnIpB[:]),
|
||||
RelayToIp: binary.BigEndian.Uint32(theirVpnIpB[:]),
|
||||
}
|
||||
msg, err := m.Marshal()
|
||||
if err != nil {
|
||||
|
@ -301,10 +306,10 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
|
|||
// This must send over the hostinfo, not over hm.Hosts[ip]
|
||||
hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
|
||||
hm.l.WithFields(logrus.Fields{
|
||||
"relayFrom": hm.lightHouse.myVpnIp,
|
||||
"relayFrom": hm.f.myVpnNet.Addr(),
|
||||
"relayTo": vpnIp,
|
||||
"initiatorRelayIndex": existingRelay.LocalIndex,
|
||||
"relay": *relay}).
|
||||
"relay": relay}).
|
||||
Info("send CreateRelayRequest")
|
||||
}
|
||||
default:
|
||||
|
@ -316,17 +321,21 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
|
|||
}
|
||||
} else {
|
||||
// No relays exist or requested yet.
|
||||
if relayHostInfo.remote != nil {
|
||||
if relayHostInfo.remote.IsValid() {
|
||||
idx, err := AddRelay(hm.l, relayHostInfo, hm.mainHostMap, vpnIp, nil, TerminalType, Requested)
|
||||
if err != nil {
|
||||
hostinfo.logger(hm.l).WithField("relay", relay.String()).WithError(err).Info("Failed to add relay to hostmap")
|
||||
}
|
||||
|
||||
//TODO: IPV6-WORK
|
||||
myVpnIpB := hm.f.myVpnNet.Addr().As4()
|
||||
theirVpnIpB := vpnIp.As4()
|
||||
|
||||
m := NebulaControl{
|
||||
Type: NebulaControl_CreateRelayRequest,
|
||||
InitiatorRelayIndex: idx,
|
||||
RelayFromIp: uint32(hm.lightHouse.myVpnIp),
|
||||
RelayToIp: uint32(vpnIp),
|
||||
RelayFromIp: binary.BigEndian.Uint32(myVpnIpB[:]),
|
||||
RelayToIp: binary.BigEndian.Uint32(theirVpnIpB[:]),
|
||||
}
|
||||
msg, err := m.Marshal()
|
||||
if err != nil {
|
||||
|
@ -336,10 +345,10 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
|
|||
} else {
|
||||
hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
|
||||
hm.l.WithFields(logrus.Fields{
|
||||
"relayFrom": hm.lightHouse.myVpnIp,
|
||||
"relayFrom": hm.f.myVpnNet.Addr(),
|
||||
"relayTo": vpnIp,
|
||||
"initiatorRelayIndex": idx,
|
||||
"relay": *relay}).
|
||||
"relay": relay}).
|
||||
Info("send CreateRelayRequest")
|
||||
}
|
||||
}
|
||||
|
@ -355,7 +364,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
|
|||
|
||||
// GetOrHandshake will try to find a hostinfo with a fully formed tunnel or start a new handshake if one is not present
|
||||
// The 2nd argument will be true if the hostinfo is ready to transmit traffic
|
||||
func (hm *HandshakeManager) GetOrHandshake(vpnIp iputil.VpnIp, cacheCb func(*HandshakeHostInfo)) (*HostInfo, bool) {
|
||||
func (hm *HandshakeManager) GetOrHandshake(vpnIp netip.Addr, cacheCb func(*HandshakeHostInfo)) (*HostInfo, bool) {
|
||||
hm.mainHostMap.RLock()
|
||||
h, ok := hm.mainHostMap.Hosts[vpnIp]
|
||||
hm.mainHostMap.RUnlock()
|
||||
|
@ -372,7 +381,7 @@ func (hm *HandshakeManager) GetOrHandshake(vpnIp iputil.VpnIp, cacheCb func(*Han
|
|||
}
|
||||
|
||||
// StartHandshake will ensure a handshake is currently being attempted for the provided vpn ip
|
||||
func (hm *HandshakeManager) StartHandshake(vpnIp iputil.VpnIp, cacheCb func(*HandshakeHostInfo)) *HostInfo {
|
||||
func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*HandshakeHostInfo)) *HostInfo {
|
||||
hm.Lock()
|
||||
|
||||
if hh, ok := hm.vpnIps[vpnIp]; ok {
|
||||
|
@ -388,8 +397,8 @@ func (hm *HandshakeManager) StartHandshake(vpnIp iputil.VpnIp, cacheCb func(*Han
|
|||
vpnIp: vpnIp,
|
||||
HandshakePacket: make(map[uint8][]byte, 0),
|
||||
relayState: RelayState{
|
||||
relays: map[iputil.VpnIp]struct{}{},
|
||||
relayForByIp: map[iputil.VpnIp]*Relay{},
|
||||
relays: map[netip.Addr]struct{}{},
|
||||
relayForByIp: map[netip.Addr]*Relay{},
|
||||
relayForByIdx: map[uint32]*Relay{},
|
||||
},
|
||||
}
|
||||
|
@ -555,7 +564,7 @@ func (c *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) {
|
|||
func (c *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) {
|
||||
delete(c.vpnIps, hostinfo.vpnIp)
|
||||
if len(c.vpnIps) == 0 {
|
||||
c.vpnIps = map[iputil.VpnIp]*HandshakeHostInfo{}
|
||||
c.vpnIps = map[netip.Addr]*HandshakeHostInfo{}
|
||||
}
|
||||
|
||||
delete(c.indexes, hostinfo.localIndexId)
|
||||
|
@ -570,7 +579,7 @@ func (c *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) {
|
|||
}
|
||||
}
|
||||
|
||||
func (hm *HandshakeManager) QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo {
|
||||
func (hm *HandshakeManager) QueryVpnIp(vpnIp netip.Addr) *HostInfo {
|
||||
hh := hm.queryVpnIp(vpnIp)
|
||||
if hh != nil {
|
||||
return hh.hostinfo
|
||||
|
@ -579,7 +588,7 @@ func (hm *HandshakeManager) QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo {
|
|||
|
||||
}
|
||||
|
||||
func (hm *HandshakeManager) queryVpnIp(vpnIp iputil.VpnIp) *HandshakeHostInfo {
|
||||
func (hm *HandshakeManager) queryVpnIp(vpnIp netip.Addr) *HandshakeHostInfo {
|
||||
hm.RLock()
|
||||
defer hm.RUnlock()
|
||||
return hm.vpnIps[vpnIp]
|
||||
|
@ -599,7 +608,7 @@ func (hm *HandshakeManager) queryIndex(index uint32) *HandshakeHostInfo {
|
|||
return hm.indexes[index]
|
||||
}
|
||||
|
||||
func (c *HandshakeManager) GetPreferredRanges() []*net.IPNet {
|
||||
func (c *HandshakeManager) GetPreferredRanges() []netip.Prefix {
|
||||
return c.mainHostMap.GetPreferredRanges()
|
||||
}
|
||||
|
||||
|
|
|
@ -1,13 +1,12 @@
|
|||
package nebula
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/header"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/slackhq/nebula/test"
|
||||
"github.com/slackhq/nebula/udp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
@ -15,10 +14,11 @@ import (
|
|||
|
||||
func Test_NewHandshakeManagerVpnIp(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
|
||||
ip := iputil.Ip2VpnIp(net.ParseIP("172.1.1.2"))
|
||||
preferredRanges := []*net.IPNet{localrange}
|
||||
vpncidr := netip.MustParsePrefix("172.1.1.1/24")
|
||||
localrange := netip.MustParsePrefix("10.1.1.1/24")
|
||||
ip := netip.MustParseAddr("172.1.1.2")
|
||||
|
||||
preferredRanges := []netip.Prefix{localrange}
|
||||
mainHM := newHostMap(l, vpncidr)
|
||||
mainHM.preferredRanges.Store(&preferredRanges)
|
||||
|
||||
|
@ -66,7 +66,7 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
|
|||
assert.NotContains(t, blah.vpnIps, ip)
|
||||
}
|
||||
|
||||
func testCountTimerWheelEntries(tw *LockingTimerWheel[iputil.VpnIp]) (c int) {
|
||||
func testCountTimerWheelEntries(tw *LockingTimerWheel[netip.Addr]) (c int) {
|
||||
for _, i := range tw.t.wheel {
|
||||
n := i.Head
|
||||
for n != nil {
|
||||
|
@ -80,7 +80,7 @@ func testCountTimerWheelEntries(tw *LockingTimerWheel[iputil.VpnIp]) (c int) {
|
|||
type mockEncWriter struct {
|
||||
}
|
||||
|
||||
func (mw *mockEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte) {
|
||||
func (mw *mockEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte) {
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -92,4 +92,4 @@ func (mw *mockEncWriter) SendMessageToHostInfo(t header.MessageType, st header.M
|
|||
return
|
||||
}
|
||||
|
||||
func (mw *mockEncWriter) Handshake(vpnIP iputil.VpnIp) {}
|
||||
func (mw *mockEncWriter) Handshake(vpnIP netip.Addr) {}
|
||||
|
|
146
hostmap.go
146
hostmap.go
|
@ -3,18 +3,17 @@ package nebula
|
|||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/rcrowley/go-metrics"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/cidr"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/header"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/slackhq/nebula/udp"
|
||||
)
|
||||
|
||||
// const ProbeLen = 100
|
||||
|
@ -49,7 +48,7 @@ type Relay struct {
|
|||
State int
|
||||
LocalIndex uint32
|
||||
RemoteIndex uint32
|
||||
PeerIp iputil.VpnIp
|
||||
PeerIp netip.Addr
|
||||
}
|
||||
|
||||
type HostMap struct {
|
||||
|
@ -57,9 +56,9 @@ type HostMap struct {
|
|||
Indexes map[uint32]*HostInfo
|
||||
Relays map[uint32]*HostInfo // Maps a Relay IDX to a Relay HostInfo object
|
||||
RemoteIndexes map[uint32]*HostInfo
|
||||
Hosts map[iputil.VpnIp]*HostInfo
|
||||
preferredRanges atomic.Pointer[[]*net.IPNet]
|
||||
vpnCIDR *net.IPNet
|
||||
Hosts map[netip.Addr]*HostInfo
|
||||
preferredRanges atomic.Pointer[[]netip.Prefix]
|
||||
vpnCIDR netip.Prefix
|
||||
l *logrus.Logger
|
||||
}
|
||||
|
||||
|
@ -69,12 +68,12 @@ type HostMap struct {
|
|||
type RelayState struct {
|
||||
sync.RWMutex
|
||||
|
||||
relays map[iputil.VpnIp]struct{} // Set of VpnIp's of Hosts to use as relays to access this peer
|
||||
relayForByIp map[iputil.VpnIp]*Relay // Maps VpnIps of peers for which this HostInfo is a relay to some Relay info
|
||||
relayForByIdx map[uint32]*Relay // Maps a local index to some Relay info
|
||||
relays map[netip.Addr]struct{} // Set of VpnIp's of Hosts to use as relays to access this peer
|
||||
relayForByIp map[netip.Addr]*Relay // Maps VpnIps of peers for which this HostInfo is a relay to some Relay info
|
||||
relayForByIdx map[uint32]*Relay // Maps a local index to some Relay info
|
||||
}
|
||||
|
||||
func (rs *RelayState) DeleteRelay(ip iputil.VpnIp) {
|
||||
func (rs *RelayState) DeleteRelay(ip netip.Addr) {
|
||||
rs.Lock()
|
||||
defer rs.Unlock()
|
||||
delete(rs.relays, ip)
|
||||
|
@ -90,33 +89,33 @@ func (rs *RelayState) CopyAllRelayFor() []*Relay {
|
|||
return ret
|
||||
}
|
||||
|
||||
func (rs *RelayState) GetRelayForByIp(ip iputil.VpnIp) (*Relay, bool) {
|
||||
func (rs *RelayState) GetRelayForByIp(ip netip.Addr) (*Relay, bool) {
|
||||
rs.RLock()
|
||||
defer rs.RUnlock()
|
||||
r, ok := rs.relayForByIp[ip]
|
||||
return r, ok
|
||||
}
|
||||
|
||||
func (rs *RelayState) InsertRelayTo(ip iputil.VpnIp) {
|
||||
func (rs *RelayState) InsertRelayTo(ip netip.Addr) {
|
||||
rs.Lock()
|
||||
defer rs.Unlock()
|
||||
rs.relays[ip] = struct{}{}
|
||||
}
|
||||
|
||||
func (rs *RelayState) CopyRelayIps() []iputil.VpnIp {
|
||||
func (rs *RelayState) CopyRelayIps() []netip.Addr {
|
||||
rs.RLock()
|
||||
defer rs.RUnlock()
|
||||
ret := make([]iputil.VpnIp, 0, len(rs.relays))
|
||||
ret := make([]netip.Addr, 0, len(rs.relays))
|
||||
for ip := range rs.relays {
|
||||
ret = append(ret, ip)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func (rs *RelayState) CopyRelayForIps() []iputil.VpnIp {
|
||||
func (rs *RelayState) CopyRelayForIps() []netip.Addr {
|
||||
rs.RLock()
|
||||
defer rs.RUnlock()
|
||||
currentRelays := make([]iputil.VpnIp, 0, len(rs.relayForByIp))
|
||||
currentRelays := make([]netip.Addr, 0, len(rs.relayForByIp))
|
||||
for relayIp := range rs.relayForByIp {
|
||||
currentRelays = append(currentRelays, relayIp)
|
||||
}
|
||||
|
@ -133,19 +132,7 @@ func (rs *RelayState) CopyRelayForIdxs() []uint32 {
|
|||
return ret
|
||||
}
|
||||
|
||||
func (rs *RelayState) RemoveRelay(localIdx uint32) (iputil.VpnIp, bool) {
|
||||
rs.Lock()
|
||||
defer rs.Unlock()
|
||||
r, ok := rs.relayForByIdx[localIdx]
|
||||
if !ok {
|
||||
return iputil.VpnIp(0), false
|
||||
}
|
||||
delete(rs.relayForByIdx, localIdx)
|
||||
delete(rs.relayForByIp, r.PeerIp)
|
||||
return r.PeerIp, true
|
||||
}
|
||||
|
||||
func (rs *RelayState) CompleteRelayByIP(vpnIp iputil.VpnIp, remoteIdx uint32) bool {
|
||||
func (rs *RelayState) CompleteRelayByIP(vpnIp netip.Addr, remoteIdx uint32) bool {
|
||||
rs.Lock()
|
||||
defer rs.Unlock()
|
||||
r, ok := rs.relayForByIp[vpnIp]
|
||||
|
@ -175,7 +162,7 @@ func (rs *RelayState) CompleteRelayByIdx(localIdx uint32, remoteIdx uint32) (*Re
|
|||
return &newRelay, true
|
||||
}
|
||||
|
||||
func (rs *RelayState) QueryRelayForByIp(vpnIp iputil.VpnIp) (*Relay, bool) {
|
||||
func (rs *RelayState) QueryRelayForByIp(vpnIp netip.Addr) (*Relay, bool) {
|
||||
rs.RLock()
|
||||
defer rs.RUnlock()
|
||||
r, ok := rs.relayForByIp[vpnIp]
|
||||
|
@ -189,7 +176,7 @@ func (rs *RelayState) QueryRelayForByIdx(idx uint32) (*Relay, bool) {
|
|||
return r, ok
|
||||
}
|
||||
|
||||
func (rs *RelayState) InsertRelay(ip iputil.VpnIp, idx uint32, r *Relay) {
|
||||
func (rs *RelayState) InsertRelay(ip netip.Addr, idx uint32, r *Relay) {
|
||||
rs.Lock()
|
||||
defer rs.Unlock()
|
||||
rs.relayForByIp[ip] = r
|
||||
|
@ -197,15 +184,15 @@ func (rs *RelayState) InsertRelay(ip iputil.VpnIp, idx uint32, r *Relay) {
|
|||
}
|
||||
|
||||
type HostInfo struct {
|
||||
remote *udp.Addr
|
||||
remote netip.AddrPort
|
||||
remotes *RemoteList
|
||||
promoteCounter atomic.Uint32
|
||||
ConnectionState *ConnectionState
|
||||
remoteIndexId uint32
|
||||
localIndexId uint32
|
||||
vpnIp iputil.VpnIp
|
||||
vpnIp netip.Addr
|
||||
recvError atomic.Uint32
|
||||
remoteCidr *cidr.Tree4[struct{}]
|
||||
remoteCidr *bart.Table[struct{}]
|
||||
relayState RelayState
|
||||
|
||||
// HandshakePacket records the packets used to create this hostinfo
|
||||
|
@ -227,7 +214,7 @@ type HostInfo struct {
|
|||
lastHandshakeTime uint64
|
||||
|
||||
lastRoam time.Time
|
||||
lastRoamRemote *udp.Addr
|
||||
lastRoamRemote netip.AddrPort
|
||||
|
||||
// Used to track other hostinfos for this vpn ip since only 1 can be primary
|
||||
// Synchronised via hostmap lock and not the hostinfo lock.
|
||||
|
@ -254,7 +241,7 @@ type cachedPacketMetrics struct {
|
|||
dropped metrics.Counter
|
||||
}
|
||||
|
||||
func NewHostMapFromConfig(l *logrus.Logger, vpnCIDR *net.IPNet, c *config.C) *HostMap {
|
||||
func NewHostMapFromConfig(l *logrus.Logger, vpnCIDR netip.Prefix, c *config.C) *HostMap {
|
||||
hm := newHostMap(l, vpnCIDR)
|
||||
|
||||
hm.reload(c, true)
|
||||
|
@ -269,12 +256,12 @@ func NewHostMapFromConfig(l *logrus.Logger, vpnCIDR *net.IPNet, c *config.C) *Ho
|
|||
return hm
|
||||
}
|
||||
|
||||
func newHostMap(l *logrus.Logger, vpnCIDR *net.IPNet) *HostMap {
|
||||
func newHostMap(l *logrus.Logger, vpnCIDR netip.Prefix) *HostMap {
|
||||
return &HostMap{
|
||||
Indexes: map[uint32]*HostInfo{},
|
||||
Relays: map[uint32]*HostInfo{},
|
||||
RemoteIndexes: map[uint32]*HostInfo{},
|
||||
Hosts: map[iputil.VpnIp]*HostInfo{},
|
||||
Hosts: map[netip.Addr]*HostInfo{},
|
||||
vpnCIDR: vpnCIDR,
|
||||
l: l,
|
||||
}
|
||||
|
@ -282,11 +269,11 @@ func newHostMap(l *logrus.Logger, vpnCIDR *net.IPNet) *HostMap {
|
|||
|
||||
func (hm *HostMap) reload(c *config.C, initial bool) {
|
||||
if initial || c.HasChanged("preferred_ranges") {
|
||||
var preferredRanges []*net.IPNet
|
||||
var preferredRanges []netip.Prefix
|
||||
rawPreferredRanges := c.GetStringSlice("preferred_ranges", []string{})
|
||||
|
||||
for _, rawPreferredRange := range rawPreferredRanges {
|
||||
_, preferredRange, err := net.ParseCIDR(rawPreferredRange)
|
||||
preferredRange, err := netip.ParsePrefix(rawPreferredRange)
|
||||
|
||||
if err != nil {
|
||||
hm.l.WithError(err).WithField("range", rawPreferredRanges).Warn("Failed to parse preferred ranges, ignoring")
|
||||
|
@ -378,7 +365,7 @@ func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
|
|||
// The vpnIp pointer points to the same hostinfo as the local index id, we can remove it
|
||||
delete(hm.Hosts, hostinfo.vpnIp)
|
||||
if len(hm.Hosts) == 0 {
|
||||
hm.Hosts = map[iputil.VpnIp]*HostInfo{}
|
||||
hm.Hosts = map[netip.Addr]*HostInfo{}
|
||||
}
|
||||
|
||||
if hostinfo.next != nil {
|
||||
|
@ -461,11 +448,11 @@ func (hm *HostMap) QueryReverseIndex(index uint32) *HostInfo {
|
|||
}
|
||||
}
|
||||
|
||||
func (hm *HostMap) QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo {
|
||||
func (hm *HostMap) QueryVpnIp(vpnIp netip.Addr) *HostInfo {
|
||||
return hm.queryVpnIp(vpnIp, nil)
|
||||
}
|
||||
|
||||
func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp iputil.VpnIp) (*HostInfo, *Relay, error) {
|
||||
func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp netip.Addr) (*HostInfo, *Relay, error) {
|
||||
hm.RLock()
|
||||
defer hm.RUnlock()
|
||||
|
||||
|
@ -483,7 +470,7 @@ func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp iputil.VpnIp) (*Host
|
|||
return nil, nil, errors.New("unable to find host with relay")
|
||||
}
|
||||
|
||||
func (hm *HostMap) queryVpnIp(vpnIp iputil.VpnIp, promoteIfce *Interface) *HostInfo {
|
||||
func (hm *HostMap) queryVpnIp(vpnIp netip.Addr, promoteIfce *Interface) *HostInfo {
|
||||
hm.RLock()
|
||||
if h, ok := hm.Hosts[vpnIp]; ok {
|
||||
hm.RUnlock()
|
||||
|
@ -535,7 +522,7 @@ func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
|
|||
}
|
||||
}
|
||||
|
||||
func (hm *HostMap) GetPreferredRanges() []*net.IPNet {
|
||||
func (hm *HostMap) GetPreferredRanges() []netip.Prefix {
|
||||
//NOTE: if preferredRanges is ever not stored before a load this will fail to dereference a nil pointer
|
||||
return *hm.preferredRanges.Load()
|
||||
}
|
||||
|
@ -560,14 +547,14 @@ func (hm *HostMap) ForEachIndex(f controlEach) {
|
|||
|
||||
// TryPromoteBest handles re-querying lighthouses and probing for better paths
|
||||
// NOTE: It is an error to call this if you are a lighthouse since they should not roam clients!
|
||||
func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface) {
|
||||
func (i *HostInfo) TryPromoteBest(preferredRanges []netip.Prefix, ifce *Interface) {
|
||||
c := i.promoteCounter.Add(1)
|
||||
if c%ifce.tryPromoteEvery.Load() == 0 {
|
||||
remote := i.remote
|
||||
|
||||
// return early if we are already on a preferred remote
|
||||
if remote != nil {
|
||||
rIP := remote.IP
|
||||
if remote.IsValid() {
|
||||
rIP := remote.Addr()
|
||||
for _, l := range preferredRanges {
|
||||
if l.Contains(rIP) {
|
||||
return
|
||||
|
@ -575,8 +562,8 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface)
|
|||
}
|
||||
}
|
||||
|
||||
i.remotes.ForEach(preferredRanges, func(addr *udp.Addr, preferred bool) {
|
||||
if remote != nil && (addr == nil || !preferred) {
|
||||
i.remotes.ForEach(preferredRanges, func(addr netip.AddrPort, preferred bool) {
|
||||
if remote.IsValid() && (!addr.IsValid() || !preferred) {
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -605,23 +592,23 @@ func (i *HostInfo) GetCert() *cert.NebulaCertificate {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (i *HostInfo) SetRemote(remote *udp.Addr) {
|
||||
func (i *HostInfo) SetRemote(remote netip.AddrPort) {
|
||||
// We copy here because we likely got this remote from a source that reuses the object
|
||||
if !i.remote.Equals(remote) {
|
||||
i.remote = remote.Copy()
|
||||
i.remotes.LearnRemote(i.vpnIp, remote.Copy())
|
||||
if i.remote != remote {
|
||||
i.remote = remote
|
||||
i.remotes.LearnRemote(i.vpnIp, remote)
|
||||
}
|
||||
}
|
||||
|
||||
// SetRemoteIfPreferred returns true if the remote was changed. The lastRoam
|
||||
// time on the HostInfo will also be updated.
|
||||
func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udp.Addr) bool {
|
||||
if newRemote == nil {
|
||||
func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) bool {
|
||||
if !newRemote.IsValid() {
|
||||
// relays have nil udp Addrs
|
||||
return false
|
||||
}
|
||||
currentRemote := i.remote
|
||||
if currentRemote == nil {
|
||||
if !currentRemote.IsValid() {
|
||||
i.SetRemote(newRemote)
|
||||
return true
|
||||
}
|
||||
|
@ -631,11 +618,11 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udp.Addr) bool {
|
|||
newIsPreferred := false
|
||||
for _, l := range hm.GetPreferredRanges() {
|
||||
// return early if we are already on a preferred remote
|
||||
if l.Contains(currentRemote.IP) {
|
||||
if l.Contains(currentRemote.Addr()) {
|
||||
return false
|
||||
}
|
||||
|
||||
if l.Contains(newRemote.IP) {
|
||||
if l.Contains(newRemote.Addr()) {
|
||||
newIsPreferred = true
|
||||
}
|
||||
}
|
||||
|
@ -643,7 +630,7 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udp.Addr) bool {
|
|||
if newIsPreferred {
|
||||
// Consider this a roaming event
|
||||
i.lastRoam = time.Now()
|
||||
i.lastRoamRemote = currentRemote.Copy()
|
||||
i.lastRoamRemote = currentRemote
|
||||
|
||||
i.SetRemote(newRemote)
|
||||
|
||||
|
@ -666,13 +653,21 @@ func (i *HostInfo) CreateRemoteCIDR(c *cert.NebulaCertificate) {
|
|||
return
|
||||
}
|
||||
|
||||
remoteCidr := cidr.NewTree4[struct{}]()
|
||||
remoteCidr := new(bart.Table[struct{}])
|
||||
for _, ip := range c.Details.Ips {
|
||||
remoteCidr.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{})
|
||||
//TODO: IPV6-WORK what to do when ip is invalid?
|
||||
nip, _ := netip.AddrFromSlice(ip.IP)
|
||||
nip = nip.Unmap()
|
||||
bits, _ := ip.Mask.Size()
|
||||
remoteCidr.Insert(netip.PrefixFrom(nip, bits), struct{}{})
|
||||
}
|
||||
|
||||
for _, n := range c.Details.Subnets {
|
||||
remoteCidr.AddCIDR(n, struct{}{})
|
||||
//TODO: IPV6-WORK what to do when ip is invalid?
|
||||
nip, _ := netip.AddrFromSlice(n.IP)
|
||||
nip = nip.Unmap()
|
||||
bits, _ := n.Mask.Size()
|
||||
remoteCidr.Insert(netip.PrefixFrom(nip, bits), struct{}{})
|
||||
}
|
||||
i.remoteCidr = remoteCidr
|
||||
}
|
||||
|
@ -697,9 +692,9 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
|
|||
|
||||
// Utility functions
|
||||
|
||||
func localIps(l *logrus.Logger, allowList *LocalAllowList) *[]net.IP {
|
||||
func localIps(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr {
|
||||
//FIXME: This function is pretty garbage
|
||||
var ips []net.IP
|
||||
var ips []netip.Addr
|
||||
ifaces, _ := net.Interfaces()
|
||||
for _, i := range ifaces {
|
||||
allow := allowList.AllowName(i.Name)
|
||||
|
@ -721,20 +716,29 @@ func localIps(l *logrus.Logger, allowList *LocalAllowList) *[]net.IP {
|
|||
ip = v.IP
|
||||
}
|
||||
|
||||
nip, ok := netip.AddrFromSlice(ip)
|
||||
if !ok {
|
||||
if l.Level >= logrus.DebugLevel {
|
||||
l.WithField("localIp", ip).Debug("ip was invalid for netip")
|
||||
}
|
||||
continue
|
||||
}
|
||||
nip = nip.Unmap()
|
||||
|
||||
//TODO: Filtering out link local for now, this is probably the most correct thing
|
||||
//TODO: Would be nice to filter out SLAAC MAC based ips as well
|
||||
if ip.IsLoopback() == false && !ip.IsLinkLocalUnicast() {
|
||||
allow := allowList.Allow(ip)
|
||||
if nip.IsLoopback() == false && nip.IsLinkLocalUnicast() == false {
|
||||
allow := allowList.Allow(nip)
|
||||
if l.Level >= logrus.TraceLevel {
|
||||
l.WithField("localIp", ip).WithField("allow", allow).Trace("localAllowList.Allow")
|
||||
l.WithField("localIp", nip).WithField("allow", allow).Trace("localAllowList.Allow")
|
||||
}
|
||||
if !allow {
|
||||
continue
|
||||
}
|
||||
|
||||
ips = append(ips, ip)
|
||||
ips = append(ips, nip)
|
||||
}
|
||||
}
|
||||
}
|
||||
return &ips
|
||||
return ips
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
package nebula
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/slackhq/nebula/config"
|
||||
|
@ -13,18 +13,15 @@ func TestHostMap_MakePrimary(t *testing.T) {
|
|||
l := test.NewLogger()
|
||||
hm := newHostMap(
|
||||
l,
|
||||
&net.IPNet{
|
||||
IP: net.IP{10, 0, 0, 1},
|
||||
Mask: net.IPMask{255, 255, 255, 0},
|
||||
},
|
||||
netip.MustParsePrefix("10.0.0.1/24"),
|
||||
)
|
||||
|
||||
f := &Interface{}
|
||||
|
||||
h1 := &HostInfo{vpnIp: 1, localIndexId: 1}
|
||||
h2 := &HostInfo{vpnIp: 1, localIndexId: 2}
|
||||
h3 := &HostInfo{vpnIp: 1, localIndexId: 3}
|
||||
h4 := &HostInfo{vpnIp: 1, localIndexId: 4}
|
||||
h1 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 1}
|
||||
h2 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 2}
|
||||
h3 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 3}
|
||||
h4 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 4}
|
||||
|
||||
hm.unlockedAddHostInfo(h4, f)
|
||||
hm.unlockedAddHostInfo(h3, f)
|
||||
|
@ -32,7 +29,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
|
|||
hm.unlockedAddHostInfo(h1, f)
|
||||
|
||||
// Make sure we go h1 -> h2 -> h3 -> h4
|
||||
prim := hm.QueryVpnIp(1)
|
||||
prim := hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
|
||||
assert.Equal(t, h1.localIndexId, prim.localIndexId)
|
||||
assert.Equal(t, h2.localIndexId, prim.next.localIndexId)
|
||||
assert.Nil(t, prim.prev)
|
||||
|
@ -47,7 +44,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
|
|||
hm.MakePrimary(h3)
|
||||
|
||||
// Make sure we go h3 -> h1 -> h2 -> h4
|
||||
prim = hm.QueryVpnIp(1)
|
||||
prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
|
||||
assert.Equal(t, h3.localIndexId, prim.localIndexId)
|
||||
assert.Equal(t, h1.localIndexId, prim.next.localIndexId)
|
||||
assert.Nil(t, prim.prev)
|
||||
|
@ -62,7 +59,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
|
|||
hm.MakePrimary(h4)
|
||||
|
||||
// Make sure we go h4 -> h3 -> h1 -> h2
|
||||
prim = hm.QueryVpnIp(1)
|
||||
prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
|
||||
assert.Equal(t, h4.localIndexId, prim.localIndexId)
|
||||
assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
|
||||
assert.Nil(t, prim.prev)
|
||||
|
@ -77,7 +74,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
|
|||
hm.MakePrimary(h4)
|
||||
|
||||
// Make sure we go h4 -> h3 -> h1 -> h2
|
||||
prim = hm.QueryVpnIp(1)
|
||||
prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
|
||||
assert.Equal(t, h4.localIndexId, prim.localIndexId)
|
||||
assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
|
||||
assert.Nil(t, prim.prev)
|
||||
|
@ -93,20 +90,17 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
|
|||
l := test.NewLogger()
|
||||
hm := newHostMap(
|
||||
l,
|
||||
&net.IPNet{
|
||||
IP: net.IP{10, 0, 0, 1},
|
||||
Mask: net.IPMask{255, 255, 255, 0},
|
||||
},
|
||||
netip.MustParsePrefix("10.0.0.1/24"),
|
||||
)
|
||||
|
||||
f := &Interface{}
|
||||
|
||||
h1 := &HostInfo{vpnIp: 1, localIndexId: 1}
|
||||
h2 := &HostInfo{vpnIp: 1, localIndexId: 2}
|
||||
h3 := &HostInfo{vpnIp: 1, localIndexId: 3}
|
||||
h4 := &HostInfo{vpnIp: 1, localIndexId: 4}
|
||||
h5 := &HostInfo{vpnIp: 1, localIndexId: 5}
|
||||
h6 := &HostInfo{vpnIp: 1, localIndexId: 6}
|
||||
h1 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 1}
|
||||
h2 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 2}
|
||||
h3 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 3}
|
||||
h4 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 4}
|
||||
h5 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 5}
|
||||
h6 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 6}
|
||||
|
||||
hm.unlockedAddHostInfo(h6, f)
|
||||
hm.unlockedAddHostInfo(h5, f)
|
||||
|
@ -122,7 +116,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
|
|||
assert.Nil(t, h)
|
||||
|
||||
// Make sure we go h1 -> h2 -> h3 -> h4 -> h5
|
||||
prim := hm.QueryVpnIp(1)
|
||||
prim := hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
|
||||
assert.Equal(t, h1.localIndexId, prim.localIndexId)
|
||||
assert.Equal(t, h2.localIndexId, prim.next.localIndexId)
|
||||
assert.Nil(t, prim.prev)
|
||||
|
@ -141,7 +135,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
|
|||
assert.Nil(t, h1.next)
|
||||
|
||||
// Make sure we go h2 -> h3 -> h4 -> h5
|
||||
prim = hm.QueryVpnIp(1)
|
||||
prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
|
||||
assert.Equal(t, h2.localIndexId, prim.localIndexId)
|
||||
assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
|
||||
assert.Nil(t, prim.prev)
|
||||
|
@ -159,7 +153,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
|
|||
assert.Nil(t, h3.next)
|
||||
|
||||
// Make sure we go h2 -> h4 -> h5
|
||||
prim = hm.QueryVpnIp(1)
|
||||
prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
|
||||
assert.Equal(t, h2.localIndexId, prim.localIndexId)
|
||||
assert.Equal(t, h4.localIndexId, prim.next.localIndexId)
|
||||
assert.Nil(t, prim.prev)
|
||||
|
@ -175,7 +169,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
|
|||
assert.Nil(t, h5.next)
|
||||
|
||||
// Make sure we go h2 -> h4
|
||||
prim = hm.QueryVpnIp(1)
|
||||
prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
|
||||
assert.Equal(t, h2.localIndexId, prim.localIndexId)
|
||||
assert.Equal(t, h4.localIndexId, prim.next.localIndexId)
|
||||
assert.Nil(t, prim.prev)
|
||||
|
@ -189,7 +183,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
|
|||
assert.Nil(t, h2.next)
|
||||
|
||||
// Make sure we only have h4
|
||||
prim = hm.QueryVpnIp(1)
|
||||
prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
|
||||
assert.Equal(t, h4.localIndexId, prim.localIndexId)
|
||||
assert.Nil(t, prim.prev)
|
||||
assert.Nil(t, prim.next)
|
||||
|
@ -201,7 +195,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
|
|||
assert.Nil(t, h4.next)
|
||||
|
||||
// Make sure we have nil
|
||||
prim = hm.QueryVpnIp(1)
|
||||
prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
|
||||
assert.Nil(t, prim)
|
||||
}
|
||||
|
||||
|
@ -211,14 +205,11 @@ func TestHostMap_reload(t *testing.T) {
|
|||
|
||||
hm := NewHostMapFromConfig(
|
||||
l,
|
||||
&net.IPNet{
|
||||
IP: net.IP{10, 0, 0, 1},
|
||||
Mask: net.IPMask{255, 255, 255, 0},
|
||||
},
|
||||
netip.MustParsePrefix("10.0.0.1/24"),
|
||||
c,
|
||||
)
|
||||
|
||||
toS := func(ipn []*net.IPNet) []string {
|
||||
toS := func(ipn []netip.Prefix) []string {
|
||||
var s []string
|
||||
for _, n := range ipn {
|
||||
s = append(s, n.String())
|
||||
|
|
|
@ -5,9 +5,11 @@ package nebula
|
|||
|
||||
// This file contains functions used to export information to the e2e testing framework
|
||||
|
||||
import "github.com/slackhq/nebula/iputil"
|
||||
import (
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
func (i *HostInfo) GetVpnIp() iputil.VpnIp {
|
||||
func (i *HostInfo) GetVpnIp() netip.Addr {
|
||||
return i.vpnIp
|
||||
}
|
||||
|
||||
|
|
44
inside.go
44
inside.go
|
@ -1,12 +1,13 @@
|
|||
package nebula
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/firewall"
|
||||
"github.com/slackhq/nebula/header"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/slackhq/nebula/noiseutil"
|
||||
"github.com/slackhq/nebula/udp"
|
||||
)
|
||||
|
||||
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
|
||||
|
@ -19,11 +20,11 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
|
|||
}
|
||||
|
||||
// Ignore local broadcast packets
|
||||
if f.dropLocalBroadcast && fwPacket.RemoteIP == f.localBroadcast {
|
||||
if f.dropLocalBroadcast && fwPacket.RemoteIP == f.myBroadcastAddr {
|
||||
return
|
||||
}
|
||||
|
||||
if fwPacket.RemoteIP == f.myVpnIp {
|
||||
if fwPacket.RemoteIP == f.myVpnNet.Addr() {
|
||||
// Immediately forward packets from self to self.
|
||||
// This should only happen on Darwin-based and FreeBSD hosts, which
|
||||
// routes packets from the Nebula IP to the Nebula IP through the Nebula
|
||||
|
@ -39,8 +40,8 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
|
|||
return
|
||||
}
|
||||
|
||||
// Ignore broadcast packets
|
||||
if f.dropMulticast && isMulticast(fwPacket.RemoteIP) {
|
||||
// Ignore multicast packets
|
||||
if f.dropMulticast && fwPacket.RemoteIP.IsMulticast() {
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -64,7 +65,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
|
|||
|
||||
dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
|
||||
if dropReason == nil {
|
||||
f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, nil, packet, nb, out, q)
|
||||
f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q)
|
||||
|
||||
} else {
|
||||
f.rejectInside(packet, out, q)
|
||||
|
@ -113,19 +114,19 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *
|
|||
return
|
||||
}
|
||||
|
||||
f.sendNoMetrics(header.Message, 0, ci, hostinfo, nil, out, nb, packet, q)
|
||||
f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q)
|
||||
}
|
||||
|
||||
func (f *Interface) Handshake(vpnIp iputil.VpnIp) {
|
||||
func (f *Interface) Handshake(vpnIp netip.Addr) {
|
||||
f.getOrHandshake(vpnIp, nil)
|
||||
}
|
||||
|
||||
// getOrHandshake returns nil if the vpnIp is not routable.
|
||||
// If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel
|
||||
func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
|
||||
if !ipMaskContains(f.lightHouse.myVpnIp, f.lightHouse.myVpnZeros, vpnIp) {
|
||||
func (f *Interface) getOrHandshake(vpnIp netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
|
||||
if !f.myVpnNet.Contains(vpnIp) {
|
||||
vpnIp = f.inside.RouteFor(vpnIp)
|
||||
if vpnIp == 0 {
|
||||
if !vpnIp.IsValid() {
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
@ -152,11 +153,11 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
|
|||
return
|
||||
}
|
||||
|
||||
f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, nil, p, nb, out, 0)
|
||||
f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0)
|
||||
}
|
||||
|
||||
// SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp
|
||||
func (f *Interface) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte) {
|
||||
func (f *Interface) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte) {
|
||||
hostInfo, ready := f.getOrHandshake(vpnIp, func(hh *HandshakeHostInfo) {
|
||||
hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics)
|
||||
})
|
||||
|
@ -182,10 +183,10 @@ func (f *Interface) SendMessageToHostInfo(t header.MessageType, st header.Messag
|
|||
|
||||
func (f *Interface) send(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, p, nb, out []byte) {
|
||||
f.messageMetrics.Tx(t, st, 1)
|
||||
f.sendNoMetrics(t, st, ci, hostinfo, nil, p, nb, out, 0)
|
||||
f.sendNoMetrics(t, st, ci, hostinfo, netip.AddrPort{}, p, nb, out, 0)
|
||||
}
|
||||
|
||||
func (f *Interface) sendTo(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udp.Addr, p, nb, out []byte) {
|
||||
func (f *Interface) sendTo(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte) {
|
||||
f.messageMetrics.Tx(t, st, 1)
|
||||
f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out, 0)
|
||||
}
|
||||
|
@ -255,12 +256,12 @@ func (f *Interface) SendVia(via *HostInfo,
|
|||
f.connectionManager.RelayUsed(relay.LocalIndex)
|
||||
}
|
||||
|
||||
func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udp.Addr, p, nb, out []byte, q int) {
|
||||
func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte, q int) {
|
||||
if ci.eKey == nil {
|
||||
//TODO: log warning
|
||||
return
|
||||
}
|
||||
useRelay := remote == nil && hostinfo.remote == nil
|
||||
useRelay := !remote.IsValid() && !hostinfo.remote.IsValid()
|
||||
fullOut := out
|
||||
|
||||
if useRelay {
|
||||
|
@ -308,13 +309,13 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
|
|||
return
|
||||
}
|
||||
|
||||
if remote != nil {
|
||||
if remote.IsValid() {
|
||||
err = f.writers[q].WriteTo(out, remote)
|
||||
if err != nil {
|
||||
hostinfo.logger(f.l).WithError(err).
|
||||
WithField("udpAddr", remote).Error("Failed to write outgoing packet")
|
||||
}
|
||||
} else if hostinfo.remote != nil {
|
||||
} else if hostinfo.remote.IsValid() {
|
||||
err = f.writers[q].WriteTo(out, hostinfo.remote)
|
||||
if err != nil {
|
||||
hostinfo.logger(f.l).WithError(err).
|
||||
|
@ -334,8 +335,3 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isMulticast(ip iputil.VpnIp) bool {
|
||||
// Class D multicast
|
||||
return (((ip >> 24) & 0xff) & 0xf0) == 0xe0
|
||||
}
|
||||
|
|
47
interface.go
47
interface.go
|
@ -2,10 +2,11 @@ package nebula
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"runtime"
|
||||
"sync/atomic"
|
||||
|
@ -16,7 +17,6 @@ import (
|
|||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/firewall"
|
||||
"github.com/slackhq/nebula/header"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/slackhq/nebula/overlay"
|
||||
"github.com/slackhq/nebula/udp"
|
||||
)
|
||||
|
@ -63,8 +63,8 @@ type Interface struct {
|
|||
serveDns bool
|
||||
createTime time.Time
|
||||
lightHouse *LightHouse
|
||||
localBroadcast iputil.VpnIp
|
||||
myVpnIp iputil.VpnIp
|
||||
myBroadcastAddr netip.Addr
|
||||
myVpnNet netip.Prefix
|
||||
dropLocalBroadcast bool
|
||||
dropMulticast bool
|
||||
routines int
|
||||
|
@ -102,9 +102,9 @@ type EncWriter interface {
|
|||
out []byte,
|
||||
nocopy bool,
|
||||
)
|
||||
SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte)
|
||||
SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte)
|
||||
SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte)
|
||||
Handshake(vpnIp iputil.VpnIp)
|
||||
Handshake(vpnIp netip.Addr)
|
||||
}
|
||||
|
||||
type sendRecvErrorConfig uint8
|
||||
|
@ -115,10 +115,10 @@ const (
|
|||
sendRecvErrorPrivate
|
||||
)
|
||||
|
||||
func (s sendRecvErrorConfig) ShouldSendRecvError(ip net.IP) bool {
|
||||
func (s sendRecvErrorConfig) ShouldSendRecvError(ip netip.AddrPort) bool {
|
||||
switch s {
|
||||
case sendRecvErrorPrivate:
|
||||
return ip.IsPrivate()
|
||||
return ip.Addr().IsPrivate()
|
||||
case sendRecvErrorAlways:
|
||||
return true
|
||||
case sendRecvErrorNever:
|
||||
|
@ -156,7 +156,27 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
|||
}
|
||||
|
||||
certificate := c.pki.GetCertState().Certificate
|
||||
myVpnIp := iputil.Ip2VpnIp(certificate.Details.Ips[0].IP)
|
||||
|
||||
myVpnAddr, ok := netip.AddrFromSlice(certificate.Details.Ips[0].IP)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid ip address in certificate: %s", certificate.Details.Ips[0].IP)
|
||||
}
|
||||
|
||||
myVpnMask, ok := netip.AddrFromSlice(certificate.Details.Ips[0].Mask)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid ip mask in certificate: %s", certificate.Details.Ips[0].Mask)
|
||||
}
|
||||
|
||||
myVpnAddr = myVpnAddr.Unmap()
|
||||
myVpnMask = myVpnMask.Unmap()
|
||||
|
||||
if myVpnAddr.BitLen() != myVpnMask.BitLen() {
|
||||
return nil, fmt.Errorf("ip address and mask are different lengths in certificate")
|
||||
}
|
||||
|
||||
ones, _ := certificate.Details.Ips[0].Mask.Size()
|
||||
myVpnNet := netip.PrefixFrom(myVpnAddr, ones)
|
||||
|
||||
ifce := &Interface{
|
||||
pki: c.pki,
|
||||
hostMap: c.HostMap,
|
||||
|
@ -168,14 +188,13 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
|||
handshakeManager: c.HandshakeManager,
|
||||
createTime: time.Now(),
|
||||
lightHouse: c.lightHouse,
|
||||
localBroadcast: myVpnIp | ^iputil.Ip2VpnIp(certificate.Details.Ips[0].Mask),
|
||||
dropLocalBroadcast: c.DropLocalBroadcast,
|
||||
dropMulticast: c.DropMulticast,
|
||||
routines: c.routines,
|
||||
version: c.version,
|
||||
writers: make([]udp.Conn, c.routines),
|
||||
readers: make([]io.ReadWriteCloser, c.routines),
|
||||
myVpnIp: myVpnIp,
|
||||
myVpnNet: myVpnNet,
|
||||
relayManager: c.relayManager,
|
||||
|
||||
conntrackCacheTimeout: c.ConntrackCacheTimeout,
|
||||
|
@ -190,6 +209,12 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
|||
l: c.l,
|
||||
}
|
||||
|
||||
if myVpnAddr.Is4() {
|
||||
addr := myVpnNet.Masked().Addr().As4()
|
||||
binary.BigEndian.PutUint32(addr[:], binary.BigEndian.Uint32(addr[:])|^binary.BigEndian.Uint32(certificate.Details.Ips[0].Mask))
|
||||
ifce.myBroadcastAddr = netip.AddrFrom4(addr)
|
||||
}
|
||||
|
||||
ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
|
||||
ifce.reQueryEvery.Store(c.reQueryEvery)
|
||||
ifce.reQueryWait.Store(int64(c.reQueryWait))
|
||||
|
|
|
@ -6,6 +6,8 @@ import (
|
|||
"golang.org/x/net/ipv4"
|
||||
)
|
||||
|
||||
//TODO: IPV6-WORK can probably delete this
|
||||
|
||||
const (
|
||||
// Need 96 bytes for the largest reject packet:
|
||||
// - 20 byte ipv4 header
|
||||
|
|
|
@ -1,93 +0,0 @@
|
|||
package iputil
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
type VpnIp uint32
|
||||
|
||||
const maxIPv4StringLen = len("255.255.255.255")
|
||||
|
||||
func (ip VpnIp) String() string {
|
||||
b := make([]byte, maxIPv4StringLen)
|
||||
|
||||
n := ubtoa(b, 0, byte(ip>>24))
|
||||
b[n] = '.'
|
||||
n++
|
||||
|
||||
n += ubtoa(b, n, byte(ip>>16&255))
|
||||
b[n] = '.'
|
||||
n++
|
||||
|
||||
n += ubtoa(b, n, byte(ip>>8&255))
|
||||
b[n] = '.'
|
||||
n++
|
||||
|
||||
n += ubtoa(b, n, byte(ip&255))
|
||||
return string(b[:n])
|
||||
}
|
||||
|
||||
func (ip VpnIp) MarshalJSON() ([]byte, error) {
|
||||
return []byte(fmt.Sprintf("\"%s\"", ip.String())), nil
|
||||
}
|
||||
|
||||
func (ip VpnIp) ToIP() net.IP {
|
||||
nip := make(net.IP, 4)
|
||||
binary.BigEndian.PutUint32(nip, uint32(ip))
|
||||
return nip
|
||||
}
|
||||
|
||||
func (ip VpnIp) ToNetIpAddr() netip.Addr {
|
||||
var nip [4]byte
|
||||
binary.BigEndian.PutUint32(nip[:], uint32(ip))
|
||||
return netip.AddrFrom4(nip)
|
||||
}
|
||||
|
||||
func Ip2VpnIp(ip []byte) VpnIp {
|
||||
if len(ip) == 16 {
|
||||
return VpnIp(binary.BigEndian.Uint32(ip[12:16]))
|
||||
}
|
||||
return VpnIp(binary.BigEndian.Uint32(ip))
|
||||
}
|
||||
|
||||
func ToNetIpAddr(ip net.IP) (netip.Addr, error) {
|
||||
addr, ok := netip.AddrFromSlice(ip)
|
||||
if !ok {
|
||||
return netip.Addr{}, fmt.Errorf("invalid net.IP: %v", ip)
|
||||
}
|
||||
return addr, nil
|
||||
}
|
||||
|
||||
func ToNetIpPrefix(ipNet net.IPNet) (netip.Prefix, error) {
|
||||
addr, err := ToNetIpAddr(ipNet.IP)
|
||||
if err != nil {
|
||||
return netip.Prefix{}, err
|
||||
}
|
||||
ones, bits := ipNet.Mask.Size()
|
||||
if ones == 0 && bits == 0 {
|
||||
return netip.Prefix{}, fmt.Errorf("invalid net.IP: %v", ipNet)
|
||||
}
|
||||
return netip.PrefixFrom(addr, ones), nil
|
||||
}
|
||||
|
||||
// ubtoa encodes the string form of the integer v to dst[start:] and
|
||||
// returns the number of bytes written to dst. The caller must ensure
|
||||
// that dst has sufficient length.
|
||||
func ubtoa(dst []byte, start int, v byte) int {
|
||||
if v < 10 {
|
||||
dst[start] = v + '0'
|
||||
return 1
|
||||
} else if v < 100 {
|
||||
dst[start+1] = v%10 + '0'
|
||||
dst[start] = v/10 + '0'
|
||||
return 2
|
||||
}
|
||||
|
||||
dst[start+2] = v%10 + '0'
|
||||
dst[start+1] = (v/10)%10 + '0'
|
||||
dst[start] = v/100 + '0'
|
||||
return 3
|
||||
}
|
|
@ -1,17 +0,0 @@
|
|||
package iputil
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestVpnIp_String(t *testing.T) {
|
||||
assert.Equal(t, "255.255.255.255", Ip2VpnIp(net.ParseIP("255.255.255.255")).String())
|
||||
assert.Equal(t, "1.255.255.255", Ip2VpnIp(net.ParseIP("1.255.255.255")).String())
|
||||
assert.Equal(t, "1.1.255.255", Ip2VpnIp(net.ParseIP("1.1.255.255")).String())
|
||||
assert.Equal(t, "1.1.1.255", Ip2VpnIp(net.ParseIP("1.1.1.255")).String())
|
||||
assert.Equal(t, "1.1.1.1", Ip2VpnIp(net.ParseIP("1.1.1.1")).String())
|
||||
assert.Equal(t, "0.0.0.0", Ip2VpnIp(net.ParseIP("0.0.0.0")).String())
|
||||
}
|
400
lighthouse.go
400
lighthouse.go
|
@ -7,16 +7,16 @@ import (
|
|||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/rcrowley/go-metrics"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cidr"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/header"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/slackhq/nebula/udp"
|
||||
"github.com/slackhq/nebula/util"
|
||||
)
|
||||
|
@ -26,25 +26,18 @@ import (
|
|||
|
||||
var ErrHostNotKnown = errors.New("host not known")
|
||||
|
||||
type netIpAndPort struct {
|
||||
ip net.IP
|
||||
port uint16
|
||||
}
|
||||
|
||||
type LightHouse struct {
|
||||
//TODO: We need a timer wheel to kick out vpnIps that haven't reported in a long time
|
||||
sync.RWMutex //Because we concurrently read and write to our maps
|
||||
ctx context.Context
|
||||
amLighthouse bool
|
||||
myVpnIp iputil.VpnIp
|
||||
myVpnZeros iputil.VpnIp
|
||||
myVpnNet *net.IPNet
|
||||
myVpnNet netip.Prefix
|
||||
punchConn udp.Conn
|
||||
punchy *Punchy
|
||||
|
||||
// Local cache of answers from light houses
|
||||
// map of vpn Ip to answers
|
||||
addrMap map[iputil.VpnIp]*RemoteList
|
||||
addrMap map[netip.Addr]*RemoteList
|
||||
|
||||
// filters remote addresses allowed for each host
|
||||
// - When we are a lighthouse, this filters what addresses we store and
|
||||
|
@ -57,26 +50,26 @@ type LightHouse struct {
|
|||
localAllowList atomic.Pointer[LocalAllowList]
|
||||
|
||||
// used to trigger the HandshakeManager when we receive HostQueryReply
|
||||
handshakeTrigger chan<- iputil.VpnIp
|
||||
handshakeTrigger chan<- netip.Addr
|
||||
|
||||
// staticList exists to avoid having a bool in each addrMap entry
|
||||
// since static should be rare
|
||||
staticList atomic.Pointer[map[iputil.VpnIp]struct{}]
|
||||
lighthouses atomic.Pointer[map[iputil.VpnIp]struct{}]
|
||||
staticList atomic.Pointer[map[netip.Addr]struct{}]
|
||||
lighthouses atomic.Pointer[map[netip.Addr]struct{}]
|
||||
|
||||
interval atomic.Int64
|
||||
updateCancel context.CancelFunc
|
||||
ifce EncWriter
|
||||
nebulaPort uint32 // 32 bits because protobuf does not have a uint16
|
||||
|
||||
advertiseAddrs atomic.Pointer[[]netIpAndPort]
|
||||
advertiseAddrs atomic.Pointer[[]netip.AddrPort]
|
||||
|
||||
// IP's of relays that can be used by peers to access me
|
||||
relaysForMe atomic.Pointer[[]iputil.VpnIp]
|
||||
relaysForMe atomic.Pointer[[]netip.Addr]
|
||||
|
||||
queryChan chan iputil.VpnIp
|
||||
queryChan chan netip.Addr
|
||||
|
||||
calculatedRemotes atomic.Pointer[cidr.Tree4[[]*calculatedRemote]] // Maps VpnIp to []*calculatedRemote
|
||||
calculatedRemotes atomic.Pointer[bart.Table[[]*calculatedRemote]] // Maps VpnIp to []*calculatedRemote
|
||||
|
||||
metrics *MessageMetrics
|
||||
metricHolepunchTx metrics.Counter
|
||||
|
@ -85,7 +78,7 @@ type LightHouse struct {
|
|||
|
||||
// NewLightHouseFromConfig will build a Lighthouse struct from the values provided in the config object
|
||||
// addrMap should be nil unless this is during a config reload
|
||||
func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, myVpnNet *net.IPNet, pc udp.Conn, p *Punchy) (*LightHouse, error) {
|
||||
func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, myVpnNet netip.Prefix, pc udp.Conn, p *Punchy) (*LightHouse, error) {
|
||||
amLighthouse := c.GetBool("lighthouse.am_lighthouse", false)
|
||||
nebulaPort := uint32(c.GetInt("listen.port", 0))
|
||||
if amLighthouse && nebulaPort == 0 {
|
||||
|
@ -98,26 +91,23 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C,
|
|||
if err != nil {
|
||||
return nil, util.NewContextualError("Failed to get listening port", nil, err)
|
||||
}
|
||||
nebulaPort = uint32(uPort.Port)
|
||||
nebulaPort = uint32(uPort.Port())
|
||||
}
|
||||
|
||||
ones, _ := myVpnNet.Mask.Size()
|
||||
h := LightHouse{
|
||||
ctx: ctx,
|
||||
amLighthouse: amLighthouse,
|
||||
myVpnIp: iputil.Ip2VpnIp(myVpnNet.IP),
|
||||
myVpnZeros: iputil.VpnIp(32 - ones),
|
||||
myVpnNet: myVpnNet,
|
||||
addrMap: make(map[iputil.VpnIp]*RemoteList),
|
||||
addrMap: make(map[netip.Addr]*RemoteList),
|
||||
nebulaPort: nebulaPort,
|
||||
punchConn: pc,
|
||||
punchy: p,
|
||||
queryChan: make(chan iputil.VpnIp, c.GetUint32("handshakes.query_buffer", 64)),
|
||||
queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)),
|
||||
l: l,
|
||||
}
|
||||
lighthouses := make(map[iputil.VpnIp]struct{})
|
||||
lighthouses := make(map[netip.Addr]struct{})
|
||||
h.lighthouses.Store(&lighthouses)
|
||||
staticList := make(map[iputil.VpnIp]struct{})
|
||||
staticList := make(map[netip.Addr]struct{})
|
||||
h.staticList.Store(&staticList)
|
||||
|
||||
if c.GetBool("stats.lighthouse_metrics", false) {
|
||||
|
@ -147,11 +137,11 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C,
|
|||
return &h, nil
|
||||
}
|
||||
|
||||
func (lh *LightHouse) GetStaticHostList() map[iputil.VpnIp]struct{} {
|
||||
func (lh *LightHouse) GetStaticHostList() map[netip.Addr]struct{} {
|
||||
return *lh.staticList.Load()
|
||||
}
|
||||
|
||||
func (lh *LightHouse) GetLighthouses() map[iputil.VpnIp]struct{} {
|
||||
func (lh *LightHouse) GetLighthouses() map[netip.Addr]struct{} {
|
||||
return *lh.lighthouses.Load()
|
||||
}
|
||||
|
||||
|
@ -163,15 +153,15 @@ func (lh *LightHouse) GetLocalAllowList() *LocalAllowList {
|
|||
return lh.localAllowList.Load()
|
||||
}
|
||||
|
||||
func (lh *LightHouse) GetAdvertiseAddrs() []netIpAndPort {
|
||||
func (lh *LightHouse) GetAdvertiseAddrs() []netip.AddrPort {
|
||||
return *lh.advertiseAddrs.Load()
|
||||
}
|
||||
|
||||
func (lh *LightHouse) GetRelaysForMe() []iputil.VpnIp {
|
||||
func (lh *LightHouse) GetRelaysForMe() []netip.Addr {
|
||||
return *lh.relaysForMe.Load()
|
||||
}
|
||||
|
||||
func (lh *LightHouse) getCalculatedRemotes() *cidr.Tree4[[]*calculatedRemote] {
|
||||
func (lh *LightHouse) getCalculatedRemotes() *bart.Table[[]*calculatedRemote] {
|
||||
return lh.calculatedRemotes.Load()
|
||||
}
|
||||
|
||||
|
@ -182,25 +172,40 @@ func (lh *LightHouse) GetUpdateInterval() int64 {
|
|||
func (lh *LightHouse) reload(c *config.C, initial bool) error {
|
||||
if initial || c.HasChanged("lighthouse.advertise_addrs") {
|
||||
rawAdvAddrs := c.GetStringSlice("lighthouse.advertise_addrs", []string{})
|
||||
advAddrs := make([]netIpAndPort, 0)
|
||||
advAddrs := make([]netip.AddrPort, 0)
|
||||
|
||||
for i, rawAddr := range rawAdvAddrs {
|
||||
fIp, fPort, err := udp.ParseIPAndPort(rawAddr)
|
||||
host, sport, err := net.SplitHostPort(rawAddr)
|
||||
if err != nil {
|
||||
return util.NewContextualError("Unable to parse lighthouse.advertise_addrs entry", m{"addr": rawAddr, "entry": i + 1}, err)
|
||||
}
|
||||
|
||||
if fPort == 0 {
|
||||
fPort = uint16(lh.nebulaPort)
|
||||
ips, err := net.DefaultResolver.LookupNetIP(context.Background(), "ip", host)
|
||||
if err != nil {
|
||||
return util.NewContextualError("Unable to lookup lighthouse.advertise_addrs entry", m{"addr": rawAddr, "entry": i + 1}, err)
|
||||
}
|
||||
if len(ips) == 0 {
|
||||
return util.NewContextualError("Unable to lookup lighthouse.advertise_addrs entry", m{"addr": rawAddr, "entry": i + 1}, nil)
|
||||
}
|
||||
|
||||
if ip4 := fIp.To4(); ip4 != nil && lh.myVpnNet.Contains(fIp) {
|
||||
port, err := strconv.Atoi(sport)
|
||||
if err != nil {
|
||||
return util.NewContextualError("Unable to parse port in lighthouse.advertise_addrs entry", m{"addr": rawAddr, "entry": i + 1}, err)
|
||||
}
|
||||
|
||||
if port == 0 {
|
||||
port = int(lh.nebulaPort)
|
||||
}
|
||||
|
||||
//TODO: we could technically insert all returned ips instead of just the first one if a dns lookup was used
|
||||
ip := ips[0].Unmap()
|
||||
if lh.myVpnNet.Contains(ip) {
|
||||
lh.l.WithField("addr", rawAddr).WithField("entry", i+1).
|
||||
Warn("Ignoring lighthouse.advertise_addrs report because it is within the nebula network range")
|
||||
continue
|
||||
}
|
||||
|
||||
advAddrs = append(advAddrs, netIpAndPort{ip: fIp, port: fPort})
|
||||
advAddrs = append(advAddrs, netip.AddrPortFrom(ip, uint16(port)))
|
||||
}
|
||||
|
||||
lh.advertiseAddrs.Store(&advAddrs)
|
||||
|
@ -278,8 +283,8 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
|
|||
lh.RUnlock()
|
||||
}
|
||||
// Build a new list based on current config.
|
||||
staticList := make(map[iputil.VpnIp]struct{})
|
||||
err := lh.loadStaticMap(c, lh.myVpnNet, staticList)
|
||||
staticList := make(map[netip.Addr]struct{})
|
||||
err := lh.loadStaticMap(c, staticList)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -303,8 +308,8 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
|
|||
}
|
||||
|
||||
if initial || c.HasChanged("lighthouse.hosts") {
|
||||
lhMap := make(map[iputil.VpnIp]struct{})
|
||||
err := lh.parseLighthouses(c, lh.myVpnNet, lhMap)
|
||||
lhMap := make(map[netip.Addr]struct{})
|
||||
err := lh.parseLighthouses(c, lhMap)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -323,16 +328,17 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
|
|||
if len(c.GetStringSlice("relay.relays", nil)) > 0 {
|
||||
lh.l.Info("Ignoring relays from config because am_relay is true")
|
||||
}
|
||||
relaysForMe := []iputil.VpnIp{}
|
||||
relaysForMe := []netip.Addr{}
|
||||
lh.relaysForMe.Store(&relaysForMe)
|
||||
case false:
|
||||
relaysForMe := []iputil.VpnIp{}
|
||||
relaysForMe := []netip.Addr{}
|
||||
for _, v := range c.GetStringSlice("relay.relays", nil) {
|
||||
lh.l.WithField("relay", v).Info("Read relay from config")
|
||||
|
||||
configRIP := net.ParseIP(v)
|
||||
if configRIP != nil {
|
||||
relaysForMe = append(relaysForMe, iputil.Ip2VpnIp(configRIP))
|
||||
configRIP, err := netip.ParseAddr(v)
|
||||
//TODO: We could print the error here
|
||||
if err == nil {
|
||||
relaysForMe = append(relaysForMe, configRIP)
|
||||
}
|
||||
}
|
||||
lh.relaysForMe.Store(&relaysForMe)
|
||||
|
@ -342,21 +348,21 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (lh *LightHouse) parseLighthouses(c *config.C, tunCidr *net.IPNet, lhMap map[iputil.VpnIp]struct{}) error {
|
||||
func (lh *LightHouse) parseLighthouses(c *config.C, lhMap map[netip.Addr]struct{}) error {
|
||||
lhs := c.GetStringSlice("lighthouse.hosts", []string{})
|
||||
if lh.amLighthouse && len(lhs) != 0 {
|
||||
lh.l.Warn("lighthouse.am_lighthouse enabled on node but upstream lighthouses exist in config")
|
||||
}
|
||||
|
||||
for i, host := range lhs {
|
||||
ip := net.ParseIP(host)
|
||||
if ip == nil {
|
||||
return util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, nil)
|
||||
ip, err := netip.ParseAddr(host)
|
||||
if err != nil {
|
||||
return util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, err)
|
||||
}
|
||||
if !tunCidr.Contains(ip) {
|
||||
return util.NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": tunCidr.String()}, nil)
|
||||
if !lh.myVpnNet.Contains(ip) {
|
||||
return util.NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": lh.myVpnNet}, nil)
|
||||
}
|
||||
lhMap[iputil.Ip2VpnIp(ip)] = struct{}{}
|
||||
lhMap[ip] = struct{}{}
|
||||
}
|
||||
|
||||
if !lh.amLighthouse && len(lhMap) == 0 {
|
||||
|
@ -399,7 +405,7 @@ func getStaticMapNetwork(c *config.C) (string, error) {
|
|||
return network, nil
|
||||
}
|
||||
|
||||
func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList map[iputil.VpnIp]struct{}) error {
|
||||
func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struct{}) error {
|
||||
d, err := getStaticMapCadence(c)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -410,7 +416,7 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList
|
|||
return err
|
||||
}
|
||||
|
||||
lookup_timeout, err := getStaticMapLookupTimeout(c)
|
||||
lookupTimeout, err := getStaticMapLookupTimeout(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -419,16 +425,15 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList
|
|||
i := 0
|
||||
|
||||
for k, v := range shm {
|
||||
rip := net.ParseIP(fmt.Sprintf("%v", k))
|
||||
if rip == nil {
|
||||
return util.NewContextualError("Unable to parse static_host_map entry", m{"host": k, "entry": i + 1}, nil)
|
||||
vpnIp, err := netip.ParseAddr(fmt.Sprintf("%v", k))
|
||||
if err != nil {
|
||||
return util.NewContextualError("Unable to parse static_host_map entry", m{"host": k, "entry": i + 1}, err)
|
||||
}
|
||||
|
||||
if !tunCidr.Contains(rip) {
|
||||
return util.NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": rip, "network": tunCidr.String(), "entry": i + 1}, nil)
|
||||
if !lh.myVpnNet.Contains(vpnIp) {
|
||||
return util.NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": vpnIp, "network": lh.myVpnNet, "entry": i + 1}, nil)
|
||||
}
|
||||
|
||||
vpnIp := iputil.Ip2VpnIp(rip)
|
||||
vals, ok := v.([]interface{})
|
||||
if !ok {
|
||||
vals = []interface{}{v}
|
||||
|
@ -438,7 +443,7 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList
|
|||
remoteAddrs = append(remoteAddrs, fmt.Sprintf("%v", v))
|
||||
}
|
||||
|
||||
err := lh.addStaticRemotes(i, d, network, lookup_timeout, vpnIp, remoteAddrs, staticList)
|
||||
err = lh.addStaticRemotes(i, d, network, lookupTimeout, vpnIp, remoteAddrs, staticList)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -448,7 +453,7 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList
|
|||
return nil
|
||||
}
|
||||
|
||||
func (lh *LightHouse) Query(ip iputil.VpnIp) *RemoteList {
|
||||
func (lh *LightHouse) Query(ip netip.Addr) *RemoteList {
|
||||
if !lh.IsLighthouseIP(ip) {
|
||||
lh.QueryServer(ip)
|
||||
}
|
||||
|
@ -462,7 +467,7 @@ func (lh *LightHouse) Query(ip iputil.VpnIp) *RemoteList {
|
|||
}
|
||||
|
||||
// QueryServer is asynchronous so no reply should be expected
|
||||
func (lh *LightHouse) QueryServer(ip iputil.VpnIp) {
|
||||
func (lh *LightHouse) QueryServer(ip netip.Addr) {
|
||||
// Don't put lighthouse ips in the query channel because we can't query lighthouses about lighthouses
|
||||
if lh.amLighthouse || lh.IsLighthouseIP(ip) {
|
||||
return
|
||||
|
@ -471,7 +476,7 @@ func (lh *LightHouse) QueryServer(ip iputil.VpnIp) {
|
|||
lh.queryChan <- ip
|
||||
}
|
||||
|
||||
func (lh *LightHouse) QueryCache(ip iputil.VpnIp) *RemoteList {
|
||||
func (lh *LightHouse) QueryCache(ip netip.Addr) *RemoteList {
|
||||
lh.RLock()
|
||||
if v, ok := lh.addrMap[ip]; ok {
|
||||
lh.RUnlock()
|
||||
|
@ -488,7 +493,7 @@ func (lh *LightHouse) QueryCache(ip iputil.VpnIp) *RemoteList {
|
|||
// queryAndPrepMessage is a lock helper on RemoteList, assisting the caller to build a lighthouse message containing
|
||||
// details from the remote list. It looks for a hit in the addrMap and a hit in the RemoteList under the owner vpnIp
|
||||
// If one is found then f() is called with proper locking, f() must return result of n.MarshalTo()
|
||||
func (lh *LightHouse) queryAndPrepMessage(vpnIp iputil.VpnIp, f func(*cache) (int, error)) (bool, int, error) {
|
||||
func (lh *LightHouse) queryAndPrepMessage(vpnIp netip.Addr, f func(*cache) (int, error)) (bool, int, error) {
|
||||
lh.RLock()
|
||||
// Do we have an entry in the main cache?
|
||||
if v, ok := lh.addrMap[vpnIp]; ok {
|
||||
|
@ -511,7 +516,7 @@ func (lh *LightHouse) queryAndPrepMessage(vpnIp iputil.VpnIp, f func(*cache) (in
|
|||
return false, 0, nil
|
||||
}
|
||||
|
||||
func (lh *LightHouse) DeleteVpnIp(vpnIp iputil.VpnIp) {
|
||||
func (lh *LightHouse) DeleteVpnIp(vpnIp netip.Addr) {
|
||||
// First we check the static mapping
|
||||
// and do nothing if it is there
|
||||
if _, ok := lh.GetStaticHostList()[vpnIp]; ok {
|
||||
|
@ -532,7 +537,7 @@ func (lh *LightHouse) DeleteVpnIp(vpnIp iputil.VpnIp) {
|
|||
// We are the owner because we don't want a lighthouse server to advertise for static hosts it was configured with
|
||||
// And we don't want a lighthouse query reply to interfere with our learned cache if we are a client
|
||||
// NOTE: this function should not interact with any hot path objects, like lh.staticList, the caller should handle it
|
||||
func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, timeout time.Duration, vpnIp iputil.VpnIp, toAddrs []string, staticList map[iputil.VpnIp]struct{}) error {
|
||||
func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, timeout time.Duration, vpnIp netip.Addr, toAddrs []string, staticList map[netip.Addr]struct{}) error {
|
||||
lh.Lock()
|
||||
am := lh.unlockedGetRemoteList(vpnIp)
|
||||
am.Lock()
|
||||
|
@ -553,20 +558,14 @@ func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, t
|
|||
am.unlockedSetHostnamesResults(hr)
|
||||
|
||||
for _, addrPort := range hr.GetIPs() {
|
||||
|
||||
if !lh.shouldAdd(vpnIp, addrPort.Addr()) {
|
||||
continue
|
||||
}
|
||||
switch {
|
||||
case addrPort.Addr().Is4():
|
||||
to := NewIp4AndPortFromNetIP(addrPort.Addr(), addrPort.Port())
|
||||
if !lh.unlockedShouldAddV4(vpnIp, to) {
|
||||
continue
|
||||
}
|
||||
am.unlockedPrependV4(lh.myVpnIp, to)
|
||||
am.unlockedPrependV4(lh.myVpnNet.Addr(), NewIp4AndPortFromNetIP(addrPort.Addr(), addrPort.Port()))
|
||||
case addrPort.Addr().Is6():
|
||||
to := NewIp6AndPortFromNetIP(addrPort.Addr(), addrPort.Port())
|
||||
if !lh.unlockedShouldAddV6(vpnIp, to) {
|
||||
continue
|
||||
}
|
||||
am.unlockedPrependV6(lh.myVpnIp, to)
|
||||
am.unlockedPrependV6(lh.myVpnNet.Addr(), NewIp6AndPortFromNetIP(addrPort.Addr(), addrPort.Port()))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -578,12 +577,12 @@ func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, t
|
|||
// addCalculatedRemotes adds any calculated remotes based on the
|
||||
// lighthouse.calculated_remotes configuration. It returns true if any
|
||||
// calculated remotes were added
|
||||
func (lh *LightHouse) addCalculatedRemotes(vpnIp iputil.VpnIp) bool {
|
||||
func (lh *LightHouse) addCalculatedRemotes(vpnIp netip.Addr) bool {
|
||||
tree := lh.getCalculatedRemotes()
|
||||
if tree == nil {
|
||||
return false
|
||||
}
|
||||
ok, calculatedRemotes := tree.MostSpecificContains(vpnIp)
|
||||
calculatedRemotes, ok := tree.Lookup(vpnIp)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
@ -602,13 +601,13 @@ func (lh *LightHouse) addCalculatedRemotes(vpnIp iputil.VpnIp) bool {
|
|||
defer am.Unlock()
|
||||
lh.Unlock()
|
||||
|
||||
am.unlockedSetV4(lh.myVpnIp, vpnIp, calculated, lh.unlockedShouldAddV4)
|
||||
am.unlockedSetV4(lh.myVpnNet.Addr(), vpnIp, calculated, lh.unlockedShouldAddV4)
|
||||
|
||||
return len(calculated) > 0
|
||||
}
|
||||
|
||||
// unlockedGetRemoteList assumes you have the lh lock
|
||||
func (lh *LightHouse) unlockedGetRemoteList(vpnIp iputil.VpnIp) *RemoteList {
|
||||
func (lh *LightHouse) unlockedGetRemoteList(vpnIp netip.Addr) *RemoteList {
|
||||
am, ok := lh.addrMap[vpnIp]
|
||||
if !ok {
|
||||
am = NewRemoteList(func(a netip.Addr) bool { return lh.shouldAdd(vpnIp, a) })
|
||||
|
@ -617,44 +616,27 @@ func (lh *LightHouse) unlockedGetRemoteList(vpnIp iputil.VpnIp) *RemoteList {
|
|||
return am
|
||||
}
|
||||
|
||||
func (lh *LightHouse) shouldAdd(vpnIp iputil.VpnIp, to netip.Addr) bool {
|
||||
switch {
|
||||
case to.Is4():
|
||||
ipBytes := to.As4()
|
||||
ip := iputil.Ip2VpnIp(ipBytes[:])
|
||||
allow := lh.GetRemoteAllowList().AllowIpV4(vpnIp, ip)
|
||||
if lh.l.Level >= logrus.TraceLevel {
|
||||
lh.l.WithField("remoteIp", vpnIp).WithField("allow", allow).Trace("remoteAllowList.Allow")
|
||||
}
|
||||
if !allow || ipMaskContains(lh.myVpnIp, lh.myVpnZeros, ip) {
|
||||
return false
|
||||
}
|
||||
case to.Is6():
|
||||
ipBytes := to.As16()
|
||||
|
||||
hi := binary.BigEndian.Uint64(ipBytes[:8])
|
||||
lo := binary.BigEndian.Uint64(ipBytes[8:])
|
||||
allow := lh.GetRemoteAllowList().AllowIpV6(vpnIp, hi, lo)
|
||||
if lh.l.Level >= logrus.TraceLevel {
|
||||
lh.l.WithField("remoteIp", to).WithField("allow", allow).Trace("remoteAllowList.Allow")
|
||||
}
|
||||
|
||||
// We don't check our vpn network here because nebula does not support ipv6 on the inside
|
||||
if !allow {
|
||||
return false
|
||||
}
|
||||
func (lh *LightHouse) shouldAdd(vpnIp netip.Addr, to netip.Addr) bool {
|
||||
allow := lh.GetRemoteAllowList().Allow(vpnIp, to)
|
||||
if lh.l.Level >= logrus.TraceLevel {
|
||||
lh.l.WithField("remoteIp", vpnIp).WithField("allow", allow).Trace("remoteAllowList.Allow")
|
||||
}
|
||||
if !allow || lh.myVpnNet.Contains(to) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// unlockedShouldAddV4 checks if to is allowed by our allow list
|
||||
func (lh *LightHouse) unlockedShouldAddV4(vpnIp iputil.VpnIp, to *Ip4AndPort) bool {
|
||||
allow := lh.GetRemoteAllowList().AllowIpV4(vpnIp, iputil.VpnIp(to.Ip))
|
||||
func (lh *LightHouse) unlockedShouldAddV4(vpnIp netip.Addr, to *Ip4AndPort) bool {
|
||||
ip := AddrPortFromIp4AndPort(to)
|
||||
allow := lh.GetRemoteAllowList().Allow(vpnIp, ip.Addr())
|
||||
if lh.l.Level >= logrus.TraceLevel {
|
||||
lh.l.WithField("remoteIp", vpnIp).WithField("allow", allow).Trace("remoteAllowList.Allow")
|
||||
}
|
||||
|
||||
if !allow || ipMaskContains(lh.myVpnIp, lh.myVpnZeros, iputil.VpnIp(to.Ip)) {
|
||||
if !allow || lh.myVpnNet.Contains(ip.Addr()) {
|
||||
return false
|
||||
}
|
||||
|
||||
|
@ -662,14 +644,14 @@ func (lh *LightHouse) unlockedShouldAddV4(vpnIp iputil.VpnIp, to *Ip4AndPort) bo
|
|||
}
|
||||
|
||||
// unlockedShouldAddV6 checks if to is allowed by our allow list
|
||||
func (lh *LightHouse) unlockedShouldAddV6(vpnIp iputil.VpnIp, to *Ip6AndPort) bool {
|
||||
allow := lh.GetRemoteAllowList().AllowIpV6(vpnIp, to.Hi, to.Lo)
|
||||
func (lh *LightHouse) unlockedShouldAddV6(vpnIp netip.Addr, to *Ip6AndPort) bool {
|
||||
ip := AddrPortFromIp6AndPort(to)
|
||||
allow := lh.GetRemoteAllowList().Allow(vpnIp, ip.Addr())
|
||||
if lh.l.Level >= logrus.TraceLevel {
|
||||
lh.l.WithField("remoteIp", lhIp6ToIp(to)).WithField("allow", allow).Trace("remoteAllowList.Allow")
|
||||
}
|
||||
|
||||
// We don't check our vpn network here because nebula does not support ipv6 on the inside
|
||||
if !allow {
|
||||
if !allow || lh.myVpnNet.Contains(ip.Addr()) {
|
||||
return false
|
||||
}
|
||||
|
||||
|
@ -683,26 +665,39 @@ func lhIp6ToIp(v *Ip6AndPort) net.IP {
|
|||
return ip
|
||||
}
|
||||
|
||||
func (lh *LightHouse) IsLighthouseIP(vpnIp iputil.VpnIp) bool {
|
||||
func (lh *LightHouse) IsLighthouseIP(vpnIp netip.Addr) bool {
|
||||
if _, ok := lh.GetLighthouses()[vpnIp]; ok {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func NewLhQueryByInt(VpnIp iputil.VpnIp) *NebulaMeta {
|
||||
func NewLhQueryByInt(vpnIp netip.Addr) *NebulaMeta {
|
||||
if vpnIp.Is6() {
|
||||
//TODO: need to support ipv6
|
||||
panic("ipv6 is not yet supported")
|
||||
}
|
||||
|
||||
b := vpnIp.As4()
|
||||
return &NebulaMeta{
|
||||
Type: NebulaMeta_HostQuery,
|
||||
Details: &NebulaMetaDetails{
|
||||
VpnIp: uint32(VpnIp),
|
||||
VpnIp: binary.BigEndian.Uint32(b[:]),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func NewIp4AndPort(ip net.IP, port uint32) *Ip4AndPort {
|
||||
ipp := Ip4AndPort{Port: port}
|
||||
ipp.Ip = uint32(iputil.Ip2VpnIp(ip))
|
||||
return &ipp
|
||||
func AddrPortFromIp4AndPort(ip *Ip4AndPort) netip.AddrPort {
|
||||
b := [4]byte{}
|
||||
binary.BigEndian.PutUint32(b[:], ip.Ip)
|
||||
return netip.AddrPortFrom(netip.AddrFrom4(b), uint16(ip.Port))
|
||||
}
|
||||
|
||||
func AddrPortFromIp6AndPort(ip *Ip6AndPort) netip.AddrPort {
|
||||
b := [16]byte{}
|
||||
binary.BigEndian.PutUint64(b[:8], ip.Hi)
|
||||
binary.BigEndian.PutUint64(b[8:], ip.Lo)
|
||||
return netip.AddrPortFrom(netip.AddrFrom16(b), uint16(ip.Port))
|
||||
}
|
||||
|
||||
func NewIp4AndPortFromNetIP(ip netip.Addr, port uint16) *Ip4AndPort {
|
||||
|
@ -713,14 +708,7 @@ func NewIp4AndPortFromNetIP(ip netip.Addr, port uint16) *Ip4AndPort {
|
|||
}
|
||||
}
|
||||
|
||||
func NewIp6AndPort(ip net.IP, port uint32) *Ip6AndPort {
|
||||
return &Ip6AndPort{
|
||||
Hi: binary.BigEndian.Uint64(ip[:8]),
|
||||
Lo: binary.BigEndian.Uint64(ip[8:]),
|
||||
Port: port,
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: IPV6-WORK we can delete some more of these
|
||||
func NewIp6AndPortFromNetIP(ip netip.Addr, port uint16) *Ip6AndPort {
|
||||
ip6Addr := ip.As16()
|
||||
return &Ip6AndPort{
|
||||
|
@ -729,17 +717,6 @@ func NewIp6AndPortFromNetIP(ip netip.Addr, port uint16) *Ip6AndPort {
|
|||
Port: uint32(port),
|
||||
}
|
||||
}
|
||||
func NewUDPAddrFromLH4(ipp *Ip4AndPort) *udp.Addr {
|
||||
ip := ipp.Ip
|
||||
return udp.NewAddr(
|
||||
net.IPv4(byte(ip&0xff000000>>24), byte(ip&0x00ff0000>>16), byte(ip&0x0000ff00>>8), byte(ip&0x000000ff)),
|
||||
uint16(ipp.Port),
|
||||
)
|
||||
}
|
||||
|
||||
func NewUDPAddrFromLH6(ipp *Ip6AndPort) *udp.Addr {
|
||||
return udp.NewAddr(lhIp6ToIp(ipp), uint16(ipp.Port))
|
||||
}
|
||||
|
||||
func (lh *LightHouse) startQueryWorker() {
|
||||
if lh.amLighthouse {
|
||||
|
@ -761,7 +738,7 @@ func (lh *LightHouse) startQueryWorker() {
|
|||
}()
|
||||
}
|
||||
|
||||
func (lh *LightHouse) innerQueryServer(ip iputil.VpnIp, nb, out []byte) {
|
||||
func (lh *LightHouse) innerQueryServer(ip netip.Addr, nb, out []byte) {
|
||||
if lh.IsLighthouseIP(ip) {
|
||||
return
|
||||
}
|
||||
|
@ -812,36 +789,41 @@ func (lh *LightHouse) SendUpdate() {
|
|||
var v6 []*Ip6AndPort
|
||||
|
||||
for _, e := range lh.GetAdvertiseAddrs() {
|
||||
if ip := e.ip.To4(); ip != nil {
|
||||
v4 = append(v4, NewIp4AndPort(e.ip, uint32(e.port)))
|
||||
if e.Addr().Is4() {
|
||||
v4 = append(v4, NewIp4AndPortFromNetIP(e.Addr(), e.Port()))
|
||||
} else {
|
||||
v6 = append(v6, NewIp6AndPort(e.ip, uint32(e.port)))
|
||||
v6 = append(v6, NewIp6AndPortFromNetIP(e.Addr(), e.Port()))
|
||||
}
|
||||
}
|
||||
|
||||
lal := lh.GetLocalAllowList()
|
||||
for _, e := range *localIps(lh.l, lal) {
|
||||
if ip4 := e.To4(); ip4 != nil && ipMaskContains(lh.myVpnIp, lh.myVpnZeros, iputil.Ip2VpnIp(ip4)) {
|
||||
for _, e := range localIps(lh.l, lal) {
|
||||
if lh.myVpnNet.Contains(e) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Only add IPs that aren't my VPN/tun IP
|
||||
if ip := e.To4(); ip != nil {
|
||||
v4 = append(v4, NewIp4AndPort(e, lh.nebulaPort))
|
||||
if e.Is4() {
|
||||
v4 = append(v4, NewIp4AndPortFromNetIP(e, uint16(lh.nebulaPort)))
|
||||
} else {
|
||||
v6 = append(v6, NewIp6AndPort(e, lh.nebulaPort))
|
||||
v6 = append(v6, NewIp6AndPortFromNetIP(e, uint16(lh.nebulaPort)))
|
||||
}
|
||||
}
|
||||
|
||||
var relays []uint32
|
||||
for _, r := range lh.GetRelaysForMe() {
|
||||
relays = append(relays, (uint32)(r))
|
||||
//TODO: IPV6-WORK both relays and vpnip need ipv6 support
|
||||
b := r.As4()
|
||||
relays = append(relays, binary.BigEndian.Uint32(b[:]))
|
||||
}
|
||||
|
||||
//TODO: IPV6-WORK both relays and vpnip need ipv6 support
|
||||
b := lh.myVpnNet.Addr().As4()
|
||||
|
||||
m := &NebulaMeta{
|
||||
Type: NebulaMeta_HostUpdateNotification,
|
||||
Details: &NebulaMetaDetails{
|
||||
VpnIp: uint32(lh.myVpnIp),
|
||||
VpnIp: binary.BigEndian.Uint32(b[:]),
|
||||
Ip4AndPorts: v4,
|
||||
Ip6AndPorts: v6,
|
||||
RelayVpnIp: relays,
|
||||
|
@ -913,12 +895,12 @@ func (lhh *LightHouseHandler) resetMeta() *NebulaMeta {
|
|||
}
|
||||
|
||||
func lhHandleRequest(lhh *LightHouseHandler, f *Interface) udp.LightHouseHandlerFunc {
|
||||
return func(rAddr *udp.Addr, vpnIp iputil.VpnIp, p []byte) {
|
||||
return func(rAddr netip.AddrPort, vpnIp netip.Addr, p []byte) {
|
||||
lhh.HandleRequest(rAddr, vpnIp, p, f)
|
||||
}
|
||||
}
|
||||
|
||||
func (lhh *LightHouseHandler) HandleRequest(rAddr *udp.Addr, vpnIp iputil.VpnIp, p []byte, w EncWriter) {
|
||||
func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, vpnIp netip.Addr, p []byte, w EncWriter) {
|
||||
n := lhh.resetMeta()
|
||||
err := n.Unmarshal(p)
|
||||
if err != nil {
|
||||
|
@ -956,7 +938,7 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udp.Addr, vpnIp iputil.VpnIp,
|
|||
}
|
||||
}
|
||||
|
||||
func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp, addr *udp.Addr, w EncWriter) {
|
||||
func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp netip.Addr, addr netip.AddrPort, w EncWriter) {
|
||||
// Exit if we don't answer queries
|
||||
if !lhh.lh.amLighthouse {
|
||||
if lhh.l.Level >= logrus.DebugLevel {
|
||||
|
@ -967,8 +949,14 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp,
|
|||
|
||||
//TODO: we can DRY this further
|
||||
reqVpnIp := n.Details.VpnIp
|
||||
|
||||
//TODO: IPV6-WORK
|
||||
b := [4]byte{}
|
||||
binary.BigEndian.PutUint32(b[:], n.Details.VpnIp)
|
||||
queryVpnIp := netip.AddrFrom4(b)
|
||||
|
||||
//TODO: Maybe instead of marshalling into n we marshal into a new `r` to not nuke our current request data
|
||||
found, ln, err := lhh.lh.queryAndPrepMessage(iputil.VpnIp(n.Details.VpnIp), func(c *cache) (int, error) {
|
||||
found, ln, err := lhh.lh.queryAndPrepMessage(queryVpnIp, func(c *cache) (int, error) {
|
||||
n = lhh.resetMeta()
|
||||
n.Type = NebulaMeta_HostQueryReply
|
||||
n.Details.VpnIp = reqVpnIp
|
||||
|
@ -994,8 +982,9 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp,
|
|||
found, ln, err = lhh.lh.queryAndPrepMessage(vpnIp, func(c *cache) (int, error) {
|
||||
n = lhh.resetMeta()
|
||||
n.Type = NebulaMeta_HostPunchNotification
|
||||
n.Details.VpnIp = uint32(vpnIp)
|
||||
|
||||
//TODO: IPV6-WORK
|
||||
b = vpnIp.As4()
|
||||
n.Details.VpnIp = binary.BigEndian.Uint32(b[:])
|
||||
lhh.coalesceAnswers(c, n)
|
||||
|
||||
return n.MarshalTo(lhh.pb)
|
||||
|
@ -1011,7 +1000,11 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp,
|
|||
}
|
||||
|
||||
lhh.lh.metricTx(NebulaMeta_HostPunchNotification, 1)
|
||||
w.SendMessageToVpnIp(header.LightHouse, 0, iputil.VpnIp(reqVpnIp), lhh.pb[:ln], lhh.nb, lhh.out[:0])
|
||||
|
||||
//TODO: IPV6-WORK
|
||||
binary.BigEndian.PutUint32(b[:], reqVpnIp)
|
||||
sendTo := netip.AddrFrom4(b)
|
||||
w.SendMessageToVpnIp(header.LightHouse, 0, sendTo, lhh.pb[:ln], lhh.nb, lhh.out[:0])
|
||||
}
|
||||
|
||||
func (lhh *LightHouseHandler) coalesceAnswers(c *cache, n *NebulaMeta) {
|
||||
|
@ -1034,34 +1027,52 @@ func (lhh *LightHouseHandler) coalesceAnswers(c *cache, n *NebulaMeta) {
|
|||
}
|
||||
|
||||
if c.relay != nil {
|
||||
n.Details.RelayVpnIp = append(n.Details.RelayVpnIp, c.relay.relay...)
|
||||
//TODO: IPV6-WORK
|
||||
relays := make([]uint32, len(c.relay.relay))
|
||||
b := [4]byte{}
|
||||
for i, _ := range relays {
|
||||
b = c.relay.relay[i].As4()
|
||||
relays[i] = binary.BigEndian.Uint32(b[:])
|
||||
}
|
||||
n.Details.RelayVpnIp = append(n.Details.RelayVpnIp, relays...)
|
||||
}
|
||||
}
|
||||
|
||||
func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, vpnIp iputil.VpnIp) {
|
||||
func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, vpnIp netip.Addr) {
|
||||
if !lhh.lh.IsLighthouseIP(vpnIp) {
|
||||
return
|
||||
}
|
||||
|
||||
lhh.lh.Lock()
|
||||
am := lhh.lh.unlockedGetRemoteList(iputil.VpnIp(n.Details.VpnIp))
|
||||
//TODO: IPV6-WORK
|
||||
b := [4]byte{}
|
||||
binary.BigEndian.PutUint32(b[:], n.Details.VpnIp)
|
||||
certVpnIp := netip.AddrFrom4(b)
|
||||
am := lhh.lh.unlockedGetRemoteList(certVpnIp)
|
||||
am.Lock()
|
||||
lhh.lh.Unlock()
|
||||
|
||||
certVpnIp := iputil.VpnIp(n.Details.VpnIp)
|
||||
//TODO: IPV6-WORK
|
||||
am.unlockedSetV4(vpnIp, certVpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4)
|
||||
am.unlockedSetV6(vpnIp, certVpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6)
|
||||
am.unlockedSetRelay(vpnIp, certVpnIp, n.Details.RelayVpnIp)
|
||||
|
||||
//TODO: IPV6-WORK
|
||||
relays := make([]netip.Addr, len(n.Details.RelayVpnIp))
|
||||
for i, _ := range n.Details.RelayVpnIp {
|
||||
binary.BigEndian.PutUint32(b[:], n.Details.RelayVpnIp[i])
|
||||
relays[i] = netip.AddrFrom4(b)
|
||||
}
|
||||
am.unlockedSetRelay(vpnIp, certVpnIp, relays)
|
||||
am.Unlock()
|
||||
|
||||
// Non-blocking attempt to trigger, skip if it would block
|
||||
select {
|
||||
case lhh.lh.handshakeTrigger <- iputil.VpnIp(n.Details.VpnIp):
|
||||
case lhh.lh.handshakeTrigger <- certVpnIp:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp iputil.VpnIp, w EncWriter) {
|
||||
func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp netip.Addr, w EncWriter) {
|
||||
if !lhh.lh.amLighthouse {
|
||||
if lhh.l.Level >= logrus.DebugLevel {
|
||||
lhh.l.Debugln("I am not a lighthouse, do not take host updates: ", vpnIp)
|
||||
|
@ -1070,9 +1081,13 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp
|
|||
}
|
||||
|
||||
//Simple check that the host sent this not someone else
|
||||
if n.Details.VpnIp != uint32(vpnIp) {
|
||||
//TODO: IPV6-WORK
|
||||
b := [4]byte{}
|
||||
binary.BigEndian.PutUint32(b[:], n.Details.VpnIp)
|
||||
detailsVpnIp := netip.AddrFrom4(b)
|
||||
if detailsVpnIp != vpnIp {
|
||||
if lhh.l.Level >= logrus.DebugLevel {
|
||||
lhh.l.WithField("vpnIp", vpnIp).WithField("answer", iputil.VpnIp(n.Details.VpnIp)).Debugln("Host sent invalid update")
|
||||
lhh.l.WithField("vpnIp", vpnIp).WithField("answer", detailsVpnIp).Debugln("Host sent invalid update")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -1082,15 +1097,24 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp
|
|||
am.Lock()
|
||||
lhh.lh.Unlock()
|
||||
|
||||
certVpnIp := iputil.VpnIp(n.Details.VpnIp)
|
||||
am.unlockedSetV4(vpnIp, certVpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4)
|
||||
am.unlockedSetV6(vpnIp, certVpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6)
|
||||
am.unlockedSetRelay(vpnIp, certVpnIp, n.Details.RelayVpnIp)
|
||||
am.unlockedSetV4(vpnIp, detailsVpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4)
|
||||
am.unlockedSetV6(vpnIp, detailsVpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6)
|
||||
|
||||
//TODO: IPV6-WORK
|
||||
relays := make([]netip.Addr, len(n.Details.RelayVpnIp))
|
||||
for i, _ := range n.Details.RelayVpnIp {
|
||||
binary.BigEndian.PutUint32(b[:], n.Details.RelayVpnIp[i])
|
||||
relays[i] = netip.AddrFrom4(b)
|
||||
}
|
||||
am.unlockedSetRelay(vpnIp, detailsVpnIp, relays)
|
||||
am.Unlock()
|
||||
|
||||
n = lhh.resetMeta()
|
||||
n.Type = NebulaMeta_HostUpdateNotificationAck
|
||||
n.Details.VpnIp = uint32(vpnIp)
|
||||
|
||||
//TODO: IPV6-WORK
|
||||
vpnIpB := vpnIp.As4()
|
||||
n.Details.VpnIp = binary.BigEndian.Uint32(vpnIpB[:])
|
||||
ln, err := n.MarshalTo(lhh.pb)
|
||||
|
||||
if err != nil {
|
||||
|
@ -1102,14 +1126,14 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp
|
|||
w.SendMessageToVpnIp(header.LightHouse, 0, vpnIp, lhh.pb[:ln], lhh.nb, lhh.out[:0])
|
||||
}
|
||||
|
||||
func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp iputil.VpnIp, w EncWriter) {
|
||||
func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp netip.Addr, w EncWriter) {
|
||||
if !lhh.lh.IsLighthouseIP(vpnIp) {
|
||||
return
|
||||
}
|
||||
|
||||
empty := []byte{0}
|
||||
punch := func(vpnPeer *udp.Addr) {
|
||||
if vpnPeer == nil {
|
||||
punch := func(vpnPeer netip.AddrPort) {
|
||||
if !vpnPeer.IsValid() {
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -1121,23 +1145,29 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp i
|
|||
|
||||
if lhh.l.Level >= logrus.DebugLevel {
|
||||
//TODO: lacking the ip we are actually punching on, old: l.Debugf("Punching %s on %d for %s", IntIp(a.Ip), a.Port, IntIp(n.Details.VpnIp))
|
||||
lhh.l.Debugf("Punching on %d for %s", vpnPeer.Port, iputil.VpnIp(n.Details.VpnIp))
|
||||
//TODO: IPV6-WORK, make this debug line not suck
|
||||
b := [4]byte{}
|
||||
binary.BigEndian.PutUint32(b[:], n.Details.VpnIp)
|
||||
lhh.l.Debugf("Punching on %d for %v", vpnPeer.Port(), netip.AddrFrom4(b))
|
||||
}
|
||||
}
|
||||
|
||||
for _, a := range n.Details.Ip4AndPorts {
|
||||
punch(NewUDPAddrFromLH4(a))
|
||||
punch(AddrPortFromIp4AndPort(a))
|
||||
}
|
||||
|
||||
for _, a := range n.Details.Ip6AndPorts {
|
||||
punch(NewUDPAddrFromLH6(a))
|
||||
punch(AddrPortFromIp6AndPort(a))
|
||||
}
|
||||
|
||||
// This sends a nebula test packet to the host trying to contact us. In the case
|
||||
// of a double nat or other difficult scenario, this may help establish
|
||||
// a tunnel.
|
||||
if lhh.lh.punchy.GetRespond() {
|
||||
queryVpnIp := iputil.VpnIp(n.Details.VpnIp)
|
||||
//TODO: IPV6-WORK
|
||||
b := [4]byte{}
|
||||
binary.BigEndian.PutUint32(b[:], n.Details.VpnIp)
|
||||
queryVpnIp := netip.AddrFrom4(b)
|
||||
go func() {
|
||||
time.Sleep(lhh.lh.punchy.GetRespondDelay())
|
||||
if lhh.l.Level >= logrus.DebugLevel {
|
||||
|
@ -1150,9 +1180,3 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp i
|
|||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// ipMaskContains checks if testIp is contained by ip after applying a cidr.
|
||||
// zeros is 32 - bits from net.IPMask.Size()
|
||||
func ipMaskContains(ip iputil.VpnIp, zeros iputil.VpnIp, testIp iputil.VpnIp) bool {
|
||||
return (testIp^ip)>>zeros == 0
|
||||
}
|
||||
|
|
|
@ -2,15 +2,14 @@ package nebula
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/header"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/slackhq/nebula/test"
|
||||
"github.com/slackhq/nebula/udp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"gopkg.in/yaml.v2"
|
||||
)
|
||||
|
@ -23,15 +22,17 @@ func TestOldIPv4Only(t *testing.T) {
|
|||
var m Ip4AndPort
|
||||
err := m.Unmarshal(b)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "10.1.1.1", iputil.VpnIp(m.GetIp()).String())
|
||||
ip := netip.MustParseAddr("10.1.1.1")
|
||||
bp := ip.As4()
|
||||
assert.Equal(t, binary.BigEndian.Uint32(bp[:]), m.GetIp())
|
||||
}
|
||||
|
||||
func TestNewLhQuery(t *testing.T) {
|
||||
myIp := net.ParseIP("192.1.1.1")
|
||||
myIpint := iputil.Ip2VpnIp(myIp)
|
||||
myIp, err := netip.ParseAddr("192.1.1.1")
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Generating a new lh query should work
|
||||
a := NewLhQueryByInt(myIpint)
|
||||
a := NewLhQueryByInt(myIp)
|
||||
|
||||
// The result should be a nebulameta protobuf
|
||||
assert.IsType(t, &NebulaMeta{}, a)
|
||||
|
@ -49,7 +50,7 @@ func TestNewLhQuery(t *testing.T) {
|
|||
|
||||
func Test_lhStaticMapping(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
_, myVpnNet, _ := net.ParseCIDR("10.128.0.1/16")
|
||||
myVpnNet := netip.MustParsePrefix("10.128.0.1/16")
|
||||
lh1 := "10.128.0.2"
|
||||
|
||||
c := config.NewC(l)
|
||||
|
@ -68,7 +69,7 @@ func Test_lhStaticMapping(t *testing.T) {
|
|||
|
||||
func TestReloadLighthouseInterval(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
_, myVpnNet, _ := net.ParseCIDR("10.128.0.1/16")
|
||||
myVpnNet := netip.MustParsePrefix("10.128.0.1/16")
|
||||
lh1 := "10.128.0.2"
|
||||
|
||||
c := config.NewC(l)
|
||||
|
@ -83,21 +84,21 @@ func TestReloadLighthouseInterval(t *testing.T) {
|
|||
lh.ifce = &mockEncWriter{}
|
||||
|
||||
// The first one routine is kicked off by main.go currently, lets make sure that one dies
|
||||
c.ReloadConfigString("lighthouse:\n interval: 5")
|
||||
assert.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 5"))
|
||||
assert.Equal(t, int64(5), lh.interval.Load())
|
||||
|
||||
// Subsequent calls are killed off by the LightHouse.Reload function
|
||||
c.ReloadConfigString("lighthouse:\n interval: 10")
|
||||
assert.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 10"))
|
||||
assert.Equal(t, int64(10), lh.interval.Load())
|
||||
|
||||
// If this completes then nothing is stealing our reload routine
|
||||
c.ReloadConfigString("lighthouse:\n interval: 11")
|
||||
assert.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 11"))
|
||||
assert.Equal(t, int64(11), lh.interval.Load())
|
||||
}
|
||||
|
||||
func BenchmarkLighthouseHandleRequest(b *testing.B) {
|
||||
l := test.NewLogger()
|
||||
_, myVpnNet, _ := net.ParseCIDR("10.128.0.1/0")
|
||||
myVpnNet := netip.MustParsePrefix("10.128.0.1/0")
|
||||
|
||||
c := config.NewC(l)
|
||||
lh, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil)
|
||||
|
@ -105,30 +106,33 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
|
|||
b.Fatal()
|
||||
}
|
||||
|
||||
hAddr := udp.NewAddrFromString("4.5.6.7:12345")
|
||||
hAddr2 := udp.NewAddrFromString("4.5.6.7:12346")
|
||||
lh.addrMap[3] = NewRemoteList(nil)
|
||||
lh.addrMap[3].unlockedSetV4(
|
||||
3,
|
||||
3,
|
||||
hAddr := netip.MustParseAddrPort("4.5.6.7:12345")
|
||||
hAddr2 := netip.MustParseAddrPort("4.5.6.7:12346")
|
||||
|
||||
vpnIp3 := netip.MustParseAddr("0.0.0.3")
|
||||
lh.addrMap[vpnIp3] = NewRemoteList(nil)
|
||||
lh.addrMap[vpnIp3].unlockedSetV4(
|
||||
vpnIp3,
|
||||
vpnIp3,
|
||||
[]*Ip4AndPort{
|
||||
NewIp4AndPort(hAddr.IP, uint32(hAddr.Port)),
|
||||
NewIp4AndPort(hAddr2.IP, uint32(hAddr2.Port)),
|
||||
NewIp4AndPortFromNetIP(hAddr.Addr(), hAddr.Port()),
|
||||
NewIp4AndPortFromNetIP(hAddr2.Addr(), hAddr2.Port()),
|
||||
},
|
||||
func(iputil.VpnIp, *Ip4AndPort) bool { return true },
|
||||
func(netip.Addr, *Ip4AndPort) bool { return true },
|
||||
)
|
||||
|
||||
rAddr := udp.NewAddrFromString("1.2.2.3:12345")
|
||||
rAddr2 := udp.NewAddrFromString("1.2.2.3:12346")
|
||||
lh.addrMap[2] = NewRemoteList(nil)
|
||||
lh.addrMap[2].unlockedSetV4(
|
||||
3,
|
||||
3,
|
||||
rAddr := netip.MustParseAddrPort("1.2.2.3:12345")
|
||||
rAddr2 := netip.MustParseAddrPort("1.2.2.3:12346")
|
||||
vpnIp2 := netip.MustParseAddr("0.0.0.3")
|
||||
lh.addrMap[vpnIp2] = NewRemoteList(nil)
|
||||
lh.addrMap[vpnIp2].unlockedSetV4(
|
||||
vpnIp3,
|
||||
vpnIp3,
|
||||
[]*Ip4AndPort{
|
||||
NewIp4AndPort(rAddr.IP, uint32(rAddr.Port)),
|
||||
NewIp4AndPort(rAddr2.IP, uint32(rAddr2.Port)),
|
||||
NewIp4AndPortFromNetIP(rAddr.Addr(), rAddr.Port()),
|
||||
NewIp4AndPortFromNetIP(rAddr2.Addr(), rAddr2.Port()),
|
||||
},
|
||||
func(iputil.VpnIp, *Ip4AndPort) bool { return true },
|
||||
func(netip.Addr, *Ip4AndPort) bool { return true },
|
||||
)
|
||||
|
||||
mw := &mockEncWriter{}
|
||||
|
@ -145,7 +149,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
|
|||
p, err := req.Marshal()
|
||||
assert.NoError(b, err)
|
||||
for n := 0; n < b.N; n++ {
|
||||
lhh.HandleRequest(rAddr, 2, p, mw)
|
||||
lhh.HandleRequest(rAddr, vpnIp2, p, mw)
|
||||
}
|
||||
})
|
||||
b.Run("found", func(b *testing.B) {
|
||||
|
@ -161,7 +165,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
|
|||
assert.NoError(b, err)
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
lhh.HandleRequest(rAddr, 2, p, mw)
|
||||
lhh.HandleRequest(rAddr, vpnIp2, p, mw)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -169,51 +173,51 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
|
|||
func TestLighthouse_Memory(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
|
||||
myUdpAddr0 := &udp.Addr{IP: net.ParseIP("10.0.0.2"), Port: 4242}
|
||||
myUdpAddr1 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4242}
|
||||
myUdpAddr2 := &udp.Addr{IP: net.ParseIP("172.16.0.2"), Port: 4242}
|
||||
myUdpAddr3 := &udp.Addr{IP: net.ParseIP("100.152.0.2"), Port: 4242}
|
||||
myUdpAddr4 := &udp.Addr{IP: net.ParseIP("24.15.0.2"), Port: 4242}
|
||||
myUdpAddr5 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4243}
|
||||
myUdpAddr6 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4244}
|
||||
myUdpAddr7 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4245}
|
||||
myUdpAddr8 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4246}
|
||||
myUdpAddr9 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4247}
|
||||
myUdpAddr10 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4248}
|
||||
myUdpAddr11 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4249}
|
||||
myVpnIp := iputil.Ip2VpnIp(net.ParseIP("10.128.0.2"))
|
||||
myUdpAddr0 := netip.MustParseAddrPort("10.0.0.2:4242")
|
||||
myUdpAddr1 := netip.MustParseAddrPort("192.168.0.2:4242")
|
||||
myUdpAddr2 := netip.MustParseAddrPort("172.16.0.2:4242")
|
||||
myUdpAddr3 := netip.MustParseAddrPort("100.152.0.2:4242")
|
||||
myUdpAddr4 := netip.MustParseAddrPort("24.15.0.2:4242")
|
||||
myUdpAddr5 := netip.MustParseAddrPort("192.168.0.2:4243")
|
||||
myUdpAddr6 := netip.MustParseAddrPort("192.168.0.2:4244")
|
||||
myUdpAddr7 := netip.MustParseAddrPort("192.168.0.2:4245")
|
||||
myUdpAddr8 := netip.MustParseAddrPort("192.168.0.2:4246")
|
||||
myUdpAddr9 := netip.MustParseAddrPort("192.168.0.2:4247")
|
||||
myUdpAddr10 := netip.MustParseAddrPort("192.168.0.2:4248")
|
||||
myUdpAddr11 := netip.MustParseAddrPort("192.168.0.2:4249")
|
||||
myVpnIp := netip.MustParseAddr("10.128.0.2")
|
||||
|
||||
theirUdpAddr0 := &udp.Addr{IP: net.ParseIP("10.0.0.3"), Port: 4242}
|
||||
theirUdpAddr1 := &udp.Addr{IP: net.ParseIP("192.168.0.3"), Port: 4242}
|
||||
theirUdpAddr2 := &udp.Addr{IP: net.ParseIP("172.16.0.3"), Port: 4242}
|
||||
theirUdpAddr3 := &udp.Addr{IP: net.ParseIP("100.152.0.3"), Port: 4242}
|
||||
theirUdpAddr4 := &udp.Addr{IP: net.ParseIP("24.15.0.3"), Port: 4242}
|
||||
theirVpnIp := iputil.Ip2VpnIp(net.ParseIP("10.128.0.3"))
|
||||
theirUdpAddr0 := netip.MustParseAddrPort("10.0.0.3:4242")
|
||||
theirUdpAddr1 := netip.MustParseAddrPort("192.168.0.3:4242")
|
||||
theirUdpAddr2 := netip.MustParseAddrPort("172.16.0.3:4242")
|
||||
theirUdpAddr3 := netip.MustParseAddrPort("100.152.0.3:4242")
|
||||
theirUdpAddr4 := netip.MustParseAddrPort("24.15.0.3:4242")
|
||||
theirVpnIp := netip.MustParseAddr("10.128.0.3")
|
||||
|
||||
c := config.NewC(l)
|
||||
c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
|
||||
c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
|
||||
lh, err := NewLightHouseFromConfig(context.Background(), l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil)
|
||||
lh, err := NewLightHouseFromConfig(context.Background(), l, c, netip.MustParsePrefix("10.128.0.1/24"), nil, nil)
|
||||
assert.NoError(t, err)
|
||||
lhh := lh.NewRequestHandler()
|
||||
|
||||
// Test that my first update responds with just that
|
||||
newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{myUdpAddr1, myUdpAddr2}, lhh)
|
||||
newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr1, myUdpAddr2}, lhh)
|
||||
r := newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
|
||||
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr2)
|
||||
|
||||
// Ensure we don't accumulate addresses
|
||||
newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{myUdpAddr3}, lhh)
|
||||
newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr3}, lhh)
|
||||
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
|
||||
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr3)
|
||||
|
||||
// Grow it back to 2
|
||||
newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{myUdpAddr1, myUdpAddr4}, lhh)
|
||||
newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr1, myUdpAddr4}, lhh)
|
||||
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
|
||||
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4)
|
||||
|
||||
// Update a different host and ask about it
|
||||
newLHHostUpdate(theirUdpAddr0, theirVpnIp, []*udp.Addr{theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4}, lhh)
|
||||
newLHHostUpdate(theirUdpAddr0, theirVpnIp, []netip.AddrPort{theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4}, lhh)
|
||||
r = newLHHostRequest(theirUdpAddr0, theirVpnIp, theirVpnIp, lhh)
|
||||
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4)
|
||||
|
||||
|
@ -233,7 +237,7 @@ func TestLighthouse_Memory(t *testing.T) {
|
|||
newLHHostUpdate(
|
||||
myUdpAddr0,
|
||||
myVpnIp,
|
||||
[]*udp.Addr{
|
||||
[]netip.AddrPort{
|
||||
myUdpAddr1,
|
||||
myUdpAddr2,
|
||||
myUdpAddr3,
|
||||
|
@ -256,10 +260,10 @@ func TestLighthouse_Memory(t *testing.T) {
|
|||
)
|
||||
|
||||
// Make sure we won't add ips in our vpn network
|
||||
bad1 := &udp.Addr{IP: net.ParseIP("10.128.0.99"), Port: 4242}
|
||||
bad2 := &udp.Addr{IP: net.ParseIP("10.128.0.100"), Port: 4242}
|
||||
good := &udp.Addr{IP: net.ParseIP("1.128.0.99"), Port: 4242}
|
||||
newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{bad1, bad2, good}, lhh)
|
||||
bad1 := netip.MustParseAddrPort("10.128.0.99:4242")
|
||||
bad2 := netip.MustParseAddrPort("10.128.0.100:4242")
|
||||
good := netip.MustParseAddrPort("1.128.0.99:4242")
|
||||
newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{bad1, bad2, good}, lhh)
|
||||
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
|
||||
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, good)
|
||||
}
|
||||
|
@ -269,7 +273,7 @@ func TestLighthouse_reload(t *testing.T) {
|
|||
c := config.NewC(l)
|
||||
c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
|
||||
c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
|
||||
lh, err := NewLightHouseFromConfig(context.Background(), l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil)
|
||||
lh, err := NewLightHouseFromConfig(context.Background(), l, c, netip.MustParsePrefix("10.128.0.1/24"), nil, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
nc := map[interface{}]interface{}{
|
||||
|
@ -285,11 +289,13 @@ func TestLighthouse_reload(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func newLHHostRequest(fromAddr *udp.Addr, myVpnIp, queryVpnIp iputil.VpnIp, lhh *LightHouseHandler) testLhReply {
|
||||
func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, lhh *LightHouseHandler) testLhReply {
|
||||
//TODO: IPV6-WORK
|
||||
bip := queryVpnIp.As4()
|
||||
req := &NebulaMeta{
|
||||
Type: NebulaMeta_HostQuery,
|
||||
Details: &NebulaMetaDetails{
|
||||
VpnIp: uint32(queryVpnIp),
|
||||
VpnIp: binary.BigEndian.Uint32(bip[:]),
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -306,17 +312,19 @@ func newLHHostRequest(fromAddr *udp.Addr, myVpnIp, queryVpnIp iputil.VpnIp, lhh
|
|||
return w.lastReply
|
||||
}
|
||||
|
||||
func newLHHostUpdate(fromAddr *udp.Addr, vpnIp iputil.VpnIp, addrs []*udp.Addr, lhh *LightHouseHandler) {
|
||||
func newLHHostUpdate(fromAddr netip.AddrPort, vpnIp netip.Addr, addrs []netip.AddrPort, lhh *LightHouseHandler) {
|
||||
//TODO: IPV6-WORK
|
||||
bip := vpnIp.As4()
|
||||
req := &NebulaMeta{
|
||||
Type: NebulaMeta_HostUpdateNotification,
|
||||
Details: &NebulaMetaDetails{
|
||||
VpnIp: uint32(vpnIp),
|
||||
VpnIp: binary.BigEndian.Uint32(bip[:]),
|
||||
Ip4AndPorts: make([]*Ip4AndPort, len(addrs)),
|
||||
},
|
||||
}
|
||||
|
||||
for k, v := range addrs {
|
||||
req.Details.Ip4AndPorts[k] = &Ip4AndPort{Ip: uint32(iputil.Ip2VpnIp(v.IP)), Port: uint32(v.Port)}
|
||||
req.Details.Ip4AndPorts[k] = NewIp4AndPortFromNetIP(v.Addr(), v.Port())
|
||||
}
|
||||
|
||||
b, err := req.Marshal()
|
||||
|
@ -394,16 +402,10 @@ func newLHHostUpdate(fromAddr *udp.Addr, vpnIp iputil.VpnIp, addrs []*udp.Addr,
|
|||
// )
|
||||
//}
|
||||
|
||||
func Test_ipMaskContains(t *testing.T) {
|
||||
assert.True(t, ipMaskContains(iputil.Ip2VpnIp(net.ParseIP("10.0.0.1")), 32-24, iputil.Ip2VpnIp(net.ParseIP("10.0.0.255"))))
|
||||
assert.False(t, ipMaskContains(iputil.Ip2VpnIp(net.ParseIP("10.0.0.1")), 32-24, iputil.Ip2VpnIp(net.ParseIP("10.0.1.1"))))
|
||||
assert.True(t, ipMaskContains(iputil.Ip2VpnIp(net.ParseIP("10.0.0.1")), 32, iputil.Ip2VpnIp(net.ParseIP("10.0.1.1"))))
|
||||
}
|
||||
|
||||
type testLhReply struct {
|
||||
nebType header.MessageType
|
||||
nebSubType header.MessageSubType
|
||||
vpnIp iputil.VpnIp
|
||||
vpnIp netip.Addr
|
||||
msg *NebulaMeta
|
||||
}
|
||||
|
||||
|
@ -414,7 +416,7 @@ type testEncWriter struct {
|
|||
|
||||
func (tw *testEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) {
|
||||
}
|
||||
func (tw *testEncWriter) Handshake(vpnIp iputil.VpnIp) {
|
||||
func (tw *testEncWriter) Handshake(vpnIp netip.Addr) {
|
||||
}
|
||||
|
||||
func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, _, _ []byte) {
|
||||
|
@ -434,7 +436,7 @@ func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.M
|
|||
}
|
||||
}
|
||||
|
||||
func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, _, _ []byte) {
|
||||
func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, _, _ []byte) {
|
||||
msg := &NebulaMeta{}
|
||||
err := msg.Unmarshal(p)
|
||||
if tw.metaFilter == nil || msg.Type == *tw.metaFilter {
|
||||
|
@ -452,35 +454,16 @@ func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.Mess
|
|||
}
|
||||
|
||||
// assertIp4InArray asserts every address in want is at the same position in have and that the lengths match
|
||||
func assertIp4InArray(t *testing.T, have []*Ip4AndPort, want ...*udp.Addr) {
|
||||
func assertIp4InArray(t *testing.T, have []*Ip4AndPort, want ...netip.AddrPort) {
|
||||
if !assert.Len(t, have, len(want)) {
|
||||
return
|
||||
}
|
||||
|
||||
for k, w := range want {
|
||||
if !(have[k].Ip == uint32(iputil.Ip2VpnIp(w.IP)) && have[k].Port == uint32(w.Port)) {
|
||||
assert.Fail(t, fmt.Sprintf("Response did not contain: %v:%v at %v; %v", w.IP, w.Port, k, translateV4toUdpAddr(have)))
|
||||
//TODO: IPV6-WORK
|
||||
h := AddrPortFromIp4AndPort(have[k])
|
||||
if !(h == w) {
|
||||
assert.Fail(t, fmt.Sprintf("Response did not contain: %v at %v, found %v", w, k, h))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// assertUdpAddrInArray asserts every address in want is at the same position in have and that the lengths match
|
||||
func assertUdpAddrInArray(t *testing.T, have []*udp.Addr, want ...*udp.Addr) {
|
||||
if !assert.Len(t, have, len(want)) {
|
||||
return
|
||||
}
|
||||
|
||||
for k, w := range want {
|
||||
if !(have[k].IP.Equal(w.IP) && have[k].Port == w.Port) {
|
||||
assert.Fail(t, fmt.Sprintf("Response did not contain: %v at %v; %v", w, k, have))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func translateV4toUdpAddr(ips []*Ip4AndPort) []*udp.Addr {
|
||||
addrs := make([]*udp.Addr, len(ips))
|
||||
for k, v := range ips {
|
||||
addrs[k] = NewUDPAddrFromLH4(v)
|
||||
}
|
||||
return addrs
|
||||
}
|
||||
|
|
30
main.go
30
main.go
|
@ -5,6 +5,7 @@ import (
|
|||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
|
@ -67,8 +68,17 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||
}
|
||||
l.WithField("firewallHashes", fw.GetRuleHashes()).Info("Firewall started")
|
||||
|
||||
// TODO: make sure mask is 4 bytes
|
||||
tunCidr := certificate.Details.Ips[0]
|
||||
ones, _ := certificate.Details.Ips[0].Mask.Size()
|
||||
addr, ok := netip.AddrFromSlice(certificate.Details.Ips[0].IP)
|
||||
if !ok {
|
||||
err = util.NewContextualError(
|
||||
"Invalid ip address in certificate",
|
||||
m{"vpnIp": certificate.Details.Ips[0].IP},
|
||||
nil,
|
||||
)
|
||||
return nil, err
|
||||
}
|
||||
tunCidr := netip.PrefixFrom(addr, ones)
|
||||
|
||||
ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
|
||||
if err != nil {
|
||||
|
@ -150,21 +160,25 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||
|
||||
if !configTest {
|
||||
rawListenHost := c.GetString("listen.host", "0.0.0.0")
|
||||
var listenHost *net.IPAddr
|
||||
var listenHost netip.Addr
|
||||
if rawListenHost == "[::]" {
|
||||
// Old guidance was to provide the literal `[::]` in `listen.host` but that won't resolve.
|
||||
listenHost = &net.IPAddr{IP: net.IPv6zero}
|
||||
listenHost = netip.IPv6Unspecified()
|
||||
|
||||
} else {
|
||||
listenHost, err = net.ResolveIPAddr("ip", rawListenHost)
|
||||
ips, err := net.DefaultResolver.LookupNetIP(context.Background(), "ip", rawListenHost)
|
||||
if err != nil {
|
||||
return nil, util.ContextualizeIfNeeded("Failed to resolve listen.host", err)
|
||||
}
|
||||
if len(ips) == 0 {
|
||||
return nil, util.ContextualizeIfNeeded("Failed to resolve listen.host", err)
|
||||
}
|
||||
listenHost = ips[0].Unmap()
|
||||
}
|
||||
|
||||
for i := 0; i < routines; i++ {
|
||||
l.Infof("listening %q %d", listenHost.IP, port)
|
||||
udpServer, err := udp.NewListener(l, listenHost.IP, port, routines > 1, c.GetInt("listen.batch", 64))
|
||||
l.Infof("listening on %v", netip.AddrPortFrom(listenHost, uint16(port)))
|
||||
udpServer, err := udp.NewListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 64))
|
||||
if err != nil {
|
||||
return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err)
|
||||
}
|
||||
|
@ -178,7 +192,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||
if err != nil {
|
||||
return nil, util.NewContextualError("Failed to get listening port", nil, err)
|
||||
}
|
||||
port = int(uPort.Port)
|
||||
port = int(uPort.Port())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
95
outside.go
95
outside.go
|
@ -4,6 +4,7 @@ import (
|
|||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/flynn/noise"
|
||||
|
@ -11,7 +12,6 @@ import (
|
|||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/firewall"
|
||||
"github.com/slackhq/nebula/header"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/slackhq/nebula/udp"
|
||||
"golang.org/x/net/ipv4"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
@ -21,9 +21,10 @@ const (
|
|||
minFwPacketLen = 4
|
||||
)
|
||||
|
||||
// TODO: IPV6-WORK this can likely be removed now
|
||||
func readOutsidePackets(f *Interface) udp.EncReader {
|
||||
return func(
|
||||
addr *udp.Addr,
|
||||
addr netip.AddrPort,
|
||||
out []byte,
|
||||
packet []byte,
|
||||
header *header.H,
|
||||
|
@ -37,27 +38,25 @@ func readOutsidePackets(f *Interface) udp.EncReader {
|
|||
}
|
||||
}
|
||||
|
||||
func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf udp.LightHouseHandlerFunc, nb []byte, q int, localCache firewall.ConntrackCache) {
|
||||
func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf udp.LightHouseHandlerFunc, nb []byte, q int, localCache firewall.ConntrackCache) {
|
||||
err := h.Parse(packet)
|
||||
if err != nil {
|
||||
// TODO: best if we return this and let caller log
|
||||
// TODO: Might be better to send the literal []byte("holepunch") packet and ignore that?
|
||||
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
|
||||
if len(packet) > 1 {
|
||||
f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", addr, err)
|
||||
f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", ip, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
//l.Error("in packet ", header, packet[HeaderLen:])
|
||||
if addr != nil {
|
||||
if ip4 := addr.IP.To4(); ip4 != nil {
|
||||
if ipMaskContains(f.lightHouse.myVpnIp, f.lightHouse.myVpnZeros, iputil.VpnIp(binary.BigEndian.Uint32(ip4))) {
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
f.l.WithField("udpAddr", addr).Debug("Refusing to process double encrypted packet")
|
||||
}
|
||||
return
|
||||
if ip.IsValid() {
|
||||
if f.myVpnNet.Contains(ip.Addr()) {
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet")
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -77,7 +76,7 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt
|
|||
switch h.Type {
|
||||
case header.Message:
|
||||
// TODO handleEncrypted sends directly to addr on error. Handle this in the tunneling case.
|
||||
if !f.handleEncrypted(ci, addr, h) {
|
||||
if !f.handleEncrypted(ci, ip, h) {
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -101,7 +100,7 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt
|
|||
// Successfully validated the thing. Get rid of the Relay header.
|
||||
signedPayload = signedPayload[header.Len:]
|
||||
// Pull the Roaming parts up here, and return in all call paths.
|
||||
f.handleHostRoaming(hostinfo, addr)
|
||||
f.handleHostRoaming(hostinfo, ip)
|
||||
// Track usage of both the HostInfo and the Relay for the received & authenticated packet
|
||||
f.connectionManager.In(hostinfo.localIndexId)
|
||||
f.connectionManager.RelayUsed(h.RemoteIndex)
|
||||
|
@ -118,7 +117,7 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt
|
|||
case TerminalType:
|
||||
// If I am the target of this relay, process the unwrapped packet
|
||||
// From this recursive point, all these variables are 'burned'. We shouldn't rely on them again.
|
||||
f.readOutsidePackets(nil, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache)
|
||||
f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache)
|
||||
return
|
||||
case ForwardingType:
|
||||
// Find the target HostInfo relay object
|
||||
|
@ -148,13 +147,13 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt
|
|||
|
||||
case header.LightHouse:
|
||||
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
||||
if !f.handleEncrypted(ci, addr, h) {
|
||||
if !f.handleEncrypted(ci, ip, h) {
|
||||
return
|
||||
}
|
||||
|
||||
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
|
||||
if err != nil {
|
||||
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr).
|
||||
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
|
||||
WithField("packet", packet).
|
||||
Error("Failed to decrypt lighthouse packet")
|
||||
|
||||
|
@ -163,19 +162,19 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt
|
|||
return
|
||||
}
|
||||
|
||||
lhf(addr, hostinfo.vpnIp, d)
|
||||
lhf(ip, hostinfo.vpnIp, d)
|
||||
|
||||
// Fallthrough to the bottom to record incoming traffic
|
||||
|
||||
case header.Test:
|
||||
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
||||
if !f.handleEncrypted(ci, addr, h) {
|
||||
if !f.handleEncrypted(ci, ip, h) {
|
||||
return
|
||||
}
|
||||
|
||||
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
|
||||
if err != nil {
|
||||
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr).
|
||||
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
|
||||
WithField("packet", packet).
|
||||
Error("Failed to decrypt test packet")
|
||||
|
||||
|
@ -187,7 +186,7 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt
|
|||
if h.Subtype == header.TestRequest {
|
||||
// This testRequest might be from TryPromoteBest, so we should roam
|
||||
// to the new IP address before responding
|
||||
f.handleHostRoaming(hostinfo, addr)
|
||||
f.handleHostRoaming(hostinfo, ip)
|
||||
f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out)
|
||||
}
|
||||
|
||||
|
@ -198,34 +197,34 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt
|
|||
|
||||
case header.Handshake:
|
||||
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
||||
f.handshakeManager.HandleIncoming(addr, via, packet, h)
|
||||
f.handshakeManager.HandleIncoming(ip, via, packet, h)
|
||||
return
|
||||
|
||||
case header.RecvError:
|
||||
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
||||
f.handleRecvError(addr, h)
|
||||
f.handleRecvError(ip, h)
|
||||
return
|
||||
|
||||
case header.CloseTunnel:
|
||||
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
||||
if !f.handleEncrypted(ci, addr, h) {
|
||||
if !f.handleEncrypted(ci, ip, h) {
|
||||
return
|
||||
}
|
||||
|
||||
hostinfo.logger(f.l).WithField("udpAddr", addr).
|
||||
hostinfo.logger(f.l).WithField("udpAddr", ip).
|
||||
Info("Close tunnel received, tearing down.")
|
||||
|
||||
f.closeTunnel(hostinfo)
|
||||
return
|
||||
|
||||
case header.Control:
|
||||
if !f.handleEncrypted(ci, addr, h) {
|
||||
if !f.handleEncrypted(ci, ip, h) {
|
||||
return
|
||||
}
|
||||
|
||||
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
|
||||
if err != nil {
|
||||
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr).
|
||||
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
|
||||
WithField("packet", packet).
|
||||
Error("Failed to decrypt Control packet")
|
||||
return
|
||||
|
@ -241,11 +240,11 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt
|
|||
|
||||
default:
|
||||
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
||||
hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", addr)
|
||||
hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", ip)
|
||||
return
|
||||
}
|
||||
|
||||
f.handleHostRoaming(hostinfo, addr)
|
||||
f.handleHostRoaming(hostinfo, ip)
|
||||
|
||||
f.connectionManager.In(hostinfo.localIndexId)
|
||||
}
|
||||
|
@ -264,34 +263,34 @@ func (f *Interface) sendCloseTunnel(h *HostInfo) {
|
|||
f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
|
||||
}
|
||||
|
||||
func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udp.Addr) {
|
||||
if addr != nil && !hostinfo.remote.Equals(addr) {
|
||||
if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, addr.IP) {
|
||||
hostinfo.logger(f.l).WithField("newAddr", addr).Debug("lighthouse.remote_allow_list denied roaming")
|
||||
func (f *Interface) handleHostRoaming(hostinfo *HostInfo, ip netip.AddrPort) {
|
||||
if ip.IsValid() && hostinfo.remote != ip {
|
||||
if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, ip.Addr()) {
|
||||
hostinfo.logger(f.l).WithField("newAddr", ip).Debug("lighthouse.remote_allow_list denied roaming")
|
||||
return
|
||||
}
|
||||
if !hostinfo.lastRoam.IsZero() && addr.Equals(hostinfo.lastRoamRemote) && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
|
||||
if !hostinfo.lastRoam.IsZero() && ip == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr).
|
||||
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", ip).
|
||||
Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr).
|
||||
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", ip).
|
||||
Info("Host roamed to new udp ip/port.")
|
||||
hostinfo.lastRoam = time.Now()
|
||||
hostinfo.lastRoamRemote = hostinfo.remote
|
||||
hostinfo.SetRemote(addr)
|
||||
hostinfo.SetRemote(ip)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (f *Interface) handleEncrypted(ci *ConnectionState, addr *udp.Addr, h *header.H) bool {
|
||||
func (f *Interface) handleEncrypted(ci *ConnectionState, addr netip.AddrPort, h *header.H) bool {
|
||||
// If connectionstate exists and the replay protector allows, process packet
|
||||
// Else, send recv errors for 300 seconds after a restart to allow fast reconnection.
|
||||
if ci == nil || !ci.window.Check(f.l, h.MessageCounter) {
|
||||
if addr != nil {
|
||||
if addr.IsValid() {
|
||||
f.maybeSendRecvError(addr, h.RemoteIndex)
|
||||
return false
|
||||
} else {
|
||||
|
@ -340,8 +339,9 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
|
|||
|
||||
// Firewall packets are locally oriented
|
||||
if incoming {
|
||||
fp.RemoteIP = iputil.Ip2VpnIp(data[12:16])
|
||||
fp.LocalIP = iputil.Ip2VpnIp(data[16:20])
|
||||
//TODO: IPV6-WORK
|
||||
fp.RemoteIP, _ = netip.AddrFromSlice(data[12:16])
|
||||
fp.LocalIP, _ = netip.AddrFromSlice(data[16:20])
|
||||
if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
|
||||
fp.RemotePort = 0
|
||||
fp.LocalPort = 0
|
||||
|
@ -350,8 +350,9 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
|
|||
fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4])
|
||||
}
|
||||
} else {
|
||||
fp.LocalIP = iputil.Ip2VpnIp(data[12:16])
|
||||
fp.RemoteIP = iputil.Ip2VpnIp(data[16:20])
|
||||
//TODO: IPV6-WORK
|
||||
fp.LocalIP, _ = netip.AddrFromSlice(data[12:16])
|
||||
fp.RemoteIP, _ = netip.AddrFromSlice(data[16:20])
|
||||
if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
|
||||
fp.RemotePort = 0
|
||||
fp.LocalPort = 0
|
||||
|
@ -425,13 +426,13 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
|
|||
return true
|
||||
}
|
||||
|
||||
func (f *Interface) maybeSendRecvError(endpoint *udp.Addr, index uint32) {
|
||||
if f.sendRecvErrorConfig.ShouldSendRecvError(endpoint.IP) {
|
||||
func (f *Interface) maybeSendRecvError(endpoint netip.AddrPort, index uint32) {
|
||||
if f.sendRecvErrorConfig.ShouldSendRecvError(endpoint) {
|
||||
f.sendRecvError(endpoint, index)
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Interface) sendRecvError(endpoint *udp.Addr, index uint32) {
|
||||
func (f *Interface) sendRecvError(endpoint netip.AddrPort, index uint32) {
|
||||
f.messageMetrics.Tx(header.RecvError, 0, 1)
|
||||
|
||||
//TODO: this should be a signed message so we can trust that we should drop the index
|
||||
|
@ -444,7 +445,7 @@ func (f *Interface) sendRecvError(endpoint *udp.Addr, index uint32) {
|
|||
}
|
||||
}
|
||||
|
||||
func (f *Interface) handleRecvError(addr *udp.Addr, h *header.H) {
|
||||
func (f *Interface) handleRecvError(addr netip.AddrPort, h *header.H) {
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
f.l.WithField("index", h.RemoteIndex).
|
||||
WithField("udpAddr", addr).
|
||||
|
@ -461,7 +462,7 @@ func (f *Interface) handleRecvError(addr *udp.Addr, h *header.H) {
|
|||
return
|
||||
}
|
||||
|
||||
if hostinfo.remote != nil && !hostinfo.remote.Equals(addr) {
|
||||
if hostinfo.remote.IsValid() && hostinfo.remote != addr {
|
||||
f.l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -2,10 +2,10 @@ package nebula
|
|||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/slackhq/nebula/firewall"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/net/ipv4"
|
||||
)
|
||||
|
@ -55,8 +55,8 @@ func Test_newPacket(t *testing.T) {
|
|||
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, p.Protocol, uint8(firewall.ProtoTCP))
|
||||
assert.Equal(t, p.LocalIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 2)))
|
||||
assert.Equal(t, p.RemoteIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 1)))
|
||||
assert.Equal(t, p.LocalIP, netip.MustParseAddr("10.0.0.2"))
|
||||
assert.Equal(t, p.RemoteIP, netip.MustParseAddr("10.0.0.1"))
|
||||
assert.Equal(t, p.RemotePort, uint16(3))
|
||||
assert.Equal(t, p.LocalPort, uint16(4))
|
||||
|
||||
|
@ -76,8 +76,8 @@ func Test_newPacket(t *testing.T) {
|
|||
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, p.Protocol, uint8(2))
|
||||
assert.Equal(t, p.LocalIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 1)))
|
||||
assert.Equal(t, p.RemoteIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 2)))
|
||||
assert.Equal(t, p.LocalIP, netip.MustParseAddr("10.0.0.1"))
|
||||
assert.Equal(t, p.RemoteIP, netip.MustParseAddr("10.0.0.2"))
|
||||
assert.Equal(t, p.RemotePort, uint16(6))
|
||||
assert.Equal(t, p.LocalPort, uint16(5))
|
||||
}
|
||||
|
|
|
@ -2,16 +2,14 @@ package overlay
|
|||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
type Device interface {
|
||||
io.ReadWriteCloser
|
||||
Activate() error
|
||||
Cidr() *net.IPNet
|
||||
Cidr() netip.Prefix
|
||||
Name() string
|
||||
RouteFor(iputil.VpnIp) iputil.VpnIp
|
||||
RouteFor(netip.Addr) netip.Addr
|
||||
NewMultiQueueReader() (io.ReadWriteCloser, error)
|
||||
}
|
||||
|
|
|
@ -1,34 +1,30 @@
|
|||
package overlay
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"math"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"strconv"
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cidr"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
)
|
||||
|
||||
type Route struct {
|
||||
MTU int
|
||||
Metric int
|
||||
Cidr *net.IPNet
|
||||
Via *iputil.VpnIp
|
||||
Cidr netip.Prefix
|
||||
Via netip.Addr
|
||||
Install bool
|
||||
}
|
||||
|
||||
// Equal determines if a route that could be installed in the system route table is equal to another
|
||||
// Via is ignored since that is only consumed within nebula itself
|
||||
func (r Route) Equal(t Route) bool {
|
||||
if !r.Cidr.IP.Equal(t.Cidr.IP) {
|
||||
return false
|
||||
}
|
||||
if !bytes.Equal(r.Cidr.Mask, t.Cidr.Mask) {
|
||||
if r.Cidr != t.Cidr {
|
||||
return false
|
||||
}
|
||||
if r.Metric != t.Metric {
|
||||
|
@ -51,21 +47,21 @@ func (r Route) String() string {
|
|||
return s
|
||||
}
|
||||
|
||||
func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*cidr.Tree4[iputil.VpnIp], error) {
|
||||
routeTree := cidr.NewTree4[iputil.VpnIp]()
|
||||
func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*bart.Table[netip.Addr], error) {
|
||||
routeTree := new(bart.Table[netip.Addr])
|
||||
for _, r := range routes {
|
||||
if !allowMTU && r.MTU > 0 {
|
||||
l.WithField("route", r).Warnf("route MTU is not supported in %s", runtime.GOOS)
|
||||
}
|
||||
|
||||
if r.Via != nil {
|
||||
routeTree.AddCIDR(r.Cidr, *r.Via)
|
||||
if r.Via.IsValid() {
|
||||
routeTree.Insert(r.Cidr, r.Via)
|
||||
}
|
||||
}
|
||||
return routeTree, nil
|
||||
}
|
||||
|
||||
func parseRoutes(c *config.C, network *net.IPNet) ([]Route, error) {
|
||||
func parseRoutes(c *config.C, network netip.Prefix) ([]Route, error) {
|
||||
var err error
|
||||
|
||||
r := c.Get("tun.routes")
|
||||
|
@ -116,12 +112,12 @@ func parseRoutes(c *config.C, network *net.IPNet) ([]Route, error) {
|
|||
MTU: mtu,
|
||||
}
|
||||
|
||||
_, r.Cidr, err = net.ParseCIDR(fmt.Sprintf("%v", rRoute))
|
||||
r.Cidr, err = netip.ParsePrefix(fmt.Sprintf("%v", rRoute))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("entry %v.route in tun.routes failed to parse: %v", i+1, err)
|
||||
}
|
||||
|
||||
if !ipWithin(network, r.Cidr) {
|
||||
if !network.Contains(r.Cidr.Addr()) || r.Cidr.Bits() < network.Bits() {
|
||||
return nil, fmt.Errorf(
|
||||
"entry %v.route in tun.routes is not contained within the network attached to the certificate; route: %v, network: %v",
|
||||
i+1,
|
||||
|
@ -136,7 +132,7 @@ func parseRoutes(c *config.C, network *net.IPNet) ([]Route, error) {
|
|||
return routes, nil
|
||||
}
|
||||
|
||||
func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]Route, error) {
|
||||
func parseUnsafeRoutes(c *config.C, network netip.Prefix) ([]Route, error) {
|
||||
var err error
|
||||
|
||||
r := c.Get("tun.unsafe_routes")
|
||||
|
@ -202,9 +198,9 @@ func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]Route, error) {
|
|||
return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not a string: found %T", i+1, rVia)
|
||||
}
|
||||
|
||||
nVia := net.ParseIP(via)
|
||||
if nVia == nil {
|
||||
return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes failed to parse address: %v", i+1, via)
|
||||
viaVpnIp, err := netip.ParseAddr(via)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes failed to parse address: %v", i+1, err)
|
||||
}
|
||||
|
||||
rRoute, ok := m["route"]
|
||||
|
@ -212,8 +208,6 @@ func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]Route, error) {
|
|||
return nil, fmt.Errorf("entry %v.route in tun.unsafe_routes is not present", i+1)
|
||||
}
|
||||
|
||||
viaVpnIp := iputil.Ip2VpnIp(nVia)
|
||||
|
||||
install := true
|
||||
rInstall, ok := m["install"]
|
||||
if ok {
|
||||
|
@ -224,18 +218,18 @@ func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]Route, error) {
|
|||
}
|
||||
|
||||
r := Route{
|
||||
Via: &viaVpnIp,
|
||||
Via: viaVpnIp,
|
||||
MTU: mtu,
|
||||
Metric: metric,
|
||||
Install: install,
|
||||
}
|
||||
|
||||
_, r.Cidr, err = net.ParseCIDR(fmt.Sprintf("%v", rRoute))
|
||||
r.Cidr, err = netip.ParsePrefix(fmt.Sprintf("%v", rRoute))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("entry %v.route in tun.unsafe_routes failed to parse: %v", i+1, err)
|
||||
}
|
||||
|
||||
if ipWithin(network, r.Cidr) {
|
||||
if network.Contains(r.Cidr.Addr()) {
|
||||
return nil, fmt.Errorf(
|
||||
"entry %v.route in tun.unsafe_routes is contained within the network attached to the certificate; route: %v, network: %v",
|
||||
i+1,
|
||||
|
|
|
@ -2,11 +2,10 @@ package overlay
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/slackhq/nebula/test"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
@ -14,7 +13,8 @@ import (
|
|||
func Test_parseRoutes(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
c := config.NewC(l)
|
||||
_, n, _ := net.ParseCIDR("10.0.0.0/24")
|
||||
n, err := netip.ParsePrefix("10.0.0.0/24")
|
||||
assert.NoError(t, err)
|
||||
|
||||
// test no routes config
|
||||
routes, err := parseRoutes(c, n)
|
||||
|
@ -67,7 +67,7 @@ func Test_parseRoutes(t *testing.T) {
|
|||
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "nope"}}}
|
||||
routes, err = parseRoutes(c, n)
|
||||
assert.Nil(t, routes)
|
||||
assert.EqualError(t, err, "entry 1.route in tun.routes failed to parse: invalid CIDR address: nope")
|
||||
assert.EqualError(t, err, "entry 1.route in tun.routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'")
|
||||
|
||||
// below network range
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "1.0.0.0/8"}}}
|
||||
|
@ -112,7 +112,8 @@ func Test_parseRoutes(t *testing.T) {
|
|||
func Test_parseUnsafeRoutes(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
c := config.NewC(l)
|
||||
_, n, _ := net.ParseCIDR("10.0.0.0/24")
|
||||
n, err := netip.ParsePrefix("10.0.0.0/24")
|
||||
assert.NoError(t, err)
|
||||
|
||||
// test no routes config
|
||||
routes, err := parseUnsafeRoutes(c, n)
|
||||
|
@ -157,7 +158,7 @@ func Test_parseUnsafeRoutes(t *testing.T) {
|
|||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": "nope"}}}
|
||||
routes, err = parseUnsafeRoutes(c, n)
|
||||
assert.Nil(t, routes)
|
||||
assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: nope")
|
||||
assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: ParseAddr(\"nope\"): unable to parse IP")
|
||||
|
||||
// missing route
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500"}}}
|
||||
|
@ -169,7 +170,7 @@ func Test_parseUnsafeRoutes(t *testing.T) {
|
|||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500", "route": "nope"}}}
|
||||
routes, err = parseUnsafeRoutes(c, n)
|
||||
assert.Nil(t, routes)
|
||||
assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: invalid CIDR address: nope")
|
||||
assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'")
|
||||
|
||||
// within network range
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.0.0/24"}}}
|
||||
|
@ -252,7 +253,8 @@ func Test_parseUnsafeRoutes(t *testing.T) {
|
|||
func Test_makeRouteTree(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
c := config.NewC(l)
|
||||
_, n, _ := net.ParseCIDR("10.0.0.0/24")
|
||||
n, err := netip.ParsePrefix("10.0.0.0/24")
|
||||
assert.NoError(t, err)
|
||||
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{
|
||||
map[interface{}]interface{}{"via": "192.168.0.1", "route": "1.0.0.0/28"},
|
||||
|
@ -264,17 +266,26 @@ func Test_makeRouteTree(t *testing.T) {
|
|||
routeTree, err := makeRouteTree(l, routes, true)
|
||||
assert.NoError(t, err)
|
||||
|
||||
ip := iputil.Ip2VpnIp(net.ParseIP("1.0.0.2"))
|
||||
ok, r := routeTree.MostSpecificContains(ip)
|
||||
ip, err := netip.ParseAddr("1.0.0.2")
|
||||
assert.NoError(t, err)
|
||||
r, ok := routeTree.Lookup(ip)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, iputil.Ip2VpnIp(net.ParseIP("192.168.0.1")), r)
|
||||
|
||||
ip = iputil.Ip2VpnIp(net.ParseIP("1.0.0.1"))
|
||||
ok, r = routeTree.MostSpecificContains(ip)
|
||||
nip, err := netip.ParseAddr("192.168.0.1")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, nip, r)
|
||||
|
||||
ip, err = netip.ParseAddr("1.0.0.1")
|
||||
assert.NoError(t, err)
|
||||
r, ok = routeTree.Lookup(ip)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, iputil.Ip2VpnIp(net.ParseIP("192.168.0.2")), r)
|
||||
|
||||
ip = iputil.Ip2VpnIp(net.ParseIP("1.1.0.1"))
|
||||
ok, r = routeTree.MostSpecificContains(ip)
|
||||
nip, err = netip.ParseAddr("192.168.0.2")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, nip, r)
|
||||
|
||||
ip, err = netip.ParseAddr("1.1.0.1")
|
||||
assert.NoError(t, err)
|
||||
r, ok = routeTree.Lookup(ip)
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
package overlay
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
|
@ -11,9 +11,9 @@ import (
|
|||
const DefaultMTU = 1300
|
||||
|
||||
// TODO: We may be able to remove routines
|
||||
type DeviceFactory func(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error)
|
||||
type DeviceFactory func(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error)
|
||||
|
||||
func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) {
|
||||
func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) {
|
||||
switch {
|
||||
case c.GetBool("tun.disabled", false):
|
||||
tun := newDisabledTun(tunCidr, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
|
||||
|
@ -25,12 +25,12 @@ func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, rout
|
|||
}
|
||||
|
||||
func NewFdDeviceFromConfig(fd *int) DeviceFactory {
|
||||
return func(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) {
|
||||
return func(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) {
|
||||
return newTunFromFd(c, l, *fd, tunCidr)
|
||||
}
|
||||
}
|
||||
|
||||
func getAllRoutesFromConfig(c *config.C, cidr *net.IPNet, initial bool) (bool, []Route, error) {
|
||||
func getAllRoutesFromConfig(c *config.C, cidr netip.Prefix, initial bool) (bool, []Route, error) {
|
||||
if !initial && !c.HasChanged("tun.routes") && !c.HasChanged("tun.unsafe_routes") {
|
||||
return false, nil, nil
|
||||
}
|
||||
|
|
|
@ -6,27 +6,26 @@ package overlay
|
|||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cidr"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/slackhq/nebula/util"
|
||||
)
|
||||
|
||||
type tun struct {
|
||||
io.ReadWriteCloser
|
||||
fd int
|
||||
cidr *net.IPNet
|
||||
cidr netip.Prefix
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
|
||||
routeTree atomic.Pointer[bart.Table[netip.Addr]]
|
||||
l *logrus.Logger
|
||||
}
|
||||
|
||||
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr *net.IPNet) (*tun, error) {
|
||||
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) {
|
||||
// XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly.
|
||||
// Be sure not to call file.Fd() as it will set the fd to blocking mode.
|
||||
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
|
||||
|
@ -53,12 +52,12 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr *net.IPNet)
|
|||
return t, nil
|
||||
}
|
||||
|
||||
func newTun(_ *config.C, _ *logrus.Logger, _ *net.IPNet, _ bool) (*tun, error) {
|
||||
func newTun(_ *config.C, _ *logrus.Logger, _ netip.Prefix, _ bool) (*tun, error) {
|
||||
return nil, fmt.Errorf("newTun not supported in Android")
|
||||
}
|
||||
|
||||
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
|
||||
_, r := t.routeTree.Load().MostSpecificContains(ip)
|
||||
func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
|
||||
r, _ := t.routeTree.Load().Lookup(ip)
|
||||
return r
|
||||
}
|
||||
|
||||
|
@ -87,7 +86,7 @@ func (t *tun) reload(c *config.C, initial bool) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) Cidr() *net.IPNet {
|
||||
func (t *tun) Cidr() netip.Prefix {
|
||||
return t.cidr
|
||||
}
|
||||
|
||||
|
|
|
@ -8,15 +8,15 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cidr"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/slackhq/nebula/util"
|
||||
netroute "golang.org/x/net/route"
|
||||
"golang.org/x/sys/unix"
|
||||
|
@ -25,10 +25,10 @@ import (
|
|||
type tun struct {
|
||||
io.ReadWriteCloser
|
||||
Device string
|
||||
cidr *net.IPNet
|
||||
cidr netip.Prefix
|
||||
DefaultMTU int
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
|
||||
routeTree atomic.Pointer[bart.Table[netip.Addr]]
|
||||
linkAddr *netroute.LinkAddr
|
||||
l *logrus.Logger
|
||||
|
||||
|
@ -73,7 +73,7 @@ type ifreqMTU struct {
|
|||
pad [8]byte
|
||||
}
|
||||
|
||||
func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error) {
|
||||
func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) {
|
||||
name := c.GetString("tun.dev", "")
|
||||
ifIndex := -1
|
||||
if name != "" && name != "utun" {
|
||||
|
@ -172,7 +172,7 @@ func (t *tun) deviceBytes() (o [16]byte) {
|
|||
return
|
||||
}
|
||||
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*tun, error) {
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
|
||||
}
|
||||
|
||||
|
@ -188,8 +188,13 @@ func (t *tun) Activate() error {
|
|||
|
||||
var addr, mask [4]byte
|
||||
|
||||
copy(addr[:], t.cidr.IP.To4())
|
||||
copy(mask[:], t.cidr.Mask)
|
||||
if !t.cidr.Addr().Is4() {
|
||||
//TODO: IPV6-WORK
|
||||
panic("need ipv6")
|
||||
}
|
||||
|
||||
addr = t.cidr.Addr().As4()
|
||||
copy(mask[:], prefixToMask(t.cidr))
|
||||
|
||||
s, err := unix.Socket(
|
||||
unix.AF_INET,
|
||||
|
@ -329,13 +334,12 @@ func (t *tun) reload(c *config.C, initial bool) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
|
||||
ok, r := t.routeTree.Load().MostSpecificContains(ip)
|
||||
func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
|
||||
r, ok := t.routeTree.Load().Lookup(ip)
|
||||
if ok {
|
||||
return r
|
||||
}
|
||||
|
||||
return 0
|
||||
return netip.Addr{}
|
||||
}
|
||||
|
||||
// Get the LinkAddr for the interface of the given name
|
||||
|
@ -384,13 +388,19 @@ func (t *tun) addRoutes(logErrors bool) error {
|
|||
maskAddr := &netroute.Inet4Addr{}
|
||||
routes := *t.Routes.Load()
|
||||
for _, r := range routes {
|
||||
if r.Via == nil || !r.Install {
|
||||
if !r.Via.IsValid() || !r.Install {
|
||||
// We don't allow route MTUs so only install routes with a via
|
||||
continue
|
||||
}
|
||||
|
||||
copy(routeAddr.IP[:], r.Cidr.IP.To4())
|
||||
copy(maskAddr.IP[:], net.IP(r.Cidr.Mask).To4())
|
||||
if !r.Cidr.Addr().Is4() {
|
||||
//TODO: implement ipv6
|
||||
panic("Cant handle ipv6 routes yet")
|
||||
}
|
||||
|
||||
routeAddr.IP = r.Cidr.Addr().As4()
|
||||
//TODO: we could avoid the copy
|
||||
copy(maskAddr.IP[:], prefixToMask(r.Cidr))
|
||||
|
||||
err := addRoute(routeSock, routeAddr, maskAddr, t.linkAddr)
|
||||
if err != nil {
|
||||
|
@ -435,8 +445,13 @@ func (t *tun) removeRoutes(routes []Route) error {
|
|||
continue
|
||||
}
|
||||
|
||||
copy(routeAddr.IP[:], r.Cidr.IP.To4())
|
||||
copy(maskAddr.IP[:], net.IP(r.Cidr.Mask).To4())
|
||||
if r.Cidr.Addr().Is6() {
|
||||
//TODO: implement ipv6
|
||||
panic("Cant handle ipv6 routes yet")
|
||||
}
|
||||
|
||||
routeAddr.IP = r.Cidr.Addr().As4()
|
||||
copy(maskAddr.IP[:], prefixToMask(r.Cidr))
|
||||
|
||||
err := delRoute(routeSock, routeAddr, maskAddr, t.linkAddr)
|
||||
if err != nil {
|
||||
|
@ -536,7 +551,7 @@ func (t *tun) Write(from []byte) (int, error) {
|
|||
return n - 4, err
|
||||
}
|
||||
|
||||
func (t *tun) Cidr() *net.IPNet {
|
||||
func (t *tun) Cidr() netip.Prefix {
|
||||
return t.cidr
|
||||
}
|
||||
|
||||
|
@ -547,3 +562,11 @@ func (t *tun) Name() string {
|
|||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
|
||||
}
|
||||
|
||||
func prefixToMask(prefix netip.Prefix) []byte {
|
||||
pLen := 128
|
||||
if prefix.Addr().Is4() {
|
||||
pLen = 32
|
||||
}
|
||||
return net.CIDRMask(prefix.Bits(), pLen)
|
||||
}
|
||||
|
|
|
@ -3,7 +3,7 @@ package overlay
|
|||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
|
||||
"github.com/rcrowley/go-metrics"
|
||||
|
@ -13,7 +13,7 @@ import (
|
|||
|
||||
type disabledTun struct {
|
||||
read chan []byte
|
||||
cidr *net.IPNet
|
||||
cidr netip.Prefix
|
||||
|
||||
// Track these metrics since we don't have the tun device to do it for us
|
||||
tx metrics.Counter
|
||||
|
@ -21,7 +21,7 @@ type disabledTun struct {
|
|||
l *logrus.Logger
|
||||
}
|
||||
|
||||
func newDisabledTun(cidr *net.IPNet, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun {
|
||||
func newDisabledTun(cidr netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun {
|
||||
tun := &disabledTun{
|
||||
cidr: cidr,
|
||||
read: make(chan []byte, queueLen),
|
||||
|
@ -43,11 +43,11 @@ func (*disabledTun) Activate() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (*disabledTun) RouteFor(iputil.VpnIp) iputil.VpnIp {
|
||||
return 0
|
||||
func (*disabledTun) RouteFor(addr netip.Addr) netip.Addr {
|
||||
return netip.Addr{}
|
||||
}
|
||||
|
||||
func (t *disabledTun) Cidr() *net.IPNet {
|
||||
func (t *disabledTun) Cidr() netip.Prefix {
|
||||
return t.cidr
|
||||
}
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
|
@ -17,10 +17,9 @@ import (
|
|||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cidr"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/slackhq/nebula/util"
|
||||
)
|
||||
|
||||
|
@ -48,10 +47,10 @@ type ifreqDestroy struct {
|
|||
|
||||
type tun struct {
|
||||
Device string
|
||||
cidr *net.IPNet
|
||||
cidr netip.Prefix
|
||||
MTU int
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
|
||||
routeTree atomic.Pointer[bart.Table[netip.Addr]]
|
||||
l *logrus.Logger
|
||||
|
||||
io.ReadWriteCloser
|
||||
|
@ -79,11 +78,11 @@ func (t *tun) Close() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*tun, error) {
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
|
||||
}
|
||||
|
||||
func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error) {
|
||||
func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) {
|
||||
// Try to open existing tun device
|
||||
var file *os.File
|
||||
var err error
|
||||
|
@ -174,7 +173,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error
|
|||
func (t *tun) Activate() error {
|
||||
var err error
|
||||
// TODO use syscalls instead of exec.Command
|
||||
cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String())
|
||||
cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String())
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err = cmd.Run(); err != nil {
|
||||
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
||||
|
@ -233,12 +232,12 @@ func (t *tun) reload(c *config.C, initial bool) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
|
||||
_, r := t.routeTree.Load().MostSpecificContains(ip)
|
||||
func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
|
||||
r, _ := t.routeTree.Load().Lookup(ip)
|
||||
return r
|
||||
}
|
||||
|
||||
func (t *tun) Cidr() *net.IPNet {
|
||||
func (t *tun) Cidr() netip.Prefix {
|
||||
return t.cidr
|
||||
}
|
||||
|
||||
|
@ -253,7 +252,7 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
|||
func (t *tun) addRoutes(logErrors bool) error {
|
||||
routes := *t.Routes.Load()
|
||||
for _, r := range routes {
|
||||
if r.Via == nil || !r.Install {
|
||||
if !r.Via.IsValid() || !r.Install {
|
||||
// We don't allow route MTUs so only install routes with a via
|
||||
continue
|
||||
}
|
||||
|
|
|
@ -7,32 +7,31 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cidr"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/slackhq/nebula/util"
|
||||
)
|
||||
|
||||
type tun struct {
|
||||
io.ReadWriteCloser
|
||||
cidr *net.IPNet
|
||||
cidr netip.Prefix
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
|
||||
routeTree atomic.Pointer[bart.Table[netip.Addr]]
|
||||
l *logrus.Logger
|
||||
}
|
||||
|
||||
func newTun(_ *config.C, _ *logrus.Logger, _ *net.IPNet, _ bool) (*tun, error) {
|
||||
func newTun(_ *config.C, _ *logrus.Logger, _ netip.Prefix, _ bool) (*tun, error) {
|
||||
return nil, fmt.Errorf("newTun not supported in iOS")
|
||||
}
|
||||
|
||||
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr *net.IPNet) (*tun, error) {
|
||||
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) {
|
||||
file := os.NewFile(uintptr(deviceFd), "/dev/tun")
|
||||
t := &tun{
|
||||
cidr: cidr,
|
||||
|
@ -80,8 +79,8 @@ func (t *tun) reload(c *config.C, initial bool) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
|
||||
_, r := t.routeTree.Load().MostSpecificContains(ip)
|
||||
func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
|
||||
r, _ := t.routeTree.Load().Lookup(ip)
|
||||
return r
|
||||
}
|
||||
|
||||
|
@ -143,7 +142,7 @@ func (tr *tunReadCloser) Close() error {
|
|||
return tr.f.Close()
|
||||
}
|
||||
|
||||
func (t *tun) Cidr() *net.IPNet {
|
||||
func (t *tun) Cidr() netip.Prefix {
|
||||
return t.cidr
|
||||
}
|
||||
|
||||
|
|
|
@ -4,19 +4,18 @@
|
|||
package overlay
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"unsafe"
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cidr"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/slackhq/nebula/util"
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.org/x/sys/unix"
|
||||
|
@ -26,7 +25,7 @@ type tun struct {
|
|||
io.ReadWriteCloser
|
||||
fd int
|
||||
Device string
|
||||
cidr *net.IPNet
|
||||
cidr netip.Prefix
|
||||
MaxMTU int
|
||||
DefaultMTU int
|
||||
TXQueueLen int
|
||||
|
@ -34,7 +33,7 @@ type tun struct {
|
|||
ioctlFd uintptr
|
||||
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
|
||||
routeTree atomic.Pointer[bart.Table[netip.Addr]]
|
||||
routeChan chan struct{}
|
||||
useSystemRoutes bool
|
||||
|
||||
|
@ -65,7 +64,7 @@ type ifreqQLEN struct {
|
|||
pad [8]byte
|
||||
}
|
||||
|
||||
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr *net.IPNet) (*tun, error) {
|
||||
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) {
|
||||
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
|
||||
|
||||
t, err := newTunGeneric(c, l, file, cidr)
|
||||
|
@ -78,7 +77,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr *net.IPNet)
|
|||
return t, nil
|
||||
}
|
||||
|
||||
func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, multiqueue bool) (*tun, error) {
|
||||
func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (*tun, error) {
|
||||
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
||||
if err != nil {
|
||||
// If /dev/net/tun doesn't exist, try to create it (will happen in docker)
|
||||
|
@ -123,7 +122,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, multiqueue bool) (*t
|
|||
return t, nil
|
||||
}
|
||||
|
||||
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, cidr *net.IPNet) (*tun, error) {
|
||||
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, cidr netip.Prefix) (*tun, error) {
|
||||
t := &tun{
|
||||
ReadWriteCloser: file,
|
||||
fd: int(file.Fd()),
|
||||
|
@ -231,8 +230,8 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
|||
return file, nil
|
||||
}
|
||||
|
||||
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
|
||||
_, r := t.routeTree.Load().MostSpecificContains(ip)
|
||||
func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
|
||||
r, _ := t.routeTree.Load().Lookup(ip)
|
||||
return r
|
||||
}
|
||||
|
||||
|
@ -275,8 +274,10 @@ func (t *tun) Activate() error {
|
|||
|
||||
var addr, mask [4]byte
|
||||
|
||||
copy(addr[:], t.cidr.IP.To4())
|
||||
copy(mask[:], t.cidr.Mask)
|
||||
//TODO: IPV6-WORK
|
||||
addr = t.cidr.Addr().As4()
|
||||
tmask := net.CIDRMask(t.cidr.Bits(), 32)
|
||||
copy(mask[:], tmask)
|
||||
|
||||
s, err := unix.Socket(
|
||||
unix.AF_INET,
|
||||
|
@ -364,14 +365,19 @@ func (t *tun) setMTU() {
|
|||
|
||||
func (t *tun) setDefaultRoute() error {
|
||||
// Default route
|
||||
dr := &net.IPNet{IP: t.cidr.IP.Mask(t.cidr.Mask), Mask: t.cidr.Mask}
|
||||
|
||||
dr := &net.IPNet{
|
||||
IP: t.cidr.Masked().Addr().AsSlice(),
|
||||
Mask: net.CIDRMask(t.cidr.Bits(), t.cidr.Addr().BitLen()),
|
||||
}
|
||||
|
||||
nr := netlink.Route{
|
||||
LinkIndex: t.deviceIndex,
|
||||
Dst: dr,
|
||||
MTU: t.DefaultMTU,
|
||||
AdvMSS: t.advMSS(Route{}),
|
||||
Scope: unix.RT_SCOPE_LINK,
|
||||
Src: t.cidr.IP,
|
||||
Src: net.IP(t.cidr.Addr().AsSlice()),
|
||||
Protocol: unix.RTPROT_KERNEL,
|
||||
Table: unix.RT_TABLE_MAIN,
|
||||
Type: unix.RTN_UNICAST,
|
||||
|
@ -392,9 +398,14 @@ func (t *tun) addRoutes(logErrors bool) error {
|
|||
continue
|
||||
}
|
||||
|
||||
dr := &net.IPNet{
|
||||
IP: r.Cidr.Masked().Addr().AsSlice(),
|
||||
Mask: net.CIDRMask(r.Cidr.Bits(), r.Cidr.Addr().BitLen()),
|
||||
}
|
||||
|
||||
nr := netlink.Route{
|
||||
LinkIndex: t.deviceIndex,
|
||||
Dst: r.Cidr,
|
||||
Dst: dr,
|
||||
MTU: r.MTU,
|
||||
AdvMSS: t.advMSS(r),
|
||||
Scope: unix.RT_SCOPE_LINK,
|
||||
|
@ -426,9 +437,14 @@ func (t *tun) removeRoutes(routes []Route) {
|
|||
continue
|
||||
}
|
||||
|
||||
dr := &net.IPNet{
|
||||
IP: r.Cidr.Masked().Addr().AsSlice(),
|
||||
Mask: net.CIDRMask(r.Cidr.Bits(), r.Cidr.Addr().BitLen()),
|
||||
}
|
||||
|
||||
nr := netlink.Route{
|
||||
LinkIndex: t.deviceIndex,
|
||||
Dst: r.Cidr,
|
||||
Dst: dr,
|
||||
MTU: r.MTU,
|
||||
AdvMSS: t.advMSS(r),
|
||||
Scope: unix.RT_SCOPE_LINK,
|
||||
|
@ -447,7 +463,7 @@ func (t *tun) removeRoutes(routes []Route) {
|
|||
}
|
||||
}
|
||||
|
||||
func (t *tun) Cidr() *net.IPNet {
|
||||
func (t *tun) Cidr() netip.Prefix {
|
||||
return t.cidr
|
||||
}
|
||||
|
||||
|
@ -499,7 +515,15 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
|
|||
return
|
||||
}
|
||||
|
||||
if !t.cidr.Contains(r.Gw) {
|
||||
//TODO: IPV6-WORK what if not ok?
|
||||
gwAddr, ok := netip.AddrFromSlice(r.Gw)
|
||||
if !ok {
|
||||
t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway address")
|
||||
return
|
||||
}
|
||||
|
||||
gwAddr = gwAddr.Unmap()
|
||||
if !t.cidr.Contains(gwAddr) {
|
||||
// Gateway isn't in our overlay network, ignore
|
||||
t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
|
||||
return
|
||||
|
@ -511,28 +535,25 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
|
|||
return
|
||||
}
|
||||
|
||||
newTree := cidr.NewTree4[iputil.VpnIp]()
|
||||
if r.Type == unix.RTM_NEWROUTE {
|
||||
for _, oldR := range t.routeTree.Load().List() {
|
||||
newTree.AddCIDR(oldR.CIDR, oldR.Value)
|
||||
}
|
||||
|
||||
t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Adding route")
|
||||
newTree.AddCIDR(r.Dst, iputil.Ip2VpnIp(r.Gw))
|
||||
|
||||
} else {
|
||||
gw := iputil.Ip2VpnIp(r.Gw)
|
||||
for _, oldR := range t.routeTree.Load().List() {
|
||||
if bytes.Equal(oldR.CIDR.IP, r.Dst.IP) && bytes.Equal(oldR.CIDR.Mask, r.Dst.Mask) && oldR.Value == gw {
|
||||
// This is the record to delete
|
||||
t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Removing route")
|
||||
continue
|
||||
}
|
||||
|
||||
newTree.AddCIDR(oldR.CIDR, oldR.Value)
|
||||
}
|
||||
dstAddr, ok := netip.AddrFromSlice(r.Dst.IP)
|
||||
if !ok {
|
||||
t.l.WithField("route", r).Debug("Ignoring route update, invalid destination address")
|
||||
return
|
||||
}
|
||||
|
||||
ones, _ := r.Dst.Mask.Size()
|
||||
dst := netip.PrefixFrom(dstAddr, ones)
|
||||
|
||||
newTree := t.routeTree.Load().Clone()
|
||||
|
||||
if r.Type == unix.RTM_NEWROUTE {
|
||||
t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Adding route")
|
||||
newTree.Insert(dst, gwAddr)
|
||||
|
||||
} else {
|
||||
newTree.Delete(dst)
|
||||
t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Removing route")
|
||||
}
|
||||
t.routeTree.Store(newTree)
|
||||
}
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ package overlay
|
|||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/exec"
|
||||
"regexp"
|
||||
|
@ -15,10 +15,9 @@ import (
|
|||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cidr"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/slackhq/nebula/util"
|
||||
)
|
||||
|
||||
|
@ -29,10 +28,10 @@ type ifreqDestroy struct {
|
|||
|
||||
type tun struct {
|
||||
Device string
|
||||
cidr *net.IPNet
|
||||
cidr netip.Prefix
|
||||
MTU int
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
|
||||
routeTree atomic.Pointer[bart.Table[netip.Addr]]
|
||||
l *logrus.Logger
|
||||
|
||||
io.ReadWriteCloser
|
||||
|
@ -59,13 +58,13 @@ func (t *tun) Close() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*tun, error) {
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported in NetBSD")
|
||||
}
|
||||
|
||||
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
|
||||
|
||||
func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error) {
|
||||
func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) {
|
||||
// Try to open tun device
|
||||
var file *os.File
|
||||
var err error
|
||||
|
@ -109,13 +108,13 @@ func (t *tun) Activate() error {
|
|||
var err error
|
||||
|
||||
// TODO use syscalls instead of exec.Command
|
||||
cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String())
|
||||
cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String())
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err = cmd.Run(); err != nil {
|
||||
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
||||
}
|
||||
|
||||
cmd = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), t.cidr.IP.String())
|
||||
cmd = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), t.cidr.Addr().String())
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err = cmd.Run(); err != nil {
|
||||
return fmt.Errorf("failed to run 'route add': %s", err)
|
||||
|
@ -168,12 +167,12 @@ func (t *tun) reload(c *config.C, initial bool) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
|
||||
_, r := t.routeTree.Load().MostSpecificContains(ip)
|
||||
func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
|
||||
r, _ := t.routeTree.Load().Lookup(ip)
|
||||
return r
|
||||
}
|
||||
|
||||
func (t *tun) Cidr() *net.IPNet {
|
||||
func (t *tun) Cidr() netip.Prefix {
|
||||
return t.cidr
|
||||
}
|
||||
|
||||
|
@ -188,12 +187,12 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
|||
func (t *tun) addRoutes(logErrors bool) error {
|
||||
routes := *t.Routes.Load()
|
||||
for _, r := range routes {
|
||||
if r.Via == nil || !r.Install {
|
||||
if !r.Via.IsValid() || !r.Install {
|
||||
// We don't allow route MTUs so only install routes with a via
|
||||
continue
|
||||
}
|
||||
|
||||
cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.cidr.IP.String())
|
||||
cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.cidr.Addr().String())
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err := cmd.Run(); err != nil {
|
||||
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err)
|
||||
|
@ -214,7 +213,7 @@ func (t *tun) removeRoutes(routes []Route) error {
|
|||
continue
|
||||
}
|
||||
|
||||
cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.cidr.IP.String())
|
||||
cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.cidr.Addr().String())
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err := cmd.Run(); err != nil {
|
||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
||||
|
|
|
@ -6,7 +6,7 @@ package overlay
|
|||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/exec"
|
||||
"regexp"
|
||||
|
@ -14,19 +14,18 @@ import (
|
|||
"sync/atomic"
|
||||
"syscall"
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cidr"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/slackhq/nebula/util"
|
||||
)
|
||||
|
||||
type tun struct {
|
||||
Device string
|
||||
cidr *net.IPNet
|
||||
cidr netip.Prefix
|
||||
MTU int
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
|
||||
routeTree atomic.Pointer[bart.Table[netip.Addr]]
|
||||
l *logrus.Logger
|
||||
|
||||
io.ReadWriteCloser
|
||||
|
@ -43,13 +42,13 @@ func (t *tun) Close() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*tun, error) {
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported in OpenBSD")
|
||||
}
|
||||
|
||||
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
|
||||
|
||||
func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error) {
|
||||
func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) {
|
||||
deviceName := c.GetString("tun.dev", "")
|
||||
if deviceName == "" {
|
||||
return nil, fmt.Errorf("a device name in the format of tunN must be specified")
|
||||
|
@ -127,7 +126,7 @@ func (t *tun) reload(c *config.C, initial bool) error {
|
|||
func (t *tun) Activate() error {
|
||||
var err error
|
||||
// TODO use syscalls instead of exec.Command
|
||||
cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String())
|
||||
cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String())
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err = cmd.Run(); err != nil {
|
||||
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
||||
|
@ -139,7 +138,7 @@ func (t *tun) Activate() error {
|
|||
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
||||
}
|
||||
|
||||
cmd = exec.Command("/sbin/route", "-n", "add", "-inet", t.cidr.String(), t.cidr.IP.String())
|
||||
cmd = exec.Command("/sbin/route", "-n", "add", "-inet", t.cidr.String(), t.cidr.Addr().String())
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err = cmd.Run(); err != nil {
|
||||
return fmt.Errorf("failed to run 'route add': %s", err)
|
||||
|
@ -149,20 +148,20 @@ func (t *tun) Activate() error {
|
|||
return t.addRoutes(false)
|
||||
}
|
||||
|
||||
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
|
||||
_, r := t.routeTree.Load().MostSpecificContains(ip)
|
||||
func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
|
||||
r, _ := t.routeTree.Load().Lookup(ip)
|
||||
return r
|
||||
}
|
||||
|
||||
func (t *tun) addRoutes(logErrors bool) error {
|
||||
routes := *t.Routes.Load()
|
||||
for _, r := range routes {
|
||||
if r.Via == nil || !r.Install {
|
||||
if !r.Via.IsValid() || !r.Install {
|
||||
// We don't allow route MTUs so only install routes with a via
|
||||
continue
|
||||
}
|
||||
|
||||
cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.cidr.IP.String())
|
||||
cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.cidr.Addr().String())
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err := cmd.Run(); err != nil {
|
||||
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err)
|
||||
|
@ -183,7 +182,7 @@ func (t *tun) removeRoutes(routes []Route) error {
|
|||
continue
|
||||
}
|
||||
|
||||
cmd := exec.Command("/sbin/route", "-n", "delete", "-inet", r.Cidr.String(), t.cidr.IP.String())
|
||||
cmd := exec.Command("/sbin/route", "-n", "delete", "-inet", r.Cidr.String(), t.cidr.Addr().String())
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err := cmd.Run(); err != nil {
|
||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
||||
|
@ -194,7 +193,7 @@ func (t *tun) removeRoutes(routes []Route) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) Cidr() *net.IPNet {
|
||||
func (t *tun) Cidr() netip.Prefix {
|
||||
return t.cidr
|
||||
}
|
||||
|
||||
|
|
|
@ -6,21 +6,20 @@ package overlay
|
|||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cidr"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
)
|
||||
|
||||
type TestTun struct {
|
||||
Device string
|
||||
cidr *net.IPNet
|
||||
cidr netip.Prefix
|
||||
Routes []Route
|
||||
routeTree *cidr.Tree4[iputil.VpnIp]
|
||||
routeTree *bart.Table[netip.Addr]
|
||||
l *logrus.Logger
|
||||
|
||||
closed atomic.Bool
|
||||
|
@ -28,7 +27,7 @@ type TestTun struct {
|
|||
TxPackets chan []byte // Packets transmitted outside by nebula
|
||||
}
|
||||
|
||||
func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*TestTun, error) {
|
||||
func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*TestTun, error) {
|
||||
_, routes, err := getAllRoutesFromConfig(c, cidr, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -49,7 +48,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*TestTun, e
|
|||
}, nil
|
||||
}
|
||||
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*TestTun, error) {
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*TestTun, error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported")
|
||||
}
|
||||
|
||||
|
@ -87,8 +86,8 @@ func (t *TestTun) Get(block bool) []byte {
|
|||
// Below this is boilerplate implementation to make nebula actually work
|
||||
//********************************************************************************************************************//
|
||||
|
||||
func (t *TestTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
|
||||
_, r := t.routeTree.MostSpecificContains(ip)
|
||||
func (t *TestTun) RouteFor(ip netip.Addr) netip.Addr {
|
||||
r, _ := t.routeTree.Lookup(ip)
|
||||
return r
|
||||
}
|
||||
|
||||
|
@ -96,7 +95,7 @@ func (t *TestTun) Activate() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (t *TestTun) Cidr() *net.IPNet {
|
||||
func (t *TestTun) Cidr() netip.Prefix {
|
||||
return t.cidr
|
||||
}
|
||||
|
||||
|
|
|
@ -4,30 +4,30 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cidr"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/slackhq/nebula/util"
|
||||
"github.com/songgao/water"
|
||||
)
|
||||
|
||||
type waterTun struct {
|
||||
Device string
|
||||
cidr *net.IPNet
|
||||
cidr netip.Prefix
|
||||
MTU int
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
|
||||
routeTree atomic.Pointer[bart.Table[netip.Addr]]
|
||||
l *logrus.Logger
|
||||
f *net.Interface
|
||||
*water.Interface
|
||||
}
|
||||
|
||||
func newWaterTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*waterTun, error) {
|
||||
func newWaterTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*waterTun, error) {
|
||||
// NOTE: You cannot set the deviceName under Windows, so you must check tun.Device after calling .Activate()
|
||||
t := &waterTun{
|
||||
cidr: cidr,
|
||||
|
@ -70,8 +70,8 @@ func (t *waterTun) Activate() error {
|
|||
`C:\Windows\System32\netsh.exe`, "interface", "ipv4", "set", "address",
|
||||
fmt.Sprintf("name=%s", t.Device),
|
||||
"source=static",
|
||||
fmt.Sprintf("addr=%s", t.cidr.IP),
|
||||
fmt.Sprintf("mask=%s", net.IP(t.cidr.Mask)),
|
||||
fmt.Sprintf("addr=%s", t.cidr.Addr()),
|
||||
fmt.Sprintf("mask=%s", net.CIDRMask(t.cidr.Bits(), t.cidr.Addr().BitLen())),
|
||||
"gateway=none",
|
||||
).Run()
|
||||
if err != nil {
|
||||
|
@ -141,7 +141,7 @@ func (t *waterTun) addRoutes(logErrors bool) error {
|
|||
// Path routes
|
||||
routes := *t.Routes.Load()
|
||||
for _, r := range routes {
|
||||
if r.Via == nil || !r.Install {
|
||||
if !r.Via.IsValid() || !r.Install {
|
||||
// We don't allow route MTUs so only install routes with a via
|
||||
continue
|
||||
}
|
||||
|
@ -182,12 +182,12 @@ func (t *waterTun) removeRoutes(routes []Route) {
|
|||
}
|
||||
}
|
||||
|
||||
func (t *waterTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
|
||||
_, r := t.routeTree.Load().MostSpecificContains(ip)
|
||||
func (t *waterTun) RouteFor(ip netip.Addr) netip.Addr {
|
||||
r, _ := t.routeTree.Load().Lookup(ip)
|
||||
return r
|
||||
}
|
||||
|
||||
func (t *waterTun) Cidr() *net.IPNet {
|
||||
func (t *waterTun) Cidr() netip.Prefix {
|
||||
return t.cidr
|
||||
}
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ package overlay
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
|
@ -15,11 +15,11 @@ import (
|
|||
"github.com/slackhq/nebula/config"
|
||||
)
|
||||
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (Device, error) {
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (Device, error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported in Windows")
|
||||
}
|
||||
|
||||
func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, multiqueue bool) (Device, error) {
|
||||
func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (Device, error) {
|
||||
useWintun := true
|
||||
if err := checkWinTunExists(); err != nil {
|
||||
l.WithError(err).Warn("Check Wintun driver failed, fallback to wintap driver")
|
||||
|
|
|
@ -4,15 +4,13 @@ import (
|
|||
"crypto"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync/atomic"
|
||||
"unsafe"
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cidr"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/slackhq/nebula/util"
|
||||
"github.com/slackhq/nebula/wintun"
|
||||
"golang.org/x/sys/windows"
|
||||
|
@ -23,11 +21,10 @@ const tunGUIDLabel = "Fixed Nebula Windows GUID v1"
|
|||
|
||||
type winTun struct {
|
||||
Device string
|
||||
cidr *net.IPNet
|
||||
prefix netip.Prefix
|
||||
cidr netip.Prefix
|
||||
MTU int
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
|
||||
routeTree atomic.Pointer[bart.Table[netip.Addr]]
|
||||
l *logrus.Logger
|
||||
|
||||
tun *wintun.NativeTun
|
||||
|
@ -52,22 +49,16 @@ func generateGUIDByDeviceName(name string) (*windows.GUID, error) {
|
|||
return (*windows.GUID)(unsafe.Pointer(&sum[0])), nil
|
||||
}
|
||||
|
||||
func newWinTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*winTun, error) {
|
||||
func newWinTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*winTun, error) {
|
||||
deviceName := c.GetString("tun.dev", "")
|
||||
guid, err := generateGUIDByDeviceName(deviceName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate GUID failed: %w", err)
|
||||
}
|
||||
|
||||
prefix, err := iputil.ToNetIpPrefix(*cidr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t := &winTun{
|
||||
Device: deviceName,
|
||||
cidr: cidr,
|
||||
prefix: prefix,
|
||||
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||
l: l,
|
||||
}
|
||||
|
@ -140,7 +131,7 @@ func (t *winTun) reload(c *config.C, initial bool) error {
|
|||
func (t *winTun) Activate() error {
|
||||
luid := winipcfg.LUID(t.tun.LUID())
|
||||
|
||||
err := luid.SetIPAddresses([]netip.Prefix{t.prefix})
|
||||
err := luid.SetIPAddresses([]netip.Prefix{t.cidr})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set address: %w", err)
|
||||
}
|
||||
|
@ -159,24 +150,13 @@ func (t *winTun) addRoutes(logErrors bool) error {
|
|||
foundDefault4 := false
|
||||
|
||||
for _, r := range routes {
|
||||
if r.Via == nil || !r.Install {
|
||||
if !r.Via.IsValid() || !r.Install {
|
||||
// We don't allow route MTUs so only install routes with a via
|
||||
continue
|
||||
}
|
||||
|
||||
prefix, err := iputil.ToNetIpPrefix(*r.Cidr)
|
||||
if err != nil {
|
||||
retErr := util.NewContextualError("Failed to parse cidr to netip prefix, ignoring route", map[string]interface{}{"route": r}, err)
|
||||
if logErrors {
|
||||
retErr.Log(t.l)
|
||||
continue
|
||||
} else {
|
||||
return retErr
|
||||
}
|
||||
}
|
||||
|
||||
// Add our unsafe route
|
||||
err = luid.AddRoute(prefix, r.Via.ToNetIpAddr(), uint32(r.Metric))
|
||||
err := luid.AddRoute(r.Cidr, r.Via, uint32(r.Metric))
|
||||
if err != nil {
|
||||
retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err)
|
||||
if logErrors {
|
||||
|
@ -190,7 +170,7 @@ func (t *winTun) addRoutes(logErrors bool) error {
|
|||
}
|
||||
|
||||
if !foundDefault4 {
|
||||
if ones, bits := r.Cidr.Mask.Size(); ones == 0 && bits != 0 {
|
||||
if r.Cidr.Bits() == 0 && r.Cidr.Addr().BitLen() == 32 {
|
||||
foundDefault4 = true
|
||||
}
|
||||
}
|
||||
|
@ -221,13 +201,7 @@ func (t *winTun) removeRoutes(routes []Route) error {
|
|||
continue
|
||||
}
|
||||
|
||||
prefix, err := iputil.ToNetIpPrefix(*r.Cidr)
|
||||
if err != nil {
|
||||
t.l.WithError(err).WithField("route", r).Info("Failed to convert cidr to netip prefix")
|
||||
continue
|
||||
}
|
||||
|
||||
err = luid.DeleteRoute(prefix, r.Via.ToNetIpAddr())
|
||||
err := luid.DeleteRoute(r.Cidr, r.Via)
|
||||
if err != nil {
|
||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
||||
} else {
|
||||
|
@ -237,12 +211,12 @@ func (t *winTun) removeRoutes(routes []Route) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (t *winTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
|
||||
_, r := t.routeTree.Load().MostSpecificContains(ip)
|
||||
func (t *winTun) RouteFor(ip netip.Addr) netip.Addr {
|
||||
r, _ := t.routeTree.Load().Lookup(ip)
|
||||
return r
|
||||
}
|
||||
|
||||
func (t *winTun) Cidr() *net.IPNet {
|
||||
func (t *winTun) Cidr() netip.Prefix {
|
||||
return t.cidr
|
||||
}
|
||||
|
||||
|
|
|
@ -2,18 +2,17 @@ package overlay
|
|||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
)
|
||||
|
||||
func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) {
|
||||
func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) {
|
||||
return NewUserDevice(tunCidr)
|
||||
}
|
||||
|
||||
func NewUserDevice(tunCidr *net.IPNet) (Device, error) {
|
||||
func NewUserDevice(tunCidr netip.Prefix) (Device, error) {
|
||||
// these pipes guarantee each write/read will match 1:1
|
||||
or, ow := io.Pipe()
|
||||
ir, iw := io.Pipe()
|
||||
|
@ -27,7 +26,7 @@ func NewUserDevice(tunCidr *net.IPNet) (Device, error) {
|
|||
}
|
||||
|
||||
type UserDevice struct {
|
||||
tunCidr *net.IPNet
|
||||
tunCidr netip.Prefix
|
||||
|
||||
outboundReader *io.PipeReader
|
||||
outboundWriter *io.PipeWriter
|
||||
|
@ -39,9 +38,9 @@ type UserDevice struct {
|
|||
func (d *UserDevice) Activate() error {
|
||||
return nil
|
||||
}
|
||||
func (d *UserDevice) Cidr() *net.IPNet { return d.tunCidr }
|
||||
func (d *UserDevice) Name() string { return "faketun0" }
|
||||
func (d *UserDevice) RouteFor(ip iputil.VpnIp) iputil.VpnIp { return ip }
|
||||
func (d *UserDevice) Cidr() netip.Prefix { return d.tunCidr }
|
||||
func (d *UserDevice) Name() string { return "faketun0" }
|
||||
func (d *UserDevice) RouteFor(ip netip.Addr) netip.Addr { return ip }
|
||||
func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
return d, nil
|
||||
}
|
||||
|
|
2
pki.go
2
pki.go
|
@ -80,6 +80,8 @@ func (p *PKI) reloadCert(c *config.C, initial bool) *util.ContextualError {
|
|||
}
|
||||
|
||||
if !initial {
|
||||
//TODO: include check for mask equality as well
|
||||
|
||||
// did IP in cert change? if so, don't set
|
||||
currentCert := p.cs.Load().Certificate
|
||||
oldIPs := currentCert.Details.Ips
|
||||
|
|
|
@ -2,14 +2,15 @@ package nebula
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/header"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
)
|
||||
|
||||
type relayManager struct {
|
||||
|
@ -50,7 +51,7 @@ func (rm *relayManager) setAmRelay(v bool) {
|
|||
|
||||
// AddRelay finds an available relay index on the hostmap, and associates the relay info with it.
|
||||
// relayHostInfo is the Nebula peer which can be used as a relay to access the target vpnIp.
|
||||
func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp iputil.VpnIp, remoteIdx *uint32, relayType int, state int) (uint32, error) {
|
||||
func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp netip.Addr, remoteIdx *uint32, relayType int, state int) (uint32, error) {
|
||||
hm.Lock()
|
||||
defer hm.Unlock()
|
||||
for i := 0; i < 32; i++ {
|
||||
|
@ -113,13 +114,17 @@ func (rm *relayManager) HandleControlMsg(h *HostInfo, m *NebulaControl, f *Inter
|
|||
|
||||
func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m *NebulaControl) {
|
||||
rm.l.WithFields(logrus.Fields{
|
||||
"relayFrom": iputil.VpnIp(m.RelayFromIp),
|
||||
"relayTo": iputil.VpnIp(m.RelayToIp),
|
||||
"relayFrom": m.RelayFromIp,
|
||||
"relayTo": m.RelayToIp,
|
||||
"initiatorRelayIndex": m.InitiatorRelayIndex,
|
||||
"responderRelayIndex": m.ResponderRelayIndex,
|
||||
"vpnIp": h.vpnIp}).
|
||||
Info("handleCreateRelayResponse")
|
||||
target := iputil.VpnIp(m.RelayToIp)
|
||||
target := m.RelayToIp
|
||||
//TODO: IPV6-WORK
|
||||
b := [4]byte{}
|
||||
binary.BigEndian.PutUint32(b[:], m.RelayToIp)
|
||||
targetAddr := netip.AddrFrom4(b)
|
||||
|
||||
relay, err := rm.EstablishRelay(h, m)
|
||||
if err != nil {
|
||||
|
@ -136,18 +141,20 @@ func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m *
|
|||
rm.l.WithField("relayTo", relay.PeerIp).Error("Can't find a HostInfo for peer")
|
||||
return
|
||||
}
|
||||
peerRelay, ok := peerHostInfo.relayState.QueryRelayForByIp(target)
|
||||
peerRelay, ok := peerHostInfo.relayState.QueryRelayForByIp(targetAddr)
|
||||
if !ok {
|
||||
rm.l.WithField("relayTo", peerHostInfo.vpnIp).Error("peerRelay does not have Relay state for relayTo")
|
||||
return
|
||||
}
|
||||
if peerRelay.State == PeerRequested {
|
||||
//TODO: IPV6-WORK
|
||||
b = peerHostInfo.vpnIp.As4()
|
||||
peerRelay.State = Established
|
||||
resp := NebulaControl{
|
||||
Type: NebulaControl_CreateRelayResponse,
|
||||
ResponderRelayIndex: peerRelay.LocalIndex,
|
||||
InitiatorRelayIndex: peerRelay.RemoteIndex,
|
||||
RelayFromIp: uint32(peerHostInfo.vpnIp),
|
||||
RelayFromIp: binary.BigEndian.Uint32(b[:]),
|
||||
RelayToIp: uint32(target),
|
||||
}
|
||||
msg, err := resp.Marshal()
|
||||
|
@ -157,8 +164,8 @@ func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m *
|
|||
} else {
|
||||
f.SendMessageToHostInfo(header.Control, 0, peerHostInfo, msg, make([]byte, 12), make([]byte, mtu))
|
||||
rm.l.WithFields(logrus.Fields{
|
||||
"relayFrom": iputil.VpnIp(resp.RelayFromIp),
|
||||
"relayTo": iputil.VpnIp(resp.RelayToIp),
|
||||
"relayFrom": resp.RelayFromIp,
|
||||
"relayTo": resp.RelayToIp,
|
||||
"initiatorRelayIndex": resp.InitiatorRelayIndex,
|
||||
"responderRelayIndex": resp.ResponderRelayIndex,
|
||||
"vpnIp": peerHostInfo.vpnIp}).
|
||||
|
@ -168,9 +175,13 @@ func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m *
|
|||
}
|
||||
|
||||
func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *NebulaControl) {
|
||||
//TODO: IPV6-WORK
|
||||
b := [4]byte{}
|
||||
binary.BigEndian.PutUint32(b[:], m.RelayFromIp)
|
||||
from := netip.AddrFrom4(b)
|
||||
|
||||
from := iputil.VpnIp(m.RelayFromIp)
|
||||
target := iputil.VpnIp(m.RelayToIp)
|
||||
binary.BigEndian.PutUint32(b[:], m.RelayToIp)
|
||||
target := netip.AddrFrom4(b)
|
||||
|
||||
logMsg := rm.l.WithFields(logrus.Fields{
|
||||
"relayFrom": from,
|
||||
|
@ -181,12 +192,12 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
|
|||
logMsg.Info("handleCreateRelayRequest")
|
||||
// Is the source of the relay me? This should never happen, but did happen due to
|
||||
// an issue migrating relays over to newly re-handshaked host info objects.
|
||||
if from == f.myVpnIp {
|
||||
logMsg.WithField("myIP", f.myVpnIp).Error("Discarding relay request from myself")
|
||||
if from == f.myVpnNet.Addr() {
|
||||
logMsg.WithField("myIP", from).Error("Discarding relay request from myself")
|
||||
return
|
||||
}
|
||||
// Is the target of the relay me?
|
||||
if target == f.myVpnIp {
|
||||
if target == f.myVpnNet.Addr() {
|
||||
existingRelay, ok := h.relayState.QueryRelayForByIp(from)
|
||||
if ok {
|
||||
switch existingRelay.State {
|
||||
|
@ -219,12 +230,16 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
|
|||
return
|
||||
}
|
||||
|
||||
//TODO: IPV6-WORK
|
||||
fromB := from.As4()
|
||||
targetB := target.As4()
|
||||
|
||||
resp := NebulaControl{
|
||||
Type: NebulaControl_CreateRelayResponse,
|
||||
ResponderRelayIndex: relay.LocalIndex,
|
||||
InitiatorRelayIndex: relay.RemoteIndex,
|
||||
RelayFromIp: uint32(from),
|
||||
RelayToIp: uint32(target),
|
||||
RelayFromIp: binary.BigEndian.Uint32(fromB[:]),
|
||||
RelayToIp: binary.BigEndian.Uint32(targetB[:]),
|
||||
}
|
||||
msg, err := resp.Marshal()
|
||||
if err != nil {
|
||||
|
@ -233,8 +248,9 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
|
|||
} else {
|
||||
f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu))
|
||||
rm.l.WithFields(logrus.Fields{
|
||||
"relayFrom": iputil.VpnIp(resp.RelayFromIp),
|
||||
"relayTo": iputil.VpnIp(resp.RelayToIp),
|
||||
//TODO: IPV6-WORK, this used to use the resp object but I am getting lazy now
|
||||
"relayFrom": from,
|
||||
"relayTo": target,
|
||||
"initiatorRelayIndex": resp.InitiatorRelayIndex,
|
||||
"responderRelayIndex": resp.ResponderRelayIndex,
|
||||
"vpnIp": h.vpnIp}).
|
||||
|
@ -253,7 +269,7 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
|
|||
f.Handshake(target)
|
||||
return
|
||||
}
|
||||
if peer.remote == nil {
|
||||
if !peer.remote.IsValid() {
|
||||
// Only create relays to peers for whom I have a direct connection
|
||||
return
|
||||
}
|
||||
|
@ -275,12 +291,16 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
|
|||
sendCreateRequest = true
|
||||
}
|
||||
if sendCreateRequest {
|
||||
//TODO: IPV6-WORK
|
||||
fromB := h.vpnIp.As4()
|
||||
targetB := target.As4()
|
||||
|
||||
// Send a CreateRelayRequest to the peer.
|
||||
req := NebulaControl{
|
||||
Type: NebulaControl_CreateRelayRequest,
|
||||
InitiatorRelayIndex: index,
|
||||
RelayFromIp: uint32(h.vpnIp),
|
||||
RelayToIp: uint32(target),
|
||||
RelayFromIp: binary.BigEndian.Uint32(fromB[:]),
|
||||
RelayToIp: binary.BigEndian.Uint32(targetB[:]),
|
||||
}
|
||||
msg, err := req.Marshal()
|
||||
if err != nil {
|
||||
|
@ -289,8 +309,9 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
|
|||
} else {
|
||||
f.SendMessageToHostInfo(header.Control, 0, peer, msg, make([]byte, 12), make([]byte, mtu))
|
||||
rm.l.WithFields(logrus.Fields{
|
||||
"relayFrom": iputil.VpnIp(req.RelayFromIp),
|
||||
"relayTo": iputil.VpnIp(req.RelayToIp),
|
||||
//TODO: IPV6-WORK another lazy used to use the req object
|
||||
"relayFrom": h.vpnIp,
|
||||
"relayTo": target,
|
||||
"initiatorRelayIndex": req.InitiatorRelayIndex,
|
||||
"responderRelayIndex": req.ResponderRelayIndex,
|
||||
"vpnIp": target}).
|
||||
|
@ -321,12 +342,15 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
|
|||
"existingRemoteIndex": relay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest")
|
||||
return
|
||||
}
|
||||
//TODO: IPV6-WORK
|
||||
fromB := h.vpnIp.As4()
|
||||
targetB := target.As4()
|
||||
resp := NebulaControl{
|
||||
Type: NebulaControl_CreateRelayResponse,
|
||||
ResponderRelayIndex: relay.LocalIndex,
|
||||
InitiatorRelayIndex: relay.RemoteIndex,
|
||||
RelayFromIp: uint32(h.vpnIp),
|
||||
RelayToIp: uint32(target),
|
||||
RelayFromIp: binary.BigEndian.Uint32(fromB[:]),
|
||||
RelayToIp: binary.BigEndian.Uint32(targetB[:]),
|
||||
}
|
||||
msg, err := resp.Marshal()
|
||||
if err != nil {
|
||||
|
@ -335,8 +359,9 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
|
|||
} else {
|
||||
f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu))
|
||||
rm.l.WithFields(logrus.Fields{
|
||||
"relayFrom": iputil.VpnIp(resp.RelayFromIp),
|
||||
"relayTo": iputil.VpnIp(resp.RelayToIp),
|
||||
//TODO: IPV6-WORK more lazy, used to use resp object
|
||||
"relayFrom": h.vpnIp,
|
||||
"relayTo": target,
|
||||
"initiatorRelayIndex": resp.InitiatorRelayIndex,
|
||||
"responderRelayIndex": resp.ResponderRelayIndex,
|
||||
"vpnIp": h.vpnIp}).
|
||||
|
@ -349,7 +374,3 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (rm *relayManager) RemoveRelay(localIdx uint32) {
|
||||
rm.hostmap.RemoveRelay(localIdx)
|
||||
}
|
||||
|
|
166
remote_list.go
166
remote_list.go
|
@ -1,7 +1,6 @@
|
|||
package nebula
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
@ -12,16 +11,14 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/slackhq/nebula/udp"
|
||||
)
|
||||
|
||||
// forEachFunc is used to benefit folks that want to do work inside the lock
|
||||
type forEachFunc func(addr *udp.Addr, preferred bool)
|
||||
type forEachFunc func(addr netip.AddrPort, preferred bool)
|
||||
|
||||
// The checkFuncs here are to simplify bulk importing LH query response logic into a single function (reset slice and iterate)
|
||||
type checkFuncV4 func(vpnIp iputil.VpnIp, to *Ip4AndPort) bool
|
||||
type checkFuncV6 func(vpnIp iputil.VpnIp, to *Ip6AndPort) bool
|
||||
type checkFuncV4 func(vpnIp netip.Addr, to *Ip4AndPort) bool
|
||||
type checkFuncV6 func(vpnIp netip.Addr, to *Ip6AndPort) bool
|
||||
|
||||
// CacheMap is a struct that better represents the lighthouse cache for humans
|
||||
// The string key is the owners vpnIp
|
||||
|
@ -30,9 +27,9 @@ type CacheMap map[string]*Cache
|
|||
// Cache is the other part of CacheMap to better represent the lighthouse cache for humans
|
||||
// We don't reason about ipv4 vs ipv6 here
|
||||
type Cache struct {
|
||||
Learned []*udp.Addr `json:"learned,omitempty"`
|
||||
Reported []*udp.Addr `json:"reported,omitempty"`
|
||||
Relay []*net.IP `json:"relay"`
|
||||
Learned []netip.AddrPort `json:"learned,omitempty"`
|
||||
Reported []netip.AddrPort `json:"reported,omitempty"`
|
||||
Relay []netip.Addr `json:"relay"`
|
||||
}
|
||||
|
||||
//TODO: Seems like we should plop static host entries in here too since the are protected by the lighthouse from deletion
|
||||
|
@ -46,7 +43,7 @@ type cache struct {
|
|||
}
|
||||
|
||||
type cacheRelay struct {
|
||||
relay []uint32
|
||||
relay []netip.Addr
|
||||
}
|
||||
|
||||
// cacheV4 stores learned and reported ipv4 records under cache
|
||||
|
@ -130,7 +127,7 @@ func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration,
|
|||
continue
|
||||
}
|
||||
for _, a := range addrs {
|
||||
netipAddrs[netip.AddrPortFrom(a, hostPort.port)] = struct{}{}
|
||||
netipAddrs[netip.AddrPortFrom(a.Unmap(), hostPort.port)] = struct{}{}
|
||||
}
|
||||
}
|
||||
origSet := r.ips.Load()
|
||||
|
@ -193,22 +190,22 @@ type RemoteList struct {
|
|||
sync.RWMutex
|
||||
|
||||
// A deduplicated set of addresses. Any accessor should lock beforehand.
|
||||
addrs []*udp.Addr
|
||||
addrs []netip.AddrPort
|
||||
|
||||
// A set of relay addresses. VpnIp addresses that the remote identified as relays.
|
||||
relays []*iputil.VpnIp
|
||||
relays []netip.Addr
|
||||
|
||||
// These are maps to store v4 and v6 addresses per lighthouse
|
||||
// Map key is the vpnIp of the person that told us about this the cached entries underneath.
|
||||
// For learned addresses, this is the vpnIp that sent the packet
|
||||
cache map[iputil.VpnIp]*cache
|
||||
cache map[netip.Addr]*cache
|
||||
|
||||
hr *hostnamesResults
|
||||
shouldAdd func(netip.Addr) bool
|
||||
|
||||
// This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip.
|
||||
// They should not be tried again during a handshake
|
||||
badRemotes []*udp.Addr
|
||||
badRemotes []netip.AddrPort
|
||||
|
||||
// A flag that the cache may have changed and addrs needs to be rebuilt
|
||||
shouldRebuild bool
|
||||
|
@ -217,9 +214,9 @@ type RemoteList struct {
|
|||
// NewRemoteList creates a new empty RemoteList
|
||||
func NewRemoteList(shouldAdd func(netip.Addr) bool) *RemoteList {
|
||||
return &RemoteList{
|
||||
addrs: make([]*udp.Addr, 0),
|
||||
relays: make([]*iputil.VpnIp, 0),
|
||||
cache: make(map[iputil.VpnIp]*cache),
|
||||
addrs: make([]netip.AddrPort, 0),
|
||||
relays: make([]netip.Addr, 0),
|
||||
cache: make(map[netip.Addr]*cache),
|
||||
shouldAdd: shouldAdd,
|
||||
}
|
||||
}
|
||||
|
@ -232,7 +229,7 @@ func (r *RemoteList) unlockedSetHostnamesResults(hr *hostnamesResults) {
|
|||
|
||||
// Len locks and reports the size of the deduplicated address list
|
||||
// The deduplication work may need to occur here, so you must pass preferredRanges
|
||||
func (r *RemoteList) Len(preferredRanges []*net.IPNet) int {
|
||||
func (r *RemoteList) Len(preferredRanges []netip.Prefix) int {
|
||||
r.Rebuild(preferredRanges)
|
||||
r.RLock()
|
||||
defer r.RUnlock()
|
||||
|
@ -241,18 +238,18 @@ func (r *RemoteList) Len(preferredRanges []*net.IPNet) int {
|
|||
|
||||
// ForEach locks and will call the forEachFunc for every deduplicated address in the list
|
||||
// The deduplication work may need to occur here, so you must pass preferredRanges
|
||||
func (r *RemoteList) ForEach(preferredRanges []*net.IPNet, forEach forEachFunc) {
|
||||
func (r *RemoteList) ForEach(preferredRanges []netip.Prefix, forEach forEachFunc) {
|
||||
r.Rebuild(preferredRanges)
|
||||
r.RLock()
|
||||
for _, v := range r.addrs {
|
||||
forEach(v, isPreferred(v.IP, preferredRanges))
|
||||
forEach(v, isPreferred(v.Addr(), preferredRanges))
|
||||
}
|
||||
r.RUnlock()
|
||||
}
|
||||
|
||||
// CopyAddrs locks and makes a deep copy of the deduplicated address list
|
||||
// The deduplication work may need to occur here, so you must pass preferredRanges
|
||||
func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udp.Addr {
|
||||
func (r *RemoteList) CopyAddrs(preferredRanges []netip.Prefix) []netip.AddrPort {
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
|
@ -261,9 +258,9 @@ func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udp.Addr {
|
|||
|
||||
r.RLock()
|
||||
defer r.RUnlock()
|
||||
c := make([]*udp.Addr, len(r.addrs))
|
||||
c := make([]netip.AddrPort, len(r.addrs))
|
||||
for i, v := range r.addrs {
|
||||
c[i] = v.Copy()
|
||||
c[i] = v
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
@ -272,13 +269,13 @@ func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udp.Addr {
|
|||
// Currently this is only needed when HostInfo.SetRemote is called as that should cover both handshaking and roaming.
|
||||
// It will mark the deduplicated address list as dirty, so do not call it unless new information is available
|
||||
// TODO: this needs to support the allow list list
|
||||
func (r *RemoteList) LearnRemote(ownerVpnIp iputil.VpnIp, addr *udp.Addr) {
|
||||
func (r *RemoteList) LearnRemote(ownerVpnIp netip.Addr, remote netip.AddrPort) {
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
if v4 := addr.IP.To4(); v4 != nil {
|
||||
r.unlockedSetLearnedV4(ownerVpnIp, NewIp4AndPort(v4, uint32(addr.Port)))
|
||||
if remote.Addr().Is4() {
|
||||
r.unlockedSetLearnedV4(ownerVpnIp, NewIp4AndPortFromNetIP(remote.Addr(), remote.Port()))
|
||||
} else {
|
||||
r.unlockedSetLearnedV6(ownerVpnIp, NewIp6AndPort(addr.IP, uint32(addr.Port)))
|
||||
r.unlockedSetLearnedV6(ownerVpnIp, NewIp6AndPortFromNetIP(remote.Addr(), remote.Port()))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -293,9 +290,9 @@ func (r *RemoteList) CopyCache() *CacheMap {
|
|||
c := cm[vpnIp]
|
||||
if c == nil {
|
||||
c = &Cache{
|
||||
Learned: make([]*udp.Addr, 0),
|
||||
Reported: make([]*udp.Addr, 0),
|
||||
Relay: make([]*net.IP, 0),
|
||||
Learned: make([]netip.AddrPort, 0),
|
||||
Reported: make([]netip.AddrPort, 0),
|
||||
Relay: make([]netip.Addr, 0),
|
||||
}
|
||||
cm[vpnIp] = c
|
||||
}
|
||||
|
@ -307,28 +304,27 @@ func (r *RemoteList) CopyCache() *CacheMap {
|
|||
|
||||
if mc.v4 != nil {
|
||||
if mc.v4.learned != nil {
|
||||
c.Learned = append(c.Learned, NewUDPAddrFromLH4(mc.v4.learned))
|
||||
c.Learned = append(c.Learned, AddrPortFromIp4AndPort(mc.v4.learned))
|
||||
}
|
||||
|
||||
for _, a := range mc.v4.reported {
|
||||
c.Reported = append(c.Reported, NewUDPAddrFromLH4(a))
|
||||
c.Reported = append(c.Reported, AddrPortFromIp4AndPort(a))
|
||||
}
|
||||
}
|
||||
|
||||
if mc.v6 != nil {
|
||||
if mc.v6.learned != nil {
|
||||
c.Learned = append(c.Learned, NewUDPAddrFromLH6(mc.v6.learned))
|
||||
c.Learned = append(c.Learned, AddrPortFromIp6AndPort(mc.v6.learned))
|
||||
}
|
||||
|
||||
for _, a := range mc.v6.reported {
|
||||
c.Reported = append(c.Reported, NewUDPAddrFromLH6(a))
|
||||
c.Reported = append(c.Reported, AddrPortFromIp6AndPort(a))
|
||||
}
|
||||
}
|
||||
|
||||
if mc.relay != nil {
|
||||
for _, a := range mc.relay.relay {
|
||||
nip := iputil.VpnIp(a).ToIP()
|
||||
c.Relay = append(c.Relay, &nip)
|
||||
c.Relay = append(c.Relay, a)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -337,8 +333,8 @@ func (r *RemoteList) CopyCache() *CacheMap {
|
|||
}
|
||||
|
||||
// BlockRemote locks and records the address as bad, it will be excluded from the deduplicated address list
|
||||
func (r *RemoteList) BlockRemote(bad *udp.Addr) {
|
||||
if bad == nil {
|
||||
func (r *RemoteList) BlockRemote(bad netip.AddrPort) {
|
||||
if !bad.IsValid() {
|
||||
// relays can have nil udp Addrs
|
||||
return
|
||||
}
|
||||
|
@ -351,20 +347,20 @@ func (r *RemoteList) BlockRemote(bad *udp.Addr) {
|
|||
}
|
||||
|
||||
// We copy here because we are taking something else's memory and we can't trust everything
|
||||
r.badRemotes = append(r.badRemotes, bad.Copy())
|
||||
r.badRemotes = append(r.badRemotes, bad)
|
||||
|
||||
// Mark the next interaction must recollect/dedupe
|
||||
r.shouldRebuild = true
|
||||
}
|
||||
|
||||
// CopyBlockedRemotes locks and makes a deep copy of the blocked remotes list
|
||||
func (r *RemoteList) CopyBlockedRemotes() []*udp.Addr {
|
||||
func (r *RemoteList) CopyBlockedRemotes() []netip.AddrPort {
|
||||
r.RLock()
|
||||
defer r.RUnlock()
|
||||
|
||||
c := make([]*udp.Addr, len(r.badRemotes))
|
||||
c := make([]netip.AddrPort, len(r.badRemotes))
|
||||
for i, v := range r.badRemotes {
|
||||
c[i] = v.Copy()
|
||||
c[i] = v
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
@ -378,7 +374,7 @@ func (r *RemoteList) ResetBlockedRemotes() {
|
|||
|
||||
// Rebuild locks and generates the deduplicated address list only if there is work to be done
|
||||
// There is generally no reason to call this directly but it is safe to do so
|
||||
func (r *RemoteList) Rebuild(preferredRanges []*net.IPNet) {
|
||||
func (r *RemoteList) Rebuild(preferredRanges []netip.Prefix) {
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
|
||||
|
@ -394,9 +390,9 @@ func (r *RemoteList) Rebuild(preferredRanges []*net.IPNet) {
|
|||
}
|
||||
|
||||
// unlockedIsBad assumes you have the write lock and checks if the remote matches any entry in the blocked address list
|
||||
func (r *RemoteList) unlockedIsBad(remote *udp.Addr) bool {
|
||||
func (r *RemoteList) unlockedIsBad(remote netip.AddrPort) bool {
|
||||
for _, v := range r.badRemotes {
|
||||
if v.Equals(remote) {
|
||||
if v == remote {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
@ -405,14 +401,14 @@ func (r *RemoteList) unlockedIsBad(remote *udp.Addr) bool {
|
|||
|
||||
// unlockedSetLearnedV4 assumes you have the write lock and sets the current learned address for this owner and marks the
|
||||
// deduplicated address list as dirty
|
||||
func (r *RemoteList) unlockedSetLearnedV4(ownerVpnIp iputil.VpnIp, to *Ip4AndPort) {
|
||||
func (r *RemoteList) unlockedSetLearnedV4(ownerVpnIp netip.Addr, to *Ip4AndPort) {
|
||||
r.shouldRebuild = true
|
||||
r.unlockedGetOrMakeV4(ownerVpnIp).learned = to
|
||||
}
|
||||
|
||||
// unlockedSetV4 assumes you have the write lock and resets the reported list of ips for this owner to the list provided
|
||||
// and marks the deduplicated address list as dirty
|
||||
func (r *RemoteList) unlockedSetV4(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp, to []*Ip4AndPort, check checkFuncV4) {
|
||||
func (r *RemoteList) unlockedSetV4(ownerVpnIp, vpnIp netip.Addr, to []*Ip4AndPort, check checkFuncV4) {
|
||||
r.shouldRebuild = true
|
||||
c := r.unlockedGetOrMakeV4(ownerVpnIp)
|
||||
|
||||
|
@ -427,7 +423,7 @@ func (r *RemoteList) unlockedSetV4(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp,
|
|||
}
|
||||
}
|
||||
|
||||
func (r *RemoteList) unlockedSetRelay(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp, to []uint32) {
|
||||
func (r *RemoteList) unlockedSetRelay(ownerVpnIp, vpnIp netip.Addr, to []netip.Addr) {
|
||||
r.shouldRebuild = true
|
||||
c := r.unlockedGetOrMakeRelay(ownerVpnIp)
|
||||
|
||||
|
@ -440,7 +436,7 @@ func (r *RemoteList) unlockedSetRelay(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnI
|
|||
|
||||
// unlockedPrependV4 assumes you have the write lock and prepends the address in the reported list for this owner
|
||||
// This is only useful for establishing static hosts
|
||||
func (r *RemoteList) unlockedPrependV4(ownerVpnIp iputil.VpnIp, to *Ip4AndPort) {
|
||||
func (r *RemoteList) unlockedPrependV4(ownerVpnIp netip.Addr, to *Ip4AndPort) {
|
||||
r.shouldRebuild = true
|
||||
c := r.unlockedGetOrMakeV4(ownerVpnIp)
|
||||
|
||||
|
@ -453,14 +449,14 @@ func (r *RemoteList) unlockedPrependV4(ownerVpnIp iputil.VpnIp, to *Ip4AndPort)
|
|||
|
||||
// unlockedSetLearnedV6 assumes you have the write lock and sets the current learned address for this owner and marks the
|
||||
// deduplicated address list as dirty
|
||||
func (r *RemoteList) unlockedSetLearnedV6(ownerVpnIp iputil.VpnIp, to *Ip6AndPort) {
|
||||
func (r *RemoteList) unlockedSetLearnedV6(ownerVpnIp netip.Addr, to *Ip6AndPort) {
|
||||
r.shouldRebuild = true
|
||||
r.unlockedGetOrMakeV6(ownerVpnIp).learned = to
|
||||
}
|
||||
|
||||
// unlockedSetV6 assumes you have the write lock and resets the reported list of ips for this owner to the list provided
|
||||
// and marks the deduplicated address list as dirty
|
||||
func (r *RemoteList) unlockedSetV6(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp, to []*Ip6AndPort, check checkFuncV6) {
|
||||
func (r *RemoteList) unlockedSetV6(ownerVpnIp, vpnIp netip.Addr, to []*Ip6AndPort, check checkFuncV6) {
|
||||
r.shouldRebuild = true
|
||||
c := r.unlockedGetOrMakeV6(ownerVpnIp)
|
||||
|
||||
|
@ -477,7 +473,7 @@ func (r *RemoteList) unlockedSetV6(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp,
|
|||
|
||||
// unlockedPrependV6 assumes you have the write lock and prepends the address in the reported list for this owner
|
||||
// This is only useful for establishing static hosts
|
||||
func (r *RemoteList) unlockedPrependV6(ownerVpnIp iputil.VpnIp, to *Ip6AndPort) {
|
||||
func (r *RemoteList) unlockedPrependV6(ownerVpnIp netip.Addr, to *Ip6AndPort) {
|
||||
r.shouldRebuild = true
|
||||
c := r.unlockedGetOrMakeV6(ownerVpnIp)
|
||||
|
||||
|
@ -488,7 +484,7 @@ func (r *RemoteList) unlockedPrependV6(ownerVpnIp iputil.VpnIp, to *Ip6AndPort)
|
|||
}
|
||||
}
|
||||
|
||||
func (r *RemoteList) unlockedGetOrMakeRelay(ownerVpnIp iputil.VpnIp) *cacheRelay {
|
||||
func (r *RemoteList) unlockedGetOrMakeRelay(ownerVpnIp netip.Addr) *cacheRelay {
|
||||
am := r.cache[ownerVpnIp]
|
||||
if am == nil {
|
||||
am = &cache{}
|
||||
|
@ -503,7 +499,7 @@ func (r *RemoteList) unlockedGetOrMakeRelay(ownerVpnIp iputil.VpnIp) *cacheRelay
|
|||
|
||||
// unlockedGetOrMakeV4 assumes you have the write lock and builds the cache and owner entry. Only the v4 pointer is established.
|
||||
// The caller must dirty the learned address cache if required
|
||||
func (r *RemoteList) unlockedGetOrMakeV4(ownerVpnIp iputil.VpnIp) *cacheV4 {
|
||||
func (r *RemoteList) unlockedGetOrMakeV4(ownerVpnIp netip.Addr) *cacheV4 {
|
||||
am := r.cache[ownerVpnIp]
|
||||
if am == nil {
|
||||
am = &cache{}
|
||||
|
@ -518,7 +514,7 @@ func (r *RemoteList) unlockedGetOrMakeV4(ownerVpnIp iputil.VpnIp) *cacheV4 {
|
|||
|
||||
// unlockedGetOrMakeV6 assumes you have the write lock and builds the cache and owner entry. Only the v6 pointer is established.
|
||||
// The caller must dirty the learned address cache if required
|
||||
func (r *RemoteList) unlockedGetOrMakeV6(ownerVpnIp iputil.VpnIp) *cacheV6 {
|
||||
func (r *RemoteList) unlockedGetOrMakeV6(ownerVpnIp netip.Addr) *cacheV6 {
|
||||
am := r.cache[ownerVpnIp]
|
||||
if am == nil {
|
||||
am = &cache{}
|
||||
|
@ -540,14 +536,14 @@ func (r *RemoteList) unlockedCollect() {
|
|||
for _, c := range r.cache {
|
||||
if c.v4 != nil {
|
||||
if c.v4.learned != nil {
|
||||
u := NewUDPAddrFromLH4(c.v4.learned)
|
||||
u := AddrPortFromIp4AndPort(c.v4.learned)
|
||||
if !r.unlockedIsBad(u) {
|
||||
addrs = append(addrs, u)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range c.v4.reported {
|
||||
u := NewUDPAddrFromLH4(v)
|
||||
u := AddrPortFromIp4AndPort(v)
|
||||
if !r.unlockedIsBad(u) {
|
||||
addrs = append(addrs, u)
|
||||
}
|
||||
|
@ -556,14 +552,14 @@ func (r *RemoteList) unlockedCollect() {
|
|||
|
||||
if c.v6 != nil {
|
||||
if c.v6.learned != nil {
|
||||
u := NewUDPAddrFromLH6(c.v6.learned)
|
||||
u := AddrPortFromIp6AndPort(c.v6.learned)
|
||||
if !r.unlockedIsBad(u) {
|
||||
addrs = append(addrs, u)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range c.v6.reported {
|
||||
u := NewUDPAddrFromLH6(v)
|
||||
u := AddrPortFromIp6AndPort(v)
|
||||
if !r.unlockedIsBad(u) {
|
||||
addrs = append(addrs, u)
|
||||
}
|
||||
|
@ -572,8 +568,7 @@ func (r *RemoteList) unlockedCollect() {
|
|||
|
||||
if c.relay != nil {
|
||||
for _, v := range c.relay.relay {
|
||||
ip := iputil.VpnIp(v)
|
||||
relays = append(relays, &ip)
|
||||
relays = append(relays, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -581,11 +576,7 @@ func (r *RemoteList) unlockedCollect() {
|
|||
dnsAddrs := r.hr.GetIPs()
|
||||
for _, addr := range dnsAddrs {
|
||||
if r.shouldAdd == nil || r.shouldAdd(addr.Addr()) {
|
||||
v6 := addr.Addr().As16()
|
||||
addrs = append(addrs, &udp.Addr{
|
||||
IP: v6[:],
|
||||
Port: addr.Port(),
|
||||
})
|
||||
addrs = append(addrs, addr)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -595,7 +586,7 @@ func (r *RemoteList) unlockedCollect() {
|
|||
}
|
||||
|
||||
// unlockedSort assumes you have the write lock and performs the deduping and sorting of the address list
|
||||
func (r *RemoteList) unlockedSort(preferredRanges []*net.IPNet) {
|
||||
func (r *RemoteList) unlockedSort(preferredRanges []netip.Prefix) {
|
||||
n := len(r.addrs)
|
||||
if n < 2 {
|
||||
return
|
||||
|
@ -606,8 +597,8 @@ func (r *RemoteList) unlockedSort(preferredRanges []*net.IPNet) {
|
|||
b := r.addrs[j]
|
||||
// Preferred addresses first
|
||||
|
||||
aPref := isPreferred(a.IP, preferredRanges)
|
||||
bPref := isPreferred(b.IP, preferredRanges)
|
||||
aPref := isPreferred(a.Addr(), preferredRanges)
|
||||
bPref := isPreferred(b.Addr(), preferredRanges)
|
||||
switch {
|
||||
case aPref && !bPref:
|
||||
// If i is preferred and j is not, i is less than j
|
||||
|
@ -622,21 +613,21 @@ func (r *RemoteList) unlockedSort(preferredRanges []*net.IPNet) {
|
|||
}
|
||||
|
||||
// ipv6 addresses 2nd
|
||||
a4 := a.IP.To4()
|
||||
b4 := b.IP.To4()
|
||||
a4 := a.Addr().Is4()
|
||||
b4 := b.Addr().Is4()
|
||||
switch {
|
||||
case a4 == nil && b4 != nil:
|
||||
case a4 == false && b4 == true:
|
||||
// If i is v6 and j is v4, i is less than j
|
||||
return true
|
||||
|
||||
case a4 != nil && b4 == nil:
|
||||
case a4 == true && b4 == false:
|
||||
// If j is v6 and i is v4, i is not less than j
|
||||
return false
|
||||
|
||||
case a4 != nil && b4 != nil:
|
||||
// Special case for ipv4, a4 and b4 are not nil
|
||||
aPrivate := isPrivateIP(a4)
|
||||
bPrivate := isPrivateIP(b4)
|
||||
case a4 == true && b4 == true:
|
||||
// i and j are both ipv4
|
||||
aPrivate := a.Addr().IsPrivate()
|
||||
bPrivate := b.Addr().IsPrivate()
|
||||
switch {
|
||||
case !aPrivate && bPrivate:
|
||||
// If i is a public ip (not private) and j is a private ip, i is less then j
|
||||
|
@ -655,10 +646,10 @@ func (r *RemoteList) unlockedSort(preferredRanges []*net.IPNet) {
|
|||
}
|
||||
|
||||
// lexical order of ips 3rd
|
||||
c := bytes.Compare(a.IP, b.IP)
|
||||
c := a.Addr().Compare(b.Addr())
|
||||
if c == 0 {
|
||||
// Ips are the same, Lexical order of ports 4th
|
||||
return a.Port < b.Port
|
||||
return a.Port() < b.Port()
|
||||
}
|
||||
|
||||
// Ip wasn't the same
|
||||
|
@ -671,7 +662,7 @@ func (r *RemoteList) unlockedSort(preferredRanges []*net.IPNet) {
|
|||
// Deduplicate
|
||||
a, b := 0, 1
|
||||
for b < n {
|
||||
if !r.addrs[a].Equals(r.addrs[b]) {
|
||||
if r.addrs[a] != r.addrs[b] {
|
||||
a++
|
||||
if a != b {
|
||||
r.addrs[a], r.addrs[b] = r.addrs[b], r.addrs[a]
|
||||
|
@ -693,7 +684,7 @@ func minInt(a, b int) int {
|
|||
}
|
||||
|
||||
// isPreferred returns true of the ip is contained in the preferredRanges list
|
||||
func isPreferred(ip net.IP, preferredRanges []*net.IPNet) bool {
|
||||
func isPreferred(ip netip.Addr, preferredRanges []netip.Prefix) bool {
|
||||
//TODO: this would be better in a CIDR6Tree
|
||||
for _, p := range preferredRanges {
|
||||
if p.Contains(ip) {
|
||||
|
@ -702,14 +693,3 @@ func isPreferred(ip net.IP, preferredRanges []*net.IPNet) bool {
|
|||
}
|
||||
return false
|
||||
}
|
||||
|
||||
var _, private24BitBlock, _ = net.ParseCIDR("10.0.0.0/8")
|
||||
var _, private20BitBlock, _ = net.ParseCIDR("172.16.0.0/12")
|
||||
var _, private16BitBlock, _ = net.ParseCIDR("192.168.0.0/16")
|
||||
|
||||
// isPrivateIP returns true if the ip is contained by a rfc 1918 private range
|
||||
func isPrivateIP(ip net.IP) bool {
|
||||
//TODO: another great cidrtree option
|
||||
//TODO: Private for ipv6 or just let it ride?
|
||||
return private24BitBlock.Contains(ip) || private20BitBlock.Contains(ip) || private16BitBlock.Contains(ip)
|
||||
}
|
||||
|
|
|
@ -1,47 +1,47 @@
|
|||
package nebula
|
||||
|
||||
import (
|
||||
"net"
|
||||
"encoding/binary"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestRemoteList_Rebuild(t *testing.T) {
|
||||
rl := NewRemoteList(nil)
|
||||
rl.unlockedSetV4(
|
||||
0,
|
||||
0,
|
||||
netip.MustParseAddr("0.0.0.0"),
|
||||
netip.MustParseAddr("0.0.0.0"),
|
||||
[]*Ip4AndPort{
|
||||
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475}, // this is duped
|
||||
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.0.182"))), Port: 10101},
|
||||
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, // this is duped
|
||||
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101}, // this is duped
|
||||
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101}, // this is a dupe
|
||||
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.19.0.1"))), Port: 10101},
|
||||
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.31.0.1"))), Port: 10101},
|
||||
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, // this is a dupe
|
||||
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1476}, // almost dupe of 0 with a diff port
|
||||
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475}, // this is a dupe
|
||||
newIp4AndPortFromString("70.199.182.92:1475"), // this is duped
|
||||
newIp4AndPortFromString("172.17.0.182:10101"),
|
||||
newIp4AndPortFromString("172.17.1.1:10101"), // this is duped
|
||||
newIp4AndPortFromString("172.18.0.1:10101"), // this is duped
|
||||
newIp4AndPortFromString("172.18.0.1:10101"), // this is a dupe
|
||||
newIp4AndPortFromString("172.19.0.1:10101"),
|
||||
newIp4AndPortFromString("172.31.0.1:10101"),
|
||||
newIp4AndPortFromString("172.17.1.1:10101"), // this is a dupe
|
||||
newIp4AndPortFromString("70.199.182.92:1476"), // almost dupe of 0 with a diff port
|
||||
newIp4AndPortFromString("70.199.182.92:1475"), // this is a dupe
|
||||
},
|
||||
func(iputil.VpnIp, *Ip4AndPort) bool { return true },
|
||||
func(netip.Addr, *Ip4AndPort) bool { return true },
|
||||
)
|
||||
|
||||
rl.unlockedSetV6(
|
||||
1,
|
||||
1,
|
||||
netip.MustParseAddr("0.0.0.1"),
|
||||
netip.MustParseAddr("0.0.0.1"),
|
||||
[]*Ip6AndPort{
|
||||
NewIp6AndPort(net.ParseIP("1::1"), 1), // this is duped
|
||||
NewIp6AndPort(net.ParseIP("1::1"), 2), // almost dupe of 0 with a diff port, also gets duped
|
||||
NewIp6AndPort(net.ParseIP("1:100::1"), 1),
|
||||
NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe
|
||||
NewIp6AndPort(net.ParseIP("1::1"), 2), // this is a dupe
|
||||
newIp6AndPortFromString("[1::1]:1"), // this is duped
|
||||
newIp6AndPortFromString("[1::1]:2"), // almost dupe of 0 with a diff port, also gets duped
|
||||
newIp6AndPortFromString("[1:100::1]:1"),
|
||||
newIp6AndPortFromString("[1::1]:1"), // this is a dupe
|
||||
newIp6AndPortFromString("[1::1]:2"), // this is a dupe
|
||||
},
|
||||
func(iputil.VpnIp, *Ip6AndPort) bool { return true },
|
||||
func(netip.Addr, *Ip6AndPort) bool { return true },
|
||||
)
|
||||
|
||||
rl.Rebuild([]*net.IPNet{})
|
||||
rl.Rebuild([]netip.Prefix{})
|
||||
assert.Len(t, rl.addrs, 10, "addrs contains too many entries")
|
||||
|
||||
// ipv6 first, sorted lexically within
|
||||
|
@ -59,9 +59,7 @@ func TestRemoteList_Rebuild(t *testing.T) {
|
|||
assert.Equal(t, "172.31.0.1:10101", rl.addrs[9].String())
|
||||
|
||||
// Now ensure we can hoist ipv4 up
|
||||
_, ipNet, err := net.ParseCIDR("0.0.0.0/0")
|
||||
assert.NoError(t, err)
|
||||
rl.Rebuild([]*net.IPNet{ipNet})
|
||||
rl.Rebuild([]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")})
|
||||
assert.Len(t, rl.addrs, 10, "addrs contains too many entries")
|
||||
|
||||
// ipv4 first, public then private, lexically within them
|
||||
|
@ -79,9 +77,7 @@ func TestRemoteList_Rebuild(t *testing.T) {
|
|||
assert.Equal(t, "[1:100::1]:1", rl.addrs[9].String())
|
||||
|
||||
// Ensure we can hoist a specific ipv4 range over anything else
|
||||
_, ipNet, err = net.ParseCIDR("172.17.0.0/16")
|
||||
assert.NoError(t, err)
|
||||
rl.Rebuild([]*net.IPNet{ipNet})
|
||||
rl.Rebuild([]netip.Prefix{netip.MustParsePrefix("172.17.0.0/16")})
|
||||
assert.Len(t, rl.addrs, 10, "addrs contains too many entries")
|
||||
|
||||
// Preferred ipv4 first
|
||||
|
@ -104,64 +100,61 @@ func TestRemoteList_Rebuild(t *testing.T) {
|
|||
func BenchmarkFullRebuild(b *testing.B) {
|
||||
rl := NewRemoteList(nil)
|
||||
rl.unlockedSetV4(
|
||||
0,
|
||||
0,
|
||||
netip.MustParseAddr("0.0.0.0"),
|
||||
netip.MustParseAddr("0.0.0.0"),
|
||||
[]*Ip4AndPort{
|
||||
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475},
|
||||
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.0.182"))), Port: 10101},
|
||||
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101},
|
||||
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101},
|
||||
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.19.0.1"))), Port: 10101},
|
||||
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.31.0.1"))), Port: 10101},
|
||||
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, // this is a dupe
|
||||
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1476}, // dupe of 0 with a diff port
|
||||
newIp4AndPortFromString("70.199.182.92:1475"),
|
||||
newIp4AndPortFromString("172.17.0.182:10101"),
|
||||
newIp4AndPortFromString("172.17.1.1:10101"),
|
||||
newIp4AndPortFromString("172.18.0.1:10101"),
|
||||
newIp4AndPortFromString("172.19.0.1:10101"),
|
||||
newIp4AndPortFromString("172.31.0.1:10101"),
|
||||
newIp4AndPortFromString("172.17.1.1:10101"), // this is a dupe
|
||||
newIp4AndPortFromString("70.199.182.92:1476"), // dupe of 0 with a diff port
|
||||
},
|
||||
func(iputil.VpnIp, *Ip4AndPort) bool { return true },
|
||||
func(netip.Addr, *Ip4AndPort) bool { return true },
|
||||
)
|
||||
|
||||
rl.unlockedSetV6(
|
||||
0,
|
||||
0,
|
||||
netip.MustParseAddr("0.0.0.0"),
|
||||
netip.MustParseAddr("0.0.0.0"),
|
||||
[]*Ip6AndPort{
|
||||
NewIp6AndPort(net.ParseIP("1::1"), 1),
|
||||
NewIp6AndPort(net.ParseIP("1::1"), 2), // dupe of 0 with a diff port
|
||||
NewIp6AndPort(net.ParseIP("1:100::1"), 1),
|
||||
NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe
|
||||
newIp6AndPortFromString("[1::1]:1"),
|
||||
newIp6AndPortFromString("[1::1]:2"), // dupe of 0 with a diff port
|
||||
newIp6AndPortFromString("[1:100::1]:1"),
|
||||
newIp6AndPortFromString("[1::1]:1"), // this is a dupe
|
||||
},
|
||||
func(iputil.VpnIp, *Ip6AndPort) bool { return true },
|
||||
func(netip.Addr, *Ip6AndPort) bool { return true },
|
||||
)
|
||||
|
||||
b.Run("no preferred", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
rl.shouldRebuild = true
|
||||
rl.Rebuild([]*net.IPNet{})
|
||||
rl.Rebuild([]netip.Prefix{})
|
||||
}
|
||||
})
|
||||
|
||||
_, ipNet, err := net.ParseCIDR("172.17.0.0/16")
|
||||
assert.NoError(b, err)
|
||||
ipNet1 := netip.MustParsePrefix("172.17.0.0/16")
|
||||
b.Run("1 preferred", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
rl.shouldRebuild = true
|
||||
rl.Rebuild([]*net.IPNet{ipNet})
|
||||
rl.Rebuild([]netip.Prefix{ipNet1})
|
||||
}
|
||||
})
|
||||
|
||||
_, ipNet2, err := net.ParseCIDR("70.0.0.0/8")
|
||||
assert.NoError(b, err)
|
||||
ipNet2 := netip.MustParsePrefix("70.0.0.0/8")
|
||||
b.Run("2 preferred", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
rl.shouldRebuild = true
|
||||
rl.Rebuild([]*net.IPNet{ipNet, ipNet2})
|
||||
rl.Rebuild([]netip.Prefix{ipNet2})
|
||||
}
|
||||
})
|
||||
|
||||
_, ipNet3, err := net.ParseCIDR("0.0.0.0/0")
|
||||
assert.NoError(b, err)
|
||||
ipNet3 := netip.MustParsePrefix("0.0.0.0/0")
|
||||
b.Run("3 preferred", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
rl.shouldRebuild = true
|
||||
rl.Rebuild([]*net.IPNet{ipNet, ipNet2, ipNet3})
|
||||
rl.Rebuild([]netip.Prefix{ipNet1, ipNet2, ipNet3})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -169,67 +162,83 @@ func BenchmarkFullRebuild(b *testing.B) {
|
|||
func BenchmarkSortRebuild(b *testing.B) {
|
||||
rl := NewRemoteList(nil)
|
||||
rl.unlockedSetV4(
|
||||
0,
|
||||
0,
|
||||
netip.MustParseAddr("0.0.0.0"),
|
||||
netip.MustParseAddr("0.0.0.0"),
|
||||
[]*Ip4AndPort{
|
||||
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475},
|
||||
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.0.182"))), Port: 10101},
|
||||
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101},
|
||||
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101},
|
||||
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.19.0.1"))), Port: 10101},
|
||||
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.31.0.1"))), Port: 10101},
|
||||
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, // this is a dupe
|
||||
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1476}, // dupe of 0 with a diff port
|
||||
newIp4AndPortFromString("70.199.182.92:1475"),
|
||||
newIp4AndPortFromString("172.17.0.182:10101"),
|
||||
newIp4AndPortFromString("172.17.1.1:10101"),
|
||||
newIp4AndPortFromString("172.18.0.1:10101"),
|
||||
newIp4AndPortFromString("172.19.0.1:10101"),
|
||||
newIp4AndPortFromString("172.31.0.1:10101"),
|
||||
newIp4AndPortFromString("172.17.1.1:10101"), // this is a dupe
|
||||
newIp4AndPortFromString("70.199.182.92:1476"), // dupe of 0 with a diff port
|
||||
},
|
||||
func(iputil.VpnIp, *Ip4AndPort) bool { return true },
|
||||
func(netip.Addr, *Ip4AndPort) bool { return true },
|
||||
)
|
||||
|
||||
rl.unlockedSetV6(
|
||||
0,
|
||||
0,
|
||||
netip.MustParseAddr("0.0.0.0"),
|
||||
netip.MustParseAddr("0.0.0.0"),
|
||||
[]*Ip6AndPort{
|
||||
NewIp6AndPort(net.ParseIP("1::1"), 1),
|
||||
NewIp6AndPort(net.ParseIP("1::1"), 2), // dupe of 0 with a diff port
|
||||
NewIp6AndPort(net.ParseIP("1:100::1"), 1),
|
||||
NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe
|
||||
newIp6AndPortFromString("[1::1]:1"),
|
||||
newIp6AndPortFromString("[1::1]:2"), // dupe of 0 with a diff port
|
||||
newIp6AndPortFromString("[1:100::1]:1"),
|
||||
newIp6AndPortFromString("[1::1]:1"), // this is a dupe
|
||||
},
|
||||
func(iputil.VpnIp, *Ip6AndPort) bool { return true },
|
||||
func(netip.Addr, *Ip6AndPort) bool { return true },
|
||||
)
|
||||
|
||||
b.Run("no preferred", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
rl.shouldRebuild = true
|
||||
rl.Rebuild([]*net.IPNet{})
|
||||
rl.Rebuild([]netip.Prefix{})
|
||||
}
|
||||
})
|
||||
|
||||
_, ipNet, err := net.ParseCIDR("172.17.0.0/16")
|
||||
rl.Rebuild([]*net.IPNet{ipNet})
|
||||
ipNet1 := netip.MustParsePrefix("172.17.0.0/16")
|
||||
rl.Rebuild([]netip.Prefix{ipNet1})
|
||||
|
||||
assert.NoError(b, err)
|
||||
b.Run("1 preferred", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
rl.Rebuild([]*net.IPNet{ipNet})
|
||||
rl.Rebuild([]netip.Prefix{ipNet1})
|
||||
}
|
||||
})
|
||||
|
||||
_, ipNet2, err := net.ParseCIDR("70.0.0.0/8")
|
||||
rl.Rebuild([]*net.IPNet{ipNet, ipNet2})
|
||||
ipNet2 := netip.MustParsePrefix("70.0.0.0/8")
|
||||
rl.Rebuild([]netip.Prefix{ipNet1, ipNet2})
|
||||
|
||||
assert.NoError(b, err)
|
||||
b.Run("2 preferred", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
rl.Rebuild([]*net.IPNet{ipNet, ipNet2})
|
||||
rl.Rebuild([]netip.Prefix{ipNet1, ipNet2})
|
||||
}
|
||||
})
|
||||
|
||||
_, ipNet3, err := net.ParseCIDR("0.0.0.0/0")
|
||||
rl.Rebuild([]*net.IPNet{ipNet, ipNet2, ipNet3})
|
||||
ipNet3 := netip.MustParsePrefix("0.0.0.0/0")
|
||||
rl.Rebuild([]netip.Prefix{ipNet1, ipNet2, ipNet3})
|
||||
|
||||
assert.NoError(b, err)
|
||||
b.Run("3 preferred", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
rl.Rebuild([]*net.IPNet{ipNet, ipNet2, ipNet3})
|
||||
rl.Rebuild([]netip.Prefix{ipNet1, ipNet2, ipNet3})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func newIp4AndPortFromString(s string) *Ip4AndPort {
|
||||
a := netip.MustParseAddrPort(s)
|
||||
v4Addr := a.Addr().As4()
|
||||
return &Ip4AndPort{
|
||||
Ip: binary.BigEndian.Uint32(v4Addr[:]),
|
||||
Port: uint32(a.Port()),
|
||||
}
|
||||
}
|
||||
|
||||
func newIp6AndPortFromString(s string) *Ip6AndPort {
|
||||
a := netip.MustParseAddrPort(s)
|
||||
v6Addr := a.Addr().As16()
|
||||
return &Ip6AndPort{
|
||||
Hi: binary.BigEndian.Uint64(v6Addr[:8]),
|
||||
Lo: binary.BigEndian.Uint64(v6Addr[8:]),
|
||||
Port: uint32(a.Port()),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -91,7 +91,7 @@ func New(config *config.C) (*Service, error) {
|
|||
|
||||
ipNet := device.Cidr()
|
||||
pa := tcpip.ProtocolAddress{
|
||||
AddressWithPrefix: tcpip.AddrFromSlice(ipNet.IP).WithPrefix(),
|
||||
AddressWithPrefix: tcpip.AddrFromSlice(ipNet.Addr().AsSlice()).WithPrefix(),
|
||||
Protocol: ipv4.ProtocolNumber,
|
||||
}
|
||||
if err := s.ipstack.AddProtocolAddress(nicID, pa, stack.AddressProperties{
|
||||
|
|
|
@ -4,7 +4,7 @@ import (
|
|||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -18,12 +18,8 @@ import (
|
|||
|
||||
type m map[string]interface{}
|
||||
|
||||
func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp net.IP, overrides m) *Service {
|
||||
|
||||
vpnIpNet := &net.IPNet{IP: make([]byte, len(udpIp)), Mask: net.IPMask{255, 255, 255, 0}}
|
||||
copy(vpnIpNet.IP, udpIp)
|
||||
|
||||
_, _, myPrivKey, myPEM := e2e.NewTestCert(caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), vpnIpNet, nil, []string{})
|
||||
func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp netip.Addr, overrides m) *Service {
|
||||
_, _, myPrivKey, myPEM := e2e.NewTestCert(caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), netip.PrefixFrom(udpIp, 24), nil, []string{})
|
||||
caB, err := caCrt.MarshalToPEM()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
|
@ -83,8 +79,8 @@ func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string,
|
|||
}
|
||||
|
||||
func TestService(t *testing.T) {
|
||||
ca, _, caKey, _ := e2e.NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
|
||||
a := newSimpleService(ca, caKey, "a", net.IP{10, 0, 0, 1}, m{
|
||||
ca, _, caKey, _ := e2e.NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
a := newSimpleService(ca, caKey, "a", netip.MustParseAddr("10.0.0.1"), m{
|
||||
"static_host_map": m{},
|
||||
"lighthouse": m{
|
||||
"am_lighthouse": true,
|
||||
|
@ -94,7 +90,7 @@ func TestService(t *testing.T) {
|
|||
"port": 4243,
|
||||
},
|
||||
})
|
||||
b := newSimpleService(ca, caKey, "b", net.IP{10, 0, 0, 2}, m{
|
||||
b := newSimpleService(ca, caKey, "b", netip.MustParseAddr("10.0.0.2"), m{
|
||||
"static_host_map": m{
|
||||
"10.0.0.1": []string{"localhost:4243"},
|
||||
},
|
||||
|
|
65
ssh.go
65
ssh.go
|
@ -7,6 +7,7 @@ import (
|
|||
"flag"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"reflect"
|
||||
"runtime"
|
||||
|
@ -18,9 +19,7 @@ import (
|
|||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/header"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/slackhq/nebula/sshd"
|
||||
"github.com/slackhq/nebula/udp"
|
||||
)
|
||||
|
||||
type sshListHostMapFlags struct {
|
||||
|
@ -431,7 +430,7 @@ func sshListHostMap(hl controlHostLister, a interface{}, w sshd.StringWriter) er
|
|||
}
|
||||
|
||||
sort.Slice(hm, func(i, j int) bool {
|
||||
return bytes.Compare(hm[i].VpnIp, hm[j].VpnIp) < 0
|
||||
return hm[i].VpnIp.Compare(hm[j].VpnIp) < 0
|
||||
})
|
||||
|
||||
if fs.Json || fs.Pretty {
|
||||
|
@ -545,13 +544,12 @@ func sshQueryLighthouse(ifce *Interface, fs interface{}, a []string, w sshd.Stri
|
|||
return w.WriteLine("No vpn ip was provided")
|
||||
}
|
||||
|
||||
parsedIp := net.ParseIP(a[0])
|
||||
if parsedIp == nil {
|
||||
vpnIp, err := netip.ParseAddr(a[0])
|
||||
if err != nil {
|
||||
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
||||
}
|
||||
|
||||
vpnIp := iputil.Ip2VpnIp(parsedIp)
|
||||
if vpnIp == 0 {
|
||||
if !vpnIp.IsValid() {
|
||||
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
||||
}
|
||||
|
||||
|
@ -574,13 +572,12 @@ func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
|
|||
return w.WriteLine("No vpn ip was provided")
|
||||
}
|
||||
|
||||
parsedIp := net.ParseIP(a[0])
|
||||
if parsedIp == nil {
|
||||
vpnIp, err := netip.ParseAddr(a[0])
|
||||
if err != nil {
|
||||
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
||||
}
|
||||
|
||||
vpnIp := iputil.Ip2VpnIp(parsedIp)
|
||||
if vpnIp == 0 {
|
||||
if !vpnIp.IsValid() {
|
||||
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
||||
}
|
||||
|
||||
|
@ -616,13 +613,12 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW
|
|||
return w.WriteLine("No vpn ip was provided")
|
||||
}
|
||||
|
||||
parsedIp := net.ParseIP(a[0])
|
||||
if parsedIp == nil {
|
||||
vpnIp, err := netip.ParseAddr(a[0])
|
||||
if err != nil {
|
||||
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
||||
}
|
||||
|
||||
vpnIp := iputil.Ip2VpnIp(parsedIp)
|
||||
if vpnIp == 0 {
|
||||
if !vpnIp.IsValid() {
|
||||
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
||||
}
|
||||
|
||||
|
@ -636,16 +632,16 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW
|
|||
return w.WriteLine(fmt.Sprintf("Tunnel already handshaking"))
|
||||
}
|
||||
|
||||
var addr *udp.Addr
|
||||
var addr netip.AddrPort
|
||||
if flags.Address != "" {
|
||||
addr = udp.NewAddrFromString(flags.Address)
|
||||
if addr == nil {
|
||||
addr, err = netip.ParseAddrPort(flags.Address)
|
||||
if err != nil {
|
||||
return w.WriteLine("Address could not be parsed")
|
||||
}
|
||||
}
|
||||
|
||||
hostInfo = ifce.handshakeManager.StartHandshake(vpnIp, nil)
|
||||
if addr != nil {
|
||||
if addr.IsValid() {
|
||||
hostInfo.SetRemote(addr)
|
||||
}
|
||||
|
||||
|
@ -667,18 +663,17 @@ func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringW
|
|||
return w.WriteLine("No address was provided")
|
||||
}
|
||||
|
||||
addr := udp.NewAddrFromString(flags.Address)
|
||||
if addr == nil {
|
||||
addr, err := netip.ParseAddrPort(flags.Address)
|
||||
if err != nil {
|
||||
return w.WriteLine("Address could not be parsed")
|
||||
}
|
||||
|
||||
parsedIp := net.ParseIP(a[0])
|
||||
if parsedIp == nil {
|
||||
vpnIp, err := netip.ParseAddr(a[0])
|
||||
if err != nil {
|
||||
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
||||
}
|
||||
|
||||
vpnIp := iputil.Ip2VpnIp(parsedIp)
|
||||
if vpnIp == 0 {
|
||||
if !vpnIp.IsValid() {
|
||||
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
||||
}
|
||||
|
||||
|
@ -792,13 +787,12 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit
|
|||
|
||||
cert := ifce.pki.GetCertState().Certificate
|
||||
if len(a) > 0 {
|
||||
parsedIp := net.ParseIP(a[0])
|
||||
if parsedIp == nil {
|
||||
vpnIp, err := netip.ParseAddr(a[0])
|
||||
if err != nil {
|
||||
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
||||
}
|
||||
|
||||
vpnIp := iputil.Ip2VpnIp(parsedIp)
|
||||
if vpnIp == 0 {
|
||||
if !vpnIp.IsValid() {
|
||||
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
||||
}
|
||||
|
||||
|
@ -862,14 +856,14 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
|
|||
Error error
|
||||
Type string
|
||||
State string
|
||||
PeerIp iputil.VpnIp
|
||||
PeerIp netip.Addr
|
||||
LocalIndex uint32
|
||||
RemoteIndex uint32
|
||||
RelayedThrough []iputil.VpnIp
|
||||
RelayedThrough []netip.Addr
|
||||
}
|
||||
|
||||
type RelayOutput struct {
|
||||
NebulaIp iputil.VpnIp
|
||||
NebulaIp netip.Addr
|
||||
RelayForIps []RelayFor
|
||||
}
|
||||
|
||||
|
@ -952,13 +946,12 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
|
|||
return w.WriteLine("No vpn ip was provided")
|
||||
}
|
||||
|
||||
parsedIp := net.ParseIP(a[0])
|
||||
if parsedIp == nil {
|
||||
vpnIp, err := netip.ParseAddr(a[0])
|
||||
if err != nil {
|
||||
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
||||
}
|
||||
|
||||
vpnIp := iputil.Ip2VpnIp(parsedIp)
|
||||
if vpnIp == 0 {
|
||||
if !vpnIp.IsValid() {
|
||||
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
||||
}
|
||||
|
||||
|
|
12
test/tun.go
12
test/tun.go
|
@ -3,23 +3,21 @@ package test
|
|||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
type NoopTun struct{}
|
||||
|
||||
func (NoopTun) RouteFor(iputil.VpnIp) iputil.VpnIp {
|
||||
return 0
|
||||
func (NoopTun) RouteFor(addr netip.Addr) netip.Addr {
|
||||
return netip.Addr{}
|
||||
}
|
||||
|
||||
func (NoopTun) Activate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (NoopTun) Cidr() *net.IPNet {
|
||||
return nil
|
||||
func (NoopTun) Cidr() netip.Prefix {
|
||||
return netip.Prefix{}
|
||||
}
|
||||
|
||||
func (NoopTun) Name() string {
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package nebula
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -115,10 +116,10 @@ func TestTimerWheel_Purge(t *testing.T) {
|
|||
assert.Equal(t, 0, tw.current)
|
||||
|
||||
fps := []firewall.Packet{
|
||||
{LocalIP: 1},
|
||||
{LocalIP: 2},
|
||||
{LocalIP: 3},
|
||||
{LocalIP: 4},
|
||||
{LocalIP: netip.MustParseAddr("0.0.0.1")},
|
||||
{LocalIP: netip.MustParseAddr("0.0.0.2")},
|
||||
{LocalIP: netip.MustParseAddr("0.0.0.3")},
|
||||
{LocalIP: netip.MustParseAddr("0.0.0.4")},
|
||||
}
|
||||
|
||||
tw.Add(fps[0], time.Second*1)
|
||||
|
|
14
udp/conn.go
14
udp/conn.go
|
@ -1,6 +1,8 @@
|
|||
package udp
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/firewall"
|
||||
"github.com/slackhq/nebula/header"
|
||||
|
@ -9,7 +11,7 @@ import (
|
|||
const MTU = 9001
|
||||
|
||||
type EncReader func(
|
||||
addr *Addr,
|
||||
addr netip.AddrPort,
|
||||
out []byte,
|
||||
packet []byte,
|
||||
header *header.H,
|
||||
|
@ -22,9 +24,9 @@ type EncReader func(
|
|||
|
||||
type Conn interface {
|
||||
Rebind() error
|
||||
LocalAddr() (*Addr, error)
|
||||
LocalAddr() (netip.AddrPort, error)
|
||||
ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int)
|
||||
WriteTo(b []byte, addr *Addr) error
|
||||
WriteTo(b []byte, addr netip.AddrPort) error
|
||||
ReloadConfig(c *config.C)
|
||||
Close() error
|
||||
}
|
||||
|
@ -34,13 +36,13 @@ type NoopConn struct{}
|
|||
func (NoopConn) Rebind() error {
|
||||
return nil
|
||||
}
|
||||
func (NoopConn) LocalAddr() (*Addr, error) {
|
||||
return nil, nil
|
||||
func (NoopConn) LocalAddr() (netip.AddrPort, error) {
|
||||
return netip.AddrPort{}, nil
|
||||
}
|
||||
func (NoopConn) ListenOut(_ EncReader, _ LightHouseHandlerFunc, _ *firewall.ConntrackCacheTicker, _ int) {
|
||||
return
|
||||
}
|
||||
func (NoopConn) WriteTo(_ []byte, _ *Addr) error {
|
||||
func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {
|
||||
return nil
|
||||
}
|
||||
func (NoopConn) ReloadConfig(_ *config.C) {
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
package udp
|
||||
|
||||
import (
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
//TODO: The items in this file belong in their own packages but doing that in a single PR is a nightmare
|
||||
|
||||
type LightHouseHandlerFunc func(rAddr *Addr, vpnIp iputil.VpnIp, p []byte)
|
||||
// TODO: IPV6-WORK this can likely be removed now
|
||||
type LightHouseHandlerFunc func(rAddr netip.AddrPort, vpnIp netip.Addr, p []byte)
|
||||
|
|
100
udp/udp_all.go
100
udp/udp_all.go
|
@ -1,100 +0,0 @@
|
|||
package udp
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type m map[string]interface{}
|
||||
|
||||
type Addr struct {
|
||||
IP net.IP
|
||||
Port uint16
|
||||
}
|
||||
|
||||
func NewAddr(ip net.IP, port uint16) *Addr {
|
||||
addr := Addr{IP: make([]byte, net.IPv6len), Port: port}
|
||||
copy(addr.IP, ip.To16())
|
||||
return &addr
|
||||
}
|
||||
|
||||
func NewAddrFromString(s string) *Addr {
|
||||
ip, port, err := ParseIPAndPort(s)
|
||||
//TODO: handle err
|
||||
_ = err
|
||||
return &Addr{IP: ip.To16(), Port: port}
|
||||
}
|
||||
|
||||
func (ua *Addr) Equals(t *Addr) bool {
|
||||
if t == nil || ua == nil {
|
||||
return t == nil && ua == nil
|
||||
}
|
||||
return ua.IP.Equal(t.IP) && ua.Port == t.Port
|
||||
}
|
||||
|
||||
func (ua *Addr) String() string {
|
||||
if ua == nil {
|
||||
return "<nil>"
|
||||
}
|
||||
|
||||
return net.JoinHostPort(ua.IP.String(), fmt.Sprintf("%v", ua.Port))
|
||||
}
|
||||
|
||||
func (ua *Addr) MarshalJSON() ([]byte, error) {
|
||||
if ua == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return json.Marshal(m{"ip": ua.IP, "port": ua.Port})
|
||||
}
|
||||
|
||||
func (ua *Addr) Copy() *Addr {
|
||||
if ua == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
nu := Addr{
|
||||
Port: ua.Port,
|
||||
IP: make(net.IP, len(ua.IP)),
|
||||
}
|
||||
|
||||
copy(nu.IP, ua.IP)
|
||||
return &nu
|
||||
}
|
||||
|
||||
type AddrSlice []*Addr
|
||||
|
||||
func (a AddrSlice) Equal(b AddrSlice) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
|
||||
for i := range a {
|
||||
if !a[i].Equals(b[i]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func ParseIPAndPort(s string) (net.IP, uint16, error) {
|
||||
rIp, sPort, err := net.SplitHostPort(s)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
addr, err := net.ResolveIPAddr("ip", rIp)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
iPort, err := strconv.Atoi(sPort)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return addr.IP, uint16(iPort), nil
|
||||
}
|
|
@ -6,13 +6,14 @@ package udp
|
|||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"syscall"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
|
||||
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
|
||||
return NewGenericListener(l, ip, port, multi, batch)
|
||||
}
|
||||
|
||||
|
|
|
@ -9,13 +9,14 @@ package udp
|
|||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"syscall"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
|
||||
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
|
||||
return NewGenericListener(l, ip, port, multi, batch)
|
||||
}
|
||||
|
||||
|
|
|
@ -8,13 +8,14 @@ package udp
|
|||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"syscall"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
|
||||
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
|
||||
return NewGenericListener(l, ip, port, multi, batch)
|
||||
}
|
||||
|
||||
|
|
|
@ -11,6 +11,7 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
|
@ -25,7 +26,7 @@ type GenericConn struct {
|
|||
|
||||
var _ Conn = &GenericConn{}
|
||||
|
||||
func NewGenericListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
|
||||
func NewGenericListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
|
||||
lc := NewListenConfig(multi)
|
||||
pc, err := lc.ListenPacket(context.TODO(), "udp", net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port)))
|
||||
if err != nil {
|
||||
|
@ -37,23 +38,24 @@ func NewGenericListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch
|
|||
return nil, fmt.Errorf("Unexpected PacketConn: %T %#v", pc, pc)
|
||||
}
|
||||
|
||||
func (u *GenericConn) WriteTo(b []byte, addr *Addr) error {
|
||||
_, err := u.UDPConn.WriteToUDP(b, &net.UDPAddr{IP: addr.IP, Port: int(addr.Port)})
|
||||
func (u *GenericConn) WriteTo(b []byte, addr netip.AddrPort) error {
|
||||
_, err := u.UDPConn.WriteToUDPAddrPort(b, addr)
|
||||
return err
|
||||
}
|
||||
|
||||
func (u *GenericConn) LocalAddr() (*Addr, error) {
|
||||
func (u *GenericConn) LocalAddr() (netip.AddrPort, error) {
|
||||
a := u.UDPConn.LocalAddr()
|
||||
|
||||
switch v := a.(type) {
|
||||
case *net.UDPAddr:
|
||||
addr := &Addr{IP: make([]byte, len(v.IP))}
|
||||
copy(addr.IP, v.IP)
|
||||
addr.Port = uint16(v.Port)
|
||||
return addr, nil
|
||||
addr, ok := netip.AddrFromSlice(v.IP)
|
||||
if !ok {
|
||||
return netip.AddrPort{}, fmt.Errorf("LocalAddr returned invalid IP address: %s", v.IP)
|
||||
}
|
||||
return netip.AddrPortFrom(addr, uint16(v.Port)), nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("LocalAddr returned: %#v", a)
|
||||
return netip.AddrPort{}, fmt.Errorf("LocalAddr returned: %#v", a)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -75,19 +77,26 @@ func (u *GenericConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *f
|
|||
buffer := make([]byte, MTU)
|
||||
h := &header.H{}
|
||||
fwPacket := &firewall.Packet{}
|
||||
udpAddr := &Addr{IP: make([]byte, 16)}
|
||||
nb := make([]byte, 12, 12)
|
||||
|
||||
for {
|
||||
// Just read one packet at a time
|
||||
n, rua, err := u.ReadFromUDP(buffer)
|
||||
n, rua, err := u.ReadFromUDPAddrPort(buffer)
|
||||
if err != nil {
|
||||
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
|
||||
return
|
||||
}
|
||||
|
||||
udpAddr.IP = rua.IP
|
||||
udpAddr.Port = uint16(rua.Port)
|
||||
r(udpAddr, plaintext[:0], buffer[:n], h, fwPacket, lhf, nb, q, cache.Get(u.l))
|
||||
r(
|
||||
netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()),
|
||||
plaintext[:0],
|
||||
buffer[:n],
|
||||
h,
|
||||
fwPacket,
|
||||
lhf,
|
||||
nb,
|
||||
q,
|
||||
cache.Get(u.l),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
|
@ -35,10 +36,9 @@ func maybeIPV4(ip net.IP) (net.IP, bool) {
|
|||
return ip, false
|
||||
}
|
||||
|
||||
func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
|
||||
ipV4, isV4 := maybeIPV4(ip)
|
||||
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
|
||||
af := unix.AF_INET6
|
||||
if isV4 {
|
||||
if ip.Is4() {
|
||||
af = unix.AF_INET
|
||||
}
|
||||
syscall.ForkLock.RLock()
|
||||
|
@ -61,13 +61,13 @@ func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (
|
|||
|
||||
//TODO: support multiple listening IPs (for limiting ipv6)
|
||||
var sa unix.Sockaddr
|
||||
if isV4 {
|
||||
if ip.Is4() {
|
||||
sa4 := &unix.SockaddrInet4{Port: port}
|
||||
copy(sa4.Addr[:], ipV4)
|
||||
sa4.Addr = ip.As4()
|
||||
sa = sa4
|
||||
} else {
|
||||
sa6 := &unix.SockaddrInet6{Port: port}
|
||||
copy(sa6.Addr[:], ip.To16())
|
||||
sa6.Addr = ip.As16()
|
||||
sa = sa6
|
||||
}
|
||||
if err = unix.Bind(fd, sa); err != nil {
|
||||
|
@ -79,7 +79,7 @@ func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (
|
|||
//v, err := unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_INCOMING_CPU)
|
||||
//l.Println(v, err)
|
||||
|
||||
return &StdConn{sysFd: fd, isV4: isV4, l: l, batch: batch}, err
|
||||
return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err
|
||||
}
|
||||
|
||||
func (u *StdConn) Rebind() error {
|
||||
|
@ -102,30 +102,29 @@ func (u *StdConn) GetSendBuffer() (int, error) {
|
|||
return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_SNDBUF)
|
||||
}
|
||||
|
||||
func (u *StdConn) LocalAddr() (*Addr, error) {
|
||||
func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
|
||||
sa, err := unix.Getsockname(u.sysFd)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return netip.AddrPort{}, err
|
||||
}
|
||||
|
||||
addr := &Addr{}
|
||||
switch sa := sa.(type) {
|
||||
case *unix.SockaddrInet4:
|
||||
addr.IP = net.IP{sa.Addr[0], sa.Addr[1], sa.Addr[2], sa.Addr[3]}.To16()
|
||||
addr.Port = uint16(sa.Port)
|
||||
case *unix.SockaddrInet6:
|
||||
addr.IP = sa.Addr[0:]
|
||||
addr.Port = uint16(sa.Port)
|
||||
}
|
||||
return netip.AddrPortFrom(netip.AddrFrom4(sa.Addr), uint16(sa.Port)), nil
|
||||
|
||||
return addr, nil
|
||||
case *unix.SockaddrInet6:
|
||||
return netip.AddrPortFrom(netip.AddrFrom16(sa.Addr), uint16(sa.Port)), nil
|
||||
|
||||
default:
|
||||
return netip.AddrPort{}, fmt.Errorf("unsupported sock type: %T", sa)
|
||||
}
|
||||
}
|
||||
|
||||
func (u *StdConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
|
||||
plaintext := make([]byte, MTU)
|
||||
h := &header.H{}
|
||||
fwPacket := &firewall.Packet{}
|
||||
udpAddr := &Addr{}
|
||||
var ip netip.Addr
|
||||
nb := make([]byte, 12, 12)
|
||||
|
||||
//TODO: should we track this?
|
||||
|
@ -146,12 +145,23 @@ func (u *StdConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firew
|
|||
//metric.Update(int64(n))
|
||||
for i := 0; i < n; i++ {
|
||||
if u.isV4 {
|
||||
udpAddr.IP = names[i][4:8]
|
||||
ip, _ = netip.AddrFromSlice(names[i][4:8])
|
||||
//TODO: IPV6-WORK what is not ok?
|
||||
} else {
|
||||
udpAddr.IP = names[i][8:24]
|
||||
ip, _ = netip.AddrFromSlice(names[i][8:24])
|
||||
//TODO: IPV6-WORK what is not ok?
|
||||
}
|
||||
udpAddr.Port = binary.BigEndian.Uint16(names[i][2:4])
|
||||
r(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], h, fwPacket, lhf, nb, q, cache.Get(u.l))
|
||||
r(
|
||||
netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])),
|
||||
plaintext[:0],
|
||||
buffers[i][:msgs[i].Len],
|
||||
h,
|
||||
fwPacket,
|
||||
lhf,
|
||||
nb,
|
||||
q,
|
||||
cache.Get(u.l),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -197,19 +207,20 @@ func (u *StdConn) ReadMulti(msgs []rawMessage) (int, error) {
|
|||
}
|
||||
}
|
||||
|
||||
func (u *StdConn) WriteTo(b []byte, addr *Addr) error {
|
||||
func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error {
|
||||
if u.isV4 {
|
||||
return u.writeTo4(b, addr)
|
||||
return u.writeTo4(b, ip)
|
||||
}
|
||||
return u.writeTo6(b, addr)
|
||||
return u.writeTo6(b, ip)
|
||||
}
|
||||
|
||||
func (u *StdConn) writeTo6(b []byte, addr *Addr) error {
|
||||
func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error {
|
||||
var rsa unix.RawSockaddrInet6
|
||||
rsa.Family = unix.AF_INET6
|
||||
rsa.Addr = ip.Addr().As16()
|
||||
port := ip.Port()
|
||||
// Little Endian -> Network Endian
|
||||
rsa.Port = (addr.Port >> 8) | ((addr.Port & 0xff) << 8)
|
||||
copy(rsa.Addr[:], addr.IP.To16())
|
||||
rsa.Port = (port >> 8) | ((port & 0xff) << 8)
|
||||
|
||||
for {
|
||||
_, _, err := unix.Syscall6(
|
||||
|
@ -232,17 +243,17 @@ func (u *StdConn) writeTo6(b []byte, addr *Addr) error {
|
|||
}
|
||||
}
|
||||
|
||||
func (u *StdConn) writeTo4(b []byte, addr *Addr) error {
|
||||
addrV4, isAddrV4 := maybeIPV4(addr.IP)
|
||||
if !isAddrV4 {
|
||||
func (u *StdConn) writeTo4(b []byte, ip netip.AddrPort) error {
|
||||
if !ip.Addr().Is4() {
|
||||
return fmt.Errorf("Listener is IPv4, but writing to IPv6 remote")
|
||||
}
|
||||
|
||||
var rsa unix.RawSockaddrInet4
|
||||
rsa.Family = unix.AF_INET
|
||||
rsa.Addr = ip.Addr().As4()
|
||||
port := ip.Port()
|
||||
// Little Endian -> Network Endian
|
||||
rsa.Port = (addr.Port >> 8) | ((addr.Port & 0xff) << 8)
|
||||
copy(rsa.Addr[:], addrV4)
|
||||
rsa.Port = (port >> 8) | ((port & 0xff) << 8)
|
||||
|
||||
for {
|
||||
_, _, err := unix.Syscall6(
|
||||
|
|
|
@ -8,13 +8,14 @@ package udp
|
|||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"syscall"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
|
||||
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
|
||||
return NewGenericListener(l, ip, port, multi, batch)
|
||||
}
|
||||
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
|
@ -61,16 +62,14 @@ type RIOConn struct {
|
|||
results [packetsPerRing]winrio.Result
|
||||
}
|
||||
|
||||
func NewRIOListener(l *logrus.Logger, ip net.IP, port int) (*RIOConn, error) {
|
||||
func NewRIOListener(l *logrus.Logger, addr netip.Addr, port int) (*RIOConn, error) {
|
||||
if !winrio.Initialize() {
|
||||
return nil, errors.New("could not initialize winrio")
|
||||
}
|
||||
|
||||
u := &RIOConn{l: l}
|
||||
|
||||
addr := [16]byte{}
|
||||
copy(addr[:], ip.To16())
|
||||
err := u.bind(&windows.SockaddrInet6{Addr: addr, Port: port})
|
||||
err := u.bind(&windows.SockaddrInet6{Addr: addr.As16(), Port: port})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("bind: %w", err)
|
||||
}
|
||||
|
@ -124,7 +123,6 @@ func (u *RIOConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firew
|
|||
buffer := make([]byte, MTU)
|
||||
h := &header.H{}
|
||||
fwPacket := &firewall.Packet{}
|
||||
udpAddr := &Addr{IP: make([]byte, 16)}
|
||||
nb := make([]byte, 12, 12)
|
||||
|
||||
for {
|
||||
|
@ -135,11 +133,17 @@ func (u *RIOConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firew
|
|||
return
|
||||
}
|
||||
|
||||
udpAddr.IP = rua.Addr[:]
|
||||
p := (*[2]byte)(unsafe.Pointer(&udpAddr.Port))
|
||||
p[0] = byte(rua.Port >> 8)
|
||||
p[1] = byte(rua.Port)
|
||||
r(udpAddr, plaintext[:0], buffer[:n], h, fwPacket, lhf, nb, q, cache.Get(u.l))
|
||||
r(
|
||||
netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)),
|
||||
plaintext[:0],
|
||||
buffer[:n],
|
||||
h,
|
||||
fwPacket,
|
||||
lhf,
|
||||
nb,
|
||||
q,
|
||||
cache.Get(u.l),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -231,7 +235,7 @@ retry:
|
|||
return n, ep, nil
|
||||
}
|
||||
|
||||
func (u *RIOConn) WriteTo(buf []byte, addr *Addr) error {
|
||||
func (u *RIOConn) WriteTo(buf []byte, ip netip.AddrPort) error {
|
||||
if !u.isOpen.Load() {
|
||||
return net.ErrClosed
|
||||
}
|
||||
|
@ -274,10 +278,9 @@ func (u *RIOConn) WriteTo(buf []byte, addr *Addr) error {
|
|||
|
||||
packet := u.tx.Push()
|
||||
packet.addr.Family = windows.AF_INET6
|
||||
p := (*[2]byte)(unsafe.Pointer(&packet.addr.Port))
|
||||
p[0] = byte(addr.Port >> 8)
|
||||
p[1] = byte(addr.Port)
|
||||
copy(packet.addr.Addr[:], addr.IP.To16())
|
||||
packet.addr.Addr = ip.Addr().As16()
|
||||
port := ip.Port()
|
||||
packet.addr.Port = (port >> 8) | ((port & 0xff) << 8)
|
||||
copy(packet.data[:], buf)
|
||||
|
||||
dataBuffer := &winrio.Buffer{
|
||||
|
@ -295,17 +298,15 @@ func (u *RIOConn) WriteTo(buf []byte, addr *Addr) error {
|
|||
return winrio.SendEx(u.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0)
|
||||
}
|
||||
|
||||
func (u *RIOConn) LocalAddr() (*Addr, error) {
|
||||
func (u *RIOConn) LocalAddr() (netip.AddrPort, error) {
|
||||
sa, err := windows.Getsockname(u.sock)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return netip.AddrPort{}, err
|
||||
}
|
||||
|
||||
v6 := sa.(*windows.SockaddrInet6)
|
||||
return &Addr{
|
||||
IP: v6.Addr[:],
|
||||
Port: uint16(v6.Port),
|
||||
}, nil
|
||||
return netip.AddrPortFrom(netip.AddrFrom16(v6.Addr).Unmap(), uint16(v6.Port)), nil
|
||||
|
||||
}
|
||||
|
||||
func (u *RIOConn) Rebind() error {
|
||||
|
|
|
@ -4,9 +4,8 @@
|
|||
package udp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
|
@ -16,30 +15,24 @@ import (
|
|||
)
|
||||
|
||||
type Packet struct {
|
||||
ToIp net.IP
|
||||
ToPort uint16
|
||||
FromIp net.IP
|
||||
FromPort uint16
|
||||
Data []byte
|
||||
To netip.AddrPort
|
||||
From netip.AddrPort
|
||||
Data []byte
|
||||
}
|
||||
|
||||
func (u *Packet) Copy() *Packet {
|
||||
n := &Packet{
|
||||
ToIp: make(net.IP, len(u.ToIp)),
|
||||
ToPort: u.ToPort,
|
||||
FromIp: make(net.IP, len(u.FromIp)),
|
||||
FromPort: u.FromPort,
|
||||
Data: make([]byte, len(u.Data)),
|
||||
To: u.To,
|
||||
From: u.From,
|
||||
Data: make([]byte, len(u.Data)),
|
||||
}
|
||||
|
||||
copy(n.ToIp, u.ToIp)
|
||||
copy(n.FromIp, u.FromIp)
|
||||
copy(n.Data, u.Data)
|
||||
return n
|
||||
}
|
||||
|
||||
type TesterConn struct {
|
||||
Addr *Addr
|
||||
Addr netip.AddrPort
|
||||
|
||||
RxPackets chan *Packet // Packets to receive into nebula
|
||||
TxPackets chan *Packet // Packets transmitted outside by nebula
|
||||
|
@ -48,9 +41,9 @@ type TesterConn struct {
|
|||
l *logrus.Logger
|
||||
}
|
||||
|
||||
func NewListener(l *logrus.Logger, ip net.IP, port int, _ bool, _ int) (Conn, error) {
|
||||
func NewListener(l *logrus.Logger, ip netip.Addr, port int, _ bool, _ int) (Conn, error) {
|
||||
return &TesterConn{
|
||||
Addr: &Addr{ip, uint16(port)},
|
||||
Addr: netip.AddrPortFrom(ip, uint16(port)),
|
||||
RxPackets: make(chan *Packet, 10),
|
||||
TxPackets: make(chan *Packet, 10),
|
||||
l: l,
|
||||
|
@ -71,7 +64,7 @@ func (u *TesterConn) Send(packet *Packet) {
|
|||
}
|
||||
if u.l.Level >= logrus.DebugLevel {
|
||||
u.l.WithField("header", h).
|
||||
WithField("udpAddr", fmt.Sprintf("%v:%v", packet.FromIp, packet.FromPort)).
|
||||
WithField("udpAddr", packet.From).
|
||||
WithField("dataLen", len(packet.Data)).
|
||||
Debug("UDP receiving injected packet")
|
||||
}
|
||||
|
@ -98,23 +91,18 @@ func (u *TesterConn) Get(block bool) *Packet {
|
|||
// Below this is boilerplate implementation to make nebula actually work
|
||||
//********************************************************************************************************************//
|
||||
|
||||
func (u *TesterConn) WriteTo(b []byte, addr *Addr) error {
|
||||
func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error {
|
||||
if u.closed.Load() {
|
||||
return io.ErrClosedPipe
|
||||
}
|
||||
|
||||
p := &Packet{
|
||||
Data: make([]byte, len(b), len(b)),
|
||||
FromIp: make([]byte, 16),
|
||||
FromPort: u.Addr.Port,
|
||||
ToIp: make([]byte, 16),
|
||||
ToPort: addr.Port,
|
||||
Data: make([]byte, len(b), len(b)),
|
||||
From: u.Addr,
|
||||
To: addr,
|
||||
}
|
||||
|
||||
copy(p.Data, b)
|
||||
copy(p.ToIp, addr.IP.To16())
|
||||
copy(p.FromIp, u.Addr.IP.To16())
|
||||
|
||||
u.TxPackets <- p
|
||||
return nil
|
||||
}
|
||||
|
@ -123,7 +111,6 @@ func (u *TesterConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *fi
|
|||
plaintext := make([]byte, MTU)
|
||||
h := &header.H{}
|
||||
fwPacket := &firewall.Packet{}
|
||||
ua := &Addr{IP: make([]byte, 16)}
|
||||
nb := make([]byte, 12, 12)
|
||||
|
||||
for {
|
||||
|
@ -131,9 +118,7 @@ func (u *TesterConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *fi
|
|||
if !ok {
|
||||
return
|
||||
}
|
||||
ua.Port = p.FromPort
|
||||
copy(ua.IP, p.FromIp.To16())
|
||||
r(ua, plaintext[:0], p.Data, h, fwPacket, lhf, nb, q, cache.Get(u.l))
|
||||
r(p.From, plaintext[:0], p.Data, h, fwPacket, lhf, nb, q, cache.Get(u.l))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -144,7 +129,7 @@ func NewUDPStatsEmitter(_ []Conn) func() {
|
|||
return func() {}
|
||||
}
|
||||
|
||||
func (u *TesterConn) LocalAddr() (*Addr, error) {
|
||||
func (u *TesterConn) LocalAddr() (netip.AddrPort, error) {
|
||||
return u.Addr, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -6,12 +6,13 @@ package udp
|
|||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"syscall"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
|
||||
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
|
||||
if multi {
|
||||
//NOTE: Technically we can support it with RIO but it wouldn't be at the socket level
|
||||
// The udp stack would need to be reworked to hide away the implementation differences between
|
||||
|
|
Loading…
Reference in a new issue