mirror of
https://github.com/moul/sshportal.git
synced 2024-12-26 01:31:33 +08:00
main: remove globalContext, and move some functions outside of the main
This commit is contained in:
parent
9cc09b320d
commit
5720123576
11 changed files with 295 additions and 248 deletions
2
acl.go
2
acl.go
|
@ -30,7 +30,7 @@ func CheckACLs(user User, host Host) (string, error) {
|
|||
}
|
||||
|
||||
// transform map to slice and sort it
|
||||
acls := []*ACL{}
|
||||
acls := make([]*ACL, 0, len(aclMap))
|
||||
for _, acl := range aclMap {
|
||||
acls = append(acls, acl)
|
||||
}
|
||||
|
|
49
config.go
Normal file
49
config.go
Normal file
|
@ -0,0 +1,49 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/urfave/cli"
|
||||
)
|
||||
|
||||
type configServe struct {
|
||||
aesKey string
|
||||
dbDriver, dbURL string
|
||||
logsLocation string
|
||||
bindAddr string
|
||||
debug, demo bool
|
||||
}
|
||||
|
||||
func parseServeConfig(c *cli.Context) (*configServe, error) {
|
||||
ret := &configServe{
|
||||
aesKey: c.String("aes-key"),
|
||||
dbDriver: c.String("db-driver"),
|
||||
dbURL: c.String("db-conn"),
|
||||
bindAddr: c.String("bind-address"),
|
||||
debug: c.Bool("debug"),
|
||||
demo: c.Bool("demo"),
|
||||
logsLocation: c.String("logs-location"),
|
||||
}
|
||||
switch len(ret.aesKey) {
|
||||
case 0, 16, 24, 32:
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid aes key size, should be 16 or 24, 32")
|
||||
}
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func ensureLogDirectory(location string) error {
|
||||
// check for the logdir existence
|
||||
logsLocation, err := os.Stat(location)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return os.MkdirAll(location, os.ModeDir|os.FileMode(0750))
|
||||
}
|
||||
return err
|
||||
}
|
||||
if !logsLocation.IsDir() {
|
||||
return fmt.Errorf("log directory cannot be created")
|
||||
}
|
||||
return nil
|
||||
}
|
14
crypto.go
14
crypto.go
|
@ -95,17 +95,14 @@ func safeDecrypt(key []byte, cryptoText string) string {
|
|||
return out
|
||||
}
|
||||
|
||||
func HostEncrypt(aesKey string, host *Host) error {
|
||||
func HostEncrypt(aesKey string, host *Host) (err error) {
|
||||
if aesKey == "" {
|
||||
return nil
|
||||
}
|
||||
var err error
|
||||
if host.Password != "" {
|
||||
if host.Password, err = encrypt([]byte(aesKey), host.Password); err != nil {
|
||||
return err
|
||||
}
|
||||
host.Password, err = encrypt([]byte(aesKey), host.Password)
|
||||
}
|
||||
return nil
|
||||
return
|
||||
}
|
||||
func HostDecrypt(aesKey string, host *Host) {
|
||||
if aesKey == "" {
|
||||
|
@ -116,13 +113,12 @@ func HostDecrypt(aesKey string, host *Host) {
|
|||
}
|
||||
}
|
||||
|
||||
func SSHKeyEncrypt(aesKey string, key *SSHKey) error {
|
||||
func SSHKeyEncrypt(aesKey string, key *SSHKey) (err error) {
|
||||
if aesKey == "" {
|
||||
return nil
|
||||
}
|
||||
var err error
|
||||
key.PrivKey, err = encrypt([]byte(aesKey), key.PrivKey)
|
||||
return err
|
||||
return
|
||||
}
|
||||
func SSHKeyDecrypt(aesKey string, key *SSHKey) {
|
||||
if aesKey == "" {
|
||||
|
|
7
db.go
7
db.go
|
@ -366,16 +366,11 @@ func (u *User) HasRole(name string) bool {
|
|||
return false
|
||||
}
|
||||
func (u *User) CheckRoles(names []string) error {
|
||||
ok := false
|
||||
for _, name := range names {
|
||||
if u.HasRole(name) {
|
||||
ok = true
|
||||
break
|
||||
return nil
|
||||
}
|
||||
}
|
||||
if ok {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("you don't have permission to access this feature (requires any of these roles: '%s')", strings.Join(names, "', '"))
|
||||
}
|
||||
|
||||
|
|
12
dbinit.go
12
dbinit.go
|
@ -535,7 +535,9 @@ func dbInit(db *gorm.DB) error {
|
|||
// create admin user
|
||||
var defaultUserGroup UserGroup
|
||||
db.Where("name = ?", "default").First(&defaultUserGroup)
|
||||
db.Table("users").Count(&count)
|
||||
if err := db.Table("users").Count(&count).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
if count == 0 {
|
||||
// if no admin, create an account for the first connection
|
||||
inviteToken := randStringBytes(16)
|
||||
|
@ -588,14 +590,10 @@ func dbInit(db *gorm.DB) error {
|
|||
}
|
||||
|
||||
// close unclosed connections
|
||||
if err := db.Table("sessions").Where("status = ?", "active").Updates(&Session{
|
||||
return db.Table("sessions").Where("status = ?", "active").Updates(&Session{
|
||||
Status: SessionStatusClosed,
|
||||
ErrMsg: "sshportal was halted while the connection was still active",
|
||||
}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}).Error
|
||||
}
|
||||
|
||||
func hardDeleteCallback(scope *gorm.Scope) {
|
||||
|
|
73
healthcheck.go
Normal file
73
healthcheck.go
Normal file
|
@ -0,0 +1,73 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/urfave/cli"
|
||||
gossh "golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// perform a healthcheck test without requiring an ssh client or an ssh key (used for Docker's HEALTHCHECK)
|
||||
func healthcheck(addr string, wait, quiet bool) error {
|
||||
cfg := gossh.ClientConfig{
|
||||
User: "healthcheck",
|
||||
HostKeyCallback: func(hostname string, remote net.Addr, key gossh.PublicKey) error { return nil },
|
||||
Auth: []gossh.AuthMethod{gossh.Password("healthcheck")},
|
||||
}
|
||||
|
||||
if wait {
|
||||
for {
|
||||
if err := healthcheckOnce(addr, cfg, quiet); err != nil {
|
||||
if !quiet {
|
||||
log.Printf("error: %v", err)
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
continue
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
if err := healthcheckOnce(addr, cfg, quiet); err != nil {
|
||||
if quiet {
|
||||
return cli.NewExitError("", 1)
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func healthcheckOnce(addr string, config gossh.ClientConfig, quiet bool) error {
|
||||
client, err := gossh.Dial("tcp", addr, &config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
session, err := client.NewSession()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err := session.Close(); err != nil {
|
||||
if !quiet {
|
||||
log.Printf("failed to close session: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
var b bytes.Buffer
|
||||
session.Stdout = &b
|
||||
if err := session.Run(""); err != nil {
|
||||
return err
|
||||
}
|
||||
stdout := strings.TrimSpace(b.String())
|
||||
if stdout != "OK" {
|
||||
return fmt.Errorf("invalid stdout: %q expected 'OK'", stdout)
|
||||
}
|
||||
return nil
|
||||
}
|
78
hidden.go
Normal file
78
hidden.go
Normal file
|
@ -0,0 +1,78 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"os/exec"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
"github.com/kr/pty"
|
||||
"github.com/urfave/cli"
|
||||
)
|
||||
|
||||
// testServer is an hidden handler used for integration tests
|
||||
func testServer(c *cli.Context) error {
|
||||
ssh.Handle(func(s ssh.Session) {
|
||||
helloMsg := struct {
|
||||
User string
|
||||
Environ []string
|
||||
Command []string
|
||||
}{
|
||||
User: s.User(),
|
||||
Environ: s.Environ(),
|
||||
Command: s.Command(),
|
||||
}
|
||||
|
||||
if err := json.NewEncoder(s).Encode(&helloMsg); err != nil {
|
||||
log.Fatalf("failed to write helloMsg: %v", err)
|
||||
}
|
||||
cmd := exec.Command(s.Command()[0], s.Command()[1:]...) // #nosec
|
||||
if s.Command() == nil {
|
||||
cmd = exec.Command("/bin/sh") // #nosec
|
||||
}
|
||||
ptyReq, winCh, isPty := s.Pty()
|
||||
var cmdErr error
|
||||
if isPty {
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term))
|
||||
f, err := pty.Start(cmd)
|
||||
if err != nil {
|
||||
fmt.Fprintf(s, "failed to run command: %v\n", err) // #nosec
|
||||
_ = s.Exit(1) // #nosec
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
for win := range winCh {
|
||||
_, _, _ = syscall.Syscall(syscall.SYS_IOCTL, f.Fd(), uintptr(syscall.TIOCSWINSZ),
|
||||
uintptr(unsafe.Pointer(&struct{ h, w, x, y uint16 }{uint16(win.Height), uint16(win.Width), 0, 0}))) // #nosec
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
// stdin
|
||||
_, _ = io.Copy(f, s) // #nosec
|
||||
}()
|
||||
// stdout
|
||||
_, _ = io.Copy(s, f) // #nosec
|
||||
cmdErr = cmd.Wait()
|
||||
} else {
|
||||
//cmd.Stdin = s
|
||||
cmd.Stdout = s
|
||||
cmd.Stderr = s
|
||||
cmdErr = cmd.Run()
|
||||
}
|
||||
|
||||
if cmdErr != nil {
|
||||
if exitError, ok := cmdErr.(*exec.ExitError); ok {
|
||||
_ = s.Exit(exitError.Sys().(syscall.WaitStatus).ExitStatus()) // #nosec
|
||||
return
|
||||
}
|
||||
}
|
||||
_ = s.Exit(cmd.ProcessState.Sys().(syscall.WaitStatus).ExitStatus()) // #nosec
|
||||
})
|
||||
|
||||
log.Println("starting ssh server on port 2222...")
|
||||
return ssh.ListenAndServe(":2222", nil)
|
||||
}
|
231
main.go
231
main.go
|
@ -1,28 +1,19 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"math/rand"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
"github.com/jinzhu/gorm"
|
||||
_ "github.com/jinzhu/gorm/dialects/mysql"
|
||||
_ "github.com/jinzhu/gorm/dialects/sqlite"
|
||||
"github.com/kr/pty"
|
||||
"github.com/urfave/cli"
|
||||
gossh "golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -46,9 +37,18 @@ func main() {
|
|||
app.Email = "https://github.com/moul/sshportal"
|
||||
app.Commands = []cli.Command{
|
||||
{
|
||||
Name: "server",
|
||||
Usage: "Start sshportal server",
|
||||
Action: server,
|
||||
Name: "server",
|
||||
Usage: "Start sshportal server",
|
||||
Action: func(c *cli.Context) error {
|
||||
if err := ensureLogDirectory(c.String("logs-location")); err != nil {
|
||||
return err
|
||||
}
|
||||
cfg, err := parseServeConfig(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return server(cfg)
|
||||
},
|
||||
Flags: []cli.Flag{
|
||||
cli.StringFlag{
|
||||
Name: "bind-address, b",
|
||||
|
@ -82,7 +82,7 @@ func main() {
|
|||
},
|
||||
}, {
|
||||
Name: "healthcheck",
|
||||
Action: healthcheck,
|
||||
Action: func(c *cli.Context) error { return healthcheck(c.String("addr"), c.Bool("wait"), c.Bool("quiet")) },
|
||||
Flags: []cli.Flag{
|
||||
cli.StringFlag{
|
||||
Name: "addr, a",
|
||||
|
@ -109,211 +109,54 @@ func main() {
|
|||
}
|
||||
}
|
||||
|
||||
func server(c *cli.Context) error {
|
||||
switch len(c.String("aes-key")) {
|
||||
case 0, 16, 24, 32:
|
||||
default:
|
||||
return fmt.Errorf("invalid aes key size, should be 16 or 24, 32")
|
||||
}
|
||||
// db
|
||||
db, err := gorm.Open(c.String("db-driver"), c.String("db-conn"))
|
||||
if err != nil {
|
||||
return err
|
||||
func server(c *configServe) (err error) {
|
||||
var db = (*gorm.DB)(nil)
|
||||
|
||||
// try to setup the local DB
|
||||
if db, err = gorm.Open(c.dbDriver, c.dbURL); err != nil {
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err2 := db.Close(); err2 != nil {
|
||||
panic(err2)
|
||||
origErr := err
|
||||
err = db.Close()
|
||||
if origErr != nil {
|
||||
err = origErr
|
||||
}
|
||||
}()
|
||||
if err = db.DB().Ping(); err != nil {
|
||||
return err
|
||||
}
|
||||
if c.Bool("debug") {
|
||||
db.LogMode(true)
|
||||
return
|
||||
}
|
||||
db.LogMode(c.debug)
|
||||
if err = dbInit(db); err != nil {
|
||||
return err
|
||||
return
|
||||
}
|
||||
|
||||
// check for the logdir existence
|
||||
logsLocation, e := os.Stat(c.String("logs-location"))
|
||||
if e != nil {
|
||||
err = os.MkdirAll(c.String("logs-location"), os.ModeDir|os.FileMode(0750))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if !logsLocation.IsDir() {
|
||||
log.Fatal("log directory cannot be created")
|
||||
}
|
||||
}
|
||||
|
||||
opts := []ssh.Option{}
|
||||
// custom PublicKeyAuth handler
|
||||
opts = append(opts, ssh.PublicKeyAuth(publicKeyAuthHandler(db, c)))
|
||||
opts = append(opts, ssh.PasswordAuth(passwordAuthHandler(db, c)))
|
||||
|
||||
// retrieve sshportal SSH private key from database
|
||||
opts = append(opts, func(srv *ssh.Server) error {
|
||||
var key SSHKey
|
||||
if err = SSHKeysByIdentifiers(db, []string{"host"}).First(&key).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
SSHKeyDecrypt(c.String("aes-key"), &key)
|
||||
|
||||
var signer gossh.Signer
|
||||
signer, err = gossh.ParsePrivateKey([]byte(key.PrivKey))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
srv.AddHostKey(signer)
|
||||
return nil
|
||||
})
|
||||
|
||||
// create TCP listening socket
|
||||
ln, err := net.Listen("tcp", c.String("bind-address"))
|
||||
ln, err := net.Listen("tcp", c.bindAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// configure server
|
||||
srv := &ssh.Server{
|
||||
Addr: c.String("bind-address"),
|
||||
Addr: c.bindAddr,
|
||||
Handler: shellHandler, // ssh.Server.Handler is the handler for the DefaultSessionHandler
|
||||
Version: fmt.Sprintf("sshportal-%s", Version),
|
||||
ChannelHandler: channelHandler,
|
||||
}
|
||||
for _, opt := range opts {
|
||||
|
||||
for _, opt := range []ssh.Option{
|
||||
// custom PublicKeyAuth handler
|
||||
ssh.PublicKeyAuth(publicKeyAuthHandler(db, c)),
|
||||
ssh.PasswordAuth(passwordAuthHandler(db, c)),
|
||||
// retrieve sshportal SSH private key from database
|
||||
privateKeyFromDB(db, c.aesKey),
|
||||
} {
|
||||
if err := srv.SetOption(opt); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("info: SSH Server accepting connections on %s", c.String("bind-address"))
|
||||
log.Printf("info: SSH Server accepting connections on %s", c.bindAddr)
|
||||
return srv.Serve(ln)
|
||||
}
|
||||
|
||||
// perform a healthcheck test without requiring an ssh client or an ssh key (used for Docker's HEALTHCHECK)
|
||||
func healthcheck(c *cli.Context) error {
|
||||
config := gossh.ClientConfig{
|
||||
User: "healthcheck",
|
||||
HostKeyCallback: func(hostname string, remote net.Addr, key gossh.PublicKey) error { return nil },
|
||||
Auth: []gossh.AuthMethod{gossh.Password("healthcheck")},
|
||||
}
|
||||
|
||||
if c.Bool("wait") {
|
||||
for {
|
||||
if err := healthcheckOnce(c.String("addr"), config, c.Bool("quiet")); err != nil {
|
||||
if !c.Bool("quiet") {
|
||||
log.Printf("error: %v", err)
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
continue
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
if err := healthcheckOnce(c.String("addr"), config, c.Bool("quiet")); err != nil {
|
||||
if c.Bool("quiet") {
|
||||
return cli.NewExitError("", 1)
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func healthcheckOnce(addr string, config gossh.ClientConfig, quiet bool) error {
|
||||
client, err := gossh.Dial("tcp", addr, &config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
session, err := client.NewSession()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err := session.Close(); err != nil {
|
||||
if !quiet {
|
||||
log.Printf("failed to close session: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
var b bytes.Buffer
|
||||
session.Stdout = &b
|
||||
if err := session.Run(""); err != nil {
|
||||
return err
|
||||
}
|
||||
stdout := strings.TrimSpace(b.String())
|
||||
if stdout != "OK" {
|
||||
return fmt.Errorf("invalid stdout: %q expected 'OK'", stdout)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// testServer is an hidden handler used for integration tests
|
||||
func testServer(c *cli.Context) error {
|
||||
ssh.Handle(func(s ssh.Session) {
|
||||
helloMsg := struct {
|
||||
User string
|
||||
Environ []string
|
||||
Command []string
|
||||
}{
|
||||
User: s.User(),
|
||||
Environ: s.Environ(),
|
||||
Command: s.Command(),
|
||||
}
|
||||
enc := json.NewEncoder(s)
|
||||
if err := enc.Encode(&helloMsg); err != nil {
|
||||
log.Fatalf("failed to write helloMsg: %v", err)
|
||||
}
|
||||
var cmd *exec.Cmd
|
||||
if s.Command() == nil {
|
||||
cmd = exec.Command("/bin/sh") // #nosec
|
||||
} else {
|
||||
cmd = exec.Command(s.Command()[0], s.Command()[1:]...) // #nosec
|
||||
}
|
||||
ptyReq, winCh, isPty := s.Pty()
|
||||
var cmdErr error
|
||||
if isPty {
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term))
|
||||
f, err := pty.Start(cmd)
|
||||
if err != nil {
|
||||
fmt.Fprintf(s, "failed to run command: %v\n", err)
|
||||
_ = s.Exit(1)
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
for win := range winCh {
|
||||
_, _, _ = syscall.Syscall(syscall.SYS_IOCTL, f.Fd(), uintptr(syscall.TIOCSWINSZ),
|
||||
uintptr(unsafe.Pointer(&struct{ h, w, x, y uint16 }{uint16(win.Height), uint16(win.Width), 0, 0}))) // #nosec
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
_, _ = io.Copy(f, s) // stdin
|
||||
}()
|
||||
_, _ = io.Copy(s, f) // stdout
|
||||
cmdErr = cmd.Wait()
|
||||
} else {
|
||||
//cmd.Stdin = s
|
||||
cmd.Stdout = s
|
||||
cmd.Stderr = s
|
||||
cmdErr = cmd.Run()
|
||||
}
|
||||
|
||||
if cmdErr != nil {
|
||||
if exitError, ok := cmdErr.(*exec.ExitError); ok {
|
||||
waitStatus := exitError.Sys().(syscall.WaitStatus)
|
||||
_ = s.Exit(waitStatus.ExitStatus())
|
||||
return
|
||||
}
|
||||
}
|
||||
waitStatus := cmd.ProcessState.Sys().(syscall.WaitStatus)
|
||||
_ = s.Exit(waitStatus.ExitStatus())
|
||||
})
|
||||
|
||||
log.Println("starting ssh server on port 2222...")
|
||||
return ssh.ListenAndServe(":2222", nil)
|
||||
}
|
||||
|
|
38
shell.go
38
shell.go
|
@ -323,11 +323,11 @@ GLOBAL OPTIONS:
|
|||
return err
|
||||
}
|
||||
for _, key := range config.SSHKeys {
|
||||
SSHKeyDecrypt(actx.globalContext.String("aes-key"), key)
|
||||
SSHKeyDecrypt(actx.config.aesKey, key)
|
||||
}
|
||||
if !c.Bool("decrypt") {
|
||||
for _, key := range config.SSHKeys {
|
||||
if err := SSHKeyEncrypt(actx.globalContext.String("aes-key"), key); err != nil {
|
||||
if err := SSHKeyEncrypt(actx.config.aesKey, key); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -337,11 +337,11 @@ GLOBAL OPTIONS:
|
|||
return err
|
||||
}
|
||||
for _, host := range config.Hosts {
|
||||
HostDecrypt(actx.globalContext.String("aes-key"), host)
|
||||
HostDecrypt(actx.config.aesKey, host)
|
||||
}
|
||||
if !c.Bool("decrypt") {
|
||||
for _, host := range config.Hosts {
|
||||
if err := HostEncrypt(actx.globalContext.String("aes-key"), host); err != nil {
|
||||
if err := HostEncrypt(actx.config.aesKey, host); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -456,9 +456,9 @@ GLOBAL OPTIONS:
|
|||
}
|
||||
}
|
||||
for _, host := range config.Hosts {
|
||||
HostDecrypt(actx.globalContext.String("aes-key"), host)
|
||||
HostDecrypt(actx.config.aesKey, host)
|
||||
if !c.Bool("decrypt") {
|
||||
if err := HostEncrypt(actx.globalContext.String("aes-key"), host); err != nil {
|
||||
if err := HostEncrypt(actx.config.aesKey, host); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -492,9 +492,9 @@ GLOBAL OPTIONS:
|
|||
}
|
||||
}
|
||||
for _, sshKey := range config.SSHKeys {
|
||||
SSHKeyDecrypt(actx.globalContext.String("aes-key"), sshKey)
|
||||
SSHKeyDecrypt(actx.config.aesKey, sshKey)
|
||||
if !c.Bool("decrypt") {
|
||||
if err := SSHKeyEncrypt(actx.globalContext.String("aes-key"), sshKey); err != nil {
|
||||
if err := SSHKeyEncrypt(actx.config.aesKey, sshKey); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -697,7 +697,7 @@ GLOBAL OPTIONS:
|
|||
}
|
||||
|
||||
// encrypt
|
||||
if err := HostEncrypt(actx.globalContext.String("aes-key"), host); err != nil {
|
||||
if err := HostEncrypt(actx.config.aesKey, host); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -734,7 +734,7 @@ GLOBAL OPTIONS:
|
|||
|
||||
if c.Bool("decrypt") {
|
||||
for _, host := range hosts {
|
||||
HostDecrypt(actx.globalContext.String("aes-key"), host)
|
||||
HostDecrypt(actx.config.aesKey, host)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1042,14 +1042,14 @@ GLOBAL OPTIONS:
|
|||
return err
|
||||
}
|
||||
|
||||
fmt.Fprintf(s, "Debug mode (server): %v\n", actx.globalContext.Bool("debug"))
|
||||
fmt.Fprintf(s, "debug mode (server): %v\n", actx.config.debug)
|
||||
hostname, _ := os.Hostname()
|
||||
fmt.Fprintf(s, "Hostname: %s\n", hostname)
|
||||
fmt.Fprintf(s, "CPUs: %d\n", runtime.NumCPU())
|
||||
fmt.Fprintf(s, "Demo mode: %v\n", actx.globalContext.Bool("demo"))
|
||||
fmt.Fprintf(s, "DB Driver: %s\n", actx.globalContext.String("db-driver"))
|
||||
fmt.Fprintf(s, "DB Conn: %s\n", actx.globalContext.String("db-conn"))
|
||||
fmt.Fprintf(s, "Bind Address: %s\n", actx.globalContext.String("bind-address"))
|
||||
fmt.Fprintf(s, "Demo mode: %v\n", actx.config.demo)
|
||||
fmt.Fprintf(s, "DB Driver: %s\n", actx.config.dbDriver)
|
||||
fmt.Fprintf(s, "DB Conn: %s\n", actx.config.dbURL)
|
||||
fmt.Fprintf(s, "Bind Address: %s\n", actx.config.bindAddr)
|
||||
fmt.Fprintf(s, "System Time: %v\n", time.Now().Format(time.RFC3339Nano))
|
||||
fmt.Fprintf(s, "OS Type: %s\n", runtime.GOOS)
|
||||
fmt.Fprintf(s, "OS Architecture: %s\n", runtime.GOARCH)
|
||||
|
@ -1095,8 +1095,8 @@ GLOBAL OPTIONS:
|
|||
}
|
||||
|
||||
key, err := NewSSHKey(c.String("type"), c.Uint("length"))
|
||||
if actx.globalContext.String("aes-key") != "" {
|
||||
if err2 := SSHKeyEncrypt(actx.globalContext.String("aes-key"), key); err2 != nil {
|
||||
if actx.config.aesKey != "" {
|
||||
if err2 := SSHKeyEncrypt(actx.config.aesKey, key); err2 != nil {
|
||||
return err2
|
||||
}
|
||||
}
|
||||
|
@ -1141,7 +1141,7 @@ GLOBAL OPTIONS:
|
|||
|
||||
if c.Bool("decrypt") {
|
||||
for _, key := range keys {
|
||||
SSHKeyDecrypt(actx.globalContext.String("aes-key"), key)
|
||||
SSHKeyDecrypt(actx.config.aesKey, key)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1250,7 +1250,7 @@ GLOBAL OPTIONS:
|
|||
if err := SSHKeysByIdentifiers(SSHKeysPreload(db), c.Args()).First(&key).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
SSHKeyDecrypt(actx.globalContext.String("aes-key"), &key)
|
||||
SSHKeyDecrypt(actx.config.aesKey, &key)
|
||||
|
||||
type line struct {
|
||||
key string
|
||||
|
|
35
ssh.go
35
ssh.go
|
@ -12,7 +12,6 @@ import (
|
|||
"github.com/gliderlabs/ssh"
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/moul/sshportal/pkg/bastionsession"
|
||||
"github.com/urfave/cli"
|
||||
gossh "golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
|
@ -27,7 +26,7 @@ type authContext struct {
|
|||
inputUsername string
|
||||
db *gorm.DB
|
||||
userKey UserKey
|
||||
globalContext *cli.Context
|
||||
config *configServe
|
||||
authMethod string
|
||||
authSuccess bool
|
||||
}
|
||||
|
@ -96,7 +95,6 @@ func channelHandler(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewCh
|
|||
}
|
||||
|
||||
actx := ctx.Value(authContextKey).(*authContext)
|
||||
logsLocation := actx.globalContext.String("logs-location")
|
||||
|
||||
switch actx.userType() {
|
||||
case UserTypeBastion:
|
||||
|
@ -145,7 +143,7 @@ func channelHandler(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewCh
|
|||
err = bastionsession.ChannelHandler(srv, conn, newChan, ctx, bastionsession.Config{
|
||||
Addr: host.DialAddr(),
|
||||
ClientConfig: clientConfig,
|
||||
Logs: logsLocation,
|
||||
Logs: actx.config.logsLocation,
|
||||
})
|
||||
|
||||
now := time.Now()
|
||||
|
@ -200,8 +198,8 @@ func bastionClientConfig(ctx ssh.Context, host *Host) (*gossh.ClientConfig, erro
|
|||
return nil, err2
|
||||
}
|
||||
|
||||
HostDecrypt(actx.globalContext.String("aes-key"), host)
|
||||
SSHKeyDecrypt(actx.globalContext.String("aes-key"), host.SSHKey)
|
||||
HostDecrypt(actx.config.aesKey, host)
|
||||
SSHKeyDecrypt(actx.config.aesKey, host.SSHKey)
|
||||
|
||||
switch action {
|
||||
case ACLActionAllow:
|
||||
|
@ -246,12 +244,12 @@ func shellHandler(s ssh.Session) {
|
|||
panic("should not happen")
|
||||
}
|
||||
|
||||
func passwordAuthHandler(db *gorm.DB, globalContext *cli.Context) ssh.PasswordHandler {
|
||||
func passwordAuthHandler(db *gorm.DB, cfg *configServe) ssh.PasswordHandler {
|
||||
return func(ctx ssh.Context, pass string) bool {
|
||||
actx := &authContext{
|
||||
db: db,
|
||||
inputUsername: ctx.User(),
|
||||
globalContext: globalContext,
|
||||
config: cfg,
|
||||
authMethod: "password",
|
||||
}
|
||||
actx.authSuccess = actx.userType() == UserTypeHealthcheck
|
||||
|
@ -260,12 +258,29 @@ func passwordAuthHandler(db *gorm.DB, globalContext *cli.Context) ssh.PasswordHa
|
|||
}
|
||||
}
|
||||
|
||||
func publicKeyAuthHandler(db *gorm.DB, globalContext *cli.Context) ssh.PublicKeyHandler {
|
||||
func privateKeyFromDB(db *gorm.DB, aesKey string) func(*ssh.Server) error {
|
||||
return func(srv *ssh.Server) error {
|
||||
var key SSHKey
|
||||
if err := SSHKeysByIdentifiers(db, []string{"host"}).First(&key).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
SSHKeyDecrypt(aesKey, &key)
|
||||
|
||||
signer, err := gossh.ParsePrivateKey([]byte(key.PrivKey))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
srv.AddHostKey(signer)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func publicKeyAuthHandler(db *gorm.DB, cfg *configServe) ssh.PublicKeyHandler {
|
||||
return func(ctx ssh.Context, key ssh.PublicKey) bool {
|
||||
actx := &authContext{
|
||||
db: db,
|
||||
inputUsername: ctx.User(),
|
||||
globalContext: globalContext,
|
||||
config: cfg,
|
||||
authMethod: "pubkey",
|
||||
authSuccess: true,
|
||||
}
|
||||
|
|
|
@ -25,9 +25,9 @@ func (caller bastionTelnetCaller) CallTELNET(ctx telnet.Context, w telnet.Writer
|
|||
for {
|
||||
// Read 1 byte.
|
||||
n, err := reader.Read(p)
|
||||
if n <= 0 && nil == err {
|
||||
if n <= 0 && err == nil {
|
||||
continue
|
||||
} else if n <= 0 && nil != err {
|
||||
} else if n <= 0 && err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue