Add cache for requested expiry times

This commit adds a sentral cache to keep track of clients whom has
requested an expiry time, but were we need to keep hold of it until the
second request comes in.
This commit is contained in:
Kristoffer Dalby 2021-11-22 19:32:52 +00:00
parent e600ead3e9
commit 021c464148
4 changed files with 67 additions and 11 deletions

22
api.go
View file

@ -19,10 +19,13 @@ import (
)
const (
reservedResponseHeaderSize = 4
RegisterMethodAuthKey = "authKey"
RegisterMethodOIDC = "oidc"
RegisterMethodCLI = "cli"
reservedResponseHeaderSize = 4
RegisterMethodAuthKey = "authKey"
RegisterMethodOIDC = "oidc"
RegisterMethodCLI = "cli"
ErrRegisterMethodCLIDoesNotSupportExpire = Error(
"machines registered with CLI does not support expire",
)
)
// KeyHandler provides the Headscale pub key
@ -441,7 +444,16 @@ func (h *Headscale) handleMachineRegistrationNew(
}
if !reqisterRequest.Expiry.IsZero() {
machine.Expiry = &reqisterRequest.Expiry
log.Trace().
Caller().
Str("machine", machine.Name).
Time("expiry", reqisterRequest.Expiry).
Msg("Non-zero expiry time requested, adding to cache")
h.requestedExpiryCache.Set(
idKey.HexString(),
reqisterRequest.Expiry,
requestedExpiryCacheExpiration,
)
}
machine.NodeKey = wgkey.Key(reqisterRequest.NodeKey).HexString() // save the NodeKey

23
app.go
View file

@ -53,6 +53,9 @@ const (
updateInterval = 5000
HTTPReadTimeout = 30 * time.Second
requestedExpiryCacheExpiration = time.Minute * 5
requestedExpiryCacheCleanupInterval = time.Minute * 10
errUnsupportedDatabase = Error("unsupported DB")
errUnsupportedLetsEncryptChallengeType = Error(
"unknown value for Lets Encrypt challenge type",
@ -139,6 +142,8 @@ type Headscale struct {
oidcProvider *oidc.Provider
oauth2Config *oauth2.Config
oidcStateCache *cache.Cache
requestedExpiryCache *cache.Cache
}
// NewHeadscale returns the Headscale app.
@ -171,13 +176,19 @@ func NewHeadscale(cfg Config) (*Headscale, error) {
return nil, errUnsupportedDatabase
}
requestedExpiryCache := cache.New(
requestedExpiryCacheExpiration,
requestedExpiryCacheCleanupInterval,
)
app := Headscale{
cfg: cfg,
dbType: cfg.DBtype,
dbString: dbString,
privateKey: privKey,
publicKey: &pubKey,
aclRules: tailcfg.FilterAllowAll, // default allowall
cfg: cfg,
dbType: cfg.DBtype,
dbString: dbString,
privateKey: privKey,
publicKey: &pubKey,
aclRules: tailcfg.FilterAllowAll, // default allowall
requestedExpiryCache: requestedExpiryCache,
}
err = app.initDB()

View file

@ -616,6 +616,31 @@ func (h *Headscale) RegisterMachine(
return nil, errMachineNotFound
}
// TODO(kradalby): Currently, if it fails to find a requested expiry, non will be set
// This means that if a user is to slow with register a machine, it will possibly not
// have the correct expiry.
requestedTime := time.Time{}
if requestedTimeIf, found := h.requestedExpiryCache.Get(machineKey.HexString()); found {
log.Trace().
Caller().
Str("machine", machine.Name).
Msg("Expiry time found in cache, assigning to node")
if reqTime, ok := requestedTimeIf.(time.Time); ok {
requestedTime = reqTime
}
}
if machine.isRegistered() {
log.Trace().
Caller().
Str("machine", machine.Name).
Msg("machine already registered, reauthenticating")
h.RefreshMachine(&machine, requestedTime)
return &machine, nil
}
log.Trace().
Caller().
Str("machine", machine.Name).

View file

@ -199,6 +199,14 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
return
}
// TODO(kradalby): Currently, if it fails to find a requested expiry, non will be set
requestedTime := time.Time{}
if requestedTimeIf, found := h.requestedExpiryCache.Get(machineKey); found {
if reqTime, ok := requestedTimeIf.(time.Time); ok {
requestedTime = reqTime
}
}
// retrieve machine information
machine, err := h.GetMachineByMachineKey(machineKey)
if err != nil {