From a606897f71263e336a71f92125a2a3061a583f4c Mon Sep 17 00:00:00 2001 From: nicksherron Date: Tue, 11 Feb 2020 22:13:37 -0500 Subject: [PATCH] added system patch endpoint and cmd/transfer: passes http.client to goroutines --- cmd/transfer.go | 24 ++++++++++++------- internal/db.go | 38 ++++++++++++++++++++++++++---- internal/server.go | 58 +++++++++++++++++++++++++++++++++++++--------- 3 files changed, 97 insertions(+), 23 deletions(-) diff --git a/cmd/transfer.go b/cmd/transfer.go index bfb0785..655735f 100644 --- a/cmd/transfer.go +++ b/cmd/transfer.go @@ -55,10 +55,10 @@ var ( srcToken string dstToken string sysRegistered bool - workers int + workers int wg sync.WaitGroup cmdList commandsList - transferCmd = &cobra.Command{ + transferCmd = &cobra.Command{ Use: "transfer", Short: "transfer bashhub history ", Run: func(cmd *cobra.Command, args []string) { @@ -73,12 +73,15 @@ var ( bar = pb.ProgressBarTemplate(barTemplate).Start(len(cmdList)).SetMaxWidth(70) bar.Set("message", "inserting records \t") } + client := &http.Client{} for _, v := range cmdList { + //commandLookup(v.UUID, client) + //} wg.Add(1) counter++ go func(c cList) { defer wg.Done() - commandLookup(c.UUID) + commandLookup(c.UUID, client) }(v) if counter > workers { wg.Wait() @@ -276,7 +279,7 @@ func getCommandList() commandsList { return result } -func commandLookup(uuid string) { +func commandLookup(uuid string, client *http.Client) { u := strings.TrimSpace(srcURL) + "/api/v1/command/" + strings.TrimSpace(uuid) req, err := http.NewRequest("GET", u, nil) if err != nil { @@ -285,13 +288,19 @@ func commandLookup(uuid string) { req.Header.Add("Authorization", srcToken) - client := &http.Client{} resp, err := client.Do(req) if err != nil { log.Println("Error on response.\n", err) } + //defer func() { + // err = resp.Body.Close() + // if err != nil { + // log.Println(err) + // } + // + //}() defer resp.Body.Close() if resp.StatusCode != 200 { log.Fatalf("failed command lookup from %v, go status code %v", srcURL, resp.StatusCode) @@ -300,10 +309,10 @@ func commandLookup(uuid string) { if err != nil { log.Fatal(err) } - srcSend(body) + srcSend(body, client) } -func srcSend(data []byte) { +func srcSend(data []byte, client *http.Client) { defer func() { if !progress { bar.Add(1) @@ -317,7 +326,6 @@ func srcSend(data []byte) { log.Fatal(err) } req.Header.Add("Authorization", dstToken) - client := &http.Client{} resp, err := client.Do(req) if err != nil { diff --git a/internal/db.go b/internal/db.go index 1d96fe3..99558eb 100644 --- a/internal/db.go +++ b/internal/db.go @@ -158,6 +158,18 @@ func (user User) userExists() bool { return false } +func (user User) userGetID() uint { + var id uint + err := db.QueryRow(`SELECT "id" + FROM users + WHERE "username" = $1`, + user.Username).Scan(&id) + if err != nil && err != sql.ErrNoRows { + log.Fatalf("error checking if row exists %v", err) + } + return id +} + func (user User) userGetSystemName() string { var systemName string err := db.QueryRow(`SELECT name @@ -436,7 +448,7 @@ func (cmd Command) commandGet() ([]Query, error) { } if err != nil { - return []Query{}, nil + return []Query{}, nil } defer rows.Close() for rows.Next() { @@ -478,6 +490,24 @@ func (cmd Command) commandDelete() int64 { return inserted } +func (sys System) systemUpdate() int64 { + + t := time.Now().Unix() + res, err := db.Exec(` + UPDATE systems + SET "hostname" = $1 , "updated" = $3 + WHERE "user_id" = $2 + AND "mac" = $3`, + sys.Hostname, t, sys.User.ID, sys.Mac) + if err != nil { + log.Fatal(err) + } + inserted, err := res.RowsAffected() + if err != nil { + log.Fatal(err) + } + return inserted +} func (sys System) systemInsert() int64 { @@ -539,13 +569,13 @@ func (status Status) statusGet() (Status, error) { return status, err } -func importCommands(q Query) { +func importCommands(imp Import) { _, err := db.Exec(`INSERT INTO commands ("command", "path", "created", "uuid", "exit_status", "system_name", "session_id", "user_id" ) VALUES ($1,$2,$3,$4,$5,$6,$7,(select "id" from users where "username" = $8)) ON CONFLICT do nothing`, - q.Command, q.Path, q.Created, q.Uuid, q.ExitStatus, - q.SystemName, q.SessionID, q.Username) + imp.Command, imp.Path, imp.Created, imp.Uuid, imp.ExitStatus, + imp.SystemName, imp.SessionID, imp.Username) if err != nil { log.Println(err) } diff --git a/internal/server.go b/internal/server.go index 1c09254..e74fbd7 100644 --- a/internal/server.go +++ b/internal/server.go @@ -99,6 +99,7 @@ type Config struct { ID int Created time.Time } +type Import Query var ( // Addr is the listen and server address for our server (gin) @@ -172,9 +173,18 @@ func Run() { }, IdentityHandler: func(c *gin.Context) interface{} { claims := jwt.ExtractClaims(c) + var id uint + switch claims["user_id"].(type) { + case float64: + id = uint(claims["user_id"].(float64)) + + default: + id = claims["user_id"].(uint) + } return &User{ Username: claims["username"].(string), SystemName: claims["systemName"].(string), + ID: id, } }, Authenticator: func(c *gin.Context) (interface{}, error) { @@ -187,6 +197,7 @@ func Run() { return &User{ Username: user.Username, SystemName: user.userGetSystemName(), + ID: user.userGetID(), }, nil } fmt.Println("failed") @@ -315,6 +326,7 @@ func Run() { command.SystemName = claims["systemName"].(string) command.commandInsert() + c.AbortWithStatus(http.StatusOK) }) r.DELETE("/api/v1/command/:uuid", func(c *gin.Context) { @@ -329,13 +341,16 @@ func Run() { } command.Uuid = c.Param("uuid") command.commandDelete() + c.AbortWithStatus(http.StatusOK) + }) r.POST("/api/v1/system", func(c *gin.Context) { var system System err := c.Bind(&system) if err != nil { - log.Fatal(err) + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return } claims := jwt.ExtractClaims(c) switch claims["user_id"].(type) { @@ -353,11 +368,6 @@ func Run() { r.GET("/api/v1/system", func(c *gin.Context) { var system System claims := jwt.ExtractClaims(c) - mac := c.Query("mac") - if mac == "" { - c.AbortWithStatus(http.StatusBadRequest) - return - } switch claims["user_id"].(type) { case float64: system.User.ID = uint(claims["user_id"].(float64)) @@ -365,7 +375,11 @@ func Run() { default: system.User.ID = claims["user_id"].(uint) } - + mac := c.Query("mac") + if mac == "" { + c.AbortWithStatus(http.StatusBadRequest) + return + } system.Mac = mac result, err := system.systemGet() if err != nil { @@ -376,6 +390,26 @@ func Run() { }) + r.PATCH("/api/v1/system/:mac", func(c *gin.Context) { + var system System + err := c.Bind(&system) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + claims := jwt.ExtractClaims(c) + switch claims["user_id"].(type) { + case float64: + system.User.ID = uint(claims["user_id"].(float64)) + + default: + system.User.ID = claims["user_id"].(uint) + } + system.Mac = c.Param("mac") + system.systemUpdate() + c.AbortWithStatus(http.StatusOK) + }) + r.GET("/api/v1/client-view/status", func(c *gin.Context) { var status Status claims := jwt.ExtractClaims(c) @@ -412,14 +446,15 @@ func Run() { }) r.POST("/api/v1/import", func(c *gin.Context) { - var query Query - if err := c.ShouldBindJSON(&query); err != nil { + var imp Import + if err := c.ShouldBindJSON(&imp); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } claims := jwt.ExtractClaims(c) - query.Username = claims["username"].(string) - importCommands(query) + imp.Username = claims["username"].(string) + importCommands(imp) + c.AbortWithStatus(http.StatusOK) }) Addr = strings.ReplaceAll(Addr, "http://", "") @@ -428,4 +463,5 @@ func Run() { if err != nil { fmt.Println("Error: \t", err) } + }