mirror of
https://github.com/nicksherron/bashhub-server.git
synced 2025-09-03 19:04:11 +08:00
444 lines
10 KiB
Go
444 lines
10 KiB
Go
/*
|
|
*
|
|
* Copyright © 2020 nicksherron <nsherron90@gmail.com>
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*
|
|
*/
|
|
|
|
package cmd
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"log"
|
|
"net/http"
|
|
"os"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/cheggaaa/pb/v3"
|
|
"github.com/spf13/cobra"
|
|
"golang.org/x/crypto/ssh/terminal"
|
|
)
|
|
|
|
type cList struct {
|
|
Retries int
|
|
UUID string `json:"uuid"`
|
|
Command string `json:"command"`
|
|
Created int64 `json:"created"`
|
|
}
|
|
|
|
type commandsList []cList
|
|
|
|
var (
|
|
barTemplate = `{{string . "message" | green }}{{counters . }} {{bar . }} {{percent . }} {{speed . "%s inserts/sec" | green}}`
|
|
bar *pb.ProgressBar
|
|
progress bool
|
|
srcUser string
|
|
dstUser string
|
|
srcURL string
|
|
dstURL string
|
|
srcPass string
|
|
dstPass string
|
|
srcToken string
|
|
dstToken string
|
|
sysRegistered bool
|
|
workers int
|
|
unique bool
|
|
limit int
|
|
dstCounter uint64
|
|
srcCounter uint64
|
|
inserted uint64
|
|
wgSrc sync.WaitGroup
|
|
wgDst sync.WaitGroup
|
|
cmdList commandsList
|
|
|
|
transferCmd = &cobra.Command{
|
|
Use: "transfer",
|
|
Short: "Transfer bashhub history from one server to another",
|
|
Run: func(cmd *cobra.Command, args []string) {
|
|
cmd.Flags().Parse(args)
|
|
switch {
|
|
case srcUser == "":
|
|
_ = cmd.Usage()
|
|
fmt.Print("\n\n")
|
|
log.Fatal("src-user can't be blank")
|
|
case dstUser == "":
|
|
_ = cmd.Usage()
|
|
fmt.Print("\n\n")
|
|
log.Fatal("--dst-user can't be blank")
|
|
case srcPass == "" || dstPass == "":
|
|
if srcPass == "" {
|
|
srcPass = credentials("source")
|
|
}
|
|
if dstPass == "" {
|
|
dstPass = credentials("destination")
|
|
}
|
|
|
|
}
|
|
|
|
if workers > 10 && srcURL == "https://bashhub.com" {
|
|
msg := fmt.Sprintf(`
|
|
WARNING: errors are likely to occur when setting workers higher
|
|
than 10 when transferring from https://bashhub.com`)
|
|
fmt.Print(msg, "\n\n")
|
|
}
|
|
run()
|
|
},
|
|
}
|
|
)
|
|
|
|
func init() {
|
|
rootCmd.AddCommand(transferCmd)
|
|
transferCmd.PersistentFlags().StringVar(&srcURL, "src-url", "https://bashhub.com", "source url ")
|
|
transferCmd.PersistentFlags().StringVar(&srcUser, "src-user", "", "source username")
|
|
transferCmd.PersistentFlags().StringVar(&srcPass, "src-pass", "", "source password (default is password prompt)")
|
|
transferCmd.PersistentFlags().StringVar(&dstURL, "dst-url", "http://localhost:8080", "destination url")
|
|
transferCmd.PersistentFlags().StringVar(&dstUser, "dst-user", "", "destination username")
|
|
transferCmd.PersistentFlags().StringVar(&dstPass, "dst-pass", "", "destination password (default is password prompt)")
|
|
transferCmd.PersistentFlags().BoolVarP(&progress, "quiet", "q", false, "don't show progress bar")
|
|
transferCmd.PersistentFlags().IntVarP(&workers, "workers", "w", 10, "max number of concurrent requests")
|
|
transferCmd.PersistentFlags().BoolVarP(&unique, "unique", "u", true, "don't include duplicate commands")
|
|
transferCmd.PersistentFlags().IntVarP(&limit, "number", "n", 10000, "limit number of commands to transfer")
|
|
|
|
}
|
|
func credentials(s string) string {
|
|
|
|
fmt.Printf("\nEnter %s password: ", s)
|
|
bytePassword, err := terminal.ReadPassword(0)
|
|
if err != nil {
|
|
check(err)
|
|
}
|
|
password := string(bytePassword)
|
|
|
|
return strings.TrimSpace(password)
|
|
}
|
|
func run() {
|
|
sysRegistered = false
|
|
srcToken = getToken(srcURL, srcUser, srcPass)
|
|
sysRegistered = false
|
|
dstToken = getToken(dstURL, dstUser, dstPass)
|
|
cmdList = getCommandList()
|
|
|
|
if !progress {
|
|
bar = pb.ProgressBarTemplate(barTemplate).Start(len(cmdList)).SetMaxWidth(70)
|
|
bar.Set("message", "transferring ")
|
|
}
|
|
fmt.Print("\nstarting transfer...\n\n")
|
|
queue := make(chan cList, len(cmdList))
|
|
pipe := make(chan []byte, len(cmdList))
|
|
|
|
// ignore http errors. We try and recover them
|
|
log.SetOutput(nil)
|
|
go func() {
|
|
for {
|
|
select {
|
|
case item := <-queue:
|
|
wgDst.Add(1)
|
|
atomic.AddUint64(&dstCounter, 1)
|
|
go func(cmd cList) {
|
|
defer wgDst.Done()
|
|
commandLookup(cmd, pipe, queue)
|
|
}(item)
|
|
if atomic.CompareAndSwapUint64(&dstCounter, uint64(workers), 0) {
|
|
wgDst.Wait()
|
|
}
|
|
case result := <-pipe:
|
|
wgSrc.Add(1)
|
|
atomic.AddUint64(&srcCounter, 1)
|
|
go func(data []byte) {
|
|
srcSend(data, 0)
|
|
}(result)
|
|
if atomic.CompareAndSwapUint64(&srcCounter, uint64(workers), 0) {
|
|
wgSrc.Wait()
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
for _, v := range cmdList {
|
|
v.Retries = 0
|
|
queue <- v
|
|
}
|
|
|
|
for {
|
|
if atomic.CompareAndSwapUint64(&inserted, uint64(len(cmdList)), 0) {
|
|
break
|
|
}
|
|
time.Sleep(200 * time.Millisecond)
|
|
}
|
|
if !progress {
|
|
bar.Finish()
|
|
}
|
|
}
|
|
func sysRegister(mac string, site string, user string, pass string) string {
|
|
|
|
var token string
|
|
func() {
|
|
var null *string
|
|
auth := map[string]interface{}{
|
|
"username": user,
|
|
"password": pass,
|
|
"mac": null,
|
|
}
|
|
|
|
payloadBytes, err := json.Marshal(auth)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
body := bytes.NewReader(payloadBytes)
|
|
|
|
u := fmt.Sprintf("%v/api/v1/login", site)
|
|
req, err := http.NewRequest("POST", u, body)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
buf, err := ioutil.ReadAll(resp.Body)
|
|
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
j := make(map[string]interface{})
|
|
|
|
err = json.Unmarshal(buf, &j)
|
|
check(err)
|
|
|
|
if len(j) == 0 {
|
|
log.Fatal("login failed for ", site)
|
|
|
|
}
|
|
token = fmt.Sprintf("Bearer %v", j["accessToken"])
|
|
|
|
}()
|
|
|
|
host, err := os.Hostname()
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
sys := map[string]interface{}{
|
|
"clientVersion": "1.2.0",
|
|
"name": "transfer",
|
|
"hostname": host,
|
|
"mac": mac,
|
|
}
|
|
|
|
payloadBytes, err := json.Marshal(sys)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
body := bytes.NewReader(payloadBytes)
|
|
|
|
u := fmt.Sprintf("%v/api/v1/system", srcURL)
|
|
req, err := http.NewRequest("POST", u, body)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Add("Authorization", token)
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
sysRegistered = true
|
|
return getToken(site, user, pass)
|
|
|
|
}
|
|
|
|
func getToken(site string, user string, pass string) string {
|
|
mac := "888888888888888"
|
|
auth := map[string]interface{}{
|
|
"username": user,
|
|
"password": pass,
|
|
"mac": mac,
|
|
}
|
|
|
|
payloadBytes, err := json.Marshal(auth)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
body := bytes.NewReader(payloadBytes)
|
|
|
|
u := fmt.Sprintf("%v/api/v1/login", site)
|
|
req, err := http.NewRequest("POST", u, body)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode == 409 && !sysRegistered {
|
|
// register system
|
|
return sysRegister(mac, site, user, pass)
|
|
}
|
|
buf, err := ioutil.ReadAll(resp.Body)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
j := make(map[string]interface{})
|
|
|
|
err = json.Unmarshal(buf, &j)
|
|
check(err)
|
|
|
|
if len(j) == 0 || resp.StatusCode == 401 {
|
|
log.Fatal("login failed for ", site)
|
|
|
|
}
|
|
return fmt.Sprintf("Bearer %v", j["accessToken"])
|
|
}
|
|
|
|
func getCommandList() commandsList {
|
|
u := strings.TrimSpace(srcURL) + fmt.Sprintf("/api/v1/command/search?unique=%v&limit=%v", unique, limit)
|
|
req, err := http.NewRequest("GET", u, nil)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
req.Header.Add("Authorization", srcToken)
|
|
|
|
client := &http.Client{}
|
|
resp, err := client.Do(req)
|
|
|
|
if err != nil {
|
|
log.Fatal("Error on response.\n", err)
|
|
}
|
|
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != 200 {
|
|
log.Fatalf("failed to get command list from %v, go status code %v", srcURL, resp.StatusCode)
|
|
}
|
|
body, err := ioutil.ReadAll(resp.Body)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
var result commandsList
|
|
err = json.Unmarshal(body, &result)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
return result
|
|
}
|
|
|
|
func commandLookup(item cList, pipe chan []byte, queue chan cList) {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
mem := strings.Contains(fmt.Sprintf("%v", r), "runtime error: invalid memory address")
|
|
eof := strings.Contains(fmt.Sprintf("%v", r), "EOF")
|
|
if mem || eof {
|
|
if item.Retries < 10 {
|
|
item.Retries++
|
|
queue <- item
|
|
return
|
|
} else {
|
|
log.SetOutput(os.Stderr)
|
|
log.Println("ERROR: failed over 10 times looking up command from source with uuid: ", item.UUID)
|
|
log.SetOutput(nil)
|
|
}
|
|
} else {
|
|
log.SetOutput(os.Stderr)
|
|
log.Fatal(r)
|
|
}
|
|
}
|
|
}()
|
|
|
|
u := strings.TrimSpace(srcURL) + "/api/v1/command/" + strings.TrimSpace(item.UUID)
|
|
req, err := http.NewRequest("GET", u, nil)
|
|
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
req.Header.Add("Authorization", srcToken)
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
body, err := ioutil.ReadAll(resp.Body)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
if resp.StatusCode != 200 {
|
|
err = fmt.Errorf("%v response from %v: %v", resp.StatusCode, srcURL, string(body))
|
|
log.SetOutput(os.Stderr)
|
|
log.Fatal(err)
|
|
}
|
|
pipe <- body
|
|
}
|
|
|
|
func srcSend(data []byte, retries int) {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
retries++
|
|
if retries < 10 {
|
|
srcSend(data, retries)
|
|
return
|
|
}
|
|
log.SetOutput(os.Stderr)
|
|
log.Println("Error on response.\n", r)
|
|
log.SetOutput(nil)
|
|
}
|
|
if !progress {
|
|
bar.Add(1)
|
|
}
|
|
atomic.AddUint64(&inserted, 1)
|
|
wgSrc.Done()
|
|
}()
|
|
|
|
body := bytes.NewReader(data)
|
|
|
|
u := dstURL + "/api/v1/import"
|
|
req, err := http.NewRequest("POST", u, body)
|
|
if err != nil {
|
|
log.SetOutput(os.Stderr)
|
|
log.Fatal(err)
|
|
}
|
|
|
|
req.Header.Add("Authorization", dstToken)
|
|
_, err = http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
log.SetOutput(os.Stderr)
|
|
log.Fatal(err)
|
|
}
|
|
}
|
|
|
|
func check(err error) {
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
}
|