diff options
120 files changed, 4419 insertions, 5779 deletions
diff --git a/.gitignore b/.gitignore index 176cfba..e5d73a5 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,7 @@ +# Version: 1.0.0 +# Last updated: 2025-02-05 +# Maintainer: GeneNetwork Team + # General .DS_Store .AppleDouble @@ -188,4 +192,5 @@ dmypy.json /**/yoyo*.ini # utility scripts -/run-dev.sh \ No newline at end of file +/run-dev.sh +.aider* diff --git a/.guix-channel b/.guix-channel index 384da17..7fd99ea 100644 --- a/.guix-channel +++ b/.guix-channel @@ -1,21 +1,30 @@ -;; This file lets us present this repo as a Guix channel. - (channel - (version 0) - (directory ".guix") - (dependencies - (channel - (name guix-bioinformatics) - (url "https://git.genenetwork.org/guix-bioinformatics")) - ;; FIXME: guix-bioinformatics depends on guix-past. So, there - ;; should be no reason to explicitly depend on guix-past. But, the - ;; channel does not build otherwise. This is probably a guix bug. - (channel - (name guix-past) - (url "https://gitlab.inria.fr/guix-hpc/guix-past") - (introduction - (channel-introduction - (version 0) - (commit "0c119db2ea86a389769f4d2b9c6f5c41c027e336") - (signer - "3CE4 6455 8A84 FDC6 9DB4 0CFB 090B 1199 3D9A EBB5")))))) + (version 0) + (directory ".guix") + (dependencies + (channel + (name guix-bioinformatics) + (url "https://git.genenetwork.org/guix-bioinformatics") + (branch "master")) + ;; FIXME: guix-bioinformatics depends on guix-past. So, there + ;; should be no reason to explicitly depend on guix-past. But, the + ;; channel does not build otherwise. This is probably a guix bug. + (channel + (name guix-past) + (url "https://codeberg.org/guix-science/guix-past") + (introduction + (channel-introduction + (version 0) + (commit "c3bc94ee752ec545e39c1b8a29f739405767b51c") + (signer + "3CE4 6455 8A84 FDC6 9DB4 0CFB 090B 1199 3D9A EBB5")))) + (channel + (name guix-rust-past-crates) + (url "https://codeberg.org/guix/guix-rust-past-crates.git") + (branch "trunk") + (introduction + (channel-introduction + (version 0) + (commit "b8b7ffbd1cec9f56f93fae4da3a74163bbc9c570") + (signer + "F4C2 D1DF 3FDE EA63 D1D3 0776 ACC6 6D09 CA52 8292")))))) diff --git a/.guix/genenetwork3-package.scm b/.guix/genenetwork3-package.scm index ab0f5b8..9030c45 100644 --- a/.guix/genenetwork3-package.scm +++ b/.guix/genenetwork3-package.scm @@ -34,7 +34,11 @@ (substitute* "tests/fixtures/rdf.py" (("virtuoso-t") (string-append #$virtuoso-ose "/bin/virtuoso-t")))))) - (add-after 'build 'rdf-tests + ;; The logical flow for running tests is to perform static + ;; checks(pylint and mypy) before running the unit-tests in + ;; order to catch issues earlier. Network tests such as RDF + ;; should run after the unit tests to maintain that order. + (add-after 'check 'rdf-tests (lambda _ (invoke "pytest" "-k" "rdf"))) (add-before 'build 'pylint diff --git a/MANIFEST.in b/MANIFEST.in index f7fe787..09b6512 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,2 +1,2 @@ -global-include scripts/**/*.R +global-include scripts/**/*.R gn3/**/*.json global-exclude *~ *.py[cod] \ No newline at end of file diff --git a/example-run-dev.sh b/example-run-dev.sh index a0c5d61..959411f 100644 --- a/example-run-dev.sh +++ b/example-run-dev.sh @@ -3,7 +3,6 @@ ## Copy to run-dev.sh and update the appropriate environment variables. export SQL_URI="${SQL_URI:+${SQL_URI}}" -export AUTH_DB="${AUTH_DB:+${AUTH_DB}}" export FLASK_DEBUG=1 export FLASK_APP="main.py" export AUTHLIB_INSECURE_TRANSPORT=true @@ -20,12 +19,5 @@ then exit 1; fi -if [ -z "${AUTH_DB}" ] -then - echo "ERROR: You need to specify the 'AUTH_DB' environment variable"; - exit 1; -fi - - # flask run --port=8080 flask ${CMD_ARGS[@]} diff --git a/gn3/api/case_attributes.py b/gn3/api/case_attributes.py new file mode 100644 index 0000000..28337ea --- /dev/null +++ b/gn3/api/case_attributes.py @@ -0,0 +1,296 @@ +"""Implement case-attribute manipulations.""" +from typing import Union +from pathlib import Path + +from functools import reduce +from urllib.parse import urljoin + +import requests +from MySQLdb.cursors import DictCursor +from authlib.integrations.flask_oauth2.errors import _HTTPException +from flask import ( + jsonify, + request, + Response, + Blueprint, + current_app) +from gn3.db.case_attributes import ( + CaseAttributeEdit, + EditStatus, + queue_edit, + apply_change, + get_changes) + + +from gn3.db_utils import Connection, database_connection + +from gn3.oauth2.authorisation import require_token +from gn3.oauth2.errors import AuthorisationError + +caseattr = Blueprint("case-attribute", __name__) + + +def required_access( + token: dict, + inbredset_id: int, + access_levels: tuple[str, ...] +) -> Union[bool, tuple[str, ...]]: + """Check whether the user has the appropriate access""" + def __species_id__(conn): + with conn.cursor() as cursor: + cursor.execute( + "SELECT SpeciesId FROM InbredSet WHERE InbredSetId=%s", + (inbredset_id,)) + return cursor.fetchone()[0] + try: + with database_connection(current_app.config["SQL_URI"]) as conn: + result = requests.get( + # this section fetches the resource ID from the auth server + urljoin(current_app.config["AUTH_SERVER_URL"], + "auth/resource/populations/resource-id" + f"/{__species_id__(conn)}/{inbredset_id}"), + timeout=300) + if result.status_code == 200: + resource_id = result.json()["resource-id"] + auth = requests.post( + # this section fetches the authorisations/privileges that + # the current user has on the resource we got above + urljoin(current_app.config["AUTH_SERVER_URL"], + "auth/resource/authorisation"), + json={"resource-ids": [resource_id]}, + headers={ + "Authorization": f"Bearer {token['access_token']}"}, + timeout=300) + if auth.status_code == 200: + privs = tuple(priv["privilege_id"] + for role in auth.json()[resource_id]["roles"] + for priv in role["privileges"]) + if all(lvl in privs for lvl in access_levels): + return privs + except _HTTPException as httpe: + raise AuthorisationError("You need to be logged in.") from httpe + + raise AuthorisationError( + f"User does not have the privileges {access_levels}") + + +def __inbredset_group__(conn, inbredset_id): + """Return InbredSet group's top-level details.""" + with conn.cursor(cursorclass=DictCursor) as cursor: + cursor.execute( + "SELECT * FROM InbredSet WHERE InbredSetId=%(inbredset_id)s", + {"inbredset_id": inbredset_id}) + return dict(cursor.fetchone()) + + +def __inbredset_strains__(conn, inbredset_id): + """Return all samples/strains for given InbredSet group.""" + with conn.cursor(cursorclass=DictCursor) as cursor: + cursor.execute( + "SELECT s.* FROM StrainXRef AS sxr INNER JOIN Strain AS s " + "ON sxr.StrainId=s.Id WHERE sxr.InbredSetId=%(inbredset_id)s " + "ORDER BY s.Name ASC", + {"inbredset_id": inbredset_id}) + return tuple(dict(row) for row in cursor.fetchall()) + + +def __case_attribute_labels_by_inbred_set__(conn, inbredset_id): + """Return the case-attribute labels/names for the given InbredSet group.""" + with conn.cursor(cursorclass=DictCursor) as cursor: + cursor.execute( + "SELECT * FROM CaseAttribute WHERE InbredSetId=%(inbredset_id)s", + {"inbredset_id": inbredset_id}) + return tuple(dict(row) for row in cursor.fetchall()) + + +@caseattr.route("/<int:inbredset_id>", methods=["GET"]) +def inbredset_group(inbredset_id: int) -> Response: + """Retrieve InbredSet group's details.""" + with database_connection(current_app.config["SQL_URI"]) as conn: + return jsonify(__inbredset_group__(conn, inbredset_id)) + + +@caseattr.route("/<int:inbredset_id>/strains", methods=["GET"]) +def inbredset_strains(inbredset_id: int) -> Response: + """Retrieve ALL strains/samples relating to a specific InbredSet group.""" + with database_connection(current_app.config["SQL_URI"]) as conn: + return jsonify(__inbredset_strains__(conn, inbredset_id)) + + +@caseattr.route("/<int:inbredset_id>/names", methods=["GET"]) +def inbredset_case_attribute_names(inbredset_id: int) -> Response: + """Retrieve ALL case-attributes for a specific InbredSet group.""" + with database_connection(current_app.config["SQL_URI"]) as conn: + return jsonify( + __case_attribute_labels_by_inbred_set__(conn, inbredset_id)) + + +def __by_strain__(accumulator, item): + attr = {item["CaseAttributeName"]: item["CaseAttributeValue"]} + strain_name = item["StrainName"] + if bool(accumulator.get(strain_name)): + return { + **accumulator, + strain_name: { + **accumulator[strain_name], + "case-attributes": { + **accumulator[strain_name]["case-attributes"], + **attr + } + } + } + return { + **accumulator, + strain_name: { + **{ + key: value for key, value in item.items() + if key in ("StrainName", "StrainName2", "Symbol", "Alias") + }, + "case-attributes": attr + } + } + + +def __case_attribute_values_by_inbred_set__( + conn: Connection, inbredset_id: int) -> tuple[dict, ...]: + """ + Retrieve Case-Attributes by their InbredSet ID. Do not call this outside + this module. + """ + with conn.cursor(cursorclass=DictCursor) as cursor: + cursor.execute( + "SELECT ca.Name AS CaseAttributeName, " + "caxrn.Value AS CaseAttributeValue, s.Name AS StrainName, " + "s.Name2 AS StrainName2, s.Symbol, s.Alias " + "FROM CaseAttribute AS ca " + "INNER JOIN CaseAttributeXRefNew AS caxrn " + "ON ca.CaseAttributeId=caxrn.CaseAttributeId " + "INNER JOIN Strain AS s " + "ON caxrn.StrainId=s.Id " + "WHERE caxrn.InbredSetId=%(inbredset_id)s " + "ORDER BY StrainName", + {"inbredset_id": inbredset_id}) + return tuple( + reduce(__by_strain__, cursor.fetchall(), {}).values()) + + +@caseattr.route("/<int:inbredset_id>/values", methods=["GET"]) +def inbredset_case_attribute_values(inbredset_id: int) -> Response: + """Retrieve the group's (InbredSet's) case-attribute values.""" + with database_connection(current_app.config["SQL_URI"]) as conn: + return jsonify(__case_attribute_values_by_inbred_set__(conn, inbredset_id)) + + +# pylint: disable=[too-many-locals] +@caseattr.route("/<int:inbredset_id>/edit", methods=["POST"]) +@require_token +def edit_case_attributes(inbredset_id: int, auth_token=None) -> tuple[Response, int]: + """Edit the case attributes for `InbredSetId` based on data received. + + :inbredset_id: Identifier for the population that the case attribute belongs + :auth_token: A validated JWT from the auth server + """ + with database_connection(current_app.config["SQL_URI"]) as conn, conn.cursor() as cursor: + data = request.json["edit-data"] # type: ignore + edit = CaseAttributeEdit( + inbredset_id=inbredset_id, + status=EditStatus.review, + user_id=auth_token["jwt"]["sub"], + changes=data + ) + directory = (Path(current_app.config["LMDB_DATA_PATH"]) / + "case-attributes" / str(inbredset_id)) + queue_edit(cursor=cursor, + directory=directory, + edit=edit) + return jsonify({ + "diff-status": "queued", + "message": ("The changes to the case-attributes have been " + "queued for approval."), + }), 201 + + +@caseattr.route("/<int:inbredset_id>/diffs/<string:change_type>/list", methods=["GET"]) +def list_diffs(inbredset_id: int, change_type: str) -> tuple[Response, int]: + """List any changes that have been made by change_type.""" + with (database_connection(current_app.config["SQL_URI"]) as conn, + conn.cursor(cursorclass=DictCursor) as cursor): + directory = (Path(current_app.config["LMDB_DATA_PATH"]) / + "case-attributes" / str(inbredset_id)) + return jsonify( + get_changes( + cursor=cursor, + change_type=EditStatus[change_type], + directory=directory + ) + ), 200 + + +@caseattr.route("/<int:inbredset_id>/approve/<int:change_id>", methods=["POST"]) +@require_token +def approve_case_attributes_diff( + inbredset_id: int, + change_id: int, auth_token=None +) -> tuple[Response, int]: + """Approve the changes to the case attributes in the diff.""" + try: + required_access(auth_token, + inbredset_id, + ("system:inbredset:edit-case-attribute",)) + with (database_connection(current_app.config["SQL_URI"]) as conn, + conn.cursor() as cursor): + directory = (Path(current_app.config["LMDB_DATA_PATH"]) / + "case-attributes" / str(inbredset_id)) + match apply_change(cursor, change_type=EditStatus.approved, + change_id=change_id, + directory=directory): + case True: + return jsonify({ + "diff-status": "approved", + "message": (f"Successfully approved # {change_id}") + }), 201 + case _: + return jsonify({ + "diff-status": "queued", + "message": (f"Was not able to successfully approve # {change_id}") + }), 200 + except AuthorisationError as __auth_err: + return jsonify({ + "diff-status": "queued", + "message": ("You don't have the right privileges to edit this resource.") + }), 401 + + +@caseattr.route("/<int:inbredset_id>/reject/<int:change_id>", methods=["POST"]) +@require_token +def reject_case_attributes_diff( + inbredset_id: int, change_id: int, auth_token=None +) -> tuple[Response, int]: + """Reject the changes to the case attributes in the diff.""" + try: + required_access(auth_token, + inbredset_id, + ("system:inbredset:edit-case-attribute", + "system:inbredset:apply-case-attribute-edit")) + with database_connection(current_app.config["SQL_URI"]) as conn, \ + conn.cursor() as cursor: + directory = (Path(current_app.config["LMDB_DATA_PATH"]) / + "case-attributes" / str(inbredset_id)) + match apply_change(cursor, change_type=EditStatus.rejected, + change_id=change_id, + directory=directory): + case True: + return jsonify({ + "diff-status": "rejected", + "message": ("The changes to the case-attributes have been " + "rejected.") + }), 201 + case _: + return jsonify({ + "diff-status": "queued", + "message": ("Failed to reject changes") + }), 200 + except AuthorisationError as __auth_err: + return jsonify({ + "message": ("You don't have the right privileges to edit this resource.") + }), 401 diff --git a/gn3/api/correlation.py b/gn3/api/correlation.py index c77dd93..3667a24 100644 --- a/gn3/api/correlation.py +++ b/gn3/api/correlation.py @@ -1,5 +1,6 @@ """Endpoints for running correlations""" import sys +import logging from functools import reduce import redis @@ -8,14 +9,14 @@ from flask import Blueprint from flask import request from flask import current_app -from gn3.settings import SQL_URI from gn3.db_utils import database_connection from gn3.commands import run_sample_corr_cmd from gn3.responses.pcorrs_responses import build_response -from gn3.commands import run_async_cmd, compose_pcorrs_command from gn3.computations.correlations import map_shared_keys_to_values from gn3.computations.correlations import compute_tissue_correlation from gn3.computations.correlations import compute_all_lit_correlation +from gn3.commands import ( + run_async_cmd, compute_job_queue, compose_pcorrs_command) correlation = Blueprint("correlation", __name__) @@ -88,6 +89,7 @@ def compute_tissue_corr(corr_method="pearson"): return jsonify(results) + @correlation.route("/partial", methods=["POST"]) def partial_correlation(): """API endpoint for partial correlations.""" @@ -111,9 +113,9 @@ def partial_correlation(): args = request.get_json() with_target_db = args.get("with_target_db", True) request_errors = __errors__( - args, ("primary_trait", "control_traits", - ("target_db" if with_target_db else "target_traits"), - "method")) + args, ("primary_trait", "control_traits", + ("target_db" if with_target_db else "target_traits"), + "method")) if request_errors: return build_response({ "status": "error", @@ -127,7 +129,7 @@ def partial_correlation(): tuple( trait_fullname(trait) for trait in args["control_traits"]), args["method"], target_database=args["target_db"], - criteria = int(args.get("criteria", 500))) + criteria=int(args.get("criteria", 500))) else: command = compose_pcorrs_command( trait_fullname(args["primary_trait"]), @@ -137,10 +139,17 @@ def partial_correlation(): trait_fullname(trait) for trait in args["target_traits"])) queueing_results = run_async_cmd( - conn=conn, - cmd=command, - job_queue=current_app.config.get("REDIS_JOB_QUEUE"), - env = {"PYTHONPATH": ":".join(sys.path), "SQL_URI": SQL_URI}) + conn=conn, + cmd=command, + job_queue=compute_job_queue(current_app), + options={ + "env": { + "PYTHONPATH": ":".join(sys.path), + "SQL_URI": current_app.config["SQL_URI"] + }, + }, + log_level=logging.getLevelName( + current_app.logger.getEffectiveLevel()).lower()) return build_response({ "status": "success", "results": queueing_results, diff --git a/gn3/api/ctl.py b/gn3/api/ctl.py index ac33d63..39c286f 100644 --- a/gn3/api/ctl.py +++ b/gn3/api/ctl.py @@ -2,7 +2,7 @@ from flask import Blueprint from flask import request -from flask import jsonify +from flask import jsonify, current_app from gn3.computations.ctl import call_ctl_script @@ -18,7 +18,8 @@ def run_ctl(): """ ctl_data = request.json - (cmd_results, response) = call_ctl_script(ctl_data) + (cmd_results, response) = call_ctl_script( + ctl_data, current_app.config["TMPDIR"]) return (jsonify({ "results": response }), 200) if response is not None else (jsonify({"error": str(cmd_results)}), 401) diff --git a/gn3/api/general.py b/gn3/api/general.py index b984361..8b57f23 100644 --- a/gn3/api/general.py +++ b/gn3/api/general.py @@ -1,5 +1,6 @@ """General API endpoints. Put endpoints that can't be grouped together nicely here.""" +import os from flask import Blueprint from flask import current_app from flask import jsonify @@ -68,3 +69,31 @@ def run_r_qtl(geno_filestr, pheno_filestr): cmd = (f"Rscript {rqtl_wrapper} " f"{geno_filestr} {pheno_filestr}") return jsonify(run_cmd(cmd)), 201 + + +@general.route("/stream ", methods=["GET"]) +def stream(): + """ + This endpoint streams the stdout content from a file. + It expects an identifier to be passed as a query parameter. + Example: `/stream?id=<identifier>` + The `id` will be used to locate the corresponding file. + You can also pass an optional `peak` parameter + to specify the file position to start reading from. + Query Parameters: + - `id` (required): The identifier used to locate the file. + - `peak` (optional): The position in the file to start reading from. + Returns: + - dict with data(stdout), run_id unique id for file, + pointer last read position for file + """ + run_id = request.args.get("id", "") + output_file = os.path.join(current_app.config.get("TMPDIR"), + f"{run_id}.txt") + seek_position = int(request.args.get("peak", 0)) + with open(output_file, encoding="utf-8") as file_handler: + # read to the last position default to 0 + file_handler.seek(seek_position) + return jsonify({"data": file_handler.readlines(), + "run_id": run_id, + "pointer": file_handler.tell()}) diff --git a/gn3/api/heatmaps.py b/gn3/api/heatmaps.py index 172d555..d3f9a45 100644 --- a/gn3/api/heatmaps.py +++ b/gn3/api/heatmaps.py @@ -33,7 +33,7 @@ def clustered_heatmaps(): with io.StringIO() as io_str: figure = build_heatmap(conn, traits_fullnames, - current_app.config["GENOTYPE_FILES"], + f'{current_app.config["GENOTYPE_FILES"]}/genotype', vertical=vertical, tmpdir=current_app.config["TMPDIR"]) figure.write_json(io_str) diff --git a/gn3/api/llm.py b/gn3/api/llm.py index b9ffbb2..dc8412e 100644 --- a/gn3/api/llm.py +++ b/gn3/api/llm.py @@ -1,16 +1,25 @@ """Api endpoints for gnqa""" +import ipaddress import json +import string +import uuid + from datetime import datetime +from datetime import timedelta +from typing import Optional +from functools import wraps from flask import Blueprint from flask import current_app from flask import jsonify from flask import request +from authlib.jose.errors import DecodeError from gn3.llms.process import get_gnqa from gn3.llms.errors import LLMError -from gn3.auth.authorisation.oauth2.resource_server import require_oauth -from gn3.auth import db + +from gn3.oauth2.authorisation import require_token +from gn3 import sqlite_db_utils as db gnqa = Blueprint("gnqa", __name__) @@ -26,6 +35,7 @@ CREATE TABLE IF NOT EXISTS history( ) WITHOUT ROWID """ + RATING_TABLE_CREATE_QUERY = """ CREATE TABLE IF NOT EXISTS Rating( user_id TEXT NOT NULL, @@ -39,40 +49,177 @@ CREATE TABLE IF NOT EXISTS Rating( """ +RATE_LIMITER_TABLE_CREATE_QUERY = """ +CREATE TABLE IF NOT EXISTS Limiter( + identifier TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + tokens INTEGER, + expiry_time TIMESTAMP, + PRIMARY KEY(identifier) +) +""" + + def database_setup(): """Temporary method to remove the need to have CREATE queries in functions""" with db.connection(current_app.config["LLM_DB_PATH"]) as conn: cursor = conn.cursor() cursor.execute(HISTORY_TABLE_CREATE_QUERY) cursor.execute(RATING_TABLE_CREATE_QUERY) + cursor.execute(RATE_LIMITER_TABLE_CREATE_QUERY) + + +def clean_query(query:str) -> str: + """This function cleans up query removing + punctuation and whitepace and transform to + lowercase + clean_query("!hello test.") -> "hello test" + """ + strip_chars = string.punctuation + string.whitespace + str_query = query.lower().strip(strip_chars) + return str_query + + +def is_verified_anonymous_user(header_metadata): + """This function should verify autheniticity of metadate from gn2 """ + anony_id = header_metadata.get("Anonymous-Id") #should verify this + metadata signature + user_status = header_metadata.get("Anonymous-Status", "") + _user_signed_metadata = ( + header_metadata.get("Anony-Metadata", "")) # TODO~ verify this for integrity with tokens + return bool(anony_id) and user_status.lower() == "verified" + +def with_gnqna_fallback(view_func): + """Allow fallback to GNQNA user if token auth fails or token is malformed.""" + @wraps(view_func) + def wrapper(*args, **kwargs): + def call_with_anonymous_fallback(): + return view_func.__wrapped__(*args, + **{**kwargs, "auth_token": None, "valid_anony": True}) + + try: + response = view_func(*args, **kwargs) + + is_invalid_token = ( + isinstance(response, tuple) and + len(response) == 2 and + response[1] == 400 + ) + + if is_invalid_token and is_verified_anonymous_user(dict(request.headers)): + return call_with_anonymous_fallback() + + return response + + except (DecodeError, ValueError): # occurs when trying to parse the token or auth results + if is_verified_anonymous_user(dict(request.headers)): + return call_with_anonymous_fallback() + return view_func.__wrapped__(*args, **kwargs) + + return wrapper + + +def is_valid_address(ip_string) -> bool : + """Function checks if is a valid ip address is valid""" + # todo !verify data is sent from gn2 + try: + ipaddress.ip_address(ip_string) + return True + except ValueError: + return False + + +def check_rate_limiter(ip_address, db_path, query, tokens_lifespan=1440, default_tokens=4): + """ + Checks if an anonymous user has a valid token within the given lifespan. + If expired or not found, creates or resets the token bucket. + `tokens_lifespan` is in seconds. 1440 seconds. + default_token set to 4 requests per hour. + """ + # Extract IP address /identifier + if not ip_address or not is_valid_address(ip_address): + raise ValueError("Please provide a valid IP address") + now = datetime.utcnow() + new_expiry = (now + timedelta(seconds=tokens_lifespan)).strftime("%Y-%m-%d %H:%M:%S") + + with db.connection(db_path) as conn: + cursor = conn.cursor() + # Fetch existing limiter record + cursor.execute(""" + SELECT tokens, expiry_time FROM Limiter + WHERE identifier = ? + """, (ip_address,)) + row = cursor.fetchone() + + if row: + tokens, expiry_time_str = row + expiry_time = datetime.strptime(expiry_time_str, "%Y-%m-%d %H:%M:%S") + time_diff = (expiry_time - now).total_seconds() + + if 0 < time_diff <= tokens_lifespan: + if tokens > 0: + # Consume token + cursor.execute(""" + UPDATE Limiter + SET tokens = tokens - 1 + WHERE identifier = ? AND tokens > 0 + """, (ip_address,)) + return True + else: + raise LLMError("Rate limit exceeded. Please try again later.", + query) + else: + # Token expired — reset ~probably reset this after 200 status + cursor.execute(""" + UPDATE Limiter + SET tokens = ?, expiry_time = ? + WHERE identifier = ? + """, (default_tokens, new_expiry, ip_address)) + return True + else: + # New user — insert record ~probably reset this after 200 status + cursor.execute(""" + INSERT INTO Limiter(identifier, tokens, expiry_time) + VALUES (?, ?, ?) + """, (ip_address, default_tokens, new_expiry)) + return True @gnqa.route("/search", methods=["GET"]) -def search(): +@with_gnqna_fallback +@require_token +def search(auth_token=None, valid_anony=False): """Api endpoint for searching queries in fahamu Api""" query = request.args.get("query", "") if not query: return jsonify({"error": "query get parameter is missing in the request"}), 400 + fahamu_token = current_app.config.get("FAHAMU_AUTH_TOKEN") if not fahamu_token: raise LLMError( "Request failed: an LLM authorisation token is required ", query) database_setup() - with (db.connection(current_app.config["LLM_DB_PATH"]) as conn, - require_oauth.acquire("profile user") as token): + with db.connection(current_app.config["LLM_DB_PATH"]) as conn: cursor = conn.cursor() previous_answer_query = """ SELECT user_id, task_id, query, results FROM history - WHERE created_at > DATE('now', '-1 day') AND - user_id = ? AND + WHERE created_at > DATE('now', '-21 day') AND query = ? ORDER BY created_at DESC LIMIT 1 """ - res = cursor.execute(previous_answer_query, (str(token.user.user_id), query)) + res = cursor.execute(previous_answer_query, (clean_query(query),)) previous_result = res.fetchone() if previous_result: _, _, _, response = previous_result + response = json.loads(response) + response["query"] = query return response + if valid_anony: + # rate limit anonymous verified users + user_metadata = json.loads(request.headers.get("Anony-Metadata", {})) + check_rate_limiter(user_metadata.get("ip_address", ""), + current_app.config["LLM_DB_PATH"], + request.args.get("query", "")) + task_id, answer, refs = get_gnqa( query, fahamu_token, current_app.config.get("DATA_DIR")) response = { @@ -81,52 +228,51 @@ def search(): "answer": answer, "references": refs } + user_id = str(uuid.uuid4()) if valid_anony else get_user_id(auth_token) cursor.execute( """INSERT INTO history(user_id, task_id, query, results) VALUES(?, ?, ?, ?) - """, (str(token.user.user_id), str(task_id["task_id"]), - query, + """, (user_id, str(task_id["task_id"]), + clean_query(query), json.dumps(response)) ) return response @gnqa.route("/rating/<task_id>", methods=["POST"]) -@require_oauth("profile") -def rate_queries(task_id): +@require_token +def rate_queries(task_id, auth_token=None): """Api endpoint for rating GNQA query and answer""" database_setup() - with (require_oauth.acquire("profile") as token, - db.connection(current_app.config["LLM_DB_PATH"]) as conn): + user_id = get_user_id(auth_token) + with db.connection(current_app.config["LLM_DB_PATH"]) as conn: results = request.json - user_id, query, answer, weight = (token.user.user_id, - results.get("query"), - results.get("answer"), - results.get("weight", 0)) + query, answer, weight = (results.get("query"), + results.get("answer"), + results.get("weight", 0)) cursor = conn.cursor() cursor.execute("""INSERT INTO Rating(user_id, query, answer, weight, task_id) VALUES(?, ?, ?, ?, ?) ON CONFLICT(task_id) DO UPDATE SET weight=excluded.weight - """, (str(user_id), query, answer, weight, task_id)) + """, (user_id, query, answer, weight, task_id)) return { "message": "You have successfully rated this query. Thank you!" }, 200 @gnqa.route("/search/records", methods=["GET"]) -@require_oauth("profile user") -def get_user_search_records(): +@require_token +def get_user_search_records(auth_token=None): """get all history records for a given user using their user id """ - with (require_oauth.acquire("profile user") as token, - db.connection(current_app.config["LLM_DB_PATH"]) as conn): + with db.connection(current_app.config["LLM_DB_PATH"]) as conn: cursor = conn.cursor() cursor.execute( """SELECT task_id, query, created_at from history WHERE user_id=?""", - (str(token.user.user_id),)) + (get_user_id(auth_token),)) results = [dict(item) for item in cursor.fetchall()] return jsonify(sorted(results, reverse=True, key=lambda x: datetime.strptime(x.get("created_at"), @@ -134,17 +280,15 @@ def get_user_search_records(): @gnqa.route("/search/record/<task_id>", methods=["GET"]) -@require_oauth("profile user") -def get_user_record_by_task(task_id): +@require_token +def get_user_record_by_task(task_id, auth_token = None): """Get user previous search record by task id """ - with (require_oauth.acquire("profile user") as token, - db.connection(current_app.config["LLM_DB_PATH"]) as conn): + with db.connection(current_app.config["LLM_DB_PATH"]) as conn: cursor = conn.cursor() cursor.execute( """SELECT results from history Where task_id=? and user_id=?""", - (task_id, - str(token.user.user_id),)) + (task_id, get_user_id(auth_token),)) record = cursor.fetchone() if record: return dict(record).get("results") @@ -152,28 +296,34 @@ def get_user_record_by_task(task_id): @gnqa.route("/search/record/<task_id>", methods=["DELETE"]) -@require_oauth("profile user") -def delete_record(task_id): +@require_token +def delete_record(task_id, auth_token = None): """Delete user previous seach record by task-id""" - with (require_oauth.acquire("profile user") as token, - db.connection(current_app.config["LLM_DB_PATH"]) as conn): + with db.connection(current_app.config["LLM_DB_PATH"]) as conn: cursor = conn.cursor() query = """DELETE FROM history WHERE task_id=? and user_id=?""" - cursor.execute(query, (task_id, token.user.user_id,)) + cursor.execute(query, (task_id, get_user_id(auth_token),)) return {"msg": f"Successfully Deleted the task {task_id}"} @gnqa.route("/search/records", methods=["DELETE"]) -@require_oauth("profile user") -def delete_records(): +@require_token +def delete_records(auth_token=None): """ Delete a users records using for all given task ids""" - with (require_oauth.acquire("profile user") as token, - db.connection(current_app.config["LLM_DB_PATH"]) as conn): + with db.connection(current_app.config["LLM_DB_PATH"]) as conn: task_ids = list(request.json.values()) cursor = conn.cursor() - query = f""" -DELETE FROM history WHERE task_id IN ({', '.join('?' * len(task_ids))}) AND user_id=? - """ - cursor.execute(query, (*task_ids, str(token.user.user_id),)) + query = ("DELETE FROM history WHERE task_id IN " + f"({', '.join('?' * len(task_ids))}) " + "AND user_id=?") + cursor.execute(query, (*task_ids, get_user_id(auth_token),)) return jsonify({}) + + +def get_user_id(auth_token: Optional[dict] = None): + """Retrieve the user ID from the JWT token.""" + if auth_token is None or auth_token.get("jwt", {}).get("sub") is None: + raise LLMError("Invalid auth token encountered") + user_id = auth_token["jwt"]["sub"] + return user_id diff --git a/gn3/api/lmdb_sample_data.py b/gn3/api/lmdb_sample_data.py new file mode 100644 index 0000000..eaa71c2 --- /dev/null +++ b/gn3/api/lmdb_sample_data.py @@ -0,0 +1,40 @@ +"""API endpoint for retrieving sample data from LMDB storage""" +import hashlib +from pathlib import Path + +import lmdb +from flask import Blueprint, current_app, jsonify + +lmdb_sample_data = Blueprint("lmdb_sample_data", __name__) + + +@lmdb_sample_data.route("/sample-data/<string:dataset>/<int:trait_id>", methods=["GET"]) +def get_sample_data(dataset: str, trait_id: int): + """Retrieve sample data from LMDB for a given dataset and trait. + + Path Parameters: + dataset: The name of the dataset + trait_id: The ID of the trait + + Returns: + JSON object mapping sample IDs to their values + """ + checksum = hashlib.md5( + f"{dataset}-{trait_id}".encode() + ).hexdigest() + + db_path = Path(current_app.config["LMDB_DATA_PATH"]) / checksum + if not db_path.exists(): + return jsonify(error="No data found for given dataset and trait"), 404 + try: + with lmdb.open(str(db_path), max_dbs=15, readonly=True) as env: + data = {} + with env.begin(write=False) as txn: + cursor = txn.cursor() + for key, value in cursor: + data[key.decode()] = float(value.decode()) + + return jsonify(data) + + except lmdb.Error as err: + return jsonify(error=f"LMDB error: {str(err)}"), 500 diff --git a/gn3/api/metadata.py b/gn3/api/metadata.py index 6110880..e272c0d 100644 --- a/gn3/api/metadata.py +++ b/gn3/api/metadata.py @@ -9,7 +9,8 @@ from flask import Blueprint from flask import request from flask import current_app -from gn3.auth.authorisation.errors import AuthorisationError + +from gn3.oauth2.errors import AuthorisationError from gn3.db.datasets import (retrieve_metadata, save_metadata, get_history) @@ -27,6 +28,7 @@ from gn3.api.metadata_api import wiki metadata = Blueprint("metadata", __name__) metadata.register_blueprint(wiki.wiki_blueprint) +metadata.register_blueprint(wiki.rif_blueprint) @metadata.route("/datasets/<name>", methods=["GET"]) @@ -170,7 +172,7 @@ def view_history(id_): "history": history, }) if history.get("error"): - raise Exception(history.get("error_description")) + raise Exception(history.get("error_description")) # pylint: disable=[broad-exception-raised] return history diff --git a/gn3/api/metadata_api/wiki.py b/gn3/api/metadata_api/wiki.py index 8df6cfb..70c5cf4 100644 --- a/gn3/api/metadata_api/wiki.py +++ b/gn3/api/metadata_api/wiki.py @@ -1,27 +1,31 @@ -"""API for accessing/editting wiki metadata""" +"""API for accessing/editing rif/wiki metadata""" import datetime from typing import Any, Dict +from typing import Optional from flask import Blueprint, request, jsonify, current_app, make_response from gn3 import db_utils -from gn3.auth.authorisation.oauth2.resource_server import require_oauth +from gn3.oauth2.authorisation import require_token from gn3.db import wiki from gn3.db.rdf.wiki import ( get_wiki_entries_by_symbol, get_comment_history, update_wiki_comment, + get_rif_entries_by_symbol, + delete_wiki_entries_by_id, ) wiki_blueprint = Blueprint("wiki", __name__, url_prefix="wiki") +rif_blueprint = Blueprint("rif", __name__, url_prefix="rif") +@wiki_blueprint.route("/edit", methods=["POST"], defaults={'comment_id': None}) @wiki_blueprint.route("/<int:comment_id>/edit", methods=["POST"]) -@require_oauth("profile") -def edit_wiki(comment_id: int): - """Edit wiki comment. This is achieved by adding another entry with a new VersionId""" +def edit_wiki(comment_id: Optional[int]): + """Edit/Insert wiki comment. This is achieved by adding another entry with a new VersionId""" # FIXME: attempt to check and fix for types here with relevant errors payload: Dict[str, Any] = request.json # type: ignore pubmed_ids = [str(x) for x in payload.get("pubmed_ids", [])] @@ -48,13 +52,17 @@ def edit_wiki(comment_id: int): VALUES (%(Id)s, %(versionId)s, %(symbol)s, %(PubMed_ID)s, %(SpeciesID)s, %(comment)s, %(email)s, %(createtime)s, %(user_ip)s, %(weburl)s, %(initial)s, %(reason)s) """ with db_utils.database_connection(current_app.config["SQL_URI"]) as conn: - cursor = conn.cursor() - next_version = 0 + cursor, next_version = conn.cursor(), 0 + if not comment_id: + comment_id = wiki.get_next_comment_id(cursor) + insert_dict["Id"] = comment_id + else: + next_version = wiki.get_next_comment_version(cursor, comment_id) + try: category_ids = wiki.get_categories_ids( cursor, payload["categories"]) species_id = wiki.get_species_id(cursor, payload["species"]) - next_version = wiki.get_next_comment_version(cursor, comment_id) except wiki.MissingDBDataException as missing_exc: return jsonify(error=f"Error editing wiki entry, {missing_exc}"), 500 insert_dict["SpeciesID"] = species_id @@ -83,7 +91,7 @@ def edit_wiki(comment_id: int): sparql_auth_uri=current_app.config["SPARQL_AUTH_URI"] ) except Exception as exc: - conn.rollback() # type: ignore + conn.rollback() # type: ignore raise exc return jsonify({"success": "ok"}) return jsonify(error="Error editing wiki entry, most likely due to DB error!"), 500 @@ -154,3 +162,57 @@ def get_history(comment_id): payload.headers["Content-Type"] = "application/ld+json" return payload, status_code return jsonify(data), status_code + + +@rif_blueprint.route("/<string:symbol>", methods=["GET"]) +def get_ncbi_rif_entries(symbol: str): + """Fetch NCBI RIF entries""" + status_code = 200 + response = get_rif_entries_by_symbol( + symbol, + sparql_uri=current_app.config["SPARQL_ENDPOINT"]) + data = response.get("data") + if not data: + data, status_code = {}, 404 + if request.headers.get("Accept") == "application/ld+json": + payload = make_response(response) + payload.headers["Content-Type"] = "application/ld+json" + return payload, status_code + return jsonify(data), status_code + + +@wiki_blueprint.route("/delete", methods=["POST"], defaults={'comment_id': None}) +@wiki_blueprint.route("/<int:comment_id>/delete", methods=["POST"]) +@require_token +def delete_wiki(comment_id: Optional[int] = None, **kwargs): # pylint: disable=[unused-argument] + """Delete a wiki entry by its comment_id from both SQL and RDF.""" + if comment_id is None: + return jsonify(error="comment_id is required for deletion."), 400 + with (db_utils.database_connection(current_app.config["SQL_URI"]) as conn, + conn.cursor() as cursor): + try: + # Delete from SQL + delete_query = "DELETE FROM GeneRIF WHERE Id = %s" + current_app.logger.debug( + f"Running query: {delete_query} with Id={comment_id}") + cursor.execute(delete_query, (comment_id,)) + # Delete from RDF + try: + delete_wiki_entries_by_id( + wiki_id=comment_id, + sparql_user=current_app.config["SPARQL_USER"], + sparql_password=current_app.config["SPARQL_PASSWORD"], + sparql_auth_uri=current_app.config["SPARQL_AUTH_URI"], + graph="<http://genenetwork.org>" + ) + # pylint: disable=W0718 + except Exception as rdf_exc: + current_app.logger.error(f"RDF deletion failed: {rdf_exc}") + conn.rollback() + return jsonify(error="Failed to delete wiki entry from RDF store."), 500 + return jsonify({"success": "Wiki entry deleted successfully."}), 200 + # pylint: disable=W0718 + except Exception as exc: + conn.rollback() + current_app.logger.error(f"Error deleting wiki entry: {exc}") + return jsonify(error="Error deleting wiki entry, most likely due to DB error!"), 500 diff --git a/gn3/api/rqtl.py b/gn3/api/rqtl.py index e029d8d..eb49f8b 100644 --- a/gn3/api/rqtl.py +++ b/gn3/api/rqtl.py @@ -1,5 +1,7 @@ """Endpoints for running the rqtl cmd""" + import os +import uuid from pathlib import Path from flask import Blueprint @@ -7,30 +9,40 @@ from flask import current_app from flask import jsonify from flask import request -from gn3.computations.rqtl import generate_rqtl_cmd, process_rqtl_mapping, \ - process_rqtl_pairscan, process_perm_output -from gn3.fs_helpers import assert_paths_exist, get_tmpdir +from gn3.computations.rqtl import ( + generate_rqtl_cmd, + process_rqtl_mapping, + process_rqtl_pairscan, + process_perm_output, +) +from gn3.computations.streaming import run_process, enable_streaming +from gn3.fs_helpers import assert_path_exists, get_tmpdir rqtl = Blueprint("rqtl", __name__) + @rqtl.route("/compute", methods=["POST"]) -def compute(): +@enable_streaming +def compute(stream_output_file): """Given at least a geno_file and pheno_file, generate and -run the rqtl_wrapper script and return the results as JSON + run the rqtl_wrapper script and return the results as JSON """ - genofile = request.form['geno_file'] - phenofile = request.form['pheno_file'] - - assert_paths_exist([genofile, phenofile]) - - # Split kwargs by those with values and boolean ones that just convert to True/False + genofile = request.form["geno_file"] + phenofile = request.form["pheno_file"] + assert_path_exists(genofile) + assert_path_exists(phenofile) kwargs = ["covarstruct", "model", "method", "nperm", "scale", "control"] boolean_kwargs = ["addcovar", "interval", "pstrata", "pairscan"] all_kwargs = kwargs + boolean_kwargs - rqtl_kwargs = {"geno": genofile, "pheno": phenofile, "outdir": current_app.config.get("TMPDIR")} + rqtl_kwargs = { + "geno": genofile, + "pheno": phenofile, + "outdir": current_app.config.get("TMPDIR"), + } rqtl_bool_kwargs = [] + for kwarg in all_kwargs: if kwarg in request.form: if kwarg in kwargs: @@ -38,30 +50,43 @@ run the rqtl_wrapper script and return the results as JSON if kwarg in boolean_kwargs: rqtl_bool_kwargs.append(kwarg) - outdir = os.path.join(get_tmpdir(),"gn3") + outdir = os.path.join(get_tmpdir(), "gn3") if not os.path.isdir(outdir): os.mkdir(outdir) rqtl_cmd = generate_rqtl_cmd( rqtl_wrapper_cmd=str( - Path(__file__).absolute().parent.parent.parent.joinpath( - 'scripts/rqtl_wrapper.R')), + Path(__file__) + .absolute() + .parent.parent.parent.joinpath("scripts/rqtl_wrapper.R") + ), rqtl_wrapper_kwargs=rqtl_kwargs, - rqtl_wrapper_bool_kwargs=rqtl_bool_kwargs + rqtl_wrapper_bool_kwargs=rqtl_bool_kwargs, ) rqtl_output = {} - if not os.path.isfile(os.path.join(current_app.config.get("TMPDIR"), - "gn3", rqtl_cmd.get('output_file'))): - os.system(rqtl_cmd.get('rqtl_cmd')) + run_id = request.args.get("id", str(uuid.uuid4())) + if not os.path.isfile( + os.path.join( + current_app.config.get("TMPDIR"), "gn3", rqtl_cmd.get("output_file") + ) + ): + pass + run_process(rqtl_cmd.get("rqtl_cmd").split(), stream_output_file, run_id) if "pairscan" in rqtl_bool_kwargs: - rqtl_output['results'] = process_rqtl_pairscan(rqtl_cmd.get('output_file'), genofile) + rqtl_output["results"] = process_rqtl_pairscan( + rqtl_cmd.get("output_file"), genofile + ) else: - rqtl_output['results'] = process_rqtl_mapping(rqtl_cmd.get('output_file')) - - if int(rqtl_kwargs['nperm']) > 0: - rqtl_output['perm_results'], rqtl_output['suggestive'], rqtl_output['significant'] = \ - process_perm_output(rqtl_cmd.get('output_file')) + rqtl_output["results"] = process_rqtl_mapping(rqtl_cmd.get("output_file")) + if int(rqtl_kwargs["nperm"]) > 0: + # pylint: disable=C0301 + perm_output_file = f"PERM_{rqtl_cmd.get('output_file')}" + ( + rqtl_output["perm_results"], + rqtl_output["suggestive"], + rqtl_output["significant"], + ) = process_perm_output(perm_output_file) return jsonify(rqtl_output) diff --git a/gn3/api/rqtl2.py b/gn3/api/rqtl2.py new file mode 100644 index 0000000..dc06d1d --- /dev/null +++ b/gn3/api/rqtl2.py @@ -0,0 +1,55 @@ +""" File contains endpoints for rqlt2""" +import shutil +from pathlib import Path +from flask import current_app +from flask import jsonify +from flask import Blueprint +from flask import request +from gn3.computations.rqtl2 import (compose_rqtl2_cmd, + prepare_files, + validate_required_keys, + write_input_file, + process_qtl2_results + ) +from gn3.computations.streaming import run_process +from gn3.computations.streaming import enable_streaming + +rqtl2 = Blueprint("rqtl2", __name__) + + +@rqtl2.route("/compute", methods=["POST"]) +@enable_streaming +def compute(log_file): + """Endpoint for computing QTL analysis using R/QTL2""" + data = request.json + required_keys = ["crosstype", "geno_data","pheno_data", "geno_codes"] + valid, error = validate_required_keys(required_keys,data) + if not valid: + return jsonify(error=error), 400 + # Provide atleast one of this data entries. + if "physical_map_data" not in data and "geno_map_data" not in data: + return jsonify(error="You need to Provide\ + Either the Physical map or Geno Map data of markers"), 400 + run_id = request.args.get("id", "output") + # prepare necessary files and dir for computation + (workspace_dir, input_file, + output_file, _log2_file) = prepare_files(current_app.config.get("TMPDIR")) + # write the input file with data required for creating the cross + write_input_file(input_file, workspace_dir, data) + # TODO : Implement a better way for fetching the file Path. + rqtl_path =Path(__file__).absolute().parent.parent.parent.joinpath("scripts/rqtl2_wrapper.R") + if not rqtl_path.is_file(): + return jsonify({"error" : f"The script {rqtl_path} does not exists"}), 400 + rqtl2_cmd = compose_rqtl2_cmd(rqtl_path, input_file, + output_file, workspace_dir, + data, current_app.config) + process_output = run_process(rqtl2_cmd.split(),log_file, run_id) + if process_output["code"] != 0: + # Err out for any non-zero status code + return jsonify(process_output), 400 + results = process_qtl2_results(output_file) + shutil.rmtree(workspace_dir, ignore_errors=True, onerror=None) + # append this at end of computation to the log file to mark end of gn3 computation + with open(log_file, "ab+") as file_handler: + file_handler.write("Done with GN3 Computation".encode("utf-8")) + return jsonify(results) diff --git a/gn3/api/search.py b/gn3/api/search.py index f696428..e814a00 100644 --- a/gn3/api/search.py +++ b/gn3/api/search.py @@ -6,7 +6,7 @@ import gzip import json from functools import partial, reduce from pathlib import Path -from typing import Callable +from typing import Union, Callable import urllib.parse from flask import abort, Blueprint, current_app, jsonify, request @@ -40,6 +40,7 @@ def combine_queries(operator: int, *queries: xapian.Query) -> xapian.Query: return reduce(partial(xapian.Query, operator), queries) +# pylint: disable=R0903 class FieldProcessor(xapian.FieldProcessor): """ Field processor for use in a xapian query parser. @@ -65,7 +66,10 @@ def field_processor_or(*field_processors: FieldProcessorFunction) -> FieldProces for field_processor in field_processors])) -def liftover(chain_file: Path, position: ChromosomalPosition) -> Maybe[ChromosomalPosition]: +def liftover( + chain_file: Union[str, Path], + position: ChromosomalPosition +) -> Maybe[ChromosomalPosition]: """Liftover chromosomal position using chain file.""" # The chain file format is described at # https://genome.ucsc.edu/goldenPath/help/chain.html @@ -91,7 +95,10 @@ def liftover(chain_file: Path, position: ChromosomalPosition) -> Maybe[Chromosom return Nothing -def liftover_interval(chain_file: str, interval: ChromosomalInterval) -> ChromosomalInterval: +def liftover_interval( + chain_file: Union[str, Path], + interval: ChromosomalInterval +) -> ChromosomalInterval: """ Liftover interval using chain file. @@ -258,7 +265,7 @@ def parse_query(synteny_files_directory: Path, query: str): xapian.Query(species_prefix + lifted_species), chromosome_prefix, range_prefixes.index("position"), - partial(liftover_interval, + partial(liftover_interval,# type: ignore[arg-type] synteny_files_directory / chain_file))) queryparser.add_boolean_prefix( shorthand, diff --git a/gn3/api/streaming.py b/gn3/api/streaming.py new file mode 100644 index 0000000..2b6b431 --- /dev/null +++ b/gn3/api/streaming.py @@ -0,0 +1,26 @@ +""" File contains endpoint for computational streaming""" +import os +from flask import current_app +from flask import jsonify +from flask import Blueprint +from flask import request + +streaming = Blueprint("stream", __name__) + + +@streaming.route("/<identifier>", methods=["GET"]) +def stream(identifier): + """ This endpoint streams stdout from a file. + It expects the identifier to be the filename + in the TMPDIR created at the main computation + endpoint see example api/rqtl.""" + output_file = os.path.join(current_app.config.get("TMPDIR"), + f"{identifier}.txt") + seek_position = int(request.args.get("peak", 0)) + with open(output_file, encoding="utf-8") as file_handler: + # read from the last read position default to 0 + file_handler.seek(seek_position) + results = {"data": file_handler.readlines(), + "run_id": identifier, + "pointer": file_handler.tell()} + return jsonify(results) diff --git a/gn3/api/wgcna.py b/gn3/api/wgcna.py index fa044cf..5468a2e 100644 --- a/gn3/api/wgcna.py +++ b/gn3/api/wgcna.py @@ -17,7 +17,9 @@ def run_wgcna(): wgcna_script = current_app.config["WGCNA_RSCRIPT"] - results = call_wgcna_script(wgcna_script, wgcna_data) + results = call_wgcna_script(wgcna_script, + wgcna_data, + current_app.config["TMPDIR"]) if results.get("data") is None: return jsonify(results), 401 diff --git a/gn3/app.py b/gn3/app.py index c8f0c5a..9a4269c 100644 --- a/gn3/app.py +++ b/gn3/app.py @@ -1,6 +1,7 @@ """Entry point from spinning up flask""" import os import sys +import json import logging from pathlib import Path @@ -10,7 +11,7 @@ from typing import Union from flask import Flask from flask_cors import CORS # type: ignore -from gn3.loggers import setup_app_handlers +from gn3.loggers import loglevel, setup_app_logging, setup_modules_logging from gn3.api.gemma import gemma from gn3.api.rqtl import rqtl from gn3.api.general import general @@ -26,19 +27,49 @@ from gn3.api.search import search from gn3.api.metadata import metadata from gn3.api.sampledata import sampledata from gn3.api.llm import gnqa -from gn3.auth import oauth2 -from gn3.case_attributes import caseattr +from gn3.api.rqtl2 import rqtl2 +from gn3.api.streaming import streaming +from gn3.api.case_attributes import caseattr +from gn3.api.lmdb_sample_data import lmdb_sample_data + + +class ConfigurationError(Exception): + """Raised in case of a configuration error.""" + + +def verify_app_config(app: Flask) -> None: + """Verify that configuration variables are as expected + It includes: + 1. making sure mandatory settings are defined + 2. provides examples for what to set as config variables (helps local dev) + """ + app_config = { + "AUTH_SERVER_URL": """AUTH_SERVER_URL is used for api requests that need login. + For local dev, use the running auth server url, which defaults to http://127.0.0.1:8081 + """, + } + error_message = [] + + for setting, err in app_config.items(): + print(f"{setting}: {app.config.get(setting)}") + if setting in app.config and bool(app.config[setting]): + continue + error_message.append(err) + if error_message: + raise ConfigurationError("\n".join(error_message)) def create_app(config: Union[Dict, str, None] = None) -> Flask: """Create a new flask object""" app = Flask(__name__) # Load default configuration - app.config.from_object("gn3.settings") + app.config.from_file( + Path(__file__).absolute().parent.joinpath("settings.json"), + load=json.load) # Load environment configuration if "GN3_CONF" in os.environ: - app.config.from_envvar('GN3_CONF') + app.config.from_envvar("GN3_CONF") # Load app specified configuration if config is not None: @@ -52,7 +83,10 @@ def create_app(config: Union[Dict, str, None] = None) -> Flask: if secrets_file and Path(secrets_file).exists(): app.config.from_envvar("GN3_SECRETS") # END: SECRETS - setup_app_handlers(app) + verify_app_config(app) + setup_app_logging(app) + setup_modules_logging(loglevel(app), ("gn3.db.menu", + "gn_libs.mysqldb")) # DO NOT log anything before this point logging.info("Guix Profile: '%s'.", os.environ.get("GUIX_PROFILE")) logging.info("Python Executable: '%s'.", sys.executable) @@ -61,7 +95,9 @@ def create_app(config: Union[Dict, str, None] = None) -> Flask: app, origins=app.config["CORS_ORIGINS"], allow_headers=app.config["CORS_HEADERS"], - supports_credentials=True, intercept_exceptions=False) + supports_credentials=True, + intercept_exceptions=False, + ) app.register_blueprint(general, url_prefix="/api/") app.register_blueprint(gemma, url_prefix="/api/gemma") @@ -76,9 +112,11 @@ def create_app(config: Union[Dict, str, None] = None) -> Flask: app.register_blueprint(search, url_prefix="/api/search") app.register_blueprint(metadata, url_prefix="/api/metadata") app.register_blueprint(sampledata, url_prefix="/api/sampledata") - app.register_blueprint(oauth2, url_prefix="/api/oauth2") app.register_blueprint(caseattr, url_prefix="/api/case-attribute") app.register_blueprint(gnqa, url_prefix="/api/llm") + app.register_blueprint(rqtl2, url_prefix="/api/rqtl2") + app.register_blueprint(streaming, url_prefix="/api/stream") + app.register_blueprint(lmdb_sample_data, url_prefix="/api/lmdb") register_error_handlers(app) return app diff --git a/gn3/auth/__init__.py b/gn3/auth/__init__.py deleted file mode 100644 index cd65e9b..0000000 --- a/gn3/auth/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -"""Top-Level `Auth` module""" -from . import authorisation - -from .views import oauth2 diff --git a/gn3/auth/authorisation/__init__.py b/gn3/auth/authorisation/__init__.py deleted file mode 100644 index abd2747..0000000 --- a/gn3/auth/authorisation/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -"""The authorisation module.""" -from .checks import authorised_p diff --git a/gn3/auth/authorisation/checks.py b/gn3/auth/authorisation/checks.py deleted file mode 100644 index 17daca4..0000000 --- a/gn3/auth/authorisation/checks.py +++ /dev/null @@ -1,69 +0,0 @@ -"""Functions to check for authorisation.""" -from functools import wraps -from typing import Callable - -from flask import request, current_app as app - -from gn3.auth import db -from gn3.auth.authorisation.oauth2.resource_server import require_oauth - -from . import privileges as auth_privs -from .errors import InvalidData, AuthorisationError - -def __system_privileges_in_roles__(conn, user): - """ - This really is a hack since groups are not treated as resources at the - moment of writing this. - - We need a way of allowing the user to have the system:group:* privileges. - """ - query = ( - "SELECT DISTINCT p.* FROM users AS u " - "INNER JOIN group_user_roles_on_resources AS guror " - "ON u.user_id=guror.user_id " - "INNER JOIN roles AS r ON guror.role_id=r.role_id " - "INNER JOIN role_privileges AS rp ON r.role_id=rp.role_id " - "INNER JOIN privileges AS p ON rp.privilege_id=p.privilege_id " - "WHERE u.user_id=? AND p.privilege_id LIKE 'system:%'") - with db.cursor(conn) as cursor: - cursor.execute(query, (str(user.user_id),)) - return (row["privilege_id"] for row in cursor.fetchall()) - -def authorised_p( - privileges: tuple[str, ...], - error_description: str = ( - "You lack authorisation to perform requested action"), - oauth2_scope = "profile"): - """Authorisation decorator.""" - assert len(privileges) > 0, "You must provide at least one privilege" - def __build_authoriser__(func: Callable): - @wraps(func) - def __authoriser__(*args, **kwargs): - # the_user = user or (hasattr(g, "user") and g.user) - with require_oauth.acquire(oauth2_scope) as the_token: - the_user = the_token.user - if the_user: - with db.connection(app.config["AUTH_DB"]) as conn: - user_privileges = tuple( - priv.privilege_id for priv in - auth_privs.user_privileges(conn, the_user)) + tuple( - priv_id for priv_id in - __system_privileges_in_roles__(conn, the_user)) - - not_assigned = [ - priv for priv in privileges if priv not in user_privileges] - if len(not_assigned) == 0: - return func(*args, **kwargs) - - raise AuthorisationError(error_description) - return __authoriser__ - return __build_authoriser__ - -def require_json(func): - """Ensure the request has JSON data.""" - @wraps(func) - def __req_json__(*args, **kwargs): - if bool(request.json): - return func(*args, **kwargs) - raise InvalidData("Expected JSON data in the request.") - return __req_json__ diff --git a/gn3/auth/authorisation/data/__init__.py b/gn3/auth/authorisation/data/__init__.py deleted file mode 100644 index e69de29..0000000 --- a/gn3/auth/authorisation/data/__init__.py +++ /dev/null diff --git a/gn3/auth/authorisation/data/genotypes.py b/gn3/auth/authorisation/data/genotypes.py deleted file mode 100644 index 8f901a5..0000000 --- a/gn3/auth/authorisation/data/genotypes.py +++ /dev/null @@ -1,96 +0,0 @@ -"""Handle linking of Genotype data to the Auth(entic|oris)ation system.""" -import uuid -from typing import Iterable - -from MySQLdb.cursors import DictCursor - -import gn3.auth.db as authdb -import gn3.db_utils as gn3db -from gn3.auth.dictify import dictify -from gn3.auth.authorisation.checks import authorised_p -from gn3.auth.authorisation.groups.models import Group - -def linked_genotype_data(conn: authdb.DbConnection) -> Iterable[dict]: - """Retrive genotype data that is linked to user groups.""" - with authdb.cursor(conn) as cursor: - cursor.execute("SELECT * FROM linked_genotype_data") - return (dict(row) for row in cursor.fetchall()) - -@authorised_p(("system:data:link-to-group",), - error_description=( - "You do not have sufficient privileges to link data to (a) " - "group(s)."), - oauth2_scope="profile group resource") -def ungrouped_genotype_data(# pylint: disable=[too-many-arguments] - authconn: authdb.DbConnection, gn3conn: gn3db.Connection, - search_query: str, selected: tuple[dict, ...] = tuple(), - limit: int = 10000, offset: int = 0) -> tuple[ - dict, ...]: - """Retrieve genotype data that is not linked to any user group.""" - params = tuple( - (row["SpeciesId"], row["InbredSetId"], row["GenoFreezeId"]) - for row in linked_genotype_data(authconn)) + tuple( - (row["SpeciesId"], row["InbredSetId"], row["GenoFreezeId"]) - for row in selected) - query = ( - "SELECT s.SpeciesId, iset.InbredSetId, iset.InbredSetName, " - "gf.Id AS GenoFreezeId, gf.Name AS dataset_name, " - "gf.FullName AS dataset_fullname, " - "gf.ShortName AS dataset_shortname " - "FROM Species AS s INNER JOIN InbredSet AS iset " - "ON s.SpeciesId=iset.SpeciesId INNER JOIN GenoFreeze AS gf " - "ON iset.InbredSetId=gf.InbredSetId ") - - if len(params) > 0 or bool(search_query): - query = query + "WHERE " - - if len(params) > 0: - paramstr = ", ".join(["(%s, %s, %s)"] * len(params)) - query = query + ( - "(s.SpeciesId, iset.InbredSetId, gf.Id) " - f"NOT IN ({paramstr}) " - ) + ("AND " if bool(search_query) else "") - - if bool(search_query): - query = query + ( - "CONCAT(gf.Name, ' ', gf.FullName, ' ', gf.ShortName) LIKE %s ") - params = params + ((f"%{search_query}%",),)# type: ignore[operator] - - query = query + f"LIMIT {int(limit)} OFFSET {int(offset)}" - with gn3conn.cursor(DictCursor) as cursor: - cursor.execute( - query, tuple(item for sublist in params for item in sublist)) - return tuple(row for row in cursor.fetchall()) - -@authorised_p( - ("system:data:link-to-group",), - error_description=( - "You do not have sufficient privileges to link data to (a) " - "group(s)."), - oauth2_scope="profile group resource") -def link_genotype_data( - conn: authdb.DbConnection, group: Group, datasets: dict) -> dict: - """Link genotye `datasets` to `group`.""" - with authdb.cursor(conn) as cursor: - cursor.executemany( - "INSERT INTO linked_genotype_data VALUES " - "(:data_link_id, :group_id, :SpeciesId, :InbredSetId, " - ":GenoFreezeId, :dataset_name, :dataset_fullname, " - ":dataset_shortname) " - "ON CONFLICT (SpeciesId, InbredSetId, GenoFreezeId) DO NOTHING", - tuple({ - "data_link_id": str(uuid.uuid4()), - "group_id": str(group.group_id), - **{ - key: value for key,value in dataset.items() if key in ( - "GenoFreezeId", "InbredSetId", "SpeciesId", - "dataset_fullname", "dataset_name", "dataset_shortname") - } - } for dataset in datasets)) - return { - "description": ( - f"Successfully linked {len(datasets)} to group " - f"'{group.group_name}'."), - "group": dictify(group), - "datasets": datasets - } diff --git a/gn3/auth/authorisation/data/mrna.py b/gn3/auth/authorisation/data/mrna.py deleted file mode 100644 index bdfc5c1..0000000 --- a/gn3/auth/authorisation/data/mrna.py +++ /dev/null @@ -1,100 +0,0 @@ -"""Handle linking of mRNA Assay data to the Auth(entic|oris)ation system.""" -import uuid -from typing import Iterable -from MySQLdb.cursors import DictCursor - -import gn3.auth.db as authdb -import gn3.db_utils as gn3db -from gn3.auth.dictify import dictify -from gn3.auth.authorisation.checks import authorised_p -from gn3.auth.authorisation.groups.models import Group - -def linked_mrna_data(conn: authdb.DbConnection) -> Iterable[dict]: - """Retrieve mRNA Assay data that is linked to user groups.""" - with authdb.cursor(conn) as cursor: - cursor.execute("SELECT * FROM linked_mrna_data") - return (dict(row) for row in cursor.fetchall()) - -@authorised_p(("system:data:link-to-group",), - error_description=( - "You do not have sufficient privileges to link data to (a) " - "group(s)."), - oauth2_scope="profile group resource") -def ungrouped_mrna_data(# pylint: disable=[too-many-arguments] - authconn: authdb.DbConnection, gn3conn: gn3db.Connection, - search_query: str, selected: tuple[dict, ...] = tuple(), - limit: int = 10000, offset: int = 0) -> tuple[ - dict, ...]: - """Retrieve mrna data that is not linked to any user group.""" - params = tuple( - (row["SpeciesId"], row["InbredSetId"], row["ProbeFreezeId"], - row["ProbeSetFreezeId"]) - for row in linked_mrna_data(authconn)) + tuple( - (row["SpeciesId"], row["InbredSetId"], row["ProbeFreezeId"], - row["ProbeSetFreezeId"]) - for row in selected) - query = ( - "SELECT s.SpeciesId, iset.InbredSetId, iset.InbredSetName, " - "pf.ProbeFreezeId, pf.Name AS StudyName, psf.Id AS ProbeSetFreezeId, " - "psf.Name AS dataset_name, psf.FullName AS dataset_fullname, " - "psf.ShortName AS dataset_shortname " - "FROM Species AS s INNER JOIN InbredSet AS iset " - "ON s.SpeciesId=iset.SpeciesId INNER JOIN ProbeFreeze AS pf " - "ON iset.InbredSetId=pf.InbredSetId INNER JOIN ProbeSetFreeze AS psf " - "ON pf.ProbeFreezeId=psf.ProbeFreezeId ") + ( - "WHERE " if (len(params) > 0 or bool(search_query)) else "") - - if len(params) > 0: - paramstr = ", ".join(["(%s, %s, %s, %s)"] * len(params)) - query = query + ( - "(s.SpeciesId, iset.InbredSetId, pf.ProbeFreezeId, psf.Id) " - f"NOT IN ({paramstr}) " - ) + ("AND " if bool(search_query) else "") - - if bool(search_query): - query = query + ( - "CONCAT(pf.Name, psf.Name, ' ', psf.FullName, ' ', psf.ShortName) " - "LIKE %s ") - params = params + ((f"%{search_query}%",),)# type: ignore[operator] - - query = query + f"LIMIT {int(limit)} OFFSET {int(offset)}" - with gn3conn.cursor(DictCursor) as cursor: - cursor.execute( - query, tuple(item for sublist in params for item in sublist)) - return tuple(row for row in cursor.fetchall()) - -@authorised_p( - ("system:data:link-to-group",), - error_description=( - "You do not have sufficient privileges to link data to (a) " - "group(s)."), - oauth2_scope="profile group resource") -def link_mrna_data( - conn: authdb.DbConnection, group: Group, datasets: dict) -> dict: - """Link genotye `datasets` to `group`.""" - with authdb.cursor(conn) as cursor: - cursor.executemany( - "INSERT INTO linked_mrna_data VALUES " - "(:data_link_id, :group_id, :SpeciesId, :InbredSetId, " - ":ProbeFreezeId, :ProbeSetFreezeId, :dataset_name, " - ":dataset_fullname, :dataset_shortname) " - "ON CONFLICT " - "(SpeciesId, InbredSetId, ProbeFreezeId, ProbeSetFreezeId) " - "DO NOTHING", - tuple({ - "data_link_id": str(uuid.uuid4()), - "group_id": str(group.group_id), - **{ - key: value for key,value in dataset.items() if key in ( - "SpeciesId", "InbredSetId", "ProbeFreezeId", - "ProbeSetFreezeId", "dataset_fullname", "dataset_name", - "dataset_shortname") - } - } for dataset in datasets)) - return { - "description": ( - f"Successfully linked {len(datasets)} to group " - f"'{group.group_name}'."), - "group": dictify(group), - "datasets": datasets - } diff --git a/gn3/auth/authorisation/data/phenotypes.py b/gn3/auth/authorisation/data/phenotypes.py deleted file mode 100644 index ff98295..0000000 --- a/gn3/auth/authorisation/data/phenotypes.py +++ /dev/null @@ -1,140 +0,0 @@ -"""Handle linking of Phenotype data to the Auth(entic|oris)ation system.""" -import uuid -from typing import Any, Iterable - -from MySQLdb.cursors import DictCursor - -import gn3.auth.db as authdb -import gn3.db_utils as gn3db -from gn3.auth.dictify import dictify -from gn3.auth.authorisation.checks import authorised_p -from gn3.auth.authorisation.groups.models import Group - -def linked_phenotype_data( - authconn: authdb.DbConnection, gn3conn: gn3db.Connection, - species: str = "") -> Iterable[dict[str, Any]]: - """Retrieve phenotype data linked to user groups.""" - authkeys = ("SpeciesId", "InbredSetId", "PublishFreezeId", "PublishXRefId") - with (authdb.cursor(authconn) as authcursor, - gn3conn.cursor(DictCursor) as gn3cursor): - authcursor.execute("SELECT * FROM linked_phenotype_data") - linked = tuple(tuple(row[key] for key in authkeys) - for row in authcursor.fetchall()) - if len(linked) <= 0: - return iter(()) - paramstr = ", ".join(["(%s, %s, %s, %s)"] * len(linked)) - query = ( - "SELECT spc.SpeciesId, spc.Name AS SpeciesName, iset.InbredSetId, " - "iset.InbredSetName, pf.Id AS PublishFreezeId, " - "pf.Name AS dataset_name, pf.FullName AS dataset_fullname, " - "pf.ShortName AS dataset_shortname, pxr.Id AS PublishXRefId " - "FROM " - "Species AS spc " - "INNER JOIN InbredSet AS iset " - "ON spc.SpeciesId=iset.SpeciesId " - "INNER JOIN PublishFreeze AS pf " - "ON iset.InbredSetId=pf.InbredSetId " - "INNER JOIN PublishXRef AS pxr " - "ON pf.InbredSetId=pxr.InbredSetId") + ( - " WHERE" if (len(linked) > 0 or bool(species)) else "") + ( - (" (spc.SpeciesId, iset.InbredSetId, pf.Id, pxr.Id) " - f"IN ({paramstr})") if len(linked) > 0 else "") + ( - " AND"if len(linked) > 0 else "") + ( - " spc.SpeciesName=%s" if bool(species) else "") - params = tuple(item for sublist in linked for item in sublist) + ( - (species,) if bool(species) else tuple()) - gn3cursor.execute(query, params) - return (item for item in gn3cursor.fetchall()) - -@authorised_p(("system:data:link-to-group",), - error_description=( - "You do not have sufficient privileges to link data to (a) " - "group(s)."), - oauth2_scope="profile group resource") -def ungrouped_phenotype_data( - authconn: authdb.DbConnection, gn3conn: gn3db.Connection): - """Retrieve phenotype data that is not linked to any user group.""" - with gn3conn.cursor() as cursor: - params = tuple( - (row["SpeciesId"], row["InbredSetId"], row["PublishFreezeId"], - row["PublishXRefId"]) - for row in linked_phenotype_data(authconn, gn3conn)) - paramstr = ", ".join(["(?, ?, ?, ?)"] * len(params)) - query = ( - "SELECT spc.SpeciesId, spc.SpeciesName, iset.InbredSetId, " - "iset.InbredSetName, pf.Id AS PublishFreezeId, " - "pf.Name AS dataset_name, pf.FullName AS dataset_fullname, " - "pf.ShortName AS dataset_shortname, pxr.Id AS PublishXRefId " - "FROM " - "Species AS spc " - "INNER JOIN InbredSet AS iset " - "ON spc.SpeciesId=iset.SpeciesId " - "INNER JOIN PublishFreeze AS pf " - "ON iset.InbredSetId=pf.InbredSetId " - "INNER JOIN PublishXRef AS pxr " - "ON pf.InbredSetId=pxr.InbredSetId") - if len(params) > 0: - query = query + ( - f" WHERE (iset.InbredSetId, pf.Id, pxr.Id) NOT IN ({paramstr})") - - cursor.execute(query, params) - return tuple(dict(row) for row in cursor.fetchall()) - - return tuple() - -def __traits__(gn3conn: gn3db.Connection, params: tuple[dict, ...]) -> tuple[dict, ...]: - """An internal utility function. Don't use outside of this module.""" - if len(params) < 1: - return tuple() - paramstr = ", ".join(["(%s, %s, %s, %s)"] * len(params)) - with gn3conn.cursor(DictCursor) as cursor: - cursor.execute( - "SELECT spc.SpeciesId, iset.InbredSetId, pf.Id AS PublishFreezeId, " - "pf.Name AS dataset_name, pf.FullName AS dataset_fullname, " - "pf.ShortName AS dataset_shortname, pxr.Id AS PublishXRefId " - "FROM " - "Species AS spc " - "INNER JOIN InbredSet AS iset " - "ON spc.SpeciesId=iset.SpeciesId " - "INNER JOIN PublishFreeze AS pf " - "ON iset.InbredSetId=pf.InbredSetId " - "INNER JOIN PublishXRef AS pxr " - "ON pf.InbredSetId=pxr.InbredSetId " - "WHERE (spc.SpeciesName, iset.InbredSetName, pf.Name, pxr.Id) " - f"IN ({paramstr})", - tuple( - itm for sublist in ( - (item["species"], item["group"], item["dataset"], item["name"]) - for item in params) - for itm in sublist)) - return cursor.fetchall() - -@authorised_p(("system:data:link-to-group",), - error_description=( - "You do not have sufficient privileges to link data to (a) " - "group(s)."), - oauth2_scope="profile group resource") -def link_phenotype_data( - authconn:authdb.DbConnection, gn3conn: gn3db.Connection, group: Group, - traits: tuple[dict, ...]) -> dict: - """Link phenotype traits to a user group.""" - with authdb.cursor(authconn) as cursor: - params = tuple({ - "data_link_id": str(uuid.uuid4()), - "group_id": str(group.group_id), - **item - } for item in __traits__(gn3conn, traits)) - cursor.executemany( - "INSERT INTO linked_phenotype_data " - "VALUES (" - ":data_link_id, :group_id, :SpeciesId, :InbredSetId, " - ":PublishFreezeId, :dataset_name, :dataset_fullname, " - ":dataset_shortname, :PublishXRefId" - ")", - params) - return { - "description": ( - f"Successfully linked {len(traits)} traits to group."), - "group": dictify(group), - "traits": params - } diff --git a/gn3/auth/authorisation/data/views.py b/gn3/auth/authorisation/data/views.py deleted file mode 100644 index 81811dd..0000000 --- a/gn3/auth/authorisation/data/views.py +++ /dev/null @@ -1,310 +0,0 @@ -"""Handle data endpoints.""" -import sys -import uuid -import json -from typing import Any -from functools import partial - -import redis -from MySQLdb.cursors import DictCursor -from authlib.integrations.flask_oauth2.errors import _HTTPException -from flask import request, jsonify, Response, Blueprint, current_app as app - -import gn3.db_utils as gn3db -from gn3 import jobs -from gn3.commands import run_async_cmd -from gn3.db.traits import build_trait_name - -from gn3.auth import db -from gn3.auth.db_utils import with_db_connection - -from gn3.auth.authorisation.checks import require_json -from gn3.auth.authorisation.errors import InvalidData, NotFoundError - -from gn3.auth.authorisation.groups.models import group_by_id - -from gn3.auth.authorisation.users.models import user_resource_roles - -from gn3.auth.authorisation.resources.checks import authorised_for -from gn3.auth.authorisation.resources.models import ( - user_resources, public_resources, attach_resources_data) - -from gn3.auth.authorisation.oauth2.resource_server import require_oauth - -from gn3.auth.authorisation.users import User -from gn3.auth.authorisation.data.phenotypes import link_phenotype_data -from gn3.auth.authorisation.data.mrna import link_mrna_data, ungrouped_mrna_data -from gn3.auth.authorisation.data.genotypes import ( - link_genotype_data, ungrouped_genotype_data) - -data = Blueprint("data", __name__) - -@data.route("species") -def list_species() -> Response: - """List all available species information.""" - with (gn3db.database_connection(app.config["SQL_URI"]) as gn3conn, - gn3conn.cursor(DictCursor) as cursor): - cursor.execute("SELECT * FROM Species") - return jsonify(tuple(dict(row) for row in cursor.fetchall())) - -@data.route("/authorisation", methods=["POST"]) -@require_json -def authorisation() -> Response: - """Retrive the authorisation level for datasets/traits for the user.""" - # Access endpoint with something like: - # curl -X POST http://127.0.0.1:8080/api/oauth2/data/authorisation \ - # -H "Content-Type: application/json" \ - # -d '{"traits": ["HC_M2_0606_P::1442370_at", "BXDGeno::01.001.695", - # "BXDPublish::10001"]}' - db_uri = app.config["AUTH_DB"] - privileges = {} - user = User(uuid.uuid4(), "anon@ymous.user", "Anonymous User") - with db.connection(db_uri) as auth_conn: - try: - with require_oauth.acquire("profile group resource") as the_token: - user = the_token.user - resources = attach_resources_data( - auth_conn, user_resources(auth_conn, the_token.user)) - resources_roles = user_resource_roles(auth_conn, the_token.user) - privileges = { - resource_id: tuple( - privilege.privilege_id - for roles in resources_roles[resource_id] - for privilege in roles.privileges)#("group:resource:view-resource",) - for resource_id, is_authorised - in authorised_for( - auth_conn, the_token.user, - ("group:resource:view-resource",), tuple( - resource.resource_id for resource in resources)).items() - if is_authorised - } - except _HTTPException as exc: - err_msg = json.loads(exc.body) - if err_msg["error"] == "missing_authorization": - resources = attach_resources_data( - auth_conn, public_resources(auth_conn)) - else: - raise exc from None - - def __gen_key__(resource, data_item): - if resource.resource_category.resource_category_key.lower() == "phenotype": - return ( - f"{resource.resource_category.resource_category_key.lower()}::" - f"{data_item['dataset_name']}::{data_item['PublishXRefId']}") - return ( - f"{resource.resource_category.resource_category_key.lower()}::" - f"{data_item['dataset_name']}") - - data_to_resource_map = { - __gen_key__(resource, data_item): resource.resource_id - for resource in resources - for data_item in resource.resource_data - } - privileges = { - **{ - resource.resource_id: ("system:resource:public-read",) - for resource in resources if resource.public - }, - **privileges} - - args = request.get_json() - traits_names = args["traits"] # type: ignore[index] - def __translate__(val): - return { - "Temp": "Temp", - "ProbeSet": "mRNA", - "Geno": "Genotype", - "Publish": "Phenotype" - }[val] - - def __trait_key__(trait): - dataset_type = __translate__(trait['db']['dataset_type']).lower() - dataset_name = trait["db"]["dataset_name"] - if dataset_type == "phenotype": - return f"{dataset_type}::{dataset_name}::{trait['trait_name']}" - return f"{dataset_type}::{dataset_name}" - - return jsonify(tuple( - { - "user": user._asdict(), - **{key:trait[key] for key in ("trait_fullname", "trait_name")}, - "dataset_name": trait["db"]["dataset_name"], - "dataset_type": __translate__(trait["db"]["dataset_type"]), - "resource_id": data_to_resource_map.get(__trait_key__(trait)), - "privileges": privileges.get( - data_to_resource_map.get( - __trait_key__(trait), - uuid.UUID("4afa415e-94cb-4189-b2c6-f9ce2b6a878d")), - tuple()) + ( - # Temporary traits do not exist in db: Set them - # as public-read - ("system:resource:public-read",) - if trait["db"]["dataset_type"] == "Temp" - else tuple()) - } for trait in - (build_trait_name(trait_fullname) - for trait_fullname in traits_names))) - -def __search_mrna__(): - query = __request_key__("query", "") - limit = int(__request_key__("limit", 10000)) - offset = int(__request_key__("offset", 0)) - with gn3db.database_connection(app.config["SQL_URI"]) as gn3conn: - __ungrouped__ = partial( - ungrouped_mrna_data, gn3conn=gn3conn, search_query=query, - selected=__request_key_list__("selected"), - limit=limit, offset=offset) - return jsonify(with_db_connection(__ungrouped__)) - -def __request_key__(key: str, default: Any = ""): - if bool(request.json): - return request.json.get(#type: ignore[union-attr] - key, request.args.get(key, request.form.get(key, default))) - return request.args.get(key, request.form.get(key, default)) - -def __request_key_list__(key: str, default: tuple[Any, ...] = tuple()): - if bool(request.json): - return (request.json.get(key,[])#type: ignore[union-attr] - or request.args.getlist(key) or request.form.getlist(key) - or list(default)) - return (request.args.getlist(key) - or request.form.getlist(key) or list(default)) - -def __search_genotypes__(): - query = __request_key__("query", "") - limit = int(__request_key__("limit", 10000)) - offset = int(__request_key__("offset", 0)) - with gn3db.database_connection(app.config["SQL_URI"]) as gn3conn: - __ungrouped__ = partial( - ungrouped_genotype_data, gn3conn=gn3conn, search_query=query, - selected=__request_key_list__("selected"), - limit=limit, offset=offset) - return jsonify(with_db_connection(__ungrouped__)) - -def __search_phenotypes__(): - # launch the external process to search for phenotypes - redisuri = app.config["REDIS_URI"] - with redis.Redis.from_url(redisuri, decode_responses=True) as redisconn: - job_id = uuid.uuid4() - selected = __request_key__("selected_traits", []) - command =[ - sys.executable, "-m", "scripts.search_phenotypes", - __request_key__("species_name"), - __request_key__("query"), - str(job_id), - f"--host={__request_key__('gn3_server_uri')}", - f"--auth-db-uri={app.config['AUTH_DB']}", - f"--gn3-db-uri={app.config['SQL_URI']}", - f"--redis-uri={redisuri}", - f"--per-page={__request_key__('per_page')}"] +( - [f"--selected={json.dumps(selected)}"] - if len(selected) > 0 else []) - jobs.create_job(redisconn, { - "job_id": job_id, "command": command, "status": "queued", - "search_results": tuple()}) - return jsonify({ - "job_id": job_id, - "command_id": run_async_cmd( - redisconn, app.config.get("REDIS_JOB_QUEUE"), command), - "command": command - }) - -@data.route("/search", methods=["GET"]) -@require_oauth("profile group resource") -def search_unlinked_data(): - """Search for various unlinked data.""" - dataset_type = request.json["dataset_type"] - search_fns = { - "mrna": __search_mrna__, - "genotype": __search_genotypes__, - "phenotype": __search_phenotypes__ - } - return search_fns[dataset_type]() - -@data.route("/search/phenotype/<uuid:job_id>", methods=["GET"]) -def pheno_search_results(job_id: uuid.UUID) -> Response: - """Get the search results from the external script""" - def __search_error__(err): - raise NotFoundError(err["error_description"]) - redisuri = app.config["REDIS_URI"] - with redis.Redis.from_url(redisuri, decode_responses=True) as redisconn: - return jobs.job(redisconn, job_id).either( - __search_error__, jsonify) - -@data.route("/link/genotype", methods=["POST"]) -def link_genotypes() -> Response: - """Link genotype data to group.""" - def __values__(form) -> dict[str, Any]: - if not bool(form.get("species_name", "").strip()): - raise InvalidData("Expected 'species_name' not provided.") - if not bool(form.get("group_id")): - raise InvalidData("Expected 'group_id' not provided.",) - try: - _group_id = uuid.UUID(form.get("group_id")) - except TypeError as terr: - raise InvalidData("Expected a UUID for 'group_id' value.") from terr - if not bool(form.get("selected")): - raise InvalidData("Expected at least one dataset to be provided.") - return { - "group_id": uuid.UUID(form.get("group_id")), - "datasets": form.get("selected") - } - - def __link__(conn: db.DbConnection, group_id: uuid.UUID, datasets: dict): - return link_genotype_data(conn, group_by_id(conn, group_id), datasets) - - return jsonify(with_db_connection( - partial(__link__, **__values__(request.json)))) - -@data.route("/link/mrna", methods=["POST"]) -def link_mrna() -> Response: - """Link mrna data to group.""" - def __values__(form) -> dict[str, Any]: - if not bool(form.get("species_name", "").strip()): - raise InvalidData("Expected 'species_name' not provided.") - if not bool(form.get("group_id")): - raise InvalidData("Expected 'group_id' not provided.",) - try: - _group_id = uuid.UUID(form.get("group_id")) - except TypeError as terr: - raise InvalidData("Expected a UUID for 'group_id' value.") from terr - if not bool(form.get("selected")): - raise InvalidData("Expected at least one dataset to be provided.") - return { - "group_id": uuid.UUID(form.get("group_id")), - "datasets": form.get("selected") - } - - def __link__(conn: db.DbConnection, group_id: uuid.UUID, datasets: dict): - return link_mrna_data(conn, group_by_id(conn, group_id), datasets) - - return jsonify(with_db_connection( - partial(__link__, **__values__(request.json)))) - -@data.route("/link/phenotype", methods=["POST"]) -def link_phenotype() -> Response: - """Link phenotype data to group.""" - def __values__(form): - if not bool(form.get("species_name", "").strip()): - raise InvalidData("Expected 'species_name' not provided.") - if not bool(form.get("group_id")): - raise InvalidData("Expected 'group_id' not provided.",) - try: - _group_id = uuid.UUID(form.get("group_id")) - except TypeError as terr: - raise InvalidData("Expected a UUID for 'group_id' value.") from terr - if not bool(form.get("selected")): - raise InvalidData("Expected at least one dataset to be provided.") - return { - "group_id": uuid.UUID(form["group_id"]), - "traits": form["selected"] - } - - with gn3db.database_connection(app.config["SQL_URI"]) as gn3conn: - def __link__(conn: db.DbConnection, group_id: uuid.UUID, - traits: tuple[dict, ...]) -> dict: - return link_phenotype_data( - conn, gn3conn, group_by_id(conn, group_id), traits) - - return jsonify(with_db_connection( - partial(__link__, **__values__(request.json)))) diff --git a/gn3/auth/authorisation/errors.py b/gn3/auth/authorisation/errors.py deleted file mode 100644 index 3bc7a04..0000000 --- a/gn3/auth/authorisation/errors.py +++ /dev/null @@ -1,42 +0,0 @@ -"""Authorisation exceptions""" - -class AuthorisationError(Exception): - """ - Top-level exception for the `gn3.auth.authorisation` package. - - All exceptions in this package should inherit from this class. - """ - error_code: int = 400 - -class ForbiddenAccess(AuthorisationError): - """Raised for forbidden access.""" - error_code: int = 403 - -class UserRegistrationError(AuthorisationError): - """Raised whenever a user registration fails""" - -class NotFoundError(AuthorisationError): - """Raised whenever we try fetching (a/an) object(s) that do(es) not exist.""" - error_code: int = 404 - -class InvalidData(AuthorisationError): - """ - Exception if user requests invalid data - """ - error_code: int = 400 - -class InconsistencyError(AuthorisationError): - """ - Exception raised due to data inconsistencies - """ - error_code: int = 500 - -class PasswordError(AuthorisationError): - """ - Raise in case of an error with passwords. - """ - -class UsernameError(AuthorisationError): - """ - Raise in case of an error with a user's name. - """ diff --git a/gn3/auth/authorisation/groups/__init__.py b/gn3/auth/authorisation/groups/__init__.py deleted file mode 100644 index 1cb0bba..0000000 --- a/gn3/auth/authorisation/groups/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -"""Initialise the `gn3.auth.authorisation.groups` package""" - -from .models import Group, GroupRole diff --git a/gn3/auth/authorisation/groups/data.py b/gn3/auth/authorisation/groups/data.py deleted file mode 100644 index ee6f70e..0000000 --- a/gn3/auth/authorisation/groups/data.py +++ /dev/null @@ -1,106 +0,0 @@ -"""Handles the resource objects' data.""" -from MySQLdb.cursors import DictCursor - -from gn3 import db_utils as gn3db -from gn3.auth import db as authdb -from gn3.auth.authorisation.groups import Group -from gn3.auth.authorisation.checks import authorised_p -from gn3.auth.authorisation.errors import NotFoundError - -def __fetch_mrna_data_by_ids__( - conn: gn3db.Connection, dataset_ids: tuple[str, ...]) -> tuple[ - dict, ...]: - """Fetch mRNA Assay data by ID.""" - with conn.cursor(DictCursor) as cursor: - paramstr = ", ".join(["%s"] * len(dataset_ids)) - cursor.execute( - "SELECT psf.Id, psf.Name AS dataset_name, " - "psf.FullName AS dataset_fullname, " - "ifiles.GN_AccesionId AS accession_id FROM ProbeSetFreeze AS psf " - "INNER JOIN InfoFiles AS ifiles ON psf.Name=ifiles.InfoPageName " - f"WHERE psf.Id IN ({paramstr})", - dataset_ids) - res = cursor.fetchall() - if res: - return tuple(dict(row) for row in res) - raise NotFoundError("Could not find mRNA Assay data with the given ID.") - -def __fetch_geno_data_by_ids__( - conn: gn3db.Connection, dataset_ids: tuple[str, ...]) -> tuple[ - dict, ...]: - """Fetch genotype data by ID.""" - with conn.cursor(DictCursor) as cursor: - paramstr = ", ".join(["%s"] * len(dataset_ids)) - cursor.execute( - "SELECT gf.Id, gf.Name AS dataset_name, " - "gf.FullName AS dataset_fullname, " - "ifiles.GN_AccesionId AS accession_id FROM GenoFreeze AS gf " - "INNER JOIN InfoFiles AS ifiles ON gf.Name=ifiles.InfoPageName " - f"WHERE gf.Id IN ({paramstr})", - dataset_ids) - res = cursor.fetchall() - if res: - return tuple(dict(row) for row in res) - raise NotFoundError("Could not find Genotype data with the given ID.") - -def __fetch_pheno_data_by_ids__( - conn: gn3db.Connection, dataset_ids: tuple[str, ...]) -> tuple[ - dict, ...]: - """Fetch phenotype data by ID.""" - with conn.cursor(DictCursor) as cursor: - paramstr = ", ".join(["%s"] * len(dataset_ids)) - cursor.execute( - "SELECT pxf.Id, iset.InbredSetName, pf.Id AS dataset_id, " - "pf.Name AS dataset_name, pf.FullName AS dataset_fullname, " - "ifiles.GN_AccesionId AS accession_id " - "FROM PublishXRef AS pxf " - "INNER JOIN InbredSet AS iset ON pxf.InbredSetId=iset.InbredSetId " - "INNER JOIN PublishFreeze AS pf ON iset.InbredSetId=pf.InbredSetId " - "INNER JOIN InfoFiles AS ifiles ON pf.Name=ifiles.InfoPageName " - f"WHERE pxf.Id IN ({paramstr})", - dataset_ids) - res = cursor.fetchall() - if res: - return tuple(dict(row) for row in res) - raise NotFoundError( - "Could not find Phenotype/Publish data with the given IDs.") - -def __fetch_data_by_id( - conn: gn3db.Connection, dataset_type: str, - dataset_ids: tuple[str, ...]) -> tuple[dict, ...]: - """Fetch data from MySQL by IDs.""" - fetch_fns = { - "mrna": __fetch_mrna_data_by_ids__, - "genotype": __fetch_geno_data_by_ids__, - "phenotype": __fetch_pheno_data_by_ids__ - } - return fetch_fns[dataset_type](conn, dataset_ids) - -@authorised_p(("system:data:link-to-group",), - error_description=( - "You do not have sufficient privileges to link data to (a) " - "group(s)."), - oauth2_scope="profile group resource") -def link_data_to_group( - authconn: authdb.DbConnection, gn3conn: gn3db.Connection, - dataset_type: str, dataset_ids: tuple[str, ...], group: Group) -> tuple[ - dict, ...]: - """Link the given data to the specified group.""" - the_data = __fetch_data_by_id(gn3conn, dataset_type, dataset_ids) - with authdb.cursor(authconn) as cursor: - params = tuple({ - "group_id": str(group.group_id), "dataset_type": { - "mrna": "mRNA", "genotype": "Genotype", - "phenotype": "Phenotype" - }[dataset_type], - "dataset_or_trait_id": item["Id"], - "dataset_name": item["dataset_name"], - "dataset_fullname": item["dataset_fullname"], - "accession_id": item["accession_id"] - } for item in the_data) - cursor.executemany( - "INSERT INTO linked_group_data VALUES" - "(:group_id, :dataset_type, :dataset_or_trait_id, :dataset_name, " - ":dataset_fullname, :accession_id)", - params) - return params diff --git a/gn3/auth/authorisation/groups/models.py b/gn3/auth/authorisation/groups/models.py deleted file mode 100644 index 7212a78..0000000 --- a/gn3/auth/authorisation/groups/models.py +++ /dev/null @@ -1,400 +0,0 @@ -"""Handle the management of resource/user groups.""" -import json -from uuid import UUID, uuid4 -from functools import reduce -from typing import Any, Sequence, Iterable, Optional, NamedTuple - -from flask import g -from pymonad.maybe import Just, Maybe, Nothing - -from gn3.auth import db -from gn3.auth.dictify import dictify -from gn3.auth.authorisation.users import User, user_by_id - -from ..checks import authorised_p -from ..privileges import Privilege -from ..errors import NotFoundError, AuthorisationError, InconsistencyError -from ..roles.models import ( - Role, create_role, check_user_editable, revoke_user_role_by_name, - assign_user_role_by_name) - -class Group(NamedTuple): - """Class representing a group.""" - group_id: UUID - group_name: str - group_metadata: dict[str, Any] - - def dictify(self): - """Return a dict representation of `Group` objects.""" - return { - "group_id": self.group_id, "group_name": self.group_name, - "group_metadata": self.group_metadata - } - -DUMMY_GROUP = Group( - group_id=UUID("77cee65b-fe29-4383-ae41-3cb3b480cc70"), - group_name="GN3_DUMMY_GROUP", - group_metadata={ - "group-description": "This is a dummy group to use as a placeholder" - }) - -class GroupRole(NamedTuple): - """Class representing a role tied/belonging to a group.""" - group_role_id: UUID - group: Group - role: Role - - def dictify(self) -> dict[str, Any]: - """Return a dict representation of `GroupRole` objects.""" - return { - "group_role_id": self.group_role_id, "group": dictify(self.group), - "role": dictify(self.role) - } - -class GroupCreationError(AuthorisationError): - """Raised whenever a group creation fails""" - -class MembershipError(AuthorisationError): - """Raised when there is an error with a user's membership to a group.""" - - def __init__(self, user: User, groups: Sequence[Group]): - """Initialise the `MembershipError` exception object.""" - groups_str = ", ".join(group.group_name for group in groups) - error_description = ( - f"User '{user.name} ({user.email})' is a member of {len(groups)} " - f"groups ({groups_str})") - super().__init__(f"{type(self).__name__}: {error_description}.") - -def user_membership(conn: db.DbConnection, user: User) -> Sequence[Group]: - """Returns all the groups that a member belongs to""" - query = ( - "SELECT groups.group_id, group_name, groups.group_metadata " - "FROM group_users INNER JOIN groups " - "ON group_users.group_id=groups.group_id " - "WHERE group_users.user_id=?") - with db.cursor(conn) as cursor: - cursor.execute(query, (str(user.user_id),)) - groups = tuple(Group(row[0], row[1], json.loads(row[2])) - for row in cursor.fetchall()) - - return groups - -@authorised_p( - privileges = ("system:group:create-group",), - error_description = ( - "You do not have the appropriate privileges to enable you to " - "create a new group."), - oauth2_scope = "profile group") -def create_group( - conn: db.DbConnection, group_name: str, group_leader: User, - group_description: Optional[str] = None) -> Group: - """Create a new group.""" - user_groups = user_membership(conn, group_leader) - if len(user_groups) > 0: - raise MembershipError(group_leader, user_groups) - - with db.cursor(conn) as cursor: - new_group = save_group( - cursor, group_name,( - {"group_description": group_description} - if group_description else {})) - add_user_to_group(cursor, new_group, group_leader) - revoke_user_role_by_name(cursor, group_leader, "group-creator") - assign_user_role_by_name(cursor, group_leader, "group-leader") - return new_group - -@authorised_p(("group:role:create-role",), - error_description="Could not create the group role") -def create_group_role( - conn: db.DbConnection, group: Group, role_name: str, - privileges: Iterable[Privilege]) -> GroupRole: - """Create a role attached to a group.""" - with db.cursor(conn) as cursor: - group_role_id = uuid4() - role = create_role(cursor, role_name, privileges) - cursor.execute( - ("INSERT INTO group_roles(group_role_id, group_id, role_id) " - "VALUES(?, ?, ?)"), - (str(group_role_id), str(group.group_id), str(role.role_id))) - - return GroupRole(group_role_id, group, role) - -def authenticated_user_group(conn) -> Maybe: - """ - Returns the currently authenticated user's group. - - Look into returning a Maybe object. - """ - user = g.user - with db.cursor(conn) as cursor: - cursor.execute( - ("SELECT groups.* FROM group_users " - "INNER JOIN groups ON group_users.group_id=groups.group_id " - "WHERE group_users.user_id = ?"), - (str(user.user_id),)) - groups = tuple(Group(UUID(row[0]), row[1], json.loads(row[2] or "{}")) - for row in cursor.fetchall()) - - if len(groups) > 1: - raise MembershipError(user, groups) - - if len(groups) == 1: - return Just(groups[0]) - - return Nothing - -def user_group(conn: db.DbConnection, user: User) -> Maybe[Group]: - """Returns the given user's group""" - with db.cursor(conn) as cursor: - cursor.execute( - ("SELECT groups.group_id, groups.group_name, groups.group_metadata " - "FROM group_users " - "INNER JOIN groups ON group_users.group_id=groups.group_id " - "WHERE group_users.user_id = ?"), - (str(user.user_id),)) - groups = tuple( - Group(UUID(row[0]), row[1], json.loads(row[2] or "{}")) - for row in cursor.fetchall()) - - if len(groups) > 1: - raise MembershipError(user, groups) - - if len(groups) == 1: - return Just(groups[0]) - - return Nothing - -def is_group_leader(conn: db.DbConnection, user: User, group: Group) -> bool: - """Check whether the given `user` is the leader of `group`.""" - - ugroup = user_group(conn, user).maybe( - False, lambda val: val) # type: ignore[arg-type, misc] - if not group: - # User cannot be a group leader if not a member of ANY group - return False - - if not ugroup == group: - # User cannot be a group leader if not a member of THIS group - return False - - with db.cursor(conn) as cursor: - cursor.execute( - ("SELECT roles.role_name FROM user_roles LEFT JOIN roles " - "ON user_roles.role_id = roles.role_id WHERE user_id = ?"), - (str(user.user_id),)) - role_names = tuple(row[0] for row in cursor.fetchall()) - - return "group-leader" in role_names - -def all_groups(conn: db.DbConnection) -> Maybe[Sequence[Group]]: - """Retrieve all existing groups""" - with db.cursor(conn) as cursor: - cursor.execute("SELECT * FROM groups") - res = cursor.fetchall() - if res: - return Just(tuple( - Group(row["group_id"], row["group_name"], - json.loads(row["group_metadata"])) for row in res)) - - return Nothing - -def save_group( - cursor: db.DbCursor, group_name: str, - group_metadata: dict[str, Any]) -> Group: - """Save a group to db""" - the_group = Group(uuid4(), group_name, group_metadata) - cursor.execute( - ("INSERT INTO groups " - "VALUES(:group_id, :group_name, :group_metadata) " - "ON CONFLICT (group_id) DO UPDATE SET " - "group_name=:group_name, group_metadata=:group_metadata"), - {"group_id": str(the_group.group_id), "group_name": the_group.group_name, - "group_metadata": json.dumps(the_group.group_metadata)}) - return the_group - -def add_user_to_group(cursor: db.DbCursor, the_group: Group, user: User): - """Add `user` to `the_group` as a member.""" - cursor.execute( - ("INSERT INTO group_users VALUES (:group_id, :user_id) " - "ON CONFLICT (group_id, user_id) DO NOTHING"), - {"group_id": str(the_group.group_id), "user_id": str(user.user_id)}) - -@authorised_p( - privileges = ("system:group:view-group",), - error_description = ( - "You do not have the appropriate privileges to access the list of users" - " in the group.")) -def group_users(conn: db.DbConnection, group_id: UUID) -> Iterable[User]: - """Retrieve all users that are members of group with id `group_id`.""" - with db.cursor(conn) as cursor: - cursor.execute( - "SELECT u.* FROM group_users AS gu INNER JOIN users AS u " - "ON gu.user_id = u.user_id WHERE gu.group_id=:group_id", - {"group_id": str(group_id)}) - results = cursor.fetchall() - - return (User(UUID(row["user_id"]), row["email"], row["name"]) - for row in results) - -@authorised_p( - privileges = ("system:group:view-group",), - error_description = ( - "You do not have the appropriate privileges to access the group.")) -def group_by_id(conn: db.DbConnection, group_id: UUID) -> Group: - """Retrieve a group by its ID""" - with db.cursor(conn) as cursor: - cursor.execute("SELECT * FROM groups WHERE group_id=:group_id", - {"group_id": str(group_id)}) - row = cursor.fetchone() - if row: - return Group( - UUID(row["group_id"]), - row["group_name"], - json.loads(row["group_metadata"])) - - raise NotFoundError(f"Could not find group with ID '{group_id}'.") - -@authorised_p(("system:group:view-group", "system:group:edit-group"), - error_description=("You do not have the appropriate authorisation" - " to act upon the join requests."), - oauth2_scope="profile group") -def join_requests(conn: db.DbConnection, user: User): - """List all the join requests for the user's group.""" - with db.cursor(conn) as cursor: - group = user_group(conn, user).maybe(DUMMY_GROUP, lambda grp: grp)# type: ignore[misc] - if group != DUMMY_GROUP and is_group_leader(conn, user, group): - cursor.execute( - "SELECT gjr.*, u.email, u.name FROM group_join_requests AS gjr " - "INNER JOIN users AS u ON gjr.requester_id=u.user_id " - "WHERE gjr.group_id=? AND gjr.status='PENDING'", - (str(group.group_id),)) - return tuple(dict(row)for row in cursor.fetchall()) - - raise AuthorisationError( - "You do not have the appropriate authorisation to access the " - "group's join requests.") - -@authorised_p(("system:group:view-group", "system:group:edit-group"), - error_description=("You do not have the appropriate authorisation" - " to act upon the join requests."), - oauth2_scope="profile group") -def accept_reject_join_request( - conn: db.DbConnection, request_id: UUID, user: User, status: str) -> dict: - """Accept/Reject a join request.""" - assert status in ("ACCEPTED", "REJECTED"), f"Invalid status '{status}'." - with db.cursor(conn) as cursor: - group = user_group(conn, user).maybe(DUMMY_GROUP, lambda grp: grp) # type: ignore[misc] - cursor.execute("SELECT * FROM group_join_requests WHERE request_id=?", - (str(request_id),)) - row = cursor.fetchone() - if row: - if group.group_id == UUID(row["group_id"]): - try: - the_user = user_by_id(conn, UUID(row["requester_id"])) - if status == "ACCEPTED": - add_user_to_group(cursor, group, the_user) - revoke_user_role_by_name(cursor, the_user, "group-creator") - cursor.execute( - "UPDATE group_join_requests SET status=? " - "WHERE request_id=?", - (status, str(request_id))) - return {"request_id": request_id, "status": status} - except NotFoundError as nfe: - raise InconsistencyError( - "Could not find user associated with join request." - ) from nfe - raise AuthorisationError( - "You cannot act on other groups join requests") - raise NotFoundError(f"Could not find request with ID '{request_id}'") - -def __organise_privileges__(acc, row): - role_id = UUID(row["role_id"]) - role = acc.get(role_id, False) - if role: - return { - **acc, - role_id: Role( - role.role_id, role.role_name, - bool(int(row["user_editable"])), - role.privileges + ( - Privilege(row["privilege_id"], - row["privilege_description"]),)) - } - return { - **acc, - role_id: Role( - UUID(row["role_id"]), row["role_name"], - bool(int(row["user_editable"])), - (Privilege(row["privilege_id"], row["privilege_description"]),)) - } - -# @authorised_p(("group:role:view",), -# "Insufficient privileges to view role", -# oauth2_scope="profile group role") -def group_role_by_id( - conn: db.DbConnection, group: Group, group_role_id: UUID) -> GroupRole: - """Retrieve GroupRole from id by its `group_role_id`.""" - ## TODO: do privileges check before running actual query - ## the check commented out above doesn't work correctly - with db.cursor(conn) as cursor: - cursor.execute( - "SELECT gr.group_role_id, r.*, p.* " - "FROM group_roles AS gr " - "INNER JOIN roles AS r ON gr.role_id=r.role_id " - "INNER JOIN role_privileges AS rp ON rp.role_id=r.role_id " - "INNER JOIN privileges AS p ON p.privilege_id=rp.privilege_id " - "WHERE gr.group_role_id=? AND gr.group_id=?", - (str(group_role_id), str(group.group_id))) - rows = cursor.fetchall() - if rows: - roles: tuple[Role,...] = tuple(reduce( - __organise_privileges__, rows, {}).values()) - assert len(roles) == 1 - return GroupRole(group_role_id, group, roles[0]) - raise NotFoundError( - f"Group role with ID '{group_role_id}' does not exist.") - -@authorised_p(("group:role:edit-role",), - "You do not have the privilege to edit a role.", - oauth2_scope="profile group role") -def add_privilege_to_group_role(conn: db.DbConnection, group_role: GroupRole, - privilege: Privilege) -> GroupRole: - """Add `privilege` to `group_role`.""" - ## TODO: do privileges check. - check_user_editable(group_role.role) - with db.cursor(conn) as cursor: - cursor.execute( - "INSERT INTO role_privileges(role_id,privilege_id) " - "VALUES (?, ?) ON CONFLICT (role_id, privilege_id) " - "DO NOTHING", - (str(group_role.role.role_id), str(privilege.privilege_id))) - return GroupRole( - group_role.group_role_id, - group_role.group, - Role(group_role.role.role_id, - group_role.role.role_name, - group_role.role.user_editable, - group_role.role.privileges + (privilege,))) - -@authorised_p(("group:role:edit-role",), - "You do not have the privilege to edit a role.", - oauth2_scope="profile group role") -def delete_privilege_from_group_role( - conn: db.DbConnection, group_role: GroupRole, - privilege: Privilege) -> GroupRole: - """Delete `privilege` to `group_role`.""" - ## TODO: do privileges check. - check_user_editable(group_role.role) - with db.cursor(conn) as cursor: - cursor.execute( - "DELETE FROM role_privileges WHERE " - "role_id=? AND privilege_id=?", - (str(group_role.role.role_id), str(privilege.privilege_id))) - return GroupRole( - group_role.group_role_id, - group_role.group, - Role(group_role.role.role_id, - group_role.role.role_name, - group_role.role.user_editable, - tuple(priv for priv in group_role.role.privileges - if priv != privilege))) diff --git a/gn3/auth/authorisation/groups/views.py b/gn3/auth/authorisation/groups/views.py deleted file mode 100644 index a849a73..0000000 --- a/gn3/auth/authorisation/groups/views.py +++ /dev/null @@ -1,430 +0,0 @@ -"""The views/routes for the `gn3.auth.authorisation.groups` package.""" -import uuid -import datetime -from typing import Iterable -from functools import partial - -from MySQLdb.cursors import DictCursor -from flask import request, jsonify, Response, Blueprint, current_app - -from gn3.auth import db -from gn3 import db_utils as gn3db - -from gn3.auth.dictify import dictify -from gn3.auth.db_utils import with_db_connection -from gn3.auth.authorisation.users import User -from gn3.auth.authorisation.oauth2.resource_server import require_oauth - -from .data import link_data_to_group -from .models import ( - Group, user_group, all_groups, DUMMY_GROUP, GroupRole, group_by_id, - join_requests, group_role_by_id, GroupCreationError, - accept_reject_join_request, group_users as _group_users, - create_group as _create_group, add_privilege_to_group_role, - delete_privilege_from_group_role, create_group_role as _create_group_role) - -from ..roles.models import Role -from ..roles.models import user_roles - -from ..checks import authorised_p -from ..privileges import Privilege, privileges_by_ids -from ..errors import InvalidData, NotFoundError, AuthorisationError - -groups = Blueprint("groups", __name__) - -@groups.route("/list", methods=["GET"]) -@require_oauth("profile group") -def list_groups(): - """Return the list of groups that exist.""" - with db.connection(current_app.config["AUTH_DB"]) as conn: - the_groups = all_groups(conn) - - return jsonify(the_groups.maybe( - [], lambda grps: [dictify(grp) for grp in grps])) - -@groups.route("/create", methods=["POST"]) -@require_oauth("profile group") -def create_group(): - """Create a new group.""" - with require_oauth.acquire("profile group") as the_token: - group_name=request.form.get("group_name", "").strip() - if not bool(group_name): - raise GroupCreationError("Could not create the group.") - - db_uri = current_app.config["AUTH_DB"] - with db.connection(db_uri) as conn: - user = the_token.user - new_group = _create_group( - conn, group_name, user, request.form.get("group_description")) - return jsonify({ - **dictify(new_group), "group_leader": dictify(user) - }) - -@groups.route("/members/<uuid:group_id>", methods=["GET"]) -@require_oauth("profile group") -def group_members(group_id: uuid.UUID) -> Response: - """Retrieve all the members of a group.""" - with require_oauth.acquire("profile group") as the_token:# pylint: disable=[unused-variable] - db_uri = current_app.config["AUTH_DB"] - ## Check that user has appropriate privileges and remove the pylint disable above - with db.connection(db_uri) as conn: - return jsonify(tuple( - dictify(user) for user in _group_users(conn, group_id))) - -@groups.route("/requests/join/<uuid:group_id>", methods=["POST"]) -@require_oauth("profile group") -def request_to_join(group_id: uuid.UUID) -> Response: - """Request to join a group.""" - def __request__(conn: db.DbConnection, user: User, group_id: uuid.UUID, - message: str): - with db.cursor(conn) as cursor: - group = user_group(conn, user).maybe(# type: ignore[misc] - False, lambda grp: grp)# type: ignore[arg-type] - if group: - error = AuthorisationError( - "You cannot request to join a new group while being a " - "member of an existing group.") - error.error_code = 400 - raise error - request_id = uuid.uuid4() - cursor.execute( - "INSERT INTO group_join_requests VALUES " - "(:request_id, :group_id, :user_id, :ts, :status, :msg)", - { - "request_id": str(request_id), - "group_id": str(group_id), - "user_id": str(user.user_id), - "ts": datetime.datetime.now().timestamp(), - "status": "PENDING", - "msg": message - }) - return { - "request_id": request_id, - "message": "Successfully sent the join request." - } - - with require_oauth.acquire("profile group") as the_token: - form = request.form - results = with_db_connection(partial( - __request__, user=the_token.user, group_id=group_id, message=form.get( - "message", "I hereby request that you add me to your group."))) - return jsonify(results) - -@groups.route("/requests/join/list", methods=["GET"]) -@require_oauth("profile group") -def list_join_requests() -> Response: - """List the pending join requests.""" - with require_oauth.acquire("profile group") as the_token: - return jsonify(with_db_connection(partial( - join_requests, user=the_token.user))) - -@groups.route("/requests/join/accept", methods=["POST"]) -@require_oauth("profile group") -def accept_join_requests() -> Response: - """Accept a join request.""" - with require_oauth.acquire("profile group") as the_token: - form = request.form - request_id = uuid.UUID(form.get("request_id")) - return jsonify(with_db_connection(partial( - accept_reject_join_request, request_id=request_id, - user=the_token.user, status="ACCEPTED"))) - -@groups.route("/requests/join/reject", methods=["POST"]) -@require_oauth("profile group") -def reject_join_requests() -> Response: - """Reject a join request.""" - with require_oauth.acquire("profile group") as the_token: - form = request.form - request_id = uuid.UUID(form.get("request_id")) - return jsonify(with_db_connection(partial( - accept_reject_join_request, request_id=request_id, - user=the_token.user, status="REJECTED"))) - -def unlinked_mrna_data( - conn: db.DbConnection, group: Group) -> tuple[dict, ...]: - """ - Retrieve all mRNA Assay data linked to a group but not linked to any - resource. - """ - query = ( - "SELECT lmd.* FROM linked_mrna_data lmd " - "LEFT JOIN mrna_resources mr ON lmd.data_link_id=mr.data_link_id " - "WHERE lmd.group_id=? AND mr.data_link_id IS NULL") - with db.cursor(conn) as cursor: - cursor.execute(query, (str(group.group_id),)) - return tuple(dict(row) for row in cursor.fetchall()) - -def unlinked_genotype_data( - conn: db.DbConnection, group: Group) -> tuple[dict, ...]: - """ - Retrieve all genotype data linked to a group but not linked to any resource. - """ - query = ( - "SELECT lgd.* FROM linked_genotype_data lgd " - "LEFT JOIN genotype_resources gr ON lgd.data_link_id=gr.data_link_id " - "WHERE lgd.group_id=? AND gr.data_link_id IS NULL") - with db.cursor(conn) as cursor: - cursor.execute(query, (str(group.group_id),)) - return tuple(dict(row) for row in cursor.fetchall()) - -def unlinked_phenotype_data( - authconn: db.DbConnection, gn3conn: gn3db.Connection, - group: Group) -> tuple[dict, ...]: - """ - Retrieve all phenotype data linked to a group but not linked to any - resource. - """ - with db.cursor(authconn) as authcur, gn3conn.cursor(DictCursor) as gn3cur: - authcur.execute( - "SELECT lpd.* FROM linked_phenotype_data AS lpd " - "LEFT JOIN phenotype_resources AS pr " - "ON lpd.data_link_id=pr.data_link_id " - "WHERE lpd.group_id=? AND pr.data_link_id IS NULL", - (str(group.group_id),)) - results = authcur.fetchall() - ids: dict[tuple[str, ...], str] = { - ( - row["SpeciesId"], row["InbredSetId"], row["PublishFreezeId"], - row["PublishXRefId"]): row["data_link_id"] - for row in results - } - if len(ids.keys()) < 1: - return tuple() - paramstr = ", ".join(["(%s, %s, %s, %s)"] * len(ids.keys())) - gn3cur.execute( - "SELECT spc.SpeciesId, spc.SpeciesName, iset.InbredSetId, " - "iset.InbredSetName, pf.Id AS PublishFreezeId, " - "pf.Name AS dataset_name, pf.FullName AS dataset_fullname, " - "pf.ShortName AS dataset_shortname, pxr.Id AS PublishXRefId, " - "pub.PubMed_ID, pub.Title, pub.Year, " - "phen.Pre_publication_description, " - "phen.Post_publication_description, phen.Original_description " - "FROM " - "Species AS spc " - "INNER JOIN InbredSet AS iset " - "ON spc.SpeciesId=iset.SpeciesId " - "INNER JOIN PublishFreeze AS pf " - "ON iset.InbredSetId=pf.InbredSetId " - "INNER JOIN PublishXRef AS pxr " - "ON pf.InbredSetId=pxr.InbredSetId " - "INNER JOIN Publication AS pub " - "ON pxr.PublicationId=pub.Id " - "INNER JOIN Phenotype AS phen " - "ON pxr.PhenotypeId=phen.Id " - "WHERE (spc.SpeciesId, iset.InbredSetId, pf.Id, pxr.Id) " - f"IN ({paramstr})", - tuple(item for sublist in ids.keys() for item in sublist)) - return tuple({ - **{key: value for key, value in row.items() if key not in - ("Post_publication_description", "Pre_publication_description", - "Original_description")}, - "description": ( - row["Post_publication_description"] or - row["Pre_publication_description"] or - row["Original_description"]), - "data_link_id": ids[tuple(str(row[key]) for key in ( - "SpeciesId", "InbredSetId", "PublishFreezeId", - "PublishXRefId"))] - } for row in gn3cur.fetchall()) - -@groups.route("/<string:resource_type>/unlinked-data") -@require_oauth("profile group resource") -def unlinked_data(resource_type: str) -> Response: - """View data linked to the group but not linked to any resource.""" - if resource_type not in ("all", "mrna", "genotype", "phenotype"): - raise AuthorisationError(f"Invalid resource type {resource_type}") - - with require_oauth.acquire("profile group resource") as the_token: - db_uri = current_app.config["AUTH_DB"] - gn3db_uri = current_app.config["SQL_URI"] - with (db.connection(db_uri) as authconn, - gn3db.database_connection(gn3db_uri) as gn3conn): - ugroup = user_group(authconn, the_token.user).maybe(# type: ignore[misc] - DUMMY_GROUP, lambda grp: grp) - if ugroup == DUMMY_GROUP: - return jsonify(tuple()) - - unlinked_fns = { - "mrna": unlinked_mrna_data, - "genotype": unlinked_genotype_data, - "phenotype": lambda conn, grp: partial( - unlinked_phenotype_data, gn3conn=gn3conn)( - authconn=conn, group=grp) - } - return jsonify(tuple( - dict(row) for row in unlinked_fns[resource_type]( - authconn, ugroup))) - - return jsonify(tuple()) - -@groups.route("/data/link", methods=["POST"]) -@require_oauth("profile group resource") -def link_data() -> Response: - """Link selected data to specified group.""" - with require_oauth.acquire("profile group resource") as _the_token: - form = request.form - group_id = uuid.UUID(form["group_id"]) - dataset_ids = form.getlist("dataset_ids") - dataset_type = form.get("dataset_type") - if dataset_type not in ("mrna", "genotype", "phenotype"): - raise InvalidData("Unexpected dataset type requested!") - def __link__(conn: db.DbConnection): - group = group_by_id(conn, group_id) - with gn3db.database_connection(current_app.config["SQL_URI"]) as gn3conn: - return link_data_to_group( - conn, gn3conn, dataset_type, dataset_ids, group) - - return jsonify(with_db_connection(__link__)) - -@groups.route("/roles", methods=["GET"]) -@require_oauth("profile group") -def group_roles(): - """Return a list of all available group roles.""" - with require_oauth.acquire("profile group role") as the_token: - def __list_roles__(conn: db.DbConnection): - ## TODO: Check that user has appropriate privileges - with db.cursor(conn) as cursor: - group = user_group(conn, the_token.user).maybe(# type: ignore[misc] - DUMMY_GROUP, lambda grp: grp) - if group == DUMMY_GROUP: - return tuple() - cursor.execute( - "SELECT gr.group_role_id, r.* " - "FROM group_roles AS gr INNER JOIN roles AS r " - "ON gr.role_id=r.role_id " - "WHERE group_id=?", - (str(group.group_id),)) - return tuple( - GroupRole(uuid.UUID(row["group_role_id"]), - group, - Role(uuid.UUID(row["role_id"]), - row["role_name"], - bool(int(row["user_editable"])), - tuple())) - for row in cursor.fetchall()) - return jsonify(tuple( - dictify(role) for role in with_db_connection(__list_roles__))) - -@groups.route("/privileges", methods=["GET"]) -@require_oauth("profile group") -def group_privileges(): - """Return a list of all available group roles.""" - with require_oauth.acquire("profile group role") as the_token: - def __list_privileges__(conn: db.DbConnection) -> Iterable[Privilege]: - ## TODO: Check that user has appropriate privileges - this_user_roles = user_roles(conn, the_token.user) - with db.cursor(conn) as cursor: - cursor.execute("SELECT * FROM privileges " - "WHERE privilege_id LIKE 'group:%'") - group_level_roles = tuple( - Privilege(row["privilege_id"], row["privilege_description"]) - for row in cursor.fetchall()) - return tuple(privilege for arole in this_user_roles - for privilege in arole.privileges) + group_level_roles - return jsonify(tuple( - dictify(priv) for priv in with_db_connection(__list_privileges__))) - - - -@groups.route("/role/create", methods=["POST"]) -@require_oauth("profile group") -def create_group_role(): - """Create a new group role.""" - with require_oauth.acquire("profile group role") as the_token: - ## TODO: Check that user has appropriate privileges - @authorised_p(("group:role:create-role",), - "You do not have the privilege to create new roles", - oauth2_scope="profile group role") - def __create__(conn: db.DbConnection) -> GroupRole: - ## TODO: Check user cannot assign any privilege they don't have. - form = request.form - role_name = form.get("role_name", "").strip() - privileges_ids = form.getlist("privileges[]") - if len(role_name) == 0: - raise InvalidData("Role name not provided!") - if len(privileges_ids) == 0: - raise InvalidData( - "At least one privilege needs to be provided.") - - group = user_group(conn, the_token.user).maybe(# type: ignore[misc] - DUMMY_GROUP, lambda grp: grp) - - if group == DUMMY_GROUP: - raise AuthorisationError( - "A user without a group cannot create a new role.") - privileges = privileges_by_ids(conn, tuple(privileges_ids)) - if len(privileges_ids) != len(privileges): - raise InvalidData( - f"{len(privileges_ids) - len(privileges)} of the selected " - "privileges were not found in the database.") - - return _create_group_role(conn, group, role_name, privileges) - - return jsonify(with_db_connection(__create__)) - -@groups.route("/role/<uuid:group_role_id>", methods=["GET"]) -@require_oauth("profile group") -def view_group_role(group_role_id: uuid.UUID): - """Return the details of the given role.""" - with require_oauth.acquire("profile group role") as the_token: - def __group_role__(conn: db.DbConnection) -> GroupRole: - group = user_group(conn, the_token.user).maybe(#type: ignore[misc] - DUMMY_GROUP, lambda grp: grp) - - if group == DUMMY_GROUP: - raise AuthorisationError( - "A user without a group cannot view group roles.") - return group_role_by_id(conn, group, group_role_id) - return jsonify(dictify(with_db_connection(__group_role__))) - -def __add_remove_priv_to_from_role__(conn: db.DbConnection, - group_role_id: uuid.UUID, - direction: str, - user: User) -> GroupRole: - assert direction in ("ADD", "DELETE") - group = user_group(conn, user).maybe(# type: ignore[misc] - DUMMY_GROUP, lambda grp: grp) - - if group == DUMMY_GROUP: - raise AuthorisationError( - "You need to be a member of a group to edit roles.") - try: - privilege_id = request.form.get("privilege_id", "") - assert bool(privilege_id), "Privilege to add must be provided." - privileges = privileges_by_ids(conn, (privilege_id,)) - if len(privileges) == 0: - raise NotFoundError("Privilege not found.") - dir_fns = { - "ADD": add_privilege_to_group_role, - "DELETE": delete_privilege_from_group_role - } - return dir_fns[direction]( - conn, - group_role_by_id(conn, group, group_role_id), - privileges[0]) - except AssertionError as aerr: - raise InvalidData(aerr.args[0]) from aerr - -@groups.route("/role/<uuid:group_role_id>/privilege/add", methods=["POST"]) -@require_oauth("profile group") -def add_priv_to_role(group_role_id: uuid.UUID) -> Response: - """Add privilege to group role.""" - with require_oauth.acquire("profile group role") as the_token: - return jsonify({ - **dictify(with_db_connection(partial( - __add_remove_priv_to_from_role__, group_role_id=group_role_id, - direction="ADD", user=the_token.user))), - "description": "Privilege added successfully" - }) - -@groups.route("/role/<uuid:group_role_id>/privilege/delete", methods=["POST"]) -@require_oauth("profile group") -def delete_priv_from_role(group_role_id: uuid.UUID) -> Response: - """Delete privilege from group role.""" - with require_oauth.acquire("profile group role") as the_token: - return jsonify({ - **dictify(with_db_connection(partial( - __add_remove_priv_to_from_role__, group_role_id=group_role_id, - direction="DELETE", user=the_token.user))), - "description": "Privilege deleted successfully" - }) diff --git a/gn3/auth/authorisation/oauth2/__init__.py b/gn3/auth/authorisation/oauth2/__init__.py deleted file mode 100644 index d083773..0000000 --- a/gn3/auth/authorisation/oauth2/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""OAuth2 modules.""" diff --git a/gn3/auth/authorisation/oauth2/oauth2client.py b/gn3/auth/authorisation/oauth2/oauth2client.py deleted file mode 100644 index dc54a41..0000000 --- a/gn3/auth/authorisation/oauth2/oauth2client.py +++ /dev/null @@ -1,234 +0,0 @@ -"""OAuth2 Client model.""" -import json -import datetime -from uuid import UUID -from typing import Sequence, Optional, NamedTuple - -from pymonad.maybe import Just, Maybe, Nothing - -from gn3.auth import db - -from gn3.auth.authorisation.errors import NotFoundError -from gn3.auth.authorisation.users import User, users, user_by_id, same_password - -class OAuth2Client(NamedTuple): - """ - Client to the OAuth2 Server. - - This is defined according to the mixin at - https://docs.authlib.org/en/latest/specs/rfc6749.html#authlib.oauth2.rfc6749.ClientMixin - """ - client_id: UUID - client_secret: str - client_id_issued_at: datetime.datetime - client_secret_expires_at: datetime.datetime - client_metadata: dict - user: User - - def check_client_secret(self, client_secret: str) -> bool: - """Check whether the `client_secret` matches this client.""" - return same_password(client_secret, self.client_secret) - - @property - def token_endpoint_auth_method(self) -> str: - """Return the token endpoint authorisation method.""" - return self.client_metadata.get("token_endpoint_auth_method", ["none"]) - - @property - def client_type(self) -> str: - """ - Return the token endpoint authorisation method. - - Acceptable client types: - * public: Unable to use registered client secrets, e.g. browsers, apps - on mobile devices. - * confidential: able to securely authenticate with authorisation server - e.g. being able to keep their registered client secret safe. - """ - return self.client_metadata.get("client_type", "public") - - def check_endpoint_auth_method(self, method: str, endpoint: str) -> bool: - """ - Check if the client supports the given method for the given endpoint. - - Acceptable methods: - * none: Client is a public client and does not have a client secret - * client_secret_post: Client uses the HTTP POST parameters - * client_secret_basic: Client uses HTTP Basic - """ - if endpoint == "token": - return (method in self.token_endpoint_auth_method - and method == "client_secret_post") - if endpoint in ("introspection", "revoke"): - return (method in self.token_endpoint_auth_method - and method == "client_secret_basic") - return False - - @property - def id(self):# pylint: disable=[invalid-name] - """Return the client_id.""" - return self.client_id - - @property - def grant_types(self) -> Sequence[str]: - """ - Return the grant types that this client supports. - - Valid grant types: - * authorisation_code - * implicit - * client_credentials - * password - """ - return self.client_metadata.get("grant_types", []) - - def check_grant_type(self, grant_type: str) -> bool: - """ - Validate that client can handle the given grant types - """ - return grant_type in self.grant_types - - @property - def redirect_uris(self) -> Sequence[str]: - """Return the redirect_uris that this client supports.""" - return self.client_metadata.get('redirect_uris', []) - - def check_redirect_uri(self, redirect_uri: str) -> bool: - """ - Check whether the given `redirect_uri` is one of the expected ones. - """ - return redirect_uri in self.redirect_uris - - @property - def response_types(self) -> Sequence[str]: - """Return the response_types that this client supports.""" - return self.client_metadata.get("response_type", []) - - def check_response_type(self, response_type: str) -> bool: - """Check whether this client supports `response_type`.""" - return response_type in self.response_types - - @property - def scope(self) -> Sequence[str]: - """Return valid scopes for this client.""" - return tuple(set(self.client_metadata.get("scope", []))) - - def get_allowed_scope(self, scope: str) -> str: - """Return list of scopes in `scope` that are supported by this client.""" - if not bool(scope): - return "" - requested = scope.split() - return " ".join(sorted(set( - scp for scp in requested if scp in self.scope))) - - def get_client_id(self): - """Return this client's identifier.""" - return self.client_id - - def get_default_redirect_uri(self) -> str: - """Return the default redirect uri""" - return self.client_metadata.get("default_redirect_uri", "") - -def client(conn: db.DbConnection, client_id: UUID, - user: Optional[User] = None) -> Maybe: - """Retrieve a client by its ID""" - with db.cursor(conn) as cursor: - cursor.execute( - "SELECT * FROM oauth2_clients WHERE client_id=?", (str(client_id),)) - result = cursor.fetchone() - the_user = user - if result: - if not bool(the_user): - try: - the_user = user_by_id(conn, result["user_id"]) - except NotFoundError as _nfe: - the_user = None - - return Just( - OAuth2Client(UUID(result["client_id"]), - result["client_secret"], - datetime.datetime.fromtimestamp( - result["client_id_issued_at"]), - datetime.datetime.fromtimestamp( - result["client_secret_expires_at"]), - json.loads(result["client_metadata"]), - the_user))# type: ignore[arg-type] - - return Nothing - -def client_by_id_and_secret(conn: db.DbConnection, client_id: UUID, - client_secret: str) -> OAuth2Client: - """Retrieve a client by its ID and secret""" - with db.cursor(conn) as cursor: - cursor.execute( - "SELECT * FROM oauth2_clients WHERE client_id=?", - (str(client_id),)) - row = cursor.fetchone() - if bool(row) and same_password(client_secret, row["client_secret"]): - return OAuth2Client( - client_id, client_secret, - datetime.datetime.fromtimestamp(row["client_id_issued_at"]), - datetime.datetime.fromtimestamp( - row["client_secret_expires_at"]), - json.loads(row["client_metadata"]), - user_by_id(conn, UUID(row["user_id"]))) - - raise NotFoundError("Could not find client with the given credentials.") - -def save_client(conn: db.DbConnection, the_client: OAuth2Client) -> OAuth2Client: - """Persist the client details into the database.""" - with db.cursor(conn) as cursor: - query = ( - "INSERT INTO oauth2_clients " - "(client_id, client_secret, client_id_issued_at, " - "client_secret_expires_at, client_metadata, user_id) " - "VALUES " - "(:client_id, :client_secret, :client_id_issued_at, " - ":client_secret_expires_at, :client_metadata, :user_id) " - "ON CONFLICT (client_id) DO UPDATE SET " - "client_secret=:client_secret, " - "client_id_issued_at=:client_id_issued_at, " - "client_secret_expires_at=:client_secret_expires_at, " - "client_metadata=:client_metadata, user_id=:user_id") - cursor.execute( - query, - { - "client_id": str(the_client.client_id), - "client_secret": the_client.client_secret, - "client_id_issued_at": ( - the_client.client_id_issued_at.timestamp()), - "client_secret_expires_at": ( - the_client.client_secret_expires_at.timestamp()), - "client_metadata": json.dumps(the_client.client_metadata), - "user_id": str(the_client.user.user_id) - }) - return the_client - -def oauth2_clients(conn: db.DbConnection) -> tuple[OAuth2Client, ...]: - """Fetch a list of all OAuth2 clients.""" - with db.cursor(conn) as cursor: - cursor.execute("SELECT * FROM oauth2_clients") - clients_rs = cursor.fetchall() - the_users = { - usr.user_id: usr for usr in users( - conn, tuple({UUID(result["user_id"]) for result in clients_rs})) - } - return tuple(OAuth2Client(UUID(result["client_id"]), - result["client_secret"], - datetime.datetime.fromtimestamp( - result["client_id_issued_at"]), - datetime.datetime.fromtimestamp( - result["client_secret_expires_at"]), - json.loads(result["client_metadata"]), - the_users[UUID(result["user_id"])]) - for result in clients_rs) - -def delete_client(conn: db.DbConnection, the_client: OAuth2Client) -> OAuth2Client: - """Delete the given client from the database""" - with db.cursor(conn) as cursor: - params = (str(the_client.client_id),) - cursor.execute("DELETE FROM authorisation_code WHERE client_id=?", - params) - cursor.execute("DELETE FROM oauth2_tokens WHERE client_id=?", params) - cursor.execute("DELETE FROM oauth2_clients WHERE client_id=?", params) - return the_client diff --git a/gn3/auth/authorisation/oauth2/oauth2token.py b/gn3/auth/authorisation/oauth2/oauth2token.py deleted file mode 100644 index bb19039..0000000 --- a/gn3/auth/authorisation/oauth2/oauth2token.py +++ /dev/null @@ -1,133 +0,0 @@ -"""OAuth2 Token""" -import uuid -import datetime -from typing import NamedTuple, Optional - -from pymonad.maybe import Just, Maybe, Nothing - -from gn3.auth import db - -from gn3.auth.authorisation.errors import NotFoundError -from gn3.auth.authorisation.users import User, user_by_id - -from .oauth2client import client, OAuth2Client - -class OAuth2Token(NamedTuple): - """Implement Tokens for OAuth2.""" - token_id: uuid.UUID - client: OAuth2Client - token_type: str - access_token: str - refresh_token: Optional[str] - scope: str - revoked: bool - issued_at: datetime.datetime - expires_in: int - user: User - - @property - def expires_at(self) -> datetime.datetime: - """Return the time when the token expires.""" - return self.issued_at + datetime.timedelta(seconds=self.expires_in) - - def check_client(self, client: OAuth2Client) -> bool:# pylint: disable=[redefined-outer-name] - """Check whether the token is issued to given `client`.""" - return client.client_id == self.client.client_id - - def get_expires_in(self) -> int: - """Return the `expires_in` value for the token.""" - return self.expires_in - - def get_scope(self) -> str: - """Return the valid scope for the token.""" - return self.scope - - def is_expired(self) -> bool: - """Check whether the token is expired.""" - return self.expires_at < datetime.datetime.now() - - def is_revoked(self): - """Check whether the token has been revoked.""" - return self.revoked - -def __token_from_resultset__(conn: db.DbConnection, rset) -> Maybe: - def __identity__(val): - return val - try: - the_user = user_by_id(conn, uuid.UUID(rset["user_id"])) - except NotFoundError as _nfe: - the_user = None - the_client = client(conn, uuid.UUID(rset["client_id"]), the_user) - - if the_client.is_just() and bool(the_user): - return Just(OAuth2Token(token_id=uuid.UUID(rset["token_id"]), - client=the_client.maybe(None, __identity__), - token_type=rset["token_type"], - access_token=rset["access_token"], - refresh_token=rset["refresh_token"], - scope=rset["scope"], - revoked=(rset["revoked"] == 1), - issued_at=datetime.datetime.fromtimestamp( - rset["issued_at"]), - expires_in=rset["expires_in"], - user=the_user))# type: ignore[arg-type] - - return Nothing - -def token_by_access_token(conn: db.DbConnection, token_str: str) -> Maybe: - """Retrieve token by its token string""" - with db.cursor(conn) as cursor: - cursor.execute("SELECT * FROM oauth2_tokens WHERE access_token=?", - (token_str,)) - res = cursor.fetchone() - if res: - return __token_from_resultset__(conn, res) - - return Nothing - -def token_by_refresh_token(conn: db.DbConnection, token_str: str) -> Maybe: - """Retrieve token by its token string""" - with db.cursor(conn) as cursor: - cursor.execute( - "SELECT * FROM oauth2_tokens WHERE refresh_token=?", - (token_str,)) - res = cursor.fetchone() - if res: - return __token_from_resultset__(conn, res) - - return Nothing - -def revoke_token(token: OAuth2Token) -> OAuth2Token: - """ - Return a new token derived from `token` with the `revoked` field set to - `True`. - """ - return OAuth2Token( - token_id=token.token_id, client=token.client, - token_type=token.token_type, access_token=token.access_token, - refresh_token=token.refresh_token, scope=token.scope, revoked=True, - issued_at=token.issued_at, expires_in=token.expires_in, user=token.user) - -def save_token(conn: db.DbConnection, token: OAuth2Token) -> None: - """Save/Update the token.""" - with db.cursor(conn) as cursor: - cursor.execute( - ("INSERT INTO oauth2_tokens VALUES (:token_id, :client_id, " - ":token_type, :access_token, :refresh_token, :scope, :revoked, " - ":issued_at, :expires_in, :user_id) " - "ON CONFLICT (token_id) DO UPDATE SET " - "refresh_token=:refresh_token, revoked=:revoked, " - "expires_in=:expires_in " - "WHERE token_id=:token_id"), - { - "token_id": str(token.token_id), - "client_id": str(token.client.client_id), - "token_type": token.token_type, - "access_token": token.access_token, - "refresh_token": token.refresh_token, - "scope": token.scope, - "revoked": 1 if token.revoked else 0, - "issued_at": int(token.issued_at.timestamp()), - "expires_in": token.expires_in, - "user_id": str(token.user.user_id) - }) diff --git a/gn3/auth/authorisation/oauth2/resource_server.py b/gn3/auth/authorisation/oauth2/resource_server.py deleted file mode 100644 index e806dc5..0000000 --- a/gn3/auth/authorisation/oauth2/resource_server.py +++ /dev/null @@ -1,19 +0,0 @@ -"""Protect the resources endpoints""" - -from flask import current_app as app -from authlib.oauth2.rfc6750 import BearerTokenValidator as _BearerTokenValidator -from authlib.integrations.flask_oauth2 import ResourceProtector - -from gn3.auth import db -from gn3.auth.authorisation.oauth2.oauth2token import token_by_access_token - -class BearerTokenValidator(_BearerTokenValidator): - """Extends `authlib.oauth2.rfc6750.BearerTokenValidator`""" - def authenticate_token(self, token_string: str): - with db.connection(app.config["AUTH_DB"]) as conn: - return token_by_access_token(conn, token_string).maybe(# type: ignore[misc] - None, lambda tok: tok) - -require_oauth = ResourceProtector() - -require_oauth.register_token_validator(BearerTokenValidator()) diff --git a/gn3/auth/authorisation/privileges.py b/gn3/auth/authorisation/privileges.py deleted file mode 100644 index 7907d76..0000000 --- a/gn3/auth/authorisation/privileges.py +++ /dev/null @@ -1,47 +0,0 @@ -"""Handle privileges""" -from typing import Any, Iterable, NamedTuple - -from gn3.auth import db -from gn3.auth.authorisation.users import User - -class Privilege(NamedTuple): - """Class representing a privilege: creates immutable objects.""" - privilege_id: str - privilege_description: str - - def dictify(self) -> dict[str, Any]: - """Return a dict representation of `Privilege` objects.""" - return { - "privilege_id": self.privilege_id, - "privilege_description": self.privilege_description - } - -def user_privileges(conn: db.DbConnection, user: User) -> Iterable[Privilege]: - """Fetch the user's privileges from the database.""" - with db.cursor(conn) as cursor: - cursor.execute( - ("SELECT p.privilege_id, p.privilege_description " - "FROM user_roles AS ur " - "INNER JOIN role_privileges AS rp ON ur.role_id=rp.role_id " - "INNER JOIN privileges AS p ON rp.privilege_id=p.privilege_id " - "WHERE ur.user_id=?"), - (str(user.user_id),)) - results = cursor.fetchall() - - return (Privilege(row[0], row[1]) for row in results) - -def privileges_by_ids( - conn: db.DbConnection, privileges_ids: tuple[str, ...]) -> tuple[ - Privilege, ...]: - """Fetch privileges by their ids.""" - if len(privileges_ids) == 0: - return tuple() - - with db.cursor(conn) as cursor: - clause = ", ".join(["?"] * len(privileges_ids)) - cursor.execute( - f"SELECT * FROM privileges WHERE privilege_id IN ({clause})", - privileges_ids) - return tuple( - Privilege(row["privilege_id"], row["privilege_description"]) - for row in cursor.fetchall()) diff --git a/gn3/auth/authorisation/resources/__init__.py b/gn3/auth/authorisation/resources/__init__.py deleted file mode 100644 index 869ab60..0000000 --- a/gn3/auth/authorisation/resources/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -"""Initialise the `gn3.auth.authorisation.resources` package.""" -from .models import Resource, ResourceCategory diff --git a/gn3/auth/authorisation/resources/checks.py b/gn3/auth/authorisation/resources/checks.py deleted file mode 100644 index 1f5a0f9..0000000 --- a/gn3/auth/authorisation/resources/checks.py +++ /dev/null @@ -1,47 +0,0 @@ -"""Handle authorisation checks for resources""" -from uuid import UUID -from functools import reduce -from typing import Sequence - -from gn3.auth import db -from gn3.auth.authorisation.users import User - -def __organise_privileges_by_resource_id__(rows): - def __organise__(privs, row): - resource_id = UUID(row["resource_id"]) - return { - **privs, - resource_id: (row["privilege_id"],) + privs.get( - resource_id, tuple()) - } - return reduce(__organise__, rows, {}) - -def authorised_for(conn: db.DbConnection, user: User, privileges: tuple[str], - resource_ids: Sequence[UUID]) -> dict[UUID, bool]: - """ - Check whether `user` is authorised to access `resources` according to given - `privileges`. - """ - with db.cursor(conn) as cursor: - cursor.execute( - ("SELECT guror.*, rp.privilege_id FROM " - "group_user_roles_on_resources AS guror " - "INNER JOIN group_roles AS gr ON " - "(guror.group_id=gr.group_id AND guror.role_id=gr.role_id) " - "INNER JOIN roles AS r ON gr.role_id=r.role_id " - "INNER JOIN role_privileges AS rp ON r.role_id=rp.role_id " - "WHERE guror.user_id=? " - f"AND guror.resource_id IN ({', '.join(['?']*len(resource_ids))})" - f"AND rp.privilege_id IN ({', '.join(['?']*len(privileges))})"), - ((str(user.user_id),) + tuple( - str(r_id) for r_id in resource_ids) + tuple(privileges))) - resource_privileges = __organise_privileges_by_resource_id__( - cursor.fetchall()) - authorised = tuple(resource_id for resource_id, res_privileges - in resource_privileges.items() - if all(priv in res_privileges - for priv in privileges)) - return { - resource_id: resource_id in authorised - for resource_id in resource_ids - } diff --git a/gn3/auth/authorisation/resources/models.py b/gn3/auth/authorisation/resources/models.py deleted file mode 100644 index cf7769e..0000000 --- a/gn3/auth/authorisation/resources/models.py +++ /dev/null @@ -1,579 +0,0 @@ -"""Handle the management of resources.""" -import json -import sqlite3 -from uuid import UUID, uuid4 -from functools import reduce, partial -from typing import Any, Dict, Sequence, Optional, NamedTuple - -from gn3.auth import db -from gn3.auth.dictify import dictify -from gn3.auth.authorisation.users import User -from gn3.auth.db_utils import with_db_connection - -from .checks import authorised_for - -from ..checks import authorised_p -from ..errors import NotFoundError, AuthorisationError -from ..groups.models import ( - Group, GroupRole, user_group, group_by_id, is_group_leader) - -class MissingGroupError(AuthorisationError): - """Raised for any resource operation without a group.""" - -class ResourceCategory(NamedTuple): - """Class representing a resource category.""" - resource_category_id: UUID - resource_category_key: str - resource_category_description: str - - def dictify(self) -> dict[str, Any]: - """Return a dict representation of `ResourceCategory` objects.""" - return { - "resource_category_id": self.resource_category_id, - "resource_category_key": self.resource_category_key, - "resource_category_description": self.resource_category_description - } - -class Resource(NamedTuple): - """Class representing a resource.""" - group: Group - resource_id: UUID - resource_name: str - resource_category: ResourceCategory - public: bool - resource_data: Sequence[dict[str, Any]] = tuple() - - def dictify(self) -> dict[str, Any]: - """Return a dict representation of `Resource` objects.""" - return { - "group": dictify(self.group), "resource_id": self.resource_id, - "resource_name": self.resource_name, - "resource_category": dictify(self.resource_category), - "public": self.public, - "resource_data": self.resource_data - } - -def __assign_resource_owner_role__(cursor, resource, user): - """Assign `user` the 'Resource Owner' role for `resource`.""" - cursor.execute( - "SELECT gr.* FROM group_roles AS gr INNER JOIN roles AS r " - "ON gr.role_id=r.role_id WHERE r.role_name='resource-owner' " - "AND gr.group_id=?", - (str(resource.group.group_id),)) - role = cursor.fetchone() - if not role: - cursor.execute("SELECT * FROM roles WHERE role_name='resource-owner'") - role = cursor.fetchone() - cursor.execute( - "INSERT INTO group_roles VALUES " - "(:group_role_id, :group_id, :role_id)", - {"group_role_id": str(uuid4()), - "group_id": str(resource.group.group_id), - "role_id": role["role_id"]}) - - cursor.execute( - "INSERT INTO group_user_roles_on_resources " - "VALUES (" - ":group_id, :user_id, :role_id, :resource_id" - ")", - {"group_id": str(resource.group.group_id), - "user_id": str(user.user_id), - "role_id": role["role_id"], - "resource_id": str(resource.resource_id)}) - -@authorised_p(("group:resource:create-resource",), - error_description="Insufficient privileges to create a resource", - oauth2_scope="profile resource") -def create_resource( - conn: db.DbConnection, resource_name: str, - resource_category: ResourceCategory, user: User, - public: bool) -> Resource: - """Create a resource item.""" - with db.cursor(conn) as cursor: - group = user_group(conn, user).maybe( - False, lambda grp: grp)# type: ignore[misc, arg-type] - if not group: - raise MissingGroupError( - "User with no group cannot create a resource.") - resource = Resource( - group, uuid4(), resource_name, resource_category, public) - cursor.execute( - "INSERT INTO resources VALUES (?, ?, ?, ?, ?)", - (str(resource.group.group_id), str(resource.resource_id), - resource_name, - str(resource.resource_category.resource_category_id), - 1 if resource.public else 0)) - __assign_resource_owner_role__(cursor, resource, user) - - return resource - -def resource_category_by_id( - conn: db.DbConnection, category_id: UUID) -> ResourceCategory: - """Retrieve a resource category by its ID.""" - with db.cursor(conn) as cursor: - cursor.execute( - "SELECT * FROM resource_categories WHERE " - "resource_category_id=?", - (str(category_id),)) - results = cursor.fetchone() - if results: - return ResourceCategory( - UUID(results["resource_category_id"]), - results["resource_category_key"], - results["resource_category_description"]) - - raise NotFoundError( - f"Could not find a ResourceCategory with ID '{category_id}'") - -def resource_categories(conn: db.DbConnection) -> Sequence[ResourceCategory]: - """Retrieve all available resource categories""" - with db.cursor(conn) as cursor: - cursor.execute("SELECT * FROM resource_categories") - return tuple( - ResourceCategory(UUID(row[0]), row[1], row[2]) - for row in cursor.fetchall()) - return tuple() - -def public_resources(conn: db.DbConnection) -> Sequence[Resource]: - """List all resources marked as public""" - categories = { - str(cat.resource_category_id): cat for cat in resource_categories(conn) - } - with db.cursor(conn) as cursor: - cursor.execute("SELECT * FROM resources WHERE public=1") - results = cursor.fetchall() - group_uuids = tuple(row[0] for row in results) - query = ("SELECT * FROM groups WHERE group_id IN " - f"({', '.join(['?'] * len(group_uuids))})") - cursor.execute(query, group_uuids) - groups = { - row[0]: Group( - UUID(row[0]), row[1], json.loads(row[2] or "{}")) - for row in cursor.fetchall() - } - return tuple( - Resource(groups[row[0]], UUID(row[1]), row[2], categories[row[3]], - bool(row[4])) - for row in results) - -def group_leader_resources( - conn: db.DbConnection, user: User, group: Group, - res_categories: Dict[UUID, ResourceCategory]) -> Sequence[Resource]: - """Return all the resources available to the group leader""" - if is_group_leader(conn, user, group): - with db.cursor(conn) as cursor: - cursor.execute("SELECT * FROM resources WHERE group_id=?", - (str(group.group_id),)) - return tuple( - Resource(group, UUID(row[1]), row[2], - res_categories[UUID(row[3])], bool(row[4])) - for row in cursor.fetchall()) - return tuple() - -def user_resources(conn: db.DbConnection, user: User) -> Sequence[Resource]: - """List the resources available to the user""" - categories = { # Repeated in `public_resources` function - cat.resource_category_id: cat for cat in resource_categories(conn) - } - with db.cursor(conn) as cursor: - def __all_resources__(group) -> Sequence[Resource]: - gl_resources = group_leader_resources(conn, user, group, categories) - - cursor.execute( - ("SELECT resources.* FROM group_user_roles_on_resources " - "LEFT JOIN resources " - "ON group_user_roles_on_resources.resource_id=resources.resource_id " - "WHERE group_user_roles_on_resources.group_id = ? " - "AND group_user_roles_on_resources.user_id = ?"), - (str(group.group_id), str(user.user_id))) - rows = cursor.fetchall() - private_res = tuple( - Resource(group, UUID(row[1]), row[2], categories[UUID(row[3])], - bool(row[4])) - for row in rows) - return tuple({ - res.resource_id: res - for res in - (private_res + gl_resources + public_resources(conn))# type: ignore[operator] - }.values()) - - # Fix the typing here - return user_group(conn, user).map(__all_resources__).maybe(# type: ignore[arg-type,misc] - public_resources(conn), lambda res: res)# type: ignore[arg-type,return-value] - -def resource_data(conn, resource, offset: int = 0, limit: Optional[int] = None) -> tuple[dict, ...]: - """ - Retrieve the data for `resource`, optionally limiting the number of items. - """ - resource_data_function = { - "mrna": mrna_resource_data, - "genotype": genotype_resource_data, - "phenotype": phenotype_resource_data - } - with db.cursor(conn) as cursor: - return tuple( - dict(data_row) for data_row in - resource_data_function[ - resource.resource_category.resource_category_key]( - cursor, resource.resource_id, offset, limit)) - -def attach_resource_data(cursor: db.DbCursor, resource: Resource) -> Resource: - """Attach the linked data to the resource""" - resource_data_function = { - "mrna": mrna_resource_data, - "genotype": genotype_resource_data, - "phenotype": phenotype_resource_data - } - category = resource.resource_category - data_rows = tuple( - dict(data_row) for data_row in - resource_data_function[category.resource_category_key]( - cursor, resource.resource_id)) - return Resource( - resource.group, resource.resource_id, resource.resource_name, - resource.resource_category, resource.public, data_rows) - -def mrna_resource_data(cursor: db.DbCursor, - resource_id: UUID, - offset: int = 0, - limit: Optional[int] = None) -> Sequence[sqlite3.Row]: - """Fetch data linked to a mRNA resource""" - cursor.execute( - (("SELECT * FROM mrna_resources AS mr " - "INNER JOIN linked_mrna_data AS lmr " - "ON mr.data_link_id=lmr.data_link_id " - "WHERE mr.resource_id=?") + ( - f" LIMIT {limit} OFFSET {offset}" if bool(limit) else "")), - (str(resource_id),)) - return cursor.fetchall() - -def genotype_resource_data( - cursor: db.DbCursor, - resource_id: UUID, - offset: int = 0, - limit: Optional[int] = None) -> Sequence[sqlite3.Row]: - """Fetch data linked to a Genotype resource""" - cursor.execute( - (("SELECT * FROM genotype_resources AS gr " - "INNER JOIN linked_genotype_data AS lgd " - "ON gr.data_link_id=lgd.data_link_id " - "WHERE gr.resource_id=?") + ( - f" LIMIT {limit} OFFSET {offset}" if bool(limit) else "")), - (str(resource_id),)) - return cursor.fetchall() - -def phenotype_resource_data( - cursor: db.DbCursor, - resource_id: UUID, - offset: int = 0, - limit: Optional[int] = None) -> Sequence[sqlite3.Row]: - """Fetch data linked to a Phenotype resource""" - cursor.execute( - ("SELECT * FROM phenotype_resources AS pr " - "INNER JOIN linked_phenotype_data AS lpd " - "ON pr.data_link_id=lpd.data_link_id " - "WHERE pr.resource_id=?") + ( - f" LIMIT {limit} OFFSET {offset}" if bool(limit) else ""), - (str(resource_id),)) - return cursor.fetchall() - -def resource_by_id( - conn: db.DbConnection, user: User, resource_id: UUID) -> Resource: - """Retrieve a resource by its ID.""" - if not authorised_for( - conn, user, ("group:resource:view-resource",), - (resource_id,))[resource_id]: - raise AuthorisationError( - "You are not authorised to access resource with id " - f"'{resource_id}'.") - - with db.cursor(conn) as cursor: - cursor.execute("SELECT * FROM resources WHERE resource_id=:id", - {"id": str(resource_id)}) - row = cursor.fetchone() - if row: - return Resource( - group_by_id(conn, UUID(row["group_id"])), - UUID(row["resource_id"]), row["resource_name"], - resource_category_by_id(conn, row["resource_category_id"]), - bool(int(row["public"]))) - - raise NotFoundError(f"Could not find a resource with id '{resource_id}'") - -def __link_mrna_data_to_resource__( - conn: db.DbConnection, resource: Resource, data_link_id: UUID) -> dict: - """Link mRNA Assay data with a resource.""" - with db.cursor(conn) as cursor: - params = { - "group_id": str(resource.group.group_id), - "resource_id": str(resource.resource_id), - "data_link_id": str(data_link_id) - } - cursor.execute( - "INSERT INTO mrna_resources VALUES" - "(:group_id, :resource_id, :data_link_id)", - params) - return params - -def __link_geno_data_to_resource__( - conn: db.DbConnection, resource: Resource, data_link_id: UUID) -> dict: - """Link Genotype data with a resource.""" - with db.cursor(conn) as cursor: - params = { - "group_id": str(resource.group.group_id), - "resource_id": str(resource.resource_id), - "data_link_id": str(data_link_id) - } - cursor.execute( - "INSERT INTO genotype_resources VALUES" - "(:group_id, :resource_id, :data_link_id)", - params) - return params - -def __link_pheno_data_to_resource__( - conn: db.DbConnection, resource: Resource, data_link_id: UUID) -> dict: - """Link Phenotype data with a resource.""" - with db.cursor(conn) as cursor: - params = { - "group_id": str(resource.group.group_id), - "resource_id": str(resource.resource_id), - "data_link_id": str(data_link_id) - } - cursor.execute( - "INSERT INTO phenotype_resources VALUES" - "(:group_id, :resource_id, :data_link_id)", - params) - return params - -def link_data_to_resource( - conn: db.DbConnection, user: User, resource_id: UUID, dataset_type: str, - data_link_id: UUID) -> dict: - """Link data to resource.""" - if not authorised_for( - conn, user, ("group:resource:edit-resource",), - (resource_id,))[resource_id]: - raise AuthorisationError( - "You are not authorised to link data to resource with id " - f"{resource_id}") - - resource = with_db_connection(partial( - resource_by_id, user=user, resource_id=resource_id)) - return { - "mrna": __link_mrna_data_to_resource__, - "genotype": __link_geno_data_to_resource__, - "phenotype": __link_pheno_data_to_resource__, - }[dataset_type.lower()](conn, resource, data_link_id) - -def __unlink_mrna_data_to_resource__( - conn: db.DbConnection, resource: Resource, data_link_id: UUID) -> dict: - """Unlink data from mRNA Assay resources""" - with db.cursor(conn) as cursor: - cursor.execute("DELETE FROM mrna_resources " - "WHERE resource_id=? AND data_link_id=?", - (str(resource.resource_id), str(data_link_id))) - return { - "resource_id": str(resource.resource_id), - "dataset_type": resource.resource_category.resource_category_key, - "data_link_id": data_link_id - } - -def __unlink_geno_data_to_resource__( - conn: db.DbConnection, resource: Resource, data_link_id: UUID) -> dict: - """Unlink data from Genotype resources""" - with db.cursor(conn) as cursor: - cursor.execute("DELETE FROM genotype_resources " - "WHERE resource_id=? AND data_link_id=?", - (str(resource.resource_id), str(data_link_id))) - return { - "resource_id": str(resource.resource_id), - "dataset_type": resource.resource_category.resource_category_key, - "data_link_id": data_link_id - } - -def __unlink_pheno_data_to_resource__( - conn: db.DbConnection, resource: Resource, data_link_id: UUID) -> dict: - """Unlink data from Phenotype resources""" - with db.cursor(conn) as cursor: - cursor.execute("DELETE FROM phenotype_resources " - "WHERE resource_id=? AND data_link_id=?", - (str(resource.resource_id), str(data_link_id))) - return { - "resource_id": str(resource.resource_id), - "dataset_type": resource.resource_category.resource_category_key, - "data_link_id": str(data_link_id) - } - -def unlink_data_from_resource( - conn: db.DbConnection, user: User, resource_id: UUID, data_link_id: UUID): - """Unlink data from resource.""" - if not authorised_for( - conn, user, ("group:resource:edit-resource",), - (resource_id,))[resource_id]: - raise AuthorisationError( - "You are not authorised to link data to resource with id " - f"{resource_id}") - - resource = with_db_connection(partial( - resource_by_id, user=user, resource_id=resource_id)) - dataset_type = resource.resource_category.resource_category_key - return { - "mrna": __unlink_mrna_data_to_resource__, - "genotype": __unlink_geno_data_to_resource__, - "phenotype": __unlink_pheno_data_to_resource__, - }[dataset_type.lower()](conn, resource, data_link_id) - -def organise_resources_by_category(resources: Sequence[Resource]) -> dict[ - ResourceCategory, tuple[Resource]]: - """Organise the `resources` by their categories.""" - def __organise__(accumulator, resource): - category = resource.resource_category - return { - **accumulator, - category: accumulator.get(category, tuple()) + (resource,) - } - return reduce(__organise__, resources, {}) - -def __attach_data__( - data_rows: Sequence[sqlite3.Row], - resources: Sequence[Resource]) -> Sequence[Resource]: - def __organise__(acc, row): - resource_id = UUID(row["resource_id"]) - return { - **acc, - resource_id: acc.get(resource_id, tuple()) + (dict(row),) - } - organised: dict[UUID, tuple[dict, ...]] = reduce(__organise__, data_rows, {}) - return tuple( - Resource( - resource.group, resource.resource_id, resource.resource_name, - resource.resource_category, resource.public, - organised.get(resource.resource_id, tuple())) - for resource in resources) - -def attach_mrna_resources_data( - cursor, resources: Sequence[Resource]) -> Sequence[Resource]: - """Attach linked data to mRNA Assay resources""" - placeholders = ", ".join(["?"] * len(resources)) - cursor.execute( - "SELECT * FROM mrna_resources AS mr INNER JOIN linked_mrna_data AS lmd" - " ON mr.data_link_id=lmd.data_link_id " - f"WHERE mr.resource_id IN ({placeholders})", - tuple(str(resource.resource_id) for resource in resources)) - return __attach_data__(cursor.fetchall(), resources) - -def attach_genotype_resources_data( - cursor, resources: Sequence[Resource]) -> Sequence[Resource]: - """Attach linked data to Genotype resources""" - placeholders = ", ".join(["?"] * len(resources)) - cursor.execute( - "SELECT * FROM genotype_resources AS gr " - "INNER JOIN linked_genotype_data AS lgd " - "ON gr.data_link_id=lgd.data_link_id " - f"WHERE gr.resource_id IN ({placeholders})", - tuple(str(resource.resource_id) for resource in resources)) - return __attach_data__(cursor.fetchall(), resources) - -def attach_phenotype_resources_data( - cursor, resources: Sequence[Resource]) -> Sequence[Resource]: - """Attach linked data to Phenotype resources""" - placeholders = ", ".join(["?"] * len(resources)) - cursor.execute( - "SELECT * FROM phenotype_resources AS pr " - "INNER JOIN linked_phenotype_data AS lpd " - "ON pr.data_link_id=lpd.data_link_id " - f"WHERE pr.resource_id IN ({placeholders})", - tuple(str(resource.resource_id) for resource in resources)) - return __attach_data__(cursor.fetchall(), resources) - -def attach_resources_data( - conn: db.DbConnection, resources: Sequence[Resource]) -> Sequence[ - Resource]: - """Attach linked data for each resource in `resources`""" - resource_data_function = { - "mrna": attach_mrna_resources_data, - "genotype": attach_genotype_resources_data, - "phenotype": attach_phenotype_resources_data - } - organised = organise_resources_by_category(resources) - with db.cursor(conn) as cursor: - return tuple( - resource for categories in - (resource_data_function[category.resource_category_key]( - cursor, rscs) - for category, rscs in organised.items()) - for resource in categories) - -@authorised_p( - ("group:user:assign-role",), - "You cannot assign roles to users for this group.", - oauth2_scope="profile group role resource") -def assign_resource_user( - conn: db.DbConnection, resource: Resource, user: User, - role: GroupRole) -> dict: - """Assign `role` to `user` for the specific `resource`.""" - with db.cursor(conn) as cursor: - cursor.execute( - "INSERT INTO " - "group_user_roles_on_resources(group_id, user_id, role_id, " - "resource_id) " - "VALUES (?, ?, ?, ?) " - "ON CONFLICT (group_id, user_id, role_id, resource_id) " - "DO NOTHING", - (str(resource.group.group_id), str(user.user_id), - str(role.role.role_id), str(resource.resource_id))) - return { - "resource": dictify(resource), - "user": dictify(user), - "role": dictify(role), - "description": ( - f"The user '{user.name}'({user.email}) was assigned the " - f"'{role.role.role_name}' role on resource with ID " - f"'{resource.resource_id}'.")} - -@authorised_p( - ("group:user:assign-role",), - "You cannot assign roles to users for this group.", - oauth2_scope="profile group role resource") -def unassign_resource_user( - conn: db.DbConnection, resource: Resource, user: User, - role: GroupRole) -> dict: - """Assign `role` to `user` for the specific `resource`.""" - with db.cursor(conn) as cursor: - cursor.execute( - "DELETE FROM group_user_roles_on_resources " - "WHERE group_id=? AND user_id=? AND role_id=? AND resource_id=?", - (str(resource.group.group_id), str(user.user_id), - str(role.role.role_id), str(resource.resource_id))) - return { - "resource": dictify(resource), - "user": dictify(user), - "role": dictify(role), - "description": ( - f"The user '{user.name}'({user.email}) had the " - f"'{role.role.role_name}' role on resource with ID " - f"'{resource.resource_id}' taken away.")} - -def save_resource( - conn: db.DbConnection, user: User, resource: Resource) -> Resource: - """Update an existing resource.""" - resource_id = resource.resource_id - authorised = authorised_for( - conn, user, ("group:resource:edit-resource",), (resource_id,)) - if authorised[resource_id]: - with db.cursor(conn) as cursor: - cursor.execute( - "UPDATE resources SET " - "resource_name=:resource_name, " - "public=:public " - "WHERE group_id=:group_id " - "AND resource_id=:resource_id", - { - "resource_name": resource.resource_name, - "public": 1 if resource.public else 0, - "group_id": str(resource.group.group_id), - "resource_id": str(resource.resource_id) - }) - return resource - - raise AuthorisationError( - "You do not have the appropriate privileges to edit this resource.") diff --git a/gn3/auth/authorisation/resources/views.py b/gn3/auth/authorisation/resources/views.py deleted file mode 100644 index bda67cd..0000000 --- a/gn3/auth/authorisation/resources/views.py +++ /dev/null @@ -1,272 +0,0 @@ -"""The views/routes for the resources package""" -import uuid -import json -import sqlite3 -from functools import reduce - -from flask import request, jsonify, Response, Blueprint, current_app as app - -from gn3.auth.db_utils import with_db_connection -from gn3.auth.authorisation.oauth2.resource_server import require_oauth -from gn3.auth.authorisation.users import User, user_by_id, user_by_email - -from .checks import authorised_for -from .models import ( - Resource, save_resource, resource_data, resource_by_id, resource_categories, - assign_resource_user, link_data_to_resource, unassign_resource_user, - resource_category_by_id, unlink_data_from_resource, - create_resource as _create_resource) - -from ..roles import Role -from ..errors import InvalidData, InconsistencyError, AuthorisationError -from ..groups.models import Group, GroupRole, group_role_by_id - -from ... import db -from ...dictify import dictify - -resources = Blueprint("resources", __name__) - -@resources.route("/categories", methods=["GET"]) -@require_oauth("profile group resource") -def list_resource_categories() -> Response: - """Retrieve all resource categories""" - db_uri = app.config["AUTH_DB"] - with db.connection(db_uri) as conn: - return jsonify(tuple( - dictify(category) for category in resource_categories(conn))) - -@resources.route("/create", methods=["POST"]) -@require_oauth("profile group resource") -def create_resource() -> Response: - """Create a new resource""" - with require_oauth.acquire("profile group resource") as the_token: - form = request.form - resource_name = form.get("resource_name") - resource_category_id = uuid.UUID(form.get("resource_category")) - db_uri = app.config["AUTH_DB"] - with db.connection(db_uri) as conn: - try: - resource = _create_resource( - conn, - resource_name, - resource_category_by_id(conn, resource_category_id), - the_token.user, - (form.get("public") == "on")) - return jsonify(dictify(resource)) - except sqlite3.IntegrityError as sql3ie: - if sql3ie.args[0] == ("UNIQUE constraint failed: " - "resources.resource_name"): - raise InconsistencyError( - "You cannot have duplicate resource names.") from sql3ie - app.logger.debug( - f"{type(sql3ie)=}: {sql3ie=}") - raise - -@resources.route("/view/<uuid:resource_id>") -@require_oauth("profile group resource") -def view_resource(resource_id: uuid.UUID) -> Response: - """View a particular resource's details.""" - with require_oauth.acquire("profile group resource") as the_token: - db_uri = app.config["AUTH_DB"] - with db.connection(db_uri) as conn: - return jsonify(dictify(resource_by_id( - conn, the_token.user, resource_id))) - -def __safe_get_requests_page__(key: str = "page") -> int: - """Get the results page if it exists or default to the first page.""" - try: - return abs(int(request.args.get(key, "1"), base=10)) - except ValueError as _valerr: - return 1 - -def __safe_get_requests_count__(key: str = "count_per_page") -> int: - """Get the results page if it exists or default to the first page.""" - try: - count = request.args.get(key, "0") - if count != 0: - return abs(int(count, base=10)) - return 0 - except ValueError as _valerr: - return 0 - -@resources.route("/view/<uuid:resource_id>/data") -@require_oauth("profile group resource") -def view_resource_data(resource_id: uuid.UUID) -> Response: - """Retrieve a particular resource's data.""" - with require_oauth.acquire("profile group resource") as the_token: - db_uri = app.config["AUTH_DB"] - count_per_page = __safe_get_requests_count__("count_per_page") - offset = (__safe_get_requests_page__("page") - 1) - with db.connection(db_uri) as conn: - resource = resource_by_id(conn, the_token.user, resource_id) - return jsonify(resource_data( - conn, - resource, - ((offset * count_per_page) if bool(count_per_page) else offset), - count_per_page)) - -@resources.route("/data/link", methods=["POST"]) -@require_oauth("profile group resource") -def link_data(): - """Link group data to a specific resource.""" - try: - form = request.form - assert "resource_id" in form, "Resource ID not provided." - assert "data_link_id" in form, "Data Link ID not provided." - assert "dataset_type" in form, "Dataset type not specified" - assert form["dataset_type"].lower() in ( - "mrna", "genotype", "phenotype"), "Invalid dataset type provided." - - with require_oauth.acquire("profile group resource") as the_token: - def __link__(conn: db.DbConnection): - return link_data_to_resource( - conn, the_token.user, uuid.UUID(form["resource_id"]), - form["dataset_type"], uuid.UUID(form["data_link_id"])) - - return jsonify(with_db_connection(__link__)) - except AssertionError as aserr: - raise InvalidData(aserr.args[0]) from aserr - - - -@resources.route("/data/unlink", methods=["POST"]) -@require_oauth("profile group resource") -def unlink_data(): - """Unlink data bound to a specific resource.""" - try: - form = request.form - assert "resource_id" in form, "Resource ID not provided." - assert "data_link_id" in form, "Data Link ID not provided." - - with require_oauth.acquire("profile group resource") as the_token: - def __unlink__(conn: db.DbConnection): - return unlink_data_from_resource( - conn, the_token.user, uuid.UUID(form["resource_id"]), - uuid.UUID(form["data_link_id"])) - return jsonify(with_db_connection(__unlink__)) - except AssertionError as aserr: - raise InvalidData(aserr.args[0]) from aserr - -@resources.route("<uuid:resource_id>/user/list", methods=["GET"]) -@require_oauth("profile group resource") -def resource_users(resource_id: uuid.UUID): - """Retrieve all users with access to the given resource.""" - with require_oauth.acquire("profile group resource") as the_token: - def __the_users__(conn: db.DbConnection): - resource = resource_by_id(conn, the_token.user, resource_id) - authorised = authorised_for( - conn, the_token.user, ("group:resource:edit-resource",), - (resource_id,)) - if authorised.get(resource_id, False): - with db.cursor(conn) as cursor: - def __organise_users_n_roles__(users_n_roles, row): - user_id = uuid.UUID(row["user_id"]) - user = users_n_roles.get(user_id, {}).get( - "user", User(user_id, row["email"], row["name"])) - role = GroupRole( - uuid.UUID(row["group_role_id"]), - resource.group, - Role(uuid.UUID(row["role_id"]), row["role_name"], - bool(int(row["user_editable"])), tuple())) - return { - **users_n_roles, - user_id: { - "user": user, - "user_group": Group( - uuid.UUID(row["group_id"]), row["group_name"], - json.loads(row["group_metadata"])), - "roles": users_n_roles.get( - user_id, {}).get("roles", tuple()) + (role,) - } - } - cursor.execute( - "SELECT g.*, u.*, r.*, gr.group_role_id " - "FROM groups AS g INNER JOIN " - "group_users AS gu ON g.group_id=gu.group_id " - "INNER JOIN users AS u ON gu.user_id=u.user_id " - "INNER JOIN group_user_roles_on_resources AS guror " - "ON u.user_id=guror.user_id INNER JOIN roles AS r " - "ON guror.role_id=r.role_id " - "INNER JOIN group_roles AS gr ON r.role_id=gr.role_id " - "WHERE guror.resource_id=?", - (str(resource_id),)) - return reduce(__organise_users_n_roles__, cursor.fetchall(), {}) - raise AuthorisationError( - "You do not have sufficient privileges to view the resource " - "users.") - results = ( - { - "user": dictify(row["user"]), - "user_group": dictify(row["user_group"]), - "roles": tuple(dictify(role) for role in row["roles"]) - } for row in ( - user_row for user_id, user_row - in with_db_connection(__the_users__).items())) - return jsonify(tuple(results)) - -@resources.route("<uuid:resource_id>/user/assign", methods=["POST"]) -@require_oauth("profile group resource role") -def assign_role_to_user(resource_id: uuid.UUID) -> Response: - """Assign a role on the specified resource to a user.""" - with require_oauth.acquire("profile group resource role") as the_token: - try: - form = request.form - group_role_id = form.get("group_role_id", "") - user_email = form.get("user_email", "") - assert bool(group_role_id), "The role must be provided." - assert bool(user_email), "The user email must be provided." - - def __assign__(conn: db.DbConnection) -> dict: - resource = resource_by_id(conn, the_token.user, resource_id) - user = user_by_email(conn, user_email) - return assign_resource_user( - conn, resource, user, - group_role_by_id(conn, resource.group, - uuid.UUID(group_role_id))) - except AssertionError as aserr: - raise AuthorisationError(aserr.args[0]) from aserr - - return jsonify(with_db_connection(__assign__)) - -@resources.route("<uuid:resource_id>/user/unassign", methods=["POST"]) -@require_oauth("profile group resource role") -def unassign_role_to_user(resource_id: uuid.UUID) -> Response: - """Unassign a role on the specified resource from a user.""" - with require_oauth.acquire("profile group resource role") as the_token: - try: - form = request.form - group_role_id = form.get("group_role_id", "") - user_id = form.get("user_id", "") - assert bool(group_role_id), "The role must be provided." - assert bool(user_id), "The user id must be provided." - - def __assign__(conn: db.DbConnection) -> dict: - resource = resource_by_id(conn, the_token.user, resource_id) - return unassign_resource_user( - conn, resource, user_by_id(conn, uuid.UUID(user_id)), - group_role_by_id(conn, resource.group, - uuid.UUID(group_role_id))) - except AssertionError as aserr: - raise AuthorisationError(aserr.args[0]) from aserr - - return jsonify(with_db_connection(__assign__)) - -@resources.route("<uuid:resource_id>/toggle-public", methods=["POST"]) -@require_oauth("profile group resource role") -def toggle_public(resource_id: uuid.UUID) -> Response: - """Make a resource public if it is private, or private if public.""" - with require_oauth.acquire("profile group resource") as the_token: - def __toggle__(conn: db.DbConnection) -> Resource: - old_rsc = resource_by_id(conn, the_token.user, resource_id) - return save_resource( - conn, the_token.user, Resource( - old_rsc.group, old_rsc.resource_id, old_rsc.resource_name, - old_rsc.resource_category, not old_rsc.public, - old_rsc.resource_data)) - - resource = with_db_connection(__toggle__) - return jsonify({ - "resource": dictify(resource), - "description": ( - "Made resource public" if resource.public - else "Made resource private")}) diff --git a/gn3/auth/authorisation/roles/__init__.py b/gn3/auth/authorisation/roles/__init__.py deleted file mode 100644 index 293a12f..0000000 --- a/gn3/auth/authorisation/roles/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -"""Initialise the `gn3.auth.authorisation.roles` package""" - -from .models import Role diff --git a/gn3/auth/authorisation/roles/models.py b/gn3/auth/authorisation/roles/models.py deleted file mode 100644 index 890d33b..0000000 --- a/gn3/auth/authorisation/roles/models.py +++ /dev/null @@ -1,161 +0,0 @@ -"""Handle management of roles""" -from uuid import UUID, uuid4 -from functools import reduce -from typing import Any, Sequence, Iterable, NamedTuple - -from pymonad.either import Left, Right, Either - -from gn3.auth import db -from gn3.auth.dictify import dictify -from gn3.auth.authorisation.users import User -from gn3.auth.authorisation.errors import AuthorisationError - -from ..checks import authorised_p -from ..privileges import Privilege -from ..errors import NotFoundError - -class Role(NamedTuple): - """Class representing a role: creates immutable objects.""" - role_id: UUID - role_name: str - user_editable: bool - privileges: tuple[Privilege, ...] - - def dictify(self) -> dict[str, Any]: - """Return a dict representation of `Role` objects.""" - return { - "role_id": self.role_id, "role_name": self.role_name, - "user_editable": self.user_editable, - "privileges": tuple(dictify(priv) for priv in self.privileges) - } - -def check_user_editable(role: Role): - """Raise an exception if `role` is not user editable.""" - if not role.user_editable: - raise AuthorisationError( - f"The role `{role.role_name}` is not user editable.") - -@authorised_p( - privileges = ("group:role:create-role",), - error_description="Could not create role") -def create_role( - cursor: db.DbCursor, role_name: str, - privileges: Iterable[Privilege]) -> Role: - """ - Create a new generic role. - - PARAMS: - * cursor: A database cursor object - This function could be used as part of - a transaction, hence the use of a cursor rather than a connection - object. - * role_name: The name of the role - * privileges: A 'list' of privileges to assign the new role - - RETURNS: An immutable `gn3.auth.authorisation.roles.Role` object - """ - role = Role(uuid4(), role_name, True, tuple(privileges)) - - cursor.execute( - "INSERT INTO roles(role_id, role_name, user_editable) VALUES (?, ?, ?)", - (str(role.role_id), role.role_name, (1 if role.user_editable else 0))) - cursor.executemany( - "INSERT INTO role_privileges(role_id, privilege_id) VALUES (?, ?)", - tuple((str(role.role_id), str(priv.privilege_id)) - for priv in privileges)) - - return role - -def __organise_privileges__(roles_dict, privilege_row): - """Organise the privileges into their roles.""" - role_id_str = privilege_row["role_id"] - if role_id_str in roles_dict: - return { - **roles_dict, - role_id_str: Role( - UUID(role_id_str), - privilege_row["role_name"], - bool(int(privilege_row["user_editable"])), - roles_dict[role_id_str].privileges + ( - Privilege(privilege_row["privilege_id"], - privilege_row["privilege_description"]),)) - } - - return { - **roles_dict, - role_id_str: Role( - UUID(role_id_str), - privilege_row["role_name"], - bool(int(privilege_row["user_editable"])), - (Privilege(privilege_row["privilege_id"], - privilege_row["privilege_description"]),)) - } - -def user_roles(conn: db.DbConnection, user: User) -> Sequence[Role]: - """Retrieve non-resource roles assigned to the user.""" - with db.cursor(conn) as cursor: - cursor.execute( - "SELECT r.*, p.* FROM user_roles AS ur INNER JOIN roles AS r " - "ON ur.role_id=r.role_id INNER JOIN role_privileges AS rp " - "ON r.role_id=rp.role_id INNER JOIN privileges AS p " - "ON rp.privilege_id=p.privilege_id WHERE ur.user_id=?", - (str(user.user_id),)) - - return tuple( - reduce(__organise_privileges__, cursor.fetchall(), {}).values()) - return tuple() - -def user_role(conn: db.DbConnection, user: User, role_id: UUID) -> Either: - """Retrieve a specific non-resource role assigned to the user.""" - with db.cursor(conn) as cursor: - cursor.execute( - "SELECT r.*, p.* FROM user_roles AS ur INNER JOIN roles AS r " - "ON ur.role_id=r.role_id INNER JOIN role_privileges AS rp " - "ON r.role_id=rp.role_id INNER JOIN privileges AS p " - "ON rp.privilege_id=p.privilege_id " - "WHERE ur.user_id=? AND ur.role_id=?", - (str(user.user_id), str(role_id))) - - results = cursor.fetchall() - if results: - return Right(tuple( - reduce(__organise_privileges__, results, {}).values())[0]) - return Left(NotFoundError( - f"Could not find role with id '{role_id}'",)) - -def assign_default_roles(cursor: db.DbCursor, user: User): - """Assign `user` some default roles.""" - cursor.execute( - 'SELECT role_id FROM roles WHERE role_name IN ' - '("group-creator")') - role_ids = cursor.fetchall() - str_user_id = str(user.user_id) - params = tuple( - {"user_id": str_user_id, "role_id": row["role_id"]} for row in role_ids) - cursor.executemany( - ("INSERT INTO user_roles VALUES (:user_id, :role_id)"), - params) - -def revoke_user_role_by_name(cursor: db.DbCursor, user: User, role_name: str): - """Revoke a role from `user` by the role's name""" - cursor.execute( - "SELECT role_id FROM roles WHERE role_name=:role_name", - {"role_name": role_name}) - role = cursor.fetchone() - if role: - cursor.execute( - ("DELETE FROM user_roles " - "WHERE user_id=:user_id AND role_id=:role_id"), - {"user_id": str(user.user_id), "role_id": role["role_id"]}) - -def assign_user_role_by_name(cursor: db.DbCursor, user: User, role_name: str): - """Revoke a role from `user` by the role's name""" - cursor.execute( - "SELECT role_id FROM roles WHERE role_name=:role_name", - {"role_name": role_name}) - role = cursor.fetchone() - - if role: - cursor.execute( - ("INSERT INTO user_roles VALUES(:user_id, :role_id) " - "ON CONFLICT DO NOTHING"), - {"user_id": str(user.user_id), "role_id": role["role_id"]}) diff --git a/gn3/auth/authorisation/roles/views.py b/gn3/auth/authorisation/roles/views.py deleted file mode 100644 index d00e596..0000000 --- a/gn3/auth/authorisation/roles/views.py +++ /dev/null @@ -1,25 +0,0 @@ -"""The views/routes for the `gn3.auth.authorisation.roles` package.""" -import uuid - -from flask import jsonify, Response, Blueprint, current_app - -from gn3.auth import db -from gn3.auth.dictify import dictify -from gn3.auth.authorisation.oauth2.resource_server import require_oauth - -from .models import user_role - -roles = Blueprint("roles", __name__) - -@roles.route("/view/<uuid:role_id>", methods=["GET"]) -@require_oauth("profile role") -def view_role(role_id: uuid.UUID) -> Response: - """Retrieve a user role with id `role_id`""" - def __error__(exc: Exception): - raise exc - with require_oauth.acquire("profile role") as the_token: - db_uri = current_app.config["AUTH_DB"] - with db.connection(db_uri) as conn: - the_role = user_role(conn, the_token.user, role_id) - return the_role.either( - __error__, lambda a_role: jsonify(dictify(a_role))) diff --git a/gn3/auth/authorisation/users/__init__.py b/gn3/auth/authorisation/users/__init__.py deleted file mode 100644 index 5f0c89c..0000000 --- a/gn3/auth/authorisation/users/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -"""Initialise the users' package.""" -from .base import ( - User, - users, - save_user, - user_by_id, - # valid_login, - user_by_email, - hash_password, # only used in tests... maybe make gn-auth a GN3 dependency - same_password, - set_user_password -) diff --git a/gn3/auth/authorisation/users/base.py b/gn3/auth/authorisation/users/base.py deleted file mode 100644 index 0e72ed2..0000000 --- a/gn3/auth/authorisation/users/base.py +++ /dev/null @@ -1,128 +0,0 @@ -"""User-specific code and data structures.""" -from uuid import UUID, uuid4 -from typing import Any, Tuple, NamedTuple - -from argon2 import PasswordHasher -from argon2.exceptions import VerifyMismatchError - -from gn3.auth import db -from gn3.auth.authorisation.errors import NotFoundError - -class User(NamedTuple): - """Class representing a user.""" - user_id: UUID - email: str - name: str - - def get_user_id(self): - """Return the user's UUID. Mostly for use with Authlib.""" - return self.user_id - - def dictify(self) -> dict[str, Any]: - """Return a dict representation of `User` objects.""" - return {"user_id": self.user_id, "email": self.email, "name": self.name} - -DUMMY_USER = User(user_id=UUID("a391cf60-e8b7-4294-bd22-ddbbda4b3530"), - email="gn3@dummy.user", - name="Dummy user to use as placeholder") - -def user_by_email(conn: db.DbConnection, email: str) -> User: - """Retrieve user from database by their email address""" - with db.cursor(conn) as cursor: - cursor.execute("SELECT * FROM users WHERE email=?", (email,)) - row = cursor.fetchone() - - if row: - return User(UUID(row["user_id"]), row["email"], row["name"]) - - raise NotFoundError(f"Could not find user with email {email}") - -def user_by_id(conn: db.DbConnection, user_id: UUID) -> User: - """Retrieve user from database by their user id""" - with db.cursor(conn) as cursor: - cursor.execute("SELECT * FROM users WHERE user_id=?", (str(user_id),)) - row = cursor.fetchone() - - if row: - return User(UUID(row["user_id"]), row["email"], row["name"]) - - raise NotFoundError(f"Could not find user with ID {user_id}") - -def same_password(password: str, hashed: str) -> bool: - """Check that `raw_password` is hashed to `hash`""" - try: - return hasher().verify(hashed, password) - except VerifyMismatchError as _vme: - return False - -def valid_login(conn: db.DbConnection, user: User, password: str) -> bool: - """Check the validity of the provided credentials for login.""" - with db.cursor(conn) as cursor: - cursor.execute( - ("SELECT * FROM users LEFT JOIN user_credentials " - "ON users.user_id=user_credentials.user_id " - "WHERE users.user_id=?"), - (str(user.user_id),)) - row = cursor.fetchone() - - if row is None: - return False - - return same_password(password, row["password"]) - -def save_user(cursor: db.DbCursor, email: str, name: str) -> User: - """ - Create and persist a user. - - The user creation could be done during a transaction, therefore the function - takes a cursor object rather than a connection. - - The newly created and persisted user is then returned. - """ - user_id = uuid4() - cursor.execute("INSERT INTO users VALUES (?, ?, ?)", - (str(user_id), email, name)) - return User(user_id, email, name) - -def hasher(): - """Retrieve PasswordHasher object""" - # TODO: Maybe tune the parameters here... - # Tuneable Parameters: - # - time_cost (default: 2) - # - memory_cost (default: 102400) - # - parallelism (default: 8) - # - hash_len (default: 16) - # - salt_len (default: 16) - # - encoding (default: 'utf-8') - # - type (default: <Type.ID: 2>) - return PasswordHasher() - -def hash_password(password): - """Hash the password.""" - return hasher().hash(password) - -def set_user_password( - cursor: db.DbCursor, user: User, password: str) -> Tuple[User, bytes]: - """Set the given user's password in the database.""" - hashed_password = hash_password(password) - cursor.execute( - ("INSERT INTO user_credentials VALUES (:user_id, :hash) " - "ON CONFLICT (user_id) DO UPDATE SET password=:hash"), - {"user_id": str(user.user_id), "hash": hashed_password}) - return user, hashed_password - -def users(conn: db.DbConnection, - ids: tuple[UUID, ...] = tuple()) -> tuple[User, ...]: - """ - Fetch all users with the given `ids`. If `ids` is empty, return ALL users. - """ - params = ", ".join(["?"] * len(ids)) - with db.cursor(conn) as cursor: - query = "SELECT * FROM users" + ( - f" WHERE user_id IN ({params})" - if len(ids) > 0 else "") - print(query) - cursor.execute(query, tuple(str(the_id) for the_id in ids)) - return tuple(User(UUID(row["user_id"]), row["email"], row["name"]) - for row in cursor.fetchall()) - return tuple() diff --git a/gn3/auth/authorisation/users/collections/__init__.py b/gn3/auth/authorisation/users/collections/__init__.py deleted file mode 100644 index 88ab040..0000000 --- a/gn3/auth/authorisation/users/collections/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Package dealing with user collections.""" diff --git a/gn3/auth/authorisation/users/collections/models.py b/gn3/auth/authorisation/users/collections/models.py deleted file mode 100644 index 7577fa8..0000000 --- a/gn3/auth/authorisation/users/collections/models.py +++ /dev/null @@ -1,269 +0,0 @@ -"""Handle user collections.""" -import json -from uuid import UUID, uuid4 -from datetime import datetime - -from redis import Redis -from email_validator import validate_email, EmailNotValidError - -from gn3.auth.authorisation.errors import InvalidData, NotFoundError - -from ..models import User - -__OLD_REDIS_COLLECTIONS_KEY__ = "collections" -__REDIS_COLLECTIONS_KEY__ = "collections2" - -class CollectionJSONEncoder(json.JSONEncoder): - """Serialise collection objects into JSON.""" - def default(self, obj):# pylint: disable=[arguments-renamed] - if isinstance(obj, UUID): - return str(obj) - if isinstance(obj, datetime): - return obj.strftime("%b %d %Y %I:%M%p") - return json.JSONEncoder.default(self, obj) - -def __valid_email__(email:str) -> bool: - """Check for email validity.""" - try: - validate_email(email, check_deliverability=True) - except EmailNotValidError as _enve: - return False - return True - -def __toggle_boolean_field__( - rconn: Redis, email: str, field: str): - """Toggle the valuen of a boolean field""" - mig_dict = json.loads(rconn.hget("migratable-accounts", email) or "{}") - if bool(mig_dict): - rconn.hset("migratable-accounts", email, - json.dumps({**mig_dict, field: not mig_dict.get(field, True)})) - -def __build_email_uuid_bridge__(rconn: Redis): - """ - Build a connection between new accounts and old user accounts. - - The only thing that is common between the two is the email address, - therefore, we use that to link the two items. - """ - old_accounts = { - account["email_address"]: { - "user_id": account["user_id"], - "collections-migrated": False, - "resources_migrated": False - } for account in ( - acct for acct in - (json.loads(usr) for usr in rconn.hgetall("users").values()) - if (bool(acct.get("email_address", False)) and - __valid_email__(acct["email_address"]))) - } - if bool(old_accounts): - rconn.hset("migratable-accounts", mapping={ - key: json.dumps(value) for key,value in old_accounts.items() - }) - return old_accounts - -def __retrieve_old_accounts__(rconn: Redis) -> dict: - accounts = rconn.hgetall("migratable-accounts") - if accounts: - return { - key: json.loads(value) for key, value in accounts.items() - } - return __build_email_uuid_bridge__(rconn) - -def parse_collection(coll: dict) -> dict: - """Parse the collection as persisted in redis to a usable python object.""" - created = coll.get("created", coll.get("created_timestamp")) - changed = coll.get("changed", coll.get("changed_timestamp")) - return { - "id": UUID(coll["id"]), - "name": coll["name"], - "created": datetime.strptime(created, "%b %d %Y %I:%M%p"), - "changed": datetime.strptime(changed, "%b %d %Y %I:%M%p"), - "num_members": int(coll["num_members"]), - "members": coll["members"] - } - -def dump_collection(pythoncollection: dict) -> str: - """Convert the collection from a python object to a json string.""" - return json.dumps(pythoncollection, cls=CollectionJSONEncoder) - -def __retrieve_old_user_collections__(rconn: Redis, old_user_id: UUID) -> tuple: - """Retrieve any old collections relating to the user.""" - return tuple(parse_collection(coll) for coll in - json.loads(rconn.hget( - __OLD_REDIS_COLLECTIONS_KEY__, str(old_user_id)) or "[]")) - -def user_collections(rconn: Redis, user: User) -> tuple[dict, ...]: - """Retrieve current user collections.""" - collections = tuple(parse_collection(coll) for coll in json.loads( - rconn.hget(__REDIS_COLLECTIONS_KEY__, str(user.user_id)) or - "[]")) - old_accounts = __retrieve_old_accounts__(rconn) - if (user.email in old_accounts and - not old_accounts[user.email]["collections-migrated"]): - old_user_id = old_accounts[user.email]["user_id"] - collections = tuple({ - coll["id"]: coll for coll in ( - collections + __retrieve_old_user_collections__( - rconn, UUID(old_user_id))) - }.values()) - __toggle_boolean_field__(rconn, user.email, "collections-migrated") - rconn.hset( - __REDIS_COLLECTIONS_KEY__, - key=str(user.user_id), - value=json.dumps(collections, cls=CollectionJSONEncoder)) - return collections - -def save_collections(rconn: Redis, user: User, collections: tuple[dict, ...]) -> tuple[dict, ...]: - """Save the `collections` to redis.""" - rconn.hset( - __REDIS_COLLECTIONS_KEY__, - str(user.user_id), - json.dumps(collections, cls=CollectionJSONEncoder)) - return collections - -def add_to_user_collections(rconn: Redis, user: User, collection: dict) -> dict: - """Add `collection` to list of user collections.""" - ucolls = user_collections(rconn, user) - save_collections(rconn, user, ucolls + (collection,)) - return collection - -def create_collection(rconn: Redis, user: User, name: str, traits: tuple) -> dict: - """Create a new collection.""" - now = datetime.utcnow() - return add_to_user_collections(rconn, user, { - "id": uuid4(), - "name": name, - "created": now, - "changed": now, - "num_members": len(traits), - "members": traits - }) - -def get_collection(rconn: Redis, user: User, collection_id: UUID) -> dict: - """Retrieve the collection with ID `collection_id`.""" - colls = tuple(coll for coll in user_collections(rconn, user) - if coll["id"] == collection_id) - if len(colls) == 0: - raise NotFoundError( - f"Could not find a collection with ID `{collection_id}` for user " - f"with ID `{user.user_id}`") - if len(colls) > 1: - err = InvalidData( - "More than one collection was found having the ID " - f"`{collection_id}` for user with ID `{user.user_id}`.") - err.error_code = 513 - raise err - return colls[0] - -def __raise_if_collections_empty__(user: User, collections: tuple[dict, ...]): - """Raise an exception if no collections are found for `user`.""" - if len(collections) < 1: - raise NotFoundError(f"No collections found for user `{user.user_id}`") - -def __raise_if_not_single_collection__( - user: User, collection_id: UUID, collections: tuple[dict, ...]): - """ - Raise an exception there is zero, or more than one collection for `user`. - """ - if len(collections) == 0: - raise NotFoundError(f"No collections found for user `{user.user_id}` " - f"with ID `{collection_id}`.") - if len(collections) > 1: - err = InvalidData( - "More than one collection was found having the ID " - f"`{collection_id}` for user with ID `{user.user_id}`.") - err.error_code = 513 - raise err - -def delete_collections(rconn: Redis, - user: User, - collection_ids: tuple[UUID, ...]) -> tuple[dict, ...]: - """ - Delete collections with the given `collection_ids` returning the deleted - collections. - """ - ucolls = user_collections(rconn, user) - save_collections( - rconn, - user, - tuple(coll for coll in ucolls if coll["id"] not in collection_ids)) - return tuple(coll for coll in ucolls if coll["id"] in collection_ids) - -def add_traits(rconn: Redis, - user: User, - collection_id: UUID, - traits: tuple[str, ...]) -> dict: - """ - Add `traits` to the `user` collection identified by `collection_id`. - - Returns: The collection with the new traits added. - """ - ucolls = user_collections(rconn, user) - __raise_if_collections_empty__(user, ucolls) - - mod_col = tuple(coll for coll in ucolls if coll["id"] == collection_id) - __raise_if_not_single_collection__(user, collection_id, mod_col) - new_members = tuple(set(tuple(mod_col[0]["members"]) + traits)) - new_coll = { - **mod_col[0], - "members": new_members, - "num_members": len(new_members) - } - save_collections( - rconn, - user, - (tuple(coll for coll in ucolls if coll["id"] != collection_id) + - (new_coll,))) - return new_coll - -def remove_traits(rconn: Redis, - user: User, - collection_id: UUID, - traits: tuple[str, ...]) -> dict: - """ - Remove `traits` from the `user` collection identified by `collection_id`. - - Returns: The collection with the specified `traits` removed. - """ - ucolls = user_collections(rconn, user) - __raise_if_collections_empty__(user, ucolls) - - mod_col = tuple(coll for coll in ucolls if coll["id"] == collection_id) - __raise_if_not_single_collection__(user, collection_id, mod_col) - new_members = tuple( - trait for trait in mod_col[0]["members"] if trait not in traits) - new_coll = { - **mod_col[0], - "members": new_members, - "num_members": len(new_members) - } - save_collections( - rconn, - user, - (tuple(coll for coll in ucolls if coll["id"] != collection_id) + - (new_coll,))) - return new_coll - -def change_name(rconn: Redis, - user: User, - collection_id: UUID, - new_name: str) -> dict: - """ - Change the collection's name. - - Returns: The collection with the new name. - """ - ucolls = user_collections(rconn, user) - __raise_if_collections_empty__(user, ucolls) - - mod_col = tuple(coll for coll in ucolls if coll["id"] == collection_id) - __raise_if_not_single_collection__(user, collection_id, mod_col) - - new_coll = {**mod_col[0], "name": new_name} - save_collections( - rconn, - user, - (tuple(coll for coll in ucolls if coll["id"] != collection_id) + - (new_coll,))) - return new_coll diff --git a/gn3/auth/authorisation/users/collections/views.py b/gn3/auth/authorisation/users/collections/views.py deleted file mode 100644 index 775e8bc..0000000 --- a/gn3/auth/authorisation/users/collections/views.py +++ /dev/null @@ -1,239 +0,0 @@ -"""Views regarding user collections.""" -from uuid import UUID - -from redis import Redis -from flask import jsonify, request, Response, Blueprint, current_app - -from gn3.auth import db -from gn3.auth.db_utils import with_db_connection -from gn3.auth.authorisation.checks import require_json -from gn3.auth.authorisation.errors import NotFoundError - -from gn3.auth.authorisation.users import User, user_by_id -from gn3.auth.authorisation.oauth2.resource_server import require_oauth - -from .models import ( - add_traits, - change_name, - remove_traits, - get_collection, - user_collections, - save_collections, - create_collection, - delete_collections as _delete_collections) - -collections = Blueprint("collections", __name__) - -@collections.route("/list") -@require_oauth("profile user") -def list_user_collections() -> Response: - """Retrieve the user ids""" - with (require_oauth.acquire("profile user") as the_token, - Redis.from_url(current_app.config["REDIS_URI"], - decode_responses=True) as redisconn): - return jsonify(user_collections(redisconn, the_token.user)) - -@collections.route("/<uuid:anon_id>/list") -def list_anonymous_collections(anon_id: UUID) -> Response: - """Fetch anonymous collections""" - with Redis.from_url( - current_app.config["REDIS_URI"], decode_responses=True) as redisconn: - def __list__(conn: db.DbConnection) -> tuple: - try: - _user = user_by_id(conn, anon_id) - current_app.logger.warning( - "Fetch collections for authenticated user using the " - "`list_user_collections()` endpoint.") - return tuple() - except NotFoundError as _nfe: - return user_collections( - redisconn, User(anon_id, "anon@ymous.user", "Anonymous User")) - - return jsonify(with_db_connection(__list__)) - -@require_oauth("profile user") -def __new_collection_as_authenticated_user__(redisconn, name, traits): - """Create a new collection as an authenticated user.""" - with require_oauth.acquire("profile user") as token: - return create_collection(redisconn, token.user, name, traits) - -def __new_collection_as_anonymous_user__(redisconn, name, traits): - """Create a new collection as an anonymous user.""" - return create_collection(redisconn, - User(UUID(request.json.get("anon_id")), - "anon@ymous.user", - "Anonymous User"), - name, - traits) - -@collections.route("/new", methods=["POST"]) -@require_json -def new_user_collection() -> Response: - """Create a new collection.""" - with (Redis.from_url(current_app.config["REDIS_URI"], - decode_responses=True) as redisconn): - traits = tuple(request.json.get("traits", tuple()))# type: ignore[union-attr] - name = request.json.get("name")# type: ignore[union-attr] - if bool(request.headers.get("Authorization")): - return jsonify(__new_collection_as_authenticated_user__( - redisconn, name, traits)) - return jsonify(__new_collection_as_anonymous_user__( - redisconn, name, traits)) - -@collections.route("/<uuid:collection_id>/view", methods=["POST"]) -@require_json -def view_collection(collection_id: UUID) -> Response: - """View a particular collection""" - with (Redis.from_url(current_app.config["REDIS_URI"], - decode_responses=True) as redisconn): - if bool(request.headers.get("Authorization")): - with require_oauth.acquire("profile user") as token: - return jsonify(get_collection(redisconn, - token.user, - collection_id)) - return jsonify(get_collection( - redisconn, - User( - UUID(request.json.get("anon_id")),#type: ignore[union-attr] - "anon@ymous.user", - "Anonymous User"), - collection_id)) - -@collections.route("/anonymous/import", methods=["POST"]) -@require_json -@require_oauth("profile user") -def import_anonymous() -> Response: - """Import anonymous collections.""" - with (require_oauth.acquire("profile user") as token, - Redis.from_url(current_app.config["REDIS_URI"], - decode_responses=True) as redisconn): - anon_id = UUID(request.json.get("anon_id"))#type: ignore[union-attr] - anon_colls = user_collections(redisconn, User( - anon_id, "anon@ymous.user", "Anonymous User")) - save_collections( - redisconn, - token.user, - (user_collections(redisconn, token.user) + - anon_colls)) - redisconn.hdel("collections", str(anon_id)) - return jsonify({ - "message": f"Import of {len(anon_colls)} was successful." - }) - -@collections.route("/anonymous/delete", methods=["POST"]) -@require_json -@require_oauth("profile user") -def delete_anonymous() -> Response: - """Delete anonymous collections.""" - with (require_oauth.acquire("profile user") as _token, - Redis.from_url(current_app.config["REDIS_URI"], - decode_responses=True) as redisconn): - anon_id = UUID(request.json.get("anon_id"))#type: ignore[union-attr] - anon_colls = user_collections(redisconn, User( - anon_id, "anon@ymous.user", "Anonymous User")) - redisconn.hdel("collections", str(anon_id)) - return jsonify({ - "message": f"Deletion of {len(anon_colls)} was successful." - }) - -@collections.route("/delete", methods=["POST"]) -@require_json -def delete_collections(): - """Delete specified collections.""" - with (Redis.from_url(current_app.config["REDIS_URI"], - decode_responses=True) as redisconn): - coll_ids = tuple(UUID(cid) for cid in request.json["collection_ids"]) - deleted = _delete_collections( - redisconn, - User(request.json["anon_id"], "anon@ymous.user", "Anonymous User"), - coll_ids) - if bool(request.headers.get("Authorization")): - with require_oauth.acquire("profile user") as token: - deleted = deleted + _delete_collections( - redisconn, token.user, coll_ids) - - return jsonify({ - "message": f"Deleted {len(deleted)} collections."}) - -@collections.route("/<uuid:collection_id>/traits/remove", methods=["POST"]) -@require_json -def remove_traits_from_collection(collection_id: UUID) -> Response: - """Remove specified traits from collection with ID `collection_id`.""" - if len(request.json["traits"]) < 1:#type: ignore[index] - return jsonify({"message": "No trait to remove from collection."}) - - the_traits = tuple(request.json["traits"])#type: ignore[index] - with (Redis.from_url(current_app.config["REDIS_URI"], - decode_responses=True) as redisconn): - if not bool(request.headers.get("Authorization")): - coll = remove_traits( - redisconn, - User(request.json["anon_id"],#type: ignore[index] - "anon@ymous.user", - "Anonymous User"), - collection_id, - the_traits) - else: - with require_oauth.acquire("profile user") as token: - coll = remove_traits( - redisconn, token.user, collection_id, the_traits) - - return jsonify({ - "message": f"Deleted {len(the_traits)} traits from collection.", - "collection": coll - }) - -@collections.route("/<uuid:collection_id>/traits/add", methods=["POST"]) -@require_json -def add_traits_to_collection(collection_id: UUID) -> Response: - """Add specified traits to collection with ID `collection_id`.""" - if len(request.json["traits"]) < 1:#type: ignore[index] - return jsonify({"message": "No trait to add to collection."}) - - the_traits = tuple(request.json["traits"])#type: ignore[index] - with (Redis.from_url(current_app.config["REDIS_URI"], - decode_responses=True) as redisconn): - if not bool(request.headers.get("Authorization")): - coll = add_traits( - redisconn, - User(request.json["anon_id"],#type: ignore[index] - "anon@ymous.user", - "Anonymous User"), - collection_id, - the_traits) - else: - with require_oauth.acquire("profile user") as token: - coll = add_traits( - redisconn, token.user, collection_id, the_traits) - - return jsonify({ - "message": f"Added {len(the_traits)} traits to collection.", - "collection": coll - }) - -@collections.route("/<uuid:collection_id>/rename", methods=["POST"]) -@require_json -def rename_collection(collection_id: UUID) -> Response: - """Rename the given collection""" - if not bool(request.json["new_name"]):#type: ignore[index] - return jsonify({"message": "No new name to change to."}) - - new_name = request.json["new_name"]#type: ignore[index] - with (Redis.from_url(current_app.config["REDIS_URI"], - decode_responses=True) as redisconn): - if not bool(request.headers.get("Authorization")): - coll = change_name(redisconn, - User(UUID(request.json["anon_id"]),#type: ignore[index] - "anon@ymous.user", - "Anonymous User"), - collection_id, - new_name) - else: - with require_oauth.acquire("profile user") as token: - coll = change_name( - redisconn, token.user, collection_id, new_name) - - return jsonify({ - "message": "Collection rename successful.", - "collection": coll - }) diff --git a/gn3/auth/authorisation/users/models.py b/gn3/auth/authorisation/users/models.py deleted file mode 100644 index 0157154..0000000 --- a/gn3/auth/authorisation/users/models.py +++ /dev/null @@ -1,66 +0,0 @@ -"""Functions for acting on users.""" -import uuid -from functools import reduce - -from gn3.auth import db -from gn3.auth.authorisation.roles.models import Role -from gn3.auth.authorisation.checks import authorised_p -from gn3.auth.authorisation.privileges import Privilege - -from .base import User - -@authorised_p( - ("system:user:list",), - "You do not have the appropriate privileges to list users.", - oauth2_scope="profile user") -def list_users(conn: db.DbConnection) -> tuple[User, ...]: - """List out all users.""" - with db.cursor(conn) as cursor: - cursor.execute("SELECT * FROM users") - return tuple( - User(uuid.UUID(row["user_id"]), row["email"], row["name"]) - for row in cursor.fetchall()) - -def __build_resource_roles__(rows): - def __build_roles__(roles, row): - role_id = uuid.UUID(row["role_id"]) - priv = Privilege(row["privilege_id"], row["privilege_description"]) - role = roles.get(role_id, Role( - role_id, row["role_name"], bool(row["user_editable"]), tuple())) - return { - **roles, - role_id: Role(role_id, role.role_name, role.user_editable, role.privileges + (priv,)) - } - def __build__(acc, row): - resource_id = uuid.UUID(row["resource_id"]) - return { - **acc, - resource_id: __build_roles__(acc.get(resource_id, {}), row) - } - return { - resource_id: tuple(roles.values()) - for resource_id, roles in reduce(__build__, rows, {}).items() - } - -# @authorised_p( -# ("",), -# ("You do not have the appropriate privileges to view a user's roles on " -# "resources.")) -def user_resource_roles(conn: db.DbConnection, user: User) -> dict[uuid.UUID, tuple[Role, ...]]: - """Fetch all the user's roles on resources.""" - with db.cursor(conn) as cursor: - cursor.execute( - "SELECT res.*, rls.*, p.*" - "FROM resources AS res INNER JOIN " - "group_user_roles_on_resources AS guror " - "ON res.resource_id=guror.resource_id " - "LEFT JOIN roles AS rls " - "ON guror.role_id=rls.role_id " - "LEFT JOIN role_privileges AS rp " - "ON rls.role_id=rp.role_id " - "LEFT JOIN privileges AS p " - "ON rp.privilege_id=p.privilege_id " - "WHERE guror.user_id = ?", - (str(user.user_id),)) - return __build_resource_roles__( - (dict(row) for row in cursor.fetchall())) diff --git a/gn3/auth/authorisation/users/views.py b/gn3/auth/authorisation/users/views.py deleted file mode 100644 index f75b51e..0000000 --- a/gn3/auth/authorisation/users/views.py +++ /dev/null @@ -1,173 +0,0 @@ -"""User authorisation endpoints.""" -import traceback -from typing import Any -from functools import partial - -import sqlite3 -from email_validator import validate_email, EmailNotValidError -from flask import request, jsonify, Response, Blueprint, current_app - -from gn3.auth import db -from gn3.auth.dictify import dictify -from gn3.auth.db_utils import with_db_connection -from gn3.auth.authorisation.oauth2.resource_server import require_oauth -from gn3.auth.authorisation.users import User, save_user, set_user_password -from gn3.auth.authorisation.oauth2.oauth2token import token_by_access_token - -from .models import list_users -from .collections.views import collections - -from ..groups.models import user_group as _user_group -from ..resources.models import user_resources as _user_resources -from ..roles.models import assign_default_roles, user_roles as _user_roles -from ..errors import ( - NotFoundError, UsernameError, PasswordError, UserRegistrationError) - -users = Blueprint("users", __name__) -users.register_blueprint(collections, url_prefix="/collections") - -@users.route("/", methods=["GET"]) -@require_oauth("profile") -def user_details() -> Response: - """Return user's details.""" - with require_oauth.acquire("profile") as the_token: - user = the_token.user - user_dets = { - "user_id": user.user_id, "email": user.email, "name": user.name, - "group": False - } - with db.connection(current_app.config["AUTH_DB"]) as conn: - the_group = _user_group(conn, user).maybe(# type: ignore[misc] - False, lambda grp: grp)# type: ignore[arg-type] - return jsonify({ - **user_dets, - "group": dictify(the_group) if the_group else False - }) - -@users.route("/roles", methods=["GET"]) -@require_oauth("role") -def user_roles() -> Response: - """Return the non-resource roles assigned to the user.""" - with require_oauth.acquire("role") as token: - with db.connection(current_app.config["AUTH_DB"]) as conn: - return jsonify(tuple( - dictify(role) for role in _user_roles(conn, token.user))) - -def validate_password(password, confirm_password) -> str: - """Validate the provided password.""" - if len(password) < 8: - raise PasswordError("The password must be at least 8 characters long.") - - if password != confirm_password: - raise PasswordError("Mismatched password values") - - return password - -def validate_username(name: str) -> str: - """Validate the provides name.""" - if name == "": - raise UsernameError("User's name not provided.") - - return name - -def __assert_not_logged_in__(conn: db.DbConnection): - bearer = request.headers.get('Authorization') - if bearer: - token = token_by_access_token(conn, bearer.split(None)[1]).maybe(# type: ignore[misc] - False, lambda tok: tok) - if token: - raise UserRegistrationError( - "Cannot register user while authenticated") - -@users.route("/register", methods=["POST"]) -def register_user() -> Response: - """Register a user.""" - with db.connection(current_app.config["AUTH_DB"]) as conn: - __assert_not_logged_in__(conn) - - try: - form = request.form - email = validate_email(form.get("email", "").strip(), - check_deliverability=True) - password = validate_password( - form.get("password", "").strip(), - form.get("confirm_password", "").strip()) - user_name = validate_username(form.get("user_name", "").strip()) - with db.cursor(conn) as cursor: - user, _hashed_password = set_user_password( - cursor, save_user( - cursor, email["email"], user_name), password) - assign_default_roles(cursor, user) - return jsonify( - { - "user_id": user.user_id, - "email": user.email, - "name": user.name - }) - except sqlite3.IntegrityError as sq3ie: - current_app.logger.debug(traceback.format_exc()) - raise UserRegistrationError( - "A user with that email already exists") from sq3ie - except EmailNotValidError as enve: - current_app.logger.debug(traceback.format_exc()) - raise(UserRegistrationError(f"Email Error: {str(enve)}")) from enve - - raise Exception( - "unknown_error", "The system experienced an unexpected error.") - -@users.route("/group", methods=["GET"]) -@require_oauth("profile group") -def user_group() -> Response: - """Retrieve the group in which the user is a member.""" - with require_oauth.acquire("profile group") as the_token: - db_uri = current_app.config["AUTH_DB"] - with db.connection(db_uri) as conn: - group = _user_group(conn, the_token.user).maybe(# type: ignore[misc] - False, lambda grp: grp)# type: ignore[arg-type] - - if group: - return jsonify(dictify(group)) - raise NotFoundError("User is not a member of any group.") - -@users.route("/resources", methods=["GET"]) -@require_oauth("profile resource") -def user_resources() -> Response: - """Retrieve the resources a user has access to.""" - with require_oauth.acquire("profile resource") as the_token: - db_uri = current_app.config["AUTH_DB"] - with db.connection(db_uri) as conn: - return jsonify([ - dictify(resource) for resource in - _user_resources(conn, the_token.user)]) - -@users.route("group/join-request", methods=["GET"]) -@require_oauth("profile group") -def user_join_request_exists(): - """Check whether a user has an active group join request.""" - def __request_exists__(conn: db.DbConnection, user: User) -> dict[str, Any]: - with db.cursor(conn) as cursor: - cursor.execute( - "SELECT * FROM group_join_requests WHERE requester_id=? AND " - "status = 'PENDING'", - (str(user.user_id),)) - res = cursor.fetchone() - if res: - return { - "request_id": res["request_id"], - "exists": True - } - return{ - "status": "Not found", - "exists": False - } - with require_oauth.acquire("profile group") as the_token: - return jsonify(with_db_connection(partial( - __request_exists__, user=the_token.user))) - -@users.route("/list", methods=["GET"]) -@require_oauth("profile user") -def list_all_users() -> Response: - """List all the users.""" - with require_oauth.acquire("profile group") as _the_token: - return jsonify(tuple( - dictify(user) for user in with_db_connection(list_users))) diff --git a/gn3/auth/db_utils.py b/gn3/auth/db_utils.py deleted file mode 100644 index c06b026..0000000 --- a/gn3/auth/db_utils.py +++ /dev/null @@ -1,14 +0,0 @@ -"""Some common auth db utilities""" -from typing import Any, Callable -from flask import current_app - -from . import db - -def with_db_connection(func: Callable[[db.DbConnection], Any]) -> Any: - """ - Takes a function of one argument `func`, whose one argument is a database - connection. - """ - db_uri = current_app.config["AUTH_DB"] - with db.connection(db_uri) as conn: - return func(conn) diff --git a/gn3/auth/dictify.py b/gn3/auth/dictify.py deleted file mode 100644 index f9337f6..0000000 --- a/gn3/auth/dictify.py +++ /dev/null @@ -1,12 +0,0 @@ -"""Module for dictifying objects""" - -from typing import Any, Protocol - -class Dictifiable(Protocol):# pylint: disable=[too-few-public-methods] - """Type annotation for generic object with a `dictify` method.""" - def dictify(self): - """Convert the object to a dict""" - -def dictify(obj: Dictifiable) -> dict[str, Any]: - """Turn `obj` to a dict representation.""" - return obj.dictify() diff --git a/gn3/auth/views.py b/gn3/auth/views.py deleted file mode 100644 index da64049..0000000 --- a/gn3/auth/views.py +++ /dev/null @@ -1,16 +0,0 @@ -"""The Auth(oris|entic)ation routes/views""" -from flask import Blueprint - -from .authorisation.data.views import data -from .authorisation.users.views import users -from .authorisation.roles.views import roles -from .authorisation.groups.views import groups -from .authorisation.resources.views import resources - -oauth2 = Blueprint("oauth2", __name__) - -oauth2.register_blueprint(data, url_prefix="/data") -oauth2.register_blueprint(users, url_prefix="/user") -oauth2.register_blueprint(roles, url_prefix="/role") -oauth2.register_blueprint(groups, url_prefix="/group") -oauth2.register_blueprint(resources, url_prefix="/resource") diff --git a/gn3/authentication.py b/gn3/authentication.py index bb717dd..fac6ed9 100644 --- a/gn3/authentication.py +++ b/gn3/authentication.py @@ -61,7 +61,7 @@ def get_user_membership(conn: Redis, user_id: str, """ results = {"member": False, "admin": False} - for key, value in conn.hgetall('groups').items(): + for key, value in conn.hgetall('groups').items():# type: ignore[union-attr] if key == group_id: group_info = json.loads(value) if user_id in group_info.get("admins"): @@ -94,7 +94,8 @@ def get_highest_user_access_role( access_role = {} response = requests.get(urljoin(gn_proxy_url, ("available?resource=" - f"{resource_id}&user={user_id}"))) + f"{resource_id}&user={user_id}")), + timeout=500) for key, value in json.loads(response.content).items(): access_role[key] = max(map(lambda role: role_mapping[role], value)) return access_role @@ -113,7 +114,7 @@ def get_groups_by_user_uid(user_uid: str, conn: Redis) -> Dict: """ admin = [] member = [] - for group_uuid, group_info in conn.hgetall("groups").items(): + for group_uuid, group_info in conn.hgetall("groups").items():# type: ignore[union-attr] group_info = json.loads(group_info) group_info["uuid"] = group_uuid if user_uid in group_info.get('admins'): @@ -130,14 +131,14 @@ def get_user_info_by_key(key: str, value: str, conn: Redis) -> Optional[Dict]: """Given a key, get a user's information if value is matched""" if key != "user_id": - for user_uuid, user_info in conn.hgetall("users").items(): + for user_uuid, user_info in conn.hgetall("users").items():# type: ignore[union-attr] user_info = json.loads(user_info) if (key in user_info and user_info.get(key) == value): user_info["user_id"] = user_uuid return user_info elif key == "user_id": if user_info := conn.hget("users", value): - user_info = json.loads(user_info) + user_info = json.loads(user_info)# type: ignore[arg-type] user_info["user_id"] = value return user_info return None diff --git a/gn3/case_attributes.py b/gn3/case_attributes.py deleted file mode 100644 index efc82e9..0000000 --- a/gn3/case_attributes.py +++ /dev/null @@ -1,642 +0,0 @@ -"""Implement case-attribute manipulations.""" -import os -import csv -import json -import uuid -import tempfile -from typing import Union -from enum import Enum, auto -from pathlib import Path -from functools import reduce -from datetime import datetime -from urllib.parse import urljoin - -import requests -from MySQLdb.cursors import DictCursor -from authlib.integrations.flask_oauth2.errors import _HTTPException -from flask import ( - jsonify, - request, - Response, - Blueprint, - current_app, - make_response) - -from gn3.commands import run_cmd - -from gn3.db_utils import Connection, database_connection - -from gn3.oauth2.authorisation import require_token -from gn3.auth.authorisation.errors import AuthorisationError - -caseattr = Blueprint("case-attribute", __name__) - -CATTR_DIFFS_DIR = "case-attribute-diffs" - -class NoDiffError(ValueError): - """Raised if there is no difference between the old and new data.""" - def __init__(self): - """Initialise exception.""" - super().__init__( - self, "No difference between existing data and sent data.") - -class EditStatus(Enum): - """Enumeration for the status of the edits.""" - review = auto() # pylint: disable=[invalid-name] - approved = auto() # pylint: disable=[invalid-name] - rejected = auto() # pylint: disable=[invalid-name] - - def __str__(self): - """Print out human-readable form.""" - return self.name - -class CAJSONEncoder(json.JSONEncoder): - """Encoder for CaseAttribute-specific data""" - def default(self, obj): # pylint: disable=[arguments-renamed] - """Default encoder""" - if isinstance(obj, datetime): - return obj.isoformat() - if isinstance(obj, uuid.UUID): - return str(obj) - return json.JSONEncoder.default(self, obj) - -def required_access( - token: dict, - inbredset_id: int, - access_levels: tuple[str, ...] -) -> Union[bool, tuple[str, ...]]: - """Check whether the user has the appropriate access""" - def __species_id__(conn): - with conn.cursor() as cursor: - cursor.execute( - "SELECT SpeciesId FROM InbredSet WHERE InbredSetId=%s", - (inbredset_id,)) - return cursor.fetchone()[0] - try: - with database_connection(current_app.config["SQL_URI"]) as conn: - result = requests.get( - # this section fetches the resource ID from the auth server - urljoin(current_app.config["AUTH_SERVER_URL"], - "auth/resource/inbredset/resource-id" - f"/{__species_id__(conn)}/{inbredset_id}")) - if result.status_code == 200: - resource_id = result.json()["resource-id"] - auth = requests.post( - # this section fetches the authorisations/privileges that - # the current user has on the resource we got above - urljoin(current_app.config["AUTH_SERVER_URL"], - "auth/resource/authorisation"), - json={"resource-ids": [resource_id]}, - headers={"Authorization": f"Bearer {token['access_token']}"}) - if auth.status_code == 200: - privs = tuple(priv["privilege_id"] - for role in auth.json()[resource_id]["roles"] - for priv in role["privileges"]) - if all(lvl in privs for lvl in access_levels): - return privs - except _HTTPException as httpe: - raise AuthorisationError("You need to be logged in.") from httpe - - raise AuthorisationError( - f"User does not have the privileges {access_levels}") - -def __inbredset_group__(conn, inbredset_id): - """Return InbredSet group's top-level details.""" - with conn.cursor(cursorclass=DictCursor) as cursor: - cursor.execute( - "SELECT * FROM InbredSet WHERE InbredSetId=%(inbredset_id)s", - {"inbredset_id": inbredset_id}) - return dict(cursor.fetchone()) - -def __inbredset_strains__(conn, inbredset_id): - """Return all samples/strains for given InbredSet group.""" - with conn.cursor(cursorclass=DictCursor) as cursor: - cursor.execute( - "SELECT s.* FROM StrainXRef AS sxr INNER JOIN Strain AS s " - "ON sxr.StrainId=s.Id WHERE sxr.InbredSetId=%(inbredset_id)s " - "ORDER BY s.Name ASC", - {"inbredset_id": inbredset_id}) - return tuple(dict(row) for row in cursor.fetchall()) - -def __case_attribute_labels_by_inbred_set__(conn, inbredset_id): - """Return the case-attribute labels/names for the given InbredSet group.""" - with conn.cursor(cursorclass=DictCursor) as cursor: - cursor.execute( - "SELECT * FROM CaseAttribute WHERE InbredSetId=%(inbredset_id)s", - {"inbredset_id": inbredset_id}) - return tuple(dict(row) for row in cursor.fetchall()) - -@caseattr.route("/<int:inbredset_id>", methods=["GET"]) -def inbredset_group(inbredset_id: int) -> Response: - """Retrieve InbredSet group's details.""" - with database_connection(current_app.config["SQL_URI"]) as conn: - return jsonify(__inbredset_group__(conn, inbredset_id)) - -@caseattr.route("/<int:inbredset_id>/strains", methods=["GET"]) -def inbredset_strains(inbredset_id: int) -> Response: - """Retrieve ALL strains/samples relating to a specific InbredSet group.""" - with database_connection(current_app.config["SQL_URI"]) as conn: - return jsonify(__inbredset_strains__(conn, inbredset_id)) - -@caseattr.route("/<int:inbredset_id>/names", methods=["GET"]) -def inbredset_case_attribute_names(inbredset_id: int) -> Response: - """Retrieve ALL case-attributes for a specific InbredSet group.""" - with database_connection(current_app.config["SQL_URI"]) as conn: - return jsonify( - __case_attribute_labels_by_inbred_set__(conn, inbredset_id)) - -def __by_strain__(accumulator, item): - attr = {item["CaseAttributeName"]: item["CaseAttributeValue"]} - strain_name = item["StrainName"] - if bool(accumulator.get(strain_name)): - return { - **accumulator, - strain_name: { - **accumulator[strain_name], - "case-attributes": { - **accumulator[strain_name]["case-attributes"], - **attr - } - } - } - return { - **accumulator, - strain_name: { - **{ - key: value for key,value in item.items() - if key in ("StrainName", "StrainName2", "Symbol", "Alias") - }, - "case-attributes": attr - } - } - -def __case_attribute_values_by_inbred_set__( - conn: Connection, inbredset_id: int) -> tuple[dict, ...]: - """ - Retrieve Case-Attributes by their InbredSet ID. Do not call this outside - this module. - """ - with conn.cursor(cursorclass=DictCursor) as cursor: - cursor.execute( - "SELECT ca.Name AS CaseAttributeName, " - "caxrn.Value AS CaseAttributeValue, s.Name AS StrainName, " - "s.Name2 AS StrainName2, s.Symbol, s.Alias " - "FROM CaseAttribute AS ca " - "INNER JOIN CaseAttributeXRefNew AS caxrn " - "ON ca.CaseAttributeId=caxrn.CaseAttributeId " - "INNER JOIN Strain AS s " - "ON caxrn.StrainId=s.Id " - "WHERE ca.InbredSetId=%(inbredset_id)s " - "ORDER BY StrainName", - {"inbredset_id": inbredset_id}) - return tuple( - reduce(__by_strain__, cursor.fetchall(), {}).values()) - -@caseattr.route("/<int:inbredset_id>/values", methods=["GET"]) -def inbredset_case_attribute_values(inbredset_id: int) -> Response: - """Retrieve the group's (InbredSet's) case-attribute values.""" - with database_connection(current_app.config["SQL_URI"]) as conn: - return jsonify(__case_attribute_values_by_inbred_set__(conn, inbredset_id)) - -def __process_orig_data__(fieldnames, cadata, strains) -> tuple[dict, ...]: - """Process data from database and return tuple of dicts.""" - data = {item["StrainName"]: item for item in cadata} - return tuple( - { - "Strain": strain["Name"], - **{ - key: data.get( - strain["Name"], {}).get("case-attributes", {}).get(key, "") - for key in fieldnames[1:] - } - } for strain in strains) - -def __process_edit_data__(fieldnames, form_data) -> tuple[dict, ...]: - """Process data from form and return tuple of dicts.""" - def __process__(acc, strain_cattrs): - strain, cattrs = strain_cattrs - return acc + ({ - "Strain": strain, **{ - field: cattrs["case-attributes"].get(field, "") - for field in fieldnames[1:] - } - },) - return reduce(__process__, form_data.items(), tuple()) - -def __write_csv__(fieldnames, data): - """Write the given `data` to a csv file and return the path to the file.""" - fdesc, filepath = tempfile.mkstemp(".csv") - os.close(fdesc) - with open(filepath, "w", encoding="utf-8") as csvfile: - writer = csv.DictWriter(csvfile, fieldnames=fieldnames, dialect="unix") - writer.writeheader() - writer.writerows(data) - - return filepath - -def __compute_diff__( - fieldnames: tuple[str, ...], - original_data: tuple[dict, ...], - edit_data: tuple[dict, ...]): - """Return the diff of the data.""" - basefilename = __write_csv__(fieldnames, original_data) - deltafilename = __write_csv__(fieldnames, edit_data) - diff_results = run_cmd(json.dumps( - ["csvdiff", basefilename, deltafilename, "--format", "json"])) - os.unlink(basefilename) - os.unlink(deltafilename) - if diff_results["code"] == 0: - return json.loads(diff_results["output"]) - return {} - -def __queue_diff__(conn: Connection, diff_data, diff_data_dir: Path) -> Path: - """ - Queue diff for future processing. - - Returns: `diff` - On success, this will return the filename where the diff was saved. - On failure, it will raise a MySQL error. - """ - diff = diff_data["diff"] - if bool(diff["Additions"]) or bool(diff["Modifications"]) or bool(diff["Deletions"]): - diff_data_dir.mkdir(parents=True, exist_ok=True) - - created = datetime.now() - filepath = Path( - diff_data_dir, - f"{diff_data['inbredset_id']}:::{diff_data['user_id']}:::" - f"{created.isoformat()}.json") - with open(filepath, "w", encoding="utf8") as diff_file: - # We want this to fail if the metadata items below are not provided. - the_diff = {**diff_data, "created": created} - insert_id = __save_diff__(conn, the_diff, EditStatus.review) - diff_file.write(json.dumps({**the_diff, "db_id": insert_id}, - cls=CAJSONEncoder)) - return filepath - raise NoDiffError - -def __save_diff__(conn: Connection, diff_data: dict, status: EditStatus) -> int: - """Save to the database.""" - with conn.cursor() as cursor: - cursor.execute( - "INSERT INTO " - "caseattributes_audit(id, status, editor, json_diff_data, time_stamp) " - "VALUES(%(db_id)s, %(status)s, %(editor)s, %(diff)s, %(ts)s) " - "ON DUPLICATE KEY UPDATE status=%(status)s", - { - "db_id": diff_data.get("db_id"), - "status": str(status), - "editor": str(diff_data["user_id"]), - "diff": json.dumps(diff_data, cls=CAJSONEncoder), - "ts": diff_data["created"].isoformat() - }) - return diff_data.get("db_id") or cursor.lastrowid - -def __parse_diff_json__(json_str): - """Parse the json string to python objects.""" - raw_diff = json.loads(json_str) - return { - **raw_diff, - "db_id": int(raw_diff["db_id"]) if raw_diff.get("db_id") else None, - "inbredset_id": (int(raw_diff["inbredset_id"]) - if raw_diff.get("inbredset_id") else None), - "user_id": (uuid.UUID(raw_diff["user_id"]) - if raw_diff.get("user_id") else None), - "created": (datetime.fromisoformat(raw_diff["created"]) - if raw_diff.get("created") else None) - } - -def __load_diff__(diff_filename): - """Load the diff.""" - with open(diff_filename, encoding="utf8") as diff_file: - return __parse_diff_json__(diff_file.read()) - -def __apply_additions__( - cursor, inbredset_id: int, additions_diff) -> None: - """Apply additions: creates new case attributes.""" - # TODO: Not tested... will probably fail. - cursor.execute( - "INSERT INTO CaseAttribute(InbredSetId, Name, Description) " - "VALUES (:inbredset_id, :name, :desc)", - tuple({ - "inbredset_id": inbredset_id, - "name": diff["name"], - "desc": diff["description"] - } for diff in additions_diff)) - -def __apply_modifications__( - cursor, inbredset_id: int, modifications_diff, fieldnames) -> None: - """Apply modifications: changes values of existing case attributes.""" - cattrs = tuple(field for field in fieldnames if field != "Strain") - - def __retrieve_changes__(acc, row): - orig = dict(zip(fieldnames, row["Original"].split(","))) - new = dict(zip(fieldnames, row["Current"].split(","))) - return acc + tuple({ - "Strain": new["Strain"], - cattr: new[cattr] - } for cattr in cattrs if new[cattr] != orig[cattr]) - - new_rows: tuple[dict, ...] = reduce( - __retrieve_changes__, modifications_diff, tuple()) - strain_names = tuple({row["Strain"] for row in new_rows}) - cursor.execute("SELECT Id AS StrainId, Name AS StrainName FROM Strain " - f"WHERE Name IN ({', '.join(['%s'] * len(strain_names))})", - strain_names) - strain_ids = { - row["StrainName"]: int(row["StrainId"]) - for row in cursor.fetchall()} - - cursor.execute("SELECT CaseAttributeId, Name AS CaseAttributeName " - "FROM CaseAttribute WHERE InbredSetId=%s " - f"AND Name IN ({', '.join(['%s'] * len(cattrs))})", - (inbredset_id,) + cattrs) - cattr_ids = { - row["CaseAttributeName"]: row["CaseAttributeId"] - for row in cursor.fetchall() - } - - cursor.executemany( - "INSERT INTO CaseAttributeXRefNew" - "(InbredSetId, StrainId, CaseAttributeId, Value) " - "VALUES(%(isetid)s, %(strainid)s, %(cattrid)s, %(value)s) " - "ON DUPLICATE KEY UPDATE Value=VALUES(value)", - tuple( - { - "isetid": inbredset_id, - "strainid": strain_ids[row["Strain"]], - "cattrid": cattr_ids[cattr], - "value": row[cattr] - } - for cattr in cattrs for row in new_rows - if bool(row.get(cattr, "").strip()))) - cursor.executemany( - "DELETE FROM CaseAttributeXRefNew WHERE " - "InbredSetId=%(isetid)s AND StrainId=%(strainid)s " - "AND CaseAttributeId=%(cattrid)s", - tuple( - { - "isetid": inbredset_id, - "strainid": strain_ids[row["Strain"]], - "cattrid": cattr_ids[cattr] - } - for row in new_rows - for cattr in (key for key in row.keys() if key != "Strain") - if not bool(row[cattr].strip()))) - -def __apply_deletions__( - cursor, inbredset_id: int, deletions_diff) -> None: - """Apply deletions: delete existing case attributes and their values.""" - # TODO: Not tested... will probably fail. - params = tuple({ - "inbredset_id": inbredset_id, - "case_attribute_id": diff["case_attribute_id"] - } for diff in deletions_diff) - cursor.executemany( - "DELETE FROM CaseAttributeXRefNew WHERE " - "InbredSetId=:inbredset_id AND CaseAttributeId=:case_attribute_id", - params) - cursor.executemany( - "DELETE FROM CaseAttribute WHERE " - "InbredSetId=:inbredset_id AND CaseAttributeId=:case_attribute_id", - params) - -def __apply_diff__( - conn: Connection, auth_token, inbredset_id: int, diff_filename, the_diff) -> None: - """ - Apply the changes in the diff at `diff_filename` to the data in the database - if the user has appropriate privileges. - """ - required_access(auth_token, - inbredset_id, - ("system:inbredset:edit-case-attribute", - "system:inbredset:apply-case-attribute-edit")) - diffs = the_diff["diff"] - with conn.cursor(cursorclass=DictCursor) as cursor: - # __apply_additions__(cursor, inbredset_id, diffs["Additions"]) - __apply_modifications__( - cursor, inbredset_id, diffs["Modifications"], the_diff["fieldnames"]) - # __apply_deletions__(cursor, inbredset_id, diffs["Deletions"]) - __save_diff__(conn, the_diff, EditStatus.approved) - new_path = Path( - diff_filename.parent, - f"{diff_filename.stem}-approved{diff_filename.suffix}") - os.rename(diff_filename, new_path) - -def __reject_diff__(conn: Connection, - auth_token: dict, - inbredset_id: int, - diff_filename: Path, - diff: dict) -> Path: - """ - Reject the changes in the diff at `diff_filename` to the data in the - database if the user has appropriate privileges. - """ - required_access(auth_token, - inbredset_id, - ("system:inbredset:edit-case-attribute", - "system:inbredset:apply-case-attribute-edit")) - __save_diff__(conn, diff, EditStatus.rejected) - new_path = Path(diff_filename.parent, f"{diff_filename.stem}-rejected{diff_filename.suffix}") - os.rename(diff_filename, new_path) - return diff_filename - -@caseattr.route("/<int:inbredset_id>/add", methods=["POST"]) -@require_token -def add_case_attributes(inbredset_id: int, auth_token=None) -> Response: - """Add a new case attribute for `InbredSetId`.""" - required_access( - auth_token, inbredset_id, ("system:inbredset:create-case-attribute",)) - with database_connection(current_app.config["SQL_URI"]) as conn: # pylint: disable=[unused-variable] - raise NotImplementedError - -@caseattr.route("/<int:inbredset_id>/delete", methods=["POST"]) -@require_token -def delete_case_attributes(inbredset_id: int, auth_token=None) -> Response: - """Delete a case attribute from `InbredSetId`.""" - required_access( - auth_token, inbredset_id, ("system:inbredset:delete-case-attribute",)) - with database_connection(current_app.config["SQL_URI"]) as conn: # pylint: disable=[unused-variable] - raise NotImplementedError - -@caseattr.route("/<int:inbredset_id>/edit", methods=["POST"]) -@require_token -def edit_case_attributes(inbredset_id: int, auth_token = None) -> Response: - """Edit the case attributes for `InbredSetId` based on data received. - - :inbredset_id: Identifier for the population that the case attribute belongs - :auth_token: A validated JWT from the auth server - """ - with database_connection(current_app.config["SQL_URI"]) as conn: - required_access(auth_token, - inbredset_id, - ("system:inbredset:edit-case-attribute",)) - fieldnames = tuple(["Strain"] + sorted( - attr["Name"] for attr in - __case_attribute_labels_by_inbred_set__(conn, inbredset_id))) - try: - diff_filename = __queue_diff__( - conn, { - "inbredset_id": inbredset_id, - "user_id": auth_token["jwt"]["sub"], - "fieldnames": fieldnames, - "diff": __compute_diff__( - fieldnames, - __process_orig_data__( - fieldnames, - __case_attribute_values_by_inbred_set__( - conn, inbredset_id), - __inbredset_strains__(conn, inbredset_id)), - __process_edit_data__( - fieldnames, request.json["edit-data"])) # type: ignore[index] - }, - Path(current_app.config["TMPDIR"], CATTR_DIFFS_DIR)) - except NoDiffError as _nde: - msg = "There were no changes to make from submitted data." - response = jsonify({ - "diff-status": "error", - "error_description": msg - }) - response.status_code = 400 - return response - - try: - __apply_diff__(conn, - auth_token, - inbredset_id, - diff_filename, - __load_diff__(diff_filename)) - return jsonify({ - "diff-status": "applied", - "message": ("The changes to the case-attributes have been " - "applied successfully.") - }) - except AuthorisationError as _auth_err: - return jsonify({ - "diff-status": "queued", - "message": ("The changes to the case-attributes have been " - "queued for approval."), - "diff-filename": str(diff_filename.name) - }) - -@caseattr.route("/<int:inbredset_id>/diff/list", methods=["GET"]) -def list_diffs(inbredset_id: int) -> Response: - """List any changes that have not been approved/rejected.""" - Path(current_app.config["TMPDIR"], CATTR_DIFFS_DIR).mkdir( - parents=True, exist_ok=True) - - def __generate_diff_files__(diffs): - diff_dir = Path(current_app.config["TMPDIR"], CATTR_DIFFS_DIR) - review_files = set(afile.name for afile in diff_dir.iterdir() - if ("-rejected" not in afile.name - and "-approved" not in afile.name)) - for diff in diffs: - the_diff = diff["json_diff_data"] - diff_filepath = Path( - diff_dir, - f"{the_diff['inbredset_id']}:::{the_diff['user_id']}:::" - f"{the_diff['created'].isoformat()}.json") - if diff_filepath not in review_files: - with open(diff_filepath, "w", encoding="utf-8") as dfile: - dfile.write(json.dumps( - {**the_diff, "db_id": diff["id"]}, - cls=CAJSONEncoder)) - - with (database_connection(current_app.config["SQL_URI"]) as conn, - conn.cursor(cursorclass=DictCursor) as cursor): - cursor.execute( - "SELECT * FROM caseattributes_audit WHERE status='review'") - diffs = tuple({ - **row, - "json_diff_data": { - **__parse_diff_json__(row["json_diff_data"]), - "db_id": row["id"], - "created": row["time_stamp"], - "user_id": uuid.UUID(row["editor"]) - } - } for row in cursor.fetchall()) - - __generate_diff_files__(diffs) - resp = make_response(json.dumps( - tuple({ - **diff, - "filename": ( - f"{diff['json_diff_data']['inbredset_id']}:::" - f"{diff['json_diff_data']['user_id']}:::" - f"{diff['time_stamp'].isoformat()}") - } for diff in diffs - if diff["json_diff_data"].get("inbredset_id") == inbredset_id), - cls=CAJSONEncoder)) - resp.headers["Content-Type"] = "application/json" - return resp - -@caseattr.route("/approve/<path:filename>", methods=["POST"]) -@require_token -def approve_case_attributes_diff(filename: str, auth_token = None) -> Response: - """Approve the changes to the case attributes in the diff.""" - diff_dir = Path(current_app.config["TMPDIR"], CATTR_DIFFS_DIR) - diff_filename = Path(diff_dir, filename) - the_diff = __load_diff__(diff_filename) - with database_connection(current_app.config["SQL_URI"]) as conn: - __apply_diff__(conn, auth_token, the_diff["inbredset_id"], diff_filename, the_diff) - return jsonify({ - "message": "Applied the diff successfully.", - "diff_filename": diff_filename.name - }) - -@caseattr.route("/reject/<path:filename>", methods=["POST"]) -@require_token -def reject_case_attributes_diff(filename: str, auth_token=None) -> Response: - """Reject the changes to the case attributes in the diff.""" - diff_dir = Path(current_app.config["TMPDIR"], CATTR_DIFFS_DIR) - diff_filename = Path(diff_dir, filename) - the_diff = __load_diff__(diff_filename) - with database_connection(current_app.config["SQL_URI"]) as conn: - __reject_diff__(conn, - auth_token, - the_diff["inbredset_id"], - diff_filename, - the_diff) - return jsonify({ - "message": "Rejected diff successfully", - "diff_filename": diff_filename.name - }) - -@caseattr.route("/<int:inbredset_id>/diff/<int:diff_id>/view", methods=["GET"]) -@require_token -def view_diff(inbredset_id: int, diff_id: int, auth_token=None) -> Response: - """View a diff.""" - with (database_connection(current_app.config["SQL_URI"]) as conn, - conn.cursor(cursorclass=DictCursor) as cursor): - required_access( - auth_token, inbredset_id, ("system:inbredset:view-case-attribute",)) - cursor.execute( - "SELECT * FROM caseattributes_audit WHERE id=%s", - (diff_id,)) - diff = cursor.fetchone() - if diff: - json_diff_data = __parse_diff_json__(diff["json_diff_data"]) - if json_diff_data["inbredset_id"] != inbredset_id: - return jsonify({ - "error": "Not Found", - "error_description": ( - "Could not find diff with the given ID for the " - "InbredSet chosen.") - }) - return jsonify({ - **diff, - "json_diff_data": { - **json_diff_data, - "db_id": diff["id"], - "created": diff["time_stamp"].isoformat(), - "user_id": uuid.UUID(diff["editor"]) - } - }) - return jsonify({ - "error": "Not Found", - "error_description": "Could not find diff with the given ID." - }) - return jsonify({ - "error": "Code Error", - "error_description": "The code should never run this." - }), 500 diff --git a/gn3/commands.py b/gn3/commands.py index 9617663..3852c41 100644 --- a/gn3/commands.py +++ b/gn3/commands.py @@ -1,13 +1,16 @@ """Procedures used to work with the various bio-informatics cli commands""" +import os import sys import json +import shlex import pickle import logging import tempfile import subprocess from datetime import datetime +from typing import Any from typing import Dict from typing import List from typing import Optional @@ -16,7 +19,7 @@ from typing import Union from typing import Sequence from uuid import uuid4 -from flask import current_app +from flask import Flask, current_app from redis.client import Redis # Used only in type hinting from pymonad.either import Either, Left, Right @@ -25,6 +28,8 @@ from gn3.debug import __pk__ from gn3.chancy import random_string from gn3.exceptions import RedisConnectionError +logger = logging.getLogger(__name__) + def compose_gemma_cmd(gemma_wrapper_cmd: str = "gemma-wrapper", gemma_wrapper_kwargs: Optional[Dict] = None, @@ -44,12 +49,14 @@ def compose_gemma_cmd(gemma_wrapper_cmd: str = "gemma-wrapper", cmd += " ".join([f"{arg}" for arg in gemma_args]) return cmd + def compose_rqtl_cmd(rqtl_wrapper_cmd: str, rqtl_wrapper_kwargs: Dict, rqtl_wrapper_bool_kwargs: list) -> str: """Compose a valid R/qtl command given the correct input""" # Add kwargs with values - cmd = f"Rscript { rqtl_wrapper_cmd } " + " ".join( + rscript = os.environ.get("RSCRIPT", "Rscript") + cmd = f"{rscript} { rqtl_wrapper_cmd } " + " ".join( [f"--{key} {val}" for key, val in rqtl_wrapper_kwargs.items()]) # Add boolean kwargs (kwargs that are either on or off, like --interval) @@ -59,18 +66,22 @@ def compose_rqtl_cmd(rqtl_wrapper_cmd: str, return cmd + def compose_pcorrs_command_for_selected_traits( prefix_cmd: Tuple[str, ...], target_traits: Tuple[str, ...]) -> Tuple[ str, ...]: """Build command for partial correlations against selected traits.""" return prefix_cmd + ("against-traits", ",".join(target_traits)) + def compose_pcorrs_command_for_database( prefix_cmd: Tuple[str, ...], target_database: str, criteria: int = 500) -> Tuple[str, ...]: """Build command for partial correlations against an entire dataset.""" return prefix_cmd + ( - "against-db", f"{target_database}", f"--criteria={criteria}") + "against-db", f"{target_database}", "--criteria", str(criteria), + "--textdir", current_app.config["TEXTDIR"]) + def compose_pcorrs_command( primary_trait: str, control_traits: Tuple[str, ...], method: str, @@ -82,7 +93,9 @@ def compose_pcorrs_command( return "pearsons" if "spearmans" in mthd: return "spearmans" - raise Exception(f"Invalid method '{method}'") + # pylint: disable=[broad-exception-raised] + raise Exception( + f"Invalid method '{method}'") prefix_cmd = ( f"{sys.executable}", "-m", "scripts.partial_correlations", @@ -96,7 +109,10 @@ def compose_pcorrs_command( kwargs.get("target_database") is None and kwargs.get("target_traits") is not None): return compose_pcorrs_command_for_selected_traits(prefix_cmd, **kwargs) - raise Exception("Invalid state: I don't know what command to generate!") + # pylint: disable=[broad-exception-raised] + raise Exception( + "Invalid state: I don't know what command to generate!") + def queue_cmd(conn: Redis, job_queue: str, @@ -130,6 +146,7 @@ Returns the name of the specific redis hash for the specific task. conn.hset(name=unique_id, key="env", value=json.dumps(env)) return unique_id + def run_sample_corr_cmd(method, this_trait_data, target_dataset_data): "Run the sample correlations in an external process, returning the results." with tempfile.TemporaryDirectory() as tempdir: @@ -152,9 +169,14 @@ def run_sample_corr_cmd(method, this_trait_data, target_dataset_data): return correlation_results + def run_cmd(cmd: str, success_codes: Tuple = (0,), env: Optional[str] = None) -> Dict: """Run CMD and return the CMD's status code and output as a dict""" - parsed_cmd = json.loads(__pk__("Attempting to parse command", cmd)) + try: + parsed_cmd = json.loads(cmd) + except json.decoder.JSONDecodeError as _jderr: + parsed_cmd = shlex.split(cmd) + parsed_env = (json.loads(env) if env is not None else None) results = subprocess.run( @@ -163,17 +185,37 @@ def run_cmd(cmd: str, success_codes: Tuple = (0,), env: Optional[str] = None) -> out = str(results.stdout, 'utf-8') if results.returncode not in success_codes: # Error! out = str(results.stderr, 'utf-8') - (# We do not always run this within an app context - current_app.logger.debug if current_app else logging.debug)(out) + logger.debug("Command output: %s", out) return {"code": results.returncode, "output": out} + +def compute_job_queue(app: Flask) -> str: + """Use the app configurations to compute the job queue""" + app_env = app.config["APPLICATION_ENVIRONMENT"] + job_queue = app.config["REDIS_JOB_QUEUE"] + if bool(app_env): + return f"{app_env}::{job_queue}" + return job_queue + + def run_async_cmd( conn: Redis, job_queue: str, cmd: Union[str, Sequence[str]], - email: Optional[str] = None, env: Optional[dict] = None) -> str: + options: Optional[Dict[str, Any]] = None, + log_level: str = "info") -> str: """A utility function to call `gn3.commands.queue_cmd` function and run the worker in the `one-shot` mode.""" + email = options.get("email") if options else None + env = options.get("env") if options else None cmd_id = queue_cmd(conn, job_queue, cmd, email, env) - subprocess.Popen([f"{sys.executable}", "-m", "sheepdog.worker"]) # pylint: disable=[consider-using-with] + worker_command = [ + sys.executable, + "-m", "sheepdog.worker", + "--queue-name", job_queue, + "--log-level", log_level + ] + logging.debug("Launching the worker: %s", worker_command) + subprocess.Popen( # pylint: disable=[consider-using-with] + worker_command) return cmd_id diff --git a/gn3/computations/correlations.py b/gn3/computations/correlations.py index d805af7..95bd957 100644 --- a/gn3/computations/correlations.py +++ b/gn3/computations/correlations.py @@ -6,6 +6,7 @@ from multiprocessing import Pool, cpu_count from typing import List from typing import Tuple +from typing import Sequence from typing import Optional from typing import Callable from typing import Generator @@ -52,8 +53,10 @@ def normalize_values(a_values: List, b_values: List) -> Generator: yield a_val, b_val -def compute_corr_coeff_p_value(primary_values: List, target_values: List, - corr_method: str) -> Tuple[float, float]: +def compute_corr_coeff_p_value( + primary_values: Sequence, + target_values: Sequence, + corr_method: str) -> Tuple[float, float]: """Given array like inputs calculate the primary and target_value methods -> pearson,spearman and biweight mid correlation return value is rho and p_value @@ -196,7 +199,7 @@ def compute_all_sample_correlation(this_trait, """ this_trait_samples = this_trait["trait_sample_data"] - with Pool(processes=(cpu_count() - 1)) as pool: + with Pool(processes=cpu_count() - 1) as pool: return sorted( ( corr for corr in diff --git a/gn3/computations/ctl.py b/gn3/computations/ctl.py index f881410..5c004ea 100644 --- a/gn3/computations/ctl.py +++ b/gn3/computations/ctl.py @@ -6,13 +6,11 @@ from gn3.computations.wgcna import dump_wgcna_data from gn3.computations.wgcna import compose_wgcna_cmd from gn3.computations.wgcna import process_image -from gn3.settings import TMPDIR - -def call_ctl_script(data): +def call_ctl_script(data, tmpdir): """function to call ctl script""" - data["imgDir"] = TMPDIR - temp_file_name = dump_wgcna_data(data) + data["imgDir"] = tmpdir + temp_file_name = dump_wgcna_data(data, tmpdir) cmd = compose_wgcna_cmd("ctl_analysis.R", temp_file_name) cmd_results = run_cmd(cmd) diff --git a/gn3/computations/gemma.py b/gn3/computations/gemma.py index 6c53ecc..f07628f 100644 --- a/gn3/computations/gemma.py +++ b/gn3/computations/gemma.py @@ -41,12 +41,13 @@ def generate_pheno_txt_file(trait_filename: str, # pylint: disable=R0913 -def generate_gemma_cmd(gemma_cmd: str, - output_dir: str, - token: str, - gemma_kwargs: Dict, - gemma_wrapper_kwargs: Optional[Dict] = None, - chromosomes: Optional[str] = None) -> Dict: +def generate_gemma_cmd(# pylint: disable=[too-many-positional-arguments] + gemma_cmd: str, + output_dir: str, + token: str, + gemma_kwargs: Dict, + gemma_wrapper_kwargs: Optional[Dict] = None, + chromosomes: Optional[str] = None) -> Dict: """Compute k values""" _hash = get_hash_of_files( [v for k, v in gemma_kwargs.items() if k in ["g", "p", "a", "c"]]) diff --git a/gn3/computations/partial_correlations.py b/gn3/computations/partial_correlations.py index 6eee299..8674910 100644 --- a/gn3/computations/partial_correlations.py +++ b/gn3/computations/partial_correlations.py @@ -16,7 +16,6 @@ import pandas import pingouin from scipy.stats import pearsonr, spearmanr -from gn3.settings import TEXTDIR from gn3.chancy import random_string from gn3.function_helpers import compose from gn3.data_helpers import parse_csv_line @@ -99,7 +98,7 @@ def fix_samples( primary_samples, tuple(primary_trait_data["data"][sample]["value"] for sample in primary_samples), - control_vals_vars[0], + (control_vals_vars[0],), tuple(primary_trait_data["data"][sample]["variance"] for sample in primary_samples), control_vals_vars[1]) @@ -209,7 +208,7 @@ def good_dataset_samples_indexes( samples_from_file.index(good) for good in set(samples).intersection(set(samples_from_file)))) -def partial_correlations_fast(# pylint: disable=[R0913, R0914] +def partial_correlations_fast(# pylint: disable=[R0913, R0914, too-many-positional-arguments] samples, primary_vals, control_vals, database_filename, fetched_correlations, method: str, correlation_type: str) -> Generator: """ @@ -334,7 +333,7 @@ def compute_partial( This implementation reworks the child function `compute_partial` which will then be used in the place of `determinPartialsByR`. """ - with Pool(processes=(cpu_count() - 1)) as pool: + with Pool(processes=cpu_count() - 1) as pool: return ( result for result in ( pool.starmap( @@ -345,7 +344,7 @@ def compute_partial( for target in targets))) if result is not None) -def partial_correlations_normal(# pylint: disable=R0913 +def partial_correlations_normal(# pylint: disable=[R0913, too-many-positional-arguments] primary_vals, control_vals, input_trait_gene_id, trait_database, data_start_pos: int, db_type: str, method: str) -> Generator: """ @@ -381,7 +380,7 @@ def partial_correlations_normal(# pylint: disable=R0913 return all_correlations -def partial_corrs(# pylint: disable=[R0913] +def partial_corrs(# pylint: disable=[R0913, too-many-positional-arguments] conn, samples, primary_vals, control_vals, return_number, species, input_trait_geneid, input_trait_symbol, tissue_probeset_freeze_id, method, dataset, database_filename): @@ -667,10 +666,15 @@ def check_for_common_errors(# pylint: disable=[R0914] return non_error_result -def partial_correlations_with_target_db(# pylint: disable=[R0913, R0914, R0911] - conn: Any, primary_trait_name: str, - control_trait_names: Tuple[str, ...], method: str, - criteria: int, target_db_name: str) -> dict: +def partial_correlations_with_target_db(# pylint: disable=[R0913, R0914, R0911 too-many-positional-arguments] + conn: Any, + primary_trait_name: str, + control_trait_names: Tuple[str, ...], + method: str, + criteria: int, + target_db_name: str, + textdir: str +) -> dict: """ This is the 'ochestration' function for the partial-correlation feature. @@ -755,7 +759,7 @@ def partial_correlations_with_target_db(# pylint: disable=[R0913, R0914, R0911] threshold, conn) - database_filename = get_filename(conn, target_db_name, TEXTDIR) + database_filename = get_filename(conn, target_db_name, textdir) all_correlations = partial_corrs( conn, check_res["common_primary_control_samples"], check_res["fixed_primary_values"], check_res["fixed_control_values"], @@ -837,7 +841,7 @@ def partial_correlations_with_target_traits( return check_res target_traits = { - trait["name"]: trait + trait["trait_name"]: trait for trait in traits_info(conn, threshold, target_trait_names)} target_traits_data = traits_data(conn, tuple(target_traits.values())) @@ -854,12 +858,13 @@ def partial_correlations_with_target_traits( __merge( target_traits[target_name], compute_trait_info( - check_res["primary_values"], check_res["fixed_control_values"], - (export_trait_data( - target_data, - samplelist=check_res["common_primary_control_samples"]), - target_name), - method)) + check_res["primary_values"], + check_res["fixed_control_values"], + (export_trait_data( + target_data, + samplelist=check_res["common_primary_control_samples"]), + target_name), + method)) for target_name, target_data in target_traits_data.items()) return { diff --git a/gn3/computations/pca.py b/gn3/computations/pca.py index 35c9f03..3b3041a 100644 --- a/gn3/computations/pca.py +++ b/gn3/computations/pca.py @@ -13,7 +13,7 @@ import redis from typing_extensions import TypeAlias -fArray: TypeAlias = list[float] +fArray: TypeAlias = list[float] # pylint: disable=[invalid-name] def compute_pca(array: list[fArray]) -> dict[str, Any]: @@ -133,7 +133,7 @@ def generate_pca_temp_traits( """ - # pylint: disable=too-many-arguments + # pylint: disable=[too-many-arguments, too-many-positional-arguments] pca_trait_dict = {} diff --git a/gn3/computations/qtlreaper.py b/gn3/computations/qtlreaper.py index 08c387f..ff83b33 100644 --- a/gn3/computations/qtlreaper.py +++ b/gn3/computations/qtlreaper.py @@ -7,7 +7,6 @@ import subprocess from typing import Union from gn3.chancy import random_string -from gn3.settings import TMPDIR def generate_traits_file(samples, trait_values, traits_filename): """ @@ -38,13 +37,15 @@ def create_output_directory(path: str): # If the directory already exists, do nothing. pass -# pylint: disable=too-many-arguments +# pylint: disable=[too-many-arguments, too-many-positional-arguments] def run_reaper( reaper_cmd: str, - genotype_filename: str, traits_filename: str, + genotype_filename: str, + traits_filename: str, + output_dir: str, other_options: tuple = ("--n_permutations", "1000"), - separate_nperm_output: bool = False, - output_dir: str = TMPDIR): + separate_nperm_output: bool = False +): """ Run the QTLReaper command to compute the QTLs. diff --git a/gn3/computations/rqtl.py b/gn3/computations/rqtl.py index 16f1398..3dd8fb2 100644 --- a/gn3/computations/rqtl.py +++ b/gn3/computations/rqtl.py @@ -1,5 +1,6 @@ """Procedures related to R/qtl computations""" import os +import csv from bisect import bisect from typing import Dict, List, Tuple, Union @@ -67,8 +68,8 @@ def process_rqtl_mapping(file_name: str) -> List: # Later I should probably redo this using csv.read to avoid the # awkwardness with removing quotes with [1:-1] outdir = os.path.join(get_tmpdir(),"gn3") - - with open( os.path.join(outdir,file_name),"r",encoding="utf-8") as the_file: + with open(os.path.join(outdir,file_name),"r",encoding="utf-8") as the_file: + column_count = len(the_file.readline().strip().split(",")) for line in the_file: line_items = line.split(",") if line_items[1][1:-1] == "chr" or not line_items: @@ -88,6 +89,16 @@ def process_rqtl_mapping(file_name: str) -> List: "Mb": float(line_items[2]), "lod_score": float(line_items[3]), } + # If 4-way, get extra effect columns + if column_count > 4: + this_marker['mean1'] = line_items[4][1:-1].split(' ± ')[0] + this_marker['se1'] = line_items[4][1:-1].split(' ± ')[1] + this_marker['mean2'] = line_items[5][1:-1].split(' ± ')[0] + this_marker['se2'] = line_items[5][1:-1].split(' ± ')[1] + this_marker['mean3'] = line_items[6][1:-1].split(' ± ')[0] + this_marker['se3'] = line_items[6][1:-1].split(' ± ')[1] + this_marker['mean4'] = line_items[7][1:-1].split(' ± ')[0] + this_marker['se4'] = line_items[7][1:-1].split(' ± ')[1] marker_obs.append(this_marker) return marker_obs @@ -111,7 +122,7 @@ def pairscan_for_figure(file_name: str) -> Dict: # Open the file with the actual results, written as a list of lists outdir = os.path.join(get_tmpdir(),"gn3") - with open( os.path.join(outdir,file_name),"r",encoding="utf-8") as the_file: + with open(os.path.join(outdir, file_name), "r",encoding="utf-8") as the_file: lod_results = [] for i, line in enumerate(the_file): if i == 0: # Skip first line @@ -134,14 +145,17 @@ def pairscan_for_figure(file_name: str) -> Dict: ) as the_file: chr_list = [] # type: List pos_list = [] # type: List + markers = [] # type: List for i, line in enumerate(the_file): if i == 0: # Skip first line continue line_items = [item.rstrip("\n") for item in line.split(",")] chr_list.append(line_items[1][1:-1]) pos_list.append(line_items[2]) + markers.append(line_items[0]) figure_data["chr"] = chr_list figure_data["pos"] = pos_list + figure_data["name"] = markers return figure_data @@ -312,18 +326,13 @@ def process_perm_output(file_name: str) -> Tuple[List, float, float]: suggestive and significant thresholds""" perm_results = [] - outdir = os.path.join(get_tmpdir(),"gn3") - - with open( os.path.join(outdir,file_name),"r",encoding="utf-8") as the_file: - for i, line in enumerate(the_file): - if i == 0: - # Skip header line - continue - - line_items = line.split(",") - perm_results.append(float(line_items[1])) - - suggestive = np.percentile(np.array(perm_results), 67) - significant = np.percentile(np.array(perm_results), 95) - + outdir = os.path.join(get_tmpdir(), "gn3") + + with open(os.path.join(outdir, file_name), + "r", encoding="utf-8") as file_handler: + reader = csv.reader(file_handler) + next(reader) + perm_results = [float(row[1]) for row in reader] # Extract LOD values + suggestive = np.percentile(np.array(perm_results), 67) + significant = np.percentile(np.array(perm_results), 95) return perm_results, suggestive, significant diff --git a/gn3/computations/rqtl2.py b/gn3/computations/rqtl2.py new file mode 100644 index 0000000..5d5f68e --- /dev/null +++ b/gn3/computations/rqtl2.py @@ -0,0 +1,228 @@ +"""Module contains functions to parse and process rqtl2 input and output""" +import os +import csv +import uuid +import json +from pathlib import Path +from typing import List +from typing import Dict +from typing import Any + +def generate_rqtl2_files(data, workspace_dir): + """Prepare data and generate necessary CSV files + required to write to control_file + """ + file_to_name_map = { + "geno_file": "geno_data", + "pheno_file": "pheno_data", + "geno_map_file": "geno_map_data", + "physical_map_file": "physical_map_data", + "phenocovar_file": "phenocovar_data", + "founder_geno_file" : "founder_geno_data", + "covar_file" : "covar_data" + } + parsed_files = {} + for file_name, data_key in file_to_name_map.items(): + if data_key in data: + file_path = write_to_csv( + workspace_dir, f"{file_name}.csv", data[data_key]) + if file_path: + parsed_files[file_name] = file_path + return {**data, **parsed_files} + + +def write_to_csv(work_dir, file_name, data: list[dict], + headers=None, delimiter=","): + """Functions to write data list to csv file + if headers is not provided use the keys for first boject. + """ + if not data: + return "" + if headers is None: + headers = data[0].keys() + file_path = os.path.join(work_dir, file_name) + with open(file_path, "w", encoding="utf-8") as file_handler: + writer = csv.DictWriter(file_handler, fieldnames=headers, + delimiter=delimiter) + writer.writeheader() + for row in data: + writer.writerow(row) + # return the relative file to the workspace see rqtl2 docs + return file_name + + +def validate_required_keys(required_keys: list, data: dict) -> tuple[bool, str]: + """Check for missing keys in data object""" + missing_keys = [key for key in required_keys if key not in data] + if missing_keys: + return False, f"Required key(s) missing: {', '.join(missing_keys)}" + return True, "" + + +def compose_rqtl2_cmd(# pylint: disable=[too-many-positional-arguments] + rqtl_path, input_file, output_file, workspace_dir, data, config): + """Compose the command for running the R/QTL2 analysis.""" + # pylint: disable=R0913 + params = { + "input_file": input_file, + "directory": workspace_dir, + "output_file": output_file, + "nperm": data.get("nperm", 0), + "method": data.get("method", "HK"), + "threshold": data.get("threshold", 1), + "cores": config.get('MULTIPROCESSOR_PROCS', 1) + } + rscript_path = config.get("RSCRIPT", os.environ.get("RSCRIPT", "Rscript")) + return f"{rscript_path} { rqtl_path } " + " ".join( + [f"--{key} {val}" for key, val in params.items()]) + + +def create_file(file_path): + """Utility function to create file given a file_path""" + try: + with open(file_path, "x", encoding="utf-8") as _file_handler: + return True, f"File created at {file_path}" + except FileExistsError: + return False, "File Already Exists" + + +def prepare_files(tmpdir): + """Prepare necessary files and workspace dir for computation.""" + workspace_dir = os.path.join(tmpdir, str(uuid.uuid4())) + Path(workspace_dir).mkdir(parents=False, exist_ok=True) + input_file = os.path.join( + workspace_dir, f"rqtl2-input-{uuid.uuid4()}.json") + output_file = os.path.join( + workspace_dir, f"rqtl2-output-{uuid.uuid4()}.json") + + # to ensure streaming api has access to file even after computation ends + # .. Create the log file outside the workspace_dir + log_file = os.path.join(tmpdir, f"rqtl2-log-{uuid.uuid4()}") + for file_path in [input_file, output_file, log_file]: + create_file(file_path) + return workspace_dir, input_file, output_file, log_file + + +def write_input_file(input_file, workspace_dir, data): + """ + Write input data to a json file to be passed + as input to the rqtl2 script + """ + with open(input_file, "w+", encoding="UTF-8") as file_handler: + # todo choose a better variable name + rqtl2_files = generate_rqtl2_files(data, workspace_dir) + json.dump(rqtl2_files, file_handler) + + +def read_output_file(output_path: str) -> dict: + """function to read output file json generated from rqtl2 + see rqtl2_wrapper.R script for the expected output + """ + with open(output_path, "r", encoding="utf-8") as file_handler: + results = json.load(file_handler) + return results + + +def process_permutation(data): + """ This function processses output data from the output results. + input: data object extracted from the output_file + returns: + dict: A dict containing + * phenotypes array + * permutations as dict with keys as permutation_id + * significance_results with keys as threshold values + """ + + perm_file = data.get("permutation_file") + with open(perm_file, "r", encoding="utf-8") as file_handler: + reader = csv.reader(file_handler) + phenotypes = next(reader)[1:] + perm_results = {_id: float(val) for _id, val, *_ in reader} + _, significance = fetch_significance_results(data.get("significance_file")) + return { + "phenotypes": phenotypes, + "perm_results": perm_results, + "significance": significance, + } + + +def fetch_significance_results(file_path: str): + """ + Processes the 'significance_file' from the given data object to extract + phenotypes and significance values. + thresholds values are: (0.05, 0.01) + Args: + file_path (str): file_Path for the significance output + + Returns: + tuple: A tuple containing + * phenotypes (list): List of phenotypes + * significances (dict): A dictionary where keys + ...are threshold values and values are lists + of significant results corresponding to each threshold. + """ + with open(file_path, "r", encoding="utf-8") as file_handler: + reader = csv.reader(file_handler) + results = {} + phenotypes = next(reader)[1:] + for line in reader: + threshold, significance = line[0], line[1:] + results[threshold] = significance + return (phenotypes, results) + + +def process_scan_results(qtl_file_path: str, map_file_path: str) -> List[Dict[str, Any]]: + """Function to process genome scanning results and obtain marker_name, Lod score, + marker_position, and chromosome. + Args: + qtl_file_path (str): Path to the QTL scan results CSV file. + map_file_path (str): Path to the map file from the script. + + Returns: + List[Dict[str, str]]: A list of dictionaries containing the marker data. + """ + map_data = {} + # read the genetic map + with open(map_file_path, "r", encoding="utf-8") as file_handler: + reader = csv.reader(file_handler) + next(reader) + for line in reader: + marker, chr_, cm_, mb_ = line + cm: float | None = float(cm_) if cm_ and cm_ != "NA" else None + mb: float | None = float(mb_) if mb_ and mb_ != "NA" else None + map_data[marker] = {"chr": chr_, "cM": cm, "Mb": mb} + + # Process QTL scan results and merge the positional data + results = [] + with open(qtl_file_path, "r", encoding="utf-8") as file_handler: + reader = csv.reader(file_handler) + next(reader) + for line in reader: + marker = line[0] + lod_score = line[1] + results.append({ + "name": marker, + "lod_score": float(lod_score), + **map_data.get(marker, {}) # Add chromosome and positions if available + }) + return results + + +def process_qtl2_results(output_file: str) -> Dict[str, Any]: + """Function provides abstraction for processing all QTL2 mapping results. + + Args: * File path to to the output generated + + Returns: + Dict[str, any]: A dictionary containing both QTL + and permutation results along with input data. + """ + results = read_output_file(output_file) + qtl_results = process_scan_results(results["scan_file"], + results["map_file"]) + permutation_results = process_permutation(results) if results["permutations"] > 0 else {} + return { + **results, + "qtl_results": qtl_results, + "permutation_results": permutation_results + } diff --git a/gn3/computations/rust_correlation.py b/gn3/computations/rust_correlation.py index 5ce097d..359b73a 100644 --- a/gn3/computations/rust_correlation.py +++ b/gn3/computations/rust_correlation.py @@ -3,27 +3,27 @@ https://github.com/Alexanderlacuna/correlation_rust """ -import subprocess -import json -import csv import os +import csv +import json +import traceback +import subprocess from flask import current_app from gn3.computations.qtlreaper import create_output_directory from gn3.chancy import random_string -from gn3.settings import TMPDIR -def generate_input_files(dataset: list[str], - output_dir: str = TMPDIR) -> tuple[str, str]: +def generate_input_files( + dataset: list[str], output_dir: str) -> tuple[str, str]: """function generates outputfiles and inputfiles""" tmp_dir = f"{output_dir}/correlation" create_output_directory(tmp_dir) tmp_file = os.path.join(tmp_dir, f"{random_string(10)}.txt") with open(tmp_file, "w", encoding="utf-8") as op_file: writer = csv.writer( - op_file, delimiter=",", dialect="unix", quotechar="", + op_file, delimiter=",", dialect="unix", quoting=csv.QUOTE_NONE, escapechar="\\") writer.writerows(dataset) @@ -49,17 +49,23 @@ def generate_json_file( def run_correlation( - dataset, trait_vals: str, method: str, delimiter: str, - corr_type: str = "sample", top_n: int = 500): + dataset, + trait_vals: str, + method: str, + delimiter: str, + tmpdir: str, + corr_type: str = "sample", + top_n: int = 500 +): """entry function to call rust correlation""" - # pylint: disable=too-many-arguments + # pylint: disable=[too-many-arguments, too-many-positional-arguments] correlation_command = current_app.config["CORRELATION_COMMAND"] # make arg? - (tmp_dir, tmp_file) = generate_input_files(dataset) + (tmp_dir, tmp_file) = generate_input_files(dataset, tmpdir) (output_file, json_file) = generate_json_file( tmp_dir=tmp_dir, tmp_file=tmp_file, method=method, delimiter=delimiter, x_vals=trait_vals) - command_list = [correlation_command, json_file, TMPDIR] + command_list = [correlation_command, json_file, tmpdir] try: subprocess.run(command_list, check=True, capture_output=True) except subprocess.CalledProcessError as cpe: @@ -67,7 +73,12 @@ def run_correlation( os.readlink(correlation_command) if os.path.islink(correlation_command) else correlation_command) - raise Exception(command_list, actual_command, cpe.stdout) from cpe + raise Exception(# pylint: disable=[broad-exception-raised] + command_list, + actual_command, + cpe.stdout, + traceback.format_exc().split() + ) from cpe return parse_correlation_output(output_file, corr_type, top_n) diff --git a/gn3/computations/streaming.py b/gn3/computations/streaming.py new file mode 100644 index 0000000..6e02694 --- /dev/null +++ b/gn3/computations/streaming.py @@ -0,0 +1,62 @@ +"""Module contains streaming procedures for genenetwork. """ +import os +import subprocess +from functools import wraps +from flask import current_app, request + + +def read_file(file_path): + """Add utility function to read files""" + with open(file_path, "r", encoding="UTF-8") as file_handler: + return file_handler.read() + +def run_process(cmd, log_file, run_id): + """Function to execute an external process and + capture the stdout in a file + input: + cmd: the command to execute as a list of args. + log_file: abs file path to write the stdout. + run_id: unique id to identify the process + + output: + Dict with the results for either success or failure. + """ + try: + # phase: execute the rscript cmd + with subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) as process: + for line in iter(process.stdout.readline, b""): + # phase: capture the stdout for each line allowing read and write + with open(log_file, "a+", encoding="utf-8") as file_handler: + file_handler.write(line.decode("utf-8")) + process.wait() + return {"msg": "success" if process.returncode == 0 else "Process failed", + "run_id": run_id, + "log" : read_file(log_file), + "code": process.returncode} + except subprocess.CalledProcessError as error: + return {"msg": "error occurred", + "code": error.returncode, + "error": str(error), + "run_id": run_id, + "log" : read_file(log_file)} + + +def enable_streaming(func): + """Decorator function to enable streaming for an endpoint + Note: should only be used in an app context + """ + @wraps(func) + def decorated_function(*args, **kwargs): + run_id = request.args.get("id") + stream_output_file = os.path.join(current_app.config.get("TMPDIR"), + f"{run_id}.txt") + with open(stream_output_file, "w+", encoding="utf-8", + ) as file_handler: + file_handler.write("File created for streaming\n" + ) + return func(stream_output_file, *args, **kwargs) + return decorated_function diff --git a/gn3/computations/wgcna.py b/gn3/computations/wgcna.py index d1f7b32..3229a0e 100644 --- a/gn3/computations/wgcna.py +++ b/gn3/computations/wgcna.py @@ -7,17 +7,16 @@ import subprocess from pathlib import Path -from gn3.settings import TMPDIR from gn3.commands import run_cmd -def dump_wgcna_data(request_data: dict): +def dump_wgcna_data(request_data: dict, tmpdir: str): """function to dump request data to json file""" filename = f"{str(uuid.uuid4())}.json" - temp_file_path = os.path.join(TMPDIR, filename) + temp_file_path = os.path.join(tmpdir, filename) - request_data["TMPDIR"] = TMPDIR + request_data["TMPDIR"] = tmpdir with open(temp_file_path, "w", encoding="utf-8") as output_file: json.dump(request_data, output_file) @@ -65,9 +64,9 @@ def compose_wgcna_cmd(rscript_path: str, temp_file_path: str): return cmd -def call_wgcna_script(rscript_path: str, request_data: dict): +def call_wgcna_script(rscript_path: str, request_data: dict, tmpdir: str): """function to call wgcna script""" - generated_file = dump_wgcna_data(request_data) + generated_file = dump_wgcna_data(request_data, tmpdir) cmd = compose_wgcna_cmd(rscript_path, generated_file) # stream_cmd_output(request_data, cmd) disable streaming of data diff --git a/gn3/db/case_attributes.py b/gn3/db/case_attributes.py index bb15248..a55f2d8 100644 --- a/gn3/db/case_attributes.py +++ b/gn3/db/case_attributes.py @@ -1,126 +1,389 @@ """Module that contains functions for editing case-attribute data""" -from typing import Any, Optional, Tuple +from pathlib import Path +from typing import Optional +from dataclasses import dataclass +from enum import Enum, auto import json -import MySQLdb - - -def get_case_attributes(conn) -> Optional[Tuple]: - """Get all the case attributes from the database.""" - with conn.cursor() as cursor: - cursor.execute("SELECT Id, Name, Description FROM CaseAttribute") - return cursor.fetchall() - - -def get_unreviewed_diffs(conn: Any) -> Optional[tuple]: - """Fetch all case attributes in GN""" - with conn.cursor() as cursor: - cursor.execute( - "SELECT id, editor, json_diff_data FROM " - "caseattributes_audit WHERE status = 'review'" - ) - return cursor.fetchall() - - -def insert_case_attribute_audit( - conn: Any, status: str, author: str, data: str -) -> int: - """Update the case_attribute_audit table""" - rowcount = 0 - try: - with conn.cursor() as cursor: - cursor.execute( - "INSERT INTO caseattributes_audit " - "(status, editor, json_diff_data) " - "VALUES (%s, %s, %s)", - (status, author, data,), +import pickle +import lmdb + + +class EditStatus(Enum): + """Enumeration for the status of the edits.""" + review = auto() # pylint: disable=[invalid-name] + approved = auto() # pylint: disable=[invalid-name] + rejected = auto() # pylint: disable=[invalid-name] + + def __str__(self): + """Print out human-readable form.""" + return self.name + + +@dataclass +class CaseAttributeEdit: + """Represents an edit operation for case attributes in the database. + + Attributes: + - inbredset_id (int): The ID of the inbred set associated with + the edit. + - status: (EditStatus): The status of this edit. + - user_id (str): The ID of the user performing the edit. + - changes (dict): A dictionary containing the changes to be + applied to the case attributes. + + """ + inbredset_id: int + status: EditStatus + user_id: str + changes: dict + + +def queue_edit(cursor, directory: Path, edit: CaseAttributeEdit) -> Optional[int]: + """Queues a case attribute edit for review by inserting it into + the audit table and storing its review ID in an LMDB database. + + Args: + cursor: A database cursor for executing SQL queries. + directory (Path): The base directory path for the LMDB database. + edit (CaseAttributeEdit): A dataclass containing the edit details, including + inbredset_id, status, user_id, and changes. + + Returns: + int: An id the particular case-attribute that was updated. + """ + cursor.execute( + "INSERT INTO " + "caseattributes_audit(status, editor, json_diff_data) " + "VALUES (%s, %s, %s) " + "ON DUPLICATE KEY UPDATE status=%s", + (str(edit.status), edit.user_id, + json.dumps(edit.changes), str(EditStatus.review),)) + directory.mkdir(parents=True, exist_ok=True) + env = lmdb.open(directory.as_posix(), map_size=8_000_000) # 1 MB + with env.begin(write=True) as txn: + review_ids = set() + if reviews := txn.get(b"review"): + review_ids = pickle.loads(reviews) + _id = cursor.lastrowid + review_ids.add(_id) + txn.put(b"review", pickle.dumps(review_ids)) + return _id + + +def __fetch_case_attrs_changes__(cursor, change_ids: tuple) -> list: + """Fetches case attribute change records from the audit table for + given change IDs. + + Retrieves records from the `caseattributes_audit` table for the + specified `change_ids`, including the editor, JSON diff data, and + timestamp. The JSON diff data is deserialized into a Python + dictionary for each record. Results are ordered by timestamp in + descending order (most recent first). + + Args: + cursor: A MySQLdb cursor for executing SQL queries. + change_ids (tuple): A tuple of integers representing the IDs + of changes to fetch. + + Returns: + list: A list of dictionaries, each containing the `editor`, + `json_diff_data` (as a deserialized dictionary), and `time_stamp` + for the matching change IDs. Returns an empty list if no records + are found. + + Notes: + - The function assumes `change_ids` is a non-empty tuple of valid integers. + - The `json_diff_data` column in `caseattributes_audit` is expected to contain valid + JSON strings, which are deserialized into dictionaries. + - The query uses parameterized placeholders to prevent SQL injection. + - This is an internal helper function (indicated by double underscores) used by + other functions like `get_changes`. + + Raises: + json.JSONDecodeError: If any `json_diff_data` value cannot be deserialized. + TypeError: If `change_ids` is empty or contains non-integer values, potentially + causing a database error. + + """ + if not change_ids: + return {} # type:ignore + placeholders = ','.join(['%s'] * len(change_ids)) + cursor.execute( + "SELECT editor, json_diff_data, time_stamp " + f"FROM caseattributes_audit WHERE id IN ({placeholders}) " + "ORDER BY time_stamp DESC", + change_ids + ) + results = cursor.fetchall() + for el in results: + el["json_diff_data"] = json.loads(el["json_diff_data"]) + return results + + +def view_change(cursor, change_id: int) -> dict: + """Queries the `caseattributes_audit` table to fetch the + `json_diff_data` column for the given `change_id`. The JSON data + is deserialized into a Python dictionary and returned. If no + record is found or the `json_diff_data` is None, an empty + dictionary is returned. + + Args: + cursor: A MySQLdb cursor for executing SQL queries. + change_id (int): The ID of the change to retrieve from the + `caseattributes_audit` table. + + Returns: + dict: The deserialized JSON diff data as a dictionary if the + record exists and contains valid JSON; otherwise, an + empty dictionary. + + Raises: + json.JSONDecodeError: If the `json_diff_data` cannot be + deserialized due to invalid JSON. + TypeError: If `cursor.fetchone()` returns None (e.g., no + record found) and `json_diff_data` is accessed, though the + function handles this by returning an empty dictionary. + + """ + cursor.execute( + "SELECT json_diff_data " + "FROM caseattributes_audit " + "WHERE id = %s", + (change_id,) + ) + json_diff_data, _ = cursor.fetchone() + if json_diff_data: + json_diff_data = json.loads(json_diff_data) + return json_diff_data + return {} + + +def get_changes(cursor, change_type: EditStatus, directory: Path) -> dict: + """Retrieves case attribute changes for given lmdb data in + directory categorized by review status. + + Fetches change IDs from an LMDB database, categorized into the + "data" key based on the EditStatus + + Args: + - cursor: A MySQLdb cursor for executing SQL queries. + - change_type (EditStatus): The status of changes to retrieve + ('review', 'approved', or 'rejected'). + - directory (Path): The base directory path for the LMDB + database. + + Returns: + dict: A dictionary with two keys: + -'count': A dictionary with counts of 'reviews', + 'approvals' and 'rejections'. + - 'data': contains the json diff data of the modified data + + Raises: + json.JSONDecodeError: If any `json_diff_data` in the audit + table cannot be deserialized by + `__fetch_case_attrs_changes__`. + TypeError: If `inbredset_id` is not an integer or if LMDB data + cannot be deserialized. Also raised when an invalid change_id + is used. + + """ + directory.mkdir(parents=True, exist_ok=True) + review_ids, approved_ids, rejected_ids = set(), set(), set() + directory.mkdir(parents=True, exist_ok=True) + env = lmdb.open(directory.as_posix(), map_size=8_000_000) # 1 MB + with env.begin(write=False) as txn: + if reviews := txn.get(b"review"): + review_ids = pickle.loads(reviews) + if approvals := txn.get(b"approved"): + approved_ids = pickle.loads(approvals) + if rejections := txn.get(b"rejected"): + rejected_ids = pickle.loads(rejections) + changes = {} + match change_type: + case EditStatus.review: + changes = dict( + zip(review_ids, + __fetch_case_attrs_changes__(cursor, tuple(review_ids))) ) - rowcount = cursor.rowcount - except Exception as _e: - raise MySQLdb.Error(_e) from _e - return rowcount - - -def reject_case_attribute(conn: Any, case_attr_audit_id: int) -> int: - """Given the id of the json_diff in the case_attribute_audit table, reject - it""" - rowcount = 0 - try: - with conn.cursor() as cursor: - cursor.execute( - "UPDATE caseattributes_audit SET " - "status = 'rejected' WHERE id = %s", - (case_attr_audit_id,), + case EditStatus.approved: + changes = dict( + zip(approved_ids, + __fetch_case_attrs_changes__(cursor, tuple(approved_ids))) ) - rowcount = cursor.rowcount - except Exception as _e: - raise MySQLdb.Error(_e) from _e - return rowcount + case EditStatus.rejected: + changes = dict(zip(rejected_ids, + __fetch_case_attrs_changes__(cursor, tuple(rejected_ids)))) + case _: + raise TypeError + return { + "change-type": str(change_type), + "count": { + "reviews": len(review_ids), + "approvals": len(approved_ids), + "rejections": len(rejected_ids) + }, + "data": changes + } + + +# pylint: disable=[too-many-locals, too-many-branches] +def apply_change(cursor, change_type: EditStatus, change_id: int, directory: Path) -> bool: + """Applies or rejects a case attribute change and updates its + status in the audit table and LMDB. + + Processes a change identified by `change_id` based on the + specified `change_type` (approved or rejected). For approved + changes, applies modifications to the `CaseAttributeXRefNew` table + using bulk inserts and updates the audit status. For rejected + changes, updates the audit status only. Manages change IDs in + LMDB by moving them from the 'review' set to either 'approved' or + 'rejected' sets. Returns False if the `change_id` is not in the + review set. + + Args: + cursor: A MySQLdb cursor for executing SQL queries. + change_type (EditStatus): The action to perform, either + `EditStatus.approved` or `EditStatus.rejected`. + change_id (int): The ID of the change to process, + corresponding to a record in `caseattributes_audit`. + directory (Path): The base directory path for the LMDB + database. + + Returns: + bool: True if the change was successfully applied or rejected, + False if `change_id` is not found in the LMDB 'review' + set. + Notes: + - Opens an LMDB environment in the specified `directory` with + a map size of 8 MB. + - For `EditStatus.approved`, fetches `json_diff_data` from + `caseattributes_audit`, extracts modifications, and performs + bulk inserts into `CaseAttributeXRefNew` with `ON DUPLICATE + KEY UPDATE`. + - For `EditStatus.rejected`, updates the + `caseattributes_audit` status without modifying case + attributes. + - Uses bulk `SELECT` queries to fetch `StrainId` and + `CaseAttributeId` values efficiently. + - Assumes `CaseAttributeXRefNew` has a unique key on + `(InbredSetId, StrainId, CaseAttributeId)` for `ON DUPLICATE + KEY UPDATE`. + - The `json_diff_data` is expected to contain valid JSON with + an `inbredset_id` and `Modifications.Current` structure. + - The second column from `fetchone()` is ignored (denoted by + `_`). -def approve_case_attribute(conn: Any, case_attr_audit_id: int) -> int: - """Given the id of the json_diff in the case_attribute_audit table, - approve it + Raises: + ValueError: If `change_type` is neither `EditStatus.approved` + nor `EditStatus.rejected`. + json.JSONDecodeError: If `json_diff_data` cannot be + deserialized for approved changes. + TypeError: If `cursor.fetchone()` returns None for + `json_diff_data` or if `strain_id` or `caseattr_id` are + missing during bulk insert preparation. + pickle.UnpicklingError: If LMDB data (e.g., 'review' or + 'approved' sets) cannot be deserialized. """ - rowcount = 0 - try: - with conn.cursor() as cursor: - cursor.execute( - "SELECT json_diff_data FROM caseattributes_audit " - "WHERE id = %s", - (case_attr_audit_id,), - ) - diff_data = cursor.fetchone() - if diff_data: - diff_data = json.loads(diff_data[0]) - # Insert (Most Important) - if diff_data.get("Insert"): - data = diff_data.get("Insert") + review_ids, approved_ids, rejected_ids = set(), set(), set() + directory.mkdir(parents=True, exist_ok=True) + env = lmdb.open(directory.as_posix(), map_size=8_000_000) # 1 MB + with env.begin(write=True) as txn: + if reviews := txn.get(b"review"): + review_ids = pickle.loads(reviews) + if change_id not in review_ids: + return False + match change_type: + case EditStatus.rejected: + cursor.execute( + "UPDATE caseattributes_audit " + "SET status = %s " + "WHERE id = %s", + (str(change_type), change_id)) + if rejections := txn.get(b"rejected"): + rejected_ids = pickle.loads(rejections) + rejected_ids.add(change_id) + review_ids.discard(change_id) + txn.put(b"review", pickle.dumps(review_ids)) + txn.put(b"rejected", pickle.dumps(rejected_ids)) + return True + case EditStatus.approved: + cursor.execute( + "SELECT json_diff_data " + "FROM caseattributes_audit WHERE " + "id = %s", + (change_id,) + ) + result = cursor.fetchone() + if result is None: + return False + json_diff_data = result[0] + json_diff_data = json.loads(json_diff_data) + inbredset_id = json_diff_data.get("inbredset_id") + modifications = json_diff_data.get( + "Modifications", {}).get("Current", {}) + strains = tuple(modifications.keys()) + case_attrs = set() + for data in modifications.values(): + case_attrs.update(data.keys()) + + # Bulk fetch strain ids + strain_id_map = {} + if strains: cursor.execute( - "INSERT INTO CaseAttribute " - "(Name, Description) VALUES " - "(%s, %s)", - ( - data.get("name").strip(), - data.get("description").strip(), - ), + "SELECT Name, Id FROM Strain WHERE Name IN " + f"({', '.join(['%s'] * len(strains))})", + strains ) - # Delete - elif diff_data.get("Deletion"): - data = diff_data.get("Deletion") + for name, strain_id in cursor.fetchall(): + strain_id_map[name] = strain_id + + # Bulk fetch case attr ids + caseattr_id_map = {} + if case_attrs: cursor.execute( - "DELETE FROM CaseAttribute WHERE Id = %s", - (data.get("id"),), + "SELECT Name, CaseAttributeId FROM CaseAttribute " + "WHERE InbredSetId = %s AND Name IN " + f"({', '.join(['%s'] * len(case_attrs))})", + (inbredset_id, *case_attrs) ) - # Modification - elif diff_data.get("Modification"): - data = diff_data.get("Modification") - if desc_ := data.get("description"): - cursor.execute( - "UPDATE CaseAttribute SET " - "Description = %s WHERE Id = %s", - ( - desc_.get("Current"), - diff_data.get("id"), - ), - ) - if name_ := data.get("name"): - cursor.execute( - "UPDATE CaseAttribute SET " - "Name = %s WHERE Id = %s", - ( - name_.get("Current"), - diff_data.get("id"), - ), - ) - if cursor.rowcount: - cursor.execute( - "UPDATE caseattributes_audit SET " - "status = 'approved' WHERE id = %s", - (case_attr_audit_id,), + for name, caseattr_id in cursor.fetchall(): + caseattr_id_map[name] = caseattr_id + + # Bulk insert data + insert_data = [] + for strain, data in modifications.items(): + strain_id = strain_id_map.get(strain) + for case_attr, value in data.items(): + insert_data.append({ + "inbredset_id": inbredset_id, + "strain_id": strain_id, + "caseattr_id": caseattr_id_map.get(case_attr), + "value": value, + }) + if insert_data: + cursor.executemany( + "INSERT INTO CaseAttributeXRefNew " + "(InbredSetId, StrainId, CaseAttributeId, Value) " + "VALUES (%(inbredset_id)s, %(strain_id)s, %(caseattr_id)s, %(value)s) " + "ON DUPLICATE KEY UPDATE Value = VALUES(Value)", + insert_data ) - rowcount = cursor.rowcount - except Exception as _e: - raise MySQLdb.Error(_e) from _e - return rowcount + + # Update LMDB and audit table + cursor.execute( + "UPDATE caseattributes_audit " + "SET status = %s " + "WHERE id = %s", + (str(change_type), change_id)) + if approvals := txn.get(b"approved"): + approved_ids = pickle.loads(approvals) + approved_ids.add(change_id) + review_ids.discard(change_id) + txn.put(b"review", pickle.dumps(review_ids)) + txn.put(b"approved", pickle.dumps(approved_ids)) + return True + case _: + raise ValueError diff --git a/gn3/db/correlations.py b/gn3/db/correlations.py index aec8eac..5d6cfb3 100644 --- a/gn3/db/correlations.py +++ b/gn3/db/correlations.py @@ -328,7 +328,7 @@ def build_temporary_tissue_correlations_table( return temp_table_name -def fetch_tissue_correlations(# pylint: disable=R0913 +def fetch_tissue_correlations(# pylint: disable=[R0913, too-many-arguments, too-many-positional-arguments] dataset: dict, trait_symbol: str, probeset_freeze_id: int, method: str, return_number: int, conn: Any) -> dict: """ @@ -529,7 +529,7 @@ def __build_query__( f"ORDER BY {db_type}.Id"), 1) -# pylint: disable=too-many-arguments +# pylint: disable=[too-many-arguments, too-many-positional-arguments] def __fetch_data__( conn, sample_ids: tuple, db_name: str, db_type: str, method: str, temp_table: Optional[str]) -> Tuple[Tuple[Any], int]: diff --git a/gn3/db/datasets.py b/gn3/db/datasets.py index f3b4f9f..fea207b 100644 --- a/gn3/db/datasets.py +++ b/gn3/db/datasets.py @@ -79,6 +79,21 @@ def retrieve_mrna_group_name(connection: Any, probeset_id: int, dataset_name: st return res[0] return None +def retrieve_group_id(connection: Any, group_name: str): + """ + Given the group name, retrieve the group ID + """ + query = ( + "SELECT iset.Id " + "FROM InbredSet AS iset " + "WHERE iset.Name = %(group_name)s") + with connection.cursor() as cursor: + cursor.execute(query, {"group_name": group_name}) + res = cursor.fetchone() + if res: + return res[0] + return None + def retrieve_phenotype_group_name(connection: Any, dataset_id: int): """ Given the dataset id (PublishFreeze.Id in the database), retrieve the name diff --git a/gn3/db/menu.py b/gn3/db/menu.py index 8dccabf..34dedde 100644 --- a/gn3/db/menu.py +++ b/gn3/db/menu.py @@ -1,10 +1,12 @@ """Menu generation code for the data in the dropdowns in the index page.""" - +import logging from typing import Tuple from functools import reduce from gn3.db.species import get_all_species +logger = logging.getLogger(__name__) + def gen_dropdown_json(conn): """ Generates and outputs (as json file) the data for the main dropdown menus on @@ -14,10 +16,12 @@ def gen_dropdown_json(conn): groups = get_groups(conn, tuple(row[0] for row in species)) types = get_types(conn, groups) datasets = get_datasets(conn, types) - return dict(species=species, - groups=groups, - types=types, - datasets=datasets) + return { + "species": species, + "groups": groups, + "types": types, + "datasets": datasets + } def get_groups(conn, species_names: Tuple[str, ...]): """Build groups list""" @@ -35,6 +39,7 @@ def get_groups(conn, species_names: Tuple[str, ...]): "IFNULL(InbredSet.Family, InbredSet.FullName) ASC, " "InbredSet.FullName ASC, " "InbredSet.MenuOrderId ASC") + logger.debug("'get_groups' QUERY: %s, %s", query, species_names) cursor.execute(query, tuple(species_names)) results = cursor.fetchall() diff --git a/gn3/db/probesets.py b/gn3/db/probesets.py index 910f05b..e725add 100644 --- a/gn3/db/probesets.py +++ b/gn3/db/probesets.py @@ -8,6 +8,9 @@ from gn3.db_utils import Connection as DBConnection from .query_tools import mapping_to_query_columns + +# pylint: disable = line-too-long + @dataclass(frozen=True) class Probeset: # pylint: disable=[too-many-instance-attributes] """Data Type that represents a Probeset""" @@ -40,40 +43,42 @@ class Probeset: # pylint: disable=[too-many-instance-attributes] # Mapping from the Phenotype dataclass to the actual column names in the # database probeset_mapping = { - "id_": "Id", - "name": "Name", - "symbol": "symbol", - "description": "description", - "probe_target_description": "Probe_Target_Description", - "chr_": "Chr", - "mb": "Mb", - "alias": "alias", - "geneid": "GeneId", - "homologeneid": "HomoloGeneID", - "unigeneid": "UniGeneId", - "omim": "OMIM", - "refseq_transcriptid": "RefSeq_TranscriptId", - "blatseq": "BlatSeq", - "targetseq": "TargetSeq", - "strand_probe": "Strand_Probe", - "probe_set_target_region": "Probe_set_target_region", - "probe_set_specificity": "Probe_set_specificity", - "probe_set_blat_score": "Probe_set_BLAT_score", - "probe_set_blat_mb_start": "Probe_set_Blat_Mb_start", - "probe_set_blat_mb_end": "Probe_set_Blat_Mb_end", - "probe_set_strand": "Probe_set_strand", - "probe_set_note_by_rw": "Probe_set_Note_by_RW", - "flag": "flag" + "id_": "ProbeSet.Id", + "name": "ProbeSet.Name", + "symbol": "ProbeSet.symbol", + "description": "ProbeSet.description", + "probe_target_description": "ProbeSet.Probe_Target_Description", + "chr_": "ProbeSet.Chr", + "mb": "ProbeSet.Mb", + "alias": "ProbeSet.alias", + "geneid": "ProbeSet.GeneId", + "homologeneid": "ProbeSet.HomoloGeneID", + "unigeneid": "ProbeSet.UniGeneId", + "omim": "ProbeSet.OMIM", + "refseq_transcriptid": "ProbeSet.RefSeq_TranscriptId", + "blatseq": "ProbeSet.BlatSeq", + "targetseq": "ProbeSet.TargetSeq", + "strand_probe": "ProbeSet.Strand_Probe", + "probe_set_target_region": "ProbeSet.Probe_set_target_region", + "probe_set_specificity": "ProbeSet.Probe_set_specificity", + "probe_set_blat_score": "ProbeSet.Probe_set_BLAT_score", + "probe_set_blat_mb_start": "ProbeSet.Probe_set_Blat_Mb_start", + "probe_set_blat_mb_end": "ProbeSet.Probe_set_Blat_Mb_end", + "probe_set_strand": "ProbeSet.Probe_set_strand", + "probe_set_note_by_rw": "ProbeSet.Probe_set_Note_by_RW", + "flag": "ProbeSet.flag" } -def fetch_probeset_metadata_by_name(conn: DBConnection, name: str) -> dict: +def fetch_probeset_metadata_by_name(conn: DBConnection, trait_name: str, dataset_name: str) -> dict: """Fetch a ProbeSet's metadata by its `name`.""" with conn.cursor(cursorclass=DictCursor) as cursor: cols = ", ".join(mapping_to_query_columns(probeset_mapping)) cursor.execute((f"SELECT {cols} " - "FROM ProbeSet " - "WHERE Name = %(name)s"), - {"name": name}) + "FROM ProbeSetFreeze " + "INNER JOIN ProbeSetXRef ON ProbeSetXRef.`ProbeSetFreezeId` = ProbeSetFreeze.`Id` " + "INNER JOIN ProbeSet ON ProbeSet.`Id` = ProbeSetXRef.`ProbeSetId` " + "WHERE ProbeSet.Name = %(trait_name)s AND ProbeSetFreeze.Name = %(ds_name)s"), + {"trait_name": trait_name, "ds_name": dataset_name}) return cursor.fetchone() def update_probeset(conn, probeset_id, data:dict) -> int: diff --git a/gn3/db/rdf/wiki.py b/gn3/db/rdf/wiki.py index b2b301a..dd8d204 100644 --- a/gn3/db/rdf/wiki.py +++ b/gn3/db/rdf/wiki.py @@ -15,6 +15,7 @@ from gn3.db.rdf import ( RDF_PREFIXES, query_frame_and_compact, update_rdf, + sparql_query, ) @@ -41,6 +42,10 @@ def __sanitize_result(result: dict) -> dict: if not result: return {} categories = result.get("categories") + if (version := result.get("version")) and isinstance(version, str): + result["version"] = int(version) + if (wiki_id := result.get("id")) and isinstance(version, str): + result["id"] = int(wiki_id) if isinstance(categories, str): result["categories"] = [categories] if categories else [] result["categories"] = sorted(result["categories"]) @@ -79,7 +84,7 @@ CONSTRUCT { gnt:belongsToCategory ?category ; gnt:hasVersion ?max ; dct:created ?created ; - dct:identifier ?id_ . + dct:identifier ?id . } FROM $graph WHERE { ?comment rdfs:label ?text_ ; gnt:symbol ?symbol ; @@ -88,12 +93,12 @@ CONSTRUCT { dct:created ?createTime . FILTER ( LCASE(STR(?symbol)) = LCASE("$symbol") ) . { - SELECT (MAX(?vers) AS ?max) ?id_ WHERE { + SELECT (MAX(?vers) AS ?max_) ?id_ WHERE { ?comment dct:identifier ?id_ ; dct:hasVersion ?vers . } } - ?comment dct:hasVersion ?max . + ?comment dct:hasVersion ?max_ . OPTIONAL { ?comment gnt:reason ?reason_ } . OPTIONAL { ?comment gnt:belongsToSpecies ?speciesId . @@ -106,6 +111,8 @@ CONSTRUCT { OPTIONAL { ?comment gnt:belongsToCategory ?category_ } . BIND (str(?createTime) AS ?created) . BIND (str(?text_) AS ?text) . + BIND (str(?max_) AS ?max) . + BIND (str(?id_) AS ?id) . BIND (STR(COALESCE(?pmid_, "")) AS ?pmid) . BIND (COALESCE(?reason_, "") AS ?reason) . BIND (STR(COALESCE(?weburl_, "")) AS ?weburl) . @@ -154,7 +161,7 @@ CONSTRUCT { rdfs:label ?text_ ; gnt:symbol ?symbol ; dct:created ?createTime ; - dct:hasVersion ?version ; + dct:hasVersion ?version_ ; dct:identifier $comment_id . OPTIONAL { ?comment gnt:reason ?reason_ } . OPTIONAL { @@ -167,6 +174,7 @@ CONSTRUCT { OPTIONAL { ?comment foaf:mbox ?email_ . } . OPTIONAL { ?comment gnt:belongsToCategory ?category_ . } . BIND (str(?text_) AS ?text) . + BIND (str(?version_) AS ?version) . BIND (str(?createTime) AS ?created) . BIND (STR(COALESCE(?pmid_, "")) AS ?pmid) . BIND (COALESCE(?reason_, "") AS ?reason) . @@ -186,38 +194,42 @@ CONSTRUCT { def update_wiki_comment( - insert_dict: dict, - sparql_user: str, - sparql_password: str, - sparql_auth_uri: str, - graph: str = "<http://genenetwork.org>", + insert_dict: dict, + sparql_user: str, + sparql_password: str, + sparql_auth_uri: str, + graph: str = "<http://genenetwork.org>", ) -> str: """Update a wiki comment by inserting a comment with the same -identifier but an updated version id. + identifier but an updated version id. """ name = f"gn:wiki-{insert_dict['Id']}-{insert_dict['versionId']}" - comment_triple = Template("""$name rdfs:label '''$comment'''@en ; + comment_triple = Template( + """$name rdfs:label '''$comment'''@en ; rdf:type gnc:GNWikiEntry ; gnt:symbol "$symbol" ; dct:identifier "$comment_id"^^xsd:integer ; dct:hasVersion "$next_version"^^xsd:integer ; dct:created "$created"^^xsd:datetime . -""").substitute( +""" + ).substitute( comment=insert_dict["comment"], - name=name, symbol=insert_dict['symbol'], - comment_id=insert_dict["Id"], next_version=insert_dict["versionId"], - created=insert_dict["createtime"]) + name=name, + symbol=insert_dict["symbol"], + comment_id=insert_dict["Id"], + next_version=insert_dict["versionId"], + created=insert_dict["createtime"], + ) using = "" if insert_dict["email"]: comment_triple += f"{name} foaf:mbox <{insert_dict['email']}> .\n" if insert_dict["initial"]: comment_triple += f"{name} gnt:initial \"{insert_dict['initial']}\" .\n" - if insert_dict["species"]: + if insert_dict["species"] and insert_dict["species"].lower() != "no specific species": comment_triple += f"{name} gnt:belongsToSpecies ?speciesId .\n" using = Template( - """ USING $graph WHERE { ?speciesId gnt:shortName "$species" . } """).substitute( - graph=graph, species=insert_dict["species"] - ) + """ USING $graph WHERE { ?speciesId gnt:shortName "$species" . } """ + ).substitute(graph=graph, species=insert_dict["species"]) if insert_dict["reason"]: comment_triple += f"{name} gnt:reason \"{insert_dict['reason']}\" .\n" if insert_dict["weburl"]: @@ -236,10 +248,110 @@ INSERT { GRAPH $graph { $comment_triple} } $using -""").substitute(prefix=RDF_PREFIXES, - graph=graph, - comment_triple=comment_triple, - using=using), +""" + ).substitute( + prefix=RDF_PREFIXES, graph=graph, comment_triple=comment_triple, using=using + ), + sparql_user=sparql_user, + sparql_password=sparql_password, + sparql_auth_uri=sparql_auth_uri, + ) + + +def get_rif_entries_by_symbol( + symbol: str, sparql_uri: str, graph: str = "<http://genenetwork.org>" +) -> dict: + """Fetch NCBI RIF entries for a given symbol (case-insensitive). + +This function retrieves NCBI RIF entries using a SPARQL `SELECT` query +instead of a `CONSTRUCT` to avoid truncation. The Virtuoso SPARQL +engine limits query results to 1,048,576 triples per solution, and +NCBI entries can exceed this limit. Since there may be more than +2,000 entries, which could result in the number of triples surpassing +the limit, `SELECT` is used to ensure complete data retrieval without +truncation. See: + +<https://community.openlinksw.com/t/sparql-query-limiting-results-to-100000-triples/2131> + + """ + # XXX: Consider pagination + query = Template( + """ +$prefix + +SELECT ?comment ?symbol ?species ?pubmed_id ?version ?created ?gene_id ?taxonomic_id +FROM $graph WHERE { + ?comment_id rdfs:label ?text_ ; + gnt:symbol ?symbol ; + rdf:type gnc:NCBIWikiEntry ; + gnt:hasGeneId ?gene_id_ ; + dct:hasVersion ?version ; + dct:references ?pmid_ ; + dct:created ?createTime ; + gnt:belongsToSpecies ?speciesId . + ?speciesId rdfs:label ?species . + FILTER ( LCASE(?symbol) = LCASE("$symbol") ) . + OPTIONAL { ?comment_id skos:notation ?taxonId_ . } . + BIND (STR(?text_) AS ?comment) . + BIND (xsd:integer(STRAFTER(STR(?taxonId_), STR(taxon:))) AS ?taxonomic_id) . + BIND (xsd:integer(STRAFTER(STR(?pmid_), STR(pubmed:))) AS ?pubmed_id) . + BIND (xsd:integer(STRAFTER(STR(?gene_id_), STR(generif:))) AS ?gene_id) . + BIND (STR(?createTime) AS ?created) . +} ORDER BY ?species ?createTime +""" + ).substitute(prefix=RDF_PREFIXES, graph=graph, symbol=symbol) + results: dict[str, dict | list] = { + "@context": { + "dct": "http://purl.org/dc/terms/", + "gnt": "http://genenetwork.org/term/", + "rdfs": "http://www.w3.org/2000/01/rdf-schema#", + "skos": "http://www.w3.org/2004/02/skos/core#", + "symbol": "gnt:symbol", + "species": "gnt:species", + "taxonomic_id": "skos:notation", + "gene_id": "gnt:hasGeneId", + "pubmed_id": "dct:references", + "created": "dct:created", + "comment": "rdfs:comment", + "version": "dct:hasVersion", + } + } + data: list[dict[str, int | str]] = [] + for entry in sparql_query(query=query, endpoint=sparql_uri, format_type="json"): + data.append( + { + key: int(metadata.get("value")) + if metadata.get("value").isdigit() + else metadata.get("value") + for key, metadata in entry.items() + } + ) + results["data"] = data + return results + + +def delete_wiki_entries_by_id( + wiki_id: int, + sparql_user: str, + sparql_password: str, + sparql_auth_uri: str, + graph: str = "<http://genenetwork.org>", +) -> str: + """Delete all wiki entries associated with a given ID.""" + query = Template( + """ +$prefix + +DELETE WHERE { + GRAPH $graph { + ?comment dct:identifier \"$wiki_id\"^^xsd:integer . + ?comment ?p ?o . + } +} +""" + ).substitute(prefix=RDF_PREFIXES, graph=graph, wiki_id=wiki_id) + return update_rdf( + query=query, sparql_user=sparql_user, sparql_password=sparql_password, sparql_auth_uri=sparql_auth_uri, diff --git a/gn3/db/sample_data.py b/gn3/db/sample_data.py index 8db40e3..4e01a3a 100644 --- a/gn3/db/sample_data.py +++ b/gn3/db/sample_data.py @@ -59,20 +59,32 @@ def __extract_actions( return result def get_mrna_sample_data( - conn: Any, probeset_id: str, dataset_name: str + conn: Any, probeset_id: int, dataset_name: str, probeset_name: str = None # type: ignore ) -> Dict: """Fetch a mRNA Assay (ProbeSet in the DB) trait's sample data and return it as a dict""" with conn.cursor() as cursor: - cursor.execute(""" -SELECT st.Name, ifnull(psd.value, 'x'), ifnull(psse.error, 'x'), ifnull(ns.count, 'x') -FROM ProbeSetFreeze psf - JOIN ProbeSetXRef psx ON psx.ProbeSetFreezeId = psf.Id - JOIN ProbeSet ps ON ps.Id = psx.ProbeSetId - JOIN ProbeSetData psd ON psd.Id = psx.DataId - JOIN Strain st ON psd.StrainId = st.Id - LEFT JOIN ProbeSetSE psse ON psse.DataId = psd.Id AND psse.StrainId = psd.StrainId - LEFT JOIN NStrain ns ON ns.DataId = psd.Id AND ns.StrainId = psd.StrainId -WHERE ps.Id = %s AND psf.Name= %s""", (probeset_id, dataset_name)) + if probeset_name: + cursor.execute(""" + SELECT st.Name, ifnull(psd.value, 'x'), ifnull(psse.error, 'x'), ifnull(ns.count, 'x') + FROM ProbeSetFreeze psf + JOIN ProbeSetXRef psx ON psx.ProbeSetFreezeId = psf.Id + JOIN ProbeSet ps ON ps.Id = psx.ProbeSetId + JOIN ProbeSetData psd ON psd.Id = psx.DataId + JOIN Strain st ON psd.StrainId = st.Id + LEFT JOIN ProbeSetSE psse ON psse.DataId = psd.Id AND psse.StrainId = psd.StrainId + LEFT JOIN NStrain ns ON ns.DataId = psd.Id AND ns.StrainId = psd.StrainId + WHERE ps.Name = %s AND psf.Name= %s""", (probeset_name, dataset_name)) + else: + cursor.execute(""" + SELECT st.Name, ifnull(psd.value, 'x'), ifnull(psse.error, 'x'), ifnull(ns.count, 'x') + FROM ProbeSetFreeze psf + JOIN ProbeSetXRef psx ON psx.ProbeSetFreezeId = psf.Id + JOIN ProbeSet ps ON ps.Id = psx.ProbeSetId + JOIN ProbeSetData psd ON psd.Id = psx.DataId + JOIN Strain st ON psd.StrainId = st.Id + LEFT JOIN ProbeSetSE psse ON psse.DataId = psd.Id AND psse.StrainId = psd.StrainId + LEFT JOIN NStrain ns ON ns.DataId = psd.Id AND ns.StrainId = psd.StrainId + WHERE ps.Id = %s AND psf.Name= %s""", (probeset_id, dataset_name)) sample_data = {} for data in cursor.fetchall(): @@ -118,18 +130,28 @@ WHERE ps.Id = %s AND psf.Name= %s""", (probeset_id, dataset_name)) return "\n".join(trait_csv) def get_pheno_sample_data( - conn: Any, trait_name: int, phenotype_id: int + conn: Any, trait_name: int, phenotype_id: int, group_id: int = None # type: ignore ) -> Dict: """Fetch a phenotype (Publish in the DB) trait's sample data and return it as a dict""" with conn.cursor() as cursor: - cursor.execute(""" -SELECT st.Name, ifnull(pd.value, 'x'), ifnull(ps.error, 'x'), ifnull(ns.count, 'x') -FROM PublishFreeze pf JOIN PublishXRef px ON px.InbredSetId = pf.InbredSetId - JOIN PublishData pd ON pd.Id = px.DataId JOIN Strain st ON pd.StrainId = st.Id - LEFT JOIN PublishSE ps ON ps.DataId = pd.Id AND ps.StrainId = pd.StrainId - LEFT JOIN NStrain ns ON ns.DataId = pd.Id AND ns.StrainId = pd.StrainId -WHERE px.Id = %s AND px.PhenotypeId = %s -ORDER BY st.Name""", (trait_name, phenotype_id)) + if group_id: + cursor.execute(""" + SELECT st.Name, ifnull(ROUND(pd.value, 2), 'x'), ifnull(ROUND(ps.error, 3), 'x'), ifnull(ns.count, 'x') + FROM PublishFreeze pf JOIN PublishXRef px ON px.InbredSetId = pf.InbredSetId + JOIN PublishData pd ON pd.Id = px.DataId JOIN Strain st ON pd.StrainId = st.Id + LEFT JOIN PublishSE ps ON ps.DataId = pd.Id AND ps.StrainId = pd.StrainId + LEFT JOIN NStrain ns ON ns.DataId = pd.Id AND ns.StrainId = pd.StrainId + WHERE px.Id = %s AND px.InbredSetId = %s + ORDER BY st.Name""", (trait_name, group_id)) + else: + cursor.execute(""" + SELECT st.Name, ifnull(pd.value, 'x'), ifnull(ps.error, 'x'), ifnull(ns.count, 'x') + FROM PublishFreeze pf JOIN PublishXRef px ON px.InbredSetId = pf.InbredSetId + JOIN PublishData pd ON pd.Id = px.DataId JOIN Strain st ON pd.StrainId = st.Id + LEFT JOIN PublishSE ps ON ps.DataId = pd.Id AND ps.StrainId = pd.StrainId + LEFT JOIN NStrain ns ON ns.DataId = pd.Id AND ns.StrainId = pd.StrainId + WHERE px.Id = %s AND px.PhenotypeId = %s + ORDER BY st.Name""", (trait_name, phenotype_id)) sample_data = {} for data in cursor.fetchall(): @@ -302,8 +324,8 @@ def update_sample_data( if data_type == "mrna": strain_id, data_id, inbredset_id = get_mrna_sample_data_ids( conn=conn, - probeset_id=int(probeset_id), - dataset_name=dataset_name, + probeset_id=int(probeset_id),# pylint: disable=[possibly-used-before-assignment] + dataset_name=dataset_name,# pylint: disable=[possibly-used-before-assignment] strain_name=extract_strain_name(csv_header, original_data), ) none_case_attrs = { @@ -315,8 +337,8 @@ def update_sample_data( else: strain_id, data_id, inbredset_id = get_pheno_sample_data_ids( conn=conn, - publishxref_id=int(trait_name), - phenotype_id=phenotype_id, + publishxref_id=int(trait_name),# pylint: disable=[possibly-used-before-assignment] + phenotype_id=phenotype_id,# pylint: disable=[possibly-used-before-assignment] strain_name=extract_strain_name(csv_header, original_data), ) none_case_attrs = { @@ -422,8 +444,8 @@ def delete_sample_data( if data_type == "mrna": strain_id, data_id, inbredset_id = get_mrna_sample_data_ids( conn=conn, - probeset_id=int(probeset_id), - dataset_name=dataset_name, + probeset_id=int(probeset_id),# pylint: disable=[possibly-used-before-assignment] + dataset_name=dataset_name,# pylint: disable=[possibly-used-before-assignment] strain_name=extract_strain_name(csv_header, data), ) none_case_attrs: Dict[str, Any] = { @@ -435,8 +457,8 @@ def delete_sample_data( else: strain_id, data_id, inbredset_id = get_pheno_sample_data_ids( conn=conn, - publishxref_id=int(trait_name), - phenotype_id=phenotype_id, + publishxref_id=int(trait_name),# pylint: disable=[possibly-used-before-assignment] + phenotype_id=phenotype_id,# pylint: disable=[possibly-used-before-assignment] strain_name=extract_strain_name(csv_header, data), ) none_case_attrs = { @@ -528,8 +550,8 @@ def insert_sample_data( if data_type == "mrna": strain_id, data_id, inbredset_id = get_mrna_sample_data_ids( conn=conn, - probeset_id=int(probeset_id), - dataset_name=dataset_name, + probeset_id=int(probeset_id),# pylint: disable=[possibly-used-before-assignment] + dataset_name=dataset_name,# pylint: disable=[possibly-used-before-assignment] strain_name=extract_strain_name(csv_header, data), ) none_case_attrs = { @@ -541,8 +563,8 @@ def insert_sample_data( else: strain_id, data_id, inbredset_id = get_pheno_sample_data_ids( conn=conn, - publishxref_id=int(trait_name), - phenotype_id=phenotype_id, + publishxref_id=int(trait_name),# pylint: disable=[possibly-used-before-assignment] + phenotype_id=phenotype_id,# pylint: disable=[possibly-used-before-assignment] strain_name=extract_strain_name(csv_header, data), ) none_case_attrs = { @@ -584,3 +606,145 @@ def insert_sample_data( return count except Exception as _e: raise MySQLdb.Error(_e) from _e + +def batch_update_sample_data( + conn: Any, diff_data: Dict +): + """Given sample data diffs, execute all relevant update/insert/delete queries""" + def __fetch_data_id(conn, db_type, trait_id, dataset_name): + with conn.cursor() as cursor: + if db_type == "Publish": + cursor.execute( + ( + f"SELECT {db_type}XRef.DataId " + f"FROM {db_type}XRef, {db_type}Freeze " + f"WHERE {db_type}XRef.InbredSetId = {db_type}Freeze.InbredSetId AND " + f"{db_type}XRef.Id = %s AND " + f"{db_type}Freeze.Name = %s" + ), (trait_id, dataset_name) + ) + elif db_type == "ProbeSet": + cursor.execute( + ( + f"SELECT {db_type}XRef.DataId " + f"FROM {db_type}XRef, {db_type}, {db_type}Freeze " + f"WHERE {db_type}XRef.InbredSetId = {db_type}Freeze.InbredSetId AND " + f"{db_type}XRef.ProbeSetId = {db_type}.Id AND " + f"{db_type}.Name = %s AND " + f"{db_type}Freeze.Name = %s" + ), (trait_id, dataset_name) + ) + return cursor.fetchone()[0] + + def __fetch_strain_id(conn, strain_name): + with conn.cursor() as cursor: + cursor.execute( + "SELECT Id FROM Strain WHERE Name = %s", (strain_name,) + ) + return cursor.fetchone()[0] + + def __update_query(conn, db_type, data_id, strain_id, diffs): + with conn.cursor() as cursor: + if 'value' in diffs: + cursor.execute( + ( + f"UPDATE {db_type}Data " + "SET value = %s " + "WHERE Id = %s AND StrainId = %s" + ), (diffs['value']['Current'], data_id, strain_id) + ) + if 'error' in diffs: + cursor.execute( + ( + f"UPDATE {db_type}SE " + "SET error = %s " + "WHERE DataId = %s AND StrainId = %s" + ), (diffs['error']['Current'], data_id, strain_id) + ) + if 'n_cases' in diffs: + cursor.execute( + ( + "UPDATE NStrain " + "SET count = %s " + "WHERE DataId = %s AND StrainId = %s" + ), (diffs['n_cases']['Current'], data_id, strain_id) + ) + + conn.commit() + + def __insert_query(conn, db_type, data_id, strain_id, diffs): + with conn.cursor() as cursor: + if 'value' in diffs: + cursor.execute( + ( + f"INSERT INTO {db_type}Data (Id, StrainId, value)" + "VALUES (%s, %s, %s)" + ), (data_id, strain_id, diffs['value']) + ) + if 'error' in diffs: + cursor.execute( + ( + f"INSERT INTO {db_type}SE (DataId, StrainId, error)" + "VALUES (%s, %s, %s)" + ), (data_id, strain_id, diffs['error']) + ) + if 'n_cases' in diffs: + cursor.execute( + ( + "INSERT INTO NStrain (DataId, StrainId, count)" + "VALUES (%s, %s, %s)" + ), (data_id, strain_id, diffs['n_cases']) + ) + + conn.commit() + + def __delete_query(conn, db_type, data_id, strain_id, diffs): + with conn.cursor() as cursor: + if 'value' in diffs: + cursor.execute( + ( + f"DELETE FROM {db_type}Data " + "WHERE Id = %s AND StrainId = %s" + ), (data_id, strain_id) + ) + if 'error' in diffs: + cursor.execute( + ( + f"DELETE FROM {db_type}SE " + "WHERE DataId = %s AND StrainId = %s" + ), (data_id, strain_id) + ) + if 'n_cases' in diffs: + cursor.execute( + ( + "DELETE FROM NStrain " + "WHERE DataId = %s AND StrainId = %s" + ), (data_id, strain_id) + ) + + conn.commit() + + def __update_data(conn, db_type, data_id, diffs, update_type): + for strain in diffs: + strain_id = __fetch_strain_id(conn, strain) + if update_type == "update": + __update_query(conn, db_type, data_id, strain_id, diffs[strain]) + elif update_type == "insert": + __insert_query(conn, db_type, data_id, strain_id, diffs[strain]) + elif update_type == "delete": + __delete_query(conn, db_type, data_id, strain_id, diffs[strain]) + + for key in diff_data: + dataset, trait = key.split(":") + if "Publish" in dataset: + db_type = "Publish" + else: + db_type = "ProbeSet" + + data_id = __fetch_data_id(conn, db_type, trait, dataset) + + __update_data(conn, db_type, data_id, diff_data[key]['Modifications'], 'update') + __update_data(conn, db_type, data_id, diff_data[key]['Additions'], 'insert') + __update_data(conn, db_type, data_id, diff_data[key]['Deletions'], 'delete') + + return diff_data diff --git a/gn3/db/traits.py b/gn3/db/traits.py index fa13fcc..fbac0da 100644 --- a/gn3/db/traits.py +++ b/gn3/db/traits.py @@ -3,7 +3,6 @@ import os from functools import reduce from typing import Any, Dict, Sequence -from gn3.settings import TMPDIR from gn3.chancy import random_string from gn3.function_helpers import compose from gn3.db.datasets import retrieve_trait_dataset @@ -690,7 +689,7 @@ def retrieve_trait_data(trait: dict, conn: Any, samplelist: Sequence[str] = tupl return {} -def generate_traits_filename(base_path: str = TMPDIR): +def generate_traits_filename(base_path: str): """Generate a unique filename for use with generated traits files.""" return ( f"{os.path.abspath(base_path)}/traits_test_file_{random_string(10)}.txt") diff --git a/gn3/db/wiki.py b/gn3/db/wiki.py index 0f46855..e702569 100644 --- a/gn3/db/wiki.py +++ b/gn3/db/wiki.py @@ -22,12 +22,20 @@ def _decode_dict(result: dict): def get_latest_comment(connection, comment_id: int) -> int: """ Latest comment is one with the highest versionId """ cursor = connection.cursor(DictCursor) - query = """ SELECT versionId AS version, symbol, PubMed_ID AS pubmed_ids, sp.Name AS species, - comment, email, weburl, initial, reason - FROM `GeneRIF` gr - INNER JOIN Species sp USING(SpeciesId) - WHERE gr.Id = %s - ORDER BY versionId DESC LIMIT 1; + query = """SELECT versionId AS version, + symbol, + PubMed_ID AS pubmed_ids, + COALESCE(sp.Name, 'no specific species') AS species, + comment, + email, + weburl, + initial, + reason +FROM `GeneRIF` gr +LEFT JOIN Species sp USING(SpeciesId) +WHERE gr.Id = %s +ORDER BY versionId DESC +LIMIT 1; """ cursor.execute(query, (str(comment_id),)) result = _decode_dict(cursor.fetchone()) @@ -48,6 +56,8 @@ def get_latest_comment(connection, comment_id: int) -> int: def get_species_id(cursor, species_name: str) -> int: """Find species id given species `Name`""" + if species_name.lower() == "no specific species": + return 0 cursor.execute( "SELECT SpeciesID from Species WHERE Name = %s", (species_name,)) species_ids = cursor.fetchall() @@ -70,6 +80,14 @@ def get_next_comment_version(cursor, comment_id: int) -> int: return latest_version + 1 +def get_next_comment_id(cursor) -> int: + """Get the next GeneRIF.Id""" + cursor.execute( + "SELECT MAX(Id) from GeneRIF" + ) + return cursor.fetchone()[0] + 1 + + def get_categories_ids(cursor, categories: List[str]) -> List[int]: """Get the categories_ids from a list of category strings""" dict_cats = get_categories(cursor) @@ -93,7 +111,7 @@ def get_categories(cursor) -> Dict[str, int]: def get_species(cursor) -> Dict[str, str]: """Get all species""" - cursor.execute("SELECT Name, SpeciesName from Species") + cursor.execute("SELECT Name, SpeciesName from Species ORDER BY Species.Id") raw_species = cursor.fetchall() dict_cats = dict(raw_species) return dict_cats diff --git a/gn3/db_utils.py b/gn3/db_utils.py index 0d9bd0a..e1816b0 100644 --- a/gn3/db_utils.py +++ b/gn3/db_utils.py @@ -1,49 +1,76 @@ """module contains all db related stuff""" -import contextlib import logging -from typing import Any, Iterator, Protocol, Tuple +import contextlib from urllib.parse import urlparse -import MySQLdb as mdb +from typing import Callable + import xapian +# XXXX: Replace instances that call db_utils.Connection or +# db_utils.database_connection with a direct call to gn_libs. +# pylint: disable=[W0611] +from gn_libs.mysqldb import Connection, database_connection # type: ignore + LOGGER = logging.getLogger(__file__) -def parse_db_url(sql_uri: str) -> Tuple: - """function to parse SQL_URI env variable note:there\ - is a default value for SQL_URI so a tuple result is\ - always expected""" - parsed_db = urlparse(sql_uri) - return ( - parsed_db.hostname, parsed_db.username, parsed_db.password, - parsed_db.path[1:], parsed_db.port) +def __check_true__(val: str) -> bool: + """Check whether the variable 'val' has the string value `true`.""" + return val.strip().lower() == "true" -# pylint: disable=missing-class-docstring, missing-function-docstring, too-few-public-methods -class Connection(Protocol): - """Type Annotation for MySQLdb's connection object""" - def cursor(self, *args, **kwargs) -> Any: - """A cursor in which queries may be performed""" +def __parse_db_opts__(opts: str) -> dict: + """Parse database options into their appropriate values. + This assumes use of python-mysqlclient library.""" + allowed_opts = ( + "unix_socket", "connect_timeout", "compress", "named_pipe", + "init_command", "read_default_file", "read_default_group", + "cursorclass", "use_unicode", "charset", "collation", "auth_plugin", + "sql_mode", "client_flag", "multi_statements", "ssl_mode", "ssl", + "local_infile", "autocommit", "binary_prefix") + conversion_fns: dict[str, Callable] = { + **{opt: str for opt in allowed_opts}, + "connect_timeout": int, + "compress": __check_true__, + "use_unicode": __check_true__, + # "cursorclass": __load_cursor_class__ + "client_flag": int, + "multi_statements": __check_true__, + # "ssl": __parse_ssl_options__, + "local_infile": __check_true__, + "autocommit": __check_true__, + "binary_prefix": __check_true__ + } + queries = tuple(filter(bool, opts.split("&"))) + if len(queries) > 0: + keyvals: tuple[tuple[str, ...], ...] = tuple( + tuple(item.strip() for item in query.split("=")) + for query in queries) -@contextlib.contextmanager -def database_connection(sql_uri: str, logger: logging.Logger = LOGGER) -> Iterator[Connection]: - """Connect to MySQL database.""" - host, user, passwd, db_name, port = parse_db_url(sql_uri) - connection = mdb.connect(db=db_name, - user=user, - passwd=passwd or '', - host=host, - port=port or 3306) - try: - yield connection - except mdb.Error as _mbde: - logger.error("DB error encountered", exc_info=True) - connection.rollback() - finally: - connection.commit() - connection.close() + def __check_opt__(opt): + assert opt in allowed_opts, ( + f"Invalid database connection option ({opt}) provided.") + return opt + return { + __check_opt__(key): conversion_fns[key](val) + for key, val in keyvals + } + return {} + + +def parse_db_url(sql_uri: str) -> dict: + """Parse the `sql_uri` variable into a dict of connection parameters.""" + parsed_db = urlparse(sql_uri) + return { + "host": parsed_db.hostname, + "port": parsed_db.port or 3306, + "user": parsed_db.username, + "password": parsed_db.password, + "database": parsed_db.path.strip("/").strip(), + **__parse_db_opts__(parsed_db.query) + } @contextlib.contextmanager diff --git a/gn3/errors.py b/gn3/errors.py index cd795e8..46483db 100644 --- a/gn3/errors.py +++ b/gn3/errors.py @@ -16,7 +16,7 @@ from authlib.oauth2.rfc6749.errors import OAuth2Error from flask import Flask, jsonify, Response, current_app from gn3.oauth2 import errors as oautherrors -from gn3.auth.authorisation.errors import AuthorisationError +from gn3.oauth2.errors import AuthorisationError from gn3.llms.errors import LLMError def add_trace(exc: Exception, jsonmsg: dict) -> dict: @@ -60,7 +60,7 @@ def handle_authorisation_error(exc: AuthorisationError): return jsonify(add_trace(exc, { "error": type(exc).__name__, "error_description": " :: ".join(exc.args) - })), exc.error_code + })), 500 def handle_oauth2_errors(exc: OAuth2Error): diff --git a/gn3/heatmaps.py b/gn3/heatmaps.py index b6822d4..80c38e8 100644 --- a/gn3/heatmaps.py +++ b/gn3/heatmaps.py @@ -145,6 +145,7 @@ def build_heatmap( app.config['REAPER_COMMAND'], genotype_filename, traits_filename, + output_dir=app.config["TMPDIR"], separate_nperm_output=True ) @@ -292,7 +293,7 @@ def process_traits_data_for_heatmap(data, trait_names, chromosome_names): for chr_name in chromosome_names] return hdata -def clustered_heatmap( +def clustered_heatmap(# pylint: disable=[too-many-positional-arguments] data: Sequence[Sequence[float]], clustering_data: Sequence[float], x_axis,#: Dict[Union[str, int], Union[str, Sequence[str]]], y_axis: Dict[str, Union[str, Sequence[str]]], @@ -335,7 +336,7 @@ def clustered_heatmap( fig.add_trace( heatmap, row=((i + 2) if vertical else 1), - col=(1 if vertical else (i + 2))) + col=(1 if vertical else i + 2)) axes_layouts = { "{axis}axis{count}".format( # pylint: disable=[C0209] diff --git a/gn3/jobs.py b/gn3/jobs.py index 1af63f7..898517c 100644 --- a/gn3/jobs.py +++ b/gn3/jobs.py @@ -24,7 +24,7 @@ def job(redisconn: Redis, job_id: UUID) -> Either: if the_job: return Right({ key: json.loads(value, object_hook=jed.custom_json_decoder) - for key, value in the_job.items() + for key, value in the_job.items()# type: ignore[union-attr] }) return Left({ "error": "NotFound", diff --git a/gn3/llms/client.py b/gn3/llms/client.py index 41f7292..cac83be 100644 --- a/gn3/llms/client.py +++ b/gn3/llms/client.py @@ -50,7 +50,7 @@ class GeneNetworkQAClient(Session): super().__init__() self.headers.update( {"Authorization": "Bearer " + api_key}) - self.base_url = "https://genenetwork.fahamuai.com/api/tasks" + self.base_url = "https://balg-qa.genenetwork.org/api/tasks" self.answer_url = f"{self.base_url}/answers" self.feedback_url = f"{self.base_url}/feedback" self.query = "" diff --git a/gn3/llms/process.py b/gn3/llms/process.py index b8e47e7..c007057 100644 --- a/gn3/llms/process.py +++ b/gn3/llms/process.py @@ -10,7 +10,7 @@ from urllib.parse import quote from gn3.llms.client import GeneNetworkQAClient -BASE_URL = 'https://genenetwork.fahamuai.com/api/tasks' +BASE_URL = 'https://balg-qa.genenetwork.org/api/tasks' BASEDIR = os.path.abspath(os.path.dirname(__file__)) diff --git a/gn3/loggers.py b/gn3/loggers.py index 5e52a9f..2cb0ca0 100644 --- a/gn3/loggers.py +++ b/gn3/loggers.py @@ -1,8 +1,14 @@ """Setup loggers""" +import os import sys import logging from logging import StreamHandler +logging.basicConfig( + format=("%(asctime)s — %(filename)s:%(lineno)s — %(levelname)s " + "(%(thread)d:%(threadName)s): %(message)s") +) + # ========== Setup formatters ========== # ========== END: Setup formatters ========== @@ -10,12 +16,35 @@ def loglevel(app): """'Compute' the LOGLEVEL from the application.""" return logging.DEBUG if app.config.get("DEBUG", False) else logging.WARNING -def setup_app_handlers(app): - """Setup the logging handlers for the application `app`.""" - # ========== Setup handlers ========== + +def setup_modules_logging(level, modules): + """Configure logging levels for a list of modules.""" + for module in modules: + _logger = logging.getLogger(module) + _logger.setLevel(level) + + +def __add_default_handlers__(app): + """Add some default handlers, if running in dev environment.""" stderr_handler = StreamHandler(stream=sys.stderr) app.logger.addHandler(stderr_handler) - # ========== END: Setup handlers ========== root_logger = logging.getLogger() root_logger.addHandler(stderr_handler) root_logger.setLevel(loglevel(app)) + + +def __add_gunicorn_handlers__(app): + """Set up logging for the WSGI environment with GUnicorn""" + logger = logging.getLogger("gunicorn.error") + app.logger.handlers = logger.handlers + app.logger.setLevel(logger.level) + return app + + +def setup_app_logging(app): + """Setup the logging handlers for the application `app`.""" + software, *_version_and_comments = os.environ.get( + "SERVER_SOFTWARE", "").split('/') + return (__add_gunicorn_handlers__(app) + if bool(software) + else __add_default_handlers__(app)) diff --git a/gn3/oauth2/jwks.py b/gn3/oauth2/jwks.py index 8798a3f..c670bf7 100644 --- a/gn3/oauth2/jwks.py +++ b/gn3/oauth2/jwks.py @@ -12,7 +12,7 @@ from gn3.oauth2.errors import TokenValidationError def fetch_jwks(authserveruri: str, path: str = "auth/public-jwks") -> KeySet: """Fetch the JWKs from a particular URI""" try: - response = requests.get(urljoin(authserveruri, path)) + response = requests.get(urljoin(authserveruri, path), timeout=300) if response.status_code == 200: return KeySet([ JsonWebKey.import_key(key) for key in response.json()["jwks"]]) diff --git a/gn3/settings.json b/gn3/settings.json new file mode 100644 index 0000000..a2f427e --- /dev/null +++ b/gn3/settings.json @@ -0,0 +1,59 @@ +{ + "==": "================ FLASK SETTINGS ================", + "--": "-- Base --", + "SECRET_KEY": "password", + "APPLICATION_ENVIRONMENT": "", + + "--": "-- Flask-CORS --", + "CORS_ORIGINS": "*", + "CORS_HEADERS": [ + "Content-Type", + "Authorization", + "Access-Control-Allow-Credentials" + ], + "==": "================================================", + + + "==": "================ Filesystem Paths SETTINGS ================", + "TMPDIR": "/tmp", + "DATA_DIR": "", + "CACHEDIR": "", + "LMDB_DATA_PATH": "/var/lib/lmdb", + "XAPIAN_DB_PATH": "xapian", + "LLM_DB_PATH": "", + "GENOTYPE_FILES": "/var/lib/genenetwork/genotype-files/genotype", + "TEXTDIR": "/gnshare/gn/web/ProbeSetFreeze_DataMatrix", + "_comment_TEXTDIR": "The configuration variable `TEXTDIR` points to a directory containing text files used for certain processes. On tux01 this path is '/home/gn1/production/gnshare/gn/web/ProbeSetFreeze_DataMatrix'.", + "==": "================================================", + + + "==": "================ Connection URIs ================", + "REDIS_URI": "redis://localhost:6379/0", + "SQL_URI": "mysql://user:password@host/db", + "SPARQL_ENDPOINT": "http://localhost:9082/sparql", + "AUTH_SERVER_URL": "", + "==": "================================================", + + + "==": "================ CLI Commands ================", + "GEMMA_WRAPPER_CMD": "gemma-wrapper", + "WGCNA_RSCRIPT": "wgcna_analysis.R", + "REAPER_COMMAND": "qtlreaper", + "CORRELATION_COMMAND": "correlation_rust", + "==": "================================================", + + + "==": "================ Service-Specific Settings ================", + "--": "-- Redis --", + "REDIS_JOB_QUEUE": "GN3::job-queue", + + "--": "-- Fahamu --", + "FAHAMU_AUTH_TOKEN": "", + "==": "================================================", + + + "==": "================ Application-Specific Settings ================", + "ROUND_TO": 10, + "MULTIPROCESSOR_PROCS": 6, + "==": "================================================" +} diff --git a/gn3/settings.py b/gn3/settings.py deleted file mode 100644 index 04aa129..0000000 --- a/gn3/settings.py +++ /dev/null @@ -1,115 +0,0 @@ -""" -Default configuration settings for this project. - -DO NOT import from this file, use `flask.current_app.config` instead to get the -application settings. -""" -import os -import uuid -import tempfile - -BCRYPT_SALT = "$2b$12$mxLvu9XRLlIaaSeDxt8Sle" # Change this! -DATA_DIR = "" -GEMMA_WRAPPER_CMD = os.environ.get("GEMMA_WRAPPER", "gemma-wrapper") -CACHEDIR = "" -REDIS_URI = "redis://localhost:6379/0" -REDIS_JOB_QUEUE = "GN3::job-queue" -TMPDIR = os.environ.get("TMPDIR", tempfile.gettempdir()) - -# SPARQL endpoint -SPARQL_ENDPOINT = os.environ.get( - "SPARQL_ENDPOINT", - "http://localhost:9082/sparql") - -# LMDB path -LMDB_PATH = os.environ.get( - "LMDB_PATH", f"{os.environ.get('HOME')}/tmp/dataset") - -# SQL confs -SQL_URI = os.environ.get( - "SQL_URI", "mysql://webqtlout:webqtlout@localhost/db_webqtl") -SECRET_KEY = "password" -# gn2 results only used in fetching dataset info - - -# FAHAMU API TOKEN -FAHAMU_AUTH_TOKEN = "" - -GN2_BASE_URL = "http://www.genenetwork.org/" - -# wgcna script -WGCNA_RSCRIPT = "wgcna_analysis.R" -# qtlreaper command -REAPER_COMMAND = f"{os.environ.get('GUIX_ENVIRONMENT')}/bin/qtlreaper" - -# correlation command - -CORRELATION_COMMAND = f"{os.environ.get('GN2_PROFILE')}/bin/correlation_rust" - -# genotype files -GENOTYPE_FILES = os.environ.get( - "GENOTYPE_FILES", f"{os.environ.get('HOME')}/genotype_files/genotype") - -# Xapian index -XAPIAN_DB_PATH = "xapian" - -# sqlite path - -LLM_DB_PATH = "" -# CROSS-ORIGIN SETUP - - -def parse_env_cors(default): - """Parse comma-separated configuration into list of strings.""" - origins_str = os.environ.get("CORS_ORIGINS", None) - if origins_str: - return [ - origin.strip() for origin in origins_str.split(",") if origin != ""] - return default - - -CORS_ORIGINS = parse_env_cors("*") - -CORS_HEADERS = [ - "Content-Type", - "Authorization", - "Access-Control-Allow-Credentials" -] - -GNSHARE = os.environ.get("GNSHARE", "/gnshare/gn/") -TEXTDIR = f"{GNSHARE}/web/ProbeSetFreeze_DataMatrix" - -ROUND_TO = 10 - -MULTIPROCESSOR_PROCS = 6 # Number of processes to spawn - -AUTH_SERVER_URL = "https://auth.genenetwork.org" -AUTH_MIGRATIONS = "migrations/auth" -AUTH_DB = os.environ.get( - "AUTH_DB", f"{os.environ.get('HOME')}/genenetwork/gn3_files/db/auth.db") -OAUTH2_SCOPE = ( - "profile", "group", "role", "resource", "user", "masquerade", - "introspect") - - -try: - # *** SECURITY CONCERN *** - # Clients with access to this privileges create a security concern. - # Be careful when adding to this configuration - OAUTH2_CLIENTS_WITH_INTROSPECTION_PRIVILEGE = tuple( - uuid.UUID(client_id) for client_id in - os.environ.get( - "OAUTH2_CLIENTS_WITH_INTROSPECTION_PRIVILEGE", "").split(",")) -except ValueError as _valerr: - OAUTH2_CLIENTS_WITH_INTROSPECTION_PRIVILEGE = tuple() - -try: - # *** SECURITY CONCERN *** - # Clients with access to this privileges create a security concern. - # Be careful when adding to this configuration - OAUTH2_CLIENTS_WITH_DATA_MIGRATION_PRIVILEGE = tuple( - uuid.UUID(client_id) for client_id in - os.environ.get( - "OAUTH2_CLIENTS_WITH_DATA_MIGRATION_PRIVILEGE", "").split(",")) -except ValueError as _valerr: - OAUTH2_CLIENTS_WITH_DATA_MIGRATION_PRIVILEGE = tuple() diff --git a/gn3/auth/db.py b/gn3/sqlite_db_utils.py index 5cd230f..5cd230f 100644 --- a/gn3/auth/db.py +++ b/gn3/sqlite_db_utils.py diff --git a/main.py b/main.py index 879b344..ccbd14f 100644 --- a/main.py +++ b/main.py @@ -1,105 +1,9 @@ """Main entry point for project""" -import sys -import uuid -import json -from math import ceil -from datetime import datetime - -import click from gn3.app import create_app -from gn3.auth.authorisation.users import hash_password - -from gn3.auth import db app = create_app() -##### BEGIN: CLI Commands ##### - -def __init_dev_users__(): - """Initialise dev users. Get's used in more than one place""" - dev_users_query = "INSERT INTO users VALUES (:user_id, :email, :name)" - dev_users_passwd = "INSERT INTO user_credentials VALUES (:user_id, :hash)" - dev_users = ({ - "user_id": "0ad1917c-57da-46dc-b79e-c81c91e5b928", - "email": "test@development.user", - "name": "Test Development User", - "password": "testpasswd"},) - - with db.connection(app.config["AUTH_DB"]) as conn, db.cursor(conn) as cursor: - cursor.executemany(dev_users_query, dev_users) - cursor.executemany(dev_users_passwd, ( - {**usr, "hash": hash_password(usr["password"])} - for usr in dev_users)) - -@app.cli.command() -def init_dev_users(): - """ - Initialise development users for OAuth2 sessions. - - **NOTE**: You really should not run this in production/staging - """ - __init_dev_users__() - -@app.cli.command() -def init_dev_clients(): - """ - Initialise a development client for OAuth2 sessions. - - **NOTE**: You really should not run this in production/staging - """ - __init_dev_users__() - dev_clients_query = ( - "INSERT INTO oauth2_clients VALUES (" - ":client_id, :client_secret, :client_id_issued_at, " - ":client_secret_expires_at, :client_metadata, :user_id" - ")") - dev_clients = ({ - "client_id": "0bbfca82-d73f-4bd4-a140-5ae7abb4a64d", - "client_secret": "yadabadaboo", - "client_id_issued_at": ceil(datetime.now().timestamp()), - "client_secret_expires_at": 0, - "client_metadata": json.dumps({ - "client_name": "GN2 Dev Server", - "token_endpoint_auth_method": [ - "client_secret_post", "client_secret_basic"], - "client_type": "confidential", - "grant_types": ["password", "authorization_code", "refresh_token"], - "default_redirect_uri": "http://localhost:5033/oauth2/code", - "redirect_uris": ["http://localhost:5033/oauth2/code", - "http://localhost:5033/oauth2/token"], - "response_type": ["code", "token"], - "scope": ["profile", "group", "role", "resource", "register-client", - "user", "masquerade", "migrate-data", "introspect"] - }), - "user_id": "0ad1917c-57da-46dc-b79e-c81c91e5b928"},) - - with db.connection(app.config["AUTH_DB"]) as conn, db.cursor(conn) as cursor: - cursor.executemany(dev_clients_query, dev_clients) - - -@app.cli.command() -@click.argument("user_id", type=click.UUID) -def assign_system_admin(user_id: uuid.UUID): - """Assign user with ID `user_id` administrator role.""" - dburi = app.config["AUTH_DB"] - with db.connection(dburi) as conn, db.cursor(conn) as cursor: - cursor.execute("SELECT * FROM users WHERE user_id=?", - (str(user_id),)) - row = cursor.fetchone() - if row: - cursor.execute( - "SELECT * FROM roles WHERE role_name='system-administrator'") - admin_role = cursor.fetchone() - cursor.execute("INSERT INTO user_roles VALUES (?,?)", - (str(user_id), admin_role["role_id"])) - return 0 - print(f"ERROR: Could not find user with ID {user_id}", - file=sys.stderr) - sys.exit(1) - -##### END: CLI Commands ##### - -if __name__ == '__main__': +if __name__ == "__main__": print("Starting app...") app.run() diff --git a/mypy.ini b/mypy.ini index 426d21f..677eb6f 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,6 +1,9 @@ [mypy] mypy_path = stubs +[gn-libs.*] +ignore_missing_imports = True + [mypy-scipy.*] ignore_missing_imports = True diff --git a/scripts/index-genenetwork b/scripts/index-genenetwork index 2779abc..0544bc7 100755 --- a/scripts/index-genenetwork +++ b/scripts/index-genenetwork @@ -514,9 +514,14 @@ def xapian_compact(combined_index: pathlib.Path, indices: List[pathlib.Path]) -> @click.argument("xapian_directory") @click.argument("sql_uri") @click.argument("sparql_uri") +@click.option("-v", "--virtuoso-ttl-directory", + type=pathlib.Path, + default=pathlib.Path("/var/lib/data/"), + show_default=True) def is_data_modified(xapian_directory: str, sql_uri: str, - sparql_uri: str) -> None: + sparql_uri: str, + virtuoso_ttl_directory: pathlib.Path) -> None: dir_ = pathlib.Path(xapian_directory) with locked_xapian_writable_database(dir_) as db, database_connection(sql_uri) as conn: checksums = "-1" @@ -529,7 +534,7 @@ def is_data_modified(xapian_directory: str, ]) # Return a zero exit status code when the data has changed; # otherwise exit with a 1 exit status code. - generif = pathlib.Path("/var/lib/data/") + generif = virtuoso_ttl_directory if (db.get_metadata("generif-checksum").decode() == md5hash_ttl_dir(generif) and db.get_metadata("checksums").decode() == checksums): sys.exit(1) @@ -540,9 +545,14 @@ def is_data_modified(xapian_directory: str, @click.argument("xapian_directory") @click.argument("sql_uri") @click.argument("sparql_uri") +@click.option("-v", "--virtuoso-ttl-directory", + type=pathlib.Path, + default=pathlib.Path("/var/lib/data/"), + show_default=True) # pylint: disable=missing-function-docstring def create_xapian_index(xapian_directory: str, sql_uri: str, - sparql_uri: str) -> None: + sparql_uri: str, + virtuoso_ttl_directory: pathlib.Path) -> None: logging.basicConfig(level=os.environ.get("LOGLEVEL", "DEBUG"), format='%(asctime)s %(levelname)s: %(message)s', datefmt='%Y-%m-%d %H:%M:%S %Z') @@ -587,7 +597,7 @@ def create_xapian_index(xapian_directory: str, sql_uri: str, logging.info("Writing generif checksums into index") db.set_metadata( "generif-checksum", - md5hash_ttl_dir(pathlib.Path("/var/lib/data/")).encode()) + md5hash_ttl_dir(virtuoso_ttl_directory).encode()) for child in combined_index.iterdir(): shutil.move(child, xapian_directory) logging.info("Index built") diff --git a/scripts/lmdb_matrix.py b/scripts/lmdb_matrix.py new file mode 100644 index 0000000..22173af --- /dev/null +++ b/scripts/lmdb_matrix.py @@ -0,0 +1,437 @@ +"""This scripts reads and store genotype files to an LMDB store. +Similarly, it can be use to read this data. + +Example: + +guix shell python-click python-lmdb python-wrapper python-numpy -- \ + python lmdb_matrix.py import-genotype \ + <path-to-genotype-file> <path-to-lmdb-store> + +guix shell python-click python-lmdb python-wrapper python-numpy -- \ + python lmdb_matrix.py print-current-matrix \ + <path-to-lmdb-store> + +""" +from dataclasses import dataclass +from pathlib import Path +from subprocess import check_output + +import json +import click +import lmdb +import numpy as np + + +@dataclass +class GenotypeMatrix: + """Store the actual Genotype Matrix""" + matrix: np.ndarray + metadata: dict + file_info: dict + + +def count_trailing_newlines(file_path): + """Count trailing newlines in a file""" + with open(file_path, 'rb', encoding="utf-8") as stream: + stream.seek(0, 2) # Move to the end of the file + file_size = stream.tell() + if file_size == 0: + return 0 + chunk_size = 1024 # Read in small chunks + empty_lines = 0 + _buffer = b"" + + # Read chunks from the end backward + while stream.tell() > 0: + # Don't read beyond start + chunk_size = min(chunk_size, stream.tell()) + stream.seek(-chunk_size, 1) # Move backward + chunk = stream.read(chunk_size) + _buffer + stream.seek(-chunk_size, 1) # Move back to start of chunk + # Decode chunk to text + try: + chunk_text = chunk.decode('utf-8', errors='ignore') + except UnicodeDecodeError: + # If decoding fails, keep some bytes for the next + # chunk Keep last 16 bytes to avoid splitting + # characters + _buffer = chunk[-16:] + continue + + # Split into lines from the end + lines = chunk_text.splitlines() + + # Count empty lines from the end + for line in reversed(lines): + if line.strip() != "": + return empty_lines + empty_lines += 1 + if stream.tell() == 0: + return empty_lines + return empty_lines + + +def wc(filename): + """Get total file count of a file""" + return int(check_output(["wc", "-l", filename]).split()[0]) - \ + count_trailing_newlines(filename) + + +def get_genotype_metadata(genotype_file: str) -> tuple[dict, dict]: + """Parse metadata from a genotype file, separating '@'-prefixed + and '#'-prefixed entries. + + This function reads a genotype file and extracts two types of metadata: + - '@'-prefixed metadata (e.g., '@name:BXD'), stored as key-value + pairs for dataset attributes. + - '#'-prefixed metadata (e.g., '# File name: BXD_Geno...'), stored + as key-value pairs for file information. Lines starting with + '#' without a colon are skipped as comments. Parsing stops at + the first non-metadata line. + + Args: + genotype_file (str): Path to the genotype file to be parsed. + + Returns: + tuple[dict, dict]: A tuple containing two dictionaries: + - First dict: '@'-prefixed metadata (e.g., {'name': 'BXD', + 'type': 'riset'}). + - Second dict: '#'-prefixed metadata with colons (e.g., + {'File name': 'BXD_Geno...', 'Citation': '...'}). + + Example: + >>> meta, file_info = get_genotype_metadata("BXD.small.geno") + >>> print(meta) + {'name': 'BXD', 'type': 'riset', 'mat': 'B', 'pat': 'D', + 'het': 'H', 'unk': 'U'} + >>> print(file_info) + {'File name': 'BXD_Geno-19Jan2017b_forGN.xls', 'Metadata': + 'Please retain...'} + + """ + metadata = {} + file_metadata = {} + with open(genotype_file, "r", encoding="utf-8") as stream: + while True: + line = stream.readline().strip() + match line: + case meta if line.startswith("#"): + if ":" in meta: + key, value = meta[2:].split(":", 1) + file_metadata[key] = value + case meta if line.startswith("#"): + continue + case meta if line.startswith("@") and ":" in line: + key, value = meta[1:].split(":", 1) + if value: + metadata[key] = value.strip() + case _: + break + return metadata, file_metadata + + +def get_genotype_dimensions(genotype_file: str) -> tuple[int, int]: + """Calculate the dimensions of the data section in a genotype + file. + + This function determines the number of data rows and columns in a + genotype file. It skips metadata lines (starting with '#' or '@') + and uses the first non-metadata line as the header to count + columns. The total number of lines is counted in binary mode to + efficiently handle large files, and the number of data rows is + calculated by subtracting metadata and header lines. Accounts for + a potential trailing newline. + + Args: + genotype_file (str): Path to the genotype file to be analyzed. + + Returns: + tuple[int, int]: A tuple containing: + - First int: Number of data rows (excluding metadata and + header). + - Second int: Number of columns (based on the header row). + + Example: + >>> rows, cols = get_genotype_dimensions("BXD.small.geno") + >>> print(rows, cols) + 2, 202 # Example: 2 data rows, 202 columns (from header) + + Note: + Assumes the first non-metadata line is the header row, split + by whitespace. A trailing newline may be included in the line + count but is accounted for in the returned row count. + + """ + counter = 0 + rows = [] + + with open(genotype_file, "r", encoding="utf-8") as stream: + while True: + line = stream.readline() + counter += 1 + match line: + case "" | _ if line.startswith(("#", "@", "\n")): + continue + case _: + rows = line.split() + break + return wc(genotype_file) - counter, len(rows) + + +def read_genotype_headers(genotype_file: str) -> list[str]: + """Extract the header row from a genotype file. + + This function reads a genotype file and returns the first + non-metadata line as a list of column headers. It skips lines + starting with '#' (comments), '@' (metadata), or empty lines, + assuming the first non-skipped line contains the header (e.g., + 'Chr', 'Locus', 'cM', 'Mb', followed by strain names like 'BXD1', + 'BXD2', etc.). The header is split by whitespace to create the + list of column names. + + Args: + genotype_file (str): Path to the genotype file to be parsed. + + Returns: + list[str]: A list of column headers from the first non-metadata line. + + Example: + >>> headers = read_genotype_headers("BXD.small.geno") + >>> print(headers) + ['Chr', 'Locus', 'cM', 'Mb', 'BXD1', 'BXD2', ..., 'BXD220'] + """ + rows = [] + with open(genotype_file, "r", encoding="utf-8") as stream: + while True: + line = stream.readline() + match line: + case _ if line.startswith("#") or line.startswith("@") or line == "": + continue + case _: + rows = line.split() + break + return rows + + +# pylint: disable=too-many-locals +def read_genotype_file(genotype_file: str) -> GenotypeMatrix: + """Read a genotype file and construct a GenotypeMatrix object. + + This function parses a genotype file to extract metadata and + genotype data, creating a numerical matrix of genotype values and + associated metadata. It processes: + - '@'-prefixed metadata (e.g., '@mat:B') for dataset attributes + like maternal/paternal alleles. + - '#'-prefixed metadata (e.g., '# File name:...') for file + information. + - Header row for column names (e.g., 'Chr', 'Locus', BXD strains). + - Data rows, converting genotype symbols (e.g., 'B', 'D', 'H', + 'U') to numeric values (0, 1, 2, 3) based on metadata mappings. + + The function skips comment lines ('#'), metadata lines ('@'), and + empty lines, and organizes the data into a GenotypeMatrix with a + numpy array and metadata dictionaries. + + Args: + genotype_file (str): Path to the genotype file to be parsed. + + Returns: + GenotypeMatrix: An object containing: + - matrix: A numpy array (nrows x ncols) with genotype values (0: + maternal, 1: paternal, 2: heterozygous, 3: unknown). + - metadata: A dictionary with '@'-prefixed metadata, row/column + counts, individuals (BXD strains), metadata columns (e.g., 'Chr', + 'Locus'), and lists of metadata values per row. + - file_info: A dictionary with '#'-prefixed metadata (e.g., 'File + name', 'Citation'). + + Raises: + ValueError: If an unrecognized genotype symbol is encountered in + the data. + + Example: + >>> geno_matrix = read_genotype_file("BXD.small.geno") + >>> print(geno_matrix.matrix.shape) + (2, 198) # Example: 2 rows, 198 BXD strains + >>> print(geno_matrix.metadata["name"]) + 'BXD' + >>> print(geno_matrix.file_info["File name"]) + 'BXD_Geno-19Jan2017b_forGN.xls' + """ + header = read_genotype_headers(genotype_file) + + counter = 0 + for i, el in enumerate(header): + if el not in ["Chr", "Locus", "cM", "Mb"]: + break + counter = i + + metadata_columns, individuals = header[:counter], header[counter:] + nrows, ncols = get_genotype_dimensions(genotype_file) + ncols -= len(metadata_columns) + matrix = np.zeros((nrows, ncols), dtype=np.uint8) + + metadata, file_metadata = get_genotype_metadata(genotype_file) + metadata = metadata | { + "nrows": nrows, + "ncols": ncols, + "individuals": individuals, + "metadata_columns": metadata_columns + } + for key in metadata_columns: + metadata[key] = [] + + maternal = metadata.get("mat") + paternal = metadata.get("pat") + heterozygous = metadata.get("het") + unknown = metadata.get("unk") + i = 0 + sentinel = True + with open(genotype_file, "r", encoding="utf-8") as stream: + while True: + if i == nrows: + break + line = stream.readline().split() + meta, data = [], [] + if line and line[0] in metadata_columns: + # Skip the metadata column + line = stream.readline().split() + sentinel = False + if len(line) == 0 or (line[0].startswith("#") and sentinel) \ + or line[0].startswith("@"): + continue + meta, data = line[:len(metadata_columns) + ], line[len(metadata_columns):] + # KLUDGE: It's not clear whether chromosome rows that + # start with a '#' should be a comment or not. For some + # there's a mismatch between (E.g. B6D2F2_mm8) the size of + # the data values and ncols. For now, skip them. + if len(data) != ncols: + i += 1 + continue + for j, el in enumerate(data): + match el: + case _ if el.isdigit(): + matrix[i, j] = int(el) + case _ if maternal == el: + matrix[i, j] = 0 + case _ if paternal == el: + matrix[i, j] = 1 + case _ if heterozygous == el: + matrix[i, j] = 2 + case _ if unknown == el: + matrix[i, j] = 3 + case _: + # KLUDGE: It's not clear how to handle float + # types in a geno file + # E.g. HSNIH-Palmer_true.geno which has float + # values such as: 0.997. Ideally maybe: + # raise ValueError + continue + i += 1 + __map = dict(zip(metadata_columns, meta)) + for key in metadata_columns: + metadata[key].append(__map.get(key)) + + genotype_matrix = GenotypeMatrix( + matrix=matrix, + metadata=metadata, + file_info=file_metadata + ) + return genotype_matrix + + +def create_database(db_path: str) -> lmdb.Environment: + """Create or open an LMDB environment.""" + return lmdb.open(db_path, map_size=100 * 1024 * 1024 * 1024, create=True) + + +def genotype_db_put(db: lmdb.Environment, genotype: GenotypeMatrix) -> bool: + """Put genotype GENOTYPEMATRIX from DB environment""" + metadata = json.dumps(genotype.metadata).encode("utf-8") + file_info = json.dumps(genotype.file_info).encode("utf-8") + with db.begin(write=True) as txn: + txn.put(b"matrix", genotype.matrix.tobytes()) + txn.put(b"metadata", metadata) + # XXXX: KLUDGE: Put this in RDF instead + txn.put(b"file_info", file_info) + return True + + +def genotype_db_get(db: lmdb.Environment) -> GenotypeMatrix: + """Get genotype GENOTYPEMATRIX from DB environment""" + with db.begin() as txn: + metadata = json.loads(txn.get(b"metadata").decode("utf-8")) + nrows, ncols = metadata.get("nrows"), metadata.get("ncols") + matrix = np.frombuffer( + txn.get(b"matrix"), dtype=np.uint8).reshape(nrows, ncols) + return GenotypeMatrix( + matrix=matrix, + metadata=metadata, + file_info=json.loads(txn.get(b"file_info")) + ) + + +def get_genotype_files(directory: str) -> list[tuple[str, int]]: + """Return a list of all the genotype files from a given + directory.""" + geno_files = [ + (_file.as_posix(), _file.stat().st_size) + for _file in Path(directory).glob("*.geno") if _file.is_file()] + return sorted(geno_files, key=lambda x: x[1]) + + +def __import_directory(directory: str, lmdb_path: str): + """Import all the genotype files from a given directory into + LMDB.""" + for file_, file_size in get_genotype_files(directory): + genofile = Path(file_) + size_mb = file_size / (1024 ** 2) + lmdb_store = (Path(lmdb_path) / genofile.stem).as_posix() + print(f"Processing file: {genofile.name}") + with create_database(lmdb_store) as db: + genotype_db_put( + db=db, genotype=read_genotype_file(genofile.as_posix())) + print(f"\nSuccessfuly created: [{size_mb:.2f} MB] {genofile.stem}") + + +@click.command(help="Import the genotype directory") +@click.argument("genotype_directory") +@click.argument("lmdb_path") +def import_directory(genotype_directory: str, lmdb_path: str): + "Import a genotype directory into genotype_database path" + __import_directory(directory=genotype_directory, lmdb_path=lmdb_path) + + +@click.command(help="Import the genotype file") +@click.argument("geno_file") +@click.argument("genotype_database") +def import_genotype(geno_file: str, genotype_database: str): + "Import a genotype file into genotype_database path" + with create_database(genotype_database) as db: + genotype_db_put(db=db, genotype=read_genotype_file(geno_file)) + + +@click.command(help="Print the current matrix") +@click.argument("database_directory") +def print_current_matrix(database_directory: str): + """Print the current matrix in the database.""" + with create_database(database_directory) as db: + current = genotype_db_get(db) + print(f"Matrix: {current.matrix}") + print(f"Metadata: {current.metadata}") + print(f"File Info: {current.file_info}") + + +# pylint: disable=missing-function-docstring +@click.group() +def cli(): + pass + + +cli.add_command(print_current_matrix) +cli.add_command(import_genotype) +cli.add_command(import_directory) + +if __name__ == "__main__": + cli() diff --git a/scripts/partial_correlations.py b/scripts/partial_correlations.py index aab8f08..f47d9d6 100644 --- a/scripts/partial_correlations.py +++ b/scripts/partial_correlations.py @@ -1,7 +1,7 @@ """Script to run partial correlations""" - import json import traceback +from pathlib import Path from argparse import ArgumentParser from gn3.db_utils import database_connection @@ -48,7 +48,8 @@ def pcorrs_against_traits(dbconn, args): def pcorrs_against_db(dbconn, args): """Run partial correlations agaist the entire dataset provided.""" - return partial_correlations_with_target_db(dbconn, **process_db_args(args)) + return partial_correlations_with_target_db( + dbconn, **process_db_args(args), textdir=args.textdir) def run_pcorrs(dbconn, args): """Run the selected partial correlations function.""" @@ -89,6 +90,11 @@ def against_db_parser(parent_parser): "--criteria", help="Number of results to return", type=int, default=500) + parser.add_argument( + "--textdir", + help="Directory to read text files from", + type=Path, + default=Path("/tmp/")) parser.set_defaults(func=pcorrs_against_db) return parent_parser diff --git a/scripts/pub_med.py b/scripts/pub_med.py index 82b1730..0a94355 100644 --- a/scripts/pub_med.py +++ b/scripts/pub_med.py @@ -155,8 +155,8 @@ def fetch_id_lossy_search(query, db_name, max_results): try: response = requests.get(f"http://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi?db={db_name}&retmode=json&retmax={max_results}&term={query}", - headers={"content-type": "application/json"} - ) + headers={"content-type": "application/json"}, + timeout=300) return response["esearchresult"]["idlist"] except requests.exceptions.RequestException as error: @@ -174,7 +174,7 @@ def search_pubmed_lossy(pubmed_id, db_name): - dict: Records fetched based on PubMed ID. """ url = f'https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi?db={db_name}&id={",".join(pubmed_id)}&retmode=json' - response = requests.get(url) + response = requests.get(url, timeout=300) response.raise_for_status() data = response.json() if db_name.lower() == "pmc": diff --git a/scripts/rqtl2_wrapper.R b/scripts/rqtl2_wrapper.R new file mode 100644 index 0000000..af13efa --- /dev/null +++ b/scripts/rqtl2_wrapper.R @@ -0,0 +1,358 @@ +library(qtl2) +library(rjson) +library(stringi) +library(optparse) + +# Define command-line options +option_list <- list( + make_option(c("-d", "--directory"), action = "store", default = NULL, type = "character", + help = "Temporary working directory: should also host the input file."), + make_option(c("-i", "--input_file"), action = "store", default = NULL, type = 'character', + help = "A YAML or JSON file with required data to create the cross file."), + make_option(c("-o", "--output_file"), action = "store", default = NULL, type = 'character', + help = "A file path of where to write the output JSON results."), + make_option(c("-c", "--cores"), type = "integer", default = 1, + help = "Number of cores to use while making computation."), + make_option(c("-p", "--nperm"), type = "integer", default = 0, + help = "Number of permutations."), + make_option(c("-m", "--method"), action = "store", default = "HK", type = "character", + help = "Scan Mapping Method - HK (Haley Knott), LMM (Linear Mixed Model), LOCO (Leave One Chromosome Out)."), + make_option(c("--pstrata"), action = "store_true", default = NULL, + help = "Use permutation strata."), + make_option(c("-t", "--threshold"), type = "integer", default = 1, + help = "Minimum LOD score for a Peak.") +) + +# Parse command-line arguments +opt_parser <- OptionParser(option_list = option_list) +opt <- parse_args(opt_parser) + +# Assign parsed arguments to variables +NO_OF_CORES <- opt$cores +SCAN_METHOD <- opt$method +NO_OF_PERMUTATION <- opt$nperm + +NO_OF_CORES <- 20 + +# Validate input and output file paths +validate_file_paths <- function(opt) { + if (is.null(opt$directory) || !dir.exists(opt$directory)) { + print_help(opt_parser) + stop("The working directory does not exist or is NULL.\n") + } + + INPUT_FILE_PATH <- opt$input_file + OUTPUT_FILE_PATH <- opt$output_file + + if (!file.exists(INPUT_FILE_PATH)) { + print_help(opt_parser) + stop("The input file ", INPUT_FILE_PATH, " you provided does not exist.\n") + } else { + cat("Input file exists. Reading the input file...\n") + } + + if (!file.exists(OUTPUT_FILE_PATH)) { + print_help(opt_parser) + stop("The output file ", OUTPUT_FILE_PATH, " you provided does not exist.\n") + } else { + cat("Output file exists...", OUTPUT_FILE_PATH, "\n") + } + + return(list(input = INPUT_FILE_PATH, output = OUTPUT_FILE_PATH)) +} + +file_paths <- validate_file_paths(opt) +INPUT_FILE_PATH <- file_paths$input +OUTPUT_FILE_PATH <- file_paths$output + +# Utility function to generate random file names +genRandomFileName <- function(prefix, string_size = 9, file_ext = ".txt") { + randStr <- paste(prefix, stri_rand_strings(1, string_size, pattern = "[A-Za-z0-9]"), sep = "_") + return(paste(randStr, file_ext, sep = "")) +} + +# Generate control file path +control_file_path <- file.path(opt$directory, genRandomFileName(prefix = "control", file_ext = ".json")) + +cat("Generated the control file path at", control_file_path, "\n") +# Read and parse the input file +cat("Reading and parsing the input file.\n") +json_data <- fromJSON(file = INPUT_FILE_PATH) + +# Set default values for JSON data +set_default_values <- function(json_data) { + if (is.null(json_data$sep)) { + cat("Using ',' as a default separator for cross file.\n") + json_data$sep <- "," + } + if (is.null(json_data$na.strings)) { + cat("Using '-' and 'NA' as the default na.strings.\n") + json_data$na.strings <- c("-", "NA") + } + + default_keys <- c("geno_transposed", "founder_geno_transposed", "pheno_transposed", + "covar_transposed", "phenocovar_transposed") + + for (item in default_keys) { + if (!(item %in% names(json_data))) { + cat("Using FALSE as default parameter for", item, "\n") + json_data[item] <- FALSE + } + } + + return(json_data) +} + +json_data <- set_default_values(json_data) + +# Function to generate the cross object +generate_cross_object <- function(control_file_path, json_data) { + write_control_file( + control_file_path, + crosstype = json_data$crosstype, + geno_file = json_data$geno_file, + pheno_file = json_data$pheno_file, + gmap_file = json_data$geno_map_file, + pmap_file = json_data$physical_map_file, + phenocovar_file = json_data$phenocovar_file, + geno_codes = json_data$geno_codes, + alleles = json_data$alleles, + na.strings = json_data$na.strings, + sex_file = json_data$sex_file, + founder_geno_file = json_data$founder_geno_file, + covar_file = json_data$covar_file, + sex_covar = json_data$sex_covar, + sex_codes = json_data$sex_codes, + crossinfo_file = json_data$crossinfo_file, + crossinfo_covar = json_data$crossinfo_covar, + crossinfo_codes = json_data$crossinfo_codes, + xchr = json_data$xchr, + overwrite = TRUE, + founder_geno_transposed = json_data$founder_geno_transposed, + geno_transposed = json_data$geno_transposed + ) +} + +# Generate the cross object +cat("Generating the cross object at", control_file_path, "\n") +generate_cross_object(control_file_path, json_data) + +# Read the cross object +cat("Reading the cross object from", control_file_path, "\n") + +cross <- read_cross2(control_file_path, quiet = FALSE) + +# Check the integrity of the cross object +cat("Checking the integrity of the cross object.\n") +if (check_cross2(cross)) { + cat("Cross meets required specifications for a cross.\n") +} else { + cat("Cross does not meet required specifications.\n") +} + +# Print cross summary +cat("A summary about the cross you provided:\n") +summary(cross) + +# Function to compute genetic probabilities +perform_genetic_pr <- function(cross, cores = NO_OF_CORES, step = 1, map = NULL, + map_function = c("haldane", "kosambi", "c-f", "morgan"), + error_prob = 0.002) { + calc_genoprob(cross, map = map, error_prob = error_prob, map_function = map_function, + quiet = FALSE, cores = cores) +} + +# Insert pseudomarkers to the genetic map +cat("Inserting pseudomarkers to the genetic map with step 1 and stepwidth fixed.\n") + +MAP <- insert_pseudomarkers(cross$gmap, step = 1, stepwidth = "fixed", cores = NO_OF_CORES) + +# Calculate genetic probabilities +cat("Calculating the genetic probabilities.\n") +Pr <- perform_genetic_pr(cross) + +# Calculate allele probabilities for 4-way cross +if (cross$crosstype == "4way") { + cat("Calculating allele genetic probability for 4-way cross.\n") + aPr <- genoprob_to_alleleprob(Pr) +} + +# Calculate genotyping error LOD scores +cat("Calculating the genotype error LOD scores.\n") +error_lod <- calc_errorlod(cross, Pr, quiet = FALSE, cores = NO_OF_CORES) +error_lod <- do.call("cbind", error_lod) + +# Get phenotypes and covariates +cat("Getting the phenotypes and covariates.\n") +pheno <- cross$pheno +# covar <- match(cross$covar$sex, c("f", "m")) # make numeric +# TODO rework on this +covar <- NULL +if (!is.null(covar)) { + names(covar) <- rownames(cross$covar) +} + +Xcovar <- get_x_covar(cross) +cat("The covariates are:\n") +print(covar) +cat("The Xcovar are:\n") +print(Xcovar) + +# Function to calculate kinship +get_kinship <- function(probability, method = "LMM") { + if (method == "LMM") { + kinship <- calc_kinship(probability) + } else if (method == "LOCO") { + kinship <- calc_kinship(probability, "loco") + } else { + kinship <- NULL + } + return(kinship) +} + +# Calculate kinship for the genetic probability +cat("Calculating the kinship for the genetic probability.\n") +if (cross$crosstype == "4way") { + kinship <- get_kinship(aPr, opt$method) +} else { + kinship <- get_kinship(Pr, "loco") +} + +# Function to perform genome scan +perform_genome_scan <- function(cross, genome_prob, method, addcovar = NULL, intcovar = NULL, + kinship = NULL, model = c("normal", "binary"), Xcovar = NULL) { + if (method == "LMM") { + cat("Performing scan1 using Linear Mixed Model.\n") + out <- scan1(genome_prob, cross$pheno, kinship = kinship, model = model, cores = NO_OF_CORES) + } else if (method == "LOCO") { + cat("Performing scan1 using Leave One Chromosome Out.\n") + out <- scan1(genome_prob, cross$pheno, kinship = kinship, model = model, cores = NO_OF_CORES) + } else if (method == "HK") { + cat("Performing scan1 using Haley Knott.\n") + out <- scan1(genome_prob, cross$pheno, addcovar = addcovar, intcovar = intcovar, + model = model, Xcovar = Xcovar, cores = NO_OF_CORES) + } + return(out) +} + +# Perform the genome scan for the cross object +if (cross$crosstype == "4way") { + sex <- setNames((cross$covar$Sex == "male") * 1, rownames(cross$covar)) + scan_results <- perform_genome_scan(aPr, cross, kinship = kinship, method = "LOCO", addcovar = sex) +} else { + scan_results <- perform_genome_scan(cross = cross, genome_prob = Pr, kinship = kinship, + method = SCAN_METHOD) +} + +# Save scan results +scan_file <- file.path(opt$directory, "scan_results.csv") +write.csv(scan_results, scan_file) + +# Function to perform permutation tests +perform_permutation_test <- function(cross, genome_prob, n_perm, method = opt$method, + covar = NULL, Xcovar = NULL, addcovar = NULL, + intcovar = NULL, perm_Xsp = FALSE, kinship = NULL, + model = c("normal", "binary"), chr_lengths = NULL, + perm_strata = NULL) { + scan1perm(genome_prob, cross$pheno, kinship = kinship, Xcovar = Xcovar, intcovar = intcovar, + addcovar = addcovar, n_perm = n_perm, perm_Xsp = perm_Xsp, model = model, + chr_lengths = chr_lengths, cores = NO_OF_CORES) +} + +# Check if permutation strata is needed +if (!is.null(opt$pstrata) && !is.null(Xcovar)) { + perm_strata <- mat2strata(Xcovar) +} else { + perm_strata <- NULL +} + +# Perform permutation test if requested +permutation_results_file <- file.path(opt$directory, "permutation.csv") +significance_results_file <- file.path(opt$directory, "significance.csv") + +if (NO_OF_PERMUTATION > 0) { + cat("Performing permutation test for the cross object with", NO_OF_PERMUTATION, "permutations.\n") + perm <- perform_permutation_test(cross, Pr, n_perm = NO_OF_PERMUTATION, perm_strata = perm_strata, + method = opt$method) + + # Function to get LOD significance thresholds + get_lod_significance <- function(perm, thresholds = c(0.01, 0.05, 0.63)) { + cat("Getting the permutation summary with significance thresholds:", thresholds, "\n") + summary(perm, alpha = thresholds) + } + + # Compute LOD significance + lod_significance <- get_lod_significance(perm) + + # Save results + write.csv(lod_significance, significance_results_file) + write.csv(perm, permutation_results_file) +} + + + +# Function to get QTL effects +get_qtl_effect <- function(chromosome, geno_prob, pheno, covar = NULL, LOCO = NULL) { + cat("Finding the QTL effect for chromosome", chromosome, "\n") + chr_Pr <- geno_prob[, chromosome] + if (!is.null(chr_Pr)) { + if (!is.null(LOCO)) { + cat("Finding QTL effect for chromosome", chromosome, "with LOCO.\n") + kinship <- calc_kinship(chr_Pr, "loco")[[chromosome]] + return(scan1coef(chr_Pr, pheno, kinship, addcovar = covar)) + } else { + return(scan1coef(chr_Pr, pheno, addcovar = covar)) + } + } + return(NULL) +} + +# Get QTL effects for each chromosome +# TODO + +# Prepare output data +gmap_file <- file.path(opt$directory, json_data$geno_map_file) +pmap_file <- file.path(opt$directory, json_data$physical_map_file) + + + + + +# Construct the Map object from cross with columns (Marker, chr, cM, Mb) +gmap <- cross$gmap # Genetic map in cM +pmap <- cross$pmap # Physical map in Mb +# Convert lists to data frames +gmap_df <- data.frame( + marker = unlist(lapply(gmap, names)), + chr = rep(names(gmap), sapply(gmap, length)), # Add chromosome info + CM = unlist(gmap), + stringsAsFactors = FALSE +) + +pmap_df <- data.frame( + marker = unlist(lapply(pmap, names)), + chr = rep(names(pmap), sapply(pmap, length)), # Add chromosome info + MB = unlist(pmap), + stringsAsFactors = FALSE +) +# Merge using full outer join (by marker and chromosome) +merged_map <- merge(gmap_df, pmap_df, by = c("marker", "chr"), all = TRUE) +map_file <- file.path(opt$directory, "map.csv") +write.csv(merged_map, map_file, row.names = FALSE) + +output <- list( + permutation_file = permutation_results_file, + significance_file = significance_results_file, + scan_file = scan_file, + gmap_file = gmap_file, + pmap_file = pmap_file, + map_file = map_file, + permutations = NO_OF_PERMUTATION, + scan_method = SCAN_METHOD +) + +# Write output to JSON file +output_json_data <- toJSON(output) +cat("The output file path generated is", OUTPUT_FILE_PATH, "\n") +cat("Writing to the output file.\n") +write(output_json_data, file = OUTPUT_FILE_PATH) \ No newline at end of file diff --git a/scripts/rqtl_wrapper.R b/scripts/rqtl_wrapper.R index 31c1277..0b39a6a 100644 --- a/scripts/rqtl_wrapper.R +++ b/scripts/rqtl_wrapper.R @@ -4,6 +4,9 @@ library(stringi) library(stringr) + +cat("Running the qtl script.\n") + tmp_dir = Sys.getenv("TMPDIR") if (!dir.exists(tmp_dir)) { tmp_dir = "/tmp" @@ -24,7 +27,7 @@ option_list = list( make_option(c("--control"), type="character", default=NULL, help="Name of marker (contained in genotype file) to be used as a control"), make_option(c("-o", "--outdir"), type="character", default=file.path(tmp_dir, "gn3"), help="Directory in which to write result file"), make_option(c("-f", "--filename"), type="character", default=NULL, help="Name to use for result file"), - make_option(c("-v", "--verbose"), action="store_true", default=NULL, help="Show extra information") + make_option(c("-v", "--verbose"), action="store_true", default=TRUE, help="Show extra information") ); opt_parser = OptionParser(option_list=option_list); @@ -353,5 +356,8 @@ if (!is.null(opt$pairscan)) { colnames(qtl_results)[4:7] <- c("AC", "AD", "BC", "BD") } + write.csv(qtl_results, out_file) } + +cat("End of script. Now working on processing the results.\n") diff --git a/scripts/update_rif_table.py b/scripts/update_rif_table.py index 24edf3d..f936f5b 100755 --- a/scripts/update_rif_table.py +++ b/scripts/update_rif_table.py @@ -35,7 +35,7 @@ VALUES (%s, %s, %s, %s, %s, %s, %s, %s) def download_file(url: str, dest: pathlib.Path): """Saves the contents of url in dest""" - with requests.get(url, stream=True) as resp: + with requests.get(url, stream=True, timeout=300) as resp: resp.raise_for_status() with open(dest, "wb") as downloaded_file: for chunk in resp.iter_content(chunk_size=8192): diff --git a/setup.py b/setup.py index 4c1d026..7ad60df 100755 --- a/setup.py +++ b/setup.py @@ -13,22 +13,21 @@ setup(author='Bonface M. K.', description=('GeneNetwork3 REST API for data ' 'science and machine learning'), install_requires=[ - "bcrypt>=3.1.7" - "click" - "Flask==1.1.2" - "mypy==0.790" - "mypy-extensions==0.4.3" - "mysqlclient==2.0.1" - "numpy==1.20.1" - "pylint==2.5.3" - "pymonad" - "redis==3.5.3" - "requests==2.25.1" - "scipy==1.6.0" - "plotly==4.14.3" - "pyld" - "flask-cors==3.0.9" - "xapian-bindings" + "click", + "Flask>=1.1.2", + "mypy>=0.790", + "mypy-extensions>=0.4.3", + "mysqlclient>=2.0.1", + "numpy>=1.20.1", + "pylint>=2.5.3", + "pymonad", + "redis>=3.5.3", + "requests>=2.25.1", + "scipy>=1.6.0", + "plotly>=4.14.3", + "pyld", + "flask-cors", # with the `>=3.0.9` specification, it breaks the build + # "xapian-bindings" # this line breaks `guix shell …` for some reason ], include_package_data=True, scripts=["scripts/index-genenetwork"], @@ -39,7 +38,11 @@ setup(author='Bonface M. K.', packages=find_packages(), url='https://github.com/genenetwork/genenetwork3', version='3.12', - tests_require=["pytest", "hypothesis"], + tests_require=[ + "pytest", + "hypothesis", + "pytest-mock" + ], cmdclass={ - "run_tests": RunTests ## testing + "run_tests": RunTests # type: ignore[dict-item] }) diff --git a/setup_commands/run_tests.py b/setup_commands/run_tests.py index 37d7ffa..8fb0b25 100644 --- a/setup_commands/run_tests.py +++ b/setup_commands/run_tests.py @@ -27,10 +27,11 @@ class RunTests(Command): def finalize_options(self): """Set final value of all the options once they are processed.""" if self.type not in RunTests.test_types: - raise Exception(f""" - Invalid test type (self.type) requested! - Valid types are - {tuple(RunTests.test_types)}""") + raise Exception(# pylint: disable=[broad-exception-raised] + f""" + Invalid test type (self.type) requested! + Valid types are + {tuple(RunTests.test_types)}""") if self.type != "all": self.command = f"pytest -m {self.type}_test" diff --git a/sheepdog/__init__.py b/sheepdog/__init__.py new file mode 100644 index 0000000..6aa513a --- /dev/null +++ b/sheepdog/__init__.py @@ -0,0 +1 @@ +"""The sheepdog scripts package.""" diff --git a/sheepdog/worker.py b/sheepdog/worker.py index c08edec..e8a7177 100644 --- a/sheepdog/worker.py +++ b/sheepdog/worker.py @@ -2,27 +2,37 @@ import os import sys import time +import logging import argparse import redis import redis.connection +from gn3.loggers import setup_modules_logging + # Enable importing from one dir up: put as first to override any other globally # accessible GN3 -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) +sys.path.insert(0, os.path.abspath( + os.path.join(os.path.dirname(__file__), '..'))) +logging.basicConfig( + format=("%(asctime)s — %(filename)s:%(lineno)s — %(levelname)s: " + "CommandWorker: %(message)s")) +logger = logging.getLogger(__name__) + def update_status(conn, cmd_id, status): """Helper to update command status""" conn.hset(name=f"{cmd_id}", key="status", value=f"{status}") -def make_incremental_backoff(init_val: float=0.1, maximum: int=420): + +def make_incremental_backoff(init_val: float = 0.1, maximum: int = 420): """ Returns a closure that can be used to increment the returned value up to `maximum` or reset it to `init_val`. """ current = init_val - def __increment_or_reset__(command: str, value: float=0.1): + def __increment_or_reset__(command: str, value: float = 0.1): nonlocal current if command == "reset": current = init_val @@ -36,7 +46,8 @@ def make_incremental_backoff(init_val: float=0.1, maximum: int=420): return __increment_or_reset__ -def run_jobs(conn, queue_name: str = "GN3::job-queue"): + +def run_jobs(conn, queue_name): """Process the redis using a redis connection, CONN""" # pylint: disable=E0401, C0415 from gn3.commands import run_cmd @@ -44,6 +55,7 @@ def run_jobs(conn, queue_name: str = "GN3::job-queue"): if bool(cmd_id): cmd = conn.hget(name=cmd_id, key="cmd") if cmd and (conn.hget(cmd_id, "status") == b"queued"): + logger.debug("Updating status for job '%s' to 'running'", cmd_id) update_status(conn, cmd_id, "running") result = run_cmd( cmd.decode("utf-8"), env=conn.hget(name=cmd_id, key="env")) @@ -56,6 +68,7 @@ def run_jobs(conn, queue_name: str = "GN3::job-queue"): return cmd_id return None + def parse_cli_arguments(): """Parse the command-line arguments.""" parser = argparse.ArgumentParser( @@ -65,17 +78,37 @@ def parse_cli_arguments(): help=( "Run process as a daemon instead of the default 'one-shot' " "process")) + parser.add_argument( + "--queue-name", default="GN3::job-queue", type=str, + help="The redis list that holds the unique command ids") + parser.add_argument( + "--log-level", default="info", type=str, + choices=("debug", "info", "warning", "error", "critical"), + help="What level to output the logs at.") return parser.parse_args() + if __name__ == "__main__": args = parse_cli_arguments() + logger.setLevel(args.log_level.upper()) + logger.debug("Worker Script: Initialising worker") + setup_modules_logging( + logging.getLevelName(logger.getEffectiveLevel()), + ("gn3.commands",)) with redis.Redis() as redis_conn: if not args.daemon: - run_jobs(redis_conn) + logger.info("Worker Script: Running worker in one-shot mode.") + run_jobs(redis_conn, args.queue_name) + logger.debug("Job completed!") else: + logger.debug("Worker Script: Running worker in daemon-mode.") sleep_time = make_incremental_backoff() while True: # Daemon that keeps running forever: - if run_jobs(redis_conn): + if run_jobs(redis_conn, args.queue_name): + logger.debug("Ran a job. Pausing for a while...") time.sleep(sleep_time("reset")) continue - time.sleep(sleep_time("increment", sleep_time("return_current"))) + time.sleep(sleep_time( + "increment", sleep_time("return_current"))) + + logger.info("Worker exiting …") diff --git a/tests/fixtures/rdf.py b/tests/fixtures/rdf.py index 98c4058..0811d3c 100644 --- a/tests/fixtures/rdf.py +++ b/tests/fixtures/rdf.py @@ -59,7 +59,8 @@ def rdf_setup(): # Make sure this graph does not exist before running anything requests.delete( - SPARQL_CONF["sparql_crud_auth_uri"], params=params, auth=auth + SPARQL_CONF["sparql_crud_auth_uri"], params=params, auth=auth, + timeout=300 ) # Open the file in binary mode and send the request @@ -69,9 +70,11 @@ def rdf_setup(): params=params, auth=auth, data=file, + timeout=300 ) yield response requests.delete( - SPARQL_CONF["sparql_crud_auth_uri"], params=params, auth=auth + SPARQL_CONF["sparql_crud_auth_uri"], params=params, auth=auth, + timeout=300 ) pid.terminate() diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 8e39726..bdbab09 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -1,4 +1,5 @@ """Module that holds fixtures for integration tests""" +from pathlib import Path import pytest import MySQLdb @@ -6,19 +7,25 @@ from gn3.app import create_app from gn3.chancy import random_string from gn3.db_utils import parse_db_url, database_connection + @pytest.fixture(scope="session") def client(): """Create a test client fixture for tests""" # Do some setup - app = create_app() - app.config.update({"TESTING": True}) - app.testing = True + app = create_app({ + "TESTING": True, + "LMDB_DATA_PATH": str( + Path(__file__).parent.parent / + Path("test_data/lmdb-test-data") + ), + "AUTH_SERVER_URL": "http://127.0.0.1:8081", + }) yield app.test_client() # Do some teardown/cleanup @pytest.fixture(scope="session") -def db_conn(client): # pylint: disable=[redefined-outer-name] +def db_conn(client): # pylint: disable=[redefined-outer-name] """Create a db connection fixture for tests""" # 01) Generate random string to append to all test db artifacts for the session live_db_uri = client.application.config["SQL_URI"] diff --git a/tests/integration/test_gemma.py b/tests/integration/test_gemma.py index 53a1596..7bc1df9 100644 --- a/tests/integration/test_gemma.py +++ b/tests/integration/test_gemma.py @@ -63,10 +63,9 @@ class GemmaAPITest(unittest.TestCase): @mock.patch("gn3.api.gemma.assert_paths_exist") @mock.patch("gn3.api.gemma.redis.Redis") @mock.patch("gn3.api.gemma.cache_ipfs_file") - def test_k_compute(self, mock_ipfs_cache, - mock_redis, - mock_path_exist, mock_json, mock_hash, - mock_queue_cmd): + def test_k_compute(# pylint: disable=[too-many-positional-arguments] + self, mock_ipfs_cache, mock_redis, mock_path_exist, mock_json, + mock_hash, mock_queue_cmd): """Test /gemma/k-compute/<token>""" mock_ipfs_cache.return_value = ("/tmp/cache/" "QmQPeNsJPyVWPFDVHb" @@ -106,9 +105,9 @@ class GemmaAPITest(unittest.TestCase): @mock.patch("gn3.api.gemma.assert_paths_exist") @mock.patch("gn3.api.gemma.redis.Redis") @mock.patch("gn3.api.gemma.cache_ipfs_file") - def test_k_compute_loco(self, mock_ipfs_cache, - mock_redis, mock_path_exist, mock_json, - mock_hash, mock_queue_cmd): + def test_k_compute_loco(# pylint: disable=[too-many-positional-arguments] + self, mock_ipfs_cache, mock_redis, mock_path_exist, mock_json, + mock_hash, mock_queue_cmd): """Test /gemma/k-compute/loco/<chromosomes>/<token>""" mock_ipfs_cache.return_value = ("/tmp/cache/" "QmQPeNsJPyVWPFDVHb" @@ -150,9 +149,9 @@ class GemmaAPITest(unittest.TestCase): @mock.patch("gn3.api.gemma.assert_paths_exist") @mock.patch("gn3.api.gemma.redis.Redis") @mock.patch("gn3.api.gemma.cache_ipfs_file") - def test_gwa_compute(self, mock_ipfs_cache, - mock_redis, mock_path_exist, mock_json, - mock_hash, mock_queue_cmd): + def test_gwa_compute(# pylint: disable=[too-many-positional-arguments] + self, mock_ipfs_cache, mock_redis, mock_path_exist, mock_json, + mock_hash, mock_queue_cmd): """Test /gemma/gwa-compute/<k-inputfile>/<token>""" mock_ipfs_cache.return_value = ("/tmp/cache/" "QmQPeNsJPyVWPFDVHb" @@ -201,9 +200,9 @@ class GemmaAPITest(unittest.TestCase): @mock.patch("gn3.api.gemma.assert_paths_exist") @mock.patch("gn3.api.gemma.redis.Redis") @mock.patch("gn3.api.gemma.cache_ipfs_file") - def test_gwa_compute_with_covars(self, mock_ipfs_cache, - mock_redis, mock_path_exist, - mock_json, mock_hash, mock_queue_cmd): + def test_gwa_compute_with_covars(# pylint: disable=[too-many-positional-arguments] + self, mock_ipfs_cache, mock_redis, mock_path_exist, mock_json, + mock_hash, mock_queue_cmd): """Test /gemma/gwa-compute/covars/<k-inputfile>/<token>""" mock_ipfs_cache.return_value = ("/tmp/cache/" "QmQPeNsJPyVWPFDVHb" @@ -255,9 +254,9 @@ class GemmaAPITest(unittest.TestCase): @mock.patch("gn3.api.gemma.assert_paths_exist") @mock.patch("gn3.api.gemma.redis.Redis") @mock.patch("gn3.api.gemma.cache_ipfs_file") - def test_gwa_compute_with_loco_only(self, mock_ipfs_cache, - mock_redis, mock_path_exist, - mock_json, mock_hash, mock_queue_cmd): + def test_gwa_compute_with_loco_only(# pylint: disable=[too-many-positional-arguments] + self, mock_ipfs_cache, mock_redis, mock_path_exist, mock_json, + mock_hash, mock_queue_cmd): """Test /gemma/gwa-compute/<k-inputfile>/loco/maf/<maf>/<token> """ @@ -308,10 +307,9 @@ class GemmaAPITest(unittest.TestCase): @mock.patch("gn3.api.gemma.assert_paths_exist") @mock.patch("gn3.api.gemma.redis.Redis") @mock.patch("gn3.api.gemma.cache_ipfs_file") - def test_gwa_compute_with_loco_covars(self, mock_ipfs_cache, - mock_redis, mock_path_exist, - mock_json, mock_hash, - mock_queue_cmd): + def test_gwa_compute_with_loco_covars(# pylint: disable=[too-many-positional-arguments] + self, mock_ipfs_cache, mock_redis, mock_path_exist, mock_json, + mock_hash, mock_queue_cmd): """Test /gemma/gwa-compute/<k-inputfile>/loco/covars/maf/<maf>/<token> """ @@ -363,10 +361,9 @@ class GemmaAPITest(unittest.TestCase): @mock.patch("gn3.api.gemma.assert_paths_exist") @mock.patch("gn3.api.gemma.redis.Redis") @mock.patch("gn3.api.gemma.cache_ipfs_file") - def test_k_gwa_compute_without_loco_covars(self, mock_ipfs_cache, - mock_redis, - mock_path_exist, mock_json, - mock_hash, mock_queue_cmd): + def test_k_gwa_compute_without_loco_covars(# pylint: disable=[too-many-positional-arguments] + self, mock_ipfs_cache, mock_redis, mock_path_exist, mock_json, + mock_hash, mock_queue_cmd): """Test /gemma/k-gwa-compute/<token> """ @@ -419,10 +416,9 @@ class GemmaAPITest(unittest.TestCase): @mock.patch("gn3.api.gemma.assert_paths_exist") @mock.patch("gn3.api.gemma.redis.Redis") @mock.patch("gn3.api.gemma.cache_ipfs_file") - def test_k_gwa_compute_with_covars_only(self, mock_ipfs_cache, - mock_redis, mock_path_exist, - mock_json, mock_hash, - mock_queue_cmd): + def test_k_gwa_compute_with_covars_only(# pylint: disable=[too-many-positional-arguments] + self, mock_ipfs_cache, mock_redis, mock_path_exist, mock_json, + mock_hash, mock_queue_cmd): """Test /gemma/k-gwa-compute/covars/<token> """ @@ -484,10 +480,9 @@ class GemmaAPITest(unittest.TestCase): @mock.patch("gn3.api.gemma.assert_paths_exist") @mock.patch("gn3.api.gemma.redis.Redis") @mock.patch("gn3.api.gemma.cache_ipfs_file") - def test_k_gwa_compute_with_loco_only(self, mock_ipfs_cache, - mock_redis, mock_path_exist, - mock_json, mock_hash, - mock_queue_cmd): + def test_k_gwa_compute_with_loco_only(# pylint: disable=[too-many-positional-arguments] + self, mock_ipfs_cache, mock_redis, mock_path_exist, mock_json, + mock_hash, mock_queue_cmd): """Test /gemma/k-gwa-compute/loco/<chromosomes>/maf/<maf>/<token> """ @@ -550,10 +545,9 @@ class GemmaAPITest(unittest.TestCase): @mock.patch("gn3.api.gemma.assert_paths_exist") @mock.patch("gn3.api.gemma.redis.Redis") @mock.patch("gn3.api.gemma.cache_ipfs_file") - def test_k_gwa_compute_with_loco_and_covar(self, mock_ipfs_cache, - mock_redis, - mock_path_exist, mock_json, - mock_hash, mock_queue_cmd): + def test_k_gwa_compute_with_loco_and_covar(# pylint: disable=[too-many-positional-arguments] + self, mock_ipfs_cache, mock_redis, mock_path_exist, mock_json, + mock_hash, mock_queue_cmd): """Test /k-gwa-compute/covars/loco/<chromosomes>/maf/<maf>/<token> """ diff --git a/tests/integration/test_lmdb_sample_data.py b/tests/integration/test_lmdb_sample_data.py new file mode 100644 index 0000000..30a23f4 --- /dev/null +++ b/tests/integration/test_lmdb_sample_data.py @@ -0,0 +1,31 @@ +"""Tests for the LMDB sample data API endpoint""" +import pytest + + +@pytest.mark.unit_test +def test_nonexistent_data(client): + """Test endpoint returns 404 when data doesn't exist""" + response = client.get("/api/lmdb/sample-data/nonexistent/123") + assert response.status_code == 404 + assert response.json["error"] == "No data found for given dataset and trait" + + +@pytest.mark.unit_test +def test_successful_retrieval(client): + """Test successful data retrieval using test LMDB data""" + # Use known test data hash: 7308efbd84b33ad3d69d14b5b1f19ccc + response = client.get("/api/lmdb/sample-data/BXDPublish/10007") + assert response.status_code == 200 + + data = response.json + assert len(data) == 31 + # Verify some known values from the test database + assert data["BXD1"] == 18.700001 + assert data["BXD11"] == 18.9 + + +@pytest.mark.unit_test +def test_invalid_trait_id(client): + """Test endpoint handles invalid trait IDs appropriately""" + response = client.get("/api/lmdb/sample-data/BXDPublish/999999") + assert response.status_code == 404 diff --git a/tests/integration/test_partial_correlations.py b/tests/integration/test_partial_correlations.py index fc9f64f..56af260 100644 --- a/tests/integration/test_partial_correlations.py +++ b/tests/integration/test_partial_correlations.py @@ -221,4 +221,4 @@ def test_part_corr_api_with_mix_of_existing_and_non_existing_control_traits( criteria = 10 with pytest.warns(UserWarning): partial_correlations_with_target_db( - db_conn, primary, controls, method, criteria, target) + db_conn, primary, controls, method, criteria, target, "/tmp") diff --git a/tests/test_data/lmdb-test-data/7308efbd84b33ad3d69d14b5b1f19ccc/data.mdb b/tests/test_data/lmdb-test-data/7308efbd84b33ad3d69d14b5b1f19ccc/data.mdb new file mode 100755 index 0000000..5fa213b --- /dev/null +++ b/tests/test_data/lmdb-test-data/7308efbd84b33ad3d69d14b5b1f19ccc/data.mdb Binary files differdiff --git a/tests/test_data/lmdb-test-data/7308efbd84b33ad3d69d14b5b1f19ccc/lock.mdb b/tests/test_data/lmdb-test-data/7308efbd84b33ad3d69d14b5b1f19ccc/lock.mdb new file mode 100755 index 0000000..116d824 --- /dev/null +++ b/tests/test_data/lmdb-test-data/7308efbd84b33ad3d69d14b5b1f19ccc/lock.mdb Binary files differdiff --git a/tests/test_data/ttl-files/test-data.ttl b/tests/test_data/ttl-files/test-data.ttl index 3e27652..c570484 100644 --- a/tests/test_data/ttl-files/test-data.ttl +++ b/tests/test_data/ttl-files/test-data.ttl @@ -1054,3 +1054,276 @@ gn:wiki-7273-0 dct:created "2022-08-24 18:34:41"^^xsd:datetime . gn:wiki-7273-0 foaf:mbox <XXX@XXX.com> . gn:wiki-7273-0 dct:identifier "7273"^^xsd:integer . gn:wiki-7273-0 dct:hasVersion "0"^^xsd:integer . + +gnc:NCBIWikiEntry rdfs:subClassOf gnc:GeneWikiEntry . +gnc:NCBIWikiEntry rdfs:comment "Represents GeneRIF Entries obtained from NCBI" . +gn:rif-12709-37156912-2023-05-17T20:43:00-5 rdf:type gnc:NCBIWikiEntry ; + rdfs:label 'Creatine kinase B suppresses ferroptosis by phosphorylating GPX4 through a moonlighting function.'@en ; + gnt:belongsToSpecies gn:Mus_musculus ; + gnt:symbol "Ckb" ; + gnt:hasGeneId generif:12709 ; + skos:notation taxon:10090 ; + dct:hasVersion "5"^^xsd:integer ; + dct:references pubmed:37156912 ; + dct:created "2023-05-17 20:43:00"^^xsd:datetime . +gn:rif-13176-36456775-2023-03-01T20:36:00-5 rdf:type gnc:NCBIWikiEntry ; + rdfs:label 'DCC/netrin-1 regulates cell death in oligodendrocytes after brain injury.'@en ; + gnt:belongsToSpecies gn:Mus_musculus ; + gnt:symbol "Dcc" ; + gnt:hasGeneId generif:13176 ; + skos:notation taxon:10090 ; + dct:hasVersion "5"^^xsd:integer ; + dct:references pubmed:36456775 ; + dct:created "2023-03-01 20:36:00"^^xsd:datetime . +gn:rif-13176-37541362-2023-09-21T20:40:00-5 rdf:type gnc:NCBIWikiEntry ; + rdfs:label 'Prefrontal cortex-specific Dcc deletion induces schizophrenia-related behavioral phenotypes and fail to be rescued by olanzapine treatment.'@en ; + gnt:belongsToSpecies gn:Mus_musculus ; + gnt:symbol "Dcc" ; + gnt:hasGeneId generif:13176 ; + skos:notation taxon:10090 ; + dct:hasVersion "5"^^xsd:integer ; + dct:references pubmed:37541362 ; + dct:created "2023-09-21 20:40:00"^^xsd:datetime . +gn:rif-16956-36519761-2023-04-27T20:33:00-5 rdf:type gnc:NCBIWikiEntry ; + rdfs:label 'Parkin regulates neuronal lipid homeostasis through SREBP2-lipoprotein lipase pathway-implications for Parkinson\'s disease.'@en ; + gnt:belongsToSpecies gn:Mus_musculus ; + gnt:symbol "Lpl" ; + gnt:hasGeneId generif:16956 ; + skos:notation taxon:10090 ; + dct:hasVersion "5"^^xsd:integer ; + dct:references pubmed:36519761 ; + dct:created "2023-04-27 20:33:00"^^xsd:datetime . +gn:rif-20423-36853961-2023-03-08T20:38:00-5 rdf:type gnc:NCBIWikiEntry ; + rdfs:label 'IHH, SHH, and primary cilia mediate epithelial-stromal cross-talk during decidualization in mice.'@en ; + gnt:belongsToSpecies gn:Mus_musculus ; + gnt:symbol "Shh" ; + gnt:hasGeneId generif:20423 ; + skos:notation taxon:10090 ; + dct:hasVersion "5"^^xsd:integer ; + dct:references pubmed:36853961 ; + dct:created "2023-03-08 20:38:00"^^xsd:datetime . +gn:rif-20423-37190906-2023-07-12T09:09:00-5 rdf:type gnc:NCBIWikiEntry ; + rdfs:label 'The SHH-GLI1 pathway is required in skin expansion and angiogenesis.'@en ; + gnt:belongsToSpecies gn:Mus_musculus ; + gnt:symbol "Shh" ; + gnt:hasGeneId generif:20423 ; + skos:notation taxon:10090 ; + dct:hasVersion "5"^^xsd:integer ; + dct:references pubmed:37190906 ; + dct:created "2023-07-12 09:09:00"^^xsd:datetime . +gn:rif-20423-37460185-2023-07-20T20:39:00-5 rdf:type gnc:NCBIWikiEntry ; + rdfs:label '[Effect study of Sonic hedgehog overexpressed hair follicle stem cells in hair follicle regeneration].'@en ; + gnt:belongsToSpecies gn:Mus_musculus ; + gnt:symbol "Shh" ; + gnt:hasGeneId generif:20423 ; + skos:notation taxon:10090 ; + dct:hasVersion "5"^^xsd:integer ; + dct:references pubmed:37460185 ; + dct:created "2023-07-20 20:39:00"^^xsd:datetime . +gn:rif-20423-37481204-2023-09-26T20:37:00-5 rdf:type gnc:NCBIWikiEntry ; + rdfs:label 'Sonic Hedgehog and WNT Signaling Regulate a Positive Feedback Loop Between Intestinal Epithelial and Stromal Cells to Promote Epithelial Regeneration.'@en ; + gnt:belongsToSpecies gn:Mus_musculus ; + gnt:symbol "Shh" ; + gnt:hasGeneId generif:20423 ; + skos:notation taxon:10090 ; + dct:hasVersion "5"^^xsd:integer ; + dct:references pubmed:37481204 ; + dct:created "2023-09-26 20:37:00"^^xsd:datetime . +gn:rif-24539-38114521-2023-12-29T20:33:00-5 rdf:type gnc:NCBIWikiEntry ; + rdfs:label 'Developing a model to predict the early risk of hypertriglyceridemia based on inhibiting lipoprotein lipase (LPL): a translational study.'@en ; + gnt:belongsToSpecies gn:Rattus_norvegicus ; + gnt:symbol "Lpl" ; + gnt:hasGeneId generif:24539 ; + skos:notation taxon:10116 ; + dct:hasVersion "5"^^xsd:integer ; + dct:references pubmed:38114521 ; + dct:created "2023-12-29 20:33:00"^^xsd:datetime . +gn:rif-29499-36906487-2023-06-23T20:38:00-5 rdf:type gnc:NCBIWikiEntry ; + rdfs:label 'Regulation of Shh/Bmp4 Signaling Pathway by DNA Methylation in Rectal Nervous System Development of Fetal Rats with Anorectal Malformation.'@en ; + gnt:belongsToSpecies gn:Rattus_norvegicus ; + gnt:symbol "Shh" ; + gnt:hasGeneId generif:29499 ; + skos:notation taxon:10116 ; + dct:hasVersion "5"^^xsd:integer ; + dct:references pubmed:36906487 ; + dct:created "2023-06-23 20:38:00"^^xsd:datetime . +gn:rif-29499-37815888-2023-10-24T20:38:00-5 rdf:type gnc:NCBIWikiEntry ; + rdfs:label 'Sonic hedgehog signaling promotes angiogenesis of endothelial progenitor cells to improve pressure ulcers healing by PI3K/AKT/eNOS signaling.'@en ; + gnt:belongsToSpecies gn:Rattus_norvegicus ; + gnt:symbol "Shh" ; + gnt:hasGeneId generif:29499 ; + skos:notation taxon:10116 ; + dct:hasVersion "5"^^xsd:integer ; + dct:references pubmed:37815888 ; + dct:created "2023-10-24 20:38:00"^^xsd:datetime . +gn:rif-1152-37156912-2023-05-17T20:43:00-5 rdf:type gnc:NCBIWikiEntry ; + rdfs:label 'Creatine kinase B suppresses ferroptosis by phosphorylating GPX4 through a moonlighting function.'@en ; + gnt:belongsToSpecies gn:Homo_sapiens ; + gnt:symbol "CKB" ; + gnt:hasGeneId generif:1152 ; + skos:notation taxon:9606 ; + dct:hasVersion "5"^^xsd:integer ; + dct:references pubmed:37156912 ; + dct:created "2023-05-17 20:43:00"^^xsd:datetime . +gn:rif-1630-36889039-2023-04-04T09:45:00-5 rdf:type gnc:NCBIWikiEntry ; + rdfs:label 'Mirror movements and callosal dysgenesis in a family with a DCC mutation: Neuropsychological and neuroimaging outcomes.'@en ; + gnt:belongsToSpecies gn:Homo_sapiens ; + gnt:symbol "DCC" ; + gnt:hasGeneId generif:1630 ; + skos:notation taxon:9606 ; + dct:hasVersion "5"^^xsd:integer ; + dct:references pubmed:36889039 ; + dct:created "2023-04-04 09:45:00"^^xsd:datetime . +gn:rif-1630-36852451-2023-07-07T20:37:00-5 rdf:type gnc:NCBIWikiEntry ; + rdfs:label 'An imbalance of netrin-1 and DCC during nigral degeneration in experimental models and patients with Parkinson\'s disease.'@en ; + gnt:belongsToSpecies gn:Homo_sapiens ; + gnt:symbol "DCC" ; + gnt:hasGeneId generif:1630 ; + skos:notation taxon:9606 ; + dct:hasVersion "5"^^xsd:integer ; + dct:references pubmed:36852451 ; + dct:created "2023-07-07 20:37:00"^^xsd:datetime . +gn:rif-4023-36763533-2023-02-23T20:40:00-5 rdf:type gnc:NCBIWikiEntry ; + rdfs:label 'Angiopoietin-like protein 4/8 complex-mediated plasmin generation leads to cleavage of the complex and restoration of LPL activity.'@en ; + gnt:belongsToSpecies gn:Homo_sapiens ; + gnt:symbol "LPL" ; + gnt:hasGeneId generif:4023 ; + skos:notation taxon:9606 ; + dct:hasVersion "5"^^xsd:integer ; + dct:references pubmed:36763533 ; + dct:created "2023-02-23 20:40:00"^^xsd:datetime . +gn:rif-4023-36652113-2023-04-07T20:39:00-5 rdf:type gnc:NCBIWikiEntry ; + rdfs:label 'The breast cancer microenvironment and lipoprotein lipase: Another negative notch for a beneficial enzyme?'@en ; + gnt:belongsToSpecies gn:Homo_sapiens ; + gnt:symbol "LPL" ; + gnt:hasGeneId generif:4023 ; + skos:notation taxon:9606 ; + dct:hasVersion "5"^^xsd:integer ; + dct:references pubmed:36652113 ; + dct:created "2023-04-07 20:39:00"^^xsd:datetime . +gn:rif-4023-36519761-2023-04-27T20:33:00-5 rdf:type gnc:NCBIWikiEntry ; + rdfs:label 'Parkin regulates neuronal lipid homeostasis through SREBP2-lipoprotein lipase pathway-implications for Parkinson\'s disease.'@en ; + gnt:belongsToSpecies gn:Homo_sapiens ; + gnt:symbol "LPL" ; + gnt:hasGeneId generif:4023 ; + skos:notation taxon:9606 ; + dct:hasVersion "5"^^xsd:integer ; + dct:references pubmed:36519761 ; + dct:created "2023-04-27 20:33:00"^^xsd:datetime . +gn:rif-4023-36708756-2023-05-22T20:32:00-5 rdf:type gnc:NCBIWikiEntry ; + rdfs:label 'Plasma Lipoprotein Lipase Is Associated with Risk of Future Major Adverse Cardiovascular Events in Patients Following Carotid Endarterectomy.'@en ; + gnt:belongsToSpecies gn:Homo_sapiens ; + gnt:symbol "LPL" ; + gnt:hasGeneId generif:4023 ; + skos:notation taxon:9606 ; + dct:hasVersion "5"^^xsd:integer ; + dct:references pubmed:36708756 ; + dct:created "2023-05-22 20:32:00"^^xsd:datetime . +gn:rif-4023-37155355-2023-07-04T21:12:00-5 rdf:type gnc:NCBIWikiEntry ; + rdfs:label 'Inverse association between apolipoprotein C-II and cardiovascular mortality: role of lipoprotein lipase activity modulation.'@en ; + gnt:belongsToSpecies gn:Homo_sapiens ; + gnt:symbol "LPL" ; + gnt:hasGeneId generif:4023 ; + skos:notation taxon:9606 ; + dct:hasVersion "5"^^xsd:integer ; + dct:references pubmed:37155355 ; + dct:created "2023-07-04 21:12:00"^^xsd:datetime . +gn:rif-4023-37432202-2023-07-13T20:35:00-5 rdf:type gnc:NCBIWikiEntry ; + rdfs:label 'Effect of the Interaction between Seaweed Intake and LPL Polymorphisms on Metabolic Syndrome in Middle-Aged Korean Adults.'@en ; + gnt:belongsToSpecies gn:Homo_sapiens ; + gnt:symbol "LPL" ; + gnt:hasGeneId generif:4023 ; + skos:notation taxon:9606 ; + dct:hasVersion "5"^^xsd:integer ; + dct:references pubmed:37432202 ; + dct:created "2023-07-13 20:35:00"^^xsd:datetime . +gn:rif-4023-37568214-2023-08-14T20:37:00-5 rdf:type gnc:NCBIWikiEntry ; + rdfs:label 'Frameshift coding sequence variants in the LPL gene: identification of two novel events and exploration of the genotype-phenotype relationship for variants reported to date.'@en ; + gnt:belongsToSpecies gn:Homo_sapiens ; + gnt:symbol "LPL" ; + gnt:hasGeneId generif:4023 ; + skos:notation taxon:9606 ; + dct:hasVersion "5"^^xsd:integer ; + dct:references pubmed:37568214 ; + dct:created "2023-08-14 20:37:00"^^xsd:datetime . +gn:rif-4023-37550668-2023-08-22T20:29:00-5 rdf:type gnc:NCBIWikiEntry ; + rdfs:label 'The East Asian-specific LPL p.Ala288Thr (c.862G > A) missense variant exerts a mild effect on protein function.'@en ; + gnt:belongsToSpecies gn:Homo_sapiens ; + gnt:symbol "LPL" ; + gnt:hasGeneId generif:4023 ; + skos:notation taxon:9606 ; + dct:hasVersion "5"^^xsd:integer ; + dct:references pubmed:37550668 ; + dct:created "2023-08-22 20:29:00"^^xsd:datetime . +gn:rif-4023-37128695-2023-09-12T20:35:00-5 rdf:type gnc:NCBIWikiEntry ; + rdfs:label 'Interaction between APOE, APOA1, and LPL Gene Polymorphisms and Variability in Changes in Lipid and Blood Pressure following Orange Juice Intake: A Pilot Study.'@en ; + gnt:belongsToSpecies gn:Homo_sapiens ; + gnt:symbol "LPL" ; + gnt:hasGeneId generif:4023 ; + skos:notation taxon:9606 ; + dct:hasVersion "5"^^xsd:integer ; + dct:references pubmed:37128695 ; + dct:created "2023-09-12 20:35:00"^^xsd:datetime . +gn:rif-4023-37427758-2023-09-25T09:33:00-5 rdf:type gnc:NCBIWikiEntry ; + rdfs:label 'Variants within the LPL gene confer susceptility to diabetic kidney disease and rapid decline in kidney function in Chinese patients with type 2 diabetes.'@en ; + gnt:belongsToSpecies gn:Homo_sapiens ; + gnt:symbol "LPL" ; + gnt:hasGeneId generif:4023 ; + skos:notation taxon:9606 ; + dct:hasVersion "5"^^xsd:integer ; + dct:references pubmed:37427758 ; + dct:created "2023-09-25 09:33:00"^^xsd:datetime . +gn:rif-4023-37901192-2023-11-01T08:55:00-5 rdf:type gnc:NCBIWikiEntry ; + rdfs:label 'The Association of Adipokines and Myokines in the Blood of Obese Children and Adolescents with Lipoprotein Lipase rs328 Gene Variants.'@en ; + gnt:belongsToSpecies gn:Homo_sapiens ; + gnt:symbol "LPL" ; + gnt:hasGeneId generif:4023 ; + skos:notation taxon:9606 ; + dct:hasVersion "5"^^xsd:integer ; + dct:references pubmed:37901192 ; + dct:created "2023-11-01 08:55:00"^^xsd:datetime . +gn:rif-4023-37871217-2023-11-10T08:44:00-5 rdf:type gnc:NCBIWikiEntry ; + rdfs:label 'The lipoprotein lipase that is shuttled into capillaries by GPIHBP1 enters the glycocalyx where it mediates lipoprotein processing.'@en ; + gnt:belongsToSpecies gn:Homo_sapiens ; + gnt:symbol "LPL" ; + gnt:hasGeneId generif:4023 ; + skos:notation taxon:9606 ; + dct:hasVersion "5"^^xsd:integer ; + dct:references pubmed:37871217 ; + dct:created "2023-11-10 08:44:00"^^xsd:datetime . +gn:rif-4023-37858495-2023-12-28T20:33:00-5 rdf:type gnc:NCBIWikiEntry ; + rdfs:label 'Clinical profile, genetic spectrum and therapy evaluation of 19 Chinese pediatric patients with lipoprotein lipase deficiency.'@en ; + gnt:belongsToSpecies gn:Homo_sapiens ; + gnt:symbol "LPL" ; + gnt:hasGeneId generif:4023 ; + skos:notation taxon:9606 ; + dct:hasVersion "5"^^xsd:integer ; + dct:references pubmed:37858495 ; + dct:created "2023-12-28 20:33:00"^^xsd:datetime . +gn:rif-4023-38114521-2023-12-29T20:33:00-5 rdf:type gnc:NCBIWikiEntry ; + rdfs:label 'Developing a model to predict the early risk of hypertriglyceridemia based on inhibiting lipoprotein lipase (LPL): a translational study.'@en ; + gnt:belongsToSpecies gn:Homo_sapiens ; + gnt:symbol "LPL" ; + gnt:hasGeneId generif:4023 ; + skos:notation taxon:9606 ; + dct:hasVersion "5"^^xsd:integer ; + dct:references pubmed:38114521 ; + dct:created "2023-12-29 20:33:00"^^xsd:datetime . +gn:rif-6469-37511358-2023-08-19T08:35:00-5 rdf:type gnc:NCBIWikiEntry ; + rdfs:label 'Low Expression of the NRP1 Gene Is Associated with Shorter Overall Survival in Patients with Sonic Hedgehog and Group 3 Medulloblastoma.'@en ; + gnt:belongsToSpecies gn:Homo_sapiens ; + gnt:symbol "SHH" ; + gnt:hasGeneId generif:6469 ; + skos:notation taxon:9606 ; + dct:hasVersion "5"^^xsd:integer ; + dct:references pubmed:37511358 ; + dct:created "2023-08-19 08:35:00"^^xsd:datetime . +gn:rif-6469-37307020-2023-09-21T20:40:00-5 rdf:type gnc:NCBIWikiEntry ; + rdfs:label 'Activation of Sonic Hedgehog Signaling Pathway Regulates Human Trabecular Meshwork Cell Function.'@en ; + gnt:belongsToSpecies gn:Homo_sapiens ; + gnt:symbol "SHH" ; + gnt:hasGeneId generif:6469 ; + skos:notation taxon:9606 ; + dct:hasVersion "5"^^xsd:integer ; + dct:references pubmed:37307020 ; + dct:created "2023-09-21 20:40:00"^^xsd:datetime . diff --git a/tests/unit/computations/test_partial_correlations.py b/tests/unit/computations/test_partial_correlations.py index 066c650..6364701 100644 --- a/tests/unit/computations/test_partial_correlations.py +++ b/tests/unit/computations/test_partial_correlations.py @@ -159,8 +159,8 @@ class TestPartialCorrelations(TestCase): "variance": None}}}, dictified_control_samples), (("BXD2",), (7.80944,), - (7.51879, 7.77141, 8.39265, 8.17443, 8.30401, 7.80944, 8.39265, - 8.17443, 8.30401, 7.80944, 7.51879, 7.77141, 7.80944), + ((7.51879, 7.77141, 8.39265, 8.17443, 8.30401, 7.80944, 8.39265, + 8.17443, 8.30401, 7.80944, 7.51879, 7.77141, 7.80944),), (None,), (None, None, None, None, None, None, None, None, None, None, None, None, None))) diff --git a/tests/unit/computations/test_wgcna.py b/tests/unit/computations/test_wgcna.py index 55432af..325bd5a 100644 --- a/tests/unit/computations/test_wgcna.py +++ b/tests/unit/computations/test_wgcna.py @@ -85,9 +85,9 @@ class TestWgcna(TestCase): mock_img.return_value = b"AFDSFNBSDGJJHH" results = call_wgcna_script( - "Rscript/GUIX_PATH/scripts/r_file.R", request_data) + "Rscript/GUIX_PATH/scripts/r_file.R", request_data, "/tmp") - mock_dumping_data.assert_called_once_with(request_data) + mock_dumping_data.assert_called_once_with(request_data, "/tmp") mock_compose_wgcna.assert_called_once_with( "Rscript/GUIX_PATH/scripts/r_file.R", @@ -119,7 +119,7 @@ class TestWgcna(TestCase): mock_run_cmd.return_value = expected_error self.assertEqual(call_wgcna_script( - "input_file.R", ""), expected_error) + "input_file.R", "", "/tmp"), expected_error) @pytest.mark.skip( "This test assumes that the system will always be invoked from the root" @@ -134,7 +134,6 @@ class TestWgcna(TestCase): wgcna_cmd, "Rscript scripts/wgcna.r /tmp/wgcna.json") @pytest.mark.unit_test - @mock.patch("gn3.computations.wgcna.TMPDIR", "/tmp") @mock.patch("gn3.computations.wgcna.uuid.uuid4") def test_create_json_file(self, file_name_generator): """test for writing the data to a csv file""" @@ -166,8 +165,7 @@ class TestWgcna(TestCase): file_name_generator.return_value = "facb73ff-7eef-4053-b6ea-e91d3a22a00c" - results = dump_wgcna_data( - expected_input) + results = dump_wgcna_data(expected_input, "/tmp") file_handler.assert_called_once_with( "/tmp/facb73ff-7eef-4053-b6ea-e91d3a22a00c.json", 'w', encoding='utf-8') diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 8005c8e..5526d16 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -7,6 +7,7 @@ import pytest from gn3.app import create_app + @pytest.fixture(scope="session") def fxtr_app(): """Fixture: setup the test app""" @@ -15,7 +16,12 @@ def fxtr_app(): testdb = Path(testdir).joinpath( f'testdb_{datetime.now().strftime("%Y%m%dT%H%M%S")}') app = create_app({ - "TESTING": True, "AUTH_DB": testdb, + "TESTING": True, + "LMDB_DATA_PATH": str( + Path(__file__).parent.parent / + Path("test_data/lmdb-test-data") + ), + "AUTH_SERVER_URL": "http://127.0.0.1:8081", "OAUTH2_ACCESS_TOKEN_GENERATOR": "tests.unit.auth.test_token.gen_token" }) app.testing = True @@ -23,13 +29,15 @@ def fxtr_app(): # Clean up after ourselves testdb.unlink(missing_ok=True) + @pytest.fixture(scope="session") -def client(fxtr_app): # pylint: disable=redefined-outer-name +def client(fxtr_app): # pylint: disable=redefined-outer-name """Create a test client fixture for tests""" with fxtr_app.app_context(): yield fxtr_app.test_client() + @pytest.fixture(scope="session") -def fxtr_app_config(client): # pylint: disable=redefined-outer-name +def fxtr_app_config(client): # pylint: disable=redefined-outer-name """Return the test application's configuration object""" return client.application.config diff --git a/tests/unit/db/rdf/data.py b/tests/unit/db/rdf/data.py new file mode 100644 index 0000000..6bc612f --- /dev/null +++ b/tests/unit/db/rdf/data.py @@ -0,0 +1,199 @@ +"""Some test data to be used in RDF data.""" + +LPL_RIF_ENTRIES = { + "@context": { + "dct": "http://purl.org/dc/terms/", + "gnt": "http://genenetwork.org/term/", + "rdfs": "http://www.w3.org/2000/01/rdf-schema#", + "skos": "http://www.w3.org/2004/02/skos/core#", + "symbol": "gnt:symbol", + "species": "gnt:species", + "taxonomic_id": "skos:notation", + "gene_id": "gnt:hasGeneId", + "pubmed_id": "dct:references", + "created": "dct:created", + "comment": "rdfs:comment", + "version": "dct:hasVersion", + }, + "data": [ + { + "gene_id": 4023, + "version": 5, + "species": "Homo sapiens", + "symbol": "LPL", + "created": "2023-02-23 20:40:00", + "pubmed_id": 36763533, + "comment": "Angiopoietin-like protein 4/8 complex-mediated plasmin generation \ +leads to cleavage of the complex and restoration of LPL activity.", + "taxonomic_id": 9606, + }, + { + "gene_id": 4023, + "version": 5, + "species": "Homo sapiens", + "symbol": "LPL", + "created": "2023-04-07 20:39:00", + "pubmed_id": 36652113, + "comment": "The breast cancer microenvironment and lipoprotein lipase: \ +Another negative notch for a beneficial enzyme?", + "taxonomic_id": 9606, + }, + { + "gene_id": 4023, + "version": 5, + "species": "Homo sapiens", + "symbol": "LPL", + "created": "2023-04-27 20:33:00", + "pubmed_id": 36519761, + "comment": "Parkin regulates neuronal lipid homeostasis through \ +SREBP2-lipoprotein lipase pathway-implications for Parkinson's disease.", + "taxonomic_id": 9606, + }, + { + "gene_id": 4023, + "version": 5, + "species": "Homo sapiens", + "symbol": "LPL", + "created": "2023-05-22 20:32:00", + "pubmed_id": 36708756, + "comment": "Plasma Lipoprotein Lipase Is Associated with Risk of \ +Future Major Adverse Cardiovascular Events in Patients Following Carotid Endarterectomy.", + "taxonomic_id": 9606, + }, + { + "gene_id": 4023, + "version": 5, + "species": "Homo sapiens", + "symbol": "LPL", + "created": "2023-07-04 21:12:00", + "pubmed_id": 37155355, + "comment": "Inverse association between apolipoprotein C-II and \ +cardiovascular mortality: role of lipoprotein lipase activity modulation.", + "taxonomic_id": 9606, + }, + { + "gene_id": 4023, + "version": 5, + "species": "Homo sapiens", + "symbol": "LPL", + "created": "2023-07-13 20:35:00", + "pubmed_id": 37432202, + "comment": "Effect of the Interaction between Seaweed Intake and LPL \ +Polymorphisms on Metabolic Syndrome in Middle-Aged Korean Adults.", + "taxonomic_id": 9606, + }, + { + "gene_id": 4023, + "version": 5, + "species": "Homo sapiens", + "symbol": "LPL", + "created": "2023-08-14 20:37:00", + "pubmed_id": 37568214, + "comment": "Frameshift coding sequence variants in the LPL gene: identification \ +of two novel events and exploration of the genotype-phenotype relationship for \ +variants reported to date.", + "taxonomic_id": 9606, + }, + { + "gene_id": 4023, + "version": 5, + "species": "Homo sapiens", + "symbol": "LPL", + "created": "2023-08-22 20:29:00", + "pubmed_id": 37550668, + "comment": "The East Asian-specific LPL p.Ala288Thr (c.862G > A) missense \ +variant exerts a mild effect on protein function.", + "taxonomic_id": 9606, + }, + { + "gene_id": 4023, + "version": 5, + "species": "Homo sapiens", + "symbol": "LPL", + "created": "2023-09-12 20:35:00", + "pubmed_id": 37128695, + "comment": "Interaction between APOE, APOA1, and LPL Gene Polymorphisms \ +and Variability in Changes in Lipid and Blood Pressure following Orange Juice Intake: \ +A Pilot Study.", + "taxonomic_id": 9606, + }, + { + "gene_id": 4023, + "version": 5, + "species": "Homo sapiens", + "symbol": "LPL", + "created": "2023-09-25 09:33:00", + "pubmed_id": 37427758, + "comment": "Variants within the LPL gene confer susceptility to \ +diabetic kidney disease and rapid decline in kidney function in Chinese patients \ +with type 2 diabetes.", + "taxonomic_id": 9606, + }, + { + "gene_id": 4023, + "version": 5, + "species": "Homo sapiens", + "symbol": "LPL", + "created": "2023-11-01 08:55:00", + "pubmed_id": 37901192, + "comment": "The Association of Adipokines and Myokines in the \ +Blood of Obese Children and Adolescents with Lipoprotein Lipase rs328 Gene Variants.", + "taxonomic_id": 9606, + }, + { + "gene_id": 4023, + "version": 5, + "species": "Homo sapiens", + "symbol": "LPL", + "created": "2023-11-10 08:44:00", + "pubmed_id": 37871217, + "comment": "The lipoprotein lipase that is shuttled into \ +capillaries by GPIHBP1 enters the glycocalyx where it mediates lipoprotein processing.", + "taxonomic_id": 9606, + }, + { + "gene_id": 4023, + "version": 5, + "species": "Homo sapiens", + "symbol": "LPL", + "created": "2023-12-28 20:33:00", + "pubmed_id": 37858495, + "comment": "Clinical profile, genetic spectrum and therapy \ +evaluation of 19 Chinese pediatric patients with lipoprotein lipase deficiency.", + "taxonomic_id": 9606, + }, + { + "gene_id": 4023, + "version": 5, + "species": "Homo sapiens", + "symbol": "LPL", + "created": "2023-12-29 20:33:00", + "pubmed_id": 38114521, + "comment": "Developing a model to predict the early risk of \ +hypertriglyceridemia based on inhibiting lipoprotein lipase (LPL): a translational study.", + "taxonomic_id": 9606, + }, + { + "gene_id": 16956, + "version": 5, + "species": "Mus musculus", + "symbol": "Lpl", + "created": "2023-04-27 20:33:00", + "pubmed_id": 36519761, + "comment": "Parkin regulates neuronal lipid homeostasis through \ +SREBP2-lipoprotein lipase pathway-implications for Parkinson's disease.", + "taxonomic_id": 10090, + }, + { + "gene_id": 24539, + "version": 5, + "species": "Rattus norvegicus", + "symbol": "Lpl", + "created": "2023-12-29 20:33:00", + "pubmed_id": 38114521, + "comment": "Developing a model to predict the early risk of \ +hypertriglyceridemia based on inhibiting lipoprotein lipase (LPL): a translational study.", + "taxonomic_id": 10116, + }, + ], +} diff --git a/tests/unit/db/rdf/test_wiki.py b/tests/unit/db/rdf/test_wiki.py index 3abf3ad..bab37ce 100644 --- a/tests/unit/db/rdf/test_wiki.py +++ b/tests/unit/db/rdf/test_wiki.py @@ -22,11 +22,15 @@ from tests.fixtures.rdf import ( SPARQL_CONF, ) +from tests.unit.db.rdf.data import LPL_RIF_ENTRIES + from gn3.db.rdf.wiki import ( __sanitize_result, get_wiki_entries_by_symbol, get_comment_history, update_wiki_comment, + get_rif_entries_by_symbol, + delete_wiki_entries_by_id, ) GRAPH = "<http://cd-test.genenetwork.org>" @@ -396,3 +400,49 @@ def test_update_wiki_comment(rdf_setup): # pylint: disable=W0613,W0621 "version": 3, "web_url": "http://some-website.com", }) + + +@pytest.mark.rdf +def test_get_rif_entries_by_symbol(rdf_setup): # pylint: disable=W0613,W0621 + """Test fetching NCBI Rif Metadata from RDF""" + sparql_conf = SPARQL_CONF + entries = get_rif_entries_by_symbol( + symbol="Lpl", + sparql_uri=sparql_conf["sparql_endpoint"], + graph=GRAPH, + ) + assert len(LPL_RIF_ENTRIES["data"]) == len(entries["data"]) + for result, expected in zip(LPL_RIF_ENTRIES["data"], entries["data"]): + TestCase().assertDictEqual(result, expected) + + +@pytest.mark.rdf +def test_delete_wiki_entries_by_id(rdf_setup): # pylint: disable=W0613,W0621 + """Test deleting a given RIF Wiki entry""" + sparql_conf = SPARQL_CONF + delete_wiki_entries_by_id( + 230, + sparql_user=sparql_conf["sparql_user"], + sparql_password=sparql_conf["sparql_password"], + sparql_auth_uri=sparql_conf["sparql_auth_uri"], + graph=GRAPH) + entries = get_comment_history( + comment_id=230, + sparql_uri=sparql_conf["sparql_endpoint"], + graph=GRAPH, + ) + assert len(entries["data"]) == 0 + + # Deleting a non-existent entry has no effect + delete_wiki_entries_by_id( + 199999, + sparql_user=sparql_conf["sparql_user"], + sparql_password=sparql_conf["sparql_password"], + sparql_auth_uri=sparql_conf["sparql_auth_uri"], + graph=GRAPH) + entries = get_comment_history( + comment_id=230, + sparql_uri=sparql_conf["sparql_endpoint"], + graph=GRAPH, + ) + assert len(entries["data"]) == 0 diff --git a/tests/unit/db/test_case_attributes.py b/tests/unit/db/test_case_attributes.py index 97a0703..998b58d 100644 --- a/tests/unit/db/test_case_attributes.py +++ b/tests/unit/db/test_case_attributes.py @@ -1,205 +1,326 @@ """Test cases for gn3.db.case_attributes.py""" +import pickle +import tempfile +import os +import json +from pathlib import Path import pytest from pytest_mock import MockFixture -from gn3.db.case_attributes import get_unreviewed_diffs -from gn3.db.case_attributes import get_case_attributes -from gn3.db.case_attributes import insert_case_attribute_audit -from gn3.db.case_attributes import approve_case_attribute -from gn3.db.case_attributes import reject_case_attribute +from gn3.db.case_attributes import queue_edit +from gn3.db.case_attributes import ( + CaseAttributeEdit, + EditStatus, + apply_change, + get_changes, + view_change +) @pytest.mark.unit_test -def test_get_case_attributes(mocker: MockFixture) -> None: - """Test that all the case attributes are fetched correctly""" +def test_queue_edit(mocker: MockFixture) -> None: + """Test queueing an edit.""" mock_conn = mocker.MagicMock() with mock_conn.cursor() as cursor: - cursor.fetchall.return_value = ( - (1, "Condition", None), - (2, "Tissue", None), - (3, "Age", "Cum sociis natoque penatibus et magnis dis"), - (4, "Condition", "Description A"), - (5, "Condition", "Description B"), - ) - results = get_case_attributes(mock_conn) + type(cursor).lastrowid = 28 + tmpdir = Path(os.environ.get("TMPDIR", tempfile.gettempdir())) + caseattr_id = queue_edit( + cursor, + directory=tmpdir, + edit=CaseAttributeEdit( + inbredset_id=1, status=EditStatus.review, + user_id="xxxx", changes={"a": 1, "b": 2} + )) cursor.execute.assert_called_once_with( - "SELECT Id, Name, Description FROM CaseAttribute" - ) - assert results == ( - (1, "Condition", None), - (2, "Tissue", None), - (3, "Age", "Cum sociis natoque penatibus et magnis dis"), - (4, "Condition", "Description A"), - (5, "Condition", "Description B"), - ) + "INSERT INTO " + "caseattributes_audit(status, editor, json_diff_data) " + "VALUES (%s, %s, %s) " + "ON DUPLICATE KEY UPDATE status=%s", + ('review', 'xxxx', '{"a": 1, "b": 2}', 'review')) + assert 28 == caseattr_id @pytest.mark.unit_test -def test_get_unreviewed_diffs(mocker: MockFixture) -> None: - """Test that the correct query is called when fetching unreviewed - case-attributes diff""" - mock_conn = mocker.MagicMock() - with mock_conn.cursor() as cursor: - _ = get_unreviewed_diffs(mock_conn) - cursor.fetchall.return_value = ((1, "editor", "diff_data_1"),) - cursor.execute.assert_called_once_with( - "SELECT id, editor, json_diff_data FROM " - "caseattributes_audit WHERE status = 'review'" - ) +def test_view_change(mocker: MockFixture) -> None: + """Test view_change function.""" + sample_json_diff = { + "inbredset_id": 1, + "Modifications": { + "Original": { + "B6D2F1": {"Epoch": "10au"}, + "BXD100": {"Epoch": "3b"}, + "BXD101": {"SeqCvge": "29"}, + "BXD102": {"Epoch": "3b"}, + "BXD108": {"SeqCvge": ""} + }, + "Current": { + "B6D2F1": {"Epoch": "10"}, + "BXD100": {"Epoch": "3"}, + "BXD101": {"SeqCvge": "2"}, + "BXD102": {"Epoch": "3"}, + "BXD108": {"SeqCvge": "oo"} + } + } + } + change_id = 28 + mock_cursor, mock_conn = mocker.MagicMock(), mocker.MagicMock() + mock_conn.cursor.return_value = mock_cursor + mock_cursor.fetchone.return_value = (json.dumps(sample_json_diff), None) + assert view_change(mock_cursor, change_id) == sample_json_diff + mock_cursor.execute.assert_called_once_with( + "SELECT json_diff_data FROM caseattributes_audit WHERE id = %s", + (change_id,)) + mock_cursor.fetchone.assert_called_once() @pytest.mark.unit_test -def test_insert_case_attribute_audit(mocker: MockFixture) -> None: - """Test that the updating case attributes uses the correct query""" - mock_conn = mocker.MagicMock() - with mock_conn.cursor() as cursor: - _ = insert_case_attribute_audit( - mock_conn, status="review", author="Author", data="diff_data" - ) - cursor.execute.assert_called_once_with( - "INSERT INTO caseattributes_audit " - "(status, editor, json_diff_data) " - "VALUES (%s, %s, %s)", - ("review", "Author", "diff_data"), - ) +def test_view_change_invalid_json(mocker: MockFixture) -> None: + """Test invalid json when view_change is called""" + change_id = 28 + mock_cursor, mock_conn = mocker.MagicMock(), mocker.MagicMock() + mock_conn.cursor.return_value = mock_cursor + mock_cursor.fetchone.return_value = ("invalid_json_string", None) + with pytest.raises(json.JSONDecodeError): + view_change(mock_cursor, change_id) + mock_cursor.execute.assert_called_once_with( + "SELECT json_diff_data FROM caseattributes_audit WHERE id = %s", + (change_id,)) @pytest.mark.unit_test -def test_reject_case_attribute(mocker: MockFixture) -> None: - """Test rejecting a case-attribute""" - mock_conn = mocker.MagicMock() - with mock_conn.cursor() as cursor: - _ = reject_case_attribute( - mock_conn, - case_attr_audit_id=1, - ) - cursor.execute.assert_called_once_with( - "UPDATE caseattributes_audit SET " - "status = 'rejected' WHERE id = %s", - (1,), - ) +def test_view_change_no_data(mocker: MockFixture) -> None: + "Test no result when view_change is called" + change_id = 28 + mock_cursor, mock_conn = mocker.MagicMock(), mocker.MagicMock() + mock_conn.cursor.return_value = mock_cursor + mock_cursor.fetchone.return_value = (None, None) + assert view_change(mock_cursor, change_id) == {} + mock_cursor.execute.assert_called_once_with( + "SELECT json_diff_data FROM caseattributes_audit WHERE id = %s", + (change_id,)) @pytest.mark.unit_test -def test_approve_inserting_case_attribute(mocker: MockFixture) -> None: - """Test approving inserting a case-attribute""" - mock_conn = mocker.MagicMock() - with mock_conn.cursor() as cursor: - type(cursor).rowcount = 1 - cursor.fetchone.return_value = ( - """ - {"Insert": {"name": "test", "description": "Random Description"}} - """, - ) - _ = approve_case_attribute( - mock_conn, - case_attr_audit_id=3, - ) - calls = [ - mocker.call( - "SELECT json_diff_data FROM caseattributes_audit " - "WHERE id = %s", - (3,), - ), - mocker.call( - "INSERT INTO CaseAttribute " - "(Name, Description) VALUES " - "(%s, %s)", - ( - "test", - "Random Description", - ), - ), - mocker.call( - "UPDATE caseattributes_audit SET " - "status = 'approved' WHERE id = %s", - (3,), - ), +def test_apply_change_approved(mocker: MockFixture) -> None: + """Test approving a change""" + mock_cursor, mock_conn = mocker.MagicMock(), mocker.MagicMock() + mock_conn.cursor.return_value = mock_cursor + mock_lmdb = mocker.patch("gn3.db.case_attributes.lmdb") + mock_env, mock_txn = mocker.MagicMock(), mocker.MagicMock() + mock_lmdb.open.return_value = mock_env + mock_env.begin.return_value.__enter__.return_value = mock_txn + change_id, review_ids = 1, {1, 2, 3} + mock_txn.get.side_effect = ( + pickle.dumps(review_ids), # b"review" key + None, # b"approved" key + ) + tmpdir = Path(os.environ.get("TMPDIR", tempfile.gettempdir())) + mock_cursor.fetchone.return_value = (json.dumps({ + "inbredset_id": 1, + "Modifications": { + "Current": { + "B6D2F1": {"Epoch": "10"}, + "BXD100": {"Epoch": "3"}, + "BXD101": {"SeqCvge": "2"}, + "BXD102": {"Epoch": "3"}, + "BXD108": {"SeqCvge": "oo"} + } + } + }), None) + mock_cursor.fetchall.side_effect = [ + [ # Strain query + ("B6D2F1", 1), ("BXD100", 2), + ("BXD101", 3), ("BXD102", 4), + ("BXD108", 5)], + [ # CaseAttribute query + ("Epoch", 101), ("SeqCvge", 102)] + ] + assert apply_change(mock_cursor, EditStatus.approved, + change_id, tmpdir) is True + assert mock_cursor.execute.call_count == 4 + mock_cursor.execute.assert_has_calls([ + mocker.call( + "SELECT json_diff_data FROM caseattributes_audit WHERE id = %s", + (change_id,)), + mocker.call( + "SELECT Name, Id FROM Strain WHERE Name IN (%s, %s, %s, %s, %s)", + ("B6D2F1", "BXD100", "BXD101", "BXD102", "BXD108")), + mocker.call( + "SELECT Name, CaseAttributeId FROM CaseAttribute " + "WHERE InbredSetId = %s AND Name IN (%s, %s)", + (1, "Epoch", "SeqCvge")), + mocker.call( + "UPDATE caseattributes_audit SET status = %s WHERE id = %s", + ("approved", change_id)) + ]) + mock_cursor.executemany.assert_called_once_with( + "INSERT INTO CaseAttributeXRefNew (InbredSetId, StrainId, CaseAttributeId, Value) " + "VALUES (%(inbredset_id)s, %(strain_id)s, %(caseattr_id)s, %(value)s) " + "ON DUPLICATE KEY UPDATE Value = VALUES(Value)", + [ + {"inbredset_id": 1, "strain_id": 1, "caseattr_id": 101, "value": "10"}, + {"inbredset_id": 1, "strain_id": 2, "caseattr_id": 101, "value": "3"}, + {"inbredset_id": 1, "strain_id": 3, "caseattr_id": 102, "value": "2"}, + {"inbredset_id": 1, "strain_id": 4, "caseattr_id": 101, "value": "3"}, + {"inbredset_id": 1, "strain_id": 5, "caseattr_id": 102, "value": "oo"} ] - cursor.execute.assert_has_calls(calls, any_order=False) + ) @pytest.mark.unit_test -def test_approve_deleting_case_attribute(mocker: MockFixture) -> None: - """Test deleting a case-attribute""" - mock_conn = mocker.MagicMock() - with mock_conn.cursor() as cursor: - type(cursor).rowcount = 1 - cursor.fetchone.return_value = ( - """ - {"Deletion": {"id": "12", "name": "test", "description": ""}} - """, - ) - _ = approve_case_attribute( - mock_conn, - case_attr_audit_id=3, - ) - calls = [ - mocker.call( - "SELECT json_diff_data FROM caseattributes_audit " - "WHERE id = %s", - (3,), - ), - mocker.call("DELETE FROM CaseAttribute WHERE Id = %s", ("12",)), - mocker.call( - "UPDATE caseattributes_audit SET " - "status = 'approved' WHERE id = %s", - (3,), - ), - ] - cursor.execute.assert_has_calls(calls, any_order=False) +def test_apply_change_rejected(mocker: MockFixture) -> None: + """Test rejecting a change""" + mock_cursor, mock_conn = mocker.MagicMock(), mocker.MagicMock() + mock_conn.cursor.return_value = mock_cursor + mock_lmdb = mocker.patch("gn3.db.case_attributes.lmdb") + mock_env, mock_txn = mocker.MagicMock(), mocker.MagicMock() + mock_lmdb.open.return_value = mock_env + mock_env.begin.return_value.__enter__.return_value = mock_txn + tmpdir = Path(os.environ.get("TMPDIR", tempfile.gettempdir())) + change_id, review_ids = 3, {1, 2, 3} + mock_txn.get.side_effect = [ + pickle.dumps(review_ids), # review_ids + None # rejected_ids (initially empty) + ] + + assert apply_change(mock_cursor, EditStatus.rejected, + change_id, tmpdir) is True + + # Verify SQL query call sequence + mock_cursor.execute.assert_called_once_with( + "UPDATE caseattributes_audit SET status = %s WHERE id = %s", + (str(EditStatus.rejected), change_id)) + mock_cursor.executemany.assert_not_called() + + # Verify LMDB operations + mock_env.begin.assert_called_once_with(write=True) + expected_txn_calls = [ + mocker.call(b"review", pickle.dumps({1, 2})), + mocker.call(b"rejected", pickle.dumps({3})) + ] + mock_txn.put.assert_has_calls(expected_txn_calls, any_order=False) @pytest.mark.unit_test -def test_approve_modifying_case_attribute(mocker: MockFixture) -> None: - """Test modifying a case-attribute""" - mock_conn = mocker.MagicMock() - with mock_conn.cursor() as cursor: - type(cursor).rowcount = 1 - cursor.fetchone.return_value = ( - """ -{ - "id": "12", - "Modification": { - "description": { - "Current": "Test", - "Original": "A" - }, - "name": { - "Current": "Height (A)", - "Original": "Height" +def test_apply_change_non_existent_change_id(mocker: MockFixture) -> None: + """Test that there's a missing change_id from the returned LMDB rejected set.""" + mock_env, mock_txn = mocker.MagicMock(), mocker.MagicMock() + mock_cursor, mock_conn = mocker.MagicMock(), mocker.MagicMock() + mock_lmdb = mocker.patch("gn3.db.case_attributes.lmdb") + mock_lmdb.open.return_value = mock_env + mock_conn.cursor.return_value = mock_cursor + mock_env.begin.return_value.__enter__.return_value = mock_txn + change_id, review_ids = 28, {1, 2, 3} + mock_txn.get.side_effect = [ + pickle.dumps(review_ids), # b"review" key + None, # b"approved" key + ] + tmpdir = Path(os.environ.get("TMPDIR", tempfile.gettempdir())) + assert apply_change(mock_cursor, EditStatus.approved, + change_id, tmpdir) is False + + +@pytest.mark.unit_test +def test_get_changes(mocker: MockFixture) -> None: + """Test that reviews are correctly fetched""" + mock_fetch_case_attrs_changes = mocker.patch( + "gn3.db.case_attributes.__fetch_case_attrs_changes__" + ) + mock_fetch_case_attrs_changes.return_value = [ + { + "editor": "user1", + "json_diff_data": { + "inbredset_id": 1, + "Modifications": { + "Original": { + "B6D2F1": {"Epoch": "10au"}, + "BXD100": {"Epoch": "3b"}, + "BXD101": {"SeqCvge": "29"}, + "BXD102": {"Epoch": "3b"}, + "BXD108": {"SeqCvge": ""} + }, + "Current": { + "B6D2F1": {"Epoch": "10"}, + "BXD100": {"Epoch": "3"}, + "BXD101": {"SeqCvge": "2"}, + "BXD102": {"Epoch": "3"}, + "BXD108": {"SeqCvge": "oo"} + } + } + }, + "time_stamp": "2025-07-01 12:00:00" + }, + { + "editor": "user2", + "json_diff_data": { + "inbredset_id": 1, + "Modifications": { + "Original": {"BXD200": {"Epoch": "5a"}}, + "Current": {"BXD200": {"Epoch": "5"}} + } + }, + "time_stamp": "2025-07-01 12:01:00" + } + ] + mock_lmdb = mocker.patch("gn3.db.case_attributes.lmdb") + mock_env, mock_txn = mocker.MagicMock(), mocker.MagicMock() + mock_lmdb.open.return_value = mock_env + mock_env.begin.return_value.__enter__.return_value = mock_txn + review_ids, approved_ids, rejected_ids = {1, 4}, {2, 3}, {5, 6, 7, 10} + mock_txn.get.side_effect = ( + pickle.dumps(review_ids), # b"review" key + pickle.dumps(approved_ids), # b"approved" key + pickle.dumps(rejected_ids) # b"rejected" key + ) + result = get_changes(cursor=mocker.MagicMock(), + change_type=EditStatus.review, + directory=Path("/tmp")) + expected = { + "change-type": "review", + "count": { + "reviews": 2, + "approvals": 2, + "rejections": 4 + }, + "data": { + 1: { + "editor": "user1", + "json_diff_data": { + "inbredset_id": 1, + "Modifications": { + "Original": { + "B6D2F1": {"Epoch": "10au"}, + "BXD100": {"Epoch": "3b"}, + "BXD101": {"SeqCvge": "29"}, + "BXD102": {"Epoch": "3b"}, + "BXD108": {"SeqCvge": ""} + }, + "Current": { + "B6D2F1": {"Epoch": "10"}, + "BXD100": {"Epoch": "3"}, + "BXD101": {"SeqCvge": "2"}, + "BXD102": {"Epoch": "3"}, + "BXD108": {"SeqCvge": "oo"} + } + } + }, + "time_stamp": "2025-07-01 12:00:00" + }, + 4: { + 'editor': 'user2', + 'json_diff_data': { + 'inbredset_id': 1, + 'Modifications': { + 'Original': { + 'BXD200': {'Epoch': '5a'} + }, + 'Current': { + 'BXD200': {'Epoch': '5'} + } + } + }, + "time_stamp": "2025-07-01 12:01:00" + } + } } - } -}""", - ) - _ = approve_case_attribute( - mock_conn, - case_attr_audit_id=3, - ) - calls = [ - mocker.call( - "SELECT json_diff_data FROM caseattributes_audit " - "WHERE id = %s", - (3,), - ), - mocker.call( - "UPDATE CaseAttribute SET Description = %s WHERE Id = %s", - ( - "Test", - "12", - ), - ), - mocker.call( - "UPDATE CaseAttribute SET Name = %s WHERE Id = %s", - ( - "Height (A)", - "12", - ), - ), - mocker.call( - "UPDATE caseattributes_audit SET " - "status = 'approved' WHERE id = %s", - (3,), - ), - ] - cursor.execute.assert_has_calls(calls, any_order=False) + assert result == expected diff --git a/tests/unit/db/test_gen_menu.py b/tests/unit/db/test_gen_menu.py index e6b5711..f64b4d3 100644 --- a/tests/unit/db/test_gen_menu.py +++ b/tests/unit/db/test_gen_menu.py @@ -120,7 +120,7 @@ class TestGenMenu(unittest.TestCase): with db_mock.cursor() as conn: with conn.cursor() as cursor: for item in ["x", ("result"), ["result"], [1]]: - cursor.fetchone.return_value = (item) + cursor.fetchone.return_value = item self.assertTrue(phenotypes_exist(db_mock, "test")) @pytest.mark.unit_test @@ -140,7 +140,7 @@ class TestGenMenu(unittest.TestCase): db_mock = mock.MagicMock() with db_mock.cursor() as cursor: for item in ["x", ("result"), ["result"], [1]]: - cursor.fetchone.return_value = (item) + cursor.fetchone.return_value = item self.assertTrue(phenotypes_exist(db_mock, "test")) @pytest.mark.unit_test diff --git a/tests/unit/test_db_utils.py b/tests/unit/test_db_utils.py index beb7169..51f4296 100644 --- a/tests/unit/test_db_utils.py +++ b/tests/unit/test_db_utils.py @@ -1,25 +1,61 @@ """module contains test for db_utils""" -from unittest import mock - import pytest -from gn3.db_utils import parse_db_url, database_connection +from gn3.db_utils import parse_db_url + @pytest.mark.unit_test -@mock.patch("gn3.db_utils.mdb") -@mock.patch("gn3.db_utils.parse_db_url") -def test_database_connection(mock_db_parser, mock_sql): - """test for creating database connection""" - mock_db_parser.return_value = ("localhost", "guest", "4321", "users", None) +@pytest.mark.parametrize( + "sql_uri,expected", + (("mysql://theuser:passwd@thehost:3306/thedb", + { + "host": "thehost", + "port": 3306, + "user": "theuser", + "password": "passwd", + "database": "thedb" + }), + (("mysql://auser:passwd@somehost:3307/thedb?" + "unix_socket=/run/mysqld/mysqld.sock&connect_timeout=30"), + { + "host": "somehost", + "port": 3307, + "user": "auser", + "password": "passwd", + "database": "thedb", + "unix_socket": "/run/mysqld/mysqld.sock", + "connect_timeout": 30 + }), + ("mysql://guest:4321@localhost/users", + { + "host": "localhost", + "port": 3306, + "user": "guest", + "password": "4321", + "database": "users" + }), + ("mysql://localhost/users", + { + "host": "localhost", + "port": 3306, + "user": None, + "password": None, + "database": "users" + }))) +def test_parse_db_url(sql_uri, expected): + """Test that valid URIs are passed into valid connection dicts""" + assert parse_db_url(sql_uri) == expected - with database_connection("mysql://guest:4321@localhost/users") as _conn: - mock_sql.connect.assert_called_with( - db="users", user="guest", passwd="4321", host="localhost", - port=3306) @pytest.mark.unit_test -def test_parse_db_url(): - """test for parsing db_uri env variable""" - results = parse_db_url("mysql://username:4321@localhost/test") - expected_results = ("localhost", "username", "4321", "test", None) - assert results == expected_results +@pytest.mark.parametrize( + "sql_uri,invalidopt", + (("mysql://localhost/users?socket=/run/mysqld/mysqld.sock", "socket"), + ("mysql://localhost/users?connect_timeout=30¬avalidoption=value", + "notavalidoption"))) +def test_parse_db_url_with_invalid_options(sql_uri, invalidopt): + """Test that invalid options cause the function to raise an exception.""" + with pytest.raises(AssertionError) as exc_info: + parse_db_url(sql_uri) + + assert exc_info.value.args[0] == f"Invalid database connection option ({invalidopt}) provided." diff --git a/tests/unit/test_llm.py b/tests/unit/test_llm.py index 8fbaba6..3a79486 100644 --- a/tests/unit/test_llm.py +++ b/tests/unit/test_llm.py @@ -1,11 +1,22 @@ """Test cases for procedures defined in llms """ # pylint: disable=C0301 +# pylint: disable=W0613 +from datetime import datetime, timedelta +from unittest.mock import patch +from unittest.mock import MagicMock + import pytest from gn3.llms.process import fetch_pubmed from gn3.llms.process import parse_context from gn3.llms.process import format_bibliography_info +from gn3.llms.errors import LLMError +from gn3.api.llm import clean_query +from gn3.api.llm import is_verified_anonymous_user +from gn3.api.llm import is_valid_address +from gn3.api.llm import check_rate_limiter +FAKE_NOW = datetime(2025, 1, 1, 12, 0, 0) @pytest.mark.unit_test def test_parse_context(): """test for parsing doc id context""" @@ -104,3 +115,130 @@ def test_fetching_pubmed_info(monkeypatch): assert (fetch_pubmed(data, "/pubmed.json", "data/") == expected_results) + + +@pytest.mark.unit_test +def test_clean_query(): + """Test function for cleaning up query""" + assert clean_query("!what is genetics.") == "what is genetics" + assert clean_query("hello test?") == "hello test" + assert clean_query(" hello test with space?") == "hello test with space" + + +@pytest.mark.unit_test +def test_is_verified_anonymous_user(): + """Test function for verifying anonymous user metadata""" + assert is_verified_anonymous_user({}) is False + assert is_verified_anonymous_user({"Anonymous-Id" : "qws2121dwsdwdwe", + "Anonymous-Status" : "verified"}) is True + +@pytest.mark.unit_test +def test_is_valid_address() : + """Test function checks if is a valid ip address is valid""" + assert is_valid_address("invalid_ip") is False + assert is_valid_address("127.0.0.1") is True + + +@patch("gn3.api.llm.datetime") +@patch("gn3.api.llm.db.connection") +@patch("gn3.api.llm.is_valid_address", return_value=True) +@pytest.mark.unit_test +def test_first_time_visitor(mock_is_valid, mock_db_conn, mock_datetime): + """Test rate limiting for first-time visitor""" + mock_datetime.utcnow.return_value = FAKE_NOW + mock_datetime.strptime = datetime.strptime # keep real one + mock_datetime.strftime = datetime.strftime # keep real one + + # Set up DB mock + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_conn.__enter__.return_value = mock_conn + mock_conn.cursor.return_value = mock_cursor + mock_cursor.fetchone.return_value = None + mock_db_conn.return_value = mock_conn + + result = check_rate_limiter("127.0.0.1", "test/llm.db", "Chromosome x") + assert result is True + mock_cursor.execute.assert_any_call(""" + INSERT INTO Limiter(identifier, tokens, expiry_time) + VALUES (?, ?, ?) + """, ("127.0.0.1", 4, "2025-01-01 12:24:00")) + + +@patch("gn3.api.llm.datetime") +@patch("gn3.api.llm.db.connection") +@patch("gn3.api.llm.is_valid_address", return_value=True) +@pytest.mark.unit_test +def test_visitor_at_limit(mock_is_valid, mock_db_conn, mock_datetime): + """Test rate limiting for Visitor at limit""" + mock_datetime.utcnow.return_value = FAKE_NOW + mock_datetime.strptime = datetime.strptime # keep real one + mock_datetime.strftime = datetime.strftime + + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_conn.__enter__.return_value = mock_conn + mock_conn.cursor.return_value = mock_cursor + fake_expiry = (FAKE_NOW + timedelta(minutes=10)).strftime("%Y-%m-%d %H:%M:%S") + mock_cursor.fetchone.return_value = (0, fake_expiry) #token returned are 0 + mock_db_conn.return_value = mock_conn + with pytest.raises(LLMError) as exc_info: + check_rate_limiter("127.0.0.1", "test/llm.db", "Chromosome x") + # assert llm error with correct message is raised + assert exc_info.value.args == ('Rate limit exceeded. Please try again later.', 'Chromosome x') + + +@patch("gn3.api.llm.datetime") +@patch("gn3.api.llm.db.connection") +@patch("gn3.api.llm.is_valid_address", return_value=True) +@pytest.mark.unit_test +def test_visitor_with_tokens(mock_is_valid, mock_db_conn, mock_datetime): + """Test rate limiting for user with valid tokens""" + + mock_datetime.utcnow.return_value = FAKE_NOW + mock_datetime.strptime = datetime.strptime # Use real versions + mock_datetime.strftime = datetime.strftime + + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_conn.__enter__.return_value = mock_conn + mock_conn.cursor.return_value = mock_cursor + + fake_expiry = (FAKE_NOW + timedelta(minutes=10)).strftime("%Y-%m-%d %H:%M:%S") + mock_cursor.fetchone.return_value = (3, fake_expiry) # Simulate 3 tokens + + mock_db_conn.return_value = mock_conn + + results = check_rate_limiter("127.0.0.1", "test/llm.db", "Chromosome x") + assert results is True + mock_cursor.execute.assert_any_call(""" + UPDATE Limiter + SET tokens = tokens - 1 + WHERE identifier = ? AND tokens > 0 + """, ("127.0.0.1",)) + +@patch("gn3.api.llm.datetime") +@patch("gn3.api.llm.db.connection") +@patch("gn3.api.llm.is_valid_address", return_value=True) +@pytest.mark.unit_test +def test_visitor_token_expired(mock_is_valid, mock_db_conn, mock_datetime): + """Test rate limiting for expired tokens""" + + mock_datetime.utcnow.return_value = FAKE_NOW + mock_datetime.strptime = datetime.strptime + mock_datetime.strftime = datetime.strftime + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_conn.__enter__.return_value = mock_conn + mock_conn.cursor.return_value = mock_cursor + fake_expiry = (FAKE_NOW - timedelta(minutes=10)).strftime("%Y-%m-%d %H:%M:%S") + mock_cursor.fetchone.return_value = (3, fake_expiry) # Simulate 3 tokens + mock_db_conn.return_value = mock_conn + + result = check_rate_limiter("127.0.0.1", "test/llm.db", "Chromosome x") + assert result is True + mock_cursor.execute.assert_any_call(""" + UPDATE Limiter + SET tokens = ?, expiry_time = ? + WHERE identifier = ? + """, (4, "2025-01-01 12:24:00", "127.0.0.1")) diff --git a/tests/unit/test_rqtl2.py b/tests/unit/test_rqtl2.py new file mode 100644 index 0000000..ddce91b --- /dev/null +++ b/tests/unit/test_rqtl2.py @@ -0,0 +1,123 @@ +"""Module contains the unittest for rqtl2 functions """ +# pylint: disable=C0301 +from unittest import mock +import pytest +from gn3.computations.rqtl2 import compose_rqtl2_cmd +from gn3.computations.rqtl2 import generate_rqtl2_files +from gn3.computations.rqtl2 import prepare_files +from gn3.computations.rqtl2 import validate_required_keys + + +@pytest.mark.unit_test +@mock.patch("gn3.computations.rqtl2.write_to_csv") +def test_generate_rqtl2_files(mock_write_to_csv): + """Test for generating rqtl2 files from set of inputs""" + + mock_write_to_csv.side_effect = ( + "/tmp/workspace/geno_file.csv", + "/tmp/workspace/pheno_file.csv" + ) + data = {"crosstype": "riself", + "geno_data": [{"NAME": "Ge_code_1"}], + "pheno_data": [{"NAME": "14343_at"}], + "alleles": ["L", "C"], + "geno_codes": { + "L": 1, + "C": 2 + }, + "na.strings": ["-", "NA"] + } + + test_results = generate_rqtl2_files(data, "/tmp/workspace") + expected_results = {"geno_file": "/tmp/workspace/geno_file.csv", + "pheno_file": "/tmp/workspace/pheno_file.csv", + **data + } + assert test_results == expected_results + + # assert data is written to the csv + expected_calls = [mock.call( + "/tmp/workspace", + "geno_file.csv", + [{"NAME": "Ge_code_1"}] + ), + mock.call( + "/tmp/workspace", + "pheno_file.csv", + [{"NAME": "14343_at"}] + )] + mock_write_to_csv.assert_has_calls(expected_calls) + + +@pytest.mark.unit_test +def test_validate_required_keys(): + """Test to validate required keys are in a dataset""" + required_keys = ["geno_data", "pheno_data", "geno_codes"] + assert ((False, + "Required key(s) missing: geno_data, pheno_data, geno_codes") + == validate_required_keys(required_keys, {}) + ) + assert ((True, + "") + == validate_required_keys(required_keys, { + "geno_data": [], + "pheno_data": [], + "geno_codes": {} + }) + ) + + +@pytest.mark.unit_test +def test_compose_rqtl2_cmd(): + """Test for composing rqtl2 command""" + input_file = "/tmp/575732e-691e-49e5-8d82-30c564927c95/input_file.json" + output_file = "/tmp/575732e-691e-49e5-8d82-30c564927c95/output_file.json" + directory = "/tmp/575732e-691e-49e5-8d82-30c564927c95" + expected_results = f"Rscript /rqtl2_wrapper.R --input_file {input_file} --directory {directory} --output_file {output_file} --nperm 12 --method LMM --threshold 0.05 --cores 1" + + # test for using default configs + assert compose_rqtl2_cmd(rqtl_path="/rqtl2_wrapper.R", + input_file=input_file, + output_file=output_file, + workspace_dir=directory, + data={ + "nperm": 12, + "threshold": 0.05, + "method" : "LMM" + }, + config={}) == expected_results + + # test for default permutation, method and threshold and custom configs + expected_results = f"/bin/rscript /rqtl2_wrapper.R --input_file {input_file} --directory {directory} --output_file {output_file} --nperm 0 --method HK --threshold 1 --cores 12" + assert (compose_rqtl2_cmd(rqtl_path="/rqtl2_wrapper.R", + input_file=input_file, + output_file=output_file, + workspace_dir=directory, + data={}, + config={"MULTIPROCESSOR_PROCS": 12, "RSCRIPT": "/bin/rscript"}) + == expected_results) + + +@pytest.mark.unit_test +@mock.patch("gn3.computations.rqtl2.os.makedirs") +@mock.patch("gn3.computations.rqtl2.create_file") +@mock.patch("gn3.computations.rqtl2.uuid") +def test_preparing_rqtl_files(mock_uuid, mock_create_file, mock_mkdir): + """test to create required rqtl files""" + mock_create_file.return_value = None + mock_mkdir.return_value = None + mock_uuid.uuid4.return_value = "2fc75611-1524-418e-970f-67f94ea09846" + assert ( + ( + "/tmp/2fc75611-1524-418e-970f-67f94ea09846", + "/tmp/2fc75611-1524-418e-970f-67f94ea09846/rqtl2-input-2fc75611-1524-418e-970f-67f94ea09846.json", + "/tmp/2fc75611-1524-418e-970f-67f94ea09846/rqtl2-output-2fc75611-1524-418e-970f-67f94ea09846.json", + "/tmp/rqtl2-log-2fc75611-1524-418e-970f-67f94ea09846" + ) == prepare_files(tmpdir="/tmp/") + ) + # assert method to create files is called + expected_calls = [mock.call("/tmp/2fc75611-1524-418e-970f-67f94ea09846/rqtl2-input-2fc75611-1524-418e-970f-67f94ea09846.json"), + mock.call( + "/tmp/2fc75611-1524-418e-970f-67f94ea09846/rqtl2-output-2fc75611-1524-418e-970f-67f94ea09846.json"), + mock.call("/tmp/rqtl2-log-2fc75611-1524-418e-970f-67f94ea09846")] + mock_create_file.assert_has_calls(expected_calls) |
