diff --git a/main.go b/main.go index 4886e451..77c13a2f 100644 --- a/main.go +++ b/main.go @@ -192,6 +192,9 @@ func genCerts() error { logger.Log(0, "checking keys and certificates") var private *ed25519.PrivateKey var err error + + // == ROOT key handling == + private, err = serverctl.ReadKeyFromDB(tls.ROOT_KEY_NAME) if errors.Is(err, os.ErrNotExist) || database.IsEmptyRecord(err) { logger.Log(0, "generating new root key") @@ -199,13 +202,17 @@ func genCerts() error { if err != nil { return err } - if err := serverctl.SaveKey(functions.GetNetmakerPath()+ncutils.GetSeparator(), tls.ROOT_KEY_NAME, newKey); err != nil { - return err - } private = &newKey } else if err != nil { return err } + logger.Log(2, "saving root.key") + if err := serverctl.SaveKey(functions.GetNetmakerPath()+ncutils.GetSeparator(), tls.ROOT_KEY_NAME, *private); err != nil { + return err + } + + // == ROOT cert handling == + ca, err := serverctl.ReadCertFromDB(tls.ROOT_PEM_NAME) //if cert doesn't exist or will expire within 10 days --- but can't do this as clients won't be able to connect //if errors.Is(err, os.ErrNotExist) || cert.NotAfter.Before(time.Now().Add(time.Hour*24*10)) { @@ -220,13 +227,17 @@ func genCerts() error { if err != nil { return err } - if err := serverctl.SaveCert(functions.GetNetmakerPath()+ncutils.GetSeparator(), tls.ROOT_PEM_NAME, rootCA); err != nil { - return err - } ca = rootCA } else if err != nil { return err } + logger.Log(2, "saving root.pem") + if err := serverctl.SaveCert(functions.GetNetmakerPath()+ncutils.GetSeparator(), tls.ROOT_PEM_NAME, ca); err != nil { + return err + } + + // == SERVER cert handling == + cert, err := serverctl.ReadCertFromDB(tls.SERVER_PEM_NAME) if errors.Is(err, os.ErrNotExist) || database.IsEmptyRecord(err) || cert.NotAfter.Before(time.Now().Add(time.Hour*24*10)) { //gen new key @@ -240,21 +251,32 @@ func genCerts() error { if err != nil { return err } - cert, err := tls.NewEndEntityCert(*private, csr, ca, tls.CERTIFICATE_VALIDITY) + newCert, err := tls.NewEndEntityCert(*private, csr, ca, tls.CERTIFICATE_VALIDITY) if err != nil { return err } if err := serverctl.SaveKey(functions.GetNetmakerPath()+ncutils.GetSeparator(), tls.SERVER_KEY_NAME, key); err != nil { return err } - if err := serverctl.SaveCert(functions.GetNetmakerPath()+ncutils.GetSeparator(), tls.SERVER_PEM_NAME, cert); err != nil { + cert = newCert + } else if err != nil { + return err + } else if err == nil { + if serverKey, err := serverctl.ReadKeyFromDB(tls.SERVER_KEY_NAME); err == nil { + logger.Log(2, "saving server.key") + if err := serverctl.SaveKey(functions.GetNetmakerPath()+ncutils.GetSeparator(), tls.SERVER_KEY_NAME, *serverKey); err != nil { + return err + } + } else { return err } - } else if err != nil { + } + logger.Log(2, "saving server.pem") + if err := serverctl.SaveCert(functions.GetNetmakerPath()+ncutils.GetSeparator(), tls.SERVER_PEM_NAME, cert); err != nil { return err } - logger.Log(2, "ensure the root.pem, root.key, server.pem, and server.key files are updated on your broker") + // == SERVER-CLIENT connection cert handling == serverClientCert, err := serverctl.ReadCertFromDB(tls.SERVER_CLIENT_PEM) if errors.Is(err, os.ErrNotExist) || database.IsEmptyRecord(err) || serverClientCert.NotAfter.Before(time.Now().Add(time.Hour*24*10)) { @@ -269,7 +291,7 @@ func genCerts() error { if err != nil { return err } - serverClientCert, err := tls.NewEndEntityCert(*private, csr, ca, tls.CERTIFICATE_VALIDITY) + newServerClientCert, err := tls.NewEndEntityCert(*private, csr, ca, tls.CERTIFICATE_VALIDITY) if err != nil { return err } @@ -277,25 +299,27 @@ func genCerts() error { if err := serverctl.SaveKey(functions.GetNetmakerPath()+ncutils.GetSeparator(), tls.SERVER_CLIENT_KEY, key); err != nil { return err } - if err := serverctl.SaveCert(functions.GetNetmakerPath()+ncutils.GetSeparator(), tls.SERVER_CLIENT_PEM, serverClientCert); err != nil { - return err - } + serverClientCert = newServerClientCert } else if err != nil { return err } else if err == nil { - logger.Log(0, "detected valid server client cert, re-saving for future consumption") - key, err := serverctl.ReadKeyFromDB(tls.SERVER_CLIENT_KEY) - if err != nil { - return err - } - if err := serverctl.SaveKey(functions.GetNetmakerPath()+ncutils.GetSeparator(), tls.SERVER_CLIENT_KEY, *key); err != nil { - return err - } - if err := serverctl.SaveCert(functions.GetNetmakerPath()+ncutils.GetSeparator(), tls.SERVER_CLIENT_PEM, serverClientCert); err != nil { + logger.Log(2, "saving serverclient.key") + if serverClientKey, err := serverctl.ReadKeyFromDB(tls.SERVER_CLIENT_KEY); err == nil { + if err := serverctl.SaveKey(functions.GetNetmakerPath()+ncutils.GetSeparator(), tls.SERVER_CLIENT_KEY, *serverClientKey); err != nil { + return err + } + } else { return err } } + logger.Log(2, "saving serverclient.pem") + if err := serverctl.SaveCert(functions.GetNetmakerPath()+ncutils.GetSeparator(), tls.SERVER_CLIENT_PEM, serverClientCert); err != nil { + return err + } + + logger.Log(1, "ensure the root.pem, root.key, server.pem, and server.key files are updated on your broker") + return serverctl.SetClientTLSConf( functions.GetNetmakerPath()+ncutils.GetSeparator()+tls.SERVER_CLIENT_PEM, functions.GetNetmakerPath()+ncutils.GetSeparator()+tls.SERVER_CLIENT_KEY,