mirror of
https://github.com/nicksherron/bashhub-server.git
synced 2024-11-10 09:02:54 +08:00
server_test: status added; db: query refactor
This commit is contained in:
parent
629475e08f
commit
27c57b37a9
4 changed files with 251 additions and 235 deletions
23
cmd/transfer_test.go
Normal file
23
cmd/transfer_test.go
Normal file
|
@ -0,0 +1,23 @@
|
|||
/*
|
||||
*
|
||||
* Copyright © 2020 nicksherron <nsherron90@gmail.com>
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package cmd
|
||||
|
||||
|
||||
|
||||
|
|
@ -45,7 +45,6 @@ var (
|
|||
QueryDebug bool
|
||||
)
|
||||
|
||||
// DbInit initializes our db.
|
||||
func dbInit() {
|
||||
var gormdb *gorm.DB
|
||||
var err error
|
||||
|
@ -238,7 +237,7 @@ func (cmd Command) commandInsert() int64 {
|
|||
func (cmd Command) commandGet() ([]Query, error) {
|
||||
var (
|
||||
results []Query
|
||||
query string
|
||||
query string
|
||||
)
|
||||
if cmd.Unique || cmd.Query != "" {
|
||||
//postgres
|
||||
|
@ -273,7 +272,7 @@ func (cmd Command) commandGet() ([]Query, error) {
|
|||
WHERE "user_id" = '%v'
|
||||
AND "system_name" = '%v'
|
||||
AND "command" ~ '%v'
|
||||
ORDER BY "created" DESC limit '%v';`, cmd.User.ID, cmd.SystemName, cmd.Query, cmd.Limit,)
|
||||
ORDER BY "created" DESC limit '%v';`, cmd.User.ID, cmd.SystemName, cmd.Query, cmd.Limit)
|
||||
|
||||
} else if cmd.Path != "" && cmd.Query != "" {
|
||||
query = fmt.Sprintf(`
|
||||
|
@ -282,7 +281,7 @@ func (cmd Command) commandGet() ([]Query, error) {
|
|||
WHERE "user_id" = '%v'
|
||||
AND "path" = '%v'
|
||||
AND "command" ~ '%v'
|
||||
ORDER BY "created" DESC limit '%v';`, cmd.User.ID, cmd.Path, cmd.Query, cmd.Limit)
|
||||
ORDER BY "created" DESC limit '%v';`, cmd.User.ID, cmd.Path, cmd.Query, cmd.Limit)
|
||||
|
||||
} else if cmd.SystemName != "" && cmd.Unique {
|
||||
query = fmt.Sprintf(`
|
||||
|
@ -292,7 +291,7 @@ func (cmd Command) commandGet() ([]Query, error) {
|
|||
WHERE "user_id" = '%v'
|
||||
AND "system_name" = '%v'
|
||||
) c
|
||||
ORDER BY "created" DESC limit '%v';`, cmd.User.ID, cmd.SystemName, cmd.Limit, )
|
||||
ORDER BY "created" DESC limit '%v';`, cmd.User.ID, cmd.SystemName, cmd.Limit)
|
||||
|
||||
} else if cmd.Path != "" && cmd.Unique {
|
||||
query = fmt.Sprintf(`
|
||||
|
@ -312,7 +311,7 @@ func (cmd Command) commandGet() ([]Query, error) {
|
|||
WHERE "user_id" = '%v'
|
||||
AND "command" ~ '%v'
|
||||
) c
|
||||
ORDER BY "created" DESC limit '%v';`, cmd.User.ID, cmd.Query, cmd.Limit,)
|
||||
ORDER BY "created" DESC limit '%v';`, cmd.User.ID, cmd.Query, cmd.Limit)
|
||||
|
||||
} else if cmd.Query != "" {
|
||||
query = fmt.Sprintf(`
|
||||
|
@ -320,7 +319,7 @@ func (cmd Command) commandGet() ([]Query, error) {
|
|||
FROM commands
|
||||
WHERE "user_id" = '%v'
|
||||
AND "command" ~ '%v'
|
||||
ORDER BY "created" DESC limit '%v';`, cmd.User.ID,cmd.Query, cmd.Limit, )
|
||||
ORDER BY "created" DESC limit '%v';`, cmd.User.ID, cmd.Query, cmd.Limit)
|
||||
|
||||
} else {
|
||||
// unique
|
||||
|
@ -344,7 +343,6 @@ func (cmd Command) commandGet() ([]Query, error) {
|
|||
AND "command" regexp '%v'
|
||||
GROUP BY "command" ORDER BY "created" DESC limit '%v'`, cmd.User.ID, cmd.Path, cmd.SystemName, cmd.Query, cmd.Limit)
|
||||
|
||||
|
||||
} else if cmd.SystemName != "" && cmd.Query != "" && cmd.Unique {
|
||||
query = fmt.Sprintf(`
|
||||
SELECT "command", "uuid", "created" FROM commands
|
||||
|
@ -353,7 +351,6 @@ func (cmd Command) commandGet() ([]Query, error) {
|
|||
AND "command" regexp '%v'
|
||||
GROUP BY "command" ORDER BY "created" DESC limit '%v'`, cmd.User.ID, cmd.SystemName, cmd.Query, cmd.Limit)
|
||||
|
||||
|
||||
} else if cmd.Path != "" && cmd.Query != "" && cmd.Unique {
|
||||
query = fmt.Sprintf(`
|
||||
SELECT "command", "uuid", "created" FROM commands
|
||||
|
@ -362,7 +359,6 @@ func (cmd Command) commandGet() ([]Query, error) {
|
|||
AND "command" regexp '%v'
|
||||
GROUP BY "command" ORDER BY "created" DESC limit '%v'`, cmd.User.ID, cmd.Path, cmd.Query, cmd.Limit)
|
||||
|
||||
|
||||
} else if cmd.SystemName != "" && cmd.Query != "" {
|
||||
query = fmt.Sprintf(`
|
||||
SELECT "command", "uuid", "created" FROM commands
|
||||
|
@ -373,7 +369,7 @@ func (cmd Command) commandGet() ([]Query, error) {
|
|||
if QueryDebug {
|
||||
log.Println(query)
|
||||
}
|
||||
|
||||
|
||||
} else if cmd.Path != "" && cmd.Query != "" {
|
||||
query = fmt.Sprintf(`
|
||||
SELECT "command", "uuid", "created" FROM commands
|
||||
|
@ -382,7 +378,6 @@ func (cmd Command) commandGet() ([]Query, error) {
|
|||
AND "command" regexp '%v'
|
||||
ORDER BY "created" DESC limit '%v'`, cmd.User.ID, cmd.Path, cmd.Query, cmd.Limit)
|
||||
|
||||
|
||||
} else if cmd.SystemName != "" && cmd.Unique {
|
||||
query = fmt.Sprintf(`
|
||||
SELECT "command", "uuid", "created" FROM commands
|
||||
|
@ -404,7 +399,6 @@ func (cmd Command) commandGet() ([]Query, error) {
|
|||
AND "command" regexp '%v'
|
||||
GROUP BY "command" ORDER BY "created" DESC limit '%v'`, cmd.User.ID, cmd.Query, cmd.Limit)
|
||||
|
||||
|
||||
} else if cmd.Query != "" {
|
||||
query = fmt.Sprintf(`
|
||||
SELECT "command", "uuid", "created" FROM commands
|
||||
|
@ -412,7 +406,6 @@ func (cmd Command) commandGet() ([]Query, error) {
|
|||
AND "command" regexp'%v'
|
||||
ORDER BY "created" DESC limit '%v'`, cmd.User.ID, cmd.Query, cmd.Limit)
|
||||
|
||||
|
||||
} else {
|
||||
// unique
|
||||
query = fmt.Sprintf(`
|
||||
|
@ -428,14 +421,14 @@ func (cmd Command) commandGet() ([]Query, error) {
|
|||
SELECT "command", "uuid", "created" FROM commands
|
||||
WHERE "user_id" = '%v'
|
||||
AND "path" = '%v'
|
||||
ORDER BY "created" DESC limit '%v'`, cmd.User.ID, cmd.Path, cmd.Limit)
|
||||
ORDER BY "created" DESC limit '%v'`, cmd.User.ID, cmd.Path, cmd.Limit)
|
||||
} else if cmd.SystemName != "" {
|
||||
query = fmt.Sprintf(`SELECT "command", "uuid", "created" FROM commands
|
||||
WHERE "user_id" = '%v'
|
||||
AND "system_name" = '%v'
|
||||
ORDER BY "created" DESC limit '%v'`, cmd.User.ID, cmd.SystemName, cmd.Limit)
|
||||
|
||||
} else {
|
||||
|
||||
} else {
|
||||
query = fmt.Sprintf(`
|
||||
SELECT "command", "uuid", "created" FROM commands
|
||||
WHERE "user_id" = '%v'
|
||||
|
@ -468,11 +461,11 @@ func (cmd Command) commandGet() ([]Query, error) {
|
|||
func (cmd Command) commandGetUUID() (Query, error) {
|
||||
var result Query
|
||||
err := db.QueryRow(`
|
||||
SELECT "command","path", "created" , "uuid", "exit_status", "system_name"
|
||||
SELECT "command","path", "created" , "uuid", "exit_status", "system_name", "process_id"
|
||||
FROM commands
|
||||
WHERE "uuid" = $1
|
||||
AND "user_id" = $2`, cmd.Uuid, cmd.User.ID).Scan(&result.Command, &result.Path, &result.Created, &result.Uuid,
|
||||
&result.ExitStatus, &result.SystemName)
|
||||
&result.ExitStatus, &result.SystemName, &result.SessionID)
|
||||
if err != nil {
|
||||
return Query{}, err
|
||||
}
|
||||
|
@ -546,22 +539,24 @@ func (sys System) systemGet() (System, error) {
|
|||
func (status Status) statusGet() (Status, error) {
|
||||
var err error
|
||||
if connectionLimit != 1 {
|
||||
err = db.QueryRow(`select
|
||||
( select count(*) from commands where user_id = $1) as totalCommands,
|
||||
( select count(distinct process_id) from commands where user_id = $1) as totalSessions,
|
||||
( select count(distinct system_name) from commands where user_id = $1) as totalSystems,
|
||||
( select count (*) from commands where to_timestamp(cast(created/1000 as bigint))::date = now()::date and user_id = $1) as totalCommandsToday,
|
||||
( select count(*) from commands where process_id = $2) as sessionTotalCommands`,
|
||||
err = db.QueryRow(`
|
||||
select
|
||||
( select count(*) from commands where user_id = $1) as totalCommands,
|
||||
( select count(distinct process_id) from commands where user_id = $1) as totalSessions,
|
||||
( select count(*) from systems where user_id = $1) as totalSystems,
|
||||
( select count(*) from commands where to_timestamp(cast(created/1000 as bigint))::date = now()::date and user_id = $1) as totalCommandsToday,
|
||||
( select count(*) from commands where process_id = $2) as sessionTotalCommands`,
|
||||
status.User.ID, status.ProcessID).Scan(
|
||||
&status.TotalCommands, &status.TotalSessions, &status.TotalSystems,
|
||||
&status.TotalCommandsToday, &status.SessionTotalCommands)
|
||||
} else {
|
||||
err = db.QueryRow(`select
|
||||
( select count(*) from commands where user_id = $1) as totalCommands,
|
||||
( select count(distinct process_id) from commands where user_id = $1) as totalSessions,
|
||||
( select count(distinct system_name) from commands where user_id = $1) as totalSystems,
|
||||
( select count(*) from commands where date(created/1000, 'unixepoch') = date('now') and user_id = $1) as totalCommandsToday,
|
||||
( select count(*) from commands where process_id = $2) as sessionTotalCommands`,
|
||||
err = db.QueryRow(`
|
||||
select
|
||||
( select count(*) from commands where user_id = $1) as totalCommands,
|
||||
( select count(distinct process_id) from commands where user_id = $1) as totalSessions,
|
||||
( select count(*) from systems where user_id = $1) as totalSystems,
|
||||
( select count(*) from commands where date(created/1000, 'unixepoch') = date('now') and user_id = $1) as totalCommandsToday,
|
||||
( select count(*) from commands where process_id = $2) as sessionTotalCommands`,
|
||||
status.User.ID, status.ProcessID).Scan(
|
||||
&status.TotalCommands, &status.TotalSessions, &status.TotalSystems,
|
||||
&status.TotalCommandsToday, &status.SessionTotalCommands)
|
||||
|
@ -573,12 +568,10 @@ func (status Status) statusGet() (Status, error) {
|
|||
}
|
||||
|
||||
func importCommands(imp Import) {
|
||||
_, err := db.Exec(`INSERT INTO commands
|
||||
("command", "path", "created", "uuid", "exit_status",
|
||||
"system_name", "session_id", "user_id" )
|
||||
VALUES ($1,$2,$3,$4,$5,$6,$7,(select "id" from users where "username" = $8)) ON CONFLICT do nothing`,
|
||||
imp.Command, imp.Path, imp.Created, imp.Uuid, imp.ExitStatus,
|
||||
imp.SystemName, imp.SessionID, imp.Username)
|
||||
_, err := db.Exec(`
|
||||
INSERT INTO commands "command", "path", "created", "uuid", "exit_status","system_name", "session_id", "user_id" )
|
||||
VALUES ($1,$2,$3,$4,$5,$6,$7,(select "id" from users where "username" = $8))ON CONFLICT do nothing`,
|
||||
imp.Command, imp.Path, imp.Created, imp.Uuid, imp.ExitStatus, imp.SystemName, imp.SessionID, imp.Username)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
|
|
|
@ -23,45 +23,68 @@ import (
|
|||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var (
|
||||
noCleanup = flag.Bool("no-cleanup", false, "don't remove testdata directory with sqlite db after test")
|
||||
postgres = flag.Bool("postgres", false, "run postgres tests")
|
||||
postgresUri = flag.String("postgres-uri", "postgres://postgres:@localhost:5444?sslmode=disable", "postgres uri to use for postgres tests")
|
||||
testWork = flag.Bool("testwork", false, "don't remove sqlite db and server log when done and print location")
|
||||
postgres = flag.Bool("postgres", false, "run postgres tests")
|
||||
postgresUri = flag.String("postgres-uri", "postgres://postgres:@localhost:5444?sslmode=disable", "postgres uri to use for postgres tests")
|
||||
sessionStartTime int64
|
||||
pid string
|
||||
dir string
|
||||
router *gin.Engine
|
||||
sysRegistered bool
|
||||
jwtToken string
|
||||
testDir string
|
||||
system sysStruct
|
||||
)
|
||||
|
||||
type sysStruct struct {
|
||||
user string
|
||||
pass string
|
||||
mac int
|
||||
email string
|
||||
systemName string
|
||||
host string
|
||||
}
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
|
||||
flag.Parse()
|
||||
dirCleanup()
|
||||
defer func() {
|
||||
if !*noCleanup {
|
||||
dirCleanup()
|
||||
}
|
||||
}()
|
||||
defer dirCleanup()
|
||||
|
||||
var err error
|
||||
dir, err = os.Getwd()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
testDir = os.TempDir()
|
||||
log.Println("test directory", testDir)
|
||||
testDir, err = ioutil.TempDir("", "bashhub-server-test-")
|
||||
check(err)
|
||||
dir = "/tmp/foo"
|
||||
|
||||
DbPath = filepath.Join(testDir, "test.db")
|
||||
LogFile = filepath.Join(testDir, "server.log")
|
||||
log.Print("sqlite tests")
|
||||
router = setupRouter()
|
||||
|
||||
system = sysStruct{
|
||||
user: "tester",
|
||||
pass: "tester",
|
||||
mac: 888888888888888,
|
||||
email: "test@email.com",
|
||||
host: "some-host",
|
||||
}
|
||||
m.Run()
|
||||
|
||||
if *postgres {
|
||||
|
@ -73,40 +96,142 @@ func TestMain(m *testing.M) {
|
|||
|
||||
}
|
||||
|
||||
func testRequest(method string, u string, body io.Reader) *httptest.ResponseRecorder {
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest(method, u, body)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Add("Authorization", jwtToken)
|
||||
router.ServeHTTP(w, req)
|
||||
return w
|
||||
}
|
||||
|
||||
func check(err error) {
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func createUser(t *testing.T) {
|
||||
auth := map[string]interface{}{
|
||||
"email": system.email,
|
||||
"Username": system.user,
|
||||
"password": system.pass,
|
||||
}
|
||||
|
||||
payloadBytes, err := json.Marshal(auth)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
body := bytes.NewReader(payloadBytes)
|
||||
w := testRequest("POST", "/api/v1/user", body)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
}
|
||||
|
||||
func getToken(t *testing.T) string {
|
||||
|
||||
auth := map[string]interface{}{
|
||||
"username": system.user,
|
||||
"password": system.pass,
|
||||
"mac": strconv.Itoa(system.mac),
|
||||
}
|
||||
|
||||
payloadBytes, err := json.Marshal(auth)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
body := bytes.NewReader(payloadBytes)
|
||||
w := testRequest("POST", "/api/v1/login", body)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
|
||||
buf, err := ioutil.ReadAll(w.Body)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
j := make(map[string]interface{})
|
||||
|
||||
json.Unmarshal(buf, &j)
|
||||
|
||||
if len(j) == 0 {
|
||||
t.Fatal("login failed for getToken")
|
||||
|
||||
}
|
||||
token := fmt.Sprintf("Bearer %v", j["accessToken"])
|
||||
|
||||
if !sysRegistered {
|
||||
// register system
|
||||
return sysRegister(t, token)
|
||||
}
|
||||
return token
|
||||
}
|
||||
|
||||
func sysRegister(t *testing.T, token string) string {
|
||||
|
||||
jwtToken = token
|
||||
sysPayload := map[string]interface{}{
|
||||
"clientVersion": "1.2.0",
|
||||
"name": system.systemName,
|
||||
"hostname": system.host,
|
||||
"mac": strconv.Itoa(system.mac),
|
||||
}
|
||||
payloadBytes, err := json.Marshal(sysPayload)
|
||||
check(err)
|
||||
|
||||
body := bytes.NewReader(payloadBytes)
|
||||
|
||||
w := testRequest("POST", "/api/v1/system", body)
|
||||
assert.Equal(t, 201, w.Code)
|
||||
|
||||
sysRegistered = true
|
||||
|
||||
return getToken(t)
|
||||
|
||||
}
|
||||
|
||||
func TestToken(t *testing.T) {
|
||||
createUser(t)
|
||||
sysRegistered = false
|
||||
jwtToken = getToken(t)
|
||||
systems := []string{
|
||||
"system-1",
|
||||
"system-2",
|
||||
"system-3",
|
||||
}
|
||||
for _, sys := range systems {
|
||||
system.systemName = sys
|
||||
system.mac++
|
||||
sysRegistered = false
|
||||
jwtToken = getToken(t)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestCommandInsert(t *testing.T) {
|
||||
var commandTests = []Command{
|
||||
{ProcessId: 90226, ExitStatus: 0, Command: "cat foo.txt"},
|
||||
{ProcessId: 90226, ExitStatus: 0, Command: "ls"},
|
||||
{ProcessId: 90226, ExitStatus: 0, Command: "pwd"},
|
||||
{ProcessId: 90226, ExitStatus: 0, Command: "whoami"},
|
||||
{ProcessId: 90226, ExitStatus: 0, Command: "which cat"},
|
||||
{ProcessId: 90226, ExitStatus: 0, Command: "head foo.txt"},
|
||||
{ProcessId: 90226, ExitStatus: 0, Command: "sed 's/fooobaar/foobar/g' somefile.txt"},
|
||||
{ProcessId: 90226, ExitStatus: 0, Command: "curl google.com"},
|
||||
{ProcessId: 90226, ExitStatus: 0, Command: "file /dev/null"},
|
||||
{ProcessId: 90226, ExitStatus: 0, Command: "df -h"},
|
||||
{ProcessId: 90226, ExitStatus: 127, Command: "catt"},
|
||||
{ProcessId: 90226, ExitStatus: 127, Command: "cay"},
|
||||
{ExitStatus: 0, Command: "cat foo.txt"},
|
||||
{ExitStatus: 0, Command: "ls"},
|
||||
{ExitStatus: 0, Command: "pwd"},
|
||||
{ExitStatus: 0, Command: "whoami"},
|
||||
{ExitStatus: 0, Command: "which cat"},
|
||||
{ExitStatus: 0, Command: "head foo.txt"},
|
||||
{ExitStatus: 0, Command: "sed 's/fooobaar/foobar/g' somefile.txt"},
|
||||
{ExitStatus: 0, Command: "curl google.com"},
|
||||
{ExitStatus: 0, Command: "file /dev/null"},
|
||||
{ExitStatus: 0, Command: "df -h"},
|
||||
{ExitStatus: 127, Command: "catt"},
|
||||
{ExitStatus: 127, Command: "cay"},
|
||||
}
|
||||
|
||||
hourAgo := time.Now().UnixNano() - (1 * time.Hour).Nanoseconds()
|
||||
|
||||
sessionStartTime = time.Now().Unix() * 1000
|
||||
for i := 0; i < 5; i++ {
|
||||
for _, tc := range commandTests {
|
||||
uid, err := uuid.NewRandom()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tc.ProcessId = i
|
||||
tc.Path = dir
|
||||
tc.Created = time.Now().Unix()
|
||||
tc.ProcessStartTime = hourAgo
|
||||
tc.Created = time.Now().Unix() * 1000
|
||||
tc.ProcessStartTime = sessionStartTime
|
||||
tc.Uuid = uid.String()
|
||||
payloadBytes, err := json.Marshal(&tc)
|
||||
if err != nil {
|
||||
|
@ -116,6 +241,7 @@ func TestCommandInsert(t *testing.T) {
|
|||
w := testRequest("POST", "/api/v1/command", body)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -125,14 +251,14 @@ func TestCommandQuery(t *testing.T) {
|
|||
expect int
|
||||
}
|
||||
var queryTests = []queryTest{
|
||||
{query: fmt.Sprintf("path=%v&unique=true&systemName=%v&query=^curl", url.QueryEscape(dir), system), expect: 1},
|
||||
{query: fmt.Sprintf("path=%v&unique=true&systemName=%v&query=^curl", url.QueryEscape(dir), system.systemName), expect: 1},
|
||||
{query: fmt.Sprintf("path=%v&query=^curl&unique=true", url.QueryEscape(dir)), expect: 1},
|
||||
{query: fmt.Sprintf("systemName=%v&query=^curl", system), expect: 5},
|
||||
{query: fmt.Sprintf("systemName=%v&query=^curl", system.systemName), expect: 5},
|
||||
{query: fmt.Sprintf("path=%v&query=^curl", url.QueryEscape(dir)), expect: 5},
|
||||
{query: fmt.Sprintf("systemName=%v&unique=true", system), expect: 10},
|
||||
{query: fmt.Sprintf("systemName=%v&unique=true", system.systemName), expect: 10},
|
||||
{query: fmt.Sprintf("path=%v&unique=true", url.QueryEscape(dir)), expect: 10},
|
||||
{query: fmt.Sprintf("path=%v", url.QueryEscape(dir)), expect: 50},
|
||||
{query: fmt.Sprintf("systemName=%v", system), expect: 50},
|
||||
{query: fmt.Sprintf("systemName=%v", system.systemName), expect: 50},
|
||||
{query: "query=^curl&unique=true", expect: 1},
|
||||
{query: "query=^curl", expect: 5},
|
||||
{query: "unique=true", expect: 10},
|
||||
|
@ -157,7 +283,7 @@ func TestCommandQuery(t *testing.T) {
|
|||
if v.expect != len(data) {
|
||||
t.Fatalf("expected: %v, got: %v -- query: %v ", v.expect, len(data), v.query)
|
||||
}
|
||||
assert.Contains(t, system, data[0].SystemName)
|
||||
assert.Contains(t, system.systemName, data[0].SystemName)
|
||||
assert.Contains(t, dir, data[0].Path)
|
||||
}()
|
||||
}
|
||||
|
@ -201,6 +327,7 @@ func TestCommandFindDelete(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
assert.Equal(t, record.Uuid, data.Uuid)
|
||||
pid = data.SessionID
|
||||
}()
|
||||
|
||||
func() {
|
||||
|
@ -224,3 +351,36 @@ func TestCommandFindDelete(t *testing.T) {
|
|||
}()
|
||||
|
||||
}
|
||||
|
||||
func TestStatus(t *testing.T) {
|
||||
u := fmt.Sprintf("/api/v1/client-view/status?processId=%v&startTime=%v", pid, sessionStartTime)
|
||||
w := testRequest("GET", u, nil)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
b, err := ioutil.ReadAll(w.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var status Status
|
||||
err = json.Unmarshal(b, &status)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.Equal(t, status.TotalCommands, 49)
|
||||
assert.Equal(t, status.TotalSessions, 5)
|
||||
assert.Equal(t, status.TotalSystems, 3)
|
||||
assert.Equal(t, status.TotalCommandsToday, 49)
|
||||
assert.Equal(t, status.SessionTotalCommands, 9)
|
||||
|
||||
}
|
||||
|
||||
|
||||
func dirCleanup() {
|
||||
if !*testWork {
|
||||
os.Chmod(testDir, 0777)
|
||||
err := os.RemoveAll(testDir)
|
||||
check(err)
|
||||
return
|
||||
}
|
||||
log.Println("TESTWORK=", testDir)
|
||||
|
||||
}
|
|
@ -1,160 +0,0 @@
|
|||
/*
|
||||
*
|
||||
* Copyright © 2020 nicksherron <nsherron90@gmail.com>
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var (
|
||||
dir string
|
||||
router *gin.Engine
|
||||
sysRegistered bool
|
||||
jwtToken string
|
||||
testDir string
|
||||
)
|
||||
|
||||
const (
|
||||
user = "tester"
|
||||
pass = "tester"
|
||||
mac = "888888888888888"
|
||||
email = "test@email.com"
|
||||
system = "system"
|
||||
)
|
||||
|
||||
func testRequest(method string, u string, body io.Reader) *httptest.ResponseRecorder {
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest(method, u, body)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Add("Authorization", jwtToken)
|
||||
router.ServeHTTP(w, req)
|
||||
return w
|
||||
}
|
||||
|
||||
func check(err error) {
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func createUser(t *testing.T) {
|
||||
auth := map[string]interface{}{
|
||||
"Username": user,
|
||||
"password": pass,
|
||||
"email": email,
|
||||
}
|
||||
|
||||
payloadBytes, err := json.Marshal(auth)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
body := bytes.NewReader(payloadBytes)
|
||||
w := testRequest("POST", "/api/v1/user", body)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
}
|
||||
|
||||
func getToken(t *testing.T) string {
|
||||
|
||||
auth := map[string]interface{}{
|
||||
"username": user,
|
||||
"password": pass,
|
||||
"mac": mac,
|
||||
}
|
||||
|
||||
payloadBytes, err := json.Marshal(auth)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
body := bytes.NewReader(payloadBytes)
|
||||
w := testRequest("POST", "/api/v1/login", body)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
|
||||
buf, err := ioutil.ReadAll(w.Body)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
j := make(map[string]interface{})
|
||||
|
||||
json.Unmarshal(buf, &j)
|
||||
|
||||
if len(j) == 0 {
|
||||
t.Fatal("login failed for getToken")
|
||||
|
||||
}
|
||||
token := fmt.Sprintf("Bearer %v", j["accessToken"])
|
||||
|
||||
if !sysRegistered {
|
||||
// register system
|
||||
return sysRegister(t, token)
|
||||
}
|
||||
return token
|
||||
}
|
||||
|
||||
func sysRegister(t *testing.T, token string) string {
|
||||
|
||||
jwtToken = token
|
||||
host, err := os.Hostname()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
sys := map[string]interface{}{
|
||||
"clientVersion": "1.2.0",
|
||||
"name": system,
|
||||
"hostname": host,
|
||||
"mac": mac,
|
||||
}
|
||||
payloadBytes, err := json.Marshal(sys)
|
||||
check(err)
|
||||
|
||||
body := bytes.NewReader(payloadBytes)
|
||||
|
||||
w := testRequest("POST", "/api/v1/system", body)
|
||||
assert.Equal(t, 201, w.Code)
|
||||
|
||||
sysRegistered = true
|
||||
|
||||
return getToken(t)
|
||||
|
||||
}
|
||||
|
||||
func dirCleanup() {
|
||||
dbFiles := []string{
|
||||
"test.db", "test.db-shm", "test.db-wal",
|
||||
}
|
||||
for _, d := range dbFiles {
|
||||
err := os.RemoveAll(filepath.Join([]string{testDir, d}...))
|
||||
check(err)
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in a new issue