main: remove globalContext, and move some functions outside of the main

This commit is contained in:
Quentin Perez 2018-01-06 19:46:00 +01:00
parent 9cc09b320d
commit 5720123576
No known key found for this signature in database
GPG key ID: 7C6DCB859CF22206
11 changed files with 295 additions and 248 deletions

2
acl.go
View file

@ -30,7 +30,7 @@ func CheckACLs(user User, host Host) (string, error) {
}
// transform map to slice and sort it
acls := []*ACL{}
acls := make([]*ACL, 0, len(aclMap))
for _, acl := range aclMap {
acls = append(acls, acl)
}

49
config.go Normal file
View file

@ -0,0 +1,49 @@
package main
import (
"fmt"
"os"
"github.com/urfave/cli"
)
type configServe struct {
aesKey string
dbDriver, dbURL string
logsLocation string
bindAddr string
debug, demo bool
}
func parseServeConfig(c *cli.Context) (*configServe, error) {
ret := &configServe{
aesKey: c.String("aes-key"),
dbDriver: c.String("db-driver"),
dbURL: c.String("db-conn"),
bindAddr: c.String("bind-address"),
debug: c.Bool("debug"),
demo: c.Bool("demo"),
logsLocation: c.String("logs-location"),
}
switch len(ret.aesKey) {
case 0, 16, 24, 32:
default:
return nil, fmt.Errorf("invalid aes key size, should be 16 or 24, 32")
}
return ret, nil
}
func ensureLogDirectory(location string) error {
// check for the logdir existence
logsLocation, err := os.Stat(location)
if err != nil {
if os.IsNotExist(err) {
return os.MkdirAll(location, os.ModeDir|os.FileMode(0750))
}
return err
}
if !logsLocation.IsDir() {
return fmt.Errorf("log directory cannot be created")
}
return nil
}

View file

@ -95,17 +95,14 @@ func safeDecrypt(key []byte, cryptoText string) string {
return out
}
func HostEncrypt(aesKey string, host *Host) error {
func HostEncrypt(aesKey string, host *Host) (err error) {
if aesKey == "" {
return nil
}
var err error
if host.Password != "" {
if host.Password, err = encrypt([]byte(aesKey), host.Password); err != nil {
return err
host.Password, err = encrypt([]byte(aesKey), host.Password)
}
}
return nil
return
}
func HostDecrypt(aesKey string, host *Host) {
if aesKey == "" {
@ -116,13 +113,12 @@ func HostDecrypt(aesKey string, host *Host) {
}
}
func SSHKeyEncrypt(aesKey string, key *SSHKey) error {
func SSHKeyEncrypt(aesKey string, key *SSHKey) (err error) {
if aesKey == "" {
return nil
}
var err error
key.PrivKey, err = encrypt([]byte(aesKey), key.PrivKey)
return err
return
}
func SSHKeyDecrypt(aesKey string, key *SSHKey) {
if aesKey == "" {

7
db.go
View file

@ -366,16 +366,11 @@ func (u *User) HasRole(name string) bool {
return false
}
func (u *User) CheckRoles(names []string) error {
ok := false
for _, name := range names {
if u.HasRole(name) {
ok = true
break
}
}
if ok {
return nil
}
}
return fmt.Errorf("you don't have permission to access this feature (requires any of these roles: '%s')", strings.Join(names, "', '"))
}

View file

@ -535,7 +535,9 @@ func dbInit(db *gorm.DB) error {
// create admin user
var defaultUserGroup UserGroup
db.Where("name = ?", "default").First(&defaultUserGroup)
db.Table("users").Count(&count)
if err := db.Table("users").Count(&count).Error; err != nil {
return err
}
if count == 0 {
// if no admin, create an account for the first connection
inviteToken := randStringBytes(16)
@ -588,14 +590,10 @@ func dbInit(db *gorm.DB) error {
}
// close unclosed connections
if err := db.Table("sessions").Where("status = ?", "active").Updates(&Session{
return db.Table("sessions").Where("status = ?", "active").Updates(&Session{
Status: SessionStatusClosed,
ErrMsg: "sshportal was halted while the connection was still active",
}).Error; err != nil {
return err
}
return nil
}).Error
}
func hardDeleteCallback(scope *gorm.Scope) {

73
healthcheck.go Normal file
View file

@ -0,0 +1,73 @@
package main
import (
"bytes"
"fmt"
"log"
"net"
"strings"
"time"
"github.com/urfave/cli"
gossh "golang.org/x/crypto/ssh"
)
// perform a healthcheck test without requiring an ssh client or an ssh key (used for Docker's HEALTHCHECK)
func healthcheck(addr string, wait, quiet bool) error {
cfg := gossh.ClientConfig{
User: "healthcheck",
HostKeyCallback: func(hostname string, remote net.Addr, key gossh.PublicKey) error { return nil },
Auth: []gossh.AuthMethod{gossh.Password("healthcheck")},
}
if wait {
for {
if err := healthcheckOnce(addr, cfg, quiet); err != nil {
if !quiet {
log.Printf("error: %v", err)
}
time.Sleep(time.Second)
continue
}
return nil
}
}
if err := healthcheckOnce(addr, cfg, quiet); err != nil {
if quiet {
return cli.NewExitError("", 1)
}
return err
}
return nil
}
func healthcheckOnce(addr string, config gossh.ClientConfig, quiet bool) error {
client, err := gossh.Dial("tcp", addr, &config)
if err != nil {
return err
}
session, err := client.NewSession()
if err != nil {
return err
}
defer func() {
if err := session.Close(); err != nil {
if !quiet {
log.Printf("failed to close session: %v", err)
}
}
}()
var b bytes.Buffer
session.Stdout = &b
if err := session.Run(""); err != nil {
return err
}
stdout := strings.TrimSpace(b.String())
if stdout != "OK" {
return fmt.Errorf("invalid stdout: %q expected 'OK'", stdout)
}
return nil
}

78
hidden.go Normal file
View file

@ -0,0 +1,78 @@
package main
import (
"encoding/json"
"fmt"
"io"
"log"
"os/exec"
"syscall"
"unsafe"
"github.com/gliderlabs/ssh"
"github.com/kr/pty"
"github.com/urfave/cli"
)
// testServer is an hidden handler used for integration tests
func testServer(c *cli.Context) error {
ssh.Handle(func(s ssh.Session) {
helloMsg := struct {
User string
Environ []string
Command []string
}{
User: s.User(),
Environ: s.Environ(),
Command: s.Command(),
}
if err := json.NewEncoder(s).Encode(&helloMsg); err != nil {
log.Fatalf("failed to write helloMsg: %v", err)
}
cmd := exec.Command(s.Command()[0], s.Command()[1:]...) // #nosec
if s.Command() == nil {
cmd = exec.Command("/bin/sh") // #nosec
}
ptyReq, winCh, isPty := s.Pty()
var cmdErr error
if isPty {
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term))
f, err := pty.Start(cmd)
if err != nil {
fmt.Fprintf(s, "failed to run command: %v\n", err) // #nosec
_ = s.Exit(1) // #nosec
return
}
go func() {
for win := range winCh {
_, _, _ = syscall.Syscall(syscall.SYS_IOCTL, f.Fd(), uintptr(syscall.TIOCSWINSZ),
uintptr(unsafe.Pointer(&struct{ h, w, x, y uint16 }{uint16(win.Height), uint16(win.Width), 0, 0}))) // #nosec
}
}()
go func() {
// stdin
_, _ = io.Copy(f, s) // #nosec
}()
// stdout
_, _ = io.Copy(s, f) // #nosec
cmdErr = cmd.Wait()
} else {
//cmd.Stdin = s
cmd.Stdout = s
cmd.Stderr = s
cmdErr = cmd.Run()
}
if cmdErr != nil {
if exitError, ok := cmdErr.(*exec.ExitError); ok {
_ = s.Exit(exitError.Sys().(syscall.WaitStatus).ExitStatus()) // #nosec
return
}
}
_ = s.Exit(cmd.ProcessState.Sys().(syscall.WaitStatus).ExitStatus()) // #nosec
})
log.Println("starting ssh server on port 2222...")
return ssh.ListenAndServe(":2222", nil)
}

227
main.go
View file

@ -1,28 +1,19 @@
package main
import (
"bytes"
"encoding/json"
"fmt"
"io"
"log"
"math/rand"
"net"
"os"
"os/exec"
"path"
"strings"
"syscall"
"time"
"unsafe"
"github.com/gliderlabs/ssh"
"github.com/jinzhu/gorm"
_ "github.com/jinzhu/gorm/dialects/mysql"
_ "github.com/jinzhu/gorm/dialects/sqlite"
"github.com/kr/pty"
"github.com/urfave/cli"
gossh "golang.org/x/crypto/ssh"
)
var (
@ -48,7 +39,16 @@ func main() {
{
Name: "server",
Usage: "Start sshportal server",
Action: server,
Action: func(c *cli.Context) error {
if err := ensureLogDirectory(c.String("logs-location")); err != nil {
return err
}
cfg, err := parseServeConfig(c)
if err != nil {
return err
}
return server(cfg)
},
Flags: []cli.Flag{
cli.StringFlag{
Name: "bind-address, b",
@ -82,7 +82,7 @@ func main() {
},
}, {
Name: "healthcheck",
Action: healthcheck,
Action: func(c *cli.Context) error { return healthcheck(c.String("addr"), c.Bool("wait"), c.Bool("quiet")) },
Flags: []cli.Flag{
cli.StringFlag{
Name: "addr, a",
@ -109,211 +109,54 @@ func main() {
}
}
func server(c *cli.Context) error {
switch len(c.String("aes-key")) {
case 0, 16, 24, 32:
default:
return fmt.Errorf("invalid aes key size, should be 16 or 24, 32")
}
// db
db, err := gorm.Open(c.String("db-driver"), c.String("db-conn"))
if err != nil {
return err
func server(c *configServe) (err error) {
var db = (*gorm.DB)(nil)
// try to setup the local DB
if db, err = gorm.Open(c.dbDriver, c.dbURL); err != nil {
return
}
defer func() {
if err2 := db.Close(); err2 != nil {
panic(err2)
origErr := err
err = db.Close()
if origErr != nil {
err = origErr
}
}()
if err = db.DB().Ping(); err != nil {
return err
}
if c.Bool("debug") {
db.LogMode(true)
return
}
db.LogMode(c.debug)
if err = dbInit(db); err != nil {
return err
return
}
// check for the logdir existence
logsLocation, e := os.Stat(c.String("logs-location"))
if e != nil {
err = os.MkdirAll(c.String("logs-location"), os.ModeDir|os.FileMode(0750))
if err != nil {
return err
}
} else {
if !logsLocation.IsDir() {
log.Fatal("log directory cannot be created")
}
}
opts := []ssh.Option{}
// custom PublicKeyAuth handler
opts = append(opts, ssh.PublicKeyAuth(publicKeyAuthHandler(db, c)))
opts = append(opts, ssh.PasswordAuth(passwordAuthHandler(db, c)))
// retrieve sshportal SSH private key from database
opts = append(opts, func(srv *ssh.Server) error {
var key SSHKey
if err = SSHKeysByIdentifiers(db, []string{"host"}).First(&key).Error; err != nil {
return err
}
SSHKeyDecrypt(c.String("aes-key"), &key)
var signer gossh.Signer
signer, err = gossh.ParsePrivateKey([]byte(key.PrivKey))
if err != nil {
return err
}
srv.AddHostKey(signer)
return nil
})
// create TCP listening socket
ln, err := net.Listen("tcp", c.String("bind-address"))
ln, err := net.Listen("tcp", c.bindAddr)
if err != nil {
return err
}
// configure server
srv := &ssh.Server{
Addr: c.String("bind-address"),
Addr: c.bindAddr,
Handler: shellHandler, // ssh.Server.Handler is the handler for the DefaultSessionHandler
Version: fmt.Sprintf("sshportal-%s", Version),
ChannelHandler: channelHandler,
}
for _, opt := range opts {
for _, opt := range []ssh.Option{
// custom PublicKeyAuth handler
ssh.PublicKeyAuth(publicKeyAuthHandler(db, c)),
ssh.PasswordAuth(passwordAuthHandler(db, c)),
// retrieve sshportal SSH private key from database
privateKeyFromDB(db, c.aesKey),
} {
if err := srv.SetOption(opt); err != nil {
return err
}
}
log.Printf("info: SSH Server accepting connections on %s", c.String("bind-address"))
log.Printf("info: SSH Server accepting connections on %s", c.bindAddr)
return srv.Serve(ln)
}
// perform a healthcheck test without requiring an ssh client or an ssh key (used for Docker's HEALTHCHECK)
func healthcheck(c *cli.Context) error {
config := gossh.ClientConfig{
User: "healthcheck",
HostKeyCallback: func(hostname string, remote net.Addr, key gossh.PublicKey) error { return nil },
Auth: []gossh.AuthMethod{gossh.Password("healthcheck")},
}
if c.Bool("wait") {
for {
if err := healthcheckOnce(c.String("addr"), config, c.Bool("quiet")); err != nil {
if !c.Bool("quiet") {
log.Printf("error: %v", err)
}
time.Sleep(time.Second)
continue
}
return nil
}
}
if err := healthcheckOnce(c.String("addr"), config, c.Bool("quiet")); err != nil {
if c.Bool("quiet") {
return cli.NewExitError("", 1)
}
return err
}
return nil
}
func healthcheckOnce(addr string, config gossh.ClientConfig, quiet bool) error {
client, err := gossh.Dial("tcp", addr, &config)
if err != nil {
return err
}
session, err := client.NewSession()
if err != nil {
return err
}
defer func() {
if err := session.Close(); err != nil {
if !quiet {
log.Printf("failed to close session: %v", err)
}
}
}()
var b bytes.Buffer
session.Stdout = &b
if err := session.Run(""); err != nil {
return err
}
stdout := strings.TrimSpace(b.String())
if stdout != "OK" {
return fmt.Errorf("invalid stdout: %q expected 'OK'", stdout)
}
return nil
}
// testServer is an hidden handler used for integration tests
func testServer(c *cli.Context) error {
ssh.Handle(func(s ssh.Session) {
helloMsg := struct {
User string
Environ []string
Command []string
}{
User: s.User(),
Environ: s.Environ(),
Command: s.Command(),
}
enc := json.NewEncoder(s)
if err := enc.Encode(&helloMsg); err != nil {
log.Fatalf("failed to write helloMsg: %v", err)
}
var cmd *exec.Cmd
if s.Command() == nil {
cmd = exec.Command("/bin/sh") // #nosec
} else {
cmd = exec.Command(s.Command()[0], s.Command()[1:]...) // #nosec
}
ptyReq, winCh, isPty := s.Pty()
var cmdErr error
if isPty {
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term))
f, err := pty.Start(cmd)
if err != nil {
fmt.Fprintf(s, "failed to run command: %v\n", err)
_ = s.Exit(1)
return
}
go func() {
for win := range winCh {
_, _, _ = syscall.Syscall(syscall.SYS_IOCTL, f.Fd(), uintptr(syscall.TIOCSWINSZ),
uintptr(unsafe.Pointer(&struct{ h, w, x, y uint16 }{uint16(win.Height), uint16(win.Width), 0, 0}))) // #nosec
}
}()
go func() {
_, _ = io.Copy(f, s) // stdin
}()
_, _ = io.Copy(s, f) // stdout
cmdErr = cmd.Wait()
} else {
//cmd.Stdin = s
cmd.Stdout = s
cmd.Stderr = s
cmdErr = cmd.Run()
}
if cmdErr != nil {
if exitError, ok := cmdErr.(*exec.ExitError); ok {
waitStatus := exitError.Sys().(syscall.WaitStatus)
_ = s.Exit(waitStatus.ExitStatus())
return
}
}
waitStatus := cmd.ProcessState.Sys().(syscall.WaitStatus)
_ = s.Exit(waitStatus.ExitStatus())
})
log.Println("starting ssh server on port 2222...")
return ssh.ListenAndServe(":2222", nil)
}

View file

@ -323,11 +323,11 @@ GLOBAL OPTIONS:
return err
}
for _, key := range config.SSHKeys {
SSHKeyDecrypt(actx.globalContext.String("aes-key"), key)
SSHKeyDecrypt(actx.config.aesKey, key)
}
if !c.Bool("decrypt") {
for _, key := range config.SSHKeys {
if err := SSHKeyEncrypt(actx.globalContext.String("aes-key"), key); err != nil {
if err := SSHKeyEncrypt(actx.config.aesKey, key); err != nil {
return err
}
}
@ -337,11 +337,11 @@ GLOBAL OPTIONS:
return err
}
for _, host := range config.Hosts {
HostDecrypt(actx.globalContext.String("aes-key"), host)
HostDecrypt(actx.config.aesKey, host)
}
if !c.Bool("decrypt") {
for _, host := range config.Hosts {
if err := HostEncrypt(actx.globalContext.String("aes-key"), host); err != nil {
if err := HostEncrypt(actx.config.aesKey, host); err != nil {
return err
}
}
@ -456,9 +456,9 @@ GLOBAL OPTIONS:
}
}
for _, host := range config.Hosts {
HostDecrypt(actx.globalContext.String("aes-key"), host)
HostDecrypt(actx.config.aesKey, host)
if !c.Bool("decrypt") {
if err := HostEncrypt(actx.globalContext.String("aes-key"), host); err != nil {
if err := HostEncrypt(actx.config.aesKey, host); err != nil {
return err
}
}
@ -492,9 +492,9 @@ GLOBAL OPTIONS:
}
}
for _, sshKey := range config.SSHKeys {
SSHKeyDecrypt(actx.globalContext.String("aes-key"), sshKey)
SSHKeyDecrypt(actx.config.aesKey, sshKey)
if !c.Bool("decrypt") {
if err := SSHKeyEncrypt(actx.globalContext.String("aes-key"), sshKey); err != nil {
if err := SSHKeyEncrypt(actx.config.aesKey, sshKey); err != nil {
return err
}
}
@ -697,7 +697,7 @@ GLOBAL OPTIONS:
}
// encrypt
if err := HostEncrypt(actx.globalContext.String("aes-key"), host); err != nil {
if err := HostEncrypt(actx.config.aesKey, host); err != nil {
return err
}
@ -734,7 +734,7 @@ GLOBAL OPTIONS:
if c.Bool("decrypt") {
for _, host := range hosts {
HostDecrypt(actx.globalContext.String("aes-key"), host)
HostDecrypt(actx.config.aesKey, host)
}
}
@ -1042,14 +1042,14 @@ GLOBAL OPTIONS:
return err
}
fmt.Fprintf(s, "Debug mode (server): %v\n", actx.globalContext.Bool("debug"))
fmt.Fprintf(s, "debug mode (server): %v\n", actx.config.debug)
hostname, _ := os.Hostname()
fmt.Fprintf(s, "Hostname: %s\n", hostname)
fmt.Fprintf(s, "CPUs: %d\n", runtime.NumCPU())
fmt.Fprintf(s, "Demo mode: %v\n", actx.globalContext.Bool("demo"))
fmt.Fprintf(s, "DB Driver: %s\n", actx.globalContext.String("db-driver"))
fmt.Fprintf(s, "DB Conn: %s\n", actx.globalContext.String("db-conn"))
fmt.Fprintf(s, "Bind Address: %s\n", actx.globalContext.String("bind-address"))
fmt.Fprintf(s, "Demo mode: %v\n", actx.config.demo)
fmt.Fprintf(s, "DB Driver: %s\n", actx.config.dbDriver)
fmt.Fprintf(s, "DB Conn: %s\n", actx.config.dbURL)
fmt.Fprintf(s, "Bind Address: %s\n", actx.config.bindAddr)
fmt.Fprintf(s, "System Time: %v\n", time.Now().Format(time.RFC3339Nano))
fmt.Fprintf(s, "OS Type: %s\n", runtime.GOOS)
fmt.Fprintf(s, "OS Architecture: %s\n", runtime.GOARCH)
@ -1095,8 +1095,8 @@ GLOBAL OPTIONS:
}
key, err := NewSSHKey(c.String("type"), c.Uint("length"))
if actx.globalContext.String("aes-key") != "" {
if err2 := SSHKeyEncrypt(actx.globalContext.String("aes-key"), key); err2 != nil {
if actx.config.aesKey != "" {
if err2 := SSHKeyEncrypt(actx.config.aesKey, key); err2 != nil {
return err2
}
}
@ -1141,7 +1141,7 @@ GLOBAL OPTIONS:
if c.Bool("decrypt") {
for _, key := range keys {
SSHKeyDecrypt(actx.globalContext.String("aes-key"), key)
SSHKeyDecrypt(actx.config.aesKey, key)
}
}
@ -1250,7 +1250,7 @@ GLOBAL OPTIONS:
if err := SSHKeysByIdentifiers(SSHKeysPreload(db), c.Args()).First(&key).Error; err != nil {
return err
}
SSHKeyDecrypt(actx.globalContext.String("aes-key"), &key)
SSHKeyDecrypt(actx.config.aesKey, &key)
type line struct {
key string

35
ssh.go
View file

@ -12,7 +12,6 @@ import (
"github.com/gliderlabs/ssh"
"github.com/jinzhu/gorm"
"github.com/moul/sshportal/pkg/bastionsession"
"github.com/urfave/cli"
gossh "golang.org/x/crypto/ssh"
)
@ -27,7 +26,7 @@ type authContext struct {
inputUsername string
db *gorm.DB
userKey UserKey
globalContext *cli.Context
config *configServe
authMethod string
authSuccess bool
}
@ -96,7 +95,6 @@ func channelHandler(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewCh
}
actx := ctx.Value(authContextKey).(*authContext)
logsLocation := actx.globalContext.String("logs-location")
switch actx.userType() {
case UserTypeBastion:
@ -145,7 +143,7 @@ func channelHandler(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewCh
err = bastionsession.ChannelHandler(srv, conn, newChan, ctx, bastionsession.Config{
Addr: host.DialAddr(),
ClientConfig: clientConfig,
Logs: logsLocation,
Logs: actx.config.logsLocation,
})
now := time.Now()
@ -200,8 +198,8 @@ func bastionClientConfig(ctx ssh.Context, host *Host) (*gossh.ClientConfig, erro
return nil, err2
}
HostDecrypt(actx.globalContext.String("aes-key"), host)
SSHKeyDecrypt(actx.globalContext.String("aes-key"), host.SSHKey)
HostDecrypt(actx.config.aesKey, host)
SSHKeyDecrypt(actx.config.aesKey, host.SSHKey)
switch action {
case ACLActionAllow:
@ -246,12 +244,12 @@ func shellHandler(s ssh.Session) {
panic("should not happen")
}
func passwordAuthHandler(db *gorm.DB, globalContext *cli.Context) ssh.PasswordHandler {
func passwordAuthHandler(db *gorm.DB, cfg *configServe) ssh.PasswordHandler {
return func(ctx ssh.Context, pass string) bool {
actx := &authContext{
db: db,
inputUsername: ctx.User(),
globalContext: globalContext,
config: cfg,
authMethod: "password",
}
actx.authSuccess = actx.userType() == UserTypeHealthcheck
@ -260,12 +258,29 @@ func passwordAuthHandler(db *gorm.DB, globalContext *cli.Context) ssh.PasswordHa
}
}
func publicKeyAuthHandler(db *gorm.DB, globalContext *cli.Context) ssh.PublicKeyHandler {
func privateKeyFromDB(db *gorm.DB, aesKey string) func(*ssh.Server) error {
return func(srv *ssh.Server) error {
var key SSHKey
if err := SSHKeysByIdentifiers(db, []string{"host"}).First(&key).Error; err != nil {
return err
}
SSHKeyDecrypt(aesKey, &key)
signer, err := gossh.ParsePrivateKey([]byte(key.PrivKey))
if err != nil {
return err
}
srv.AddHostKey(signer)
return nil
}
}
func publicKeyAuthHandler(db *gorm.DB, cfg *configServe) ssh.PublicKeyHandler {
return func(ctx ssh.Context, key ssh.PublicKey) bool {
actx := &authContext{
db: db,
inputUsername: ctx.User(),
globalContext: globalContext,
config: cfg,
authMethod: "pubkey",
authSuccess: true,
}

View file

@ -25,9 +25,9 @@ func (caller bastionTelnetCaller) CallTELNET(ctx telnet.Context, w telnet.Writer
for {
// Read 1 byte.
n, err := reader.Read(p)
if n <= 0 && nil == err {
if n <= 0 && err == nil {
continue
} else if n <= 0 && nil != err {
} else if n <= 0 && err != nil {
break
}