mirror of
https://github.com/slackhq/nebula.git
synced 2024-11-10 09:12:39 +08:00
3aca576b07
* update to go1.21 Since the first minor version update has already been released, we can probably feel comfortable updating to go1.21. This version now enforces that the go version on the system is compatible with the version specified in go.mod, so we can remove the old logic around checking the minimum version in the Makefile. - https://go.dev/doc/go1.21#tools > To improve forwards compatibility, Go 1.21 now reads the go line in a go.work or go.mod file as a strict minimum requirement: go 1.21.0 means that the workspace or module cannot be used with Go 1.20 or with Go 1.21rc1. This allows projects that depend on fixes made in later versions of Go to ensure that they are not used with earlier versions. It also gives better error reporting for projects that make use of new Go features: when the problem is that a newer Go version is needed, that problem is reported clearly, instead of attempting to build the code and printing errors about unresolved imports or syntax errors. * update to go1.22 * bump gvisor * fix merge conflicts * use latest gvisor `go` branch Need to use the latest commit on the `go` branch, see: - https://github.com/google/gvisor?tab=readme-ov-file#using-go-get * mod tidy * more fixes * give smoketest more time Is this why it is failing? * also a little more sleep here --------- Co-authored-by: Jack Doan <me@jackdoan.com>
248 lines
6.1 KiB
Go
248 lines
6.1 KiB
Go
package service
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"log"
|
|
"math"
|
|
"net"
|
|
"os"
|
|
"strings"
|
|
"sync"
|
|
|
|
"github.com/sirupsen/logrus"
|
|
"github.com/slackhq/nebula"
|
|
"github.com/slackhq/nebula/config"
|
|
"github.com/slackhq/nebula/overlay"
|
|
"golang.org/x/sync/errgroup"
|
|
"gvisor.dev/gvisor/pkg/buffer"
|
|
"gvisor.dev/gvisor/pkg/tcpip"
|
|
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
|
"gvisor.dev/gvisor/pkg/tcpip/header"
|
|
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
|
|
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
|
|
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
|
|
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
|
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
|
|
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
|
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
|
"gvisor.dev/gvisor/pkg/waiter"
|
|
)
|
|
|
|
const nicID = 1
|
|
|
|
type Service struct {
|
|
eg *errgroup.Group
|
|
control *nebula.Control
|
|
ipstack *stack.Stack
|
|
|
|
mu struct {
|
|
sync.Mutex
|
|
|
|
listeners map[uint16]*tcpListener
|
|
}
|
|
}
|
|
|
|
func New(config *config.C) (*Service, error) {
|
|
logger := logrus.New()
|
|
logger.Out = os.Stdout
|
|
|
|
control, err := nebula.Main(config, false, "custom-app", logger, overlay.NewUserDeviceFromConfig)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
control.Start()
|
|
|
|
ctx := control.Context()
|
|
eg, ctx := errgroup.WithContext(ctx)
|
|
s := Service{
|
|
eg: eg,
|
|
control: control,
|
|
}
|
|
s.mu.listeners = map[uint16]*tcpListener{}
|
|
|
|
device, ok := control.Device().(*overlay.UserDevice)
|
|
if !ok {
|
|
return nil, errors.New("must be using user device")
|
|
}
|
|
|
|
s.ipstack = stack.New(stack.Options{
|
|
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
|
|
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol4, icmp.NewProtocol6},
|
|
})
|
|
sackEnabledOpt := tcpip.TCPSACKEnabled(true) // TCP SACK is disabled by default
|
|
tcpipErr := s.ipstack.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt)
|
|
if tcpipErr != nil {
|
|
return nil, fmt.Errorf("could not enable TCP SACK: %v", tcpipErr)
|
|
}
|
|
linkEP := channel.New( /*size*/ 512 /*mtu*/, 1280, "")
|
|
if tcpipProblem := s.ipstack.CreateNIC(nicID, linkEP); tcpipProblem != nil {
|
|
return nil, fmt.Errorf("could not create netstack NIC: %v", tcpipProblem)
|
|
}
|
|
ipv4Subnet, _ := tcpip.NewSubnet(tcpip.AddrFrom4([4]byte{0x00, 0x00, 0x00, 0x00}), tcpip.MaskFrom(strings.Repeat("\x00", 4)))
|
|
s.ipstack.SetRouteTable([]tcpip.Route{
|
|
{
|
|
Destination: ipv4Subnet,
|
|
NIC: nicID,
|
|
},
|
|
})
|
|
|
|
ipNet := device.Cidr()
|
|
pa := tcpip.ProtocolAddress{
|
|
AddressWithPrefix: tcpip.AddrFromSlice(ipNet.IP).WithPrefix(),
|
|
Protocol: ipv4.ProtocolNumber,
|
|
}
|
|
if err := s.ipstack.AddProtocolAddress(nicID, pa, stack.AddressProperties{
|
|
PEB: stack.CanBePrimaryEndpoint, // zero value default
|
|
ConfigType: stack.AddressConfigStatic, // zero value default
|
|
}); err != nil {
|
|
return nil, fmt.Errorf("error creating IP: %s", err)
|
|
}
|
|
|
|
const tcpReceiveBufferSize = 0
|
|
const maxInFlightConnectionAttempts = 1024
|
|
tcpFwd := tcp.NewForwarder(s.ipstack, tcpReceiveBufferSize, maxInFlightConnectionAttempts, s.tcpHandler)
|
|
s.ipstack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpFwd.HandlePacket)
|
|
|
|
reader, writer := device.Pipe()
|
|
|
|
go func() {
|
|
<-ctx.Done()
|
|
reader.Close()
|
|
writer.Close()
|
|
}()
|
|
|
|
// create Goroutines to forward packets between Nebula and Gvisor
|
|
eg.Go(func() error {
|
|
buf := make([]byte, header.IPv4MaximumHeaderSize+header.IPv4MaximumPayloadSize)
|
|
for {
|
|
// this will read exactly one packet
|
|
n, err := reader.Read(buf)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
packetBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
|
Payload: buffer.MakeWithData(bytes.Clone(buf[:n])),
|
|
})
|
|
linkEP.InjectInbound(header.IPv4ProtocolNumber, packetBuf)
|
|
|
|
if err := ctx.Err(); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
})
|
|
eg.Go(func() error {
|
|
for {
|
|
packet := linkEP.ReadContext(ctx)
|
|
if packet == nil {
|
|
if err := ctx.Err(); err != nil {
|
|
return err
|
|
}
|
|
continue
|
|
}
|
|
bufView := packet.ToView()
|
|
if _, err := bufView.WriteTo(writer); err != nil {
|
|
return err
|
|
}
|
|
bufView.Release()
|
|
}
|
|
})
|
|
|
|
return &s, nil
|
|
}
|
|
|
|
// DialContext dials the provided address. Currently only TCP is supported.
|
|
func (s *Service) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
|
if network != "tcp" && network != "tcp4" {
|
|
return nil, errors.New("only tcp is supported")
|
|
}
|
|
|
|
addr, err := net.ResolveTCPAddr(network, address)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
fullAddr := tcpip.FullAddress{
|
|
NIC: nicID,
|
|
Addr: tcpip.AddrFromSlice(addr.IP),
|
|
Port: uint16(addr.Port),
|
|
}
|
|
|
|
return gonet.DialContextTCP(ctx, s.ipstack, fullAddr, ipv4.ProtocolNumber)
|
|
}
|
|
|
|
// Listen listens on the provided address. Currently only TCP with wildcard
|
|
// addresses are supported.
|
|
func (s *Service) Listen(network, address string) (net.Listener, error) {
|
|
if network != "tcp" && network != "tcp4" {
|
|
return nil, errors.New("only tcp is supported")
|
|
}
|
|
addr, err := net.ResolveTCPAddr(network, address)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if addr.IP != nil && !bytes.Equal(addr.IP, []byte{0, 0, 0, 0}) {
|
|
return nil, fmt.Errorf("only wildcard address supported, got %q %v", address, addr.IP)
|
|
}
|
|
if addr.Port == 0 {
|
|
return nil, errors.New("specific port required, got 0")
|
|
}
|
|
if addr.Port < 0 || addr.Port >= math.MaxUint16 {
|
|
return nil, fmt.Errorf("invalid port %d", addr.Port)
|
|
}
|
|
port := uint16(addr.Port)
|
|
|
|
l := &tcpListener{
|
|
port: port,
|
|
s: s,
|
|
addr: addr,
|
|
accept: make(chan net.Conn),
|
|
}
|
|
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
if _, ok := s.mu.listeners[port]; ok {
|
|
return nil, fmt.Errorf("already listening on port %d", port)
|
|
}
|
|
s.mu.listeners[port] = l
|
|
|
|
return l, nil
|
|
}
|
|
|
|
func (s *Service) Wait() error {
|
|
return s.eg.Wait()
|
|
}
|
|
|
|
func (s *Service) Close() error {
|
|
s.control.Stop()
|
|
return nil
|
|
}
|
|
|
|
func (s *Service) tcpHandler(r *tcp.ForwarderRequest) {
|
|
endpointID := r.ID()
|
|
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
l, ok := s.mu.listeners[endpointID.LocalPort]
|
|
if !ok {
|
|
r.Complete(true)
|
|
return
|
|
}
|
|
|
|
var wq waiter.Queue
|
|
ep, err := r.CreateEndpoint(&wq)
|
|
if err != nil {
|
|
log.Printf("got error creating endpoint %q", err)
|
|
r.Complete(true)
|
|
return
|
|
}
|
|
r.Complete(false)
|
|
ep.SocketOptions().SetKeepAlive(true)
|
|
|
|
conn := gonet.NewTCPConn(&wq, ep)
|
|
l.accept <- conn
|
|
}
|