aboutsummaryrefslogtreecommitdiff
path: root/gn3/case_attributes.py
blob: 60616e49ecb271beb8b936ad404ec4fac15a1bc8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
"""Implement case-attribute manipulations."""
import os
import csv
import json
import uuid
import requests
import tempfile
from enum import Enum, auto
from pathlib import Path
from functools import reduce
from datetime import datetime
from urllib.parse import urljoin

from MySQLdb.cursors import DictCursor
from authlib.integrations.flask_oauth2.errors import _HTTPException
from flask import jsonify, request, Response, Blueprint, current_app

from gn3.commands import run_cmd

from gn3.db_utils import Connection, database_connection

from gn3.auth.authorisation.users import User
from gn3.auth.authorisation.errors import AuthorisationError
from gn3.auth.authorisation.oauth2.resource_server import require_oauth

caseattr = Blueprint("case-attribute", __name__)

CATTR_DIFFS_DIR = "case-attribute-diffs"

class NoDiffError(ValueError):
    """Raised if there is no difference between the old and new data."""
    def __init__(self):
        """Initialise exception."""
        super().__init__(
            self, "No difference between existing data and sent data.")

class EditStatus(Enum):
    """Enumeration for the status of the edits."""
    review = auto()
    approved = auto()
    rejected = auto()

    def __str__(self):
        """Print out human-readable form."""
        return self.name

class CAJSONEncoder(json.JSONEncoder):
    """Encoder for CaseAttribute-specific data"""
    def default(self, obj):
        """Default encoder"""
        if isinstance(obj, datetime):
            return obj.isoformat()
        if isinstance(obj, uuid.UUID):
            return str(obj)
        return json.JSONEncoder.default(self, obj)

def required_access(inbredset_id: int, access_levels: tuple[str, ...]) -> bool:
    """Check whether the user has the appropriate access"""
    def __species_id__(conn):
        with conn.cursor() as cursor:
            cursor.execute(
                "SELECT SpeciesId FROM InbredSet WHERE InbredSetId=%s",
                (inbredset_id,))
            return cursor.fetchone()[0]
    try:
        with (require_oauth.acquire("profile resource") as the_token,
              database_connection(current_app.config["SQL_URI"]) as conn):
            result = requests.get(
                urljoin(current_app.config["AUTH_SERVER_URL"],
                        "auth/resource/inbredset/resource-id"
                        f"/{__species_id__(conn)}/{inbredset_id}"))
            if result.status_code == 200:
                resource_id = result.json()["resource-id"]
                auth = requests.post(
                    urljoin(current_app.config["AUTH_SERVER_URL"],
                            "auth/resource/authorisation"),
                    json={"resource-ids": [resource_id]},
                    headers={"Authorization": f"Bearer {the_token.access_token}"})
                if auth.status_code == 200:
                    privs = tuple(priv["privilege_id"]
                                  for role in auth.json()[resource_id]["roles"]
                                  for priv in role["privileges"])
                    if all(lvl in privs for lvl in access_levels):
                        return privs
    except _HTTPException as httpe:
        raise AuthorisationError("You need to be logged in.") from httpe

    raise AuthorisationError(
        f"User does not have the privileges {access_levels}")

def __inbredset_group__(conn, inbredset_id):
    """Return InbredSet group's top-level details."""
    with conn.cursor(cursorclass=DictCursor) as cursor:
        cursor.execute(
            "SELECT * FROM InbredSet WHERE InbredSetId=%(inbredset_id)s",
            {"inbredset_id": inbredset_id})
        return dict(cursor.fetchone())

def __inbredset_strains__(conn, inbredset_id):
    """Return all samples/strains for given InbredSet group."""
    with conn.cursor(cursorclass=DictCursor) as cursor:
        cursor.execute(
            "SELECT s.* FROM StrainXRef AS sxr INNER JOIN Strain AS s "
            "ON sxr.StrainId=s.Id WHERE sxr.InbredSetId=%(inbredset_id)s "
            "ORDER BY s.Name ASC",
            {"inbredset_id": inbredset_id})
        return tuple(dict(row) for row in cursor.fetchall())

def __case_attribute_labels_by_inbred_set__(conn, inbredset_id):
    """Return the case-attribute labels/names for the given InbredSet group."""
    with conn.cursor(cursorclass=DictCursor) as cursor:
        cursor.execute(
            "SELECT * FROM CaseAttribute WHERE InbredSetId=%(inbredset_id)s",
            {"inbredset_id": inbredset_id})
        return tuple(dict(row) for row in cursor.fetchall())

@caseattr.route("/<int:inbredset_id>", methods=["GET"])
def inbredset_group(inbredset_id: int) -> Response:
    """Retrieve InbredSet group's details."""
    with database_connection(current_app.config["SQL_URI"]) as conn:
        return jsonify(__inbredset_group__(conn, inbredset_id))

@caseattr.route("/<int:inbredset_id>/strains", methods=["GET"])
def inbredset_strains(inbredset_id: int) -> Response:
    """Retrieve ALL strains/samples relating to a specific InbredSet group."""
    with database_connection(current_app.config["SQL_URI"]) as conn:
        return jsonify(__inbredset_strains__(conn, inbredset_id))

@caseattr.route("/<int:inbredset_id>/names", methods=["GET"])
def inbredset_case_attribute_names(inbredset_id: int) -> Response:
    """Retrieve ALL case-attributes for a specific InbredSet group."""
    with database_connection(current_app.config["SQL_URI"]) as conn:
        return jsonify(
            __case_attribute_labels_by_inbred_set__(conn, inbredset_id))

def __by_strain__(accumulator, item):
    attr = {item["CaseAttributeName"]: item["CaseAttributeValue"]}
    strain_name = item["StrainName"]
    if bool(accumulator.get(strain_name)):
        return {
            **accumulator,
            strain_name: {
                **accumulator[strain_name],
                "case-attributes": {
                    **accumulator[strain_name]["case-attributes"],
                    **attr
                }
            }
        }
    return {
        strain_name: {
            **{
                key: value for key,value in item.items()
                if key in ("StrainName", "StrainName2", "Symbol", "Alias")
            },
            "case-attributes": attr
        }
    }

def __case_attribute_values_by_inbred_set__(
        conn: Connection, inbredset_id: int) -> tuple[dict, ...]:
    """
    Retrieve Case-Attributes by their InbredSet ID. Do not call this outside
    this module.
    """
    with conn.cursor(cursorclass=DictCursor) as cursor:
        cursor.execute(
            "SELECT ca.Name AS CaseAttributeName, "
            "caxrn.Value AS CaseAttributeValue, s.Name AS StrainName, "
            "s.Name2 AS StrainName2, s.Symbol, s.Alias "
            "FROM CaseAttribute AS ca "
            "INNER JOIN CaseAttributeXRefNew AS caxrn "
            "ON ca.CaseAttributeId=caxrn.CaseAttributeId "
            "INNER JOIN Strain AS s "
            "ON caxrn.StrainId=s.Id "
            "WHERE ca.InbredSetId=%(inbredset_id)s "
            "ORDER BY StrainName",
            {"inbredset_id": inbredset_id})
        return tuple(
            reduce(__by_strain__, cursor.fetchall(), {}).values())

@caseattr.route("/<int:inbredset_id>/values", methods=["GET"])
def inbredset_case_attribute_values(inbredset_id: int) -> Response:
    """Retrieve the group's (InbredSet's) case-attribute values."""
    with database_connection(current_app.config["SQL_URI"]) as conn:
        return jsonify(__case_attribute_values_by_inbred_set__(conn, inbredset_id))

def __process_orig_data__(fieldnames, cadata, strains) -> tuple[dict, ...]:
    """Process data from database and return tuple of dicts."""
    data = {item["StrainName"]: item for item in cadata}
    return tuple(
        {
            "Strain": strain["Name"],
            **{
                key: data.get(
                    strain["Name"], {}).get("case-attributes", {}).get(key, "")
                for key in fieldnames[1:]
            }
        } for strain in strains)

def __process_edit_data__(fieldnames, form_data) -> tuple[dict, ...]:
    """Process data from form and return tuple of dicts."""
    def __process__(acc, strain_cattrs):
        strain, cattrs = strain_cattrs
        return acc + ({
            "Strain": strain, **{
            field: cattrs["case-attributes"].get(field, "")
            for field in fieldnames[1:]
            }
        },)
    return reduce(__process__, form_data.items(), tuple())

def __write_csv__(fieldnames, data):
    """Write the given `data` to a csv file and return the path to the file."""
    fdesc, filepath = tempfile.mkstemp(".csv")
    os.close(fdesc)
    with open(filepath, "w", encoding="utf-8") as csvfile:
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames, dialect="unix")
        writer.writeheader()
        writer.writerows(data)

    return filepath

def __compute_diff__(
        fieldnames: tuple[str, ...],
        original_data: tuple[dict, ...],
        edit_data: tuple[dict, ...]):
    """Return the diff of the data."""
    basefilename = __write_csv__(fieldnames, original_data)
    deltafilename = __write_csv__(fieldnames, edit_data)
    diff_results = run_cmd(json.dumps(
        ["csvdiff", basefilename, deltafilename, "--format", "json"]))
    os.unlink(basefilename)
    os.unlink(deltafilename)
    if diff_results["code"] == 0:
        return json.loads(diff_results["output"])
    return {}

def __queue_diff__(conn: Connection, diff_data, diff_data_dir: Path) -> Path:
    """
    Queue diff for future processing.

    Returns: `diff`
        On success, this will return the filename where the diff was saved.
        On failure, it will raise a MySQL error.
    """
    diff = diff_data["diff"]
    if bool(diff["Additions"]) or bool(diff["Modifications"]) or bool(diff["Deletions"]):
        diff_data_dir.mkdir(parents=True, exist_ok=True)

        created = datetime.now()
        filepath = Path(
            diff_data_dir,
            f"{diff_data['inbredset_id']}:::{diff_data['user_id']}:::"
            f"{created.isoformat()}.json")
        with open(filepath, "w", encoding="utf8") as diff_file:
            # We want this to fail if the metadata items below are not provided.
            the_diff = {**diff_data, "created": created}
            insert_id = __save_diff__(conn, the_diff, EditStatus.review)
            diff_file.write(json.dumps({**the_diff, "db_id": insert_id},
                                       cls=CAJSONEncoder))
        return filepath
    raise NoDiffError

def __save_diff__(conn: Connection, diff_data: dict, status: EditStatus) -> int:
    """Save to the database."""
    with conn.cursor() as cursor:
        cursor.execute(
            "INSERT INTO "
            "caseattributes_audit(id, status, editor, json_diff_data, time_stamp) "
            "VALUES(%(db_id)s, %(status)s, %(editor)s, %(diff)s, %(ts)s) "
            "ON DUPLICATE KEY UPDATE status=%(status)s",
            {
                "db_id": diff_data.get("db_id"),
                "status": str(status),
                "editor": str(diff_data["user_id"]),
                "diff": json.dumps(diff_data, cls=CAJSONEncoder),
                "ts": diff_data["created"].isoformat()
            })
        return diff_data.get("db_id") or cursor.lastrowid

def __load_diff__(diff_filename):
    """Load the diff."""
    with open(diff_filename, encoding="utf8") as diff_file:
        the_diff = json.loads(diff_file.read())
        return {
            **the_diff,
            "db_id": int(the_diff["db_id"]),
            "inbredset_id": int(the_diff["inbredset_id"]),
            "user_id": uuid.UUID(the_diff["user_id"]),
            "created": datetime.fromisoformat(the_diff["created"])
        }

def __apply_diff__(
        conn: Connection, inbredset_id: int, user: User, diff_filename) -> None:
    """
    Apply the changes in the diff at `diff_filename` to the data in the database
    if the user has appropriate privileges.
    """
    required_access(
        inbredset_id, ("system:inbredset:edit-case-attribute",
                       "system:inbredset:apply-case-attribute-edit"))
    raise NotImplementedError

def __reject_diff__(conn: Connection,
                    inbredset_id: int,
                    user: User,
                    diff_filename: Path,
                    diff: dict) -> Path:
    """
    Reject the changes in the diff at `diff_filename` to the data in the
    database if the user has appropriate privileges.
    """
    required_access(
        inbredset_id, ("system:inbredset:edit-case-attribute",
                       "system:inbredset:apply-case-attribute-edit"))
    __save_diff__(conn, diff, EditStatus.rejected)
    new_path = Path(diff_filename.parent, f"{diff_filename.stem}-rejected{diff_filename.suffix}")
    os.rename(diff_filename, new_path)
    return diff_filename

@caseattr.route("/<int:inbredset_id>/add", methods=["POST"])
def add_case_attributes(inbredset_id: int) -> Response:
    """Add a new case attribute for `InbredSetId`."""
    required_access(inbredset_id, ("system:inbredset:create-case-attribute",))
    with (require_oauth.acquire("profile resource") as the_token,
          database_connection(current_app.config["SQL_URI"]) as conn):
        raise NotImplementedError

@caseattr.route("/<int:inbredset_id>/delete", methods=["POST"])
def delete_case_attributes(inbredset_id: int) -> Response:
    """Delete a case attribute from `InbredSetId`."""
    required_access(inbredset_id, ("system:inbredset:delete-case-attribute",))
    with (require_oauth.acquire("profile resource") as the_token,
          database_connection(current_app.config["SQL_URI"]) as conn):
        raise NotImplementedError

@caseattr.route("/<int:inbredset_id>/edit", methods=["POST"])
def edit_case_attributes(inbredset_id: int) -> Response:
    """Edit the case attributes for `InbredSetId` based on data received."""
    with (require_oauth.acquire("profile resource") as the_token,
          database_connection(current_app.config["SQL_URI"]) as conn):
        required_access(inbredset_id,
                        ("system:inbredset:edit-case-attribute",))
        user = the_token.user
        fieldnames = (["Strain"] + sorted(
            attr["Name"] for attr in
            __case_attribute_labels_by_inbred_set__(conn, inbredset_id)))
        try:
            diff_filename = __queue_diff__(
                conn, {
                    "inbredset_id": inbredset_id,
                    "user_id": str(user.user_id),
                    "fieldnames": fieldnames,
                    "diff": __compute_diff__(
                        fieldnames,
                        __process_orig_data__(
                            fieldnames,
                            __case_attribute_values_by_inbred_set__(conn, inbredset_id),
                            __inbredset_strains__(conn, inbredset_id)),
                        __process_edit_data__(fieldnames, request.json["edit-data"]))
                },
                Path(current_app.config.get("TMPDIR"), CATTR_DIFFS_DIR))
        except NoDiffError as _nde:
            msg = "There were no changes to make from submitted data."
            response = jsonify({
                "diff-status": "error",
                "error_description": msg
            })
            response.status_code = 400
            return response

        try:
            __apply_diff__(conn, user, diff_filename)
            return jsonify({
                "diff-status": "applied",
                "message": ("The changes to the case-attributes have been "
                            "applied successfully.")
            })
        except AuthorisationError as _auth_err:
            return jsonify({
                "diff-status": "queued",
                "message": ("The changes to the case-attributes have been "
                            "queued for approval."),
                "diff-filename": str(diff_filename.name)
            })

@caseattr.route("/approve/<path:filename>", methods=["POST"])
def approve_case_attributes_diff(inbredset_id: int) -> Response:
    """Approve the changes to the case attributes in the diff."""
    with (require_oauth.acquire("profile resource") as the_token,
          database_connection(current_app.config["SQL_URI"]) as conn):
        __apply_diff__(conn, inbredset_id, the_token.user, diff_filename)
        raise NotImplementedError

@caseattr.route("/reject/<path:filename>", methods=["POST"])
def reject_case_attributes_diff(filename: str) -> Response:
    """Reject the changes to the case attributes in the diff."""
    diff_dir = Path(current_app.config.get("TMPDIR"), CATTR_DIFFS_DIR)
    diff_filename = Path(diff_dir, filename)
    the_diff = __load_diff__(diff_filename)
    with (require_oauth.acquire("profile resource") as the_token,
          database_connection(current_app.config["SQL_URI"]) as conn):
        __reject_diff__(conn, the_diff["inbredset_id"], the_token.user, diff_filename, the_diff)
        return jsonify({
            "message": f"Rejected diff successfully",
            "diff_filename": diff_filename.name
        })