diff --git a/examples/go_service/main.go b/examples/go_service/main.go index f46273a..30178c0 100644 --- a/examples/go_service/main.go +++ b/examples/go_service/main.go @@ -4,6 +4,7 @@ import ( "bufio" "fmt" "log" + "net" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/service" @@ -54,16 +55,16 @@ pki: cert: /home/rice/Developer/nebula-config/app.crt key: /home/rice/Developer/nebula-config/app.key ` - var config config.C - if err := config.LoadString(configStr); err != nil { + var cfg config.C + if err := cfg.LoadString(configStr); err != nil { return err } - service, err := service.New(&config) + svc, err := service.New(&cfg) if err != nil { return err } - ln, err := service.Listen("tcp", ":1234") + ln, err := svc.Listen("tcp", ":1234") if err != nil { return err } @@ -73,16 +74,24 @@ pki: log.Printf("accept error: %s", err) break } - defer conn.Close() + defer func(conn net.Conn) { + _ = conn.Close() + }(conn) log.Printf("got connection") - conn.Write([]byte("hello world\n")) + _, err = conn.Write([]byte("hello world\n")) + if err != nil { + log.Printf("write error: %s", err) + } scanner := bufio.NewScanner(conn) for scanner.Scan() { message := scanner.Text() - fmt.Fprintf(conn, "echo: %q\n", message) + _, err = fmt.Fprintf(conn, "echo: %q\n", message) + if err != nil { + log.Printf("write error: %s", err) + } log.Printf("got message %q", message) } @@ -92,8 +101,8 @@ pki: } } - service.Close() - if err := service.Wait(); err != nil { + _ = svc.Close() + if err := svc.Wait(); err != nil { return err } return nil diff --git a/service/service.go b/service/service.go index 50c1d4a..4ddd301 100644 --- a/service/service.go +++ b/service/service.go @@ -8,6 +8,7 @@ import ( "log" "math" "net" + "net/netip" "os" "strings" "sync" @@ -153,24 +154,48 @@ func New(config *config.C) (*Service, error) { return &s, nil } -// DialContext dials the provided address. Currently only TCP is supported. +func getProtocolNumber(addr netip.Addr) tcpip.NetworkProtocolNumber { + if addr.Is6() { + return ipv6.ProtocolNumber + } + return ipv4.ProtocolNumber +} + +// DialContext dials the provided address. func (s *Service) DialContext(ctx context.Context, network, address string) (net.Conn, error) { - if network != "tcp" && network != "tcp4" { - return nil, errors.New("only tcp is supported") + switch network { + case "udp", "udp4", "udp6": + addr, err := net.ResolveUDPAddr(network, address) + if err != nil { + return nil, err + } + fullAddr := tcpip.FullAddress{ + NIC: nicID, + Addr: tcpip.AddrFromSlice(addr.IP), + Port: uint16(addr.Port), + } + num := getProtocolNumber(addr.AddrPort().Addr()) + return gonet.DialUDP(s.ipstack, nil, &fullAddr, num) + case "tcp", "tcp4", "tcp6": + addr, err := net.ResolveTCPAddr(network, address) + if err != nil { + return nil, err + } + fullAddr := tcpip.FullAddress{ + NIC: nicID, + Addr: tcpip.AddrFromSlice(addr.IP), + Port: uint16(addr.Port), + } + num := getProtocolNumber(addr.AddrPort().Addr()) + return gonet.DialContextTCP(ctx, s.ipstack, fullAddr, num) + default: + return nil, fmt.Errorf("unknown network type: %s", network) } +} - addr, err := net.ResolveTCPAddr(network, address) - if err != nil { - return nil, err - } - - fullAddr := tcpip.FullAddress{ - NIC: nicID, - Addr: tcpip.AddrFromSlice(addr.IP), - Port: uint16(addr.Port), - } - - return gonet.DialContextTCP(ctx, s.ipstack, fullAddr, ipv4.ProtocolNumber) +// Dial dials the provided address +func (s *Service) Dial(network, address string) (net.Conn, error) { + return s.DialContext(context.Background(), network, address) } // Listen listens on the provided address. Currently only TCP with wildcard