diff --git a/config/config.go b/config/config.go index 1e2ed656..1dd65327 100644 --- a/config/config.go +++ b/config/config.go @@ -30,6 +30,7 @@ var Config *EnvironmentConfig // EnvironmentConfig : type EnvironmentConfig struct { Server ServerConfig `yaml:"server"` + SQL SQLConfig `yaml:"sql"` } // ServerConfig : @@ -61,6 +62,17 @@ type ServerConfig struct { Verbosity int32 `yaml:"verbosity"` } + +// Generic SQL Config +type SQLConfig struct { + Host string `yaml:"host"` + Port int32 `yaml:"port"` + Username string `yaml:"username"` + Password string `yaml:"password"` + DB string `yaml:"db"` + SSLMode string `yaml:"sslmode"` +} + //reading in the env file func readConfig() *EnvironmentConfig { file := fmt.Sprintf("config/environments/%s.yaml", getEnv()) diff --git a/database/database.go b/database/database.go index b2c28f8c..04f08ae1 100644 --- a/database/database.go +++ b/database/database.go @@ -38,6 +38,8 @@ func getCurrentDB() map[string]interface{} { return RQLITE_FUNCTIONS case "sqlite": return SQLITE_FUNCTIONS + case "postgres": + return PG_FUNCTIONS default: return SQLITE_FUNCTIONS } diff --git a/database/postgres.go b/database/postgres.go new file mode 100644 index 00000000..da9c8122 --- /dev/null +++ b/database/postgres.go @@ -0,0 +1,131 @@ +package database + +import ( + "github.com/gravitl/netmaker/servercfg" + "database/sql" + "errors" + _ "github.com/lib/pq" + "fmt" +) + +var PGDB *sql.DB + +var PG_FUNCTIONS = map[string]interface{}{ + INIT_DB: initPGDB, + CREATE_TABLE: pgCreateTable, + INSERT: pgInsert, + INSERT_PEER: pgInsertPeer, + DELETE: pgDeleteRecord, + DELETE_ALL: pgDeleteAllRecords, + FETCH_ALL: pgFetchRecords, + CLOSE_DB: pgCloseDB, +} + +func getPGConnString() string{ + pgconf := servercfg.GetSQLConf() + pgConn := fmt.Sprintf("host=%s port=%d user=%s "+ + "password=%s dbname=%s sslmode=%s", + pgconf.Host, pgconf.Port, pgconf.Username, pgconf.Password, pgconf.DB, pgconf.SSLMode) + return pgConn +} + + +func initPGDB() error { + connString := getPGConnString() + var dbOpenErr error + PGDB, dbOpenErr = sql.Open("postgres", connString) + if dbOpenErr != nil { + return dbOpenErr + } + dbOpenErr = PGDB.Ping() + + return dbOpenErr +} + +func pgCreateTable(tableName string) error { + statement, err := PGDB.Prepare("CREATE TABLE IF NOT EXISTS " + tableName + " (key TEXT NOT NULL UNIQUE PRIMARY KEY, value TEXT)") + if err != nil { + return err + } + _, err = statement.Exec() + if err != nil { + return err + } + return nil +} + +func pgInsert(key string, value string, tableName string) error { + if key != "" && value != "" && IsJSONString(value) { + insertSQL := "INSERT INTO " + tableName + " (key, value) VALUES ($1, $2) ON CONFLICT (key) DO UPDATE SET value = $3;" + statement, err := PGDB.Prepare(insertSQL) + if err != nil { + return err + } + _, err = statement.Exec(key, value, value) + if err != nil { + return err + } + return nil + } else { + return errors.New("invalid insert " + key + " : " + value) + } +} + +func pgInsertPeer(key string, value string) error { + if key != "" && value != "" && IsJSONString(value) { + err := pgInsert(key, value, PEERS_TABLE_NAME) + if err != nil { + return err + } + return nil + } else { + return errors.New("invalid peer insert " + key + " : " + value) + } +} + +func pgDeleteRecord(tableName string, key string) error { + deleteSQL := "DELETE FROM " + tableName + " WHERE key = \"" + key + "\"" + statement, err := PGDB.Prepare(deleteSQL) + if err != nil { + return err + } + if _, err = statement.Exec(); err != nil { + return err + } + return nil +} + +func pgDeleteAllRecords(tableName string) error { + deleteSQL := "DELETE FROM " + tableName + statement, err := PGDB.Prepare(deleteSQL) + if err != nil { + return err + } + if _, err = statement.Exec(); err != nil { + return err + } + return nil +} + +func pgFetchRecords(tableName string) (map[string]string, error) { + row, err := PGDB.Query("SELECT * FROM " + tableName + " ORDER BY key") + if err != nil { + return nil, err + } + records := make(map[string]string) + defer row.Close() + for row.Next() { // Iterate and fetch the records from result cursor + var key string + var value string + row.Scan(&key, &value) + records[key] = value + } + if len(records) == 0 { + return nil, errors.New(NO_RECORDS) + } + return records, nil +} + +func pgCloseDB() { + PGDB.Close() +} diff --git a/go.mod b/go.mod index c6135080..fb7785c9 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( github.com/golang/protobuf v1.5.2 // indirect github.com/gorilla/handlers v1.5.1 github.com/gorilla/mux v1.8.0 + github.com/lib/pq v1.10.3 // indirect github.com/mattn/go-sqlite3 v1.14.8 github.com/rqlite/gorqlite v0.0.0-20210514125552-08ff1e76b22f github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e diff --git a/go.sum b/go.sum index df7b8ce0..386280bb 100644 --- a/go.sum +++ b/go.sum @@ -73,6 +73,8 @@ github.com/jsimonetti/rtnetlink v0.0.0-20210212075122-66c871082f2b h1:c3NTyLNozI github.com/jsimonetti/rtnetlink v0.0.0-20210212075122-66c871082f2b/go.mod h1:8w9Rh8m+aHZIG69YPGGem1i5VzoyRC8nw2kA8B+ik5U= github.com/leodido/go-urn v1.2.0 h1:hpXL4XnriNwQ/ABnpepYM/1vCLWNDfUNts8dX3xTG6Y= github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= +github.com/lib/pq v1.10.3 h1:v9QZf2Sn6AmjXtQeFpdoq/eaNtYP6IN+7lcrygsIAtg= +github.com/lib/pq v1.10.3/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= github.com/mattn/go-sqlite3 v1.14.8 h1:gDp86IdQsN/xWjIEmr9MF6o9mpksUgh0fu+9ByFxzIU= github.com/mattn/go-sqlite3 v1.14.8/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= diff --git a/servercfg/serverconf.go b/servercfg/serverconf.go index 489a087c..efa0a6ed 100644 --- a/servercfg/serverconf.go +++ b/servercfg/serverconf.go @@ -82,9 +82,9 @@ func GetVersion() string { } func GetDB() string { database := "sqlite" - if os.Getenv("DATABASE") == "rqlite" { + if os.Getenv("DATABASE") != "" { database = os.Getenv("DATABASE") - } else if config.Config.Server.Database == "rqlite" { + } else if config.Config.Server.Database != "" { database = config.Config.Server.Database } return database diff --git a/servercfg/sqlconf.go b/servercfg/sqlconf.go new file mode 100644 index 00000000..0cef15ca --- /dev/null +++ b/servercfg/sqlconf.go @@ -0,0 +1,73 @@ +package servercfg + +import ( + "os" + "github.com/gravitl/netmaker/config" + "strconv" +) + +func GetSQLConf() config.SQLConfig { + var cfg config.SQLConfig + cfg.Host = GetSQLHost() + cfg.Port = GetSQLPort() + cfg.Username = GetSQLUser() + cfg.Password = GetSQLPass() + cfg.DB = GetSQLDB() + cfg.SSLMode = GetSQLSSLMode() + return cfg +} +func GetSQLHost() string { + host := "localhost" + if os.Getenv("SQL_HOST") != "" { + host = os.Getenv("SQL_HOST") + } else if config.Config.SQL.Host != "" { + host = config.Config.SQL.Host + } + return host +} +func GetSQLPort() int32 { + port := int32(5432) + envport, err := strconv.Atoi(os.Getenv("SQL_PORT")) + if err == nil && envport != 0 { + port = int32(envport) + } else if config.Config.SQL.Port != 0 { + port = config.Config.SQL.Port + } + return port +} +func GetSQLUser() string { + user := "posgres" + if os.Getenv("SQL_USER") != "" { + user = os.Getenv("SQL_USER") + } else if config.Config.SQL.Username != "" { + user = config.Config.SQL.Username + } + return user +} +func GetSQLPass() string { + pass := "nopass" + if os.Getenv("SQL_PASS") != "" { + pass = os.Getenv("SQL_PASS") + } else if config.Config.SQL.Password != "" { + pass = config.Config.SQL.Password + } + return pass +} +func GetSQLDB() string { + db := "netmaker" + if os.Getenv("SQL_DB") != "" { + db = os.Getenv("SQL_DB") + } else if config.Config.SQL.DB != "" { + db = config.Config.SQL.DB + } + return db +} +func GetSQLSSLMode() string { + sslmode := "disable" + if os.Getenv("SQL_SSL_MODE") != "" { + sslmode = os.Getenv("SQL_SSL_MODE") + } else if config.Config.SQL.SSLMode != "" { + sslmode = config.Config.SQL.SSLMode + } + return sslmode +}