diff --git a/internal/cmd/root.go b/internal/cmd/root.go index 623acb1..e77d04d 100644 --- a/internal/cmd/root.go +++ b/internal/cmd/root.go @@ -101,28 +101,39 @@ func getDataDir(portableMode bool) (string, error) { } func openDatabase() (database.DB, error) { - // Check if it uses MySQL - if dbms, _ := os.LookupEnv("SHIORI_DBMS"); dbms == "mysql" { - user, _ := os.LookupEnv("SHIORI_MYSQL_USER") - password, _ := os.LookupEnv("SHIORI_MYSQL_PASS") - dbName, _ := os.LookupEnv("SHIORI_MYSQL_NAME") - dbAddress, _ := os.LookupEnv("SHIORI_MYSQL_ADDRESS") - - connString := fmt.Sprintf("%s:%s@%s/%s", user, password, dbAddress, dbName) - return database.OpenMySQLDatabase(connString) + switch dbms, _ := os.LookupEnv("SHIORI_DBMS"); dbms { + case "mysql": + return openMySQLDatabase() + case "postgresql": + return openPostgreSQLDatabase() + default: + return openSQLiteDatabase() } - // Check if it uses PostgreSQL - if dbms, _ := os.LookupEnv("SHIORI_DBMS"); dbms == "postgresql" { - host, _ := os.LookupEnv("SHIORI_PG_HOST") - port, _ := os.LookupEnv("SHIORI_PG_PORT") - user, _ := os.LookupEnv("SHIORI_PG_USER") - password, _ := os.LookupEnv("SHIORI_PG_PASS") - dbName, _ := os.LookupEnv("SHIORI_PG_NAME") +} - connString := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=disable", host, port, user, password, dbName) - return database.OpenPGDatabase(connString) - } - // If not, just uses SQLite +func openSQLiteDatabase() (database.DB, error) { dbPath := fp.Join(dataDir, "shiori.db") return database.OpenSQLiteDatabase(dbPath) } + +func openMySQLDatabase() (database.DB, error) { + user, _ := os.LookupEnv("SHIORI_MYSQL_USER") + password, _ := os.LookupEnv("SHIORI_MYSQL_PASS") + dbName, _ := os.LookupEnv("SHIORI_MYSQL_NAME") + dbAddress, _ := os.LookupEnv("SHIORI_MYSQL_ADDRESS") + + connString := fmt.Sprintf("%s:%s@%s/%s", user, password, dbAddress, dbName) + return database.OpenMySQLDatabase(connString) +} + +func openPostgreSQLDatabase() (database.DB, error) { + host, _ := os.LookupEnv("SHIORI_PG_HOST") + port, _ := os.LookupEnv("SHIORI_PG_PORT") + user, _ := os.LookupEnv("SHIORI_PG_USER") + password, _ := os.LookupEnv("SHIORI_PG_PASS") + dbName, _ := os.LookupEnv("SHIORI_PG_NAME") + + connString := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=disable", + host, port, user, password, dbName) + return database.OpenPGDatabase(connString) +}