From 71dbd5eab7c6f2b401ff460ece115e82991b9e2a Mon Sep 17 00:00:00 2001 From: Thomas Hintz Date: Sun, 3 Dec 2023 10:18:20 -0800 Subject: [PATCH] adding rest of docker files --- config/config.yaml | 14 ++ requirements.txt | 11 + src/__init__.py | 40 +++ src/api.py | 430 ++++++++++++++++++++++++++++++++ src/server.py | 40 +++ src/ssl/.gitkeep | 0 src/util.py | 609 +++++++++++++++++++++++++++++++++++++++++++++ webroot/.gitkeep | 0 8 files changed, 1144 insertions(+) create mode 100644 config/config.yaml create mode 100644 requirements.txt create mode 100644 src/__init__.py create mode 100644 src/api.py create mode 100644 src/server.py create mode 100644 src/ssl/.gitkeep create mode 100644 src/util.py create mode 100644 webroot/.gitkeep diff --git a/config/config.yaml b/config/config.yaml new file mode 100644 index 0000000..9e9dd3b --- /dev/null +++ b/config/config.yaml @@ -0,0 +1,14 @@ +backendHostname: receiptmanager.home.local +backendIP: 192.168.50.59 +backendLanguage: de-DE +backendPort: 5558 +dbMode: mssql +encrypted: true +parserIP: 192.168.50.6 +parserPort: 8721 +parserToken: dd8dfbfbc51233fea39acc23c9a08fc3 +sqlDatabase: receiptData +sqlPassword: cbadf3c7d51a20df94a0ed37ed92bff16a75fdaa35d1f53e78f99ebe0179f89e +sqlServerIP: 192.168.50.6 +sqlUsername: receiptParser +useSSL: false diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..dbb99d1 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,11 @@ +flask==2.0.1 +PyYAML==5.4.1 +gevent==23.9.0 +Flask-Cors==3.0.10 +pyodbc==4.0.30 +requests==2.25.1 +uuid==1.30 +mysql-connector-python==8.0.25 +cryptography==3.4.7 +pycryptodome==3.10.1 +Wand==0.6.6 diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..95cf199 --- /dev/null +++ b/src/__init__.py @@ -0,0 +1,40 @@ +from server import server +from util import ( + create_ssl_cert, + init_mssql_db, + init_mysql_db, + load_conf, + create_web_config, + check_existing_token, + load_db_conn, +) + +def main(): + cfg = load_conf() + + if cfg["useSSL"]: + create_ssl_cert([cfg["backendIP"]]) + + if cfg["dbMode"] and cfg["sqlDatabase"] and cfg["sqlPassword"] and cfg["sqlServerIP"] and cfg["sqlUsername"]: + print("Using " + cfg["dbMode"] + " DB") + + try: + conn = load_db_conn()[0] + if cfg["dbMode"] == "mssql": + init_mssql_db(conn) + elif cfg["dbMode"] == "mysql": + init_mysql_db(conn) + else: + print("Error! No valid db mode found. Please use mssql or mysql") + except Exception as e: + print(e) + else: + print("No db mode set.") + + check_existing_token() + create_web_config() + server() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/api.py b/src/api.py new file mode 100644 index 0000000..ae12c1c --- /dev/null +++ b/src/api.py @@ -0,0 +1,430 @@ +import json +import uuid +from datetime import datetime + +import requests +from flask import Flask, request, send_from_directory +from flask_cors import CORS, cross_origin + +from util import ( + check_existing_token, + convert_to_mysql_query, + get_category_id, + get_data_from_db, + add_or_update_to_db, + delete_from_db, + get_store_id, + load_conf, + load_db_conn, + delete_receipt, + update_server_config, + convert_pdf_to_png +) + +app = Flask( + __name__, + static_url_path="", + static_folder="../webroot", +) +cors = CORS(app, resources={r"/api/upload/*": {"origins": "*"}}) +app.config["CORS_HEADERS"] = "Content-Type" + +cfg = None +api_token = None + + +@app.before_first_request +def first(): + global cfg + global api_token + cfg = load_conf() + api_token = check_existing_token() + +@app.before_request +def before_request(): + if request.endpoint in ('static', 'index'): + return + + if not request.args: + return "No token provided! Add &token= to URL", 401 + + if api_token != request.args["token"]: + return "Unauthorized", 401 + + global cfg + cfg = load_conf() + if (( + not cfg["dbMode"] + or not cfg["sqlDatabase"] + or not cfg["sqlPassword"] + or not cfg["sqlServerIP"] + or not cfg["sqlUsername"] + ) + and request.endpoint != 'updateConfig' + and request.endpoint != 'getBackendConfig' + ): + return "Settings incomplete!", 512 + +@app.route("/", methods=["GET"]) +def index(): + return send_from_directory(app.static_folder, "index.html") + +@app.route("/api/getBackendConfig", methods=["GET", "OPTIONS"]) +@cross_origin(origin="*", headers=["Content-Type"]) +def getBackendConfig(): + return json.dumps(cfg) + +@app.route("/api/updateConfig", methods=["POST", "OPTIONS"]) +@cross_origin(origin="*", headers=["Content-Type"]) +def updateConfig(): + post_string = json.dumps(request.get_json()) + post_json = json.loads(post_string) + + update_server_config(post_json) + + return "Config updated", 200 + +@app.route("/api/upload", methods=["POST", "OPTIONS"]) +@cross_origin(origin="*", headers=["Content-Type"]) +def upload(): + file = request.files["file"] + file_name = file.filename + gaussian_blur = True + + if file.content_type == "application/pdf": + file = convert_pdf_to_png(file.read()) + file_name = file_name.split(".")[0] + ".png" + gaussian_blur = False + + if not file: + return "No valid file found", 500 + + url = ( + "http://" + + str(cfg["parserIP"]) + + ":" + + str(cfg["parserPort"]) + + "/api/upload?access_token=" + + str(cfg["parserToken"]) + + "&legacy_parser=True" + + "&grayscale_image=True" + + "&rotate_image=True" + + "&gaussian_blur=" + str(gaussian_blur) + + "&median_blur=True" + ) + + receipt_upload = requests.post(url, files={"file": (file_name, file)}) + + if receipt_upload.status_code == 200: + upload_response = json.dumps(receipt_upload.content.decode("utf8")) + response_json = json.loads(upload_response) + response_json = json.loads(response_json) + + # Replace " in Date + if '"' in response_json["receiptDate"]: + response_json["receiptDate"] = response_json["receiptDate"].replace('"', "") + + # Create 4 digit year + if response_json["receiptDate"] != "null": + year_string = response_json["receiptDate"].split(".") + if len(year_string[2]) < 4: + year_string[2] = "20" + year_string[2] + response_json["receiptDate"] = ( + year_string[0] + "." + year_string[1] + "." + year_string[2] + ) + + conn, cursor = load_db_conn() + for idx, article in enumerate(response_json["receiptItems"]): + article = article[0] + + splitted_articles = article.split(" ") + + for article in splitted_articles: + if len(article) > 3: + + if cfg["dbMode"] == "mysql": + sql_query = "SELECT category FROM purchaseData where article_name like %s order by timestamp desc limit 1" + else: + sql_query = "SELECT TOP 1 category FROM purchaseData where article_name like ? order by timestamp desc" + + cursor.execute(sql_query, [f"%{article}%"]) + row = cursor.fetchone() + + if row: + if cfg["dbMode"] == "mysql": + found_cat = row[0] + else: + found_cat = row.category + + copy_array = response_json["receiptItems"][idx] + copy_array.insert(2, found_cat) + + response_json["receiptItems"][idx] = copy_array + break + + conn.close() + return json.dumps(response_json) + + else: + return "Error on upload", receipt_upload.status_code + + +@app.route("/api/getHistory", methods=["GET", "OPTIONS"]) +@cross_origin(origin="*", headers=["Content-Type"]) +def get_history(): + conn, cursor = load_db_conn() + history_json = [] + + cursor.execute( + "select SUM(total) as totalSum, location, id, timestamp from purchaseData \ + where id is not null \ + GROUP BY timestamp, id, location \ + ORDER BY timestamp desc" + ) + + rows = cursor.fetchall() + + for row in rows: + if cfg["dbMode"] == "mysql": + add_json = { + "location": row[1], + "totalSum": str(row[0]), + "timestamp": str(row[3]), + "id": row[2], + } + else: + add_json = { + "location": row.location, + "totalSum": str(row.totalSum), + "timestamp": str(row.timestamp), + "id": row.id, + } + history_json.append(add_json) + + conn.close() + + return json.dumps(history_json) + + +@app.route("/api/getHistoryDetails", methods=["GET", "OPTIONS"]) +@cross_origin(origin="*", headers=["Content-Type"]) +def get_history_details(): + purchase_id = request.args["purchaseID"] + + conn, cursor = load_db_conn() + sql_query = ( + "SELECT storeName, total, date, purchaseId " + + "FROM receipts re " + + "JOIN stores st ON re.storeId = st.id " + + "where re.id = ?" + ) + + if cfg["dbMode"] == "mysql": + sql_query = convert_to_mysql_query(sql_query) + + cursor.execute(sql_query, [purchase_id]) + row = cursor.fetchone() + + if row: + if cfg["dbMode"] == "mysql": + store_name = row[0] + receipt_total = row[1] + receipt_date = row[2] + db_purchase_id = row[3] + else: + store_name = row.storeName + receipt_total = row.total + receipt_date = row.date + db_purchase_id = row.purchaseId + + receipt_date = receipt_date.strftime("%d.%m.%Y") + + purchase_details = { + "storeName": store_name, + "receiptTotal": str(receipt_total), + "receiptDate": receipt_date, + "purchaseID": db_purchase_id, + "receiptItems": [], + } + + sql_query = "select article_name, total, category from purchaseData where id = ?" + if cfg["dbMode"] == "mysql": + sql_query = convert_to_mysql_query(sql_query) + + cursor.execute(sql_query, [purchase_id]) + rows = cursor.fetchall() + + for row in rows: + if cfg["dbMode"] == "mysql": + add_json = [row[0], str(row[1]), row[2]] + else: + add_json = [row.article_name, str(row.total), row.category] + + purchase_details["receiptItems"].append(add_json) + + conn.close() + + return json.dumps(purchase_details) + else: + return "Purchase not found!", 500 + + +@app.route("/api/getValue", methods=["GET", "OPTIONS"]) +@cross_origin(origin="*", headers=["Content-Type"]) +def get_categories(): + get_values_from = request.args["getValuesFrom"] + + ret_json = get_data_from_db(get_values_from) + + return json.dumps(ret_json, ensure_ascii=False) + + +@app.route("/api/addValue", methods=["POST", "OPTIONS"]) +@cross_origin(origin="*", headers=["Content-Type"]) +def add_value(): + to_add_array = request.args["toAddArray"] + to_add_value = request.args["toAddValue"] + item_id = request.args["id"] + + add_or_update_to_db(to_add_array, item_id, to_add_value) + + return "Done!" + + +@app.route("/api/deleteValue", methods=["POST", "OPTIONS"]) +@cross_origin(origin="*", headers=["Content-Type"]) +def delete_value(): + table_name = request.args["tableName"] + item_id = request.args["id"] + + delete_from_db(table_name, item_id) + + return "Done!" + +@app.route("/api/deleteReceiptFromDB", methods=["POST", "OPTIONS"]) +@cross_origin(origin="*", headers=["Content-Type"]) +def delete_receipt_from_db(): + receipt_id = request.args["purchaseID"] + + delete_receipt(receipt_id) + + return "Done!" + +@app.route("/api/updateReceiptToDB", methods=["POST", "OPTIONS"]) +@cross_origin(origin="*", headers=["Content-Type"]) +def update_receipt_to_db(): + post_string = json.dumps(request.get_json()) + post_json = json.loads(post_string) + + conn, cursor = load_db_conn() + + store_id = get_store_id(post_json["storeName"]) + receipt_date = post_json["receiptDate"] + receipt_total = post_json["receiptTotal"] + receipt_id = post_json["purchaseID"] + + receipt_date = datetime.strptime(receipt_date, "%d.%m.%Y") + receipt_date = receipt_date.strftime("%m-%d-%Y") + + # Clean old db data + delete_receipt(receipt_id) + + # Write article positions + for article in post_json["receiptItems"]: + article_name = article[1] + article_sum = article[2] + article_id = int(str(uuid.uuid1().int)[:8]) + article_category_id = get_category_id(article[3]) + + sql_query = "INSERT INTO items values (?,?,?,?)" + if cfg["dbMode"] == "mysql": + sql_query = convert_to_mysql_query(sql_query) + cursor.execute( + sql_query, [article_id, article_name, article_sum, article_category_id] + ) + + sql_query = "INSERT INTO purchasesArticles values (?,?)" + if cfg["dbMode"] == "mysql": + sql_query = convert_to_mysql_query(sql_query) + cursor.execute(sql_query, [receipt_id, article_id]) + + # Write receipt summary + if cfg["dbMode"] == "mysql": + sql_query = ( + "INSERT INTO receipts values (%s,%s,STR_TO_DATE(%s,'%m-%d-%Y'),%s,%s,%s)" + ) + else: + sql_query = "INSERT INTO receipts values (?,?,?,?,?,?)" + + cursor.execute( + sql_query, [receipt_id, store_id, receipt_date, receipt_total, None, receipt_id] + ) + + conn.commit() + conn.close() + + return "Done!" + + +@app.route("/api/writeReceiptToDB", methods=["POST", "OPTIONS"]) +@cross_origin(origin="*", headers=["Content-Type"]) +def write_receipt_to_db(): + post_string = json.dumps(request.get_json()) + post_json = json.loads(post_string) + + conn, cursor = load_db_conn() + + store_id = get_store_id(post_json["storeName"]) + receipt_date = post_json["receiptDate"] + receipt_total = post_json["receiptTotal"] + receipt_id = int(str(uuid.uuid1().int)[:6]) + + receipt_date = datetime.strptime(receipt_date, "%d.%m.%Y") + receipt_date = receipt_date.strftime("%m-%d-%Y") + + # Write article positions + for article in post_json["receiptItems"]: + article_id = int(str(uuid.uuid1().int)[:8]) + article_name = article[1] + article_sum = article[2] + article_category_id = get_category_id(article[3]) + + if not article_category_id: + return "Category id for category: " + article[3] + " not found", 500 + + sql_query = "INSERT INTO items values (?,?,?,?)" + if cfg["dbMode"] == "mysql": + sql_query = convert_to_mysql_query(sql_query) + cursor.execute( + sql_query, [article_id, article_name, article_sum, article_category_id] + ) + + sql_query = "INSERT INTO purchasesArticles values (?,?)" + if cfg["dbMode"] == "mysql": + sql_query = convert_to_mysql_query(sql_query) + cursor.execute(sql_query, [receipt_id, article_id]) + + # Write receipt summary + if cfg["dbMode"] == "mysql": + sql_query = ( + "INSERT INTO receipts values (%s,%s,STR_TO_DATE(%s,'%m-%d-%Y'),%s,%s,%s)" + ) + else: + sql_query = "INSERT INTO receipts values (?,?,?,?,?,?)" + + cursor.execute( + sql_query, [receipt_id, store_id, receipt_date, receipt_total, None, receipt_id] + ) + + conn.commit() + conn.close() + + return "Done!" + +@app.after_request +def add_header(r): + r.headers["Cache-Control"] = "no-cache, no-store, must-revalidate" + r.headers["Pragma"] = "no-cache" + r.headers["Expires"] = "0" + r.headers['Cache-Control'] = 'public, max-age=0' + return r \ No newline at end of file diff --git a/src/server.py b/src/server.py new file mode 100644 index 0000000..dda14a6 --- /dev/null +++ b/src/server.py @@ -0,0 +1,40 @@ +from gevent import monkey +monkey.patch_all() + +from gevent.pywsgi import WSGIServer +from api import app +from util import load_conf, check_existing_token + +def server(): + cfg = load_conf() + api_token = check_existing_token() + + if cfg["useSSL"]: + http_server = WSGIServer( + (cfg["backendIP"], int(cfg["backendPort"])), + app, + certfile="ssl/cert.crt", + keyfile="ssl/key.pem", + ) + print( + "Server started. Running on https://" + + str(cfg["backendIP"]) + + ":" + + str(cfg["backendPort"]) + ) + else: + http_server = WSGIServer((cfg["backendIP"], int(cfg["backendPort"])), app) + print( + "Server started. Running on http://" + + str(cfg["backendIP"]) + + ":" + + str(cfg["backendPort"]) + ) + + print("API Token: " + api_token) + if cfg['parserIP']: + print("Parser IP set to: " + str(cfg["parserIP"]) + ":" + str(cfg["parserPort"])) + else: + print("No parser IP set.") + + http_server.serve_forever() \ No newline at end of file diff --git a/src/ssl/.gitkeep b/src/ssl/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/src/util.py b/src/util.py new file mode 100644 index 0000000..9cb2bb8 --- /dev/null +++ b/src/util.py @@ -0,0 +1,609 @@ +import pyodbc +import uuid +import json +import yaml +import os +import ipaddress +import socket +from mysql.connector import connect, Error +from datetime import datetime, timedelta +from wand.image import Image +from wand.color import Color + +from Crypto.Cipher import AES +from cryptography.fernet import Fernet +from binascii import b2a_hex, a2b_hex +from cryptography import x509 +from cryptography.x509.oid import NameOID +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import rsa + +cfg = None +api_token = None +key = None +BLOCK_SIZE = 16 +SEGMENT_SIZE = 128 + +def convert_pdf_to_png(pdf_file): + try: + with Image(blob=pdf_file, resolution=200) as pdf: + first_page = pdf.sequence[0] + with Image(first_page) as i: + i.format = 'png' + i.compression_quality = 99 + i.background_color = Color('white') + i.alpha_channel = 'remove' + binary_png = i.make_blob("png") + return binary_png + except Exception as e: + if hasattr(e, 'wand_error_code'): + if e.wand_error_code == 415: + print("ERROR ! Please install ghostscript to convert PDF to PNG from https://www.ghostscript.com/download/gsdnld.html or on linux with 'apt-get install ghostscript'. \n" + + "If you are using Linux, have a look at here: https://stackoverflow.com/questions/57208396/imagemagick-ghostscript-delegate-security-policy-blocking-conversion") + else: + print("ERROR ! " + e) + else: + print("ERROR ! " + e) + return None + +def encrypt(plaintext): + key = check_existing_key() + key = key.encode('utf-8') + iv = key + + aes = AES.new(key, AES.MODE_CFB, iv, segment_size=SEGMENT_SIZE) + plaintext = _pad_string(plaintext) + encrypted_text = aes.encrypt(plaintext.encode()) + return b2a_hex(encrypted_text).rstrip().decode() + +def decrypt(encrypted_text): + key = check_existing_key() + key = key.encode('utf-8')[:16] + iv = key + + aes = AES.new(key, AES.MODE_CFB, iv, segment_size=SEGMENT_SIZE) + encrypted_text_bytes = a2b_hex(encrypted_text) + decrypted_text = aes.decrypt(encrypted_text_bytes) + decrypted_text = _unpad_string(decrypted_text.decode()) + return decrypted_text + +def _pad_string(value): + length = len(value) + pad_size = BLOCK_SIZE - (length % BLOCK_SIZE) + return value.ljust(length + pad_size, '\x00') + +def _unpad_string(value): + while value[-1] == '\x00': + value = value[:-1] + return value + +def create_ssl_cert( + ip_addresses=None, + ca_cert="../webroot/ssl/ca.crt", + ca_key="../webroot/ssl/ca.key", + key_file="ssl/key.pem", + cert_file="ssl/cert.crt", + ): + + root_cert = None + root_key = None + now = datetime.utcnow() + if not os.path.isfile(ca_cert) and not os.path.isfile(ca_key): + # Create CA Cert and Key + root_key = rsa.generate_private_key( + public_exponent=65537, + key_size=2048, + backend=default_backend() + ) + subject = issuer = x509.Name([ + x509.NameAttribute(NameOID.COUNTRY_NAME, u"DE"), + x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, u"NRW"), + x509.NameAttribute(NameOID.LOCALITY_NAME, u"NV"), + x509.NameAttribute(NameOID.ORGANIZATION_NAME, u"ReceiptManager"), + x509.NameAttribute(NameOID.COMMON_NAME, u"receipt-manager-ca"), + ]) + + basic_contraints = x509.BasicConstraints(ca=True, path_length=1) + + root_cert = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(issuer) + .public_key(root_key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(now) + .not_valid_after(now + timedelta(days=2 * 365)) + .add_extension(basic_contraints, False) + .sign(root_key, hashes.SHA256(), default_backend())) + + ca_cert_pem = root_cert.public_bytes(encoding=serialization.Encoding.PEM) + ca_key_pem = root_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ) + + open(ca_key, "wb").write(ca_key_pem) + open(ca_cert, "wb").write(ca_cert_pem) + + if not os.path.isfile(cert_file) and not os.path.isfile(key_file): + if not root_cert and not root_key: + cert_binary = open(ca_cert,"rb").read() + root_cert = x509.load_pem_x509_certificate(cert_binary, default_backend()) + key_binary = open(ca_key,"rb").read() + root_key = serialization.load_pem_private_key(key_binary, None, default_backend()) + + # Generate our key + key = rsa.generate_private_key( + public_exponent=65537, + key_size=2048, + backend=default_backend(), + ) + + name = x509.Name( + [x509.NameAttribute(NameOID.COMMON_NAME, cfg["backendHostname"])] + ) + + # best practice seem to be to include the hostname in the SAN, which *SHOULD* mean COMMON_NAME is ignored. + alt_names = [x509.DNSName(cfg["backendHostname"]), x509.DNSName("localhost")] + + # allow addressing by IP, for when you don't have real DNS (common in most testing scenarios + if ip_addresses: + for addr in ip_addresses: + # openssl wants DNSnames for ips... + alt_names.append(x509.DNSName(addr)) + # ... whereas golang's crypto/tls is stricter, and needs IPAddresses + # note: older versions of cryptography do not understand ip_address objects + alt_names.append(x509.IPAddress(ipaddress.ip_address(addr))) + + san = x509.SubjectAlternativeName(alt_names) + extended_key_usage = x509.ExtendedKeyUsage([x509.oid.ExtendedKeyUsageOID.SERVER_AUTH]) + + cert = ( + x509.CertificateBuilder() + .subject_name(name) + .issuer_name(root_cert.issuer) + .public_key(key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(now) + .not_valid_after(now + timedelta(days=2 * 365)) + .add_extension(san, False) + .add_extension(extended_key_usage, True) + .sign(root_key, hashes.SHA256(), default_backend()) + ) + cert_pem = cert.public_bytes(encoding=serialization.Encoding.PEM) + key_pem = key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ) + + open(key_file, "wb").write(key_pem) + open(cert_file, "wb").write(cert_pem) + +def update_server_config(settings): + crypt_config(settings) + load_conf(True) + create_web_config() + +def update_config_yaml(settings): + config_file = open('../config/config.yaml', 'w') + yaml.dump(settings, config_file) + config_file.close() + +def create_web_config(): + web_json = "../webroot/settings/settings.json" + web_cfg = { + "useSSL": cfg["useSSL"], + "backendIP": cfg["backendIP"], + "backendPort": cfg["backendPort"], + "backendToken": api_token, + "language": cfg["backendLanguage"] + } + f = open(web_json, "w") + f.write(json.dumps(web_cfg)) + f.close() + +def check_existing_key(): + if not os.path.isfile(r"../config/.key"): + create_key() + else: + read_key() + + return key + +def read_key(): + global key + if not key: + with open(r"../config/.key") as f: + key = f.readline() + +def create_key(): + global key + new_key = Fernet.generate_key().decode('utf-8')[:16] + f = open("../config/.key", "w") + f.write(new_key) + key = new_key + f.close() + +def check_existing_token(): + if not os.path.isfile(r".api_token"): + create_token() + else: + read_token() + + return api_token + +def read_token(): + global api_token + if not api_token: + with open(r".api_token") as f: + api_token = f.readline() + + +def create_token(): + global api_token + new_token = str(uuid.uuid4()) + f = open(".api_token", "w") + f.write(new_token) + api_token = new_token + f.close() + +def create_initial_config(): + use_ssl = os.environ.get("useSSL", False) + + if isinstance(use_ssl, str): + if use_ssl.lower() == 'true': + use_ssl = True + else: + use_ssl = False + + run_in_docker = os.environ.get("RUN_IN_DOCKER", False) + if not run_in_docker: + backend_ip = socket.gethostbyname(socket.gethostname()) + else: + backend_ip = os.environ.get("backendIP", None) + if not backend_ip: + stream = os.popen(r"ip -4 addr show eth0 | grep -Po 'inet \K[\d.]+'") + backend_ip = stream.read().rstrip() + + backend_port = os.environ.get("backendPort", 5558) + + print("Initial config created.") + + temp_config = { + "useSSL": use_ssl, + "backendHostname": "", + "backendIP": backend_ip, + "backendPort": backend_port, + "backendLanguage": "", + "parserIP": "", + "parserPort": "", + "parserToken": "", + "dbMode": "", + "sqlServerIP": "", + "sqlDatabase": "", + "sqlUsername": "", + "sqlPassword": "", + } + + config = json.dumps(temp_config) + cfg = json.loads(config) + update_config_yaml(cfg) + +def load_conf(force_reload=False): + global cfg + + if not os.path.isfile("../config/config.yaml"): + create_initial_config() + + if not cfg or force_reload: + with open("../config/config.yaml", "r") as ymlfile: + cfg = yaml.load(ymlfile, Loader=yaml.FullLoader) + + cfg = crypt_config(cfg) + + return cfg + +def crypt_config(settings): + rewrite_config = False + encrypted_cfg = settings.copy() + + if 'encrypted' in settings: + is_encrypted = settings['encrypted'] + else: + is_encrypted = False + + if not is_encrypted: + rewrite_config = True + + for c, v in settings.items(): + if v and ("Token" in c or "Password" in c): + if not is_encrypted: + rewrite_config = True + + encrypted = encrypt(str(v)) + encrypted_cfg[c] = encrypted + else: + try: + decrypted = decrypt(str(v)) + settings[c] = decrypted + except Exception as e: + if "Non-hexadecimal digit found" in str(e): + print("Decryption failed. Set encryption flag in config yaml to False!") + else: + print(e) + + if rewrite_config: + settings['encrypted'] = True + encrypted_cfg['encrypted'] = True + update_config_yaml(encrypted_cfg) + + return settings + +def delete_from_db(table_name, id): + conn, cur = load_db_conn() + + sql_query = "DELETE FROM " + table_name + " WHERE id = ?" + if cfg["dbMode"] == "mysql": + sql_query = convert_to_mysql_query(sql_query) + + cur.execute(sql_query, [id]) + conn.commit() + conn.close() + + +def add_or_update_to_db(to_add_table, item_id, to_add_value): + conn, cur = load_db_conn() + + if to_add_table == "categories": + name_col = "categoryName" + elif to_add_table == "stores": + name_col = "storeName" + + if item_id: + sql_update = ( + """ UPDATE """ + + to_add_table + + """ SET """ + + name_col + + """ = ? WHERE id = ?""" + ) + if cfg["dbMode"] == "mysql": + sql_update = convert_to_mysql_query(sql_update) + + cur.execute(sql_update, [to_add_value, item_id]) + else: + item_id = int(str(uuid.uuid1().int)[:6]) + sql_insert = """ INSERT INTO """ + to_add_table + """ VALUES (?, ?)""" + + if cfg["dbMode"] == "mysql": + sql_insert = convert_to_mysql_query(sql_insert) + + cur.execute(sql_insert, [item_id, to_add_value]) + + conn.commit() + conn.close() + return item_id + + +def get_data_from_db(table_name): + conn, cur = load_db_conn() + + if table_name == "categories": + orderby = "categoryName" + elif table_name == "stores": + orderby = "storeName" + + sql_select = """ SELECT * from """ + table_name + """ order by """ + orderby + cur.execute(sql_select) + rows = cur.fetchall() + conn.close() + + ret_array = json.dumps({"values": []}) + ret_json = json.loads(ret_array) + + for row in rows: + add_array = {"name": row[1], "id": row[0]} + ret_json["values"].append(add_array) + + return ret_json + + +def load_db_conn(): + if cfg["dbMode"] == "mssql": + conn, cur = create_ms_db_conn() + elif cfg["dbMode"] == "mysql": + conn, cur = create_mysql_db_conn() + else: + conn = None + cur = None + + print("Error! No valid db mode found. Please use mssql or mysql") + + return conn, cur + + +def create_ms_db_conn(): + global cfg + if not cfg: + cfg = load_conf() + + try: + conn = pyodbc.connect( + Driver="{ODBC Driver 17 for SQL Server}", + Server=cfg["sqlServerIP"], + Database=cfg["sqlDatabase"], + user=cfg["sqlUsername"], + password=cfg["sqlPassword"], + ) + except Error as e: + + print(e) + + cur = conn.cursor() + + return conn, cur + + +def convert_to_mysql_query(sql_query): + sql_query = sql_query.replace("?", "%s") + return sql_query + + +def create_mysql_db_conn(): + global cfg + if not cfg: + cfg = load_conf() + + try: + conn = connect( + host=cfg["sqlServerIP"], + user=cfg["sqlUsername"], + password=cfg["sqlPassword"], + database=cfg["sqlDatabase"], + ) + except Error as e: + print(e) + + cur = conn.cursor() + return conn, cur + + +def init_mysql_db(conn): + create_receipts_tags = "CREATE TABLE IF NOT EXISTS tags (id int, tagName nvarchar(50), PRIMARY KEY(id)); " + create_receipts_stores = "CREATE TABLE IF NOT EXISTS stores (id int, storeName nvarchar(50), PRIMARY KEY(id));" + create_receipts_categories = "CREATE TABLE IF NOT EXISTS categories (id int, categoryName nvarchar(50), PRIMARY KEY(id)); " + create_receipts_items = "CREATE TABLE IF NOT EXISTS items (id int, itemName nvarchar(100), itemTotal decimal(15,2), categoryId int, FOREIGN KEY (categoryId) REFERENCES categories(id), PRIMARY KEY(id));" + create_receipts_purchases_articles = " CREATE TABLE IF NOT EXISTS purchasesArticles (id int, itemid int, FOREIGN KEY (itemid) REFERENCES items(id));" + create_receipts_receipts = " CREATE TABLE IF NOT EXISTS receipts (id int, storeId int, `date` date, total decimal(15,2), tagId int, FOREIGN KEY (tagId) REFERENCES tags(id), purchaseId int, PRIMARY KEY(id));" + + create_receipts_view = """ + CREATE OR REPLACE VIEW purchaseData AS + select i.itemName article_name, 1 amount, itemTotal total, c.categoryName category, storeName location, date timestamp, CONVERT(r.id, char) id from receipts r + JOIN stores s ON r.storeId = s.id + JOIN purchasesArticles pa ON r.purchaseId = pa.id + JOIN items i on pa.itemid = i.id + JOIN categories c on c.id = i.categoryId + """ + + if conn: + create_mysql_table(conn, create_receipts_tags) + create_mysql_table(conn, create_receipts_stores) + create_mysql_table(conn, create_receipts_categories) + create_mysql_table(conn, create_receipts_items) + create_mysql_table(conn, create_receipts_purchases_articles) + create_mysql_table(conn, create_receipts_receipts) + create_mysql_table(conn, create_receipts_view) + conn.close() + else: + print("Error! cannot create the database connection.") + + +def init_mssql_db(conn): + create_receipts_tables = """ IF object_id('tags', 'U') is null + CREATE TABLE tags (id int PRIMARY KEY, tagName nvarchar(50)) + IF object_id('stores', 'U') is null + CREATE TABLE stores (id int PRIMARY KEY, storeName nvarchar(50)) + IF object_id('categories', 'U') is null + CREATE TABLE categories (id int PRIMARY KEY, categoryName nvarchar(50)) + IF object_id('items', 'U') is null + CREATE TABLE items (id int PRIMARY KEY, itemName nvarchar(100), itemTotal decimal(15,2), categoryId int FOREIGN KEY REFERENCES categories(id)) + IF object_id('purchasesArticles', 'U') is null + CREATE TABLE purchasesArticles (id int, itemid int FOREIGN KEY REFERENCES items(id)) + IF object_id('receipts', 'U') is null + CREATE TABLE receipts (id int PRIMARY KEY, storeId int, [date] date, total decimal(15,2), tagId int FOREIGN KEY REFERENCES tags(id), purchaseId int) + """ + create_receipts_view = """ IF object_id('purchaseData', 'V') is null + EXEC('CREATE VIEW purchaseData AS + select i.itemName article_name, 1 amount, itemTotal total, c.categoryName category, storeName location, date timestamp, CONVERT(varchar, r.id) id from receipts r + JOIN stores s ON r.storeId = s.id + JOIN purchasesArticles pa ON r.purchaseId = pa.id + JOIN items i on pa.itemid = i.id + JOIN categories c on c.id = i.categoryId') + """ + if conn: + create_mssql_table(conn, create_receipts_tables) + create_mssql_table(conn, create_receipts_view) + conn.close() + else: + print("Error! cannot create the database connection.") + + +def create_mysql_table(conn, sql_query): + try: + cur = conn.cursor() + cur.execute(sql_query) + conn.commit() + + except Error as e: + print(e) + + +def create_mssql_table(conn, sql_query): + try: + cur = conn.cursor() + cur.execute(sql_query) + conn.commit() + + except pyodbc.Error as e: + print(e) + +def delete_receipt(receipt_id): + conn, cursor = load_db_conn() + + sql_query = "DELETE FROM receipts WHERE ID = ?" + if cfg["dbMode"] == "mysql": + sql_query = convert_to_mysql_query(sql_query) + cursor.execute(sql_query, [receipt_id]) + + sql_query = "DELETE FROM purchasesArticles WHERE ID = ?" + if cfg["dbMode"] == "mysql": + sql_query = convert_to_mysql_query(sql_query) + cursor.execute(sql_query, [receipt_id]) + + sql_query = "DELETE FROM items where id not in (select itemid from purchasesArticles)" + if cfg["dbMode"] == "mysql": + sql_query = convert_to_mysql_query(sql_query) + cursor.execute(sql_query) + + conn.commit() + conn.close() + +def get_category_id(category_name): + conn, cursor = load_db_conn() + + sql_query = "select id from categories where categoryName = ?" + if cfg["dbMode"] == "mysql": + sql_query = convert_to_mysql_query(sql_query) + + cursor.execute(sql_query, [category_name]) + rows = cursor.fetchone() + conn.close() + + if rows: + category_id = rows[0] + else: + category_id = None + + return category_id + +def get_store_id(store_name): + conn, cursor = load_db_conn() + + sql_query = "select id from stores where storeName = ?" + if cfg["dbMode"] == "mysql": + sql_query = convert_to_mysql_query(sql_query) + + cursor.execute(sql_query, [store_name]) + rows = cursor.fetchone() + conn.close() + + if rows: + store_id = rows[0] + else: + store_id = add_or_update_to_db("stores", None, store_name) + + return store_id diff --git a/webroot/.gitkeep b/webroot/.gitkeep new file mode 100644 index 0000000..e69de29