test/server_test: addes postgres tests

This commit is contained in:
nicksherron 2020-02-14 08:53:40 -05:00
parent ce4313b18e
commit 8ee357e2e7
8 changed files with 332 additions and 144 deletions

View file

@ -7,4 +7,4 @@ env:
install: true install: true
script: go run *.go version script: go test ./...

View file

@ -63,3 +63,16 @@ clean:
test: test:
go test -v ./... go test -v ./...
docker-postgres-stop:
docker stop bashhub-postgres-test
docker-postgres-start:
docker run -d --rm --name bashhub-postgres-test -p 5444:5432 postgres
test-postgres:
go test -v ./... -postgres -postgres-uri "postgres://postgres:@localhost:5444?sslmode=disable"
test-docker-postgres: docker-postgres-stop docker-postgres-start test-postgres

View file

@ -38,10 +38,12 @@ import (
) )
var ( var (
db *sql.DB //DB is a connection pool to sqlite or postgres
DB *sql.DB
// DbPath is the postgres connection uri or the sqlite db file location to use for backend. // DbPath is the postgres connection uri or the sqlite db file location to use for backend.
DbPath string DbPath string
connectionLimit int connectionLimit int
QueryDebug bool
) )
// DbInit initializes our db. // DbInit initializes our db.
@ -51,7 +53,7 @@ func dbInit() {
if strings.HasPrefix(DbPath, "postgres://") { if strings.HasPrefix(DbPath, "postgres://") {
// postgres // postgres
db, err = sql.Open("postgres", DbPath) DB, err = sql.Open("postgres", DbPath)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
@ -81,16 +83,16 @@ func dbInit() {
}) })
DbPath = fmt.Sprintf("file:%v?cache=shared&mode=rwc&_loc=auto", DbPath) DbPath = fmt.Sprintf("file:%v?cache=shared&mode=rwc&_loc=auto", DbPath)
db, err = sql.Open("sqlite3_with_regex", DbPath) DB, err = sql.Open("sqlite3_with_regex", DbPath)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
db.Exec("PRAGMA journal_mode=WAL;") DB.Exec("PRAGMA journal_mode=WAL;")
connectionLimit = 1 connectionLimit = 1
} }
db.SetMaxOpenConns(connectionLimit) DB.SetMaxOpenConns(connectionLimit)
gormdb.AutoMigrate(&User{}) gormdb.AutoMigrate(&User{})
gormdb.AutoMigrate(&Command{}) gormdb.AutoMigrate(&Command{})
gormdb.AutoMigrate(&System{}) gormdb.AutoMigrate(&System{})
@ -111,19 +113,19 @@ func dbInit() {
func (c Config) getSecret() string { func (c Config) getSecret() string {
var err error var err error
if connectionLimit != 1 { if connectionLimit != 1 {
_, err = db.Exec(`INSERT INTO configs ("id","created", "secret") _, err = DB.Exec(`INSERT INTO configs ("id","created", "secret")
VALUES (1, now(), (SELECT md5(random()::text))) VALUES (1, now(), (SELECT md5(random()::text)))
ON conflict do nothing;`) ON conflict do nothing;`)
} else { } else {
_, err = db.Exec(`INSERT INTO configs ("id","created" ,"secret") _, err = DB.Exec(`INSERT INTO configs ("id","created" ,"secret")
VALUES (1, current_timestamp, lower(hex(randomblob(16)))) VALUES (1, current_timestamp, lower(hex(randomblob(16))))
ON conflict do nothing;`) ON conflict do nothing;`)
} }
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
err = db.QueryRow(`SELECT "secret" from configs where "id" = 1 `).Scan(&c.Secret) err = DB.QueryRow(`SELECT "secret" from configs where "id" = 1 `).Scan(&c.Secret)
return c.Secret return c.Secret
} }
@ -147,7 +149,7 @@ func comparePasswords(hashedPwd string, plainPwd string) bool {
func (user User) userExists() bool { func (user User) userExists() bool {
var password string var password string
err := db.QueryRow("SELECT password FROM users WHERE username = $1", err := DB.QueryRow("SELECT password FROM users WHERE username = $1",
user.Username).Scan(&password) user.Username).Scan(&password)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
log.Fatalf("error checking if row exists %v", err) log.Fatalf("error checking if row exists %v", err)
@ -160,7 +162,7 @@ func (user User) userExists() bool {
func (user User) userGetID() uint { func (user User) userGetID() uint {
var id uint var id uint
err := db.QueryRow(`SELECT "id" err := DB.QueryRow(`SELECT "id"
FROM users FROM users
WHERE "username" = $1`, WHERE "username" = $1`,
user.Username).Scan(&id) user.Username).Scan(&id)
@ -172,7 +174,7 @@ func (user User) userGetID() uint {
func (user User) userGetSystemName() string { func (user User) userGetSystemName() string {
var systemName string var systemName string
err := db.QueryRow(`SELECT name err := DB.QueryRow(`SELECT name
FROM systems FROM systems
WHERE user_id in (select id from users where username = $1) WHERE user_id in (select id from users where username = $1)
AND mac = $2`, AND mac = $2`,
@ -185,7 +187,7 @@ func (user User) userGetSystemName() string {
func (user User) usernameExists() bool { func (user User) usernameExists() bool {
var exists bool var exists bool
err := db.QueryRow(`SELECT exists (select id FROM users WHERE "username" = $1)`, err := DB.QueryRow(`SELECT exists (select id FROM users WHERE "username" = $1)`,
user.Username).Scan(&exists) user.Username).Scan(&exists)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
log.Fatalf("error checking if row exists %v", err) log.Fatalf("error checking if row exists %v", err)
@ -195,7 +197,7 @@ func (user User) usernameExists() bool {
func (user User) emailExists() bool { func (user User) emailExists() bool {
var exists bool var exists bool
err := db.QueryRow(`SELECT exists (select id FROM users WHERE "email" = $1)`, err := DB.QueryRow(`SELECT exists (select id FROM users WHERE "email" = $1)`,
user.Email).Scan(&exists) user.Email).Scan(&exists)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
log.Fatalf("error checking if row exists %v", err) log.Fatalf("error checking if row exists %v", err)
@ -205,7 +207,7 @@ func (user User) emailExists() bool {
func (user User) userCreate() int64 { func (user User) userCreate() int64 {
user.Password = hashAndSalt(user.Password) user.Password = hashAndSalt(user.Password)
res, err := db.Exec(`INSERT INTO users("registration_code", "username","password","email") res, err := DB.Exec(`INSERT INTO users("registration_code", "username","password","email")
VALUES ($1,$2,$3,$4) ON CONFLICT(username) do nothing`, user.RegistrationCode, VALUES ($1,$2,$3,$4) ON CONFLICT(username) do nothing`, user.RegistrationCode,
user.Username, user.Password, user.Email) user.Username, user.Password, user.Email)
if err != nil { if err != nil {
@ -220,7 +222,7 @@ func (user User) userCreate() int64 {
func (cmd Command) commandInsert() int64 { func (cmd Command) commandInsert() int64 {
res, err := db.Exec(` res, err := DB.Exec(`
INSERT INTO commands("process_id","process_start_time","exit_status","uuid","command", "created", "path", "user_id", "system_name") INSERT INTO commands("process_id","process_start_time","exit_status","uuid","command", "created", "path", "user_id", "system_name")
VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9) ON CONFLICT do nothing`, VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9) ON CONFLICT do nothing`,
cmd.ProcessId, cmd.ProcessStartTime, cmd.ExitStatus, cmd.Uuid, cmd.Command, cmd.Created, cmd.Path, cmd.User.ID, cmd.SystemName) cmd.ProcessId, cmd.ProcessStartTime, cmd.ExitStatus, cmd.Uuid, cmd.Command, cmd.Created, cmd.Path, cmd.User.ID, cmd.SystemName)
@ -235,106 +237,107 @@ func (cmd Command) commandInsert() int64 {
} }
func (cmd Command) commandGet() ([]Query, error) { func (cmd Command) commandGet() ([]Query, error) {
var results []Query var (
var rows *sql.Rows results []Query
var err error query string
)
if cmd.Unique || cmd.Query != "" { if cmd.Unique || cmd.Query != "" {
//postgres //postgres
if connectionLimit != 1 { if connectionLimit != 1 {
if cmd.SystemName != "" && cmd.Path != "" && cmd.Query != "" && cmd.Unique { if cmd.SystemName != "" && cmd.Path != "" && cmd.Query != "" && cmd.Unique {
rows, err = db.Query(` query = fmt.Sprintf(`
SELECT * FROM ( SELECT * FROM (
SELECT DISTINCT ON ("command") command, "uuid", "created" SELECT DISTINCT ON ("command") command, "uuid", "created"
FROM commands FROM commands
WHERE "user_id" = $1 WHERE "user_id" = '%v'
AND "path" = $3 AND "path" = '%v'
AND "system_name" = $4 AND "system_name" = '%v'
AND "command" ~ $5 AND "command" ~ '%v'
) c ) c
ORDER BY "created" DESC limit $2;`, cmd.User.ID, cmd.Limit, cmd.Path, cmd.SystemName, cmd.Query) ORDER BY "created" DESC limit '%v';`, cmd.User.ID, cmd.Path, cmd.SystemName, cmd.Query, cmd.Limit,)
} else if cmd.Path != "" && cmd.Query != "" && cmd.Unique { } else if cmd.Path != "" && cmd.Query != "" && cmd.Unique {
rows, err = db.Query(` query = fmt.Sprintf(`
SELECT * FROM ( SELECT * FROM (
SELECT DISTINCT ON ("command") command, "uuid", "created" SELECT DISTINCT ON ("command") command, "uuid", "created"
FROM commands FROM commands
WHERE "user_id" = $1 WHERE "user_id" = '%v'
AND "path" = $3 AND "path" = '%v'
AND "command" ~ $4 AND "command" ~ '%v'
) c ) c
ORDER BY "created" DESC limit $2;`, cmd.User.ID, cmd.Limit, cmd.Path, cmd.Query) ORDER BY "created" DESC limit '%v';`, cmd.User.ID, cmd.Path, cmd.Query, cmd.Limit)
} else if cmd.SystemName != "" && cmd.Query != "" { } else if cmd.SystemName != "" && cmd.Query != "" {
rows, err = db.Query(` query = fmt.Sprintf(`
SELECT "command", "uuid", "created" SELECT "command", "uuid", "created"
FROM commands FROM commands
WHERE "user_id" = $1 WHERE "user_id" = '%v'
AND "system_name" = $3 AND "system_name" = '%v'
AND "command" ~ $4 AND "command" ~ '%v'
ORDER BY "created" DESC limit $2;`, cmd.User.ID, cmd.Limit, cmd.SystemName, cmd.Query) ORDER BY "created" DESC limit '%v';`, cmd.User.ID, cmd.SystemName, cmd.Query, cmd.Limit,)
} else if cmd.Path != "" && cmd.Query != "" { } else if cmd.Path != "" && cmd.Query != "" {
rows, err = db.Query(` query = fmt.Sprintf(`
SELECT "command", "uuid", "created" SELECT "command", "uuid", "created"
FROM commands FROM commands
WHERE "user_id" = $1 WHERE "user_id" = '%v'
AND "path" = $3 AND "path" = '%v'
AND "command" ~ $4 AND "command" ~ '%v'
ORDER BY "created" DESC limit $2;`, cmd.User.ID, cmd.Limit, cmd.Path, cmd.Query) ORDER BY "created" DESC limit '%v';`, cmd.User.ID, cmd.Path, cmd.Query, cmd.Limit)
} else if cmd.SystemName != "" && cmd.Unique { } else if cmd.SystemName != "" && cmd.Unique {
rows, err = db.Query(` query = fmt.Sprintf(`
SELECT * FROM ( SELECT * FROM (
SELECT DISTINCT ON ("command") command, "uuid", "created" SELECT DISTINCT ON ("command") command, "uuid", "created"
FROM commands FROM commands
WHERE "user_id" = $1 WHERE "user_id" = '%v'
AND "system_name" = $3 AND "system_name" = '%v'
) c ) c
ORDER BY "created" DESC limit $2;`, cmd.User.ID, cmd.Limit, cmd.SystemName) ORDER BY "created" DESC limit '%v';`, cmd.User.ID, cmd.SystemName, cmd.Limit, )
} else if cmd.Path != "" && cmd.Unique { } else if cmd.Path != "" && cmd.Unique {
rows, err = db.Query(` query = fmt.Sprintf(`
SELECT * FROM ( SELECT * FROM (
SELECT DISTINCT ON ("command") command, "uuid", "created" SELECT DISTINCT ON ("command") command, "uuid", "created"
FROM commands FROM commands
WHERE "user_id" = $1 WHERE "user_id" = '%v'
AND "path" = $3 AND "path" = '%v'
) c ) c
ORDER BY "created" DESC limit $2;`, cmd.User.ID, cmd.Limit, cmd.Path) ORDER BY "created" DESC limit '%v';`, cmd.User.ID, cmd.Path, cmd.Limit)
} else if cmd.Query != "" && cmd.Unique { } else if cmd.Query != "" && cmd.Unique {
rows, err = db.Query(` query = fmt.Sprintf(`
SELECT * FROM ( SELECT * FROM (
SELECT DISTINCT ON ("command") command, "uuid", "created" SELECT DISTINCT ON ("command") command, "uuid", "created"
FROM commands FROM commands
WHERE "user_id" = $1 WHERE "user_id" = '%v'
AND "command" ~ $3 AND "command" ~ '%v'
) c ) c
ORDER BY "created" DESC limit $2;`, cmd.User.ID, cmd.Limit, cmd.Query) ORDER BY "created" DESC limit '%v';`, cmd.User.ID, cmd.Query, cmd.Limit,)
} else if cmd.Query != "" { } else if cmd.Query != "" {
rows, err = db.Query(` query = fmt.Sprintf(`
SELECT "command", "uuid", "created" SELECT "command", "uuid", "created"
FROM commands FROM commands
WHERE "user_id" = $1 WHERE "user_id" = '%v'
AND "command" ~ $3 AND "command" ~ '%v'
ORDER BY "created" DESC limit $2;`, cmd.User.ID, cmd.Limit, cmd.Query) ORDER BY "created" DESC limit '%v';`, cmd.User.ID,cmd.Query, cmd.Limit, )
} else { } else {
// unique // unique
rows, err = db.Query(` query = fmt.Sprintf(`
SELECT * FROM ( SELECT * FROM (
SELECT DISTINCT ON ("command") command, "uuid", "created" SELECT DISTINCT ON ("command") command, "uuid", "created"
FROM commands FROM commands
WHERE "user_id" = $1 WHERE "user_id" = '%v'
) c ) c
ORDER BY "created" DESC limit $2;`, cmd.User.ID, cmd.Limit) ORDER BY "created" DESC limit '%v';`, cmd.User.ID, cmd.Limit)
} }
} else { } else {
// sqlite // sqlite
if cmd.SystemName != "" && cmd.Path != "" && cmd.Query != "" && cmd.Unique { if cmd.SystemName != "" && cmd.Path != "" && cmd.Query != "" && cmd.Unique {
// Have to use fmt.Sprintf to build queries where sqlite regexp function is used because of single quotes. Haven't found any other work around. // Have to use fmt.Sprintf to build queries where sqlite regexp function is used because of single quotes. Haven't found any other work around.
query := fmt.Sprintf(` query = fmt.Sprintf(`
SELECT "command", "uuid", "created" FROM commands SELECT "command", "uuid", "created" FROM commands
WHERE "user_id" = '%v' WHERE "user_id" = '%v'
AND "path" = '%v' AND "path" = '%v'
@ -342,120 +345,120 @@ func (cmd Command) commandGet() ([]Query, error) {
AND "command" regexp '%v' AND "command" regexp '%v'
GROUP BY "command" ORDER BY "created" DESC limit '%v'`, cmd.User.ID, cmd.Path, cmd.SystemName, cmd.Query, cmd.Limit) GROUP BY "command" ORDER BY "created" DESC limit '%v'`, cmd.User.ID, cmd.Path, cmd.SystemName, cmd.Query, cmd.Limit)
rows, err = db.Query(query)
} else if cmd.SystemName != "" && cmd.Query != "" && cmd.Unique { } else if cmd.SystemName != "" && cmd.Query != "" && cmd.Unique {
query := fmt.Sprintf(` query = fmt.Sprintf(`
SELECT "command", "uuid", "created" FROM commands SELECT "command", "uuid", "created" FROM commands
WHERE "user_id" = '%v' WHERE "user_id" = '%v'
AND "system_name" = '%v' AND "system_name" = '%v'
AND "command" regexp '%v' AND "command" regexp '%v'
GROUP BY "command" ORDER BY "created" DESC limit '%v'`, cmd.User.ID, cmd.SystemName, cmd.Query, cmd.Limit) GROUP BY "command" ORDER BY "created" DESC limit '%v'`, cmd.User.ID, cmd.SystemName, cmd.Query, cmd.Limit)
rows, err = db.Query(query)
} else if cmd.Path != "" && cmd.Query != "" && cmd.Unique { } else if cmd.Path != "" && cmd.Query != "" && cmd.Unique {
query := fmt.Sprintf(` query = fmt.Sprintf(`
SELECT "command", "uuid", "created" FROM commands SELECT "command", "uuid", "created" FROM commands
WHERE "user_id" = '%v' WHERE "user_id" = '%v'
AND "path" = '%v' AND "path" = '%v'
AND "command" regexp '%v' AND "command" regexp '%v'
GROUP BY "command" ORDER BY "created" DESC limit '%v'`, cmd.User.ID, cmd.Path, cmd.Query, cmd.Limit) GROUP BY "command" ORDER BY "created" DESC limit '%v'`, cmd.User.ID, cmd.Path, cmd.Query, cmd.Limit)
rows, err = db.Query(query)
} else if cmd.SystemName != "" && cmd.Query != "" { } else if cmd.SystemName != "" && cmd.Query != "" {
query := fmt.Sprintf(` query = fmt.Sprintf(`
SELECT "command", "uuid", "created" FROM commands SELECT "command", "uuid", "created" FROM commands
WHERE "user_id" = '%v' WHERE "user_id" = '%v'
AND "system_name" = %v' AND "system_name" = '%v'
AND "command" regexp %v' AND "command" regexp '%v'
ORDER BY "created" DESC limit '%v'`, cmd.User.ID, cmd.SystemName, cmd.Query, cmd.Limit) ORDER BY "created" DESC limit '%v'`, cmd.User.ID, cmd.SystemName, cmd.Query, cmd.Limit)
if QueryDebug {
rows, err = db.Query(query) log.Println(query)
}
} else if cmd.Path != "" && cmd.Query != "" { } else if cmd.Path != "" && cmd.Query != "" {
query := fmt.Sprintf(` query = fmt.Sprintf(`
SELECT "command", "uuid", "created" FROM commands SELECT "command", "uuid", "created" FROM commands
WHERE "user_id" = '%v' WHERE "user_id" = '%v'
AND "path" = %v' AND "path" = '%v'
AND "command" regexp %v' AND "command" regexp '%v'
ORDER BY "created" DESC limit '%v'`, cmd.User.ID, cmd.Path, cmd.Query, cmd.Limit) ORDER BY "created" DESC limit '%v'`, cmd.User.ID, cmd.Path, cmd.Query, cmd.Limit)
rows, err = db.Query(query)
} else if cmd.SystemName != "" && cmd.Unique { } else if cmd.SystemName != "" && cmd.Unique {
rows, err = db.Query(` query = fmt.Sprintf(`
SELECT "command", "uuid", "created" FROM commands SELECT "command", "uuid", "created" FROM commands
WHERE "user_id" = $1 WHERE "user_id" = '%v'
AND "system_name" = $2 AND "system_name" = '%v'
GROUP BY "command" ORDER BY "created" DESC limit $3`, cmd.User.ID, cmd.SystemName, cmd.Limit) GROUP BY "command" ORDER BY "created" DESC limit '%v'`, cmd.User.ID, cmd.SystemName, cmd.Limit)
} else if cmd.Path != "" && cmd.Unique { } else if cmd.Path != "" && cmd.Unique {
rows, err = db.Query(` query = fmt.Sprintf(`
SELECT "command", "uuid", "created" FROM commands SELECT "command", "uuid", "created" FROM commands
WHERE "user_id" = $1 WHERE "user_id" = '%v'
AND "path" = $2 AND "path" = '%v'
GROUP BY "command" ORDER BY "created" DESC limit $3`, cmd.User.ID, cmd.Path, cmd.Limit) GROUP BY "command" ORDER BY "created" DESC limit '%v'`, cmd.User.ID, cmd.Path, cmd.Limit)
} else if cmd.Query != "" && cmd.Unique { } else if cmd.Query != "" && cmd.Unique {
query := fmt.Sprintf(` query = fmt.Sprintf(`
SELECT "command", "uuid", "created" FROM commands SELECT "command", "uuid", "created" FROM commands
WHERE "user_id" = '%v' WHERE "user_id" = '%v'
AND "command" regexp '%v' AND "command" regexp '%v'
GROUP BY "command" ORDER BY "created" DESC limit '%v'`, cmd.User.ID, cmd.Query, cmd.Limit) GROUP BY "command" ORDER BY "created" DESC limit '%v'`, cmd.User.ID, cmd.Query, cmd.Limit)
rows, err = db.Query(query)
} else if cmd.Query != "" { } else if cmd.Query != "" {
query := fmt.Sprintf(` query = fmt.Sprintf(`
SELECT "command", "uuid", "created" FROM commands SELECT "command", "uuid", "created" FROM commands
WHERE "user_id" = '%v' WHERE "user_id" = '%v'
AND "command" regexp'%v' AND "command" regexp'%v'
ORDER BY "created" DESC limit '%v'`, cmd.User.ID, cmd.Query, cmd.Limit) ORDER BY "created" DESC limit '%v'`, cmd.User.ID, cmd.Query, cmd.Limit)
rows, err = db.Query(query)
} else { } else {
// unique // unique
rows, err = db.Query(` query = fmt.Sprintf(`
SELECT "command", "uuid", "created" SELECT "command", "uuid", "created"
FROM commands FROM commands
WHERE "user_id" = $1 WHERE "user_id" = '%v'
GROUP BY "command" ORDER BY "created" DESC limit $2;`, cmd.User.ID, cmd.Limit) GROUP BY "command" ORDER BY "created" DESC limit '%v';`, cmd.User.ID, cmd.Limit)
} }
} }
} else { } else {
if cmd.Path != "" { if cmd.Path != "" {
rows, err = db.Query(` query = fmt.Sprintf(`
SELECT "command", "uuid", "created" FROM commands SELECT "command", "uuid", "created" FROM commands
WHERE "user_id" = $1 WHERE "user_id" = '%v'
AND "path" = $3 AND "path" = '%v'
ORDER BY "created" DESC limit $2`, cmd.User.ID, cmd.Limit, cmd.Path) ORDER BY "created" DESC limit '%v'`, cmd.User.ID, cmd.Path, cmd.Limit)
} else if cmd.SystemName != "" { } else if cmd.SystemName != "" {
rows, err = db.Query(` query = fmt.Sprintf(`SELECT "command", "uuid", "created" FROM commands
WHERE "user_id" = '%v'
AND "system_name" = '%v'
ORDER BY "created" DESC limit '%v'`, cmd.User.ID, cmd.SystemName, cmd.Limit)
} else {
query = fmt.Sprintf(`
SELECT "command", "uuid", "created" FROM commands SELECT "command", "uuid", "created" FROM commands
WHERE "user_id" = $1 WHERE "user_id" = '%v'
AND "system_name" = $3 ORDER BY "created" DESC limit '%v'`, cmd.User.ID, cmd.Limit)
ORDER BY "created" DESC limit $2`, cmd.User.ID, cmd.Limit, cmd.SystemName)
} else {
rows, err = db.Query(`
SELECT "command", "uuid", "created" FROM commands
WHERE "user_id" = $1
ORDER BY "created" DESC limit $2`, cmd.User.ID, cmd.Limit)
} }
} }
if QueryDebug {
fmt.Println(query)
}
rows, err := DB.Query(query)
if err != nil { if err != nil {
return []Query{}, nil return []Query{}, err
} }
defer rows.Close() defer rows.Close()
for rows.Next() { for rows.Next() {
var result Query var result Query
err = rows.Scan(&result.Command, &result.Uuid, &result.Created) err = rows.Scan(&result.Command, &result.Uuid, &result.Created)
if err != nil { if err != nil {
return []Query{}, nil return []Query{}, err
} }
results = append(results, result) results = append(results, result)
} }
@ -465,7 +468,7 @@ func (cmd Command) commandGet() ([]Query, error) {
func (cmd Command) commandGetUUID() (Query, error) { func (cmd Command) commandGetUUID() (Query, error) {
var result Query var result Query
err := db.QueryRow(` err := DB.QueryRow(`
SELECT "command","path", "created" , "uuid", "exit_status", "system_name" SELECT "command","path", "created" , "uuid", "exit_status", "system_name"
FROM commands FROM commands
WHERE "uuid" = $1 WHERE "uuid" = $1
@ -478,7 +481,7 @@ func (cmd Command) commandGetUUID() (Query, error) {
} }
func (cmd Command) commandDelete() int64 { func (cmd Command) commandDelete() int64 {
res, err := db.Exec(` res, err := DB.Exec(`
DELETE FROM commands WHERE "user_id" = $1 AND "uuid" = $2 `, cmd.User.ID, cmd.Uuid) DELETE FROM commands WHERE "user_id" = $1 AND "uuid" = $2 `, cmd.User.ID, cmd.Uuid)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
@ -494,7 +497,7 @@ func (cmd Command) commandDelete() int64 {
func (sys System) systemUpdate() int64 { func (sys System) systemUpdate() int64 {
t := time.Now().Unix() t := time.Now().Unix()
res, err := db.Exec(` res, err := DB.Exec(`
UPDATE systems UPDATE systems
SET "hostname" = $1 , "updated" = $2 SET "hostname" = $1 , "updated" = $2
WHERE "user_id" = $3 WHERE "user_id" = $3
@ -513,7 +516,7 @@ func (sys System) systemUpdate() int64 {
func (sys System) systemInsert() int64 { func (sys System) systemInsert() int64 {
t := time.Now().Unix() t := time.Now().Unix()
res, err := db.Exec(`INSERT INTO systems ("name", "mac", "user_id", "hostname", "client_version", "created", "updated") res, err := DB.Exec(`INSERT INTO systems ("name", "mac", "user_id", "hostname", "client_version", "created", "updated")
VALUES ($1, $2, $3, $4, $5, $6, $7)`, VALUES ($1, $2, $3, $4, $5, $6, $7)`,
sys.Name, sys.Mac, sys.User.ID, sys.Hostname, sys.ClientVersion, t, t) sys.Name, sys.Mac, sys.User.ID, sys.Hostname, sys.ClientVersion, t, t)
if err != nil { if err != nil {
@ -528,7 +531,7 @@ func (sys System) systemInsert() int64 {
func (sys System) systemGet() (System, error) { func (sys System) systemGet() (System, error) {
var row System var row System
err := db.QueryRow(`SELECT "name", "mac", "user_id", "hostname", "client_version", err := DB.QueryRow(`SELECT "name", "mac", "user_id", "hostname", "client_version",
"id", "created", "updated" FROM systems "id", "created", "updated" FROM systems
WHERE "user_id" = $1 WHERE "user_id" = $1
AND "mac" = $2`, AND "mac" = $2`,
@ -544,7 +547,7 @@ func (sys System) systemGet() (System, error) {
func (status Status) statusGet() (Status, error) { func (status Status) statusGet() (Status, error) {
var err error var err error
if connectionLimit != 1 { if connectionLimit != 1 {
err = db.QueryRow(`select err = DB.QueryRow(`select
( select count(*) from commands where user_id = $1) as totalCommands, ( select count(*) from commands where user_id = $1) as totalCommands,
( select count(distinct process_id) from commands where user_id = $1) as totalSessions, ( select count(distinct process_id) from commands where user_id = $1) as totalSessions,
( select count(distinct system_name) from commands where user_id = $1) as totalSystems, ( select count(distinct system_name) from commands where user_id = $1) as totalSystems,
@ -554,7 +557,7 @@ func (status Status) statusGet() (Status, error) {
&status.TotalCommands, &status.TotalSessions, &status.TotalSystems, &status.TotalCommands, &status.TotalSessions, &status.TotalSystems,
&status.TotalCommandsToday, &status.SessionTotalCommands) &status.TotalCommandsToday, &status.SessionTotalCommands)
} else { } else {
err = db.QueryRow(`select err = DB.QueryRow(`select
( select count(*) from commands where user_id = $1) as totalCommands, ( select count(*) from commands where user_id = $1) as totalCommands,
( select count(distinct process_id) from commands where user_id = $1) as totalSessions, ( select count(distinct process_id) from commands where user_id = $1) as totalSessions,
( select count(distinct system_name) from commands where user_id = $1) as totalSystems, ( select count(distinct system_name) from commands where user_id = $1) as totalSystems,
@ -571,7 +574,7 @@ func (status Status) statusGet() (Status, error) {
} }
func importCommands(imp Import) { func importCommands(imp Import) {
_, err := db.Exec(`INSERT INTO commands _, err := DB.Exec(`INSERT INTO commands
("command", "path", "created", "uuid", "exit_status", ("command", "path", "created", "uuid", "exit_status",
"system_name", "session_id", "user_id" ) "system_name", "session_id", "user_id" )
VALUES ($1,$2,$3,$4,$5,$6,$7,(select "id" from users where "username" = $8)) ON CONFLICT do nothing`, VALUES ($1,$2,$3,$4,$5,$6,$7,(select "id" from users where "username" = $8)) ON CONFLICT do nothing`,

View file

@ -278,10 +278,9 @@ func SetupRouter() *gin.Engine {
command.Limit = num command.Limit = num
} }
} }
command.Unique = false
if c.Query("unique") == "true" { if c.Query("unique") == "true" {
command.Unique = true command.Unique = true
} else {
command.Unique = false
} }
command.Path = c.Query("path") command.Path = c.Query("path")
command.Query = c.Query("query") command.Query = c.Query("query")

View file

@ -21,6 +21,7 @@ package test
import ( import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"flag"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"log" "log"
@ -43,16 +44,34 @@ var (
router *gin.Engine router *gin.Engine
sysRegistered bool sysRegistered bool
jwtToken string jwtToken string
user = "tester" db = flag.String("db", sqliteDB(), "db path")
pass = "tester" postgres = flag.Bool("postgres", false, "run postgres tests")
mac = "888888888888888" postgresUri = flag.String("postgres-uri", "postgres://postgres:@localhost:5444?sslmode=disable", "postgres uri to use for postgres tests")
) )
const (
system = "system"
user = "tester"
pass = "tester"
mac = "888888888888888"
email = "test@email.com"
testdir = "testdata"
)
func sqliteDB() string {
return filepath.Join(dir, "testdata/test.db")
}
func check(err error) {
if err != nil {
log.Fatal(err)
}
}
func createUser(t *testing.T) { func createUser(t *testing.T) {
auth := map[string]interface{}{ auth := map[string]interface{}{
"Username": user, "Username": user,
"password": pass, "password": pass,
"email": "test@email.com", "email": email,
} }
payloadBytes, err := json.Marshal(auth) payloadBytes, err := json.Marshal(auth)
@ -118,14 +137,12 @@ func sysRegister(t *testing.T, token string) string {
} }
sys := map[string]interface{}{ sys := map[string]interface{}{
"clientVersion": "1.2.0", "clientVersion": "1.2.0",
"name": "test-system", "name": system,
"hostname": host, "hostname": host,
"mac": mac, "mac": mac,
} }
payloadBytes, err := json.Marshal(sys) payloadBytes, err := json.Marshal(sys)
if err != nil { check(err)
log.Fatal(err)
}
body := bytes.NewReader(payloadBytes) body := bytes.NewReader(payloadBytes)
@ -145,24 +162,35 @@ func sysRegister(t *testing.T, token string) string {
} }
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
err := os.RemoveAll("testdata")
if err != nil { flag.Parse()
log.Fatal(err)
if *db == sqliteDB() {
err := os.RemoveAll(testdir)
if err != nil {
log.Fatal(err)
}
err = os.Mkdir(testdir, 0700)
if err != nil {
log.Fatal(err)
}
} }
var err error
dir, err = os.Getwd() dir, err = os.Getwd()
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
err = os.Mkdir("testdata", 0700) internal.DbPath = *db
if err != nil {
log.Fatal(err)
}
internal.DbPath = filepath.Join(dir, "test.db")
router = internal.SetupRouter() router = internal.SetupRouter()
m.Run() m.Run()
if *postgres {
internal.DbPath = *postgresUri
router = internal.SetupRouter()
m.Run()
}
} }
func TestToken(t *testing.T) { func TestToken(t *testing.T) {
@ -221,9 +249,75 @@ func TestCommand(t *testing.T) {
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
} }
} }
var allQueries = map[string]string{
"unique": "true",
"limit": "1",
"query": "curl",
"path": dir,
"systemName": system,
}
var queryTests []url.Values
allQuery := url.Values{}
for keyP, valP := range allQueries {
allQuery.Add(keyP, valP)
for kepC, valC := range allQueries {
if keyP == kepC {
continue
}
v := url.Values{}
v.Add(kepC, valC)
v.Add(keyP, valP)
queryTests = append(queryTests, v)
}
}
func() { func() {
w := httptest.NewRecorder() w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/api/v1/command/search?unique=true", nil) u := fmt.Sprintf("/api/v1/command/search?%v", allQuery.Encode())
req, _ := http.NewRequest("GET", u, nil)
req.Header.Set("Content-Type", "application/json")
req.Header.Add("Authorization", jwtToken)
router.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
}()
for _, v := range queryTests {
func() {
w := httptest.NewRecorder()
u := fmt.Sprintf("/api/v1/command/search?%v", v.Encode())
req, _ := http.NewRequest("GET", u, nil)
req.Header.Set("Content-Type", "application/json")
req.Header.Add("Authorization", jwtToken)
router.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
b, err := ioutil.ReadAll(w.Body)
if err != nil {
t.Fatal(err)
}
var data []internal.Query
err = json.Unmarshal(b, &data)
if err != nil {
t.Fatal(err)
}
assert.GreaterOrEqual(t, len(data), 1)
assert.Contains(t, system, data[0].SystemName)
assert.Contains(t, dir, data[0].Path)
}()
}
func() {
w := httptest.NewRecorder()
v := url.Values{}
v.Add("unique", "true")
u := fmt.Sprintf("/api/v1/command/search?%v", v.Encode())
req, _ := http.NewRequest("GET", u, nil)
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
req.Header.Add("Authorization", jwtToken) req.Header.Add("Authorization", jwtToken)
router.ServeHTTP(w, req) router.ServeHTTP(w, req)
@ -259,7 +353,11 @@ func TestCommand(t *testing.T) {
}() }()
func() { func() {
w := httptest.NewRecorder() w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/api/v1/command/search?query=%5Ecurl&unique=true", nil) v := url.Values{}
v.Add("query", "curl")
v.Add("unique", "true")
u := fmt.Sprintf("/api/v1/command/search?%v", v.Encode())
req, _ := http.NewRequest("GET", u, nil)
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
req.Header.Add("Authorization", jwtToken) req.Header.Add("Authorization", jwtToken)
router.ServeHTTP(w, req) router.ServeHTTP(w, req)
@ -277,7 +375,11 @@ func TestCommand(t *testing.T) {
}() }()
func() { func() {
w := httptest.NewRecorder() w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/api/v1/command/search?unique=true&systemName=test-system", nil) v := url.Values{}
v.Add("unique", "true")
v.Add("systemName", system)
u := fmt.Sprintf("/api/v1/command/search?%v", v.Encode())
req, _ := http.NewRequest("GET", u, nil)
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
req.Header.Add("Authorization", jwtToken) req.Header.Add("Authorization", jwtToken)
router.ServeHTTP(w, req) router.ServeHTTP(w, req)
@ -296,8 +398,8 @@ func TestCommand(t *testing.T) {
func() { func() {
w := httptest.NewRecorder() w := httptest.NewRecorder()
v := url.Values{} v := url.Values{}
v.Add("unique", "true")
v.Add("path", dir) v.Add("path", dir)
v.Add("unique", "true")
u := fmt.Sprintf("/api/v1/command/search?%v", v.Encode()) u := fmt.Sprintf("/api/v1/command/search?%v", v.Encode())
req, _ := http.NewRequest("GET", u, nil) req, _ := http.NewRequest("GET", u, nil)
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
@ -315,4 +417,75 @@ func TestCommand(t *testing.T) {
} }
assert.Equal(t, 10, len(data)) assert.Equal(t, 10, len(data))
}() }()
var record internal.Command
func() {
w := httptest.NewRecorder()
v := url.Values{}
v.Add("limit","1")
v.Add("unique", "true")
u := fmt.Sprintf("/api/v1/command/search?%v", v.Encode())
req, _ := http.NewRequest("GET", u, nil)
req.Header.Set("Content-Type", "application/json")
req.Header.Add("Authorization", jwtToken)
router.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
b, err := ioutil.ReadAll(w.Body)
if err != nil {
t.Fatal(err)
}
var data []internal.Command
err = json.Unmarshal(b, &data)
if err != nil {
t.Fatal(err)
}
assert.Equal(t, 1, len(data))
record = data[0]
}()
func() {
w := httptest.NewRecorder()
u := fmt.Sprintf("/api/v1/command/%v", record.Uuid)
req, _ := http.NewRequest("GET", u, nil)
req.Header.Set("Content-Type", "application/json")
req.Header.Add("Authorization", jwtToken)
router.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
b, err := ioutil.ReadAll(w.Body)
if err != nil {
t.Fatal(err)
}
var data internal.Command
err = json.Unmarshal(b, &data)
if err != nil {
t.Fatal(err)
}
assert.Equal(t, record.Uuid, data.Uuid)
}()
func() {
w := httptest.NewRecorder()
u := fmt.Sprintf("/api/v1/command/%v", record.Uuid)
req, _ := http.NewRequest("DELETE", u, nil)
req.Header.Set("Content-Type", "application/json")
req.Header.Add("Authorization", jwtToken)
router.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
}()
func() {
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/api/v1/command/search?", nil)
req.Header.Set("Content-Type", "application/json")
req.Header.Add("Authorization", jwtToken)
router.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
b, err := ioutil.ReadAll(w.Body)
if err != nil {
t.Fatal(err)
}
var data []internal.Command
err = json.Unmarshal(b, &data)
if err != nil {
t.Fatal(err)
}
assert.Equal(t, 49, len(data))
}()
} }

BIN
test/testdata/test.db vendored Normal file

Binary file not shown.

BIN
test/testdata/test.db-shm vendored Normal file

Binary file not shown.

BIN
test/testdata/test.db-wal vendored Normal file

Binary file not shown.