diff --git a/api.go b/api.go index 50af5522..92186ee8 100644 --- a/api.go +++ b/api.go @@ -19,10 +19,13 @@ import ( ) const ( - reservedResponseHeaderSize = 4 - RegisterMethodAuthKey = "authKey" - RegisterMethodOIDC = "oidc" - RegisterMethodCLI = "cli" + reservedResponseHeaderSize = 4 + RegisterMethodAuthKey = "authKey" + RegisterMethodOIDC = "oidc" + RegisterMethodCLI = "cli" + ErrRegisterMethodCLIDoesNotSupportExpire = Error( + "machines registered with CLI does not support expire", + ) ) // KeyHandler provides the Headscale pub key @@ -441,7 +444,16 @@ func (h *Headscale) handleMachineRegistrationNew( } if !reqisterRequest.Expiry.IsZero() { - machine.Expiry = &reqisterRequest.Expiry + log.Trace(). + Caller(). + Str("machine", machine.Name). + Time("expiry", reqisterRequest.Expiry). + Msg("Non-zero expiry time requested, adding to cache") + h.requestedExpiryCache.Set( + idKey.HexString(), + reqisterRequest.Expiry, + requestedExpiryCacheExpiration, + ) } machine.NodeKey = wgkey.Key(reqisterRequest.NodeKey).HexString() // save the NodeKey diff --git a/app.go b/app.go index b2b545b4..0d3332dd 100644 --- a/app.go +++ b/app.go @@ -53,6 +53,9 @@ const ( updateInterval = 5000 HTTPReadTimeout = 30 * time.Second + requestedExpiryCacheExpiration = time.Minute * 5 + requestedExpiryCacheCleanupInterval = time.Minute * 10 + errUnsupportedDatabase = Error("unsupported DB") errUnsupportedLetsEncryptChallengeType = Error( "unknown value for Lets Encrypt challenge type", @@ -139,6 +142,8 @@ type Headscale struct { oidcProvider *oidc.Provider oauth2Config *oauth2.Config oidcStateCache *cache.Cache + + requestedExpiryCache *cache.Cache } // NewHeadscale returns the Headscale app. @@ -171,13 +176,19 @@ func NewHeadscale(cfg Config) (*Headscale, error) { return nil, errUnsupportedDatabase } + requestedExpiryCache := cache.New( + requestedExpiryCacheExpiration, + requestedExpiryCacheCleanupInterval, + ) + app := Headscale{ - cfg: cfg, - dbType: cfg.DBtype, - dbString: dbString, - privateKey: privKey, - publicKey: &pubKey, - aclRules: tailcfg.FilterAllowAll, // default allowall + cfg: cfg, + dbType: cfg.DBtype, + dbString: dbString, + privateKey: privKey, + publicKey: &pubKey, + aclRules: tailcfg.FilterAllowAll, // default allowall + requestedExpiryCache: requestedExpiryCache, } err = app.initDB() diff --git a/machine.go b/machine.go index dd7124ec..06fec743 100644 --- a/machine.go +++ b/machine.go @@ -616,6 +616,31 @@ func (h *Headscale) RegisterMachine( return nil, errMachineNotFound } + // TODO(kradalby): Currently, if it fails to find a requested expiry, non will be set + // This means that if a user is to slow with register a machine, it will possibly not + // have the correct expiry. + requestedTime := time.Time{} + if requestedTimeIf, found := h.requestedExpiryCache.Get(machineKey.HexString()); found { + log.Trace(). + Caller(). + Str("machine", machine.Name). + Msg("Expiry time found in cache, assigning to node") + if reqTime, ok := requestedTimeIf.(time.Time); ok { + requestedTime = reqTime + } + } + + if machine.isRegistered() { + log.Trace(). + Caller(). + Str("machine", machine.Name). + Msg("machine already registered, reauthenticating") + + h.RefreshMachine(&machine, requestedTime) + + return &machine, nil + } + log.Trace(). Caller(). Str("machine", machine.Name). diff --git a/oidc.go b/oidc.go index 2f7d6d61..9b0a3087 100644 --- a/oidc.go +++ b/oidc.go @@ -199,6 +199,14 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { return } + // TODO(kradalby): Currently, if it fails to find a requested expiry, non will be set + requestedTime := time.Time{} + if requestedTimeIf, found := h.requestedExpiryCache.Get(machineKey); found { + if reqTime, ok := requestedTimeIf.(time.Time); ok { + requestedTime = reqTime + } + } + // retrieve machine information machine, err := h.GetMachineByMachineKey(machineKey) if err != nil {