diff --git a/main.go b/main.go index b6f9a7b4..f484ecd3 100644 --- a/main.go +++ b/main.go @@ -120,7 +120,9 @@ func initialize() { // Client Mode Prereq Check } } - genCerts() + if err = genCerts(); err != nil { + logger.Log(0, "something went wrong when generating broker certs", err.Error()) + } if servercfg.IsMessageQueueBackend() { if err = mq.ServerStartNotify(); err != nil { @@ -251,5 +253,40 @@ func genCerts() error { } else if err != nil { return err } + + _, scErr := serverctl.ReadClientCertFromDB() + serverClientCert, err := serverctl.ReadCertFromDB(tls.SERVER_CLIENT_PEM) + if errors.Is(err, os.ErrNotExist) || database.IsEmptyRecord(err) || database.IsEmptyRecord(scErr) || serverClientCert.NotAfter.Before(time.Now().Add(time.Hour*24*10)) { + //gen new key + logger.Log(0, "generating new server client key/certificate") + _, key, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + return err + } + serverName := tls.NewCName(servercfg.GetServer()) + csr, err := tls.NewCSR(key, serverName) + if err != nil { + return err + } + serverClientCert, err := tls.NewEndEntityCert(*private, csr, ca, tls.CERTIFICATE_VALIDITY) + if err != nil { + return err + } + + if err := serverctl.SaveKey(functions.GetNetmakerPath()+ncutils.GetSeparator(), tls.SERVER_CLIENT_KEY, key); err != nil { + return err + } + if err := serverctl.SaveCert(functions.GetNetmakerPath()+ncutils.GetSeparator(), tls.SERVER_CLIENT_PEM, serverClientCert); err != nil { + return err + } + return serverctl.SaveClientCertToDB( + functions.GetNetmakerPath()+ncutils.GetSeparator()+tls.SERVER_CLIENT_PEM, + functions.GetNetmakerPath()+ncutils.GetSeparator()+tls.SERVER_CLIENT_KEY, + ca, + ) + } else if err != nil { + return err + } + return nil } diff --git a/mq/mq.go b/mq/mq.go index 756ba93a..f2c4f779 100644 --- a/mq/mq.go +++ b/mq/mq.go @@ -9,6 +9,7 @@ import ( "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/netclient/ncutils" "github.com/gravitl/netmaker/servercfg" + "github.com/gravitl/netmaker/serverctl" ) // KEEPALIVE_TIMEOUT - time in seconds for timeout @@ -27,6 +28,11 @@ func SetupMQTT(publish bool) mqtt.Client { opts.AddBroker(servercfg.GetMessageQueueEndpoint()) id := ncutils.MakeRandomString(23) opts.ClientID = id + tlsConfig, err := serverctl.ReadClientCertFromDB() + if err != nil { + logger.Log(0, "failed to get TLS config for server to broker connection", err.Error()) + } + opts.SetTLSConfig(tlsConfig) opts.SetAutoReconnect(true) opts.SetConnectRetry(true) opts.SetConnectRetryInterval(time.Second << 2) @@ -68,6 +74,9 @@ func SetupMQTT(publish bool) mqtt.Client { } time.Sleep(2 * time.Second) } + if !publish { + logger.Log(0, "successfully connected to mq broker") + } return client } diff --git a/serverctl/tls.go b/serverctl/tls.go index 67b04018..5ed0db02 100644 --- a/serverctl/tls.go +++ b/serverctl/tls.go @@ -2,6 +2,7 @@ package serverctl import ( "crypto/ed25519" + ssl "crypto/tls" "crypto/x509" "encoding/json" "encoding/pem" @@ -103,3 +104,42 @@ func ReadKeyFromDB(name string) (*ed25519.PrivateKey, error) { private := key.(ed25519.PrivateKey) return &private, nil } + +// SaveClientCertToDB - saves client cert for servers to connect to MQ broker with +func SaveClientCertToDB(serverClientPemPath, serverClientKeyPath string, ca *x509.Certificate) error { + certpool := x509.NewCertPool() + ok := certpool.AppendCertsFromPEM(ca.Raw) + if !ok { + return fmt.Errorf("failed to append root cert to server client cert") + } + clientKeyPair, err := ssl.LoadX509KeyPair(serverClientPemPath, serverClientKeyPath) + if err != nil { + return err + } + certs := []ssl.Certificate{clientKeyPair} + netmakerClientCert := ssl.Config{ + RootCAs: certpool, + ClientAuth: ssl.NoClientCert, + ClientCAs: nil, + Certificates: certs, + InsecureSkipVerify: false, + } + data, err := json.Marshal(netmakerClientCert) + if err != nil { + return err + } + return database.Insert(tls.SERVER_CLIENT_ENTRY, string(data), database.CERTS_TABLE_NAME) +} + +// ReadClientCertFromDB - reads the client cert from the DB +func ReadClientCertFromDB() (*ssl.Config, error) { + var netmakerClientCert ssl.Config + record, err := database.FetchRecord(database.CERTS_TABLE_NAME, tls.SERVER_CLIENT_ENTRY) + if err != nil { + return nil, err + } + if err = json.Unmarshal([]byte(record), &netmakerClientCert); err != nil { + return nil, err + } + return &netmakerClientCert, err +} diff --git a/tls/tls.go b/tls/tls.go index 1cc6eb9a..2af5ebb1 100644 --- a/tls/tls.go +++ b/tls/tls.go @@ -18,7 +18,6 @@ import ( ) const ( - // CERTTIFICATE_VALIDITY duration of certificate validity in days CERTIFICATE_VALIDITY = 365 @@ -33,6 +32,15 @@ const ( // ROOT_PEM_NAME - name of root pem ROOT_PEM_NAME = "root.pem" + + // SERVER_CLIENT_PEM - the name of server client cert + SERVER_CLIENT_PEM = "serverclient.pem" + + // SERVER_CLIENT_KEY - the name of server client key + SERVER_CLIENT_KEY = "serverclient.key" + + // SERVER_CLIENT_ENTRY - the server client cert key for DB + SERVER_CLIENT_ENTRY = "servercliententry" ) type (