mirror of
https://github.com/gravitl/netmaker.git
synced 2025-09-06 05:04:27 +08:00
236 lines
5.4 KiB
Go
236 lines
5.4 KiB
Go
package wg
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"log"
|
|
"net"
|
|
"os"
|
|
"os/exec"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"golang.zx2c4.com/wireguard/wgctrl"
|
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
)
|
|
|
|
const (
|
|
DefaultMTU = 1280
|
|
DefaultWgPort = 51820
|
|
DefaultWgKeepAlive = 20 * time.Second
|
|
)
|
|
|
|
// WGIface represents a interface instance
|
|
type WGIface struct {
|
|
Name string
|
|
Port int
|
|
MTU int
|
|
Device *wgtypes.Device
|
|
Address WGAddress
|
|
Interface NetInterface
|
|
mu sync.Mutex
|
|
}
|
|
|
|
// NetInterface represents a generic network tunnel interface
|
|
type NetInterface interface {
|
|
Close() error
|
|
}
|
|
|
|
// WGAddress Wireguard parsed address
|
|
type WGAddress struct {
|
|
IP net.IP
|
|
Network *net.IPNet
|
|
}
|
|
|
|
// NewWGIFace Creates a new Wireguard interface instance
|
|
func NewWGIFace(iface string, address string, mtu int) (*WGIface, error) {
|
|
wgIface := &WGIface{
|
|
Name: iface,
|
|
MTU: mtu,
|
|
mu: sync.Mutex{},
|
|
}
|
|
|
|
wgAddress, err := parseAddress(address)
|
|
if err != nil {
|
|
return wgIface, err
|
|
}
|
|
|
|
wgIface.Address = wgAddress
|
|
err = wgIface.GetWgIface(iface)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return wgIface, nil
|
|
}
|
|
|
|
func (w *WGIface) GetWgIface(iface string) error {
|
|
wgClient, err := wgctrl.New()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
dev, err := wgClient.Device(iface)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
log.Printf("----> DEVICE: %+v\n", dev)
|
|
w.Device = dev
|
|
w.Port = dev.ListenPort
|
|
return nil
|
|
}
|
|
|
|
func GetWgIfacePubKey(iface string) string {
|
|
wgClient, err := wgctrl.New()
|
|
if err != nil {
|
|
log.Println("Error fetching pub key: ", iface, err)
|
|
return ""
|
|
}
|
|
dev, err := wgClient.Device(iface)
|
|
if err != nil {
|
|
log.Println("Error fetching pub key: ", iface, err)
|
|
return ""
|
|
}
|
|
return dev.PublicKey.String()
|
|
}
|
|
|
|
// parseAddress parse a string ("1.2.3.4/24") address to WG Address
|
|
func parseAddress(address string) (WGAddress, error) {
|
|
ip, network, err := net.ParseCIDR(address)
|
|
if err != nil {
|
|
return WGAddress{}, err
|
|
}
|
|
return WGAddress{
|
|
IP: ip,
|
|
Network: network,
|
|
}, nil
|
|
}
|
|
|
|
// UpdatePeer updates existing Wireguard Peer or creates a new one if doesn't exist
|
|
func (w *WGIface) UpdatePeer(peerKey string, allowedIps []net.IPNet, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
|
w.mu.Lock()
|
|
defer w.mu.Unlock()
|
|
|
|
log.Printf("updating interface %s peer %s: endpoint %s ", w.Name, peerKey, endpoint)
|
|
|
|
// //parse allowed ips
|
|
// _, ipNet, err := net.ParseCIDR(allowedIps)
|
|
// if err != nil {
|
|
// return err
|
|
// }
|
|
|
|
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
peer := wgtypes.PeerConfig{
|
|
PublicKey: peerKeyParsed,
|
|
// ReplaceAllowedIPs: true,
|
|
// AllowedIPs: allowedIps,
|
|
PersistentKeepaliveInterval: &keepAlive,
|
|
PresharedKey: preSharedKey,
|
|
Endpoint: endpoint,
|
|
}
|
|
|
|
config := wgtypes.Config{
|
|
Peers: []wgtypes.PeerConfig{peer},
|
|
}
|
|
err = w.configureDevice(config)
|
|
if err != nil {
|
|
return fmt.Errorf("received error \"%v\" while updating peer on interface %s with settings: allowed ips %s, endpoint %s", err, w.Name, allowedIps, endpoint.String())
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// configureDevice configures the wireguard device
|
|
func (w *WGIface) configureDevice(config wgtypes.Config) error {
|
|
wg, err := wgctrl.New()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer wg.Close()
|
|
|
|
// validate if device with name exists
|
|
_, err = wg.Device(w.Name)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
log.Printf("got Wireguard device %s\n", w.Name)
|
|
|
|
return wg.ConfigureDevice(w.Name, config)
|
|
}
|
|
|
|
// GetListenPort returns the listening port of the Wireguard endpoint
|
|
func (w *WGIface) GetListenPort() (*int, error) {
|
|
log.Printf("getting Wireguard listen port of interface %s", w.Name)
|
|
|
|
//discover Wireguard current configuration
|
|
wg, err := wgctrl.New()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer wg.Close()
|
|
|
|
d, err := wg.Device(w.Name)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
log.Printf("got Wireguard device listen port %s, %d", w.Name, d.ListenPort)
|
|
|
|
return &d.ListenPort, nil
|
|
}
|
|
|
|
// GetRealIface - retrieves tun iface based on reference iface name from config file
|
|
func GetRealIface(iface string) (string, error) {
|
|
RunCmd("wg show interfaces", false)
|
|
ifacePath := "/var/run/wireguard/" + iface + ".name"
|
|
if !(FileExists(ifacePath)) {
|
|
return "", errors.New(ifacePath + " does not exist")
|
|
}
|
|
realIfaceName, err := GetFileAsString(ifacePath)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
realIfaceName = strings.TrimSpace(realIfaceName)
|
|
if !(FileExists(fmt.Sprintf("/var/run/wireguard/%s.sock", realIfaceName))) {
|
|
return "", errors.New("interface file does not exist")
|
|
}
|
|
return realIfaceName, nil
|
|
}
|
|
|
|
// FileExists - checks if file exists locally
|
|
func FileExists(f string) bool {
|
|
info, err := os.Stat(f)
|
|
if os.IsNotExist(err) {
|
|
return false
|
|
}
|
|
if err != nil && strings.Contains(err.Error(), "not a directory") {
|
|
return false
|
|
}
|
|
if err != nil {
|
|
log.Println(0, "error reading file: "+f+", "+err.Error())
|
|
}
|
|
return !info.IsDir()
|
|
}
|
|
|
|
// GetFileAsString - returns the string contents of a given file
|
|
func GetFileAsString(path string) (string, error) {
|
|
content, err := os.ReadFile(path)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return string(content), err
|
|
}
|
|
|
|
// RunCmd - runs a local command
|
|
func RunCmd(command string, printerr bool) (string, error) {
|
|
args := strings.Fields(command)
|
|
cmd := exec.Command(args[0], args[1:]...)
|
|
cmd.Wait()
|
|
out, err := cmd.CombinedOutput()
|
|
if err != nil && printerr {
|
|
log.Println("error running command: ", command)
|
|
log.Println(strings.TrimSuffix(string(out), "\n"))
|
|
}
|
|
return string(out), err
|
|
}
|