mirror of
https://github.com/nicksherron/bashhub-server.git
synced 2025-09-08 13:24:12 +08:00
server_test: status added; db: query refactor
This commit is contained in:
parent
629475e08f
commit
27c57b37a9
4 changed files with 251 additions and 235 deletions
23
cmd/transfer_test.go
Normal file
23
cmd/transfer_test.go
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
/*
|
||||||
|
*
|
||||||
|
* Copyright © 2020 nicksherron <nsherron90@gmail.com>
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
package cmd
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -45,7 +45,6 @@ var (
|
||||||
QueryDebug bool
|
QueryDebug bool
|
||||||
)
|
)
|
||||||
|
|
||||||
// DbInit initializes our db.
|
|
||||||
func dbInit() {
|
func dbInit() {
|
||||||
var gormdb *gorm.DB
|
var gormdb *gorm.DB
|
||||||
var err error
|
var err error
|
||||||
|
@ -238,7 +237,7 @@ func (cmd Command) commandInsert() int64 {
|
||||||
func (cmd Command) commandGet() ([]Query, error) {
|
func (cmd Command) commandGet() ([]Query, error) {
|
||||||
var (
|
var (
|
||||||
results []Query
|
results []Query
|
||||||
query string
|
query string
|
||||||
)
|
)
|
||||||
if cmd.Unique || cmd.Query != "" {
|
if cmd.Unique || cmd.Query != "" {
|
||||||
//postgres
|
//postgres
|
||||||
|
@ -273,7 +272,7 @@ func (cmd Command) commandGet() ([]Query, error) {
|
||||||
WHERE "user_id" = '%v'
|
WHERE "user_id" = '%v'
|
||||||
AND "system_name" = '%v'
|
AND "system_name" = '%v'
|
||||||
AND "command" ~ '%v'
|
AND "command" ~ '%v'
|
||||||
ORDER BY "created" DESC limit '%v';`, cmd.User.ID, cmd.SystemName, cmd.Query, cmd.Limit,)
|
ORDER BY "created" DESC limit '%v';`, cmd.User.ID, cmd.SystemName, cmd.Query, cmd.Limit)
|
||||||
|
|
||||||
} else if cmd.Path != "" && cmd.Query != "" {
|
} else if cmd.Path != "" && cmd.Query != "" {
|
||||||
query = fmt.Sprintf(`
|
query = fmt.Sprintf(`
|
||||||
|
@ -282,7 +281,7 @@ func (cmd Command) commandGet() ([]Query, error) {
|
||||||
WHERE "user_id" = '%v'
|
WHERE "user_id" = '%v'
|
||||||
AND "path" = '%v'
|
AND "path" = '%v'
|
||||||
AND "command" ~ '%v'
|
AND "command" ~ '%v'
|
||||||
ORDER BY "created" DESC limit '%v';`, cmd.User.ID, cmd.Path, cmd.Query, cmd.Limit)
|
ORDER BY "created" DESC limit '%v';`, cmd.User.ID, cmd.Path, cmd.Query, cmd.Limit)
|
||||||
|
|
||||||
} else if cmd.SystemName != "" && cmd.Unique {
|
} else if cmd.SystemName != "" && cmd.Unique {
|
||||||
query = fmt.Sprintf(`
|
query = fmt.Sprintf(`
|
||||||
|
@ -292,7 +291,7 @@ func (cmd Command) commandGet() ([]Query, error) {
|
||||||
WHERE "user_id" = '%v'
|
WHERE "user_id" = '%v'
|
||||||
AND "system_name" = '%v'
|
AND "system_name" = '%v'
|
||||||
) c
|
) c
|
||||||
ORDER BY "created" DESC limit '%v';`, cmd.User.ID, cmd.SystemName, cmd.Limit, )
|
ORDER BY "created" DESC limit '%v';`, cmd.User.ID, cmd.SystemName, cmd.Limit)
|
||||||
|
|
||||||
} else if cmd.Path != "" && cmd.Unique {
|
} else if cmd.Path != "" && cmd.Unique {
|
||||||
query = fmt.Sprintf(`
|
query = fmt.Sprintf(`
|
||||||
|
@ -312,7 +311,7 @@ func (cmd Command) commandGet() ([]Query, error) {
|
||||||
WHERE "user_id" = '%v'
|
WHERE "user_id" = '%v'
|
||||||
AND "command" ~ '%v'
|
AND "command" ~ '%v'
|
||||||
) c
|
) c
|
||||||
ORDER BY "created" DESC limit '%v';`, cmd.User.ID, cmd.Query, cmd.Limit,)
|
ORDER BY "created" DESC limit '%v';`, cmd.User.ID, cmd.Query, cmd.Limit)
|
||||||
|
|
||||||
} else if cmd.Query != "" {
|
} else if cmd.Query != "" {
|
||||||
query = fmt.Sprintf(`
|
query = fmt.Sprintf(`
|
||||||
|
@ -320,7 +319,7 @@ func (cmd Command) commandGet() ([]Query, error) {
|
||||||
FROM commands
|
FROM commands
|
||||||
WHERE "user_id" = '%v'
|
WHERE "user_id" = '%v'
|
||||||
AND "command" ~ '%v'
|
AND "command" ~ '%v'
|
||||||
ORDER BY "created" DESC limit '%v';`, cmd.User.ID,cmd.Query, cmd.Limit, )
|
ORDER BY "created" DESC limit '%v';`, cmd.User.ID, cmd.Query, cmd.Limit)
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
// unique
|
// unique
|
||||||
|
@ -344,7 +343,6 @@ func (cmd Command) commandGet() ([]Query, error) {
|
||||||
AND "command" regexp '%v'
|
AND "command" regexp '%v'
|
||||||
GROUP BY "command" ORDER BY "created" DESC limit '%v'`, cmd.User.ID, cmd.Path, cmd.SystemName, cmd.Query, cmd.Limit)
|
GROUP BY "command" ORDER BY "created" DESC limit '%v'`, cmd.User.ID, cmd.Path, cmd.SystemName, cmd.Query, cmd.Limit)
|
||||||
|
|
||||||
|
|
||||||
} else if cmd.SystemName != "" && cmd.Query != "" && cmd.Unique {
|
} else if cmd.SystemName != "" && cmd.Query != "" && cmd.Unique {
|
||||||
query = fmt.Sprintf(`
|
query = fmt.Sprintf(`
|
||||||
SELECT "command", "uuid", "created" FROM commands
|
SELECT "command", "uuid", "created" FROM commands
|
||||||
|
@ -353,7 +351,6 @@ func (cmd Command) commandGet() ([]Query, error) {
|
||||||
AND "command" regexp '%v'
|
AND "command" regexp '%v'
|
||||||
GROUP BY "command" ORDER BY "created" DESC limit '%v'`, cmd.User.ID, cmd.SystemName, cmd.Query, cmd.Limit)
|
GROUP BY "command" ORDER BY "created" DESC limit '%v'`, cmd.User.ID, cmd.SystemName, cmd.Query, cmd.Limit)
|
||||||
|
|
||||||
|
|
||||||
} else if cmd.Path != "" && cmd.Query != "" && cmd.Unique {
|
} else if cmd.Path != "" && cmd.Query != "" && cmd.Unique {
|
||||||
query = fmt.Sprintf(`
|
query = fmt.Sprintf(`
|
||||||
SELECT "command", "uuid", "created" FROM commands
|
SELECT "command", "uuid", "created" FROM commands
|
||||||
|
@ -362,7 +359,6 @@ func (cmd Command) commandGet() ([]Query, error) {
|
||||||
AND "command" regexp '%v'
|
AND "command" regexp '%v'
|
||||||
GROUP BY "command" ORDER BY "created" DESC limit '%v'`, cmd.User.ID, cmd.Path, cmd.Query, cmd.Limit)
|
GROUP BY "command" ORDER BY "created" DESC limit '%v'`, cmd.User.ID, cmd.Path, cmd.Query, cmd.Limit)
|
||||||
|
|
||||||
|
|
||||||
} else if cmd.SystemName != "" && cmd.Query != "" {
|
} else if cmd.SystemName != "" && cmd.Query != "" {
|
||||||
query = fmt.Sprintf(`
|
query = fmt.Sprintf(`
|
||||||
SELECT "command", "uuid", "created" FROM commands
|
SELECT "command", "uuid", "created" FROM commands
|
||||||
|
@ -373,7 +369,7 @@ func (cmd Command) commandGet() ([]Query, error) {
|
||||||
if QueryDebug {
|
if QueryDebug {
|
||||||
log.Println(query)
|
log.Println(query)
|
||||||
}
|
}
|
||||||
|
|
||||||
} else if cmd.Path != "" && cmd.Query != "" {
|
} else if cmd.Path != "" && cmd.Query != "" {
|
||||||
query = fmt.Sprintf(`
|
query = fmt.Sprintf(`
|
||||||
SELECT "command", "uuid", "created" FROM commands
|
SELECT "command", "uuid", "created" FROM commands
|
||||||
|
@ -382,7 +378,6 @@ func (cmd Command) commandGet() ([]Query, error) {
|
||||||
AND "command" regexp '%v'
|
AND "command" regexp '%v'
|
||||||
ORDER BY "created" DESC limit '%v'`, cmd.User.ID, cmd.Path, cmd.Query, cmd.Limit)
|
ORDER BY "created" DESC limit '%v'`, cmd.User.ID, cmd.Path, cmd.Query, cmd.Limit)
|
||||||
|
|
||||||
|
|
||||||
} else if cmd.SystemName != "" && cmd.Unique {
|
} else if cmd.SystemName != "" && cmd.Unique {
|
||||||
query = fmt.Sprintf(`
|
query = fmt.Sprintf(`
|
||||||
SELECT "command", "uuid", "created" FROM commands
|
SELECT "command", "uuid", "created" FROM commands
|
||||||
|
@ -404,7 +399,6 @@ func (cmd Command) commandGet() ([]Query, error) {
|
||||||
AND "command" regexp '%v'
|
AND "command" regexp '%v'
|
||||||
GROUP BY "command" ORDER BY "created" DESC limit '%v'`, cmd.User.ID, cmd.Query, cmd.Limit)
|
GROUP BY "command" ORDER BY "created" DESC limit '%v'`, cmd.User.ID, cmd.Query, cmd.Limit)
|
||||||
|
|
||||||
|
|
||||||
} else if cmd.Query != "" {
|
} else if cmd.Query != "" {
|
||||||
query = fmt.Sprintf(`
|
query = fmt.Sprintf(`
|
||||||
SELECT "command", "uuid", "created" FROM commands
|
SELECT "command", "uuid", "created" FROM commands
|
||||||
|
@ -412,7 +406,6 @@ func (cmd Command) commandGet() ([]Query, error) {
|
||||||
AND "command" regexp'%v'
|
AND "command" regexp'%v'
|
||||||
ORDER BY "created" DESC limit '%v'`, cmd.User.ID, cmd.Query, cmd.Limit)
|
ORDER BY "created" DESC limit '%v'`, cmd.User.ID, cmd.Query, cmd.Limit)
|
||||||
|
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
// unique
|
// unique
|
||||||
query = fmt.Sprintf(`
|
query = fmt.Sprintf(`
|
||||||
|
@ -428,14 +421,14 @@ func (cmd Command) commandGet() ([]Query, error) {
|
||||||
SELECT "command", "uuid", "created" FROM commands
|
SELECT "command", "uuid", "created" FROM commands
|
||||||
WHERE "user_id" = '%v'
|
WHERE "user_id" = '%v'
|
||||||
AND "path" = '%v'
|
AND "path" = '%v'
|
||||||
ORDER BY "created" DESC limit '%v'`, cmd.User.ID, cmd.Path, cmd.Limit)
|
ORDER BY "created" DESC limit '%v'`, cmd.User.ID, cmd.Path, cmd.Limit)
|
||||||
} else if cmd.SystemName != "" {
|
} else if cmd.SystemName != "" {
|
||||||
query = fmt.Sprintf(`SELECT "command", "uuid", "created" FROM commands
|
query = fmt.Sprintf(`SELECT "command", "uuid", "created" FROM commands
|
||||||
WHERE "user_id" = '%v'
|
WHERE "user_id" = '%v'
|
||||||
AND "system_name" = '%v'
|
AND "system_name" = '%v'
|
||||||
ORDER BY "created" DESC limit '%v'`, cmd.User.ID, cmd.SystemName, cmd.Limit)
|
ORDER BY "created" DESC limit '%v'`, cmd.User.ID, cmd.SystemName, cmd.Limit)
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
query = fmt.Sprintf(`
|
query = fmt.Sprintf(`
|
||||||
SELECT "command", "uuid", "created" FROM commands
|
SELECT "command", "uuid", "created" FROM commands
|
||||||
WHERE "user_id" = '%v'
|
WHERE "user_id" = '%v'
|
||||||
|
@ -468,11 +461,11 @@ func (cmd Command) commandGet() ([]Query, error) {
|
||||||
func (cmd Command) commandGetUUID() (Query, error) {
|
func (cmd Command) commandGetUUID() (Query, error) {
|
||||||
var result Query
|
var result Query
|
||||||
err := db.QueryRow(`
|
err := db.QueryRow(`
|
||||||
SELECT "command","path", "created" , "uuid", "exit_status", "system_name"
|
SELECT "command","path", "created" , "uuid", "exit_status", "system_name", "process_id"
|
||||||
FROM commands
|
FROM commands
|
||||||
WHERE "uuid" = $1
|
WHERE "uuid" = $1
|
||||||
AND "user_id" = $2`, cmd.Uuid, cmd.User.ID).Scan(&result.Command, &result.Path, &result.Created, &result.Uuid,
|
AND "user_id" = $2`, cmd.Uuid, cmd.User.ID).Scan(&result.Command, &result.Path, &result.Created, &result.Uuid,
|
||||||
&result.ExitStatus, &result.SystemName)
|
&result.ExitStatus, &result.SystemName, &result.SessionID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Query{}, err
|
return Query{}, err
|
||||||
}
|
}
|
||||||
|
@ -546,22 +539,24 @@ func (sys System) systemGet() (System, error) {
|
||||||
func (status Status) statusGet() (Status, error) {
|
func (status Status) statusGet() (Status, error) {
|
||||||
var err error
|
var err error
|
||||||
if connectionLimit != 1 {
|
if connectionLimit != 1 {
|
||||||
err = db.QueryRow(`select
|
err = db.QueryRow(`
|
||||||
( select count(*) from commands where user_id = $1) as totalCommands,
|
select
|
||||||
( select count(distinct process_id) from commands where user_id = $1) as totalSessions,
|
( select count(*) from commands where user_id = $1) as totalCommands,
|
||||||
( select count(distinct system_name) from commands where user_id = $1) as totalSystems,
|
( select count(distinct process_id) from commands where user_id = $1) as totalSessions,
|
||||||
( select count (*) from commands where to_timestamp(cast(created/1000 as bigint))::date = now()::date and user_id = $1) as totalCommandsToday,
|
( select count(*) from systems where user_id = $1) as totalSystems,
|
||||||
( select count(*) from commands where process_id = $2) as sessionTotalCommands`,
|
( select count(*) from commands where to_timestamp(cast(created/1000 as bigint))::date = now()::date and user_id = $1) as totalCommandsToday,
|
||||||
|
( select count(*) from commands where process_id = $2) as sessionTotalCommands`,
|
||||||
status.User.ID, status.ProcessID).Scan(
|
status.User.ID, status.ProcessID).Scan(
|
||||||
&status.TotalCommands, &status.TotalSessions, &status.TotalSystems,
|
&status.TotalCommands, &status.TotalSessions, &status.TotalSystems,
|
||||||
&status.TotalCommandsToday, &status.SessionTotalCommands)
|
&status.TotalCommandsToday, &status.SessionTotalCommands)
|
||||||
} else {
|
} else {
|
||||||
err = db.QueryRow(`select
|
err = db.QueryRow(`
|
||||||
( select count(*) from commands where user_id = $1) as totalCommands,
|
select
|
||||||
( select count(distinct process_id) from commands where user_id = $1) as totalSessions,
|
( select count(*) from commands where user_id = $1) as totalCommands,
|
||||||
( select count(distinct system_name) from commands where user_id = $1) as totalSystems,
|
( select count(distinct process_id) from commands where user_id = $1) as totalSessions,
|
||||||
( select count(*) from commands where date(created/1000, 'unixepoch') = date('now') and user_id = $1) as totalCommandsToday,
|
( select count(*) from systems where user_id = $1) as totalSystems,
|
||||||
( select count(*) from commands where process_id = $2) as sessionTotalCommands`,
|
( select count(*) from commands where date(created/1000, 'unixepoch') = date('now') and user_id = $1) as totalCommandsToday,
|
||||||
|
( select count(*) from commands where process_id = $2) as sessionTotalCommands`,
|
||||||
status.User.ID, status.ProcessID).Scan(
|
status.User.ID, status.ProcessID).Scan(
|
||||||
&status.TotalCommands, &status.TotalSessions, &status.TotalSystems,
|
&status.TotalCommands, &status.TotalSessions, &status.TotalSystems,
|
||||||
&status.TotalCommandsToday, &status.SessionTotalCommands)
|
&status.TotalCommandsToday, &status.SessionTotalCommands)
|
||||||
|
@ -573,12 +568,10 @@ func (status Status) statusGet() (Status, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func importCommands(imp Import) {
|
func importCommands(imp Import) {
|
||||||
_, err := db.Exec(`INSERT INTO commands
|
_, err := db.Exec(`
|
||||||
("command", "path", "created", "uuid", "exit_status",
|
INSERT INTO commands "command", "path", "created", "uuid", "exit_status","system_name", "session_id", "user_id" )
|
||||||
"system_name", "session_id", "user_id" )
|
VALUES ($1,$2,$3,$4,$5,$6,$7,(select "id" from users where "username" = $8))ON CONFLICT do nothing`,
|
||||||
VALUES ($1,$2,$3,$4,$5,$6,$7,(select "id" from users where "username" = $8)) ON CONFLICT do nothing`,
|
imp.Command, imp.Path, imp.Created, imp.Uuid, imp.ExitStatus, imp.SystemName, imp.SessionID, imp.Username)
|
||||||
imp.Command, imp.Path, imp.Created, imp.Uuid, imp.ExitStatus,
|
|
||||||
imp.SystemName, imp.SessionID, imp.Username)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println(err)
|
log.Println(err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,45 +23,68 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"log"
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"strconv"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
noCleanup = flag.Bool("no-cleanup", false, "don't remove testdata directory with sqlite db after test")
|
testWork = flag.Bool("testwork", false, "don't remove sqlite db and server log when done and print location")
|
||||||
postgres = flag.Bool("postgres", false, "run postgres tests")
|
postgres = flag.Bool("postgres", false, "run postgres tests")
|
||||||
postgresUri = flag.String("postgres-uri", "postgres://postgres:@localhost:5444?sslmode=disable", "postgres uri to use for postgres tests")
|
postgresUri = flag.String("postgres-uri", "postgres://postgres:@localhost:5444?sslmode=disable", "postgres uri to use for postgres tests")
|
||||||
|
sessionStartTime int64
|
||||||
|
pid string
|
||||||
|
dir string
|
||||||
|
router *gin.Engine
|
||||||
|
sysRegistered bool
|
||||||
|
jwtToken string
|
||||||
|
testDir string
|
||||||
|
system sysStruct
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type sysStruct struct {
|
||||||
|
user string
|
||||||
|
pass string
|
||||||
|
mac int
|
||||||
|
email string
|
||||||
|
systemName string
|
||||||
|
host string
|
||||||
|
}
|
||||||
|
|
||||||
func TestMain(m *testing.M) {
|
func TestMain(m *testing.M) {
|
||||||
|
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
dirCleanup()
|
defer dirCleanup()
|
||||||
defer func() {
|
|
||||||
if !*noCleanup {
|
|
||||||
dirCleanup()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
dir, err = os.Getwd()
|
testDir, err = ioutil.TempDir("", "bashhub-server-test-")
|
||||||
if err != nil {
|
check(err)
|
||||||
log.Fatal(err)
|
dir = "/tmp/foo"
|
||||||
}
|
|
||||||
testDir = os.TempDir()
|
|
||||||
log.Println("test directory", testDir)
|
|
||||||
DbPath = filepath.Join(testDir, "test.db")
|
DbPath = filepath.Join(testDir, "test.db")
|
||||||
LogFile = filepath.Join(testDir, "server.log")
|
LogFile = filepath.Join(testDir, "server.log")
|
||||||
log.Print("sqlite tests")
|
log.Print("sqlite tests")
|
||||||
router = setupRouter()
|
router = setupRouter()
|
||||||
|
|
||||||
|
system = sysStruct{
|
||||||
|
user: "tester",
|
||||||
|
pass: "tester",
|
||||||
|
mac: 888888888888888,
|
||||||
|
email: "test@email.com",
|
||||||
|
host: "some-host",
|
||||||
|
}
|
||||||
m.Run()
|
m.Run()
|
||||||
|
|
||||||
if *postgres {
|
if *postgres {
|
||||||
|
@ -73,40 +96,142 @@ func TestMain(m *testing.M) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func testRequest(method string, u string, body io.Reader) *httptest.ResponseRecorder {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(method, u, body)
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Add("Authorization", jwtToken)
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
return w
|
||||||
|
}
|
||||||
|
|
||||||
|
func check(err error) {
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func createUser(t *testing.T) {
|
||||||
|
auth := map[string]interface{}{
|
||||||
|
"email": system.email,
|
||||||
|
"Username": system.user,
|
||||||
|
"password": system.pass,
|
||||||
|
}
|
||||||
|
|
||||||
|
payloadBytes, err := json.Marshal(auth)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
body := bytes.NewReader(payloadBytes)
|
||||||
|
w := testRequest("POST", "/api/v1/user", body)
|
||||||
|
assert.Equal(t, 200, w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getToken(t *testing.T) string {
|
||||||
|
|
||||||
|
auth := map[string]interface{}{
|
||||||
|
"username": system.user,
|
||||||
|
"password": system.pass,
|
||||||
|
"mac": strconv.Itoa(system.mac),
|
||||||
|
}
|
||||||
|
|
||||||
|
payloadBytes, err := json.Marshal(auth)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
body := bytes.NewReader(payloadBytes)
|
||||||
|
w := testRequest("POST", "/api/v1/login", body)
|
||||||
|
assert.Equal(t, 200, w.Code)
|
||||||
|
|
||||||
|
buf, err := ioutil.ReadAll(w.Body)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
j := make(map[string]interface{})
|
||||||
|
|
||||||
|
json.Unmarshal(buf, &j)
|
||||||
|
|
||||||
|
if len(j) == 0 {
|
||||||
|
t.Fatal("login failed for getToken")
|
||||||
|
|
||||||
|
}
|
||||||
|
token := fmt.Sprintf("Bearer %v", j["accessToken"])
|
||||||
|
|
||||||
|
if !sysRegistered {
|
||||||
|
// register system
|
||||||
|
return sysRegister(t, token)
|
||||||
|
}
|
||||||
|
return token
|
||||||
|
}
|
||||||
|
|
||||||
|
func sysRegister(t *testing.T, token string) string {
|
||||||
|
|
||||||
|
jwtToken = token
|
||||||
|
sysPayload := map[string]interface{}{
|
||||||
|
"clientVersion": "1.2.0",
|
||||||
|
"name": system.systemName,
|
||||||
|
"hostname": system.host,
|
||||||
|
"mac": strconv.Itoa(system.mac),
|
||||||
|
}
|
||||||
|
payloadBytes, err := json.Marshal(sysPayload)
|
||||||
|
check(err)
|
||||||
|
|
||||||
|
body := bytes.NewReader(payloadBytes)
|
||||||
|
|
||||||
|
w := testRequest("POST", "/api/v1/system", body)
|
||||||
|
assert.Equal(t, 201, w.Code)
|
||||||
|
|
||||||
|
sysRegistered = true
|
||||||
|
|
||||||
|
return getToken(t)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
func TestToken(t *testing.T) {
|
func TestToken(t *testing.T) {
|
||||||
createUser(t)
|
createUser(t)
|
||||||
sysRegistered = false
|
systems := []string{
|
||||||
jwtToken = getToken(t)
|
"system-1",
|
||||||
|
"system-2",
|
||||||
|
"system-3",
|
||||||
|
}
|
||||||
|
for _, sys := range systems {
|
||||||
|
system.systemName = sys
|
||||||
|
system.mac++
|
||||||
|
sysRegistered = false
|
||||||
|
jwtToken = getToken(t)
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCommandInsert(t *testing.T) {
|
func TestCommandInsert(t *testing.T) {
|
||||||
var commandTests = []Command{
|
var commandTests = []Command{
|
||||||
{ProcessId: 90226, ExitStatus: 0, Command: "cat foo.txt"},
|
{ExitStatus: 0, Command: "cat foo.txt"},
|
||||||
{ProcessId: 90226, ExitStatus: 0, Command: "ls"},
|
{ExitStatus: 0, Command: "ls"},
|
||||||
{ProcessId: 90226, ExitStatus: 0, Command: "pwd"},
|
{ExitStatus: 0, Command: "pwd"},
|
||||||
{ProcessId: 90226, ExitStatus: 0, Command: "whoami"},
|
{ExitStatus: 0, Command: "whoami"},
|
||||||
{ProcessId: 90226, ExitStatus: 0, Command: "which cat"},
|
{ExitStatus: 0, Command: "which cat"},
|
||||||
{ProcessId: 90226, ExitStatus: 0, Command: "head foo.txt"},
|
{ExitStatus: 0, Command: "head foo.txt"},
|
||||||
{ProcessId: 90226, ExitStatus: 0, Command: "sed 's/fooobaar/foobar/g' somefile.txt"},
|
{ExitStatus: 0, Command: "sed 's/fooobaar/foobar/g' somefile.txt"},
|
||||||
{ProcessId: 90226, ExitStatus: 0, Command: "curl google.com"},
|
{ExitStatus: 0, Command: "curl google.com"},
|
||||||
{ProcessId: 90226, ExitStatus: 0, Command: "file /dev/null"},
|
{ExitStatus: 0, Command: "file /dev/null"},
|
||||||
{ProcessId: 90226, ExitStatus: 0, Command: "df -h"},
|
{ExitStatus: 0, Command: "df -h"},
|
||||||
{ProcessId: 90226, ExitStatus: 127, Command: "catt"},
|
{ExitStatus: 127, Command: "catt"},
|
||||||
{ProcessId: 90226, ExitStatus: 127, Command: "cay"},
|
{ExitStatus: 127, Command: "cay"},
|
||||||
}
|
}
|
||||||
|
|
||||||
hourAgo := time.Now().UnixNano() - (1 * time.Hour).Nanoseconds()
|
sessionStartTime = time.Now().Unix() * 1000
|
||||||
|
|
||||||
for i := 0; i < 5; i++ {
|
for i := 0; i < 5; i++ {
|
||||||
for _, tc := range commandTests {
|
for _, tc := range commandTests {
|
||||||
uid, err := uuid.NewRandom()
|
uid, err := uuid.NewRandom()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
tc.ProcessId = i
|
||||||
tc.Path = dir
|
tc.Path = dir
|
||||||
tc.Created = time.Now().Unix()
|
tc.Created = time.Now().Unix() * 1000
|
||||||
tc.ProcessStartTime = hourAgo
|
tc.ProcessStartTime = sessionStartTime
|
||||||
tc.Uuid = uid.String()
|
tc.Uuid = uid.String()
|
||||||
payloadBytes, err := json.Marshal(&tc)
|
payloadBytes, err := json.Marshal(&tc)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -116,6 +241,7 @@ func TestCommandInsert(t *testing.T) {
|
||||||
w := testRequest("POST", "/api/v1/command", body)
|
w := testRequest("POST", "/api/v1/command", body)
|
||||||
assert.Equal(t, 200, w.Code)
|
assert.Equal(t, 200, w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -125,14 +251,14 @@ func TestCommandQuery(t *testing.T) {
|
||||||
expect int
|
expect int
|
||||||
}
|
}
|
||||||
var queryTests = []queryTest{
|
var queryTests = []queryTest{
|
||||||
{query: fmt.Sprintf("path=%v&unique=true&systemName=%v&query=^curl", url.QueryEscape(dir), system), expect: 1},
|
{query: fmt.Sprintf("path=%v&unique=true&systemName=%v&query=^curl", url.QueryEscape(dir), system.systemName), expect: 1},
|
||||||
{query: fmt.Sprintf("path=%v&query=^curl&unique=true", url.QueryEscape(dir)), expect: 1},
|
{query: fmt.Sprintf("path=%v&query=^curl&unique=true", url.QueryEscape(dir)), expect: 1},
|
||||||
{query: fmt.Sprintf("systemName=%v&query=^curl", system), expect: 5},
|
{query: fmt.Sprintf("systemName=%v&query=^curl", system.systemName), expect: 5},
|
||||||
{query: fmt.Sprintf("path=%v&query=^curl", url.QueryEscape(dir)), expect: 5},
|
{query: fmt.Sprintf("path=%v&query=^curl", url.QueryEscape(dir)), expect: 5},
|
||||||
{query: fmt.Sprintf("systemName=%v&unique=true", system), expect: 10},
|
{query: fmt.Sprintf("systemName=%v&unique=true", system.systemName), expect: 10},
|
||||||
{query: fmt.Sprintf("path=%v&unique=true", url.QueryEscape(dir)), expect: 10},
|
{query: fmt.Sprintf("path=%v&unique=true", url.QueryEscape(dir)), expect: 10},
|
||||||
{query: fmt.Sprintf("path=%v", url.QueryEscape(dir)), expect: 50},
|
{query: fmt.Sprintf("path=%v", url.QueryEscape(dir)), expect: 50},
|
||||||
{query: fmt.Sprintf("systemName=%v", system), expect: 50},
|
{query: fmt.Sprintf("systemName=%v", system.systemName), expect: 50},
|
||||||
{query: "query=^curl&unique=true", expect: 1},
|
{query: "query=^curl&unique=true", expect: 1},
|
||||||
{query: "query=^curl", expect: 5},
|
{query: "query=^curl", expect: 5},
|
||||||
{query: "unique=true", expect: 10},
|
{query: "unique=true", expect: 10},
|
||||||
|
@ -157,7 +283,7 @@ func TestCommandQuery(t *testing.T) {
|
||||||
if v.expect != len(data) {
|
if v.expect != len(data) {
|
||||||
t.Fatalf("expected: %v, got: %v -- query: %v ", v.expect, len(data), v.query)
|
t.Fatalf("expected: %v, got: %v -- query: %v ", v.expect, len(data), v.query)
|
||||||
}
|
}
|
||||||
assert.Contains(t, system, data[0].SystemName)
|
assert.Contains(t, system.systemName, data[0].SystemName)
|
||||||
assert.Contains(t, dir, data[0].Path)
|
assert.Contains(t, dir, data[0].Path)
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
@ -201,6 +327,7 @@ func TestCommandFindDelete(t *testing.T) {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
assert.Equal(t, record.Uuid, data.Uuid)
|
assert.Equal(t, record.Uuid, data.Uuid)
|
||||||
|
pid = data.SessionID
|
||||||
}()
|
}()
|
||||||
|
|
||||||
func() {
|
func() {
|
||||||
|
@ -224,3 +351,36 @@ func TestCommandFindDelete(t *testing.T) {
|
||||||
}()
|
}()
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestStatus(t *testing.T) {
|
||||||
|
u := fmt.Sprintf("/api/v1/client-view/status?processId=%v&startTime=%v", pid, sessionStartTime)
|
||||||
|
w := testRequest("GET", u, nil)
|
||||||
|
assert.Equal(t, 200, w.Code)
|
||||||
|
b, err := ioutil.ReadAll(w.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
var status Status
|
||||||
|
err = json.Unmarshal(b, &status)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
assert.Equal(t, status.TotalCommands, 49)
|
||||||
|
assert.Equal(t, status.TotalSessions, 5)
|
||||||
|
assert.Equal(t, status.TotalSystems, 3)
|
||||||
|
assert.Equal(t, status.TotalCommandsToday, 49)
|
||||||
|
assert.Equal(t, status.SessionTotalCommands, 9)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
func dirCleanup() {
|
||||||
|
if !*testWork {
|
||||||
|
os.Chmod(testDir, 0777)
|
||||||
|
err := os.RemoveAll(testDir)
|
||||||
|
check(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Println("TESTWORK=", testDir)
|
||||||
|
|
||||||
|
}
|
|
@ -1,160 +0,0 @@
|
||||||
/*
|
|
||||||
*
|
|
||||||
* Copyright © 2020 nicksherron <nsherron90@gmail.com>
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
|
|
||||||
package internal
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"io/ioutil"
|
|
||||||
"log"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
dir string
|
|
||||||
router *gin.Engine
|
|
||||||
sysRegistered bool
|
|
||||||
jwtToken string
|
|
||||||
testDir string
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
user = "tester"
|
|
||||||
pass = "tester"
|
|
||||||
mac = "888888888888888"
|
|
||||||
email = "test@email.com"
|
|
||||||
system = "system"
|
|
||||||
)
|
|
||||||
|
|
||||||
func testRequest(method string, u string, body io.Reader) *httptest.ResponseRecorder {
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
req, _ := http.NewRequest(method, u, body)
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
req.Header.Add("Authorization", jwtToken)
|
|
||||||
router.ServeHTTP(w, req)
|
|
||||||
return w
|
|
||||||
}
|
|
||||||
|
|
||||||
func check(err error) {
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func createUser(t *testing.T) {
|
|
||||||
auth := map[string]interface{}{
|
|
||||||
"Username": user,
|
|
||||||
"password": pass,
|
|
||||||
"email": email,
|
|
||||||
}
|
|
||||||
|
|
||||||
payloadBytes, err := json.Marshal(auth)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
body := bytes.NewReader(payloadBytes)
|
|
||||||
w := testRequest("POST", "/api/v1/user", body)
|
|
||||||
assert.Equal(t, 200, w.Code)
|
|
||||||
}
|
|
||||||
|
|
||||||
func getToken(t *testing.T) string {
|
|
||||||
|
|
||||||
auth := map[string]interface{}{
|
|
||||||
"username": user,
|
|
||||||
"password": pass,
|
|
||||||
"mac": mac,
|
|
||||||
}
|
|
||||||
|
|
||||||
payloadBytes, err := json.Marshal(auth)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
body := bytes.NewReader(payloadBytes)
|
|
||||||
w := testRequest("POST", "/api/v1/login", body)
|
|
||||||
assert.Equal(t, 200, w.Code)
|
|
||||||
|
|
||||||
buf, err := ioutil.ReadAll(w.Body)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
j := make(map[string]interface{})
|
|
||||||
|
|
||||||
json.Unmarshal(buf, &j)
|
|
||||||
|
|
||||||
if len(j) == 0 {
|
|
||||||
t.Fatal("login failed for getToken")
|
|
||||||
|
|
||||||
}
|
|
||||||
token := fmt.Sprintf("Bearer %v", j["accessToken"])
|
|
||||||
|
|
||||||
if !sysRegistered {
|
|
||||||
// register system
|
|
||||||
return sysRegister(t, token)
|
|
||||||
}
|
|
||||||
return token
|
|
||||||
}
|
|
||||||
|
|
||||||
func sysRegister(t *testing.T, token string) string {
|
|
||||||
|
|
||||||
jwtToken = token
|
|
||||||
host, err := os.Hostname()
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
sys := map[string]interface{}{
|
|
||||||
"clientVersion": "1.2.0",
|
|
||||||
"name": system,
|
|
||||||
"hostname": host,
|
|
||||||
"mac": mac,
|
|
||||||
}
|
|
||||||
payloadBytes, err := json.Marshal(sys)
|
|
||||||
check(err)
|
|
||||||
|
|
||||||
body := bytes.NewReader(payloadBytes)
|
|
||||||
|
|
||||||
w := testRequest("POST", "/api/v1/system", body)
|
|
||||||
assert.Equal(t, 201, w.Code)
|
|
||||||
|
|
||||||
sysRegistered = true
|
|
||||||
|
|
||||||
return getToken(t)
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func dirCleanup() {
|
|
||||||
dbFiles := []string{
|
|
||||||
"test.db", "test.db-shm", "test.db-wal",
|
|
||||||
}
|
|
||||||
for _, d := range dbFiles {
|
|
||||||
err := os.RemoveAll(filepath.Join([]string{testDir, d}...))
|
|
||||||
check(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue