server_test: status added; db: query refactor

This commit is contained in:
nicksherron 2020-02-14 15:23:07 -05:00
parent 629475e08f
commit 27c57b37a9
4 changed files with 251 additions and 235 deletions

23
cmd/transfer_test.go Normal file
View file

@ -0,0 +1,23 @@
/*
*
* Copyright © 2020 nicksherron <nsherron90@gmail.com>
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package cmd

View file

@ -45,7 +45,6 @@ var (
QueryDebug bool
)
// DbInit initializes our db.
func dbInit() {
var gormdb *gorm.DB
var err error
@ -238,7 +237,7 @@ func (cmd Command) commandInsert() int64 {
func (cmd Command) commandGet() ([]Query, error) {
var (
results []Query
query string
query string
)
if cmd.Unique || cmd.Query != "" {
//postgres
@ -273,7 +272,7 @@ func (cmd Command) commandGet() ([]Query, error) {
WHERE "user_id" = '%v'
AND "system_name" = '%v'
AND "command" ~ '%v'
ORDER BY "created" DESC limit '%v';`, cmd.User.ID, cmd.SystemName, cmd.Query, cmd.Limit,)
ORDER BY "created" DESC limit '%v';`, cmd.User.ID, cmd.SystemName, cmd.Query, cmd.Limit)
} else if cmd.Path != "" && cmd.Query != "" {
query = fmt.Sprintf(`
@ -282,7 +281,7 @@ func (cmd Command) commandGet() ([]Query, error) {
WHERE "user_id" = '%v'
AND "path" = '%v'
AND "command" ~ '%v'
ORDER BY "created" DESC limit '%v';`, cmd.User.ID, cmd.Path, cmd.Query, cmd.Limit)
ORDER BY "created" DESC limit '%v';`, cmd.User.ID, cmd.Path, cmd.Query, cmd.Limit)
} else if cmd.SystemName != "" && cmd.Unique {
query = fmt.Sprintf(`
@ -292,7 +291,7 @@ func (cmd Command) commandGet() ([]Query, error) {
WHERE "user_id" = '%v'
AND "system_name" = '%v'
) c
ORDER BY "created" DESC limit '%v';`, cmd.User.ID, cmd.SystemName, cmd.Limit, )
ORDER BY "created" DESC limit '%v';`, cmd.User.ID, cmd.SystemName, cmd.Limit)
} else if cmd.Path != "" && cmd.Unique {
query = fmt.Sprintf(`
@ -312,7 +311,7 @@ func (cmd Command) commandGet() ([]Query, error) {
WHERE "user_id" = '%v'
AND "command" ~ '%v'
) c
ORDER BY "created" DESC limit '%v';`, cmd.User.ID, cmd.Query, cmd.Limit,)
ORDER BY "created" DESC limit '%v';`, cmd.User.ID, cmd.Query, cmd.Limit)
} else if cmd.Query != "" {
query = fmt.Sprintf(`
@ -320,7 +319,7 @@ func (cmd Command) commandGet() ([]Query, error) {
FROM commands
WHERE "user_id" = '%v'
AND "command" ~ '%v'
ORDER BY "created" DESC limit '%v';`, cmd.User.ID,cmd.Query, cmd.Limit, )
ORDER BY "created" DESC limit '%v';`, cmd.User.ID, cmd.Query, cmd.Limit)
} else {
// unique
@ -344,7 +343,6 @@ func (cmd Command) commandGet() ([]Query, error) {
AND "command" regexp '%v'
GROUP BY "command" ORDER BY "created" DESC limit '%v'`, cmd.User.ID, cmd.Path, cmd.SystemName, cmd.Query, cmd.Limit)
} else if cmd.SystemName != "" && cmd.Query != "" && cmd.Unique {
query = fmt.Sprintf(`
SELECT "command", "uuid", "created" FROM commands
@ -353,7 +351,6 @@ func (cmd Command) commandGet() ([]Query, error) {
AND "command" regexp '%v'
GROUP BY "command" ORDER BY "created" DESC limit '%v'`, cmd.User.ID, cmd.SystemName, cmd.Query, cmd.Limit)
} else if cmd.Path != "" && cmd.Query != "" && cmd.Unique {
query = fmt.Sprintf(`
SELECT "command", "uuid", "created" FROM commands
@ -362,7 +359,6 @@ func (cmd Command) commandGet() ([]Query, error) {
AND "command" regexp '%v'
GROUP BY "command" ORDER BY "created" DESC limit '%v'`, cmd.User.ID, cmd.Path, cmd.Query, cmd.Limit)
} else if cmd.SystemName != "" && cmd.Query != "" {
query = fmt.Sprintf(`
SELECT "command", "uuid", "created" FROM commands
@ -373,7 +369,7 @@ func (cmd Command) commandGet() ([]Query, error) {
if QueryDebug {
log.Println(query)
}
} else if cmd.Path != "" && cmd.Query != "" {
query = fmt.Sprintf(`
SELECT "command", "uuid", "created" FROM commands
@ -382,7 +378,6 @@ func (cmd Command) commandGet() ([]Query, error) {
AND "command" regexp '%v'
ORDER BY "created" DESC limit '%v'`, cmd.User.ID, cmd.Path, cmd.Query, cmd.Limit)
} else if cmd.SystemName != "" && cmd.Unique {
query = fmt.Sprintf(`
SELECT "command", "uuid", "created" FROM commands
@ -404,7 +399,6 @@ func (cmd Command) commandGet() ([]Query, error) {
AND "command" regexp '%v'
GROUP BY "command" ORDER BY "created" DESC limit '%v'`, cmd.User.ID, cmd.Query, cmd.Limit)
} else if cmd.Query != "" {
query = fmt.Sprintf(`
SELECT "command", "uuid", "created" FROM commands
@ -412,7 +406,6 @@ func (cmd Command) commandGet() ([]Query, error) {
AND "command" regexp'%v'
ORDER BY "created" DESC limit '%v'`, cmd.User.ID, cmd.Query, cmd.Limit)
} else {
// unique
query = fmt.Sprintf(`
@ -428,14 +421,14 @@ func (cmd Command) commandGet() ([]Query, error) {
SELECT "command", "uuid", "created" FROM commands
WHERE "user_id" = '%v'
AND "path" = '%v'
ORDER BY "created" DESC limit '%v'`, cmd.User.ID, cmd.Path, cmd.Limit)
ORDER BY "created" DESC limit '%v'`, cmd.User.ID, cmd.Path, cmd.Limit)
} else if cmd.SystemName != "" {
query = fmt.Sprintf(`SELECT "command", "uuid", "created" FROM commands
WHERE "user_id" = '%v'
AND "system_name" = '%v'
ORDER BY "created" DESC limit '%v'`, cmd.User.ID, cmd.SystemName, cmd.Limit)
} else {
} else {
query = fmt.Sprintf(`
SELECT "command", "uuid", "created" FROM commands
WHERE "user_id" = '%v'
@ -468,11 +461,11 @@ func (cmd Command) commandGet() ([]Query, error) {
func (cmd Command) commandGetUUID() (Query, error) {
var result Query
err := db.QueryRow(`
SELECT "command","path", "created" , "uuid", "exit_status", "system_name"
SELECT "command","path", "created" , "uuid", "exit_status", "system_name", "process_id"
FROM commands
WHERE "uuid" = $1
AND "user_id" = $2`, cmd.Uuid, cmd.User.ID).Scan(&result.Command, &result.Path, &result.Created, &result.Uuid,
&result.ExitStatus, &result.SystemName)
&result.ExitStatus, &result.SystemName, &result.SessionID)
if err != nil {
return Query{}, err
}
@ -546,22 +539,24 @@ func (sys System) systemGet() (System, error) {
func (status Status) statusGet() (Status, error) {
var err error
if connectionLimit != 1 {
err = db.QueryRow(`select
( select count(*) from commands where user_id = $1) as totalCommands,
( select count(distinct process_id) from commands where user_id = $1) as totalSessions,
( select count(distinct system_name) from commands where user_id = $1) as totalSystems,
( select count (*) from commands where to_timestamp(cast(created/1000 as bigint))::date = now()::date and user_id = $1) as totalCommandsToday,
( select count(*) from commands where process_id = $2) as sessionTotalCommands`,
err = db.QueryRow(`
select
( select count(*) from commands where user_id = $1) as totalCommands,
( select count(distinct process_id) from commands where user_id = $1) as totalSessions,
( select count(*) from systems where user_id = $1) as totalSystems,
( select count(*) from commands where to_timestamp(cast(created/1000 as bigint))::date = now()::date and user_id = $1) as totalCommandsToday,
( select count(*) from commands where process_id = $2) as sessionTotalCommands`,
status.User.ID, status.ProcessID).Scan(
&status.TotalCommands, &status.TotalSessions, &status.TotalSystems,
&status.TotalCommandsToday, &status.SessionTotalCommands)
} else {
err = db.QueryRow(`select
( select count(*) from commands where user_id = $1) as totalCommands,
( select count(distinct process_id) from commands where user_id = $1) as totalSessions,
( select count(distinct system_name) from commands where user_id = $1) as totalSystems,
( select count(*) from commands where date(created/1000, 'unixepoch') = date('now') and user_id = $1) as totalCommandsToday,
( select count(*) from commands where process_id = $2) as sessionTotalCommands`,
err = db.QueryRow(`
select
( select count(*) from commands where user_id = $1) as totalCommands,
( select count(distinct process_id) from commands where user_id = $1) as totalSessions,
( select count(*) from systems where user_id = $1) as totalSystems,
( select count(*) from commands where date(created/1000, 'unixepoch') = date('now') and user_id = $1) as totalCommandsToday,
( select count(*) from commands where process_id = $2) as sessionTotalCommands`,
status.User.ID, status.ProcessID).Scan(
&status.TotalCommands, &status.TotalSessions, &status.TotalSystems,
&status.TotalCommandsToday, &status.SessionTotalCommands)
@ -573,12 +568,10 @@ func (status Status) statusGet() (Status, error) {
}
func importCommands(imp Import) {
_, err := db.Exec(`INSERT INTO commands
("command", "path", "created", "uuid", "exit_status",
"system_name", "session_id", "user_id" )
VALUES ($1,$2,$3,$4,$5,$6,$7,(select "id" from users where "username" = $8)) ON CONFLICT do nothing`,
imp.Command, imp.Path, imp.Created, imp.Uuid, imp.ExitStatus,
imp.SystemName, imp.SessionID, imp.Username)
_, err := db.Exec(`
INSERT INTO commands "command", "path", "created", "uuid", "exit_status","system_name", "session_id", "user_id" )
VALUES ($1,$2,$3,$4,$5,$6,$7,(select "id" from users where "username" = $8))ON CONFLICT do nothing`,
imp.Command, imp.Path, imp.Created, imp.Uuid, imp.ExitStatus, imp.SystemName, imp.SessionID, imp.Username)
if err != nil {
log.Println(err)
}

View file

@ -23,45 +23,68 @@ import (
"encoding/json"
"flag"
"fmt"
"io"
"io/ioutil"
"log"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"strconv"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
)
var (
noCleanup = flag.Bool("no-cleanup", false, "don't remove testdata directory with sqlite db after test")
postgres = flag.Bool("postgres", false, "run postgres tests")
postgresUri = flag.String("postgres-uri", "postgres://postgres:@localhost:5444?sslmode=disable", "postgres uri to use for postgres tests")
testWork = flag.Bool("testwork", false, "don't remove sqlite db and server log when done and print location")
postgres = flag.Bool("postgres", false, "run postgres tests")
postgresUri = flag.String("postgres-uri", "postgres://postgres:@localhost:5444?sslmode=disable", "postgres uri to use for postgres tests")
sessionStartTime int64
pid string
dir string
router *gin.Engine
sysRegistered bool
jwtToken string
testDir string
system sysStruct
)
type sysStruct struct {
user string
pass string
mac int
email string
systemName string
host string
}
func TestMain(m *testing.M) {
flag.Parse()
dirCleanup()
defer func() {
if !*noCleanup {
dirCleanup()
}
}()
defer dirCleanup()
var err error
dir, err = os.Getwd()
if err != nil {
log.Fatal(err)
}
testDir = os.TempDir()
log.Println("test directory", testDir)
testDir, err = ioutil.TempDir("", "bashhub-server-test-")
check(err)
dir = "/tmp/foo"
DbPath = filepath.Join(testDir, "test.db")
LogFile = filepath.Join(testDir, "server.log")
log.Print("sqlite tests")
router = setupRouter()
system = sysStruct{
user: "tester",
pass: "tester",
mac: 888888888888888,
email: "test@email.com",
host: "some-host",
}
m.Run()
if *postgres {
@ -73,40 +96,142 @@ func TestMain(m *testing.M) {
}
func testRequest(method string, u string, body io.Reader) *httptest.ResponseRecorder {
w := httptest.NewRecorder()
req, _ := http.NewRequest(method, u, body)
req.Header.Set("Content-Type", "application/json")
req.Header.Add("Authorization", jwtToken)
router.ServeHTTP(w, req)
return w
}
func check(err error) {
if err != nil {
log.Fatal(err)
}
}
func createUser(t *testing.T) {
auth := map[string]interface{}{
"email": system.email,
"Username": system.user,
"password": system.pass,
}
payloadBytes, err := json.Marshal(auth)
if err != nil {
log.Fatal(err)
}
body := bytes.NewReader(payloadBytes)
w := testRequest("POST", "/api/v1/user", body)
assert.Equal(t, 200, w.Code)
}
func getToken(t *testing.T) string {
auth := map[string]interface{}{
"username": system.user,
"password": system.pass,
"mac": strconv.Itoa(system.mac),
}
payloadBytes, err := json.Marshal(auth)
if err != nil {
log.Fatal(err)
}
body := bytes.NewReader(payloadBytes)
w := testRequest("POST", "/api/v1/login", body)
assert.Equal(t, 200, w.Code)
buf, err := ioutil.ReadAll(w.Body)
if err != nil {
t.Fatal(err)
}
j := make(map[string]interface{})
json.Unmarshal(buf, &j)
if len(j) == 0 {
t.Fatal("login failed for getToken")
}
token := fmt.Sprintf("Bearer %v", j["accessToken"])
if !sysRegistered {
// register system
return sysRegister(t, token)
}
return token
}
func sysRegister(t *testing.T, token string) string {
jwtToken = token
sysPayload := map[string]interface{}{
"clientVersion": "1.2.0",
"name": system.systemName,
"hostname": system.host,
"mac": strconv.Itoa(system.mac),
}
payloadBytes, err := json.Marshal(sysPayload)
check(err)
body := bytes.NewReader(payloadBytes)
w := testRequest("POST", "/api/v1/system", body)
assert.Equal(t, 201, w.Code)
sysRegistered = true
return getToken(t)
}
func TestToken(t *testing.T) {
createUser(t)
sysRegistered = false
jwtToken = getToken(t)
systems := []string{
"system-1",
"system-2",
"system-3",
}
for _, sys := range systems {
system.systemName = sys
system.mac++
sysRegistered = false
jwtToken = getToken(t)
}
}
func TestCommandInsert(t *testing.T) {
var commandTests = []Command{
{ProcessId: 90226, ExitStatus: 0, Command: "cat foo.txt"},
{ProcessId: 90226, ExitStatus: 0, Command: "ls"},
{ProcessId: 90226, ExitStatus: 0, Command: "pwd"},
{ProcessId: 90226, ExitStatus: 0, Command: "whoami"},
{ProcessId: 90226, ExitStatus: 0, Command: "which cat"},
{ProcessId: 90226, ExitStatus: 0, Command: "head foo.txt"},
{ProcessId: 90226, ExitStatus: 0, Command: "sed 's/fooobaar/foobar/g' somefile.txt"},
{ProcessId: 90226, ExitStatus: 0, Command: "curl google.com"},
{ProcessId: 90226, ExitStatus: 0, Command: "file /dev/null"},
{ProcessId: 90226, ExitStatus: 0, Command: "df -h"},
{ProcessId: 90226, ExitStatus: 127, Command: "catt"},
{ProcessId: 90226, ExitStatus: 127, Command: "cay"},
{ExitStatus: 0, Command: "cat foo.txt"},
{ExitStatus: 0, Command: "ls"},
{ExitStatus: 0, Command: "pwd"},
{ExitStatus: 0, Command: "whoami"},
{ExitStatus: 0, Command: "which cat"},
{ExitStatus: 0, Command: "head foo.txt"},
{ExitStatus: 0, Command: "sed 's/fooobaar/foobar/g' somefile.txt"},
{ExitStatus: 0, Command: "curl google.com"},
{ExitStatus: 0, Command: "file /dev/null"},
{ExitStatus: 0, Command: "df -h"},
{ExitStatus: 127, Command: "catt"},
{ExitStatus: 127, Command: "cay"},
}
hourAgo := time.Now().UnixNano() - (1 * time.Hour).Nanoseconds()
sessionStartTime = time.Now().Unix() * 1000
for i := 0; i < 5; i++ {
for _, tc := range commandTests {
uid, err := uuid.NewRandom()
if err != nil {
t.Fatal(err)
}
tc.ProcessId = i
tc.Path = dir
tc.Created = time.Now().Unix()
tc.ProcessStartTime = hourAgo
tc.Created = time.Now().Unix() * 1000
tc.ProcessStartTime = sessionStartTime
tc.Uuid = uid.String()
payloadBytes, err := json.Marshal(&tc)
if err != nil {
@ -116,6 +241,7 @@ func TestCommandInsert(t *testing.T) {
w := testRequest("POST", "/api/v1/command", body)
assert.Equal(t, 200, w.Code)
}
}
}
@ -125,14 +251,14 @@ func TestCommandQuery(t *testing.T) {
expect int
}
var queryTests = []queryTest{
{query: fmt.Sprintf("path=%v&unique=true&systemName=%v&query=^curl", url.QueryEscape(dir), system), expect: 1},
{query: fmt.Sprintf("path=%v&unique=true&systemName=%v&query=^curl", url.QueryEscape(dir), system.systemName), expect: 1},
{query: fmt.Sprintf("path=%v&query=^curl&unique=true", url.QueryEscape(dir)), expect: 1},
{query: fmt.Sprintf("systemName=%v&query=^curl", system), expect: 5},
{query: fmt.Sprintf("systemName=%v&query=^curl", system.systemName), expect: 5},
{query: fmt.Sprintf("path=%v&query=^curl", url.QueryEscape(dir)), expect: 5},
{query: fmt.Sprintf("systemName=%v&unique=true", system), expect: 10},
{query: fmt.Sprintf("systemName=%v&unique=true", system.systemName), expect: 10},
{query: fmt.Sprintf("path=%v&unique=true", url.QueryEscape(dir)), expect: 10},
{query: fmt.Sprintf("path=%v", url.QueryEscape(dir)), expect: 50},
{query: fmt.Sprintf("systemName=%v", system), expect: 50},
{query: fmt.Sprintf("systemName=%v", system.systemName), expect: 50},
{query: "query=^curl&unique=true", expect: 1},
{query: "query=^curl", expect: 5},
{query: "unique=true", expect: 10},
@ -157,7 +283,7 @@ func TestCommandQuery(t *testing.T) {
if v.expect != len(data) {
t.Fatalf("expected: %v, got: %v -- query: %v ", v.expect, len(data), v.query)
}
assert.Contains(t, system, data[0].SystemName)
assert.Contains(t, system.systemName, data[0].SystemName)
assert.Contains(t, dir, data[0].Path)
}()
}
@ -201,6 +327,7 @@ func TestCommandFindDelete(t *testing.T) {
t.Fatal(err)
}
assert.Equal(t, record.Uuid, data.Uuid)
pid = data.SessionID
}()
func() {
@ -224,3 +351,36 @@ func TestCommandFindDelete(t *testing.T) {
}()
}
func TestStatus(t *testing.T) {
u := fmt.Sprintf("/api/v1/client-view/status?processId=%v&startTime=%v", pid, sessionStartTime)
w := testRequest("GET", u, nil)
assert.Equal(t, 200, w.Code)
b, err := ioutil.ReadAll(w.Body)
if err != nil {
t.Fatal(err)
}
var status Status
err = json.Unmarshal(b, &status)
if err != nil {
t.Fatal(err)
}
assert.Equal(t, status.TotalCommands, 49)
assert.Equal(t, status.TotalSessions, 5)
assert.Equal(t, status.TotalSystems, 3)
assert.Equal(t, status.TotalCommandsToday, 49)
assert.Equal(t, status.SessionTotalCommands, 9)
}
func dirCleanup() {
if !*testWork {
os.Chmod(testDir, 0777)
err := os.RemoveAll(testDir)
check(err)
return
}
log.Println("TESTWORK=", testDir)
}

View file

@ -1,160 +0,0 @@
/*
*
* Copyright © 2020 nicksherron <nsherron90@gmail.com>
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package internal
import (
"bytes"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"log"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
)
var (
dir string
router *gin.Engine
sysRegistered bool
jwtToken string
testDir string
)
const (
user = "tester"
pass = "tester"
mac = "888888888888888"
email = "test@email.com"
system = "system"
)
func testRequest(method string, u string, body io.Reader) *httptest.ResponseRecorder {
w := httptest.NewRecorder()
req, _ := http.NewRequest(method, u, body)
req.Header.Set("Content-Type", "application/json")
req.Header.Add("Authorization", jwtToken)
router.ServeHTTP(w, req)
return w
}
func check(err error) {
if err != nil {
log.Fatal(err)
}
}
func createUser(t *testing.T) {
auth := map[string]interface{}{
"Username": user,
"password": pass,
"email": email,
}
payloadBytes, err := json.Marshal(auth)
if err != nil {
log.Fatal(err)
}
body := bytes.NewReader(payloadBytes)
w := testRequest("POST", "/api/v1/user", body)
assert.Equal(t, 200, w.Code)
}
func getToken(t *testing.T) string {
auth := map[string]interface{}{
"username": user,
"password": pass,
"mac": mac,
}
payloadBytes, err := json.Marshal(auth)
if err != nil {
log.Fatal(err)
}
body := bytes.NewReader(payloadBytes)
w := testRequest("POST", "/api/v1/login", body)
assert.Equal(t, 200, w.Code)
buf, err := ioutil.ReadAll(w.Body)
if err != nil {
t.Fatal(err)
}
j := make(map[string]interface{})
json.Unmarshal(buf, &j)
if len(j) == 0 {
t.Fatal("login failed for getToken")
}
token := fmt.Sprintf("Bearer %v", j["accessToken"])
if !sysRegistered {
// register system
return sysRegister(t, token)
}
return token
}
func sysRegister(t *testing.T, token string) string {
jwtToken = token
host, err := os.Hostname()
if err != nil {
log.Fatal(err)
}
sys := map[string]interface{}{
"clientVersion": "1.2.0",
"name": system,
"hostname": host,
"mac": mac,
}
payloadBytes, err := json.Marshal(sys)
check(err)
body := bytes.NewReader(payloadBytes)
w := testRequest("POST", "/api/v1/system", body)
assert.Equal(t, 201, w.Code)
sysRegistered = true
return getToken(t)
}
func dirCleanup() {
dbFiles := []string{
"test.db", "test.db-shm", "test.db-wal",
}
for _, d := range dbFiles {
err := os.RemoveAll(filepath.Join([]string{testDir, d}...))
check(err)
}
}