mirror of
https://github.com/moul/sshportal.git
synced 2025-09-05 20:24:29 +08:00
main: remove globalContext, and move some functions outside of the main
This commit is contained in:
parent
9cc09b320d
commit
5720123576
11 changed files with 295 additions and 248 deletions
2
acl.go
2
acl.go
|
@ -30,7 +30,7 @@ func CheckACLs(user User, host Host) (string, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// transform map to slice and sort it
|
// transform map to slice and sort it
|
||||||
acls := []*ACL{}
|
acls := make([]*ACL, 0, len(aclMap))
|
||||||
for _, acl := range aclMap {
|
for _, acl := range aclMap {
|
||||||
acls = append(acls, acl)
|
acls = append(acls, acl)
|
||||||
}
|
}
|
||||||
|
|
49
config.go
Normal file
49
config.go
Normal file
|
@ -0,0 +1,49 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/urfave/cli"
|
||||||
|
)
|
||||||
|
|
||||||
|
type configServe struct {
|
||||||
|
aesKey string
|
||||||
|
dbDriver, dbURL string
|
||||||
|
logsLocation string
|
||||||
|
bindAddr string
|
||||||
|
debug, demo bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseServeConfig(c *cli.Context) (*configServe, error) {
|
||||||
|
ret := &configServe{
|
||||||
|
aesKey: c.String("aes-key"),
|
||||||
|
dbDriver: c.String("db-driver"),
|
||||||
|
dbURL: c.String("db-conn"),
|
||||||
|
bindAddr: c.String("bind-address"),
|
||||||
|
debug: c.Bool("debug"),
|
||||||
|
demo: c.Bool("demo"),
|
||||||
|
logsLocation: c.String("logs-location"),
|
||||||
|
}
|
||||||
|
switch len(ret.aesKey) {
|
||||||
|
case 0, 16, 24, 32:
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("invalid aes key size, should be 16 or 24, 32")
|
||||||
|
}
|
||||||
|
return ret, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ensureLogDirectory(location string) error {
|
||||||
|
// check for the logdir existence
|
||||||
|
logsLocation, err := os.Stat(location)
|
||||||
|
if err != nil {
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
return os.MkdirAll(location, os.ModeDir|os.FileMode(0750))
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !logsLocation.IsDir() {
|
||||||
|
return fmt.Errorf("log directory cannot be created")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
14
crypto.go
14
crypto.go
|
@ -95,17 +95,14 @@ func safeDecrypt(key []byte, cryptoText string) string {
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
func HostEncrypt(aesKey string, host *Host) error {
|
func HostEncrypt(aesKey string, host *Host) (err error) {
|
||||||
if aesKey == "" {
|
if aesKey == "" {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
var err error
|
|
||||||
if host.Password != "" {
|
if host.Password != "" {
|
||||||
if host.Password, err = encrypt([]byte(aesKey), host.Password); err != nil {
|
host.Password, err = encrypt([]byte(aesKey), host.Password)
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
func HostDecrypt(aesKey string, host *Host) {
|
func HostDecrypt(aesKey string, host *Host) {
|
||||||
if aesKey == "" {
|
if aesKey == "" {
|
||||||
|
@ -116,13 +113,12 @@ func HostDecrypt(aesKey string, host *Host) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func SSHKeyEncrypt(aesKey string, key *SSHKey) error {
|
func SSHKeyEncrypt(aesKey string, key *SSHKey) (err error) {
|
||||||
if aesKey == "" {
|
if aesKey == "" {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
var err error
|
|
||||||
key.PrivKey, err = encrypt([]byte(aesKey), key.PrivKey)
|
key.PrivKey, err = encrypt([]byte(aesKey), key.PrivKey)
|
||||||
return err
|
return
|
||||||
}
|
}
|
||||||
func SSHKeyDecrypt(aesKey string, key *SSHKey) {
|
func SSHKeyDecrypt(aesKey string, key *SSHKey) {
|
||||||
if aesKey == "" {
|
if aesKey == "" {
|
||||||
|
|
7
db.go
7
db.go
|
@ -366,16 +366,11 @@ func (u *User) HasRole(name string) bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
func (u *User) CheckRoles(names []string) error {
|
func (u *User) CheckRoles(names []string) error {
|
||||||
ok := false
|
|
||||||
for _, name := range names {
|
for _, name := range names {
|
||||||
if u.HasRole(name) {
|
if u.HasRole(name) {
|
||||||
ok = true
|
return nil
|
||||||
break
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return fmt.Errorf("you don't have permission to access this feature (requires any of these roles: '%s')", strings.Join(names, "', '"))
|
return fmt.Errorf("you don't have permission to access this feature (requires any of these roles: '%s')", strings.Join(names, "', '"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
12
dbinit.go
12
dbinit.go
|
@ -535,7 +535,9 @@ func dbInit(db *gorm.DB) error {
|
||||||
// create admin user
|
// create admin user
|
||||||
var defaultUserGroup UserGroup
|
var defaultUserGroup UserGroup
|
||||||
db.Where("name = ?", "default").First(&defaultUserGroup)
|
db.Where("name = ?", "default").First(&defaultUserGroup)
|
||||||
db.Table("users").Count(&count)
|
if err := db.Table("users").Count(&count).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
if count == 0 {
|
if count == 0 {
|
||||||
// if no admin, create an account for the first connection
|
// if no admin, create an account for the first connection
|
||||||
inviteToken := randStringBytes(16)
|
inviteToken := randStringBytes(16)
|
||||||
|
@ -588,14 +590,10 @@ func dbInit(db *gorm.DB) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// close unclosed connections
|
// close unclosed connections
|
||||||
if err := db.Table("sessions").Where("status = ?", "active").Updates(&Session{
|
return db.Table("sessions").Where("status = ?", "active").Updates(&Session{
|
||||||
Status: SessionStatusClosed,
|
Status: SessionStatusClosed,
|
||||||
ErrMsg: "sshportal was halted while the connection was still active",
|
ErrMsg: "sshportal was halted while the connection was still active",
|
||||||
}).Error; err != nil {
|
}).Error
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func hardDeleteCallback(scope *gorm.Scope) {
|
func hardDeleteCallback(scope *gorm.Scope) {
|
||||||
|
|
73
healthcheck.go
Normal file
73
healthcheck.go
Normal file
|
@ -0,0 +1,73 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/urfave/cli"
|
||||||
|
gossh "golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
// perform a healthcheck test without requiring an ssh client or an ssh key (used for Docker's HEALTHCHECK)
|
||||||
|
func healthcheck(addr string, wait, quiet bool) error {
|
||||||
|
cfg := gossh.ClientConfig{
|
||||||
|
User: "healthcheck",
|
||||||
|
HostKeyCallback: func(hostname string, remote net.Addr, key gossh.PublicKey) error { return nil },
|
||||||
|
Auth: []gossh.AuthMethod{gossh.Password("healthcheck")},
|
||||||
|
}
|
||||||
|
|
||||||
|
if wait {
|
||||||
|
for {
|
||||||
|
if err := healthcheckOnce(addr, cfg, quiet); err != nil {
|
||||||
|
if !quiet {
|
||||||
|
log.Printf("error: %v", err)
|
||||||
|
}
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := healthcheckOnce(addr, cfg, quiet); err != nil {
|
||||||
|
if quiet {
|
||||||
|
return cli.NewExitError("", 1)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func healthcheckOnce(addr string, config gossh.ClientConfig, quiet bool) error {
|
||||||
|
client, err := gossh.Dial("tcp", addr, &config)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
session, err := client.NewSession()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := session.Close(); err != nil {
|
||||||
|
if !quiet {
|
||||||
|
log.Printf("failed to close session: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
session.Stdout = &b
|
||||||
|
if err := session.Run(""); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
stdout := strings.TrimSpace(b.String())
|
||||||
|
if stdout != "OK" {
|
||||||
|
return fmt.Errorf("invalid stdout: %q expected 'OK'", stdout)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
78
hidden.go
Normal file
78
hidden.go
Normal file
|
@ -0,0 +1,78 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"os/exec"
|
||||||
|
"syscall"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"github.com/gliderlabs/ssh"
|
||||||
|
"github.com/kr/pty"
|
||||||
|
"github.com/urfave/cli"
|
||||||
|
)
|
||||||
|
|
||||||
|
// testServer is an hidden handler used for integration tests
|
||||||
|
func testServer(c *cli.Context) error {
|
||||||
|
ssh.Handle(func(s ssh.Session) {
|
||||||
|
helloMsg := struct {
|
||||||
|
User string
|
||||||
|
Environ []string
|
||||||
|
Command []string
|
||||||
|
}{
|
||||||
|
User: s.User(),
|
||||||
|
Environ: s.Environ(),
|
||||||
|
Command: s.Command(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.NewEncoder(s).Encode(&helloMsg); err != nil {
|
||||||
|
log.Fatalf("failed to write helloMsg: %v", err)
|
||||||
|
}
|
||||||
|
cmd := exec.Command(s.Command()[0], s.Command()[1:]...) // #nosec
|
||||||
|
if s.Command() == nil {
|
||||||
|
cmd = exec.Command("/bin/sh") // #nosec
|
||||||
|
}
|
||||||
|
ptyReq, winCh, isPty := s.Pty()
|
||||||
|
var cmdErr error
|
||||||
|
if isPty {
|
||||||
|
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term))
|
||||||
|
f, err := pty.Start(cmd)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(s, "failed to run command: %v\n", err) // #nosec
|
||||||
|
_ = s.Exit(1) // #nosec
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
for win := range winCh {
|
||||||
|
_, _, _ = syscall.Syscall(syscall.SYS_IOCTL, f.Fd(), uintptr(syscall.TIOCSWINSZ),
|
||||||
|
uintptr(unsafe.Pointer(&struct{ h, w, x, y uint16 }{uint16(win.Height), uint16(win.Width), 0, 0}))) // #nosec
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
// stdin
|
||||||
|
_, _ = io.Copy(f, s) // #nosec
|
||||||
|
}()
|
||||||
|
// stdout
|
||||||
|
_, _ = io.Copy(s, f) // #nosec
|
||||||
|
cmdErr = cmd.Wait()
|
||||||
|
} else {
|
||||||
|
//cmd.Stdin = s
|
||||||
|
cmd.Stdout = s
|
||||||
|
cmd.Stderr = s
|
||||||
|
cmdErr = cmd.Run()
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmdErr != nil {
|
||||||
|
if exitError, ok := cmdErr.(*exec.ExitError); ok {
|
||||||
|
_ = s.Exit(exitError.Sys().(syscall.WaitStatus).ExitStatus()) // #nosec
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ = s.Exit(cmd.ProcessState.Sys().(syscall.WaitStatus).ExitStatus()) // #nosec
|
||||||
|
})
|
||||||
|
|
||||||
|
log.Println("starting ssh server on port 2222...")
|
||||||
|
return ssh.ListenAndServe(":2222", nil)
|
||||||
|
}
|
231
main.go
231
main.go
|
@ -1,28 +1,19 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"log"
|
"log"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
|
||||||
"path"
|
"path"
|
||||||
"strings"
|
|
||||||
"syscall"
|
|
||||||
"time"
|
"time"
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"github.com/gliderlabs/ssh"
|
"github.com/gliderlabs/ssh"
|
||||||
"github.com/jinzhu/gorm"
|
"github.com/jinzhu/gorm"
|
||||||
_ "github.com/jinzhu/gorm/dialects/mysql"
|
_ "github.com/jinzhu/gorm/dialects/mysql"
|
||||||
_ "github.com/jinzhu/gorm/dialects/sqlite"
|
_ "github.com/jinzhu/gorm/dialects/sqlite"
|
||||||
"github.com/kr/pty"
|
|
||||||
"github.com/urfave/cli"
|
"github.com/urfave/cli"
|
||||||
gossh "golang.org/x/crypto/ssh"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -46,9 +37,18 @@ func main() {
|
||||||
app.Email = "https://github.com/moul/sshportal"
|
app.Email = "https://github.com/moul/sshportal"
|
||||||
app.Commands = []cli.Command{
|
app.Commands = []cli.Command{
|
||||||
{
|
{
|
||||||
Name: "server",
|
Name: "server",
|
||||||
Usage: "Start sshportal server",
|
Usage: "Start sshportal server",
|
||||||
Action: server,
|
Action: func(c *cli.Context) error {
|
||||||
|
if err := ensureLogDirectory(c.String("logs-location")); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
cfg, err := parseServeConfig(c)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return server(cfg)
|
||||||
|
},
|
||||||
Flags: []cli.Flag{
|
Flags: []cli.Flag{
|
||||||
cli.StringFlag{
|
cli.StringFlag{
|
||||||
Name: "bind-address, b",
|
Name: "bind-address, b",
|
||||||
|
@ -82,7 +82,7 @@ func main() {
|
||||||
},
|
},
|
||||||
}, {
|
}, {
|
||||||
Name: "healthcheck",
|
Name: "healthcheck",
|
||||||
Action: healthcheck,
|
Action: func(c *cli.Context) error { return healthcheck(c.String("addr"), c.Bool("wait"), c.Bool("quiet")) },
|
||||||
Flags: []cli.Flag{
|
Flags: []cli.Flag{
|
||||||
cli.StringFlag{
|
cli.StringFlag{
|
||||||
Name: "addr, a",
|
Name: "addr, a",
|
||||||
|
@ -109,211 +109,54 @@ func main() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func server(c *cli.Context) error {
|
func server(c *configServe) (err error) {
|
||||||
switch len(c.String("aes-key")) {
|
var db = (*gorm.DB)(nil)
|
||||||
case 0, 16, 24, 32:
|
|
||||||
default:
|
// try to setup the local DB
|
||||||
return fmt.Errorf("invalid aes key size, should be 16 or 24, 32")
|
if db, err = gorm.Open(c.dbDriver, c.dbURL); err != nil {
|
||||||
}
|
return
|
||||||
// db
|
|
||||||
db, err := gorm.Open(c.String("db-driver"), c.String("db-conn"))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
if err2 := db.Close(); err2 != nil {
|
origErr := err
|
||||||
panic(err2)
|
err = db.Close()
|
||||||
|
if origErr != nil {
|
||||||
|
err = origErr
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
if err = db.DB().Ping(); err != nil {
|
if err = db.DB().Ping(); err != nil {
|
||||||
return err
|
return
|
||||||
}
|
|
||||||
if c.Bool("debug") {
|
|
||||||
db.LogMode(true)
|
|
||||||
}
|
}
|
||||||
|
db.LogMode(c.debug)
|
||||||
if err = dbInit(db); err != nil {
|
if err = dbInit(db); err != nil {
|
||||||
return err
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// check for the logdir existence
|
|
||||||
logsLocation, e := os.Stat(c.String("logs-location"))
|
|
||||||
if e != nil {
|
|
||||||
err = os.MkdirAll(c.String("logs-location"), os.ModeDir|os.FileMode(0750))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if !logsLocation.IsDir() {
|
|
||||||
log.Fatal("log directory cannot be created")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
opts := []ssh.Option{}
|
|
||||||
// custom PublicKeyAuth handler
|
|
||||||
opts = append(opts, ssh.PublicKeyAuth(publicKeyAuthHandler(db, c)))
|
|
||||||
opts = append(opts, ssh.PasswordAuth(passwordAuthHandler(db, c)))
|
|
||||||
|
|
||||||
// retrieve sshportal SSH private key from database
|
|
||||||
opts = append(opts, func(srv *ssh.Server) error {
|
|
||||||
var key SSHKey
|
|
||||||
if err = SSHKeysByIdentifiers(db, []string{"host"}).First(&key).Error; err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
SSHKeyDecrypt(c.String("aes-key"), &key)
|
|
||||||
|
|
||||||
var signer gossh.Signer
|
|
||||||
signer, err = gossh.ParsePrivateKey([]byte(key.PrivKey))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
srv.AddHostKey(signer)
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
// create TCP listening socket
|
// create TCP listening socket
|
||||||
ln, err := net.Listen("tcp", c.String("bind-address"))
|
ln, err := net.Listen("tcp", c.bindAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// configure server
|
// configure server
|
||||||
srv := &ssh.Server{
|
srv := &ssh.Server{
|
||||||
Addr: c.String("bind-address"),
|
Addr: c.bindAddr,
|
||||||
Handler: shellHandler, // ssh.Server.Handler is the handler for the DefaultSessionHandler
|
Handler: shellHandler, // ssh.Server.Handler is the handler for the DefaultSessionHandler
|
||||||
Version: fmt.Sprintf("sshportal-%s", Version),
|
Version: fmt.Sprintf("sshportal-%s", Version),
|
||||||
ChannelHandler: channelHandler,
|
ChannelHandler: channelHandler,
|
||||||
}
|
}
|
||||||
for _, opt := range opts {
|
|
||||||
|
for _, opt := range []ssh.Option{
|
||||||
|
// custom PublicKeyAuth handler
|
||||||
|
ssh.PublicKeyAuth(publicKeyAuthHandler(db, c)),
|
||||||
|
ssh.PasswordAuth(passwordAuthHandler(db, c)),
|
||||||
|
// retrieve sshportal SSH private key from database
|
||||||
|
privateKeyFromDB(db, c.aesKey),
|
||||||
|
} {
|
||||||
if err := srv.SetOption(opt); err != nil {
|
if err := srv.SetOption(opt); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("info: SSH Server accepting connections on %s", c.String("bind-address"))
|
log.Printf("info: SSH Server accepting connections on %s", c.bindAddr)
|
||||||
return srv.Serve(ln)
|
return srv.Serve(ln)
|
||||||
}
|
}
|
||||||
|
|
||||||
// perform a healthcheck test without requiring an ssh client or an ssh key (used for Docker's HEALTHCHECK)
|
|
||||||
func healthcheck(c *cli.Context) error {
|
|
||||||
config := gossh.ClientConfig{
|
|
||||||
User: "healthcheck",
|
|
||||||
HostKeyCallback: func(hostname string, remote net.Addr, key gossh.PublicKey) error { return nil },
|
|
||||||
Auth: []gossh.AuthMethod{gossh.Password("healthcheck")},
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.Bool("wait") {
|
|
||||||
for {
|
|
||||||
if err := healthcheckOnce(c.String("addr"), config, c.Bool("quiet")); err != nil {
|
|
||||||
if !c.Bool("quiet") {
|
|
||||||
log.Printf("error: %v", err)
|
|
||||||
}
|
|
||||||
time.Sleep(time.Second)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := healthcheckOnce(c.String("addr"), config, c.Bool("quiet")); err != nil {
|
|
||||||
if c.Bool("quiet") {
|
|
||||||
return cli.NewExitError("", 1)
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func healthcheckOnce(addr string, config gossh.ClientConfig, quiet bool) error {
|
|
||||||
client, err := gossh.Dial("tcp", addr, &config)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
session, err := client.NewSession()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
if err := session.Close(); err != nil {
|
|
||||||
if !quiet {
|
|
||||||
log.Printf("failed to close session: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
var b bytes.Buffer
|
|
||||||
session.Stdout = &b
|
|
||||||
if err := session.Run(""); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
stdout := strings.TrimSpace(b.String())
|
|
||||||
if stdout != "OK" {
|
|
||||||
return fmt.Errorf("invalid stdout: %q expected 'OK'", stdout)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// testServer is an hidden handler used for integration tests
|
|
||||||
func testServer(c *cli.Context) error {
|
|
||||||
ssh.Handle(func(s ssh.Session) {
|
|
||||||
helloMsg := struct {
|
|
||||||
User string
|
|
||||||
Environ []string
|
|
||||||
Command []string
|
|
||||||
}{
|
|
||||||
User: s.User(),
|
|
||||||
Environ: s.Environ(),
|
|
||||||
Command: s.Command(),
|
|
||||||
}
|
|
||||||
enc := json.NewEncoder(s)
|
|
||||||
if err := enc.Encode(&helloMsg); err != nil {
|
|
||||||
log.Fatalf("failed to write helloMsg: %v", err)
|
|
||||||
}
|
|
||||||
var cmd *exec.Cmd
|
|
||||||
if s.Command() == nil {
|
|
||||||
cmd = exec.Command("/bin/sh") // #nosec
|
|
||||||
} else {
|
|
||||||
cmd = exec.Command(s.Command()[0], s.Command()[1:]...) // #nosec
|
|
||||||
}
|
|
||||||
ptyReq, winCh, isPty := s.Pty()
|
|
||||||
var cmdErr error
|
|
||||||
if isPty {
|
|
||||||
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term))
|
|
||||||
f, err := pty.Start(cmd)
|
|
||||||
if err != nil {
|
|
||||||
fmt.Fprintf(s, "failed to run command: %v\n", err)
|
|
||||||
_ = s.Exit(1)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
go func() {
|
|
||||||
for win := range winCh {
|
|
||||||
_, _, _ = syscall.Syscall(syscall.SYS_IOCTL, f.Fd(), uintptr(syscall.TIOCSWINSZ),
|
|
||||||
uintptr(unsafe.Pointer(&struct{ h, w, x, y uint16 }{uint16(win.Height), uint16(win.Width), 0, 0}))) // #nosec
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
go func() {
|
|
||||||
_, _ = io.Copy(f, s) // stdin
|
|
||||||
}()
|
|
||||||
_, _ = io.Copy(s, f) // stdout
|
|
||||||
cmdErr = cmd.Wait()
|
|
||||||
} else {
|
|
||||||
//cmd.Stdin = s
|
|
||||||
cmd.Stdout = s
|
|
||||||
cmd.Stderr = s
|
|
||||||
cmdErr = cmd.Run()
|
|
||||||
}
|
|
||||||
|
|
||||||
if cmdErr != nil {
|
|
||||||
if exitError, ok := cmdErr.(*exec.ExitError); ok {
|
|
||||||
waitStatus := exitError.Sys().(syscall.WaitStatus)
|
|
||||||
_ = s.Exit(waitStatus.ExitStatus())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
waitStatus := cmd.ProcessState.Sys().(syscall.WaitStatus)
|
|
||||||
_ = s.Exit(waitStatus.ExitStatus())
|
|
||||||
})
|
|
||||||
|
|
||||||
log.Println("starting ssh server on port 2222...")
|
|
||||||
return ssh.ListenAndServe(":2222", nil)
|
|
||||||
}
|
|
||||||
|
|
38
shell.go
38
shell.go
|
@ -323,11 +323,11 @@ GLOBAL OPTIONS:
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
for _, key := range config.SSHKeys {
|
for _, key := range config.SSHKeys {
|
||||||
SSHKeyDecrypt(actx.globalContext.String("aes-key"), key)
|
SSHKeyDecrypt(actx.config.aesKey, key)
|
||||||
}
|
}
|
||||||
if !c.Bool("decrypt") {
|
if !c.Bool("decrypt") {
|
||||||
for _, key := range config.SSHKeys {
|
for _, key := range config.SSHKeys {
|
||||||
if err := SSHKeyEncrypt(actx.globalContext.String("aes-key"), key); err != nil {
|
if err := SSHKeyEncrypt(actx.config.aesKey, key); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -337,11 +337,11 @@ GLOBAL OPTIONS:
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
for _, host := range config.Hosts {
|
for _, host := range config.Hosts {
|
||||||
HostDecrypt(actx.globalContext.String("aes-key"), host)
|
HostDecrypt(actx.config.aesKey, host)
|
||||||
}
|
}
|
||||||
if !c.Bool("decrypt") {
|
if !c.Bool("decrypt") {
|
||||||
for _, host := range config.Hosts {
|
for _, host := range config.Hosts {
|
||||||
if err := HostEncrypt(actx.globalContext.String("aes-key"), host); err != nil {
|
if err := HostEncrypt(actx.config.aesKey, host); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -456,9 +456,9 @@ GLOBAL OPTIONS:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for _, host := range config.Hosts {
|
for _, host := range config.Hosts {
|
||||||
HostDecrypt(actx.globalContext.String("aes-key"), host)
|
HostDecrypt(actx.config.aesKey, host)
|
||||||
if !c.Bool("decrypt") {
|
if !c.Bool("decrypt") {
|
||||||
if err := HostEncrypt(actx.globalContext.String("aes-key"), host); err != nil {
|
if err := HostEncrypt(actx.config.aesKey, host); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -492,9 +492,9 @@ GLOBAL OPTIONS:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for _, sshKey := range config.SSHKeys {
|
for _, sshKey := range config.SSHKeys {
|
||||||
SSHKeyDecrypt(actx.globalContext.String("aes-key"), sshKey)
|
SSHKeyDecrypt(actx.config.aesKey, sshKey)
|
||||||
if !c.Bool("decrypt") {
|
if !c.Bool("decrypt") {
|
||||||
if err := SSHKeyEncrypt(actx.globalContext.String("aes-key"), sshKey); err != nil {
|
if err := SSHKeyEncrypt(actx.config.aesKey, sshKey); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -697,7 +697,7 @@ GLOBAL OPTIONS:
|
||||||
}
|
}
|
||||||
|
|
||||||
// encrypt
|
// encrypt
|
||||||
if err := HostEncrypt(actx.globalContext.String("aes-key"), host); err != nil {
|
if err := HostEncrypt(actx.config.aesKey, host); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -734,7 +734,7 @@ GLOBAL OPTIONS:
|
||||||
|
|
||||||
if c.Bool("decrypt") {
|
if c.Bool("decrypt") {
|
||||||
for _, host := range hosts {
|
for _, host := range hosts {
|
||||||
HostDecrypt(actx.globalContext.String("aes-key"), host)
|
HostDecrypt(actx.config.aesKey, host)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1042,14 +1042,14 @@ GLOBAL OPTIONS:
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Fprintf(s, "Debug mode (server): %v\n", actx.globalContext.Bool("debug"))
|
fmt.Fprintf(s, "debug mode (server): %v\n", actx.config.debug)
|
||||||
hostname, _ := os.Hostname()
|
hostname, _ := os.Hostname()
|
||||||
fmt.Fprintf(s, "Hostname: %s\n", hostname)
|
fmt.Fprintf(s, "Hostname: %s\n", hostname)
|
||||||
fmt.Fprintf(s, "CPUs: %d\n", runtime.NumCPU())
|
fmt.Fprintf(s, "CPUs: %d\n", runtime.NumCPU())
|
||||||
fmt.Fprintf(s, "Demo mode: %v\n", actx.globalContext.Bool("demo"))
|
fmt.Fprintf(s, "Demo mode: %v\n", actx.config.demo)
|
||||||
fmt.Fprintf(s, "DB Driver: %s\n", actx.globalContext.String("db-driver"))
|
fmt.Fprintf(s, "DB Driver: %s\n", actx.config.dbDriver)
|
||||||
fmt.Fprintf(s, "DB Conn: %s\n", actx.globalContext.String("db-conn"))
|
fmt.Fprintf(s, "DB Conn: %s\n", actx.config.dbURL)
|
||||||
fmt.Fprintf(s, "Bind Address: %s\n", actx.globalContext.String("bind-address"))
|
fmt.Fprintf(s, "Bind Address: %s\n", actx.config.bindAddr)
|
||||||
fmt.Fprintf(s, "System Time: %v\n", time.Now().Format(time.RFC3339Nano))
|
fmt.Fprintf(s, "System Time: %v\n", time.Now().Format(time.RFC3339Nano))
|
||||||
fmt.Fprintf(s, "OS Type: %s\n", runtime.GOOS)
|
fmt.Fprintf(s, "OS Type: %s\n", runtime.GOOS)
|
||||||
fmt.Fprintf(s, "OS Architecture: %s\n", runtime.GOARCH)
|
fmt.Fprintf(s, "OS Architecture: %s\n", runtime.GOARCH)
|
||||||
|
@ -1095,8 +1095,8 @@ GLOBAL OPTIONS:
|
||||||
}
|
}
|
||||||
|
|
||||||
key, err := NewSSHKey(c.String("type"), c.Uint("length"))
|
key, err := NewSSHKey(c.String("type"), c.Uint("length"))
|
||||||
if actx.globalContext.String("aes-key") != "" {
|
if actx.config.aesKey != "" {
|
||||||
if err2 := SSHKeyEncrypt(actx.globalContext.String("aes-key"), key); err2 != nil {
|
if err2 := SSHKeyEncrypt(actx.config.aesKey, key); err2 != nil {
|
||||||
return err2
|
return err2
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1141,7 +1141,7 @@ GLOBAL OPTIONS:
|
||||||
|
|
||||||
if c.Bool("decrypt") {
|
if c.Bool("decrypt") {
|
||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
SSHKeyDecrypt(actx.globalContext.String("aes-key"), key)
|
SSHKeyDecrypt(actx.config.aesKey, key)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1250,7 +1250,7 @@ GLOBAL OPTIONS:
|
||||||
if err := SSHKeysByIdentifiers(SSHKeysPreload(db), c.Args()).First(&key).Error; err != nil {
|
if err := SSHKeysByIdentifiers(SSHKeysPreload(db), c.Args()).First(&key).Error; err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
SSHKeyDecrypt(actx.globalContext.String("aes-key"), &key)
|
SSHKeyDecrypt(actx.config.aesKey, &key)
|
||||||
|
|
||||||
type line struct {
|
type line struct {
|
||||||
key string
|
key string
|
||||||
|
|
35
ssh.go
35
ssh.go
|
@ -12,7 +12,6 @@ import (
|
||||||
"github.com/gliderlabs/ssh"
|
"github.com/gliderlabs/ssh"
|
||||||
"github.com/jinzhu/gorm"
|
"github.com/jinzhu/gorm"
|
||||||
"github.com/moul/sshportal/pkg/bastionsession"
|
"github.com/moul/sshportal/pkg/bastionsession"
|
||||||
"github.com/urfave/cli"
|
|
||||||
gossh "golang.org/x/crypto/ssh"
|
gossh "golang.org/x/crypto/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -27,7 +26,7 @@ type authContext struct {
|
||||||
inputUsername string
|
inputUsername string
|
||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
userKey UserKey
|
userKey UserKey
|
||||||
globalContext *cli.Context
|
config *configServe
|
||||||
authMethod string
|
authMethod string
|
||||||
authSuccess bool
|
authSuccess bool
|
||||||
}
|
}
|
||||||
|
@ -96,7 +95,6 @@ func channelHandler(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewCh
|
||||||
}
|
}
|
||||||
|
|
||||||
actx := ctx.Value(authContextKey).(*authContext)
|
actx := ctx.Value(authContextKey).(*authContext)
|
||||||
logsLocation := actx.globalContext.String("logs-location")
|
|
||||||
|
|
||||||
switch actx.userType() {
|
switch actx.userType() {
|
||||||
case UserTypeBastion:
|
case UserTypeBastion:
|
||||||
|
@ -145,7 +143,7 @@ func channelHandler(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewCh
|
||||||
err = bastionsession.ChannelHandler(srv, conn, newChan, ctx, bastionsession.Config{
|
err = bastionsession.ChannelHandler(srv, conn, newChan, ctx, bastionsession.Config{
|
||||||
Addr: host.DialAddr(),
|
Addr: host.DialAddr(),
|
||||||
ClientConfig: clientConfig,
|
ClientConfig: clientConfig,
|
||||||
Logs: logsLocation,
|
Logs: actx.config.logsLocation,
|
||||||
})
|
})
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
@ -200,8 +198,8 @@ func bastionClientConfig(ctx ssh.Context, host *Host) (*gossh.ClientConfig, erro
|
||||||
return nil, err2
|
return nil, err2
|
||||||
}
|
}
|
||||||
|
|
||||||
HostDecrypt(actx.globalContext.String("aes-key"), host)
|
HostDecrypt(actx.config.aesKey, host)
|
||||||
SSHKeyDecrypt(actx.globalContext.String("aes-key"), host.SSHKey)
|
SSHKeyDecrypt(actx.config.aesKey, host.SSHKey)
|
||||||
|
|
||||||
switch action {
|
switch action {
|
||||||
case ACLActionAllow:
|
case ACLActionAllow:
|
||||||
|
@ -246,12 +244,12 @@ func shellHandler(s ssh.Session) {
|
||||||
panic("should not happen")
|
panic("should not happen")
|
||||||
}
|
}
|
||||||
|
|
||||||
func passwordAuthHandler(db *gorm.DB, globalContext *cli.Context) ssh.PasswordHandler {
|
func passwordAuthHandler(db *gorm.DB, cfg *configServe) ssh.PasswordHandler {
|
||||||
return func(ctx ssh.Context, pass string) bool {
|
return func(ctx ssh.Context, pass string) bool {
|
||||||
actx := &authContext{
|
actx := &authContext{
|
||||||
db: db,
|
db: db,
|
||||||
inputUsername: ctx.User(),
|
inputUsername: ctx.User(),
|
||||||
globalContext: globalContext,
|
config: cfg,
|
||||||
authMethod: "password",
|
authMethod: "password",
|
||||||
}
|
}
|
||||||
actx.authSuccess = actx.userType() == UserTypeHealthcheck
|
actx.authSuccess = actx.userType() == UserTypeHealthcheck
|
||||||
|
@ -260,12 +258,29 @@ func passwordAuthHandler(db *gorm.DB, globalContext *cli.Context) ssh.PasswordHa
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func publicKeyAuthHandler(db *gorm.DB, globalContext *cli.Context) ssh.PublicKeyHandler {
|
func privateKeyFromDB(db *gorm.DB, aesKey string) func(*ssh.Server) error {
|
||||||
|
return func(srv *ssh.Server) error {
|
||||||
|
var key SSHKey
|
||||||
|
if err := SSHKeysByIdentifiers(db, []string{"host"}).First(&key).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
SSHKeyDecrypt(aesKey, &key)
|
||||||
|
|
||||||
|
signer, err := gossh.ParsePrivateKey([]byte(key.PrivKey))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
srv.AddHostKey(signer)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func publicKeyAuthHandler(db *gorm.DB, cfg *configServe) ssh.PublicKeyHandler {
|
||||||
return func(ctx ssh.Context, key ssh.PublicKey) bool {
|
return func(ctx ssh.Context, key ssh.PublicKey) bool {
|
||||||
actx := &authContext{
|
actx := &authContext{
|
||||||
db: db,
|
db: db,
|
||||||
inputUsername: ctx.User(),
|
inputUsername: ctx.User(),
|
||||||
globalContext: globalContext,
|
config: cfg,
|
||||||
authMethod: "pubkey",
|
authMethod: "pubkey",
|
||||||
authSuccess: true,
|
authSuccess: true,
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,9 +25,9 @@ func (caller bastionTelnetCaller) CallTELNET(ctx telnet.Context, w telnet.Writer
|
||||||
for {
|
for {
|
||||||
// Read 1 byte.
|
// Read 1 byte.
|
||||||
n, err := reader.Read(p)
|
n, err := reader.Read(p)
|
||||||
if n <= 0 && nil == err {
|
if n <= 0 && err == nil {
|
||||||
continue
|
continue
|
||||||
} else if n <= 0 && nil != err {
|
} else if n <= 0 && err != nil {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue