bashhub-server/internal/db.go

479 lines
16 KiB
Go
Raw Normal View History

2020-02-08 00:14:22 +08:00
package internal
import (
"database/sql"
"fmt"
"log"
"regexp"
2020-02-08 00:14:22 +08:00
"strings"
"time"
"github.com/jinzhu/gorm"
_ "github.com/jinzhu/gorm/dialects/postgres"
_ "github.com/jinzhu/gorm/dialects/sqlite"
_ "github.com/lib/pq"
"github.com/mattn/go-sqlite3"
2020-02-10 07:30:05 +08:00
"golang.org/x/crypto/bcrypt"
2020-02-08 00:14:22 +08:00
)
var (
DB *sql.DB
DbPath string
connectionLimit int
)
// DbInit initializes our db.
func DbInit() {
// GormDB contains DB connection state
var gormdb *gorm.DB
var err error
if strings.HasPrefix(DbPath, "postgres://") {
//
DB, err = sql.Open("postgres", DbPath)
if err != nil {
log.Fatal(err)
}
gormdb, err = gorm.Open("postgres", DbPath)
if err != nil {
log.Fatal(err)
}
connectionLimit = 50
} else {
gormdb, err = gorm.Open("sqlite3", DbPath)
2020-02-08 00:14:22 +08:00
if err != nil {
log.Fatal(err)
}
regex := func(re, s string) (bool, error) {
b, e := regexp.MatchString(re, s)
return b, e
}
sql.Register("sqlite3_with_regex",
&sqlite3.SQLiteDriver{
ConnectHook: func(conn *sqlite3.SQLiteConn) error {
return conn.RegisterFunc("regexp", regex, true)
},
})
DbPath = fmt.Sprintf("file:%v?cache=shared&mode=rwc", DbPath)
DB, err = sql.Open("sqlite3_with_regex", DbPath)
2020-02-08 00:14:22 +08:00
if err != nil {
log.Fatal(err)
}
2020-02-08 00:14:22 +08:00
DB.Exec("PRAGMA journal_mode=WAL;")
connectionLimit = 1
}
DB.SetMaxOpenConns(connectionLimit)
gormdb.AutoMigrate(&User{})
gormdb.AutoMigrate(&Command{})
gormdb.AutoMigrate(&System{})
gormdb.Model(&User{}).AddIndex("idx_user", "username")
gormdb.Model(&System{}).AddIndex("idx_mac", "mac")
gormdb.Model(&Command{}).AddIndex("idx_exit_command", "exit_status, command")
2020-02-08 00:14:22 +08:00
// just need gorm for migration.
gormdb.Close()
}
2020-02-10 07:30:05 +08:00
func hashAndSalt(password string) string {
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.MinCost)
if err != nil {
log.Println(err)
}
return string(hash)
}
func comparePasswords(hashedPwd string, plainPwd string) bool {
byteHash := []byte(hashedPwd)
err := bcrypt.CompareHashAndPassword(byteHash, []byte(plainPwd))
if err != nil {
log.Println(err)
return false
}
return true
}
2020-02-08 00:14:22 +08:00
func (user User) userExists() bool {
2020-02-10 07:30:05 +08:00
var password string
err := DB.QueryRow("SELECT password FROM users WHERE username = $1",
user.Username).Scan(&password)
if err != nil && err != sql.ErrNoRows {
log.Fatalf("error checking if row exists %v", err)
}
if password != "" {
return comparePasswords(password, user.Password)
}
return false
}
func (user User) userGetId() uint {
var id uint
err := DB.QueryRow("SELECT id FROM users WHERE username = $1",
user.Username).Scan(&id)
if err != nil && err != sql.ErrNoRows {
log.Fatalf("error checking if row exists %v", err)
}
return id
}
func (user User) userGetSystemName() string {
var systemName string
err := DB.QueryRow(`SELECT name
FROM systems
WHERE user_id in (select id from users where username = $1)
AND mac = $2`,
user.Username, user.Mac).Scan(&systemName)
if err != nil && err != sql.ErrNoRows {
log.Fatalf("error checking if row exists %v", err)
}
return systemName
}
func (user User) usernameExists() bool {
var exists bool
err := DB.QueryRow(`SELECT exists (select id FROM users WHERE "username" = $1)`,
user.Username).Scan(&exists)
if err != nil && err != sql.ErrNoRows {
log.Fatalf("error checking if row exists %v", err)
}
return exists
}
func (user User) emailExists() bool {
2020-02-08 00:14:22 +08:00
var exists bool
2020-02-10 07:30:05 +08:00
err := DB.QueryRow(`SELECT exists (select id FROM users WHERE "email" = $1)`,
user.Email).Scan(&exists)
2020-02-08 00:14:22 +08:00
if err != nil && err != sql.ErrNoRows {
log.Fatalf("error checking if row exists %v", err)
}
return exists
}
func (user User) userCreate() int64 {
res, err := DB.Exec(`INSERT INTO users("registration_code", "username","password","email")
2020-02-08 00:14:22 +08:00
VALUES ($1,$2,$3,$4) ON CONFLICT(username) do nothing`, user.RegistrationCode,
user.Username, user.Password, user.Email)
if err != nil {
log.Fatal(err)
}
inserted, err := res.RowsAffected()
if err != nil {
log.Fatal(err)
}
return inserted
}
func (cmd Command) commandInsert() int64 {
2020-02-10 07:47:18 +08:00
res, err := DB.Exec(`INSERT INTO commands("process_id","process_start_time","exit_status","uuid", "command", "created", "path", "user_id", "system_name")
VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9)`,
cmd.ProcessId, cmd.ProcessStartTime, cmd.ExitStatus, cmd.Uuid, cmd.Command, cmd.Created, cmd.Path, cmd.User.ID, cmd.SystemName)
2020-02-08 00:14:22 +08:00
if err != nil {
log.Fatal(err)
}
inserted, err := res.RowsAffected()
if err != nil {
log.Fatal(err)
}
return inserted
}
func (cmd Command) commandGet() []Query {
var results []Query
var rows *sql.Rows
var err error
if cmd.Unique || cmd.Query != "" {
//postgres
if connectionLimit != 1 {
2020-02-10 07:47:18 +08:00
if cmd.SystemName != "" && cmd.Path != "" && cmd.Query != "" && cmd.Unique {
rows, err = DB.Query(`SELECT * FROM (
SELECT DISTINCT ON ("command") command, "uuid", "created"
FROM commands
WHERE "user_id" = $1
AND ("exit_status" = 0 OR "exit_status" = 130)
AND "command" not like 'bh%'
AND "path" = $3
AND "system_name" = $4
AND "command" ~ $5
) c
ORDER BY "created" DESC limit $2;`, cmd.User.ID, cmd.Limit, cmd.Path, cmd.SystemName, cmd.Query)
} else if cmd.Path != "" && cmd.Query != "" && cmd.Unique {
rows, err = DB.Query(`SELECT * FROM (
SELECT DISTINCT ON ("command") command, "uuid", "created"
FROM commands
2020-02-10 07:30:05 +08:00
WHERE "user_id" = $1
AND ("exit_status" = 0 OR "exit_status" = 130)
AND "command" not like 'bh%'
AND "path" = $3
AND "command" ~ $4
) c
2020-02-10 07:30:05 +08:00
ORDER BY "created" DESC limit $2;`, cmd.User.ID, cmd.Limit, cmd.Path, cmd.Query)
2020-02-10 07:47:18 +08:00
} else if cmd.SystemName != "" && cmd.Query != "" {
rows, err = DB.Query(`SELECT "command", "uuid", "created"
FROM commands
WHERE "user_id" = $1
AND ("exit_status" = 0 OR "exit_status" = 130)
AND "command" not like 'bh%'
AND "system_name" = $3
AND "command" ~ $4
ORDER BY "created" DESC limit $2;`, cmd.User.ID, cmd.Limit, cmd.SystemName, cmd.Query)
} else if cmd.Path != "" && cmd.Query != "" {
rows, err = DB.Query(`SELECT "command", "uuid", "created"
FROM commands
2020-02-10 07:30:05 +08:00
WHERE "user_id" = $1
AND ("exit_status" = 0 OR "exit_status" = 130)
AND "command" not like 'bh%'
AND "path" = $3
AND "command" ~ $4
2020-02-10 07:30:05 +08:00
ORDER BY "created" DESC limit $2;`, cmd.User.ID, cmd.Limit, cmd.Path, cmd.Query)
2020-02-10 07:47:18 +08:00
} else if cmd.SystemName != "" && cmd.Unique {
rows, err = DB.Query(`SELECT * FROM (
SELECT DISTINCT ON ("command") command, "uuid", "created"
FROM commands
WHERE "user_id" = $1
AND ("exit_status" = 0 OR "exit_status" = 130)
AND "command" not like 'bh%'
AND "system_name" = $3
) c
ORDER BY "created" DESC limit $2;`, cmd.User.ID, cmd.Limit, cmd.SystemName)
} else if cmd.Path != "" && cmd.Unique {
rows, err = DB.Query(`SELECT * FROM (
SELECT DISTINCT ON ("command") command, "uuid", "created"
FROM commands
2020-02-10 07:30:05 +08:00
WHERE "user_id" = $1
AND ("exit_status" = 0 OR "exit_status" = 130)
AND "command" not like 'bh%'
AND "path" = $3
) c
2020-02-10 07:30:05 +08:00
ORDER BY "created" DESC limit $2;`, cmd.User.ID, cmd.Limit, cmd.Path)
} else if cmd.Query != "" && cmd.Unique {
rows, err = DB.Query(`SELECT * FROM (
SELECT DISTINCT ON ("command") command, "uuid", "created"
FROM commands
2020-02-10 07:30:05 +08:00
WHERE "user_id" = $1
AND ("exit_status" = 0 OR "exit_status" = 130)
AND "command" not like 'bh%'
AND "command" ~ $3
) c
2020-02-10 07:30:05 +08:00
ORDER BY "created" DESC limit $2;`, cmd.User.ID, cmd.Limit, cmd.Query)
} else if cmd.Query != "" {
rows, err = DB.Query(`SELECT "command", "uuid", "created"
FROM commands
2020-02-10 07:30:05 +08:00
WHERE "user_id" = $1
AND ("exit_status" = 0 OR "exit_status" = 130)
AND "command" not like 'bh%'
AND "command" ~ $3
2020-02-10 07:30:05 +08:00
ORDER BY "created" DESC limit $2;`, cmd.User.ID, cmd.Limit, cmd.Query)
} else {
// unique
rows, err = DB.Query(`SELECT * FROM (
SELECT DISTINCT ON ("command") command, "uuid", "created"
FROM commands
2020-02-10 07:30:05 +08:00
WHERE "user_id" = $1
AND ("exit_status" = 0 OR "exit_status" = 130)
AND "command" not like 'bh%'
) c
2020-02-10 07:30:05 +08:00
ORDER BY "created" DESC limit $2;`, cmd.User.ID, cmd.Limit)
}
} else {
// sqlite
2020-02-10 07:47:18 +08:00
if cmd.SystemName != "" && cmd.Path != "" && cmd.Query != "" && cmd.Unique {
query := fmt.Sprintf(`SELECT "command", "uuid", "created" FROM commands
WHERE "user_id" = '%v'
AND ("exit_status" = 0 OR "exit_status" = 130)
AND "command" not like '%v'
AND "path" = '%v'
AND "system_name" = '%v'
AND "command" regexp '%v'
GROUP BY "command" ORDER BY "created" DESC limit '%v'`,
cmd.User.ID, "bh%", cmd.Path, cmd.SystemName, cmd.Query, cmd.Limit)
rows, err = DB.Query(query)
} else if cmd.SystemName != "" && cmd.Query != "" && cmd.Unique {
query := fmt.Sprintf(`SELECT "command", "uuid", "created" FROM commands
WHERE "user_id" = '%v'
AND ("exit_status" = 0 OR "exit_status" = 130)
AND "command" not like '%v'
AND "system_name" = '%v'
AND "command" regexp '%v'
GROUP BY "command" ORDER BY "created" DESC limit '%v'`,
cmd.User.ID, "bh%", cmd.SystemName, cmd.Query, cmd.Limit)
rows, err = DB.Query(query)
} else if cmd.Path != "" && cmd.Query != "" && cmd.Unique {
query := fmt.Sprintf(`SELECT "command", "uuid", "created" FROM commands
2020-02-10 07:30:05 +08:00
WHERE "user_id" = '%v'
AND ("exit_status" = 0 OR "exit_status" = 130)
AND "command" not like '%v'
AND "path" = '%v'
AND "command" regexp '%v'
GROUP BY "command" ORDER BY "created" DESC limit '%v'`,
2020-02-10 07:30:05 +08:00
cmd.User.ID, "bh%", cmd.Path, cmd.Query, cmd.Limit)
rows, err = DB.Query(query)
2020-02-10 07:47:18 +08:00
} else if cmd.SystemName != "" && cmd.Query != "" {
query := fmt.Sprintf(`SELECT "command", "uuid", "created" FROM commands
WHERE "user_id" = '%v'
AND ("exit_status" = 0 OR "exit_status" = 130)
AND "command" not like '%v'
AND "system_name" = %v'
AND "command" regexp %v'
ORDER BY "created" DESC limit '%v'`,
cmd.User.ID, "bh%", cmd.SystemName, cmd.Query, cmd.Limit)
rows, err = DB.Query(query)
} else if cmd.Path != "" && cmd.Query != "" {
query := fmt.Sprintf(`SELECT "command", "uuid", "created" FROM commands
2020-02-10 07:30:05 +08:00
WHERE "user_id" = '%v'
AND ("exit_status" = 0 OR "exit_status" = 130)
AND "command" not like '%v'
AND "path" = %v'
AND "command" regexp %v'
ORDER BY "created" DESC limit '%v'`,
2020-02-10 07:30:05 +08:00
cmd.User.ID, "bh%", cmd.Path, cmd.Query, cmd.Limit)
rows, err = DB.Query(query)
2020-02-10 07:47:18 +08:00
}else if cmd.SystemName != "" && cmd.Unique {
rows, err = DB.Query(`SELECT "command", "uuid", "created" FROM commands
WHERE "user_id" = $1
AND ("exit_status" = 0 OR "exit_status" = 130)
AND "command" not like 'bh%'
AND "system_name" = $2
GROUP BY "command" ORDER BY "created" DESC limit $3`,
cmd.User.ID, cmd.SystemName, cmd.Limit)
} else if cmd.Path != "" && cmd.Unique {
rows, err = DB.Query(`SELECT "command", "uuid", "created" FROM commands
2020-02-10 07:30:05 +08:00
WHERE "user_id" = $1
AND ("exit_status" = 0 OR "exit_status" = 130)
AND "command" not like 'bh%'
AND "path" = $2
GROUP BY "command" ORDER BY "created" DESC limit $3`,
2020-02-10 07:30:05 +08:00
cmd.User.ID, cmd.Path, cmd.Limit)
} else if cmd.Query != "" && cmd.Unique {
query := fmt.Sprintf(`SELECT "command", "uuid", "created" FROM commands
2020-02-10 07:30:05 +08:00
WHERE "user_id" = '%v'
AND ("exit_status" = 0 OR "exit_status" = 130)
AND "command" not like '%v'
AND "command" regexp '%v'
GROUP BY "command" ORDER BY "created" DESC limit '%v'`,
2020-02-10 07:30:05 +08:00
cmd.User.ID, "bh%", cmd.Query, cmd.Limit)
rows, err = DB.Query(query)
} else if cmd.Query != "" {
query := fmt.Sprintf(`SELECT "command", "uuid", "created" FROM commands
2020-02-10 07:30:05 +08:00
WHERE "user_id" = '%v'
AND ("exit_status" = 0 OR "exit_status" = 130)
AND "command" not like '%v'
AND "command" regexp'%v'
ORDER BY "created" DESC limit '%v'`,
2020-02-10 07:30:05 +08:00
cmd.User.ID, "bh%", cmd.Query, cmd.Limit)
rows, err = DB.Query(query)
} else {
// unique
rows, err = DB.Query(`SELECT "command", "uuid", "created"
FROM commands
2020-02-10 07:30:05 +08:00
WHERE "user_id" = $1
AND ("exit_status" = 0 OR "exit_status" = 130)
AND "command" not like 'bh%'
2020-02-10 07:30:05 +08:00
GROUP BY "command" ORDER BY "created" DESC limit $2;`, cmd.User.ID, cmd.Limit)
}
}
2020-02-08 00:14:22 +08:00
} else {
if cmd.Path != "" {
rows, err = DB.Query(`SELECT "command", "uuid", "created" FROM commands
2020-02-10 07:47:18 +08:00
WHERE "user_id" = $1
AND "path" = $3
AND ("exit_status" = 0 OR "exit_status" = 130)
AND "command" not like 'bh%'
2020-02-10 07:30:05 +08:00
ORDER BY "created" DESC limit $2`, cmd.User.ID, cmd.Limit, cmd.Path)
2020-02-10 07:47:18 +08:00
} else if cmd.SystemName != "" {
rows, err = DB.Query(`SELECT "command", "uuid", "created" FROM commands
WHERE "user_id" = $1
AND "system_name" = $3
AND ("exit_status" = 0 OR "exit_status" = 130) AND "command" not like 'bh%'
ORDER BY "created" DESC limit $2`, cmd.User.ID, cmd.Limit, cmd.SystemName)
} else {
rows, err = DB.Query(`SELECT "command", "uuid", "created" FROM commands
2020-02-10 07:30:05 +08:00
WHERE "user_id" = $1
2020-02-10 07:47:18 +08:00
AND ("exit_status" = 0 OR "exit_status" = 130)
AND "command" not like 'bh%'
2020-02-10 07:30:05 +08:00
ORDER BY "created" DESC limit $2`, cmd.User.ID, cmd.Limit)
}
2020-02-08 00:14:22 +08:00
}
if err != nil {
log.Println(err)
}
defer rows.Close()
for rows.Next() {
var result Query
err = rows.Scan(&result.Command, &result.Uuid, &result.Created)
if err != nil {
log.Println(err)
}
results = append(results, result)
}
return results
}
2020-02-10 08:30:26 +08:00
func (cmd Command) commandGetUUID() Query {
var result Query
err := DB.QueryRow(`SELECT "command","path", "created" , "uuid", "exit_status", "system_name"
FROM commands
WHERE "uuid" = $1
AND "user_id" = $2`, cmd.Uuid, cmd.User.ID).Scan(
&result.Command, &result.Path, &result.Created, &result.Uuid,
&result.ExitStatus, &result.SystemName)
if err != nil {
log.Println(err)
}
return result
}
2020-02-08 00:14:22 +08:00
func (sys System) systemInsert() int64 {
t := time.Now().Unix()
res, err := DB.Exec(`INSERT INTO systems ("name", "mac", "user_id", "hostname", "client_version", "created", "updated")
2020-02-10 07:30:05 +08:00
VALUES ($1, $2, $3, $4, $5, $6, $7)`,
sys.Name, sys.Mac, sys.User.ID, sys.Hostname, sys.ClientVersion, t, t)
2020-02-08 00:14:22 +08:00
if err != nil {
log.Fatal(err)
}
inserted, err := res.RowsAffected()
if err != nil {
log.Fatal(err)
}
return inserted
}
2020-02-10 07:30:05 +08:00
func (sys System) systemGet() SystemQuery {
2020-02-08 00:14:22 +08:00
var row SystemQuery
err := DB.QueryRow(`SELECT "name", "mac", "user_id", "hostname", "client_version",
2020-02-10 07:30:05 +08:00
"id", "created", "updated" FROM systems
WHERE "user_id" $1
AND "mac" = $2`,
sys.User.ID, sys.Mac).Scan(&row)
2020-02-08 00:14:22 +08:00
if err != nil {
2020-02-10 07:30:05 +08:00
return SystemQuery{}
2020-02-08 00:14:22 +08:00
}
2020-02-10 07:30:05 +08:00
return row
2020-02-08 00:14:22 +08:00
}