about summary refs log tree commit diff
path: root/gn3/db/__init__.py
diff options
context:
space:
mode:
Diffstat (limited to 'gn3/db/__init__.py')
-rw-r--r--gn3/db/__init__.py70
1 files changed, 70 insertions, 0 deletions
diff --git a/gn3/db/__init__.py b/gn3/db/__init__.py
index e69de29..fae4d29 100644
--- a/gn3/db/__init__.py
+++ b/gn3/db/__init__.py
@@ -0,0 +1,70 @@
+# pylint: disable=[R0902, R0903]
+"""Module that exposes common db operations"""
+from typing import Optional, Dict, Any
+from dataclasses import dataclass, asdict, astuple
+from typing_extensions import Protocol
+from MySQLdb import escape_string
+
+from gn3.db.phenotypes import Phenotype
+from gn3.db.phenotypes import PublishXRef
+from gn3.db.phenotypes import Publication
+
+from gn3.db.phenotypes import phenotype_mapping
+from gn3.db.phenotypes import publish_x_ref_mapping
+from gn3.db.phenotypes import publication_mapping
+
+TABLEMAP = {
+    "Phenotype": phenotype_mapping,
+    "PublishXRef": publish_x_ref_mapping,
+    "Publication": publication_mapping,
+}
+
+DATACLASSMAP = {
+    "Phenotype": Phenotype,
+    "PublishXRef": PublishXRef,
+    "Publication": Publication,
+}
+
+
+class Dataclass(Protocol):
+    """Type Definition for a Dataclass"""
+    __dataclass_fields__: Dict
+
+
+def update(conn: Any,
+           table: str,
+           data: Dataclass,
+           where: Dataclass) -> Optional[int]:
+    """Run an UPDATE on a table"""
+    if not any(astuple(data) + astuple(where)):
+        return None
+    sql = f"UPDATE {table} SET "
+    sql += ", ".join(f"{TABLEMAP[table].get(k)} "
+                     f"= '{escape_string(str(v)).decode('utf-8')}'" for
+                     k, v in asdict(data).items()
+                     if v is not None and k in TABLEMAP[table])
+    sql += " WHERE "
+    sql += "AND ".join(f"{TABLEMAP[table].get(k)} = "
+                       f"'{escape_string(str(v)).decode('utf-8')}'" for
+                       k, v in asdict(where).items()
+                       if v is not None and k in TABLEMAP[table])
+    with conn.cursor() as cursor:
+        cursor.execute(sql)
+        return cursor.rowcount
+
+
+def fetchone(conn: Any,
+             table: str,
+             where: Dataclass) -> Optional[Dataclass]:
+    """Run a SELECT on a table. Returns only one result!"""
+    if not any(astuple(where)):
+        return None
+    sql = f"SELECT * FROM {table} "
+    sql += "WHERE "
+    sql += "AND ".join(f"{TABLEMAP[table].get(k)} = "
+                       f"'{escape_string(str(v)).decode('utf-8')}'" for
+                       k, v in asdict(where).items()
+                       if v is not None and k in TABLEMAP[table])
+    with conn.cursor() as cursor:
+        cursor.execute(sql)
+        return DATACLASSMAP[table](*cursor.fetchone())