aboutsummaryrefslogtreecommitdiff
path: root/gn3/db/__init__.py
blob: ea800c15e3c0b85b09105b73aad90b096fb3df9a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# pylint: disable=[R0902, R0903]
"""Module that exposes common db operations"""
from dataclasses import asdict, astuple
from typing import Any, Dict, List, Optional, Generator, Tuple, Union
from typing_extensions import Protocol

from gn3.db.metadata_audit import MetadataAudit
from gn3.db.phenotypes import Phenotype
from gn3.db.phenotypes import Probeset
from gn3.db.phenotypes import Publication
from gn3.db.phenotypes import PublishXRef


from gn3.db.metadata_audit import metadata_audit_mapping
from gn3.db.phenotypes import phenotype_mapping
from gn3.db.phenotypes import probeset_mapping
from gn3.db.phenotypes import publication_mapping
from gn3.db.phenotypes import publish_x_ref_mapping


TABLEMAP = {
    "Phenotype": phenotype_mapping,
    "ProbeSet": probeset_mapping,
    "Publication": publication_mapping,
    "PublishXRef": publish_x_ref_mapping,
    "metadata_audit": metadata_audit_mapping,
}

DATACLASSMAP = {
    "Phenotype": Phenotype,
    "ProbeSet": Probeset,
    "Publication": Publication,
    "PublishXRef": PublishXRef,
    "metadata_audit": MetadataAudit,
}


class Dataclass(Protocol):
    """Type Definition for a Dataclass"""
    __dataclass_fields__: Dict


def update(conn: Any,
           table: str,
           data: Dataclass,
           where: Dataclass) -> Optional[int]:
    """Run an UPDATE on a table"""
    if not (any(astuple(data)) and any(astuple(where))):
        return None
    data_ = {k: v for k, v in asdict(data).items()
             if v is not None and k in TABLEMAP[table]}
    where_ = {k: v for k, v in asdict(where).items()
              if v is not None and k in TABLEMAP[table]}
    sql = f"UPDATE {table} SET "
    sql += ", ".join(f"{TABLEMAP[table].get(k)} "
                     "= %s" for k in data_.keys())
    sql += " WHERE "
    sql += " AND ".join(f"{TABLEMAP[table].get(k)} = "
                        "%s" for k in where_.keys())
    with conn.cursor() as cursor:
        cursor.execute(sql,
                       tuple(data_.values()) + tuple(where_.values()))
        return cursor.rowcount


def fetchone(conn: Any,
             table: str,
             where: Optional[Dataclass],
             columns: Union[str, List[str]] = "*") -> Optional[Dataclass]:
    """Run a SELECT on a table. Returns only one result!"""
    if not any(astuple(where)):
        return None
    where_ = {TABLEMAP[table].get(k): v for k, v in asdict(where).items()
              if v is not None and k in TABLEMAP[table]}
    sql = ""
    if columns != "*":
        sql = f"SELECT {', '.join(columns)} FROM {table} "
    else:
        sql = f"SELECT * FROM {table} "
    if where:
        sql += "WHERE "
        sql += " AND ".join(f"{k} = "
                            "%s" for k in where_.keys())
    with conn.cursor() as cursor:
        cursor.execute(sql, tuple(where_.values()))
        return DATACLASSMAP[table](*cursor.fetchone())


def fetchall(conn: Any,
             table: str,
             where: Optional[Dataclass],
             columns: Union[str, List[str]] = "*") -> Optional[Generator]:
    """Run a SELECT on a table. Returns all the results as a tuple!"""
    if not any(astuple(where)):
        return None
    where_ = {TABLEMAP[table].get(k): v for k, v in asdict(where).items()
              if v is not None and k in TABLEMAP[table]}
    sql = ""
    if columns != "*":
        sql = f"SELECT {', '.join(columns)} FROM {table} "
    else:
        sql = f"SELECT * FROM {table} "
    if where:
        sql += "WHERE "
        sql += " AND ".join(f"{k} = "
                            "%s" for k in where_.keys())
    with conn.cursor() as cursor:
        cursor.execute(sql, tuple(where_.values()))
        return (DATACLASSMAP[table](*record) for record in cursor.fetchall())


def insert(conn: Any,
           table: str,
           data: Dataclass) -> Optional[int]:
    """Run an INSERT into a table"""
    dict_ = {TABLEMAP[table].get(k): v for k, v in asdict(data).items()
             if v is not None and k in TABLEMAP[table]}
    sql = f"INSERT INTO {table} ("
    sql += ", ".join(f"{k}" for k in dict_.keys())
    sql += ") VALUES ("
    sql += ", ".join("%s" for _ in dict_.keys())
    sql += ")"
    with conn.cursor() as cursor:
        cursor.execute(sql, tuple(dict_.values()))
        return cursor.rowcount


def diff_from_dict(old: Dict, new: Dict) -> Dict:
    """Construct a new dict with a specific structure that contains the difference
between the 2 dicts in the structure:

diff_from_dict({"id": 1, "data": "a"}, {"id": 2, "data": "b"})

Should return:

{"id": {"old": 1, "new": 2}, "data": {"old": "a", "new": "b"}}
    """
    dict_ = {}
    for key in old.keys():
        dict_[key] = {"old": old[key], "new": new[key]}
    return dict_