From fbb999f36b190f3397ac28b2e16dd86afb9eafc3 Mon Sep 17 00:00:00 2001 From: worker-9 Date: Wed, 18 Aug 2021 18:12:08 -0400 Subject: [PATCH] added sqlite support and ability to add dbs easier --- config/config.go | 1 + database/database.go | 84 +++++++++------------ database/rqlite.go | 117 +++++++++++++++++++++++++++++ database/sqlite.go | 137 ++++++++++++++++++++++++++++++++++ main.go | 5 +- netclient/main.go | 2 +- netclient/wireguard/kernel.go | 34 +++++---- servercfg/serverconf.go | 13 +++- 8 files changed, 323 insertions(+), 70 deletions(-) create mode 100644 database/rqlite.go create mode 100644 database/sqlite.go diff --git a/config/config.go b/config/config.go index e6f370c1..f2e21983 100644 --- a/config/config.go +++ b/config/config.go @@ -53,6 +53,7 @@ type ServerConfig struct { GRPCSSL string `yaml:"grpcssl"` Version string `yaml:"version"` SQLConn string `yaml:"sqlconn"` + Database string `yaml:database` DefaultNodeLimit int32 `yaml:"defaultnodelimit"` Verbosity int32 `yaml:"verbosity"` } diff --git a/database/database.go b/database/database.go index c7705e3a..e8a4c113 100644 --- a/database/database.go +++ b/database/database.go @@ -3,9 +3,8 @@ package database import ( "encoding/json" "errors" - "log" + "github.com/gravitl/netmaker/servercfg" - "github.com/rqlite/gorqlite" ) const NETWORKS_TABLE_NAME = "networks" @@ -22,19 +21,32 @@ const DATABASE_FILENAME = "netmaker.db" const NO_RECORD = "no result found" const NO_RECORDS = "could not find any records" -var Database gorqlite.Connection +// == Constants == +const INIT_DB = "init" +const CREATE_TABLE = "createtable" +const INSERT = "insert" +const INSERT_PEER = "insertpeer" +const DELETE = "delete" +const DELETE_ALL = "deleteall" +const FETCH_ALL = "fetchall" +const CLOSE_DB = "closedb" + +func getCurrentDB() map[string]interface{} { + switch servercfg.GetDB() { + case "rqlite": + return RQLITE_FUNCTIONS + case "sqlite": + return SQLITE_FUNCTIONS + default: + return RQLITE_FUNCTIONS + } +} func InitializeDatabase() error { - //log.Println("sql conn value:",servercfg.GetSQLConn()) - conn, err := gorqlite.Open(servercfg.GetSQLConn()) - if err != nil { + if err := getCurrentDB()[INIT_DB].(func() error)(); err != nil { return err } - - // sqliteDatabase, _ := sql.Open("sqlite3", "./database/"+dbFilename) - Database = conn - Database.SetConsistencyLevel("strong") createTables() return nil } @@ -51,52 +63,36 @@ func createTables() { } func createTable(tableName string) error { - _, err := Database.WriteOne("CREATE TABLE IF NOT EXISTS " + tableName + " (key TEXT NOT NULL UNIQUE PRIMARY KEY, value TEXT)") - if err != nil { - return err - } - return nil + return getCurrentDB()[CREATE_TABLE].(func(string) error)(tableName) } -func isJSONString(value string) bool { +func IsJSONString(value string) bool { var jsonInt interface{} return json.Unmarshal([]byte(value), &jsonInt) == nil } func Insert(key string, value string, tableName string) error { - if key != "" && value != "" && isJSONString(value) { - _, err := Database.WriteOne("INSERT OR REPLACE INTO " + tableName + " (key, value) VALUES ('" + key + "', '" + value + "')") - if err != nil { - return err - } - return nil + if key != "" && value != "" && IsJSONString(value) { + return getCurrentDB()[INSERT].(func(string, string, string) error)(key, value, tableName) } else { return errors.New("invalid insert " + key + " : " + value) } } func InsertPeer(key string, value string) error { - if key != "" && value != "" && isJSONString(value) { - _, err := Database.WriteOne("INSERT OR REPLACE INTO " + PEERS_TABLE_NAME + " (key, value) VALUES ('" + key + "', '" + value + "')") - if err != nil { - return err - } - return nil + if key != "" && value != "" && IsJSONString(value) { + return getCurrentDB()[INSERT_PEER].(func(string, string) error)(key, value) } else { return errors.New("invalid peer insert " + key + " : " + value) } } func DeleteRecord(tableName string, key string) error { - _, err := Database.WriteOne("DELETE FROM " + tableName + " WHERE key = \"" + key + "\"") - if err != nil { - return err - } - return nil + return getCurrentDB()[DELETE].(func(string, string) error)(tableName, key) } func DeleteAllRecords(tableName string) error { - _, err := Database.WriteOne("DELETE TABLE " + tableName) + err := getCurrentDB()[DELETE_ALL].(func(string) error)(tableName) if err != nil { return err } @@ -119,19 +115,9 @@ func FetchRecord(tableName string, key string) (string, error) { } func FetchRecords(tableName string) (map[string]string, error) { - row, err := Database.QueryOne("SELECT * FROM " + tableName + " ORDER BY key") - if err != nil { - return nil, err - } - records := make(map[string]string) - for row.Next() { // Iterate and fetch the records from result cursor - var key string - var value string - row.Scan(&key, &value) - records[key] = value - } - if len(records) == 0 { - return nil, errors.New(NO_RECORDS) - } - return records, nil + return getCurrentDB()[FETCH_ALL].(func(string) (map[string]string, error))(tableName) +} + +func CloseDB() { + getCurrentDB()[CLOSE_DB].(func())() } diff --git a/database/rqlite.go b/database/rqlite.go new file mode 100644 index 00000000..76460f08 --- /dev/null +++ b/database/rqlite.go @@ -0,0 +1,117 @@ +package database + +import ( + "errors" + + "github.com/gravitl/netmaker/servercfg" + "github.com/rqlite/gorqlite" +) + +var RQliteDatabase gorqlite.Connection + +var RQLITE_FUNCTIONS = map[string]interface{}{ + INIT_DB: initRqliteDatabase, + CREATE_TABLE: rqliteCreateTable, + INSERT: rqliteInsert, + INSERT_PEER: rqliteInsertPeer, + DELETE: rqliteDeleteRecord, + DELETE_ALL: rqliteDeleteAllRecords, + FETCH_ALL: rqliteFetchRecords, + CLOSE_DB: rqliteCloseDB, +} + +func initRqliteDatabase() error { + + conn, err := gorqlite.Open(servercfg.GetSQLConn()) + if err != nil { + return err + } + RQliteDatabase = conn + RQliteDatabase.SetConsistencyLevel("strong") + return nil +} + +func rqliteCreateTable(tableName string) error { + _, err := RQliteDatabase.WriteOne("CREATE TABLE IF NOT EXISTS " + tableName + " (key TEXT NOT NULL UNIQUE PRIMARY KEY, value TEXT)") + if err != nil { + return err + } + return nil +} + +func rqliteInsert(key string, value string, tableName string) error { + if key != "" && value != "" && IsJSONString(value) { + _, err := RQliteDatabase.WriteOne("INSERT OR REPLACE INTO " + tableName + " (key, value) VALUES ('" + key + "', '" + value + "')") + if err != nil { + return err + } + return nil + } else { + return errors.New("invalid insert " + key + " : " + value) + } +} + +func rqliteInsertPeer(key string, value string) error { + if key != "" && value != "" && IsJSONString(value) { + _, err := RQliteDatabase.WriteOne("INSERT OR REPLACE INTO " + PEERS_TABLE_NAME + " (key, value) VALUES ('" + key + "', '" + value + "')") + if err != nil { + return err + } + return nil + } else { + return errors.New("invalid peer insert " + key + " : " + value) + } +} + +func rqliteDeleteRecord(tableName string, key string) error { + _, err := RQliteDatabase.WriteOne("DELETE FROM " + tableName + " WHERE key = \"" + key + "\"") + if err != nil { + return err + } + return nil +} + +func rqliteDeleteAllRecords(tableName string) error { + _, err := RQliteDatabase.WriteOne("DELETE TABLE " + tableName) + if err != nil { + return err + } + err = rqliteCreateTable(tableName) + if err != nil { + return err + } + return nil +} + +func rqliteFetchRecord(tableName string, key string) (string, error) { + results, err := FetchRecords(tableName) + if err != nil { + return "", err + } + if results[key] == "" { + return "", errors.New(NO_RECORD) + } + return results[key], nil +} + +func rqliteFetchRecords(tableName string) (map[string]string, error) { + row, err := RQliteDatabase.QueryOne("SELECT * FROM " + tableName + " ORDER BY key") + if err != nil { + return nil, err + } + records := make(map[string]string) + for row.Next() { // Iterate and fetch the records from result cursor + var key string + var value string + row.Scan(&key, &value) + records[key] = value + } + if len(records) == 0 { + return nil, errors.New(NO_RECORDS) + } + return records, nil +} + +func rqliteCloseDB() { + RQliteDatabase.Close() +} diff --git a/database/sqlite.go b/database/sqlite.go new file mode 100644 index 00000000..7fbf39f8 --- /dev/null +++ b/database/sqlite.go @@ -0,0 +1,137 @@ +package database + +import ( + "database/sql" + "errors" + "log" + "os" + "path/filepath" + + _ "github.com/mattn/go-sqlite3" +) + +// == sqlite == +const dbFilename = "netmaker.db" + +var SqliteDB *sql.DB + +var SQLITE_FUNCTIONS = map[string]interface{}{ + INIT_DB: initSqliteDB, + CREATE_TABLE: sqliteCreateTable, + INSERT: sqliteInsert, + INSERT_PEER: sqliteInsertPeer, + DELETE: sqliteDeleteRecord, + DELETE_ALL: sqliteDeleteAllRecords, + FETCH_ALL: sqliteFetchRecords, + CLOSE_DB: sqliteCloseDB, +} + +func initSqliteDB() error { + // == create db file if not present == + if _, err := os.Stat("data"); os.IsNotExist(err) { + log.Println("Could not find data directory, creating it.") + os.Mkdir("data", 0644) + } + dbFilePath := filepath.Join("data", dbFilename) + if _, err := os.Stat(dbFilePath); os.IsNotExist(err) { + log.Println("Could not get database file, creating it.") + os.Create(dbFilePath) + } + // == "connect" the database == + var dbOpenErr error + SqliteDB, dbOpenErr = sql.Open("sqlite3", dbFilePath) + if dbOpenErr != nil { + return dbOpenErr + } + return nil +} + +func sqliteCreateTable(tableName string) error { + statement, err := SqliteDB.Prepare("CREATE TABLE IF NOT EXISTS " + tableName + " (key TEXT NOT NULL UNIQUE PRIMARY KEY, value TEXT)") + if err != nil { + return err + } + _, err = statement.Exec() + if err != nil { + return err + } + log.Println(tableName, "table created") + return nil +} + +func sqliteInsert(key string, value string, tableName string) error { + if key != "" && value != "" && IsJSONString(value) { + insertSQL := "INSERT OR REPLACE INTO " + tableName + " (key, value) VALUES (?, ?)" + statement, err := SqliteDB.Prepare(insertSQL) + if err != nil { + return err + } + _, err = statement.Exec(key, value) + if err != nil { + return err + } + log.Println("inserted", key, ":", value, "into ", tableName) + return nil + } else { + return errors.New("invalid insert " + key + " : " + value) + } +} + +func sqliteInsertPeer(key string, value string) error { + if key != "" && value != "" && IsJSONString(value) { + err := sqliteInsert(key, value, PEERS_TABLE_NAME) + if err != nil { + return err + } + return nil + } else { + return errors.New("invalid peer insert " + key + " : " + value) + } +} + +func sqliteDeleteRecord(tableName string, key string) error { + deleteSQL := "DELETE FROM " + tableName + " WHERE key = \"" + key + "\"" + statement, err := SqliteDB.Prepare(deleteSQL) + if err != nil { + return err + } + if _, err = statement.Exec(); err != nil { + return err + } + return nil +} + +func sqliteDeleteAllRecords(tableName string) error { + deleteSQL := "DELETE FROM " + tableName + statement, err := SqliteDB.Prepare(deleteSQL) + if err != nil { + return err + } + if _, err = statement.Exec(); err != nil { + return err + } + return nil +} + +func sqliteFetchRecords(tableName string) (map[string]string, error) { + row, err := SqliteDB.Query("SELECT * FROM " + tableName + " ORDER BY key") + if err != nil { + return nil, err + } + records := make(map[string]string) + defer row.Close() + for row.Next() { // Iterate and fetch the records from result cursor + var key string + var value string + row.Scan(&key, &value) + records[key] = value + } + if len(records) == 0 { + return nil, errors.New(NO_RECORDS) + } + return records, nil +} + +func sqliteCloseDB() { + SqliteDB.Close() +} diff --git a/main.go b/main.go index 80253051..4f81252f 100644 --- a/main.go +++ b/main.go @@ -11,6 +11,7 @@ import ( "os/signal" "strconv" "sync" + controller "github.com/gravitl/netmaker/controllers" "github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/functions" @@ -25,7 +26,7 @@ import ( func main() { fmt.Println(models.RetrieveLogo()) // print the logo initialize() // initial db and grpc server - defer database.Database.Close() + defer database.CloseDB() startControllers() // start the grpc or rest endpoints } @@ -40,7 +41,7 @@ func initialize() { // Client Mode Prereq Check if err != nil { log.Println("Error running 'id -u' for prereq check. Please investigate or disable client mode.") - log.Fatal(err) + log.Fatal(output, err) } uid, err := strconv.Atoi(string(output[:len(output)-1])) if err != nil { diff --git a/netclient/main.go b/netclient/main.go index 3baad235..c2275f74 100644 --- a/netclient/main.go +++ b/netclient/main.go @@ -316,7 +316,7 @@ func main() { out, err := local.RunCmd("id -u") if err != nil { - log.Fatal(err) + log.Fatal(out, err) } id, err := strconv.Atoi(string(out[:len(out)-1])) diff --git a/netclient/wireguard/kernel.go b/netclient/wireguard/kernel.go index b48b1ddb..0644c394 100644 --- a/netclient/wireguard/kernel.go +++ b/netclient/wireguard/kernel.go @@ -64,17 +64,18 @@ func InitWireguard(node *models.Node, privkey string, peers []wgtypes.PeerConfig network = node.Network } - _, delErr := local.RunCmd("ip link delete dev " + ifacename) - _, addLinkErr := local.RunCmd(ipExec + " link add dev " + ifacename + " type wireguard") - _, addErr := local.RunCmd(ipExec + " address add dev " + ifacename + " " + node.Address + "/24") + delOut, delErr := local.RunCmd("ip link delete dev " + ifacename) + addLinkOut, addLinkErr := local.RunCmd(ipExec + " link add dev " + ifacename + " type wireguard") + addOut, addErr := local.RunCmd(ipExec + " address add dev " + ifacename + " " + node.Address + "/24") if delErr != nil { // pass + log.Println(delOut, delErr) } if addLinkErr != nil { - log.Println(addLinkErr) + log.Println(addLinkOut, addLinkErr) } if addErr != nil { - log.Println(addErr) + log.Println(addOut, addErr) } var nodeport int nodeport = int(node.ListenPort) @@ -162,16 +163,16 @@ func InitWireguard(node *models.Node, privkey string, peers []wgtypes.PeerConfig out, err := local.RunCmd(ipExec + " -4 route add " + gateway + " dev " + ifacename) fmt.Println(string(out)) if err != nil { - fmt.Println("Error encountered adding gateway: " + err.Error()) + fmt.Println("error encountered adding gateway: " + err.Error()) } } } if node.Address6 != "" && node.IsDualStack == "yes" { - fmt.Println("Adding address: " + node.Address6) + fmt.Println("adding address: " + node.Address6) out, err := local.RunCmd(ipExec + " address add dev " + ifacename + " " + node.Address6 + "/64") if err != nil { fmt.Println(out) - fmt.Println("Error encountered adding ipv6: " + err.Error()) + fmt.Println("error encountered adding ipv6: " + err.Error()) } } @@ -268,9 +269,9 @@ func SetPeers(iface string, keepalive int32, peers []wgtypes.PeerConfig) error { for _, currentPeer := range devicePeers { if currentPeer.AllowedIPs[0].String() == peer.AllowedIPs[0].String() && currentPeer.PublicKey.String() != peer.PublicKey.String() { - _, err := local.RunCmd("wg set " + iface + " peer " + currentPeer.PublicKey.String() + " remove") + output, err := local.RunCmd("wg set " + iface + " peer " + currentPeer.PublicKey.String() + " remove") if err != nil { - log.Println("error removing peer", peer.Endpoint.String()) + log.Println(output, "error removing peer", peer.Endpoint.String()) } } } @@ -285,18 +286,19 @@ func SetPeers(iface string, keepalive int32, peers []wgtypes.PeerConfig) error { if keepAliveString == "0" { keepAliveString = "5" } + var output string if peer.Endpoint != nil { - _, err = local.RunCmd("wg set " + iface + " peer " + peer.PublicKey.String() + + output, err = local.RunCmd("wg set " + iface + " peer " + peer.PublicKey.String() + " endpoint " + udpendpoint + " persistent-keepalive " + keepAliveString + " allowed-ips " + allowedips) } else { - _, err = local.RunCmd("wg set " + iface + " peer " + peer.PublicKey.String() + + output, err = local.RunCmd("wg set " + iface + " peer " + peer.PublicKey.String() + " persistent-keepalive " + keepAliveString + " allowed-ips " + allowedips) } if err != nil { - log.Println("error setting peer", peer.PublicKey.String(), err) + log.Println(output, "error setting peer", peer.PublicKey.String(), err) } } @@ -308,15 +310,15 @@ func SetPeers(iface string, keepalive int32, peers []wgtypes.PeerConfig) error { } } if shouldDelete { - _, err := local.RunCmd("wg set " + iface + " peer " + currentPeer.PublicKey.String() + " remove") + output, err := local.RunCmd("wg set " + iface + " peer " + currentPeer.PublicKey.String() + " remove") if err != nil { - log.Println("error removing peer", currentPeer.PublicKey.String()) + log.Println(output, "error removing peer", currentPeer.PublicKey.String()) } else { log.Println("removed peer " + currentPeer.PublicKey.String()) } } } - + return nil } diff --git a/servercfg/serverconf.go b/servercfg/serverconf.go index 001b41b1..4246b70c 100644 --- a/servercfg/serverconf.go +++ b/servercfg/serverconf.go @@ -78,6 +78,15 @@ func GetVersion() string { } return version } +func GetDB() string { + database := "rqlite" + if os.Getenv("DATABASE") == "sqlite" { + database = os.Getenv("DATABASE") + } else if config.Config.Server.Database == "sqlite" { + database = config.Config.Server.Database + } + return database +} func GetAPIHost() string { serverhost := "127.0.0.1" remoteip, _ := GetPublicIP() @@ -313,8 +322,8 @@ func GetSQLConn() string { sqlconn := "http://" if os.Getenv("SQL_CONN") != "" { sqlconn = os.Getenv("SQL_CONN") - } else if config.Config.Server.SQLConn != "" { + } else if config.Config.Server.SQLConn != "" { sqlconn = config.Config.Server.SQLConn } return sqlconn -} \ No newline at end of file +}