sshportal/main.go

153 lines
3.3 KiB
Go
Raw Normal View History

2017-09-30 19:12:43 +08:00
package main
import (
"errors"
"fmt"
"io"
"log"
"os"
"path"
"github.com/gliderlabs/ssh"
2017-10-30 23:48:14 +08:00
"github.com/jinzhu/gorm"
2017-10-31 00:12:04 +08:00
_ "github.com/jinzhu/gorm/dialects/mysql"
2017-10-30 23:48:14 +08:00
_ "github.com/jinzhu/gorm/dialects/sqlite"
2017-09-30 19:12:43 +08:00
"github.com/urfave/cli"
)
2017-11-02 05:09:08 +08:00
var version = "0.0.1"
type sshportalContextKey string
var userContextKey = sshportalContextKey("user")
2017-09-30 19:12:43 +08:00
func main() {
app := cli.NewApp()
app.Name = path.Base(os.Args[0])
app.Author = "Manfred Touron"
2017-11-02 05:09:08 +08:00
app.Version = version
2017-09-30 19:12:43 +08:00
app.Email = "https://github.com/moul/sshportal"
app.Flags = []cli.Flag{
cli.StringFlag{
Name: "bind-address, b",
EnvVar: "SSHPORTAL_BIND",
Value: ":2222",
Usage: "SSH server bind address",
},
cli.BoolFlag{
Name: "demo",
Usage: "*unsafe* - demo mode: accept all connections",
},
2017-10-30 23:48:14 +08:00
cli.StringFlag{
Name: "db-driver",
Value: "sqlite3",
2017-10-31 00:12:04 +08:00
Usage: "GORM driver (sqlite3, mysql)",
2017-10-30 23:48:14 +08:00
},
cli.StringFlag{
Name: "db-conn",
Value: "./sshportal.db",
2017-10-31 00:12:04 +08:00
Usage: "GORM connection string",
},
cli.BoolFlag{
Name: "debug, D",
Usage: "Display debug information",
2017-10-30 23:48:14 +08:00
},
2017-11-02 05:11:46 +08:00
cli.StringFlag{
Name: "config-user",
Usage: "SSH user that spawns a configuration shell",
Value: "admin",
},
2017-09-30 19:12:43 +08:00
}
app.Action = server
if err := app.Run(os.Args); err != nil {
log.Fatalf("error: %v", err)
}
2017-09-30 19:12:43 +08:00
}
func server(c *cli.Context) error {
2017-10-30 23:48:14 +08:00
db, err := gorm.Open(c.String("db-driver"), c.String("db-conn"))
if err != nil {
return err
}
defer db.Close()
2017-10-31 00:12:04 +08:00
if c.Bool("debug") {
db.LogMode(true)
}
2017-10-30 23:48:14 +08:00
if err := dbInit(db); err != nil {
return err
}
if c.Bool("demo") {
if err := dbDemo(db); err != nil {
return err
}
}
2017-09-30 19:12:43 +08:00
ssh.Handle(func(s ssh.Session) {
log.Printf("New connection: user=%q remote=%q local=%q command=%q", s.User(), s.RemoteAddr(), s.LocalAddr(), s.Command())
switch s.User() {
2017-11-02 05:11:46 +08:00
case c.String("config-user"):
2017-11-02 00:00:34 +08:00
if err := shell(c, s, s.Command(), db); err != nil {
2017-10-31 17:17:06 +08:00
io.WriteString(s, fmt.Sprintf("error: %v\n", err))
}
2017-09-30 19:12:43 +08:00
default:
2017-10-31 16:24:18 +08:00
host, err := RemoteHostFromSession(s, db)
2017-09-30 19:12:43 +08:00
if err != nil {
io.WriteString(s, fmt.Sprintf("error: %v\n", err))
2017-10-31 16:24:18 +08:00
// FIXME: print available hosts
2017-09-30 19:12:43 +08:00
return
}
2017-10-31 16:24:18 +08:00
if err := proxy(s, host); err != nil {
2017-09-30 19:12:43 +08:00
io.WriteString(s, fmt.Sprintf("error: %v\n", err))
}
}
})
opts := []ssh.Option{}
if !c.Bool("demo") {
return errors.New("use `--demo` for now")
2017-09-30 19:12:43 +08:00
}
opts = append(opts, ssh.PublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) bool {
var (
userKey UserKey
user User
count uint
)
// lookup user by key
db.Where("key = ?", key.Marshal()).First(&userKey)
if userKey.UserID > 0 {
db.Where("id = ?", userKey.UserID).First(&user)
ctx.SetValue(userContextKey, user)
return true
}
// check if there are users in DB
db.Table("users").Count(&count)
if count == 0 { // create an admin user
// if no admin, create an account for the first connection
user = User{
Name: "Administrator",
Email: "admin@sshportal",
Comment: "created by sshportal",
IsAdmin: true,
}
db.Create(&user)
userKey = UserKey{
UserID: user.ID,
Key: key.Marshal(),
}
db.Create(&userKey)
ctx.SetValue(userContextKey, user)
return true
}
return false
}))
2017-09-30 19:12:43 +08:00
log.Printf("SSH Server accepting connections on %s", c.String("bind-address"))
return ssh.ListenAndServe(c.String("bind-address"), nil, opts...)
}