headscale/utils.go

366 lines
8.4 KiB
Go
Raw Normal View History

2020-06-21 18:32:08 +08:00
// Codehere is mostly taken from github.com/tailscale/tailscale
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package headscale
import (
2021-10-30 22:29:03 +08:00
"context"
2022-01-26 06:11:15 +08:00
"crypto/rand"
"encoding/base64"
2020-06-21 18:32:08 +08:00
"encoding/json"
"fmt"
"io/fs"
2021-10-30 22:29:03 +08:00
"net"
"os"
"path/filepath"
"reflect"
"strconv"
2021-08-13 17:33:19 +08:00
"strings"
2020-06-21 18:32:08 +08:00
"github.com/rs/zerolog/log"
"github.com/spf13/viper"
"inet.af/netaddr"
2021-08-13 17:33:19 +08:00
"tailscale.com/tailcfg"
"tailscale.com/types/key"
2020-06-21 18:32:08 +08:00
)
2021-11-16 03:18:14 +08:00
const (
ErrCannotDecryptResponse = Error("cannot decrypt response")
ErrCouldNotAllocateIP = Error("could not find any suitable IP")
// These constants are copied from the upstream tailscale.com/types/key
// library, because they are not exported.
// https://github.com/tailscale/tailscale/tree/main/types/key
// nodePublicHexPrefix is the prefix used to identify a
// hex-encoded node public key.
//
// This prefix is used in the control protocol, so cannot be
// changed.
nodePublicHexPrefix = "nodekey:"
// machinePublicHexPrefix is the prefix used to identify a
// hex-encoded machine public key.
//
// This prefix is used in the control protocol, so cannot be
// changed.
machinePublicHexPrefix = "mkey:"
// discoPublicHexPrefix is the prefix used to identify a
// hex-encoded disco public key.
//
// This prefix is used in the control protocol, so cannot be
// changed.
discoPublicHexPrefix = "discokey:"
// privateKey prefix.
privateHexPrefix = "privkey:"
PermissionFallback = 0o700
2022-08-14 23:04:07 +08:00
ZstdCompression = "zstd"
2021-11-16 03:18:14 +08:00
)
func MachinePublicKeyStripPrefix(machineKey key.MachinePublic) string {
return strings.TrimPrefix(machineKey.String(), machinePublicHexPrefix)
}
func NodePublicKeyStripPrefix(nodeKey key.NodePublic) string {
return strings.TrimPrefix(nodeKey.String(), nodePublicHexPrefix)
}
func DiscoPublicKeyStripPrefix(discoKey key.DiscoPublic) string {
return strings.TrimPrefix(discoKey.String(), discoPublicHexPrefix)
}
func MachinePublicKeyEnsurePrefix(machineKey string) string {
if !strings.HasPrefix(machineKey, machinePublicHexPrefix) {
return machinePublicHexPrefix + machineKey
}
return machineKey
}
func NodePublicKeyEnsurePrefix(nodeKey string) string {
if !strings.HasPrefix(nodeKey, nodePublicHexPrefix) {
return nodePublicHexPrefix + nodeKey
}
return nodeKey
}
func DiscoPublicKeyEnsurePrefix(discoKey string) string {
if !strings.HasPrefix(discoKey, discoPublicHexPrefix) {
return discoPublicHexPrefix + discoKey
}
return discoKey
}
func PrivateKeyEnsurePrefix(privateKey string) string {
if !strings.HasPrefix(privateKey, privateHexPrefix) {
return privateHexPrefix + privateKey
}
return privateKey
}
2021-05-06 07:01:45 +08:00
// Error is used to compare errors as per https://dave.cheney.net/2016/04/07/constant-errors
2021-05-06 05:00:04 +08:00
type Error string
func (e Error) Error() string { return string(e) }
2021-11-13 16:36:45 +08:00
func decode(
msg []byte,
output interface{},
pubKey *key.MachinePublic,
privKey *key.MachinePrivate,
2021-11-13 16:36:45 +08:00
) error {
2022-08-15 16:43:39 +08:00
log.Trace().
Str("pubkey", pubKey.ShortString()).
Int("length", len(msg)).
Msg("Trying to decrypt")
decrypted, ok := privKey.OpenFrom(*pubKey, msg)
if !ok {
return ErrCannotDecryptResponse
2020-06-21 18:32:08 +08:00
}
if err := json.Unmarshal(decrypted, output); err != nil {
2021-11-16 03:18:14 +08:00
return err
2020-06-21 18:32:08 +08:00
}
2021-11-14 23:46:09 +08:00
2020-06-21 18:32:08 +08:00
return nil
}
2022-05-16 20:59:46 +08:00
func (h *Headscale) getAvailableIPs() (MachineAddresses, error) {
var ips MachineAddresses
var err error
2022-01-16 21:16:59 +08:00
ipPrefixes := h.cfg.IPPrefixes
for _, ipPrefix := range ipPrefixes {
var ip *netaddr.IP
ip, err = h.getAvailableIP(ipPrefix)
if err != nil {
2022-05-16 20:59:46 +08:00
return ips, err
2022-01-16 21:16:59 +08:00
}
ips = append(ips, *ip)
}
2022-05-16 20:59:46 +08:00
return ips, err
2022-01-16 21:16:59 +08:00
}
2022-05-16 20:59:46 +08:00
func GetIPPrefixEndpoints(na netaddr.IPPrefix) (netaddr.IP, netaddr.IP) {
var network, broadcast netaddr.IP
ipRange := na.Range()
network = ipRange.From()
broadcast = ipRange.To()
2022-01-30 16:35:10 +08:00
2022-05-16 20:59:46 +08:00
return network, broadcast
}
2022-01-16 21:16:59 +08:00
func (h *Headscale) getAvailableIP(ipPrefix netaddr.IPPrefix) (*netaddr.IP, error) {
usedIps, err := h.getUsedIPs()
if err != nil {
return nil, err
}
ipPrefixNetworkAddress, ipPrefixBroadcastAddress := GetIPPrefixEndpoints(ipPrefix)
// Get the first IP in our prefix
ip := ipPrefixNetworkAddress.Next()
2020-06-21 18:32:08 +08:00
for {
if !ipPrefix.Contains(ip) {
2022-07-29 23:35:21 +08:00
return nil, ErrCouldNotAllocateIP
2020-06-21 18:32:08 +08:00
}
switch {
case ip.Compare(ipPrefixBroadcastAddress) == 0:
fallthrough
case usedIps.Contains(ip):
fallthrough
case ip.IsZero() || ip.IsLoopback():
ip = ip.Next()
2021-11-14 23:46:09 +08:00
continue
default:
return &ip, nil
2020-06-21 18:32:08 +08:00
}
}
}
func (h *Headscale) getUsedIPs() (*netaddr.IPSet, error) {
2022-01-16 21:16:59 +08:00
// FIXME: This really deserves a better data model,
// but this was quick to get running and it should be enough
// to begin experimenting with a dual stack tailnet.
var addressesSlices []string
h.db.Model(&Machine{}).Pluck("ip_addresses", &addressesSlices)
var ips netaddr.IPSetBuilder
2022-01-16 21:16:59 +08:00
for _, slice := range addressesSlices {
var machineAddresses MachineAddresses
err := machineAddresses.Scan(slice)
2022-01-16 21:16:59 +08:00
if err != nil {
return &netaddr.IPSet{}, fmt.Errorf(
"failed to read ip from database: %w",
err,
)
}
for _, ip := range machineAddresses {
ips.Add(ip)
}
}
ipSet, err := ips.IPSet()
if err != nil {
return &netaddr.IPSet{}, fmt.Errorf(
"failed to build IP Set: %w",
err,
)
}
return ipSet, nil
}
2021-08-13 17:33:19 +08:00
func tailNodesToString(nodes []*tailcfg.Node) string {
temp := make([]string, len(nodes))
for index, node := range nodes {
temp[index] = node.Name
}
return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp))
}
func tailMapResponseToString(resp tailcfg.MapResponse) string {
2021-11-13 16:36:45 +08:00
return fmt.Sprintf(
"{ Node: %s, Peers: %s }",
resp.Node.Name,
tailNodesToString(resp.Peers),
)
2021-08-13 17:33:19 +08:00
}
2021-10-30 22:29:03 +08:00
func GrpcSocketDialer(ctx context.Context, addr string) (net.Conn, error) {
var d net.Dialer
2021-11-14 23:46:09 +08:00
2021-10-30 22:29:03 +08:00
return d.DialContext(ctx, "unix", addr)
}
2021-11-05 06:17:44 +08:00
func ipPrefixToString(prefixes []netaddr.IPPrefix) []string {
result := make([]string, len(prefixes))
for index, prefix := range prefixes {
result[index] = prefix.String()
}
return result
}
func stringToIPPrefix(prefixes []string) ([]netaddr.IPPrefix, error) {
2021-11-05 06:17:44 +08:00
result := make([]netaddr.IPPrefix, len(prefixes))
for index, prefixStr := range prefixes {
prefix, err := netaddr.ParseIPPrefix(prefixStr)
if err != nil {
return []netaddr.IPPrefix{}, err
}
result[index] = prefix
}
return result, nil
}
func contains[T string | netaddr.IPPrefix](ts []T, t T) bool {
for _, v := range ts {
if reflect.DeepEqual(v, t) {
2021-11-05 06:17:44 +08:00
return true
}
}
return false
}
2022-01-26 06:11:15 +08:00
// GenerateRandomBytes returns securely generated random bytes.
// It will return an error if the system's secure random
// number generator fails to function correctly, in which
// case the caller should not continue.
func GenerateRandomBytes(n int) ([]byte, error) {
2022-02-13 03:42:55 +08:00
bytes := make([]byte, n)
2022-01-26 06:11:15 +08:00
// Note that err == nil only if we read len(b) bytes.
2022-02-13 03:42:55 +08:00
if _, err := rand.Read(bytes); err != nil {
2022-01-26 06:11:15 +08:00
return nil, err
}
2022-02-13 03:42:55 +08:00
return bytes, nil
2022-01-26 06:11:15 +08:00
}
// GenerateRandomStringURLSafe returns a URL-safe, base64 encoded
// securely generated random string.
// It will return an error if the system's secure random
// number generator fails to function correctly, in which
// case the caller should not continue.
func GenerateRandomStringURLSafe(n int) (string, error) {
b, err := GenerateRandomBytes(n)
2022-02-13 03:42:55 +08:00
2022-01-26 06:11:15 +08:00
return base64.RawURLEncoding.EncodeToString(b), err
}
// GenerateRandomStringDNSSafe returns a DNS-safe
// securely generated random string.
// It will return an error if the system's secure random
// number generator fails to function correctly, in which
// case the caller should not continue.
2022-06-26 17:55:37 +08:00
func GenerateRandomStringDNSSafe(size int) (string, error) {
var str string
var err error
2022-06-26 17:55:37 +08:00
for len(str) < size {
str, err = GenerateRandomStringURLSafe(size)
if err != nil {
return "", err
}
str = strings.ToLower(strings.ReplaceAll(strings.ReplaceAll(str, "_", ""), "-", ""))
}
2022-06-26 17:55:37 +08:00
return str[:size], nil
}
2022-05-17 03:41:46 +08:00
func IsStringInSlice(slice []string, str string) bool {
for _, s := range slice {
if s == str {
return true
}
}
return false
}
func AbsolutePathFromConfigPath(path string) string {
// If a relative path is provided, prefix it with the the directory where
// the config file was found.
if (path != "") && !strings.HasPrefix(path, string(os.PathSeparator)) {
dir, _ := filepath.Split(viper.ConfigFileUsed())
if dir != "" {
path = filepath.Join(dir, path)
}
}
return path
}
func GetFileMode(key string) fs.FileMode {
modeStr := viper.GetString(key)
mode, err := strconv.ParseUint(modeStr, Base8, BitSize64)
if err != nil {
return PermissionFallback
}
return fs.FileMode(mode)
}