From b34e90c45dba328c114094d2ebd55f59191ad824 Mon Sep 17 00:00:00 2001 From: Ward Vandewege Date: Sun, 2 May 2021 14:47:36 -0400 Subject: [PATCH] Fix bug in preauthkeys: namespace object was not populated in the return value from CreatePreAuthKey and GetPreAuthKeys. Add tests for that bug, and the rest of the preauthkeys functionality. Fix path in `compress` Makefile target. --- Makefile | 2 +- app.go | 5 ++- db.go | 11 +++++-- go.mod | 3 +- go.sum | 2 ++ preauth_keys.go | 3 +- preauth_keys_test.go | 75 ++++++++++++++++++++++++++++++++++++++++++++ 7 files changed, 94 insertions(+), 7 deletions(-) create mode 100644 preauth_keys_test.go diff --git a/Makefile b/Makefile index f6bf5ba5..18498462 100644 --- a/Makefile +++ b/Makefile @@ -20,5 +20,5 @@ lint: golangci-lint run compress: build - upx --brute cmd/headscale/headscale + upx --brute headscale diff --git a/app.go b/app.go index 9047c9b2..d7fd3f2e 100644 --- a/app.go +++ b/app.go @@ -40,6 +40,8 @@ type Config struct { type Headscale struct { cfg Config dbString string + dbType string + dbDebug bool publicKey *wgcfg.Key privateKey *wgcfg.PrivateKey @@ -59,7 +61,8 @@ func NewHeadscale(cfg Config) (*Headscale, error) { } pubKey := privKey.Public() h := Headscale{ - cfg: cfg, + cfg: cfg, + dbType: "postgres", dbString: fmt.Sprintf("host=%s port=%d dbname=%s user=%s password=%s sslmode=disable", cfg.DBhost, cfg.DBport, cfg.DBname, cfg.DBuser, cfg.DBpass), privateKey: privKey, diff --git a/db.go b/db.go index 45e7dd26..0723c529 100644 --- a/db.go +++ b/db.go @@ -16,11 +16,13 @@ type KV struct { } func (h *Headscale) initDB() error { - db, err := gorm.Open("postgres", h.dbString) + db, err := gorm.Open(h.dbType, h.dbString) if err != nil { return err } - db.Exec("create extension if not exists \"uuid-ossp\";") + if h.dbType == "postgres" { + db.Exec("create extension if not exists \"uuid-ossp\";") + } db.AutoMigrate(&Machine{}) db.AutoMigrate(&KV{}) db.AutoMigrate(&Namespace{}) @@ -32,10 +34,13 @@ func (h *Headscale) initDB() error { } func (h *Headscale) db() (*gorm.DB, error) { - db, err := gorm.Open("postgres", h.dbString) + db, err := gorm.Open(h.dbType, h.dbString) if err != nil { return nil, err } + if h.dbDebug { + db.LogMode(true) + } return db, nil } diff --git a/go.mod b/go.mod index 2e7af86e..1718d06c 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/klauspost/compress v1.11.12 github.com/kr/text v0.2.0 // indirect github.com/lib/pq v1.9.0 // indirect + github.com/mattn/go-sqlite3 v1.14.7 // indirect github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e // indirect github.com/spf13/cobra v1.1.3 github.com/spf13/viper v1.7.1 @@ -19,7 +20,7 @@ require ( golang.org/x/text v0.3.5 // indirect golang.org/x/time v0.0.0-20200416051211-89c76fbcd5d1 // indirect google.golang.org/appengine v1.6.6 // indirect - gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f // indirect + gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f gopkg.in/yaml.v2 v2.4.0 inet.af/netaddr v0.0.0-20210317195617-2d42ec05f8a1 tailscale.com v1.6.0 diff --git a/go.sum b/go.sum index 2eba6d73..edae5615 100644 --- a/go.sum +++ b/go.sum @@ -298,6 +298,8 @@ github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHX github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/mattn/go-sqlite3 v1.14.0 h1:mLyGNKR8+Vv9CAU7PphKa2hkEqxxhn8i32J6FPj1/QA= github.com/mattn/go-sqlite3 v1.14.0/go.mod h1:JIl7NbARA7phWnGvh0LKTyg7S9BA+6gx71ShQilpsus= +github.com/mattn/go-sqlite3 v1.14.7 h1:fxWBnXkxfM6sRiuH3bqJ4CfzZojMOLVc0UTsTglEghA= +github.com/mattn/go-sqlite3 v1.14.7/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= github.com/mattn/go-zglob v0.0.1/go.mod h1:9fxibJccNxU2cnpIKLRRFA7zX7qhkJIQWBb449FYHOo= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/mdlayher/ethtool v0.0.0-20210210192532-2b88debcdd43/go.mod h1:+t7E0lkKfbBsebllff1xdTmyJt8lH37niI6kwFk9OTo= diff --git a/preauth_keys.go b/preauth_keys.go index 50db149e..de89b04d 100644 --- a/preauth_keys.go +++ b/preauth_keys.go @@ -42,6 +42,7 @@ func (h *Headscale) CreatePreAuthKey(namespaceName string, reusable bool, expira k := PreAuthKey{ Key: kstr, NamespaceID: n.ID, + Namespace: *n, Reusable: reusable, CreatedAt: &now, Expiration: expiration, @@ -65,7 +66,7 @@ func (h *Headscale) GetPreAuthKeys(namespaceName string) (*[]PreAuthKey, error) defer db.Close() keys := []PreAuthKey{} - if err := db.Where(&PreAuthKey{NamespaceID: n.ID}).Find(&keys).Error; err != nil { + if err := db.Preload("Namespace").Where(&PreAuthKey{NamespaceID: n.ID}).Find(&keys).Error; err != nil { return nil, err } return &keys, nil diff --git a/preauth_keys_test.go b/preauth_keys_test.go new file mode 100644 index 00000000..5ac3bcfb --- /dev/null +++ b/preauth_keys_test.go @@ -0,0 +1,75 @@ +package headscale + +import ( + "fmt" + "io/ioutil" + "os" + "testing" + + _ "github.com/jinzhu/gorm/dialects/sqlite" // sql driver + + "gopkg.in/check.v1" +) + +func Test(t *testing.T) { + check.TestingT(t) +} + +var _ = check.Suite(&Suite{}) + +type Suite struct{} + +var tmpDir string +var h Headscale + +func (s *Suite) SetUpSuite(c *check.C) { + var err error + tmpDir, err = ioutil.TempDir("", "autoygg-client-test") + if err != nil { + c.Fatal(err) + } + fmt.Printf("tmpDir is %s\n", tmpDir) + cfg := Config{} + + h = Headscale{ + cfg: cfg, + dbType: "sqlite3", + dbString: tmpDir + "/headscale_test.db", + } + err = h.initDB() + if err != nil { + c.Fatal(err) + } +} + +func (s *Suite) TearDownSuite(c *check.C) { + os.RemoveAll(tmpDir) +} + +func (*Suite) TestCreatePreAuthKey(c *check.C) { + _, err := h.CreatePreAuthKey("bogus", true, nil) + c.Assert(err, check.NotNil) + + n, err := h.CreateNamespace("test") + c.Assert(err, check.IsNil) + + k, err := h.CreatePreAuthKey(n.Name, true, nil) + c.Assert(err, check.IsNil) + + // Did we get a valid key? + c.Assert(k.Key, check.NotNil) + c.Assert(len(k.Key), check.Equals, 48) + + // Make sure the Namespace association is populated + c.Assert(k.Namespace.Name, check.Equals, n.Name) + + _, err = h.GetPreAuthKeys("bogus") + c.Assert(err, check.NotNil) + + keys, err := h.GetPreAuthKeys(n.Name) + c.Assert(err, check.IsNil) + c.Assert(len(*keys), check.Equals, 1) + + // Make sure the Namespace association is populated + c.Assert((*keys)[0].Namespace.Name, check.Equals, n.Name) +}