diff --git a/.golangci.yml b/.golangci.yml index cf01fa4..37cf7f2 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,8 +1,10 @@ run: deadline: 1m tests: false - #skip-files: - # - ".*\\.gen\\.go" + skip-files: + - "testing.go" + - ".*\\.pb\\.go" + - ".*\\.gen\\.go" linters-settings: golint: @@ -18,17 +20,36 @@ linters-settings: linters: disable-all: true enable: - - goconst - - misspell + - bodyclose - deadcode - - misspell - - structcheck + - depguard + - dogsled + #- dupl - errcheck - - unused - - varcheck - - staticcheck - - unconvert + #- funlen + - gochecknoinits + #- gocognit + - goconst + - gocritic + #- gocyclo - gofmt - goimports - golint + - gosimple + - govet - ineffassign + - interfacer + #- maligned + - misspell + - nakedret + - prealloc + - scopelint + - staticcheck + - structcheck + #- stylecheck + - typecheck + - unconvert + - unparam + - unused + - varcheck + - whitespace diff --git a/pkg/bastion/acl.go b/pkg/bastion/acl.go index a641aa1..fbecc95 100644 --- a/pkg/bastion/acl.go +++ b/pkg/bastion/acl.go @@ -12,7 +12,7 @@ func (a byWeight) Len() int { return len(a) } func (a byWeight) Swap(i, j int) { a[i], a[j] = a[j], a[i] } func (a byWeight) Less(i, j int) bool { return a[i].Weight < a[j].Weight } -func checkACLs(user dbmodels.User, host dbmodels.Host) (string, error) { +func checkACLs(user dbmodels.User, host dbmodels.Host) string { // shared ACLs between user and host aclMap := map[uint]*dbmodels.ACL{} for _, userGroup := range user.Groups { @@ -30,7 +30,7 @@ func checkACLs(user dbmodels.User, host dbmodels.Host) (string, error) { // deny by default if no shared ACL if len(aclMap) == 0 { - return string(dbmodels.ACLActionDeny), nil // default action + return string(dbmodels.ACLActionDeny) // default action } // transform map to slice and sort it @@ -40,5 +40,5 @@ func checkACLs(user dbmodels.User, host dbmodels.Host) (string, error) { } sort.Sort(byWeight(acls)) - return acls[0].Action, nil + return acls[0].Action } diff --git a/pkg/bastion/acl_test.go b/pkg/bastion/acl_test.go index 794360f..3c49fb7 100644 --- a/pkg/bastion/acl_test.go +++ b/pkg/bastion/acl_test.go @@ -43,8 +43,7 @@ func TestCheckACLs(t *testing.T) { db.Preload("Groups").Preload("Groups.ACLs").Find(&users) // test - action, err := checkACLs(users[0], hosts[0]) - c.So(err, ShouldBeNil) + action := checkACLs(users[0], hosts[0]) c.So(action, ShouldEqual, dbmodels.ACLActionAllow) }) } diff --git a/pkg/bastion/dbinit.go b/pkg/bastion/dbinit.go index 5c4bf26..0ff46cc 100644 --- a/pkg/bastion/dbinit.go +++ b/pkg/bastion/dbinit.go @@ -261,14 +261,14 @@ func DBInit(db *gorm.DB) error { return err } - var users []dbmodels.User + var users []*dbmodels.User if err := db.Preload("Roles").Where("is_admin = ?", true).Find(&users).Error; err != nil { return err } for _, user := range users { user.Roles = append(user.Roles, &adminRole) - if err := tx.Save(&user).Error; err != nil { + if err := tx.Save(user).Error; err != nil { return err } } @@ -358,7 +358,7 @@ func DBInit(db *gorm.DB) error { }, { ID: "24", Migrate: func(tx *gorm.DB) error { - var userKeys []dbmodels.UserKey + var userKeys []*dbmodels.UserKey if err := db.Find(&userKeys).Error; err != nil { return err } @@ -369,7 +369,7 @@ func DBInit(db *gorm.DB) error { return err } userKey.AuthorizedKey = string(gossh.MarshalAuthorizedKey(key)) - if err := db.Model(&userKey).Updates(&userKey).Error; err != nil { + if err := db.Model(userKey).Updates(userKey).Error; err != nil { return err } } @@ -422,14 +422,14 @@ func DBInit(db *gorm.DB) error { }, { ID: "27", Migrate: func(tx *gorm.DB) error { - var sessions []dbmodels.Session + var sessions []*dbmodels.Session if err := db.Find(&sessions).Error; err != nil { return err } for _, session := range sessions { if session.StoppedAt != nil && session.StoppedAt.IsZero() { - if err := db.Model(&session).Updates(map[string]interface{}{"stopped_at": nil}).Error; err != nil { + if err := db.Model(session).Updates(map[string]interface{}{"stopped_at": nil}).Error; err != nil { return err } } diff --git a/pkg/bastion/session.go b/pkg/bastion/session.go index 31a2e39..fb40731 100644 --- a/pkg/bastion/session.go +++ b/pkg/bastion/session.go @@ -19,7 +19,7 @@ type sessionConfig struct { ClientConfig *gossh.ClientConfig } -func multiChannelHandler(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context, configs []sessionConfig, sessionID uint) error { +func multiChannelHandler(conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context, configs []sessionConfig, sessionID uint) error { var lastClient *gossh.Client switch newChan.ChannelType() { case "session": diff --git a/pkg/bastion/shell.go b/pkg/bastion/shell.go index 2f00a96..6d12893 100644 --- a/pkg/bastion/shell.go +++ b/pkg/bastion/shell.go @@ -67,6 +67,8 @@ GLOBAL OPTIONS: app.Writer = s app.HideVersion = true + dbmodels.InitValidator() + var ( myself = &actx.user db = actx.db @@ -173,10 +175,8 @@ GLOBAL OPTIONS: return err } acls = append(acls, &acl) - } else { - if err := query.Find(&acls).Error; err != nil { - return err - } + } else if err := query.Find(&acls).Error; err != nil { + return err } if c.Bool("quiet") { for _, acl := range acls { @@ -250,14 +250,14 @@ GLOBAL OPTIONS: return err } - var acls []dbmodels.ACL + var acls []*dbmodels.ACL if err := dbmodels.ACLsByIdentifiers(db, c.Args()).Find(&acls).Error; err != nil { return err } tx := db.Begin() for _, acl := range acls { - model := tx.Model(&acl) + model := tx.Model(acl) update := dbmodels.ACL{ Action: c.String("action"), HostPattern: c.String("pattern"), @@ -477,6 +477,7 @@ GLOBAL OPTIONS: } } for _, host := range config.Hosts { + host := host crypto.HostDecrypt(actx.aesKey, host) if !c.Bool("decrypt") { if err := crypto.HostEncrypt(actx.aesKey, host); err != nil { @@ -489,30 +490,35 @@ GLOBAL OPTIONS: } } for _, user := range config.Users { + user := user if err := tx.FirstOrCreate(&user).Error; err != nil { tx.Rollback() return err } } for _, acl := range config.ACLs { + acl := acl if err := tx.FirstOrCreate(&acl).Error; err != nil { tx.Rollback() return err } } for _, hostGroup := range config.HostGroups { + hostGroup := hostGroup if err := tx.FirstOrCreate(&hostGroup).Error; err != nil { tx.Rollback() return err } } for _, userGroup := range config.UserGroups { + userGroup := userGroup if err := tx.FirstOrCreate(&userGroup).Error; err != nil { tx.Rollback() return err } } for _, sshKey := range config.SSHKeys { + sshKey := sshKey crypto.SSHKeyDecrypt(actx.aesKey, sshKey) if !c.Bool("decrypt") { if err := crypto.SSHKeyEncrypt(actx.aesKey, sshKey); err != nil { @@ -525,24 +531,28 @@ GLOBAL OPTIONS: } } for _, userKey := range config.UserKeys { + userKey := userKey if err := tx.FirstOrCreate(&userKey).Error; err != nil { tx.Rollback() return err } } for _, setting := range config.Settings { + setting := setting if err := tx.FirstOrCreate(&setting).Error; err != nil { tx.Rollback() return err } } for _, session := range config.Sessions { + session := session if err := tx.FirstOrCreate(&session).Error; err != nil { tx.Rollback() return err } } for _, event := range config.Events { + event := event if err := tx.FirstOrCreate(&event).Error; err != nil { tx.Rollback() return err @@ -612,10 +622,8 @@ GLOBAL OPTIONS: return err } events = append(events, event) - } else { - if err := query.Find(&events).Error; err != nil { - return err - } + } else if err := query.Find(&events).Error; err != nil { + return err } if c.Bool("quiet") { @@ -799,10 +807,8 @@ GLOBAL OPTIONS: return err } hosts = append(hosts, &host) - } else { - if err := query.Find(&hosts).Error; err != nil { - return err - } + } else if err := query.Find(&hosts).Error; err != nil { + return err } if c.Bool("quiet") { @@ -820,7 +826,7 @@ GLOBAL OPTIONS: authKey := "" if host.SSHKeyID > 0 { var key dbmodels.SSHKey - db.Model(&host).Related(&key) + db.Model(host).Related(&key) authKey = key.Name } groupNames := []string{} @@ -830,7 +836,7 @@ GLOBAL OPTIONS: var hop string if host.HopID != 0 { var hopHost dbmodels.Host - db.Model(&host).Related(&hopHost, "HopID") + db.Model(host).Related(&hopHost, "HopID") hop = hopHost.Name } else { hop = "" @@ -900,6 +906,7 @@ GLOBAL OPTIONS: tx := db.Begin() for _, host := range hosts { + host := host model := tx.Model(&host) // simple fields for _, fieldname := range []string{"name", "comment"} { @@ -1063,10 +1070,8 @@ GLOBAL OPTIONS: return err } hostGroups = append(hostGroups, &hostGroup) - } else { - if err := query.Find(&hostGroups).Error; err != nil { - return err - } + } else if err := query.Find(&hostGroups).Error; err != nil { + return err } if c.Bool("quiet") { @@ -1127,7 +1132,7 @@ GLOBAL OPTIONS: return err } - var hostgroups []dbmodels.HostGroup + var hostgroups []*dbmodels.HostGroup if err := dbmodels.HostGroupsByIdentifiers(db, c.Args()).Find(&hostgroups).Error; err != nil { return err } @@ -1138,7 +1143,7 @@ GLOBAL OPTIONS: tx := db.Begin() for _, hostgroup := range hostgroups { - model := tx.Model(&hostgroup) + model := tx.Model(hostgroup) // simple fields for _, fieldname := range []string{"name", "comment"} { if c.String(fieldname) != "" { @@ -1342,10 +1347,8 @@ GLOBAL OPTIONS: return err } sshKeys = append(sshKeys, &sshKey) - } else { - if err := query.Find(&sshKeys).Error; err != nil { - return err - } + } else if err := query.Find(&sshKeys).Error; err != nil { + return err } if c.Bool("quiet") { for _, sshKey := range sshKeys { @@ -1584,10 +1587,8 @@ GLOBAL OPTIONS: return err } users = append(users, &user) - } else { - if err := query.Find(&users).Error; err != nil { - return err - } + } else if err := query.Find(&users).Error; err != nil { + return err } if c.Bool("quiet") { for _, user := range users { @@ -1661,7 +1662,7 @@ GLOBAL OPTIONS: } // FIXME: check if unset-admin + user == myself - var users []dbmodels.User + var users []*dbmodels.User if err := dbmodels.UsersByIdentifiers(db, c.Args()).Find(&users).Error; err != nil { return err } @@ -1676,7 +1677,7 @@ GLOBAL OPTIONS: tx := db.Begin() for _, user := range users { - model := tx.Model(&user) + model := tx.Model(user) // simple fields for _, fieldname := range []string{"name", "email", "comment"} { if c.String(fieldname) != "" { @@ -1814,10 +1815,8 @@ GLOBAL OPTIONS: return err } userGroups = append(userGroups, &userGroup) - } else { - if err := query.Find(&userGroups).Error; err != nil { - return err - } + } else if err := query.Find(&userGroups).Error; err != nil { + return err } if c.Bool("quiet") { for _, userGroup := range userGroups { @@ -1877,7 +1876,7 @@ GLOBAL OPTIONS: return err } - var usergroups []dbmodels.UserGroup + var usergroups []*dbmodels.UserGroup if err := dbmodels.UserGroupsByIdentifiers(db, c.Args()).Find(&usergroups).Error; err != nil { return err } @@ -1888,7 +1887,7 @@ GLOBAL OPTIONS: tx := db.Begin() for _, usergroup := range usergroups { - model := tx.Model(&usergroup) + model := tx.Model(usergroup) // simple fields for _, fieldname := range []string{"name", "comment"} { if c.String(fieldname) != "" { @@ -2001,10 +2000,8 @@ GLOBAL OPTIONS: return err } userKeys = append(userKeys, &userKey) - } else { - if err := query.Find(&userKeys).Error; err != nil { - return err - } + } else if err := query.Find(&userKeys).Error; err != nil { + return err } if c.Bool("quiet") { for _, userKey := range userKeys { @@ -2112,7 +2109,6 @@ GLOBAL OPTIONS: factor := 1 for len(sessions) >= limit*factor { - var additionnalSessions []*dbmodels.Session offset = limit * factor diff --git a/pkg/bastion/ssh.go b/pkg/bastion/ssh.go index a3826e6..c81cc02 100644 --- a/pkg/bastion/ssh.go +++ b/pkg/bastion/ssh.go @@ -149,7 +149,7 @@ func ChannelHandler(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewCh return } go func() { - err = multiChannelHandler(srv, conn, newChan, ctx, sessionConfigs, sess.ID) + err = multiChannelHandler(conn, newChan, ctx, sessionConfigs, sess.ID) if err != nil { log.Printf("Error: %v", err) } @@ -204,11 +204,8 @@ func bastionClientConfig(ctx ssh.Context, host *dbmodels.Host) (*gossh.ClientCon if err = actx.db.Preload("Groups").Preload("Groups.ACLs").Where("id = ?", host.ID).First(&tmpHost).Error; err != nil { return nil, err } - action, err2 := checkACLs(tmpUser, tmpHost) - if err2 != nil { - return nil, err2 - } + action := checkACLs(tmpUser, tmpHost) switch action { case string(dbmodels.ACLActionAllow): // do nothing diff --git a/pkg/dbmodels/dbmodels.go b/pkg/dbmodels/dbmodels.go index 967856a..d5198cb 100644 --- a/pkg/dbmodels/dbmodels.go +++ b/pkg/dbmodels/dbmodels.go @@ -5,12 +5,10 @@ import ( "fmt" "log" "net/url" - "regexp" "strconv" "strings" "time" - "github.com/asaskevich/govalidator" "github.com/jinzhu/gorm" gossh "golang.org/x/crypto/ssh" ) @@ -166,18 +164,6 @@ const ( BastionSchemeTelnet BastionScheme = "telnet" ) -func init() { - unixUserRegexp := regexp.MustCompile("[a-z_][a-z0-9_-]*") - - govalidator.CustomTypeTagMap.Set("unix_user", govalidator.CustomTypeValidator(func(i interface{}, context interface{}) bool { - name, ok := i.(string) - if !ok { - return false - } - return unixUserRegexp.MatchString(name) - })) -} - // Host helpers func (host *Host) DialAddr() string { diff --git a/pkg/dbmodels/validator.go b/pkg/dbmodels/validator.go new file mode 100644 index 0000000..7dd0efa --- /dev/null +++ b/pkg/dbmodels/validator.go @@ -0,0 +1,19 @@ +package dbmodels + +import ( + "regexp" + + "github.com/asaskevich/govalidator" +) + +func InitValidator() { + unixUserRegexp := regexp.MustCompile("[a-z_][a-z0-9_-]*") + + govalidator.CustomTypeTagMap.Set("unix_user", govalidator.CustomTypeValidator(func(i interface{}, context interface{}) bool { + name, ok := i.(string) + if !ok { + return false + } + return unixUserRegexp.MatchString(name) + })) +}