adding rest of docker files

main
Thomas Hintz 9 months ago
parent 618fd1c4bb
commit 71dbd5eab7

@ -0,0 +1,14 @@
backendHostname: receiptmanager.home.local
backendIP: 192.168.50.59
backendLanguage: de-DE
backendPort: 5558
dbMode: mssql
encrypted: true
parserIP: 192.168.50.6
parserPort: 8721
parserToken: dd8dfbfbc51233fea39acc23c9a08fc3
sqlDatabase: receiptData
sqlPassword: cbadf3c7d51a20df94a0ed37ed92bff16a75fdaa35d1f53e78f99ebe0179f89e
sqlServerIP: 192.168.50.6
sqlUsername: receiptParser
useSSL: false

@ -0,0 +1,11 @@
flask==2.0.1
PyYAML==5.4.1
gevent==23.9.0
Flask-Cors==3.0.10
pyodbc==4.0.30
requests==2.25.1
uuid==1.30
mysql-connector-python==8.0.25
cryptography==3.4.7
pycryptodome==3.10.1
Wand==0.6.6

@ -0,0 +1,40 @@
from server import server
from util import (
create_ssl_cert,
init_mssql_db,
init_mysql_db,
load_conf,
create_web_config,
check_existing_token,
load_db_conn,
)
def main():
cfg = load_conf()
if cfg["useSSL"]:
create_ssl_cert([cfg["backendIP"]])
if cfg["dbMode"] and cfg["sqlDatabase"] and cfg["sqlPassword"] and cfg["sqlServerIP"] and cfg["sqlUsername"]:
print("Using " + cfg["dbMode"] + " DB")
try:
conn = load_db_conn()[0]
if cfg["dbMode"] == "mssql":
init_mssql_db(conn)
elif cfg["dbMode"] == "mysql":
init_mysql_db(conn)
else:
print("Error! No valid db mode found. Please use mssql or mysql")
except Exception as e:
print(e)
else:
print("No db mode set.")
check_existing_token()
create_web_config()
server()
if __name__ == "__main__":
main()

@ -0,0 +1,430 @@
import json
import uuid
from datetime import datetime
import requests
from flask import Flask, request, send_from_directory
from flask_cors import CORS, cross_origin
from util import (
check_existing_token,
convert_to_mysql_query,
get_category_id,
get_data_from_db,
add_or_update_to_db,
delete_from_db,
get_store_id,
load_conf,
load_db_conn,
delete_receipt,
update_server_config,
convert_pdf_to_png
)
app = Flask(
__name__,
static_url_path="",
static_folder="../webroot",
)
cors = CORS(app, resources={r"/api/upload/*": {"origins": "*"}})
app.config["CORS_HEADERS"] = "Content-Type"
cfg = None
api_token = None
@app.before_first_request
def first():
global cfg
global api_token
cfg = load_conf()
api_token = check_existing_token()
@app.before_request
def before_request():
if request.endpoint in ('static', 'index'):
return
if not request.args:
return "No token provided! Add &token= to URL", 401
if api_token != request.args["token"]:
return "Unauthorized", 401
global cfg
cfg = load_conf()
if ((
not cfg["dbMode"]
or not cfg["sqlDatabase"]
or not cfg["sqlPassword"]
or not cfg["sqlServerIP"]
or not cfg["sqlUsername"]
)
and request.endpoint != 'updateConfig'
and request.endpoint != 'getBackendConfig'
):
return "Settings incomplete!", 512
@app.route("/", methods=["GET"])
def index():
return send_from_directory(app.static_folder, "index.html")
@app.route("/api/getBackendConfig", methods=["GET", "OPTIONS"])
@cross_origin(origin="*", headers=["Content-Type"])
def getBackendConfig():
return json.dumps(cfg)
@app.route("/api/updateConfig", methods=["POST", "OPTIONS"])
@cross_origin(origin="*", headers=["Content-Type"])
def updateConfig():
post_string = json.dumps(request.get_json())
post_json = json.loads(post_string)
update_server_config(post_json)
return "Config updated", 200
@app.route("/api/upload", methods=["POST", "OPTIONS"])
@cross_origin(origin="*", headers=["Content-Type"])
def upload():
file = request.files["file"]
file_name = file.filename
gaussian_blur = True
if file.content_type == "application/pdf":
file = convert_pdf_to_png(file.read())
file_name = file_name.split(".")[0] + ".png"
gaussian_blur = False
if not file:
return "No valid file found", 500
url = (
"http://"
+ str(cfg["parserIP"])
+ ":"
+ str(cfg["parserPort"])
+ "/api/upload?access_token="
+ str(cfg["parserToken"])
+ "&legacy_parser=True"
+ "&grayscale_image=True"
+ "&rotate_image=True"
+ "&gaussian_blur=" + str(gaussian_blur)
+ "&median_blur=True"
)
receipt_upload = requests.post(url, files={"file": (file_name, file)})
if receipt_upload.status_code == 200:
upload_response = json.dumps(receipt_upload.content.decode("utf8"))
response_json = json.loads(upload_response)
response_json = json.loads(response_json)
# Replace " in Date
if '"' in response_json["receiptDate"]:
response_json["receiptDate"] = response_json["receiptDate"].replace('"', "")
# Create 4 digit year
if response_json["receiptDate"] != "null":
year_string = response_json["receiptDate"].split(".")
if len(year_string[2]) < 4:
year_string[2] = "20" + year_string[2]
response_json["receiptDate"] = (
year_string[0] + "." + year_string[1] + "." + year_string[2]
)
conn, cursor = load_db_conn()
for idx, article in enumerate(response_json["receiptItems"]):
article = article[0]
splitted_articles = article.split(" ")
for article in splitted_articles:
if len(article) > 3:
if cfg["dbMode"] == "mysql":
sql_query = "SELECT category FROM purchaseData where article_name like %s order by timestamp desc limit 1"
else:
sql_query = "SELECT TOP 1 category FROM purchaseData where article_name like ? order by timestamp desc"
cursor.execute(sql_query, [f"%{article}%"])
row = cursor.fetchone()
if row:
if cfg["dbMode"] == "mysql":
found_cat = row[0]
else:
found_cat = row.category
copy_array = response_json["receiptItems"][idx]
copy_array.insert(2, found_cat)
response_json["receiptItems"][idx] = copy_array
break
conn.close()
return json.dumps(response_json)
else:
return "Error on upload", receipt_upload.status_code
@app.route("/api/getHistory", methods=["GET", "OPTIONS"])
@cross_origin(origin="*", headers=["Content-Type"])
def get_history():
conn, cursor = load_db_conn()
history_json = []
cursor.execute(
"select SUM(total) as totalSum, location, id, timestamp from purchaseData \
where id is not null \
GROUP BY timestamp, id, location \
ORDER BY timestamp desc"
)
rows = cursor.fetchall()
for row in rows:
if cfg["dbMode"] == "mysql":
add_json = {
"location": row[1],
"totalSum": str(row[0]),
"timestamp": str(row[3]),
"id": row[2],
}
else:
add_json = {
"location": row.location,
"totalSum": str(row.totalSum),
"timestamp": str(row.timestamp),
"id": row.id,
}
history_json.append(add_json)
conn.close()
return json.dumps(history_json)
@app.route("/api/getHistoryDetails", methods=["GET", "OPTIONS"])
@cross_origin(origin="*", headers=["Content-Type"])
def get_history_details():
purchase_id = request.args["purchaseID"]
conn, cursor = load_db_conn()
sql_query = (
"SELECT storeName, total, date, purchaseId " +
"FROM receipts re " +
"JOIN stores st ON re.storeId = st.id " +
"where re.id = ?"
)
if cfg["dbMode"] == "mysql":
sql_query = convert_to_mysql_query(sql_query)
cursor.execute(sql_query, [purchase_id])
row = cursor.fetchone()
if row:
if cfg["dbMode"] == "mysql":
store_name = row[0]
receipt_total = row[1]
receipt_date = row[2]
db_purchase_id = row[3]
else:
store_name = row.storeName
receipt_total = row.total
receipt_date = row.date
db_purchase_id = row.purchaseId
receipt_date = receipt_date.strftime("%d.%m.%Y")
purchase_details = {
"storeName": store_name,
"receiptTotal": str(receipt_total),
"receiptDate": receipt_date,
"purchaseID": db_purchase_id,
"receiptItems": [],
}
sql_query = "select article_name, total, category from purchaseData where id = ?"
if cfg["dbMode"] == "mysql":
sql_query = convert_to_mysql_query(sql_query)
cursor.execute(sql_query, [purchase_id])
rows = cursor.fetchall()
for row in rows:
if cfg["dbMode"] == "mysql":
add_json = [row[0], str(row[1]), row[2]]
else:
add_json = [row.article_name, str(row.total), row.category]
purchase_details["receiptItems"].append(add_json)
conn.close()
return json.dumps(purchase_details)
else:
return "Purchase not found!", 500
@app.route("/api/getValue", methods=["GET", "OPTIONS"])
@cross_origin(origin="*", headers=["Content-Type"])
def get_categories():
get_values_from = request.args["getValuesFrom"]
ret_json = get_data_from_db(get_values_from)
return json.dumps(ret_json, ensure_ascii=False)
@app.route("/api/addValue", methods=["POST", "OPTIONS"])
@cross_origin(origin="*", headers=["Content-Type"])
def add_value():
to_add_array = request.args["toAddArray"]
to_add_value = request.args["toAddValue"]
item_id = request.args["id"]
add_or_update_to_db(to_add_array, item_id, to_add_value)
return "Done!"
@app.route("/api/deleteValue", methods=["POST", "OPTIONS"])
@cross_origin(origin="*", headers=["Content-Type"])
def delete_value():
table_name = request.args["tableName"]
item_id = request.args["id"]
delete_from_db(table_name, item_id)
return "Done!"
@app.route("/api/deleteReceiptFromDB", methods=["POST", "OPTIONS"])
@cross_origin(origin="*", headers=["Content-Type"])
def delete_receipt_from_db():
receipt_id = request.args["purchaseID"]
delete_receipt(receipt_id)
return "Done!"
@app.route("/api/updateReceiptToDB", methods=["POST", "OPTIONS"])
@cross_origin(origin="*", headers=["Content-Type"])
def update_receipt_to_db():
post_string = json.dumps(request.get_json())
post_json = json.loads(post_string)
conn, cursor = load_db_conn()
store_id = get_store_id(post_json["storeName"])
receipt_date = post_json["receiptDate"]
receipt_total = post_json["receiptTotal"]
receipt_id = post_json["purchaseID"]
receipt_date = datetime.strptime(receipt_date, "%d.%m.%Y")
receipt_date = receipt_date.strftime("%m-%d-%Y")
# Clean old db data
delete_receipt(receipt_id)
# Write article positions
for article in post_json["receiptItems"]:
article_name = article[1]
article_sum = article[2]
article_id = int(str(uuid.uuid1().int)[:8])
article_category_id = get_category_id(article[3])
sql_query = "INSERT INTO items values (?,?,?,?)"
if cfg["dbMode"] == "mysql":
sql_query = convert_to_mysql_query(sql_query)
cursor.execute(
sql_query, [article_id, article_name, article_sum, article_category_id]
)
sql_query = "INSERT INTO purchasesArticles values (?,?)"
if cfg["dbMode"] == "mysql":
sql_query = convert_to_mysql_query(sql_query)
cursor.execute(sql_query, [receipt_id, article_id])
# Write receipt summary
if cfg["dbMode"] == "mysql":
sql_query = (
"INSERT INTO receipts values (%s,%s,STR_TO_DATE(%s,'%m-%d-%Y'),%s,%s,%s)"
)
else:
sql_query = "INSERT INTO receipts values (?,?,?,?,?,?)"
cursor.execute(
sql_query, [receipt_id, store_id, receipt_date, receipt_total, None, receipt_id]
)
conn.commit()
conn.close()
return "Done!"
@app.route("/api/writeReceiptToDB", methods=["POST", "OPTIONS"])
@cross_origin(origin="*", headers=["Content-Type"])
def write_receipt_to_db():
post_string = json.dumps(request.get_json())
post_json = json.loads(post_string)
conn, cursor = load_db_conn()
store_id = get_store_id(post_json["storeName"])
receipt_date = post_json["receiptDate"]
receipt_total = post_json["receiptTotal"]
receipt_id = int(str(uuid.uuid1().int)[:6])
receipt_date = datetime.strptime(receipt_date, "%d.%m.%Y")
receipt_date = receipt_date.strftime("%m-%d-%Y")
# Write article positions
for article in post_json["receiptItems"]:
article_id = int(str(uuid.uuid1().int)[:8])
article_name = article[1]
article_sum = article[2]
article_category_id = get_category_id(article[3])
if not article_category_id:
return "Category id for category: " + article[3] + " not found", 500
sql_query = "INSERT INTO items values (?,?,?,?)"
if cfg["dbMode"] == "mysql":
sql_query = convert_to_mysql_query(sql_query)
cursor.execute(
sql_query, [article_id, article_name, article_sum, article_category_id]
)
sql_query = "INSERT INTO purchasesArticles values (?,?)"
if cfg["dbMode"] == "mysql":
sql_query = convert_to_mysql_query(sql_query)
cursor.execute(sql_query, [receipt_id, article_id])
# Write receipt summary
if cfg["dbMode"] == "mysql":
sql_query = (
"INSERT INTO receipts values (%s,%s,STR_TO_DATE(%s,'%m-%d-%Y'),%s,%s,%s)"
)
else:
sql_query = "INSERT INTO receipts values (?,?,?,?,?,?)"
cursor.execute(
sql_query, [receipt_id, store_id, receipt_date, receipt_total, None, receipt_id]
)
conn.commit()
conn.close()
return "Done!"
@app.after_request
def add_header(r):
r.headers["Cache-Control"] = "no-cache, no-store, must-revalidate"
r.headers["Pragma"] = "no-cache"
r.headers["Expires"] = "0"
r.headers['Cache-Control'] = 'public, max-age=0'
return r

@ -0,0 +1,40 @@
from gevent import monkey
monkey.patch_all()
from gevent.pywsgi import WSGIServer
from api import app
from util import load_conf, check_existing_token
def server():
cfg = load_conf()
api_token = check_existing_token()
if cfg["useSSL"]:
http_server = WSGIServer(
(cfg["backendIP"], int(cfg["backendPort"])),
app,
certfile="ssl/cert.crt",
keyfile="ssl/key.pem",
)
print(
"Server started. Running on https://"
+ str(cfg["backendIP"])
+ ":"
+ str(cfg["backendPort"])
)
else:
http_server = WSGIServer((cfg["backendIP"], int(cfg["backendPort"])), app)
print(
"Server started. Running on http://"
+ str(cfg["backendIP"])
+ ":"
+ str(cfg["backendPort"])
)
print("API Token: " + api_token)
if cfg['parserIP']:
print("Parser IP set to: " + str(cfg["parserIP"]) + ":" + str(cfg["parserPort"]))
else:
print("No parser IP set.")
http_server.serve_forever()

@ -0,0 +1,609 @@
import pyodbc
import uuid
import json
import yaml
import os
import ipaddress
import socket
from mysql.connector import connect, Error
from datetime import datetime, timedelta
from wand.image import Image
from wand.color import Color
from Crypto.Cipher import AES
from cryptography.fernet import Fernet
from binascii import b2a_hex, a2b_hex
from cryptography import x509
from cryptography.x509.oid import NameOID
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
cfg = None
api_token = None
key = None
BLOCK_SIZE = 16
SEGMENT_SIZE = 128
def convert_pdf_to_png(pdf_file):
try:
with Image(blob=pdf_file, resolution=200) as pdf:
first_page = pdf.sequence[0]
with Image(first_page) as i:
i.format = 'png'
i.compression_quality = 99
i.background_color = Color('white')
i.alpha_channel = 'remove'
binary_png = i.make_blob("png")
return binary_png
except Exception as e:
if hasattr(e, 'wand_error_code'):
if e.wand_error_code == 415:
print("ERROR ! Please install ghostscript to convert PDF to PNG from https://www.ghostscript.com/download/gsdnld.html or on linux with 'apt-get install ghostscript'. \n" +
"If you are using Linux, have a look at here: https://stackoverflow.com/questions/57208396/imagemagick-ghostscript-delegate-security-policy-blocking-conversion")
else:
print("ERROR ! " + e)
else:
print("ERROR ! " + e)
return None
def encrypt(plaintext):
key = check_existing_key()
key = key.encode('utf-8')
iv = key
aes = AES.new(key, AES.MODE_CFB, iv, segment_size=SEGMENT_SIZE)
plaintext = _pad_string(plaintext)
encrypted_text = aes.encrypt(plaintext.encode())
return b2a_hex(encrypted_text).rstrip().decode()
def decrypt(encrypted_text):
key = check_existing_key()
key = key.encode('utf-8')[:16]
iv = key
aes = AES.new(key, AES.MODE_CFB, iv, segment_size=SEGMENT_SIZE)
encrypted_text_bytes = a2b_hex(encrypted_text)
decrypted_text = aes.decrypt(encrypted_text_bytes)
decrypted_text = _unpad_string(decrypted_text.decode())
return decrypted_text
def _pad_string(value):
length = len(value)
pad_size = BLOCK_SIZE - (length % BLOCK_SIZE)
return value.ljust(length + pad_size, '\x00')
def _unpad_string(value):
while value[-1] == '\x00':
value = value[:-1]
return value
def create_ssl_cert(
ip_addresses=None,
ca_cert="../webroot/ssl/ca.crt",
ca_key="../webroot/ssl/ca.key",
key_file="ssl/key.pem",
cert_file="ssl/cert.crt",
):
root_cert = None
root_key = None
now = datetime.utcnow()
if not os.path.isfile(ca_cert) and not os.path.isfile(ca_key):
# Create CA Cert and Key
root_key = rsa.generate_private_key(
public_exponent=65537,
key_size=2048,
backend=default_backend()
)
subject = issuer = x509.Name([
x509.NameAttribute(NameOID.COUNTRY_NAME, u"DE"),
x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, u"NRW"),
x509.NameAttribute(NameOID.LOCALITY_NAME, u"NV"),
x509.NameAttribute(NameOID.ORGANIZATION_NAME, u"ReceiptManager"),
x509.NameAttribute(NameOID.COMMON_NAME, u"receipt-manager-ca"),
])
basic_contraints = x509.BasicConstraints(ca=True, path_length=1)
root_cert = (
x509.CertificateBuilder()
.subject_name(subject)
.issuer_name(issuer)
.public_key(root_key.public_key())
.serial_number(x509.random_serial_number())
.not_valid_before(now)
.not_valid_after(now + timedelta(days=2 * 365))
.add_extension(basic_contraints, False)
.sign(root_key, hashes.SHA256(), default_backend()))
ca_cert_pem = root_cert.public_bytes(encoding=serialization.Encoding.PEM)
ca_key_pem = root_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption(),
)
open(ca_key, "wb").write(ca_key_pem)
open(ca_cert, "wb").write(ca_cert_pem)
if not os.path.isfile(cert_file) and not os.path.isfile(key_file):
if not root_cert and not root_key:
cert_binary = open(ca_cert,"rb").read()
root_cert = x509.load_pem_x509_certificate(cert_binary, default_backend())
key_binary = open(ca_key,"rb").read()
root_key = serialization.load_pem_private_key(key_binary, None, default_backend())
# Generate our key
key = rsa.generate_private_key(
public_exponent=65537,
key_size=2048,
backend=default_backend(),
)
name = x509.Name(
[x509.NameAttribute(NameOID.COMMON_NAME, cfg["backendHostname"])]
)
# best practice seem to be to include the hostname in the SAN, which *SHOULD* mean COMMON_NAME is ignored.
alt_names = [x509.DNSName(cfg["backendHostname"]), x509.DNSName("localhost")]
# allow addressing by IP, for when you don't have real DNS (common in most testing scenarios
if ip_addresses:
for addr in ip_addresses:
# openssl wants DNSnames for ips...
alt_names.append(x509.DNSName(addr))
# ... whereas golang's crypto/tls is stricter, and needs IPAddresses
# note: older versions of cryptography do not understand ip_address objects
alt_names.append(x509.IPAddress(ipaddress.ip_address(addr)))
san = x509.SubjectAlternativeName(alt_names)
extended_key_usage = x509.ExtendedKeyUsage([x509.oid.ExtendedKeyUsageOID.SERVER_AUTH])
cert = (
x509.CertificateBuilder()
.subject_name(name)
.issuer_name(root_cert.issuer)
.public_key(key.public_key())
.serial_number(x509.random_serial_number())
.not_valid_before(now)
.not_valid_after(now + timedelta(days=2 * 365))
.add_extension(san, False)
.add_extension(extended_key_usage, True)
.sign(root_key, hashes.SHA256(), default_backend())
)
cert_pem = cert.public_bytes(encoding=serialization.Encoding.PEM)
key_pem = key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption(),
)
open(key_file, "wb").write(key_pem)
open(cert_file, "wb").write(cert_pem)
def update_server_config(settings):
crypt_config(settings)
load_conf(True)
create_web_config()
def update_config_yaml(settings):
config_file = open('../config/config.yaml', 'w')
yaml.dump(settings, config_file)
config_file.close()
def create_web_config():
web_json = "../webroot/settings/settings.json"
web_cfg = {
"useSSL": cfg["useSSL"],
"backendIP": cfg["backendIP"],
"backendPort": cfg["backendPort"],
"backendToken": api_token,
"language": cfg["backendLanguage"]
}
f = open(web_json, "w")
f.write(json.dumps(web_cfg))
f.close()
def check_existing_key():
if not os.path.isfile(r"../config/.key"):
create_key()
else:
read_key()
return key
def read_key():
global key
if not key:
with open(r"../config/.key") as f:
key = f.readline()
def create_key():
global key
new_key = Fernet.generate_key().decode('utf-8')[:16]
f = open("../config/.key", "w")
f.write(new_key)
key = new_key
f.close()
def check_existing_token():
if not os.path.isfile(r".api_token"):
create_token()
else:
read_token()
return api_token
def read_token():
global api_token
if not api_token:
with open(r".api_token") as f:
api_token = f.readline()
def create_token():
global api_token
new_token = str(uuid.uuid4())
f = open(".api_token", "w")
f.write(new_token)
api_token = new_token
f.close()
def create_initial_config():
use_ssl = os.environ.get("useSSL", False)
if isinstance(use_ssl, str):
if use_ssl.lower() == 'true':
use_ssl = True
else:
use_ssl = False
run_in_docker = os.environ.get("RUN_IN_DOCKER", False)
if not run_in_docker:
backend_ip = socket.gethostbyname(socket.gethostname())
else:
backend_ip = os.environ.get("backendIP", None)
if not backend_ip:
stream = os.popen(r"ip -4 addr show eth0 | grep -Po 'inet \K[\d.]+'")
backend_ip = stream.read().rstrip()
backend_port = os.environ.get("backendPort", 5558)
print("Initial config created.")
temp_config = {
"useSSL": use_ssl,
"backendHostname": "",
"backendIP": backend_ip,
"backendPort": backend_port,
"backendLanguage": "",
"parserIP": "",
"parserPort": "",
"parserToken": "",
"dbMode": "",
"sqlServerIP": "",
"sqlDatabase": "",
"sqlUsername": "",
"sqlPassword": "",
}
config = json.dumps(temp_config)
cfg = json.loads(config)
update_config_yaml(cfg)
def load_conf(force_reload=False):
global cfg
if not os.path.isfile("../config/config.yaml"):
create_initial_config()
if not cfg or force_reload:
with open("../config/config.yaml", "r") as ymlfile:
cfg = yaml.load(ymlfile, Loader=yaml.FullLoader)
cfg = crypt_config(cfg)
return cfg
def crypt_config(settings):
rewrite_config = False
encrypted_cfg = settings.copy()
if 'encrypted' in settings:
is_encrypted = settings['encrypted']
else:
is_encrypted = False
if not is_encrypted:
rewrite_config = True
for c, v in settings.items():
if v and ("Token" in c or "Password" in c):
if not is_encrypted:
rewrite_config = True
encrypted = encrypt(str(v))
encrypted_cfg[c] = encrypted
else:
try:
decrypted = decrypt(str(v))
settings[c] = decrypted
except Exception as e:
if "Non-hexadecimal digit found" in str(e):
print("Decryption failed. Set encryption flag in config yaml to False!")
else:
print(e)
if rewrite_config:
settings['encrypted'] = True
encrypted_cfg['encrypted'] = True
update_config_yaml(encrypted_cfg)
return settings
def delete_from_db(table_name, id):
conn, cur = load_db_conn()
sql_query = "DELETE FROM " + table_name + " WHERE id = ?"
if cfg["dbMode"] == "mysql":
sql_query = convert_to_mysql_query(sql_query)
cur.execute(sql_query, [id])
conn.commit()
conn.close()
def add_or_update_to_db(to_add_table, item_id, to_add_value):
conn, cur = load_db_conn()
if to_add_table == "categories":
name_col = "categoryName"
elif to_add_table == "stores":
name_col = "storeName"
if item_id:
sql_update = (
""" UPDATE """
+ to_add_table
+ """ SET """
+ name_col
+ """ = ? WHERE id = ?"""
)
if cfg["dbMode"] == "mysql":
sql_update = convert_to_mysql_query(sql_update)
cur.execute(sql_update, [to_add_value, item_id])
else:
item_id = int(str(uuid.uuid1().int)[:6])
sql_insert = """ INSERT INTO """ + to_add_table + """ VALUES (?, ?)"""
if cfg["dbMode"] == "mysql":
sql_insert = convert_to_mysql_query(sql_insert)
cur.execute(sql_insert, [item_id, to_add_value])
conn.commit()
conn.close()
return item_id
def get_data_from_db(table_name):
conn, cur = load_db_conn()
if table_name == "categories":
orderby = "categoryName"
elif table_name == "stores":
orderby = "storeName"
sql_select = """ SELECT * from """ + table_name + """ order by """ + orderby
cur.execute(sql_select)
rows = cur.fetchall()
conn.close()
ret_array = json.dumps({"values": []})
ret_json = json.loads(ret_array)
for row in rows:
add_array = {"name": row[1], "id": row[0]}
ret_json["values"].append(add_array)
return ret_json
def load_db_conn():
if cfg["dbMode"] == "mssql":
conn, cur = create_ms_db_conn()
elif cfg["dbMode"] == "mysql":
conn, cur = create_mysql_db_conn()
else:
conn = None
cur = None
print("Error! No valid db mode found. Please use mssql or mysql")
return conn, cur
def create_ms_db_conn():
global cfg
if not cfg:
cfg = load_conf()
try:
conn = pyodbc.connect(
Driver="{ODBC Driver 17 for SQL Server}",
Server=cfg["sqlServerIP"],
Database=cfg["sqlDatabase"],
user=cfg["sqlUsername"],
password=cfg["sqlPassword"],
)
except Error as e:
print(e)
cur = conn.cursor()
return conn, cur
def convert_to_mysql_query(sql_query):
sql_query = sql_query.replace("?", "%s")
return sql_query
def create_mysql_db_conn():
global cfg
if not cfg:
cfg = load_conf()
try:
conn = connect(
host=cfg["sqlServerIP"],
user=cfg["sqlUsername"],
password=cfg["sqlPassword"],
database=cfg["sqlDatabase"],
)
except Error as e:
print(e)
cur = conn.cursor()
return conn, cur
def init_mysql_db(conn):
create_receipts_tags = "CREATE TABLE IF NOT EXISTS tags (id int, tagName nvarchar(50), PRIMARY KEY(id)); "
create_receipts_stores = "CREATE TABLE IF NOT EXISTS stores (id int, storeName nvarchar(50), PRIMARY KEY(id));"
create_receipts_categories = "CREATE TABLE IF NOT EXISTS categories (id int, categoryName nvarchar(50), PRIMARY KEY(id)); "
create_receipts_items = "CREATE TABLE IF NOT EXISTS items (id int, itemName nvarchar(100), itemTotal decimal(15,2), categoryId int, FOREIGN KEY (categoryId) REFERENCES categories(id), PRIMARY KEY(id));"
create_receipts_purchases_articles = " CREATE TABLE IF NOT EXISTS purchasesArticles (id int, itemid int, FOREIGN KEY (itemid) REFERENCES items(id));"
create_receipts_receipts = " CREATE TABLE IF NOT EXISTS receipts (id int, storeId int, `date` date, total decimal(15,2), tagId int, FOREIGN KEY (tagId) REFERENCES tags(id), purchaseId int, PRIMARY KEY(id));"
create_receipts_view = """
CREATE OR REPLACE VIEW purchaseData AS
select i.itemName article_name, 1 amount, itemTotal total, c.categoryName category, storeName location, date timestamp, CONVERT(r.id, char) id from receipts r
JOIN stores s ON r.storeId = s.id
JOIN purchasesArticles pa ON r.purchaseId = pa.id
JOIN items i on pa.itemid = i.id
JOIN categories c on c.id = i.categoryId
"""
if conn:
create_mysql_table(conn, create_receipts_tags)
create_mysql_table(conn, create_receipts_stores)
create_mysql_table(conn, create_receipts_categories)
create_mysql_table(conn, create_receipts_items)
create_mysql_table(conn, create_receipts_purchases_articles)
create_mysql_table(conn, create_receipts_receipts)
create_mysql_table(conn, create_receipts_view)
conn.close()
else:
print("Error! cannot create the database connection.")
def init_mssql_db(conn):
create_receipts_tables = """ IF object_id('tags', 'U') is null
CREATE TABLE tags (id int PRIMARY KEY, tagName nvarchar(50))
IF object_id('stores', 'U') is null
CREATE TABLE stores (id int PRIMARY KEY, storeName nvarchar(50))
IF object_id('categories', 'U') is null
CREATE TABLE categories (id int PRIMARY KEY, categoryName nvarchar(50))
IF object_id('items', 'U') is null
CREATE TABLE items (id int PRIMARY KEY, itemName nvarchar(100), itemTotal decimal(15,2), categoryId int FOREIGN KEY REFERENCES categories(id))
IF object_id('purchasesArticles', 'U') is null
CREATE TABLE purchasesArticles (id int, itemid int FOREIGN KEY REFERENCES items(id))
IF object_id('receipts', 'U') is null
CREATE TABLE receipts (id int PRIMARY KEY, storeId int, [date] date, total decimal(15,2), tagId int FOREIGN KEY REFERENCES tags(id), purchaseId int)
"""
create_receipts_view = """ IF object_id('purchaseData', 'V') is null
EXEC('CREATE VIEW purchaseData AS
select i.itemName article_name, 1 amount, itemTotal total, c.categoryName category, storeName location, date timestamp, CONVERT(varchar, r.id) id from receipts r
JOIN stores s ON r.storeId = s.id
JOIN purchasesArticles pa ON r.purchaseId = pa.id
JOIN items i on pa.itemid = i.id
JOIN categories c on c.id = i.categoryId')
"""
if conn:
create_mssql_table(conn, create_receipts_tables)
create_mssql_table(conn, create_receipts_view)
conn.close()
else:
print("Error! cannot create the database connection.")
def create_mysql_table(conn, sql_query):
try:
cur = conn.cursor()
cur.execute(sql_query)
conn.commit()
except Error as e:
print(e)
def create_mssql_table(conn, sql_query):
try:
cur = conn.cursor()
cur.execute(sql_query)
conn.commit()
except pyodbc.Error as e:
print(e)
def delete_receipt(receipt_id):
conn, cursor = load_db_conn()
sql_query = "DELETE FROM receipts WHERE ID = ?"
if cfg["dbMode"] == "mysql":
sql_query = convert_to_mysql_query(sql_query)
cursor.execute(sql_query, [receipt_id])
sql_query = "DELETE FROM purchasesArticles WHERE ID = ?"
if cfg["dbMode"] == "mysql":
sql_query = convert_to_mysql_query(sql_query)
cursor.execute(sql_query, [receipt_id])
sql_query = "DELETE FROM items where id not in (select itemid from purchasesArticles)"
if cfg["dbMode"] == "mysql":
sql_query = convert_to_mysql_query(sql_query)
cursor.execute(sql_query)
conn.commit()
conn.close()
def get_category_id(category_name):
conn, cursor = load_db_conn()
sql_query = "select id from categories where categoryName = ?"
if cfg["dbMode"] == "mysql":
sql_query = convert_to_mysql_query(sql_query)
cursor.execute(sql_query, [category_name])
rows = cursor.fetchone()
conn.close()
if rows:
category_id = rows[0]
else:
category_id = None
return category_id
def get_store_id(store_name):
conn, cursor = load_db_conn()
sql_query = "select id from stores where storeName = ?"
if cfg["dbMode"] == "mysql":
sql_query = convert_to_mysql_query(sql_query)
cursor.execute(sql_query, [store_name])
rows = cursor.fetchone()
conn.close()
if rows:
store_id = rows[0]
else:
store_id = add_or_update_to_db("stores", None, store_name)
return store_id
Loading…
Cancel
Save