diff --git a/main.go b/main.go index f484ecd3..7a126416 100644 --- a/main.go +++ b/main.go @@ -254,9 +254,8 @@ func genCerts() error { return err } - _, scErr := serverctl.ReadClientCertFromDB() serverClientCert, err := serverctl.ReadCertFromDB(tls.SERVER_CLIENT_PEM) - if errors.Is(err, os.ErrNotExist) || database.IsEmptyRecord(err) || database.IsEmptyRecord(scErr) || serverClientCert.NotAfter.Before(time.Now().Add(time.Hour*24*10)) { + if errors.Is(err, os.ErrNotExist) || database.IsEmptyRecord(err) || serverClientCert.NotAfter.Before(time.Now().Add(time.Hour*24*10)) { //gen new key logger.Log(0, "generating new server client key/certificate") _, key, err := ed25519.GenerateKey(rand.Reader) @@ -279,14 +278,13 @@ func genCerts() error { if err := serverctl.SaveCert(functions.GetNetmakerPath()+ncutils.GetSeparator(), tls.SERVER_CLIENT_PEM, serverClientCert); err != nil { return err } - return serverctl.SaveClientCertToDB( - functions.GetNetmakerPath()+ncutils.GetSeparator()+tls.SERVER_CLIENT_PEM, - functions.GetNetmakerPath()+ncutils.GetSeparator()+tls.SERVER_CLIENT_KEY, - ca, - ) } else if err != nil { return err } - return nil + return serverctl.SetClientTLSConf( + functions.GetNetmakerPath()+ncutils.GetSeparator()+tls.SERVER_CLIENT_PEM, + functions.GetNetmakerPath()+ncutils.GetSeparator()+tls.SERVER_CLIENT_KEY, + ca, + ) } diff --git a/mq/mq.go b/mq/mq.go index f2c4f779..3be9335a 100644 --- a/mq/mq.go +++ b/mq/mq.go @@ -2,7 +2,6 @@ package mq import ( "context" - "log" "time" mqtt "github.com/eclipse/paho.mqtt.golang" @@ -28,11 +27,7 @@ func SetupMQTT(publish bool) mqtt.Client { opts.AddBroker(servercfg.GetMessageQueueEndpoint()) id := ncutils.MakeRandomString(23) opts.ClientID = id - tlsConfig, err := serverctl.ReadClientCertFromDB() - if err != nil { - logger.Log(0, "failed to get TLS config for server to broker connection", err.Error()) - } - opts.SetTLSConfig(tlsConfig) + opts.SetTLSConfig(&serverctl.TlsConfig) opts.SetAutoReconnect(true) opts.SetConnectRetry(true) opts.SetConnectRetryInterval(time.Second << 2) @@ -64,9 +59,9 @@ func SetupMQTT(publish bool) mqtt.Client { logger.Log(2, "unable to connect to broker, retrying ...") if time.Now().After(tperiod) { if token.Error() == nil { - log.Fatal(0, "could not connect to broker, token timeout, exiting ...") + logger.FatalLog("could not connect to broker, token timeout, exiting ...") } else { - log.Fatal(0, "could not connect to broker, exiting ...", token.Error()) + logger.FatalLog("could not connect to broker, exiting ...", token.Error().Error()) } } } else { diff --git a/serverctl/tls.go b/serverctl/tls.go index 5ed0db02..2ccbadf3 100644 --- a/serverctl/tls.go +++ b/serverctl/tls.go @@ -13,6 +13,9 @@ import ( "github.com/gravitl/netmaker/tls" ) +// TlsConfig - holds this servers TLS conf in memory +var TlsConfig ssl.Config + // SaveCert - save a certificate to file and DB func SaveCert(path, name string, cert *x509.Certificate) error { if err := SaveCertToDB(name, cert); err != nil { @@ -105,41 +108,33 @@ func ReadKeyFromDB(name string) (*ed25519.PrivateKey, error) { return &private, nil } -// SaveClientCertToDB - saves client cert for servers to connect to MQ broker with -func SaveClientCertToDB(serverClientPemPath, serverClientKeyPath string, ca *x509.Certificate) error { +// SetClientTLSConf - saves client cert for servers to connect to MQ broker with +func SetClientTLSConf(serverClientPemPath, serverClientKeyPath string, ca *x509.Certificate) error { certpool := x509.NewCertPool() - ok := certpool.AppendCertsFromPEM(ca.Raw) - if !ok { - return fmt.Errorf("failed to append root cert to server client cert") + if caData := pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: ca.Raw, + }); len(caData) <= 0 { + return fmt.Errorf("could not encode CA cert to memory for server client") + } else { + ok := certpool.AppendCertsFromPEM(caData) + if !ok { + return fmt.Errorf("failed to append root cert to server client cert") + } } clientKeyPair, err := ssl.LoadX509KeyPair(serverClientPemPath, serverClientKeyPath) if err != nil { return err } certs := []ssl.Certificate{clientKeyPair} - netmakerClientCert := ssl.Config{ + + TlsConfig = ssl.Config{ RootCAs: certpool, ClientAuth: ssl.NoClientCert, ClientCAs: nil, Certificates: certs, InsecureSkipVerify: false, } - data, err := json.Marshal(netmakerClientCert) - if err != nil { - return err - } - return database.Insert(tls.SERVER_CLIENT_ENTRY, string(data), database.CERTS_TABLE_NAME) -} -// ReadClientCertFromDB - reads the client cert from the DB -func ReadClientCertFromDB() (*ssl.Config, error) { - var netmakerClientCert ssl.Config - record, err := database.FetchRecord(database.CERTS_TABLE_NAME, tls.SERVER_CLIENT_ENTRY) - if err != nil { - return nil, err - } - if err = json.Unmarshal([]byte(record), &netmakerClientCert); err != nil { - return nil, err - } - return &netmakerClientCert, err + return nil }