diff --git a/src/app.py b/src/app.py
index b3d430990..391a0231e 100644
--- a/src/app.py
+++ b/src/app.py
@@ -14,6 +14,7 @@ from password_api import password_api
from settings_api import settings_api
from notes_history_api import notes_history_api
from audit_api import audit_api
+from migration_api import migration_api, APP_DB_VERSION
import config_provider
import my_scrypt
@@ -37,6 +38,7 @@ app.register_blueprint(password_api)
app.register_blueprint(settings_api)
app.register_blueprint(notes_history_api)
app.register_blueprint(audit_api)
+app.register_blueprint(migration_api)
class User(UserMixin):
pass
@@ -48,8 +50,18 @@ def login_form():
@app.route('/app', methods=['GET'])
@login_required
def show_app():
+ db_version = int(getOption('db_version'))
+
+ if db_version != APP_DB_VERSION:
+ return redirect('migration')
+
return render_template('app.html')
+@app.route('/migration', methods=['GET'])
+@login_required
+def show_migration():
+ return render_template('migration.html')
+
@app.route('/logout', methods=['POST'])
@login_required
def logout():
diff --git a/src/backup.py b/src/backup.py
index 4d7c8baab..14431e1b8 100644
--- a/src/backup.py
+++ b/src/backup.py
@@ -7,25 +7,29 @@ from shutil import copyfile
import os
import re
-def backup():
+def regular_backup():
now = utils.nowTimestamp()
last_backup_date = int(getOption('last_backup_date'))
if now - last_backup_date > 43200:
- config = config_provider.getConfig()
-
- document_path = config['Document']['documentPath']
- backup_directory = config['Backup']['backupDirectory']
-
- date_str = datetime.utcnow().strftime("%Y-%m-%d %H:%M")
-
- copyfile(document_path, backup_directory + "/" + "backup-" + date_str + ".db")
-
- setOption('last_backup_date', now)
- commit()
+ backup_now()
cleanup_old_backups()
+def backup_now():
+ now = utils.nowTimestamp()
+
+ config = config_provider.getConfig()
+
+ document_path = config['Document']['documentPath']
+ backup_directory = config['Backup']['backupDirectory']
+
+ date_str = datetime.utcnow().strftime("%Y-%m-%d %H:%M")
+
+ copyfile(document_path, backup_directory + "/" + "backup-" + date_str + ".db")
+
+ setOption('last_backup_date', now)
+ commit()
def cleanup_old_backups():
now = datetime.utcnow()
diff --git a/src/migration_api.py b/src/migration_api.py
new file mode 100644
index 000000000..ae2a789c4
--- /dev/null
+++ b/src/migration_api.py
@@ -0,0 +1,72 @@
+import os
+import re
+
+import traceback
+
+from flask import Blueprint, jsonify
+from flask_login import login_required
+
+from sql import getOption, setOption, commit, execute_script
+
+import backup
+
+APP_DB_VERSION = 0
+
+MIGRATIONS_DIR = "src/migrations"
+
+migration_api = Blueprint('migration_api', __name__)
+
+@migration_api.route('/api/migration', methods = ['GET'])
+@login_required
+def getMigrationInfo():
+ return jsonify({
+ 'db_version': int(getOption('db_version')),
+ 'app_db_version': APP_DB_VERSION
+ })
+
+@migration_api.route('/api/migration', methods = ['POST'])
+@login_required
+def runMigration():
+ migrations = []
+
+ backup.backup_now()
+
+ current_db_version = int(getOption('db_version'))
+
+ for file in os.listdir(MIGRATIONS_DIR):
+ match = re.search(r"([0-9]{4})__([a-zA-Z0-9_ ]+)\.sql", file)
+
+ if match:
+ db_version = int(match.group(1))
+
+ if db_version > current_db_version:
+ name = match.group(2)
+
+ migration_record = {
+ 'db_version': db_version,
+ 'name': name
+ }
+
+ migrations.append(migration_record)
+
+ with open(MIGRATIONS_DIR + "/" + file, 'r') as sql_file:
+ sql = sql_file.read()
+
+ try:
+ execute_script(sql)
+
+ setOption('db_version', db_version)
+ commit()
+
+ migration_record['success'] = True
+ except:
+ migration_record['success'] = False
+ migration_record['error'] = traceback.format_exc()
+
+ break
+
+ migrations.sort(key=lambda x: x['db_version'])
+
+ return jsonify({
+ 'migrations': migrations
+ })
diff --git a/src/sql.py b/src/sql.py
index c067f5e35..5ae80b96f 100644
--- a/src/sql.py
+++ b/src/sql.py
@@ -64,6 +64,11 @@ def execute(sql, params=[]):
cursor.execute(sql, params)
return cursor
+def execute_script(sql):
+ cursor = conn.cursor()
+ cursor.executescript(sql)
+ return cursor
+
def getResults(sql, params=[]):
cursor = conn.cursor()
query = cursor.execute(sql, params)
diff --git a/src/templates/app.html b/src/templates/app.html
index aaaa4ce65..5e4b863a2 100644
--- a/src/templates/app.html
+++ b/src/templates/app.html
@@ -98,9 +98,9 @@
- + - +
diff --git a/src/templates/migration.html b/src/templates/migration.html new file mode 100644 index 000000000..74e5e887c --- /dev/null +++ b/src/templates/migration.html @@ -0,0 +1,57 @@ + + + + +