changes from code review

Signed-off-by: Matthew R. Kasun <mkasun@nusak.ca>
This commit is contained in:
Matthew R. Kasun 2022-04-18 17:19:26 -04:00
parent 924403d5b4
commit 2b1f20e94b
10 changed files with 124 additions and 51 deletions

View file

@ -14,7 +14,6 @@ import (
"github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/netclient/config" "github.com/gravitl/netmaker/netclient/config"
"github.com/gravitl/netmaker/netclient/ncutils"
"github.com/gravitl/netmaker/servercfg" "github.com/gravitl/netmaker/servercfg"
"github.com/gravitl/netmaker/tls" "github.com/gravitl/netmaker/tls"
) )
@ -173,9 +172,6 @@ func register(w http.ResponseWriter, r *http.Request) {
returnErrorResponse(w, r, errorResponse) returnErrorResponse(w, r, errorResponse)
return return
} }
tls.SaveCert("/tmp/sent/", "root.pem", ca)
tls.SaveCert("/tmp/sent/", "client.pem", cert)
//x509.Certificate.PublicKey is an interface therefore json encoding/decoding result in a string value rather than a []byte //x509.Certificate.PublicKey is an interface therefore json encoding/decoding result in a string value rather than a []byte
//include the actual public key so the certificate can be properly reassembled on the other end. //include the actual public key so the certificate can be properly reassembled on the other end.
response := config.RegisterResponse{ response := config.RegisterResponse{
@ -212,28 +208,3 @@ func genCerts(clientKey *ed25519.PrivateKey, name *pkix.Name) (*x509.Certificate
} }
return cert, ca, nil return cert, ca, nil
} }
// genOpenSSLCerts generates a client certificate using calls to openssl and returns the certificate and root CA
func genOpenSSLCerts(key *ed25519.PrivateKey, name *pkix.Name) (*x509.Certificate, *x509.Certificate, error) {
if err := tls.SaveKey("/tmp/", "client.key", *key); err != nil {
return nil, nil, fmt.Errorf("failed to store client key %w", err)
}
cmd2 := fmt.Sprintf("openssl req -new -out /tmp/client.csr -key /tmp/client.key -subj /CN=%s", name.CommonName)
cmd3 := "openssl x509 -req -in /tmp/client.csr -days 365 -CA /etc/netmaker/root.pem -CAkey /etc/netmaker/root.key -CAcreateserial -out /tmp/client.pem"
if _, err := ncutils.RunCmd(cmd2, true); err != nil {
return nil, nil, fmt.Errorf("client csr error %w", err)
}
if _, err := ncutils.RunCmd(cmd3, true); err != nil {
return nil, nil, fmt.Errorf("client cert error %w", err)
}
cert, err := tls.ReadCert("/tmp/client.pem")
if err != nil {
return nil, nil, fmt.Errorf("read client cert error %w", err)
}
ca, err := tls.ReadCert("/etc/netmaker/root.pem")
if err != nil {
return nil, nil, fmt.Errorf("read ca cert error %w", err)
}
return cert, ca, nil
}

View file

@ -7,6 +7,9 @@ import (
"github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/logic"
) )
// LINUX_APP_DATA_PATH - linux path
const LINUX_APP_DATA_PATH = "/etc/netmaker"
// FileExists - checks if file exists // FileExists - checks if file exists
func FileExists(f string) bool { func FileExists(f string) bool {
info, err := os.Stat(f) info, err := os.Stat(f)
@ -49,3 +52,8 @@ func SetDNSDir() error {
} }
return nil return nil
} }
// GetNetmakerPath - gets netmaker path locally
func GetNetmakerPath() string {
return LINUX_APP_DATA_PATH
}

66
main.go
View file

@ -2,6 +2,9 @@ package main
import ( import (
"context" "context"
"crypto/ed25519"
"crypto/rand"
"errors"
"flag" "flag"
"fmt" "fmt"
"net" "net"
@ -11,6 +14,7 @@ import (
"strconv" "strconv"
"sync" "sync"
"syscall" "syscall"
"time"
"github.com/gravitl/netmaker/auth" "github.com/gravitl/netmaker/auth"
"github.com/gravitl/netmaker/config" "github.com/gravitl/netmaker/config"
@ -25,6 +29,7 @@ import (
"github.com/gravitl/netmaker/netclient/ncutils" "github.com/gravitl/netmaker/netclient/ncutils"
"github.com/gravitl/netmaker/servercfg" "github.com/gravitl/netmaker/servercfg"
"github.com/gravitl/netmaker/serverctl" "github.com/gravitl/netmaker/serverctl"
"github.com/gravitl/netmaker/tls"
"google.golang.org/grpc" "google.golang.org/grpc"
) )
@ -117,6 +122,7 @@ func initialize() { // Client Mode Prereq Check
logger.FatalLog(err.Error()) logger.FatalLog(err.Error())
} }
} }
genCerts()
} }
func startControllers() { func startControllers() {
@ -235,3 +241,63 @@ func setGarbageCollection() {
debug.SetGCPercent(ncutils.DEFAULT_GC_PERCENT) debug.SetGCPercent(ncutils.DEFAULT_GC_PERCENT)
} }
} }
func genCerts() error {
private, err := tls.ReadKey(functions.GetNetmakerPath())
if errors.Is(err, os.ErrNotExist) {
_, *private, err = ed25519.GenerateKey(rand.Reader)
if err != nil {
return err
}
if err := tls.SaveKey(functions.GetNetmakerPath(), "/root.key", *private); err != nil {
return err
}
} else if err != nil {
return err
}
ca, err := tls.ReadCert(functions.GetNetmakerPath() + "/root.pem")
//if cert doesn't exist or will expire within 10 days --- but can't do this as clients won't be able to connect
//if errors.Is(err, os.ErrNotExist) || cert.NotAfter.Before(time.Now().Add(time.Hour*24*10)) {
if errors.Is(err, os.ErrNotExist) {
caName := tls.NewName("CA Root", "US", "Gravitl")
csr, err := tls.NewCSR(*private, caName)
if err != nil {
return err
}
rootCA, err := tls.SelfSignedCA(*private, csr, tls.CERTIFICATE_VALIDITY)
if err != nil {
return err
}
if err := tls.SaveCert(functions.GetNetmakerPath(), "/root.pem", rootCA); err != nil {
return err
}
} else if err != nil {
return err
}
cert, err := tls.ReadCert(functions.GetNetmakerPath() + "/server.pem")
if errors.Is(err, os.ErrNotExist) || cert.NotAfter.Before(time.Now().Add(time.Hour*24*10)) {
//gen new key
_, key, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
return err
}
serverName := tls.NewCName(servercfg.GetServer())
csr, err := tls.NewCSR(key, serverName)
if err != nil {
return err
}
cert, err := tls.NewEndEntityCert(*private, csr, ca, tls.CERTIFICATE_VALIDITY)
if err != nil {
return err
}
if err := tls.SaveKey(functions.GetNetmakerPath(), "/server.key", key); err != nil {
return err
}
if err := tls.SaveCert(functions.GetNetmakerPath(), "/server.pem", cert); err != nil {
return err
}
} else if err != nil {
return err
}
return nil
}

View file

@ -30,7 +30,7 @@ func GetCommands(cliFlags []cli.Flag) []*cli.Command {
err = errors.New("no server address provided") err = errors.New("no server address provided")
return err return err
} }
err = command.Join(&cfg, pvtKey) err = command.Register(&cfg, pvtKey)
return err return err
}, },
}, },
@ -105,18 +105,6 @@ func GetCommands(cliFlags []cli.Flag) []*cli.Command {
return err return err
}, },
}, },
{
Name: "register",
Usage: "register with netmaker",
Flags: cliFlags,
Action: func(c *cli.Context) error {
cfg, _, err := config.GetCLIConfig(c)
if err != nil {
return err
}
return command.Register(&cfg)
},
},
} }
} }

View file

@ -155,6 +155,6 @@ func Daemon() error {
return err return err
} }
func Register(cfg *config.ClientConfig) error { func Register(cfg *config.ClientConfig, key string) error {
return functions.Register(cfg) return functions.Register(cfg, key)
} }

View file

@ -23,6 +23,9 @@ import (
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
) )
// LINUX_APP_DATA_PATH - linux path
const LINUX_APP_DATA_PATH = "/etc/netmaker"
// ListPorts - lists ports of WireGuard devices // ListPorts - lists ports of WireGuard devices
func ListPorts() error { func ListPorts() error {
wgclient, err := wgctrl.New() wgclient, err := wgctrl.New()
@ -321,3 +324,8 @@ func WipeLocal(network string) error {
} }
return err return err
} }
// GetNetmakerPath - gets netmaker path locally
func GetNetmakerPath() string {
return LINUX_APP_DATA_PATH
}

View file

@ -186,7 +186,6 @@ func setupMQTTSub(server string) mqtt.Client {
opts := mqtt.NewClientOptions() opts := mqtt.NewClientOptions()
opts.AddBroker("ssl://" + server + ":8883") // TODO get the appropriate port of the comms mq server opts.AddBroker("ssl://" + server + ":8883") // TODO get the appropriate port of the comms mq server
opts.TLSConfig = NewTLSConfig(nil, server) opts.TLSConfig = NewTLSConfig(nil, server)
opts.ClientID = ncutils.MakeRandomString(23) // helps avoid id duplication on broker
opts.SetDefaultPublishHandler(All) opts.SetDefaultPublishHandler(All)
opts.SetAutoReconnect(true) opts.SetAutoReconnect(true)
opts.SetConnectRetry(true) opts.SetConnectRetry(true)
@ -328,7 +327,6 @@ func setupMQTT(cfg *config.ClientConfig, publish bool) mqtt.Client {
server := cfg.Server.Server server := cfg.Server.Server
opts.AddBroker("ssl://" + server + ":8883") // TODO get the appropriate port of the comms mq server opts.AddBroker("ssl://" + server + ":8883") // TODO get the appropriate port of the comms mq server
opts.TLSConfig = NewTLSConfig(cfg, "") opts.TLSConfig = NewTLSConfig(cfg, "")
opts.ClientID = ncutils.MakeRandomString(23) // helps avoid id duplication on broker
opts.SetDefaultPublishHandler(All) opts.SetDefaultPublishHandler(All)
opts.SetAutoReconnect(true) opts.SetAutoReconnect(true)
opts.SetConnectRetry(true) opts.SetConnectRetry(true)

View file

@ -3,7 +3,9 @@ package functions
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"os"
"sync" "sync"
"time" "time"
@ -11,6 +13,7 @@ import (
"github.com/gravitl/netmaker/netclient/auth" "github.com/gravitl/netmaker/netclient/auth"
"github.com/gravitl/netmaker/netclient/config" "github.com/gravitl/netmaker/netclient/config"
"github.com/gravitl/netmaker/netclient/ncutils" "github.com/gravitl/netmaker/netclient/ncutils"
"github.com/gravitl/netmaker/tls"
) )
// Checkin -- go routine that checks for public or local ip changes, publishes changes // Checkin -- go routine that checks for public or local ip changes, publishes changes
@ -75,6 +78,7 @@ func Checkin(ctx context.Context, wg *sync.WaitGroup, currentComms map[string]st
} else { } else {
Hello(&nodeCfg) Hello(&nodeCfg)
} }
checkCertExpiry(&nodeCfg)
} }
} }
} }
@ -135,3 +139,19 @@ func publish(nodeCfg *config.ClientConfig, dest string, msg []byte, qos byte) er
} }
return nil return nil
} }
func checkCertExpiry(cfg *config.ClientConfig) error {
cert, err := tls.ReadCert(ncutils.GetNetclientServerPath(cfg.Server.Server) + "/client.pem")
//if cert doesn't exist or will expire within 10 days
if errors.Is(err, os.ErrNotExist) || cert.NotAfter.Before(time.Now().Add(time.Hour*24*10)) {
key, err := tls.ReadKey(ncutils.GetNetclientPath() + "/client.key")
if err != nil {
return err
}
return RegisterWithServer(key, cfg)
}
if err != nil {
return err
}
return nil
}

View file

@ -17,7 +17,7 @@ import (
) )
// Register - the function responsible for registering with the server and acquiring certs // Register - the function responsible for registering with the server and acquiring certs
func Register(cfg *config.ClientConfig) error { func Register(cfg *config.ClientConfig, key string) error {
if cfg.Server.Server == "" { if cfg.Server.Server == "" {
return errors.New("no server provided") return errors.New("no server provided")
} }
@ -35,6 +35,20 @@ func Register(cfg *config.ClientConfig) error {
return err return err
} }
} }
//check if cert exists
_, err = tls.ReadCert(ncutils.GetNetclientServerPath(cfg.Server.Server) + "/client.pem")
if err != os.ErrNotExist {
if err := RegisterWithServer(private, cfg); err != nil {
return err
}
}
if err != nil {
return err
}
return JoinNetwork(cfg, key, false)
}
func RegisterWithServer(private *ed25519.PrivateKey, cfg *config.ClientConfig) error {
data := config.RegisterRequest{ data := config.RegisterRequest{
Key: *private, Key: *private,
CommonName: tls.NewCName(os.Getenv("HOSTNAME")), CommonName: tls.NewCName(os.Getenv("HOSTNAME")),
@ -75,5 +89,5 @@ func Register(cfg *config.ClientConfig) error {
} }
logger.Log(0, "certificates/key saved ") logger.Log(0, "certificates/key saved ")
//join the network defined in the token //join the network defined in the token
return JoinNetwork(cfg, "", false) return nil
} }

View file

@ -170,7 +170,7 @@ func NewEndEntityCert(key ed25519.PrivateKey, req *x509.CertificateRequest, pare
// SaveRequest saves a certificate request to the specified path // SaveRequest saves a certificate request to the specified path
func SaveRequest(path, name string, csr *x509.CertificateRequest) error { func SaveRequest(path, name string, csr *x509.CertificateRequest) error {
if err := os.MkdirAll(path, 0644); err != nil { if err := os.MkdirAll(path, 0600); err != nil {
return err return err
} }
requestOut, err := os.Create(path + name) requestOut, err := os.Create(path + name)
@ -190,7 +190,7 @@ func SaveRequest(path, name string, csr *x509.CertificateRequest) error {
// SaveCert save a certificate to the specified path // SaveCert save a certificate to the specified path
func SaveCert(path, name string, cert *x509.Certificate) error { func SaveCert(path, name string, cert *x509.Certificate) error {
//certbytes, err := x509.ParseCertificate(cert) //certbytes, err := x509.ParseCertificate(cert)
if err := os.MkdirAll(path, 0644); err != nil { if err := os.MkdirAll(path, 0600); err != nil {
return fmt.Errorf("failed to create dir %s %w", path, err) return fmt.Errorf("failed to create dir %s %w", path, err)
} }
certOut, err := os.Create(path + name) certOut, err := os.Create(path + name)
@ -210,7 +210,7 @@ func SaveCert(path, name string, cert *x509.Certificate) error {
// SaveKey save a private key (ed25519) to the specified path // SaveKey save a private key (ed25519) to the specified path
func SaveKey(path, name string, key ed25519.PrivateKey) error { func SaveKey(path, name string, key ed25519.PrivateKey) error {
//func SaveKey(name string, key *ecdsa.PrivateKey) error { //func SaveKey(name string, key *ecdsa.PrivateKey) error {
if err := os.MkdirAll(path, 0644); err != nil { if err := os.MkdirAll(path, 0600); err != nil {
return fmt.Errorf("failed to create dir %s %w", path, err) return fmt.Errorf("failed to create dir %s %w", path, err)
} }
keyOut, err := os.Create(path + name) keyOut, err := os.Create(path + name)