diff --git a/cert/cert.go b/cert/cert.go index 216efcf..afc6b05 100644 --- a/cert/cert.go +++ b/cert/cert.go @@ -393,7 +393,7 @@ func (nc *NebulaCertificate) Expired(t time.Time) bool { // Verify will ensure a certificate is good in all respects (expiry, group membership, signature, cert blocklist, etc) func (nc *NebulaCertificate) Verify(t time.Time, ncp *NebulaCAPool) (bool, error) { if ncp.IsBlocklisted(nc) { - return false, fmt.Errorf("certificate has been blocked") + return false, ErrBlockListed } signer, err := ncp.GetCAForCert(nc) @@ -402,15 +402,15 @@ func (nc *NebulaCertificate) Verify(t time.Time, ncp *NebulaCAPool) (bool, error } if signer.Expired(t) { - return false, fmt.Errorf("root certificate is expired") + return false, ErrRootExpired } if nc.Expired(t) { - return false, fmt.Errorf("certificate is expired") + return false, ErrExpired } if !nc.CheckSignature(signer.Details.PublicKey) { - return false, fmt.Errorf("certificate signature did not match") + return false, ErrSignatureMismatch } if err := nc.CheckRootConstrains(signer); err != nil { diff --git a/cert/cert_test.go b/cert/cert_test.go index ece9f7f..0fe2f39 100644 --- a/cert/cert_test.go +++ b/cert/cert_test.go @@ -177,7 +177,7 @@ func TestNebulaCertificate_Verify(t *testing.T) { v, err := c.Verify(time.Now(), caPool) assert.False(t, v) - assert.EqualError(t, err, "certificate has been blocked") + assert.EqualError(t, err, "certificate is in the block list") caPool.ResetCertBlocklist() v, err = c.Verify(time.Now(), caPool) diff --git a/cert/errors.go b/cert/errors.go index 3135467..05b42d1 100644 --- a/cert/errors.go +++ b/cert/errors.go @@ -1,9 +1,14 @@ package cert -import "errors" +import ( + "errors" +) var ( - ErrExpired = errors.New("certificate is expired") - ErrNotCA = errors.New("certificate is not a CA") - ErrNotSelfSigned = errors.New("certificate is not self-signed") + ErrRootExpired = errors.New("root certificate is expired") + ErrExpired = errors.New("certificate is expired") + ErrNotCA = errors.New("certificate is not a CA") + ErrNotSelfSigned = errors.New("certificate is not self-signed") + ErrBlockListed = errors.New("certificate is in the block list") + ErrSignatureMismatch = errors.New("certificate signature did not match") ) diff --git a/connection_manager.go b/connection_manager.go index f94f54a..0a629a8 100644 --- a/connection_manager.go +++ b/connection_manager.go @@ -8,6 +8,7 @@ import ( "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/udp" @@ -419,12 +420,9 @@ func (n *connectionManager) swapPrimary(current, primary *HostInfo) { } // isInvalidCertificate will check if we should destroy a tunnel if pki.disconnect_invalid is true and -// the certificate is no longer valid +// the certificate is no longer valid. Block listed certificates will skip the pki.disconnect_invalid +// check and return true. func (n *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool { - if !n.intf.disconnectInvalid { - return false - } - remoteCert := hostinfo.GetCert() if remoteCert == nil { return false @@ -435,6 +433,11 @@ func (n *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostIn return false } + if !n.intf.disconnectInvalid && err != cert.ErrBlockListed { + // Block listed certificates should always be disconnected + return false + } + fingerprint, _ := remoteCert.Sha256Sum() hostinfo.logger(n.l).WithError(err). WithField("fingerprint", fingerprint).