diff options
Diffstat (limited to 'gn3/api')
| -rw-r--r-- | gn3/api/case_attributes.py | 296 | ||||
| -rw-r--r-- | gn3/api/correlation.py | 29 | ||||
| -rw-r--r-- | gn3/api/ctl.py | 5 | ||||
| -rw-r--r-- | gn3/api/general.py | 29 | ||||
| -rw-r--r-- | gn3/api/heatmaps.py | 2 | ||||
| -rw-r--r-- | gn3/api/llm.py | 236 | ||||
| -rw-r--r-- | gn3/api/lmdb_sample_data.py | 40 | ||||
| -rw-r--r-- | gn3/api/metadata.py | 6 | ||||
| -rw-r--r-- | gn3/api/metadata_api/wiki.py | 80 | ||||
| -rw-r--r-- | gn3/api/rqtl.py | 75 | ||||
| -rw-r--r-- | gn3/api/rqtl2.py | 55 | ||||
| -rw-r--r-- | gn3/api/search.py | 15 | ||||
| -rw-r--r-- | gn3/api/streaming.py | 26 | ||||
| -rw-r--r-- | gn3/api/wgcna.py | 4 |
14 files changed, 801 insertions, 97 deletions
diff --git a/gn3/api/case_attributes.py b/gn3/api/case_attributes.py new file mode 100644 index 0000000..28337ea --- /dev/null +++ b/gn3/api/case_attributes.py @@ -0,0 +1,296 @@ +"""Implement case-attribute manipulations.""" +from typing import Union +from pathlib import Path + +from functools import reduce +from urllib.parse import urljoin + +import requests +from MySQLdb.cursors import DictCursor +from authlib.integrations.flask_oauth2.errors import _HTTPException +from flask import ( + jsonify, + request, + Response, + Blueprint, + current_app) +from gn3.db.case_attributes import ( + CaseAttributeEdit, + EditStatus, + queue_edit, + apply_change, + get_changes) + + +from gn3.db_utils import Connection, database_connection + +from gn3.oauth2.authorisation import require_token +from gn3.oauth2.errors import AuthorisationError + +caseattr = Blueprint("case-attribute", __name__) + + +def required_access( + token: dict, + inbredset_id: int, + access_levels: tuple[str, ...] +) -> Union[bool, tuple[str, ...]]: + """Check whether the user has the appropriate access""" + def __species_id__(conn): + with conn.cursor() as cursor: + cursor.execute( + "SELECT SpeciesId FROM InbredSet WHERE InbredSetId=%s", + (inbredset_id,)) + return cursor.fetchone()[0] + try: + with database_connection(current_app.config["SQL_URI"]) as conn: + result = requests.get( + # this section fetches the resource ID from the auth server + urljoin(current_app.config["AUTH_SERVER_URL"], + "auth/resource/populations/resource-id" + f"/{__species_id__(conn)}/{inbredset_id}"), + timeout=300) + if result.status_code == 200: + resource_id = result.json()["resource-id"] + auth = requests.post( + # this section fetches the authorisations/privileges that + # the current user has on the resource we got above + urljoin(current_app.config["AUTH_SERVER_URL"], + "auth/resource/authorisation"), + json={"resource-ids": [resource_id]}, + headers={ + "Authorization": f"Bearer {token['access_token']}"}, + timeout=300) + if auth.status_code == 200: + privs = tuple(priv["privilege_id"] + for role in auth.json()[resource_id]["roles"] + for priv in role["privileges"]) + if all(lvl in privs for lvl in access_levels): + return privs + except _HTTPException as httpe: + raise AuthorisationError("You need to be logged in.") from httpe + + raise AuthorisationError( + f"User does not have the privileges {access_levels}") + + +def __inbredset_group__(conn, inbredset_id): + """Return InbredSet group's top-level details.""" + with conn.cursor(cursorclass=DictCursor) as cursor: + cursor.execute( + "SELECT * FROM InbredSet WHERE InbredSetId=%(inbredset_id)s", + {"inbredset_id": inbredset_id}) + return dict(cursor.fetchone()) + + +def __inbredset_strains__(conn, inbredset_id): + """Return all samples/strains for given InbredSet group.""" + with conn.cursor(cursorclass=DictCursor) as cursor: + cursor.execute( + "SELECT s.* FROM StrainXRef AS sxr INNER JOIN Strain AS s " + "ON sxr.StrainId=s.Id WHERE sxr.InbredSetId=%(inbredset_id)s " + "ORDER BY s.Name ASC", + {"inbredset_id": inbredset_id}) + return tuple(dict(row) for row in cursor.fetchall()) + + +def __case_attribute_labels_by_inbred_set__(conn, inbredset_id): + """Return the case-attribute labels/names for the given InbredSet group.""" + with conn.cursor(cursorclass=DictCursor) as cursor: + cursor.execute( + "SELECT * FROM CaseAttribute WHERE InbredSetId=%(inbredset_id)s", + {"inbredset_id": inbredset_id}) + return tuple(dict(row) for row in cursor.fetchall()) + + +@caseattr.route("/<int:inbredset_id>", methods=["GET"]) +def inbredset_group(inbredset_id: int) -> Response: + """Retrieve InbredSet group's details.""" + with database_connection(current_app.config["SQL_URI"]) as conn: + return jsonify(__inbredset_group__(conn, inbredset_id)) + + +@caseattr.route("/<int:inbredset_id>/strains", methods=["GET"]) +def inbredset_strains(inbredset_id: int) -> Response: + """Retrieve ALL strains/samples relating to a specific InbredSet group.""" + with database_connection(current_app.config["SQL_URI"]) as conn: + return jsonify(__inbredset_strains__(conn, inbredset_id)) + + +@caseattr.route("/<int:inbredset_id>/names", methods=["GET"]) +def inbredset_case_attribute_names(inbredset_id: int) -> Response: + """Retrieve ALL case-attributes for a specific InbredSet group.""" + with database_connection(current_app.config["SQL_URI"]) as conn: + return jsonify( + __case_attribute_labels_by_inbred_set__(conn, inbredset_id)) + + +def __by_strain__(accumulator, item): + attr = {item["CaseAttributeName"]: item["CaseAttributeValue"]} + strain_name = item["StrainName"] + if bool(accumulator.get(strain_name)): + return { + **accumulator, + strain_name: { + **accumulator[strain_name], + "case-attributes": { + **accumulator[strain_name]["case-attributes"], + **attr + } + } + } + return { + **accumulator, + strain_name: { + **{ + key: value for key, value in item.items() + if key in ("StrainName", "StrainName2", "Symbol", "Alias") + }, + "case-attributes": attr + } + } + + +def __case_attribute_values_by_inbred_set__( + conn: Connection, inbredset_id: int) -> tuple[dict, ...]: + """ + Retrieve Case-Attributes by their InbredSet ID. Do not call this outside + this module. + """ + with conn.cursor(cursorclass=DictCursor) as cursor: + cursor.execute( + "SELECT ca.Name AS CaseAttributeName, " + "caxrn.Value AS CaseAttributeValue, s.Name AS StrainName, " + "s.Name2 AS StrainName2, s.Symbol, s.Alias " + "FROM CaseAttribute AS ca " + "INNER JOIN CaseAttributeXRefNew AS caxrn " + "ON ca.CaseAttributeId=caxrn.CaseAttributeId " + "INNER JOIN Strain AS s " + "ON caxrn.StrainId=s.Id " + "WHERE caxrn.InbredSetId=%(inbredset_id)s " + "ORDER BY StrainName", + {"inbredset_id": inbredset_id}) + return tuple( + reduce(__by_strain__, cursor.fetchall(), {}).values()) + + +@caseattr.route("/<int:inbredset_id>/values", methods=["GET"]) +def inbredset_case_attribute_values(inbredset_id: int) -> Response: + """Retrieve the group's (InbredSet's) case-attribute values.""" + with database_connection(current_app.config["SQL_URI"]) as conn: + return jsonify(__case_attribute_values_by_inbred_set__(conn, inbredset_id)) + + +# pylint: disable=[too-many-locals] +@caseattr.route("/<int:inbredset_id>/edit", methods=["POST"]) +@require_token +def edit_case_attributes(inbredset_id: int, auth_token=None) -> tuple[Response, int]: + """Edit the case attributes for `InbredSetId` based on data received. + + :inbredset_id: Identifier for the population that the case attribute belongs + :auth_token: A validated JWT from the auth server + """ + with database_connection(current_app.config["SQL_URI"]) as conn, conn.cursor() as cursor: + data = request.json["edit-data"] # type: ignore + edit = CaseAttributeEdit( + inbredset_id=inbredset_id, + status=EditStatus.review, + user_id=auth_token["jwt"]["sub"], + changes=data + ) + directory = (Path(current_app.config["LMDB_DATA_PATH"]) / + "case-attributes" / str(inbredset_id)) + queue_edit(cursor=cursor, + directory=directory, + edit=edit) + return jsonify({ + "diff-status": "queued", + "message": ("The changes to the case-attributes have been " + "queued for approval."), + }), 201 + + +@caseattr.route("/<int:inbredset_id>/diffs/<string:change_type>/list", methods=["GET"]) +def list_diffs(inbredset_id: int, change_type: str) -> tuple[Response, int]: + """List any changes that have been made by change_type.""" + with (database_connection(current_app.config["SQL_URI"]) as conn, + conn.cursor(cursorclass=DictCursor) as cursor): + directory = (Path(current_app.config["LMDB_DATA_PATH"]) / + "case-attributes" / str(inbredset_id)) + return jsonify( + get_changes( + cursor=cursor, + change_type=EditStatus[change_type], + directory=directory + ) + ), 200 + + +@caseattr.route("/<int:inbredset_id>/approve/<int:change_id>", methods=["POST"]) +@require_token +def approve_case_attributes_diff( + inbredset_id: int, + change_id: int, auth_token=None +) -> tuple[Response, int]: + """Approve the changes to the case attributes in the diff.""" + try: + required_access(auth_token, + inbredset_id, + ("system:inbredset:edit-case-attribute",)) + with (database_connection(current_app.config["SQL_URI"]) as conn, + conn.cursor() as cursor): + directory = (Path(current_app.config["LMDB_DATA_PATH"]) / + "case-attributes" / str(inbredset_id)) + match apply_change(cursor, change_type=EditStatus.approved, + change_id=change_id, + directory=directory): + case True: + return jsonify({ + "diff-status": "approved", + "message": (f"Successfully approved # {change_id}") + }), 201 + case _: + return jsonify({ + "diff-status": "queued", + "message": (f"Was not able to successfully approve # {change_id}") + }), 200 + except AuthorisationError as __auth_err: + return jsonify({ + "diff-status": "queued", + "message": ("You don't have the right privileges to edit this resource.") + }), 401 + + +@caseattr.route("/<int:inbredset_id>/reject/<int:change_id>", methods=["POST"]) +@require_token +def reject_case_attributes_diff( + inbredset_id: int, change_id: int, auth_token=None +) -> tuple[Response, int]: + """Reject the changes to the case attributes in the diff.""" + try: + required_access(auth_token, + inbredset_id, + ("system:inbredset:edit-case-attribute", + "system:inbredset:apply-case-attribute-edit")) + with database_connection(current_app.config["SQL_URI"]) as conn, \ + conn.cursor() as cursor: + directory = (Path(current_app.config["LMDB_DATA_PATH"]) / + "case-attributes" / str(inbredset_id)) + match apply_change(cursor, change_type=EditStatus.rejected, + change_id=change_id, + directory=directory): + case True: + return jsonify({ + "diff-status": "rejected", + "message": ("The changes to the case-attributes have been " + "rejected.") + }), 201 + case _: + return jsonify({ + "diff-status": "queued", + "message": ("Failed to reject changes") + }), 200 + except AuthorisationError as __auth_err: + return jsonify({ + "message": ("You don't have the right privileges to edit this resource.") + }), 401 diff --git a/gn3/api/correlation.py b/gn3/api/correlation.py index c77dd93..3667a24 100644 --- a/gn3/api/correlation.py +++ b/gn3/api/correlation.py @@ -1,5 +1,6 @@ """Endpoints for running correlations""" import sys +import logging from functools import reduce import redis @@ -8,14 +9,14 @@ from flask import Blueprint from flask import request from flask import current_app -from gn3.settings import SQL_URI from gn3.db_utils import database_connection from gn3.commands import run_sample_corr_cmd from gn3.responses.pcorrs_responses import build_response -from gn3.commands import run_async_cmd, compose_pcorrs_command from gn3.computations.correlations import map_shared_keys_to_values from gn3.computations.correlations import compute_tissue_correlation from gn3.computations.correlations import compute_all_lit_correlation +from gn3.commands import ( + run_async_cmd, compute_job_queue, compose_pcorrs_command) correlation = Blueprint("correlation", __name__) @@ -88,6 +89,7 @@ def compute_tissue_corr(corr_method="pearson"): return jsonify(results) + @correlation.route("/partial", methods=["POST"]) def partial_correlation(): """API endpoint for partial correlations.""" @@ -111,9 +113,9 @@ def partial_correlation(): args = request.get_json() with_target_db = args.get("with_target_db", True) request_errors = __errors__( - args, ("primary_trait", "control_traits", - ("target_db" if with_target_db else "target_traits"), - "method")) + args, ("primary_trait", "control_traits", + ("target_db" if with_target_db else "target_traits"), + "method")) if request_errors: return build_response({ "status": "error", @@ -127,7 +129,7 @@ def partial_correlation(): tuple( trait_fullname(trait) for trait in args["control_traits"]), args["method"], target_database=args["target_db"], - criteria = int(args.get("criteria", 500))) + criteria=int(args.get("criteria", 500))) else: command = compose_pcorrs_command( trait_fullname(args["primary_trait"]), @@ -137,10 +139,17 @@ def partial_correlation(): trait_fullname(trait) for trait in args["target_traits"])) queueing_results = run_async_cmd( - conn=conn, - cmd=command, - job_queue=current_app.config.get("REDIS_JOB_QUEUE"), - env = {"PYTHONPATH": ":".join(sys.path), "SQL_URI": SQL_URI}) + conn=conn, + cmd=command, + job_queue=compute_job_queue(current_app), + options={ + "env": { + "PYTHONPATH": ":".join(sys.path), + "SQL_URI": current_app.config["SQL_URI"] + }, + }, + log_level=logging.getLevelName( + current_app.logger.getEffectiveLevel()).lower()) return build_response({ "status": "success", "results": queueing_results, diff --git a/gn3/api/ctl.py b/gn3/api/ctl.py index ac33d63..39c286f 100644 --- a/gn3/api/ctl.py +++ b/gn3/api/ctl.py @@ -2,7 +2,7 @@ from flask import Blueprint from flask import request -from flask import jsonify +from flask import jsonify, current_app from gn3.computations.ctl import call_ctl_script @@ -18,7 +18,8 @@ def run_ctl(): """ ctl_data = request.json - (cmd_results, response) = call_ctl_script(ctl_data) + (cmd_results, response) = call_ctl_script( + ctl_data, current_app.config["TMPDIR"]) return (jsonify({ "results": response }), 200) if response is not None else (jsonify({"error": str(cmd_results)}), 401) diff --git a/gn3/api/general.py b/gn3/api/general.py index b984361..8b57f23 100644 --- a/gn3/api/general.py +++ b/gn3/api/general.py @@ -1,5 +1,6 @@ """General API endpoints. Put endpoints that can't be grouped together nicely here.""" +import os from flask import Blueprint from flask import current_app from flask import jsonify @@ -68,3 +69,31 @@ def run_r_qtl(geno_filestr, pheno_filestr): cmd = (f"Rscript {rqtl_wrapper} " f"{geno_filestr} {pheno_filestr}") return jsonify(run_cmd(cmd)), 201 + + +@general.route("/stream ", methods=["GET"]) +def stream(): + """ + This endpoint streams the stdout content from a file. + It expects an identifier to be passed as a query parameter. + Example: `/stream?id=<identifier>` + The `id` will be used to locate the corresponding file. + You can also pass an optional `peak` parameter + to specify the file position to start reading from. + Query Parameters: + - `id` (required): The identifier used to locate the file. + - `peak` (optional): The position in the file to start reading from. + Returns: + - dict with data(stdout), run_id unique id for file, + pointer last read position for file + """ + run_id = request.args.get("id", "") + output_file = os.path.join(current_app.config.get("TMPDIR"), + f"{run_id}.txt") + seek_position = int(request.args.get("peak", 0)) + with open(output_file, encoding="utf-8") as file_handler: + # read to the last position default to 0 + file_handler.seek(seek_position) + return jsonify({"data": file_handler.readlines(), + "run_id": run_id, + "pointer": file_handler.tell()}) diff --git a/gn3/api/heatmaps.py b/gn3/api/heatmaps.py index 172d555..d3f9a45 100644 --- a/gn3/api/heatmaps.py +++ b/gn3/api/heatmaps.py @@ -33,7 +33,7 @@ def clustered_heatmaps(): with io.StringIO() as io_str: figure = build_heatmap(conn, traits_fullnames, - current_app.config["GENOTYPE_FILES"], + f'{current_app.config["GENOTYPE_FILES"]}/genotype', vertical=vertical, tmpdir=current_app.config["TMPDIR"]) figure.write_json(io_str) diff --git a/gn3/api/llm.py b/gn3/api/llm.py index b9ffbb2..dc8412e 100644 --- a/gn3/api/llm.py +++ b/gn3/api/llm.py @@ -1,16 +1,25 @@ """Api endpoints for gnqa""" +import ipaddress import json +import string +import uuid + from datetime import datetime +from datetime import timedelta +from typing import Optional +from functools import wraps from flask import Blueprint from flask import current_app from flask import jsonify from flask import request +from authlib.jose.errors import DecodeError from gn3.llms.process import get_gnqa from gn3.llms.errors import LLMError -from gn3.auth.authorisation.oauth2.resource_server import require_oauth -from gn3.auth import db + +from gn3.oauth2.authorisation import require_token +from gn3 import sqlite_db_utils as db gnqa = Blueprint("gnqa", __name__) @@ -26,6 +35,7 @@ CREATE TABLE IF NOT EXISTS history( ) WITHOUT ROWID """ + RATING_TABLE_CREATE_QUERY = """ CREATE TABLE IF NOT EXISTS Rating( user_id TEXT NOT NULL, @@ -39,40 +49,177 @@ CREATE TABLE IF NOT EXISTS Rating( """ +RATE_LIMITER_TABLE_CREATE_QUERY = """ +CREATE TABLE IF NOT EXISTS Limiter( + identifier TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + tokens INTEGER, + expiry_time TIMESTAMP, + PRIMARY KEY(identifier) +) +""" + + def database_setup(): """Temporary method to remove the need to have CREATE queries in functions""" with db.connection(current_app.config["LLM_DB_PATH"]) as conn: cursor = conn.cursor() cursor.execute(HISTORY_TABLE_CREATE_QUERY) cursor.execute(RATING_TABLE_CREATE_QUERY) + cursor.execute(RATE_LIMITER_TABLE_CREATE_QUERY) + + +def clean_query(query:str) -> str: + """This function cleans up query removing + punctuation and whitepace and transform to + lowercase + clean_query("!hello test.") -> "hello test" + """ + strip_chars = string.punctuation + string.whitespace + str_query = query.lower().strip(strip_chars) + return str_query + + +def is_verified_anonymous_user(header_metadata): + """This function should verify autheniticity of metadate from gn2 """ + anony_id = header_metadata.get("Anonymous-Id") #should verify this + metadata signature + user_status = header_metadata.get("Anonymous-Status", "") + _user_signed_metadata = ( + header_metadata.get("Anony-Metadata", "")) # TODO~ verify this for integrity with tokens + return bool(anony_id) and user_status.lower() == "verified" + +def with_gnqna_fallback(view_func): + """Allow fallback to GNQNA user if token auth fails or token is malformed.""" + @wraps(view_func) + def wrapper(*args, **kwargs): + def call_with_anonymous_fallback(): + return view_func.__wrapped__(*args, + **{**kwargs, "auth_token": None, "valid_anony": True}) + + try: + response = view_func(*args, **kwargs) + + is_invalid_token = ( + isinstance(response, tuple) and + len(response) == 2 and + response[1] == 400 + ) + + if is_invalid_token and is_verified_anonymous_user(dict(request.headers)): + return call_with_anonymous_fallback() + + return response + + except (DecodeError, ValueError): # occurs when trying to parse the token or auth results + if is_verified_anonymous_user(dict(request.headers)): + return call_with_anonymous_fallback() + return view_func.__wrapped__(*args, **kwargs) + + return wrapper + + +def is_valid_address(ip_string) -> bool : + """Function checks if is a valid ip address is valid""" + # todo !verify data is sent from gn2 + try: + ipaddress.ip_address(ip_string) + return True + except ValueError: + return False + + +def check_rate_limiter(ip_address, db_path, query, tokens_lifespan=1440, default_tokens=4): + """ + Checks if an anonymous user has a valid token within the given lifespan. + If expired or not found, creates or resets the token bucket. + `tokens_lifespan` is in seconds. 1440 seconds. + default_token set to 4 requests per hour. + """ + # Extract IP address /identifier + if not ip_address or not is_valid_address(ip_address): + raise ValueError("Please provide a valid IP address") + now = datetime.utcnow() + new_expiry = (now + timedelta(seconds=tokens_lifespan)).strftime("%Y-%m-%d %H:%M:%S") + + with db.connection(db_path) as conn: + cursor = conn.cursor() + # Fetch existing limiter record + cursor.execute(""" + SELECT tokens, expiry_time FROM Limiter + WHERE identifier = ? + """, (ip_address,)) + row = cursor.fetchone() + + if row: + tokens, expiry_time_str = row + expiry_time = datetime.strptime(expiry_time_str, "%Y-%m-%d %H:%M:%S") + time_diff = (expiry_time - now).total_seconds() + + if 0 < time_diff <= tokens_lifespan: + if tokens > 0: + # Consume token + cursor.execute(""" + UPDATE Limiter + SET tokens = tokens - 1 + WHERE identifier = ? AND tokens > 0 + """, (ip_address,)) + return True + else: + raise LLMError("Rate limit exceeded. Please try again later.", + query) + else: + # Token expired — reset ~probably reset this after 200 status + cursor.execute(""" + UPDATE Limiter + SET tokens = ?, expiry_time = ? + WHERE identifier = ? + """, (default_tokens, new_expiry, ip_address)) + return True + else: + # New user — insert record ~probably reset this after 200 status + cursor.execute(""" + INSERT INTO Limiter(identifier, tokens, expiry_time) + VALUES (?, ?, ?) + """, (ip_address, default_tokens, new_expiry)) + return True @gnqa.route("/search", methods=["GET"]) -def search(): +@with_gnqna_fallback +@require_token +def search(auth_token=None, valid_anony=False): """Api endpoint for searching queries in fahamu Api""" query = request.args.get("query", "") if not query: return jsonify({"error": "query get parameter is missing in the request"}), 400 + fahamu_token = current_app.config.get("FAHAMU_AUTH_TOKEN") if not fahamu_token: raise LLMError( "Request failed: an LLM authorisation token is required ", query) database_setup() - with (db.connection(current_app.config["LLM_DB_PATH"]) as conn, - require_oauth.acquire("profile user") as token): + with db.connection(current_app.config["LLM_DB_PATH"]) as conn: cursor = conn.cursor() previous_answer_query = """ SELECT user_id, task_id, query, results FROM history - WHERE created_at > DATE('now', '-1 day') AND - user_id = ? AND + WHERE created_at > DATE('now', '-21 day') AND query = ? ORDER BY created_at DESC LIMIT 1 """ - res = cursor.execute(previous_answer_query, (str(token.user.user_id), query)) + res = cursor.execute(previous_answer_query, (clean_query(query),)) previous_result = res.fetchone() if previous_result: _, _, _, response = previous_result + response = json.loads(response) + response["query"] = query return response + if valid_anony: + # rate limit anonymous verified users + user_metadata = json.loads(request.headers.get("Anony-Metadata", {})) + check_rate_limiter(user_metadata.get("ip_address", ""), + current_app.config["LLM_DB_PATH"], + request.args.get("query", "")) + task_id, answer, refs = get_gnqa( query, fahamu_token, current_app.config.get("DATA_DIR")) response = { @@ -81,52 +228,51 @@ def search(): "answer": answer, "references": refs } + user_id = str(uuid.uuid4()) if valid_anony else get_user_id(auth_token) cursor.execute( """INSERT INTO history(user_id, task_id, query, results) VALUES(?, ?, ?, ?) - """, (str(token.user.user_id), str(task_id["task_id"]), - query, + """, (user_id, str(task_id["task_id"]), + clean_query(query), json.dumps(response)) ) return response @gnqa.route("/rating/<task_id>", methods=["POST"]) -@require_oauth("profile") -def rate_queries(task_id): +@require_token +def rate_queries(task_id, auth_token=None): """Api endpoint for rating GNQA query and answer""" database_setup() - with (require_oauth.acquire("profile") as token, - db.connection(current_app.config["LLM_DB_PATH"]) as conn): + user_id = get_user_id(auth_token) + with db.connection(current_app.config["LLM_DB_PATH"]) as conn: results = request.json - user_id, query, answer, weight = (token.user.user_id, - results.get("query"), - results.get("answer"), - results.get("weight", 0)) + query, answer, weight = (results.get("query"), + results.get("answer"), + results.get("weight", 0)) cursor = conn.cursor() cursor.execute("""INSERT INTO Rating(user_id, query, answer, weight, task_id) VALUES(?, ?, ?, ?, ?) ON CONFLICT(task_id) DO UPDATE SET weight=excluded.weight - """, (str(user_id), query, answer, weight, task_id)) + """, (user_id, query, answer, weight, task_id)) return { "message": "You have successfully rated this query. Thank you!" }, 200 @gnqa.route("/search/records", methods=["GET"]) -@require_oauth("profile user") -def get_user_search_records(): +@require_token +def get_user_search_records(auth_token=None): """get all history records for a given user using their user id """ - with (require_oauth.acquire("profile user") as token, - db.connection(current_app.config["LLM_DB_PATH"]) as conn): + with db.connection(current_app.config["LLM_DB_PATH"]) as conn: cursor = conn.cursor() cursor.execute( """SELECT task_id, query, created_at from history WHERE user_id=?""", - (str(token.user.user_id),)) + (get_user_id(auth_token),)) results = [dict(item) for item in cursor.fetchall()] return jsonify(sorted(results, reverse=True, key=lambda x: datetime.strptime(x.get("created_at"), @@ -134,17 +280,15 @@ def get_user_search_records(): @gnqa.route("/search/record/<task_id>", methods=["GET"]) -@require_oauth("profile user") -def get_user_record_by_task(task_id): +@require_token +def get_user_record_by_task(task_id, auth_token = None): """Get user previous search record by task id """ - with (require_oauth.acquire("profile user") as token, - db.connection(current_app.config["LLM_DB_PATH"]) as conn): + with db.connection(current_app.config["LLM_DB_PATH"]) as conn: cursor = conn.cursor() cursor.execute( """SELECT results from history Where task_id=? and user_id=?""", - (task_id, - str(token.user.user_id),)) + (task_id, get_user_id(auth_token),)) record = cursor.fetchone() if record: return dict(record).get("results") @@ -152,28 +296,34 @@ def get_user_record_by_task(task_id): @gnqa.route("/search/record/<task_id>", methods=["DELETE"]) -@require_oauth("profile user") -def delete_record(task_id): +@require_token +def delete_record(task_id, auth_token = None): """Delete user previous seach record by task-id""" - with (require_oauth.acquire("profile user") as token, - db.connection(current_app.config["LLM_DB_PATH"]) as conn): + with db.connection(current_app.config["LLM_DB_PATH"]) as conn: cursor = conn.cursor() query = """DELETE FROM history WHERE task_id=? and user_id=?""" - cursor.execute(query, (task_id, token.user.user_id,)) + cursor.execute(query, (task_id, get_user_id(auth_token),)) return {"msg": f"Successfully Deleted the task {task_id}"} @gnqa.route("/search/records", methods=["DELETE"]) -@require_oauth("profile user") -def delete_records(): +@require_token +def delete_records(auth_token=None): """ Delete a users records using for all given task ids""" - with (require_oauth.acquire("profile user") as token, - db.connection(current_app.config["LLM_DB_PATH"]) as conn): + with db.connection(current_app.config["LLM_DB_PATH"]) as conn: task_ids = list(request.json.values()) cursor = conn.cursor() - query = f""" -DELETE FROM history WHERE task_id IN ({', '.join('?' * len(task_ids))}) AND user_id=? - """ - cursor.execute(query, (*task_ids, str(token.user.user_id),)) + query = ("DELETE FROM history WHERE task_id IN " + f"({', '.join('?' * len(task_ids))}) " + "AND user_id=?") + cursor.execute(query, (*task_ids, get_user_id(auth_token),)) return jsonify({}) + + +def get_user_id(auth_token: Optional[dict] = None): + """Retrieve the user ID from the JWT token.""" + if auth_token is None or auth_token.get("jwt", {}).get("sub") is None: + raise LLMError("Invalid auth token encountered") + user_id = auth_token["jwt"]["sub"] + return user_id diff --git a/gn3/api/lmdb_sample_data.py b/gn3/api/lmdb_sample_data.py new file mode 100644 index 0000000..eaa71c2 --- /dev/null +++ b/gn3/api/lmdb_sample_data.py @@ -0,0 +1,40 @@ +"""API endpoint for retrieving sample data from LMDB storage""" +import hashlib +from pathlib import Path + +import lmdb +from flask import Blueprint, current_app, jsonify + +lmdb_sample_data = Blueprint("lmdb_sample_data", __name__) + + +@lmdb_sample_data.route("/sample-data/<string:dataset>/<int:trait_id>", methods=["GET"]) +def get_sample_data(dataset: str, trait_id: int): + """Retrieve sample data from LMDB for a given dataset and trait. + + Path Parameters: + dataset: The name of the dataset + trait_id: The ID of the trait + + Returns: + JSON object mapping sample IDs to their values + """ + checksum = hashlib.md5( + f"{dataset}-{trait_id}".encode() + ).hexdigest() + + db_path = Path(current_app.config["LMDB_DATA_PATH"]) / checksum + if not db_path.exists(): + return jsonify(error="No data found for given dataset and trait"), 404 + try: + with lmdb.open(str(db_path), max_dbs=15, readonly=True) as env: + data = {} + with env.begin(write=False) as txn: + cursor = txn.cursor() + for key, value in cursor: + data[key.decode()] = float(value.decode()) + + return jsonify(data) + + except lmdb.Error as err: + return jsonify(error=f"LMDB error: {str(err)}"), 500 diff --git a/gn3/api/metadata.py b/gn3/api/metadata.py index 6110880..e272c0d 100644 --- a/gn3/api/metadata.py +++ b/gn3/api/metadata.py @@ -9,7 +9,8 @@ from flask import Blueprint from flask import request from flask import current_app -from gn3.auth.authorisation.errors import AuthorisationError + +from gn3.oauth2.errors import AuthorisationError from gn3.db.datasets import (retrieve_metadata, save_metadata, get_history) @@ -27,6 +28,7 @@ from gn3.api.metadata_api import wiki metadata = Blueprint("metadata", __name__) metadata.register_blueprint(wiki.wiki_blueprint) +metadata.register_blueprint(wiki.rif_blueprint) @metadata.route("/datasets/<name>", methods=["GET"]) @@ -170,7 +172,7 @@ def view_history(id_): "history": history, }) if history.get("error"): - raise Exception(history.get("error_description")) + raise Exception(history.get("error_description")) # pylint: disable=[broad-exception-raised] return history diff --git a/gn3/api/metadata_api/wiki.py b/gn3/api/metadata_api/wiki.py index 8df6cfb..70c5cf4 100644 --- a/gn3/api/metadata_api/wiki.py +++ b/gn3/api/metadata_api/wiki.py @@ -1,27 +1,31 @@ -"""API for accessing/editting wiki metadata""" +"""API for accessing/editing rif/wiki metadata""" import datetime from typing import Any, Dict +from typing import Optional from flask import Blueprint, request, jsonify, current_app, make_response from gn3 import db_utils -from gn3.auth.authorisation.oauth2.resource_server import require_oauth +from gn3.oauth2.authorisation import require_token from gn3.db import wiki from gn3.db.rdf.wiki import ( get_wiki_entries_by_symbol, get_comment_history, update_wiki_comment, + get_rif_entries_by_symbol, + delete_wiki_entries_by_id, ) wiki_blueprint = Blueprint("wiki", __name__, url_prefix="wiki") +rif_blueprint = Blueprint("rif", __name__, url_prefix="rif") +@wiki_blueprint.route("/edit", methods=["POST"], defaults={'comment_id': None}) @wiki_blueprint.route("/<int:comment_id>/edit", methods=["POST"]) -@require_oauth("profile") -def edit_wiki(comment_id: int): - """Edit wiki comment. This is achieved by adding another entry with a new VersionId""" +def edit_wiki(comment_id: Optional[int]): + """Edit/Insert wiki comment. This is achieved by adding another entry with a new VersionId""" # FIXME: attempt to check and fix for types here with relevant errors payload: Dict[str, Any] = request.json # type: ignore pubmed_ids = [str(x) for x in payload.get("pubmed_ids", [])] @@ -48,13 +52,17 @@ def edit_wiki(comment_id: int): VALUES (%(Id)s, %(versionId)s, %(symbol)s, %(PubMed_ID)s, %(SpeciesID)s, %(comment)s, %(email)s, %(createtime)s, %(user_ip)s, %(weburl)s, %(initial)s, %(reason)s) """ with db_utils.database_connection(current_app.config["SQL_URI"]) as conn: - cursor = conn.cursor() - next_version = 0 + cursor, next_version = conn.cursor(), 0 + if not comment_id: + comment_id = wiki.get_next_comment_id(cursor) + insert_dict["Id"] = comment_id + else: + next_version = wiki.get_next_comment_version(cursor, comment_id) + try: category_ids = wiki.get_categories_ids( cursor, payload["categories"]) species_id = wiki.get_species_id(cursor, payload["species"]) - next_version = wiki.get_next_comment_version(cursor, comment_id) except wiki.MissingDBDataException as missing_exc: return jsonify(error=f"Error editing wiki entry, {missing_exc}"), 500 insert_dict["SpeciesID"] = species_id @@ -83,7 +91,7 @@ def edit_wiki(comment_id: int): sparql_auth_uri=current_app.config["SPARQL_AUTH_URI"] ) except Exception as exc: - conn.rollback() # type: ignore + conn.rollback() # type: ignore raise exc return jsonify({"success": "ok"}) return jsonify(error="Error editing wiki entry, most likely due to DB error!"), 500 @@ -154,3 +162,57 @@ def get_history(comment_id): payload.headers["Content-Type"] = "application/ld+json" return payload, status_code return jsonify(data), status_code + + +@rif_blueprint.route("/<string:symbol>", methods=["GET"]) +def get_ncbi_rif_entries(symbol: str): + """Fetch NCBI RIF entries""" + status_code = 200 + response = get_rif_entries_by_symbol( + symbol, + sparql_uri=current_app.config["SPARQL_ENDPOINT"]) + data = response.get("data") + if not data: + data, status_code = {}, 404 + if request.headers.get("Accept") == "application/ld+json": + payload = make_response(response) + payload.headers["Content-Type"] = "application/ld+json" + return payload, status_code + return jsonify(data), status_code + + +@wiki_blueprint.route("/delete", methods=["POST"], defaults={'comment_id': None}) +@wiki_blueprint.route("/<int:comment_id>/delete", methods=["POST"]) +@require_token +def delete_wiki(comment_id: Optional[int] = None, **kwargs): # pylint: disable=[unused-argument] + """Delete a wiki entry by its comment_id from both SQL and RDF.""" + if comment_id is None: + return jsonify(error="comment_id is required for deletion."), 400 + with (db_utils.database_connection(current_app.config["SQL_URI"]) as conn, + conn.cursor() as cursor): + try: + # Delete from SQL + delete_query = "DELETE FROM GeneRIF WHERE Id = %s" + current_app.logger.debug( + f"Running query: {delete_query} with Id={comment_id}") + cursor.execute(delete_query, (comment_id,)) + # Delete from RDF + try: + delete_wiki_entries_by_id( + wiki_id=comment_id, + sparql_user=current_app.config["SPARQL_USER"], + sparql_password=current_app.config["SPARQL_PASSWORD"], + sparql_auth_uri=current_app.config["SPARQL_AUTH_URI"], + graph="<http://genenetwork.org>" + ) + # pylint: disable=W0718 + except Exception as rdf_exc: + current_app.logger.error(f"RDF deletion failed: {rdf_exc}") + conn.rollback() + return jsonify(error="Failed to delete wiki entry from RDF store."), 500 + return jsonify({"success": "Wiki entry deleted successfully."}), 200 + # pylint: disable=W0718 + except Exception as exc: + conn.rollback() + current_app.logger.error(f"Error deleting wiki entry: {exc}") + return jsonify(error="Error deleting wiki entry, most likely due to DB error!"), 500 diff --git a/gn3/api/rqtl.py b/gn3/api/rqtl.py index e029d8d..eb49f8b 100644 --- a/gn3/api/rqtl.py +++ b/gn3/api/rqtl.py @@ -1,5 +1,7 @@ """Endpoints for running the rqtl cmd""" + import os +import uuid from pathlib import Path from flask import Blueprint @@ -7,30 +9,40 @@ from flask import current_app from flask import jsonify from flask import request -from gn3.computations.rqtl import generate_rqtl_cmd, process_rqtl_mapping, \ - process_rqtl_pairscan, process_perm_output -from gn3.fs_helpers import assert_paths_exist, get_tmpdir +from gn3.computations.rqtl import ( + generate_rqtl_cmd, + process_rqtl_mapping, + process_rqtl_pairscan, + process_perm_output, +) +from gn3.computations.streaming import run_process, enable_streaming +from gn3.fs_helpers import assert_path_exists, get_tmpdir rqtl = Blueprint("rqtl", __name__) + @rqtl.route("/compute", methods=["POST"]) -def compute(): +@enable_streaming +def compute(stream_output_file): """Given at least a geno_file and pheno_file, generate and -run the rqtl_wrapper script and return the results as JSON + run the rqtl_wrapper script and return the results as JSON """ - genofile = request.form['geno_file'] - phenofile = request.form['pheno_file'] - - assert_paths_exist([genofile, phenofile]) - - # Split kwargs by those with values and boolean ones that just convert to True/False + genofile = request.form["geno_file"] + phenofile = request.form["pheno_file"] + assert_path_exists(genofile) + assert_path_exists(phenofile) kwargs = ["covarstruct", "model", "method", "nperm", "scale", "control"] boolean_kwargs = ["addcovar", "interval", "pstrata", "pairscan"] all_kwargs = kwargs + boolean_kwargs - rqtl_kwargs = {"geno": genofile, "pheno": phenofile, "outdir": current_app.config.get("TMPDIR")} + rqtl_kwargs = { + "geno": genofile, + "pheno": phenofile, + "outdir": current_app.config.get("TMPDIR"), + } rqtl_bool_kwargs = [] + for kwarg in all_kwargs: if kwarg in request.form: if kwarg in kwargs: @@ -38,30 +50,43 @@ run the rqtl_wrapper script and return the results as JSON if kwarg in boolean_kwargs: rqtl_bool_kwargs.append(kwarg) - outdir = os.path.join(get_tmpdir(),"gn3") + outdir = os.path.join(get_tmpdir(), "gn3") if not os.path.isdir(outdir): os.mkdir(outdir) rqtl_cmd = generate_rqtl_cmd( rqtl_wrapper_cmd=str( - Path(__file__).absolute().parent.parent.parent.joinpath( - 'scripts/rqtl_wrapper.R')), + Path(__file__) + .absolute() + .parent.parent.parent.joinpath("scripts/rqtl_wrapper.R") + ), rqtl_wrapper_kwargs=rqtl_kwargs, - rqtl_wrapper_bool_kwargs=rqtl_bool_kwargs + rqtl_wrapper_bool_kwargs=rqtl_bool_kwargs, ) rqtl_output = {} - if not os.path.isfile(os.path.join(current_app.config.get("TMPDIR"), - "gn3", rqtl_cmd.get('output_file'))): - os.system(rqtl_cmd.get('rqtl_cmd')) + run_id = request.args.get("id", str(uuid.uuid4())) + if not os.path.isfile( + os.path.join( + current_app.config.get("TMPDIR"), "gn3", rqtl_cmd.get("output_file") + ) + ): + pass + run_process(rqtl_cmd.get("rqtl_cmd").split(), stream_output_file, run_id) if "pairscan" in rqtl_bool_kwargs: - rqtl_output['results'] = process_rqtl_pairscan(rqtl_cmd.get('output_file'), genofile) + rqtl_output["results"] = process_rqtl_pairscan( + rqtl_cmd.get("output_file"), genofile + ) else: - rqtl_output['results'] = process_rqtl_mapping(rqtl_cmd.get('output_file')) - - if int(rqtl_kwargs['nperm']) > 0: - rqtl_output['perm_results'], rqtl_output['suggestive'], rqtl_output['significant'] = \ - process_perm_output(rqtl_cmd.get('output_file')) + rqtl_output["results"] = process_rqtl_mapping(rqtl_cmd.get("output_file")) + if int(rqtl_kwargs["nperm"]) > 0: + # pylint: disable=C0301 + perm_output_file = f"PERM_{rqtl_cmd.get('output_file')}" + ( + rqtl_output["perm_results"], + rqtl_output["suggestive"], + rqtl_output["significant"], + ) = process_perm_output(perm_output_file) return jsonify(rqtl_output) diff --git a/gn3/api/rqtl2.py b/gn3/api/rqtl2.py new file mode 100644 index 0000000..dc06d1d --- /dev/null +++ b/gn3/api/rqtl2.py @@ -0,0 +1,55 @@ +""" File contains endpoints for rqlt2""" +import shutil +from pathlib import Path +from flask import current_app +from flask import jsonify +from flask import Blueprint +from flask import request +from gn3.computations.rqtl2 import (compose_rqtl2_cmd, + prepare_files, + validate_required_keys, + write_input_file, + process_qtl2_results + ) +from gn3.computations.streaming import run_process +from gn3.computations.streaming import enable_streaming + +rqtl2 = Blueprint("rqtl2", __name__) + + +@rqtl2.route("/compute", methods=["POST"]) +@enable_streaming +def compute(log_file): + """Endpoint for computing QTL analysis using R/QTL2""" + data = request.json + required_keys = ["crosstype", "geno_data","pheno_data", "geno_codes"] + valid, error = validate_required_keys(required_keys,data) + if not valid: + return jsonify(error=error), 400 + # Provide atleast one of this data entries. + if "physical_map_data" not in data and "geno_map_data" not in data: + return jsonify(error="You need to Provide\ + Either the Physical map or Geno Map data of markers"), 400 + run_id = request.args.get("id", "output") + # prepare necessary files and dir for computation + (workspace_dir, input_file, + output_file, _log2_file) = prepare_files(current_app.config.get("TMPDIR")) + # write the input file with data required for creating the cross + write_input_file(input_file, workspace_dir, data) + # TODO : Implement a better way for fetching the file Path. + rqtl_path =Path(__file__).absolute().parent.parent.parent.joinpath("scripts/rqtl2_wrapper.R") + if not rqtl_path.is_file(): + return jsonify({"error" : f"The script {rqtl_path} does not exists"}), 400 + rqtl2_cmd = compose_rqtl2_cmd(rqtl_path, input_file, + output_file, workspace_dir, + data, current_app.config) + process_output = run_process(rqtl2_cmd.split(),log_file, run_id) + if process_output["code"] != 0: + # Err out for any non-zero status code + return jsonify(process_output), 400 + results = process_qtl2_results(output_file) + shutil.rmtree(workspace_dir, ignore_errors=True, onerror=None) + # append this at end of computation to the log file to mark end of gn3 computation + with open(log_file, "ab+") as file_handler: + file_handler.write("Done with GN3 Computation".encode("utf-8")) + return jsonify(results) diff --git a/gn3/api/search.py b/gn3/api/search.py index f696428..e814a00 100644 --- a/gn3/api/search.py +++ b/gn3/api/search.py @@ -6,7 +6,7 @@ import gzip import json from functools import partial, reduce from pathlib import Path -from typing import Callable +from typing import Union, Callable import urllib.parse from flask import abort, Blueprint, current_app, jsonify, request @@ -40,6 +40,7 @@ def combine_queries(operator: int, *queries: xapian.Query) -> xapian.Query: return reduce(partial(xapian.Query, operator), queries) +# pylint: disable=R0903 class FieldProcessor(xapian.FieldProcessor): """ Field processor for use in a xapian query parser. @@ -65,7 +66,10 @@ def field_processor_or(*field_processors: FieldProcessorFunction) -> FieldProces for field_processor in field_processors])) -def liftover(chain_file: Path, position: ChromosomalPosition) -> Maybe[ChromosomalPosition]: +def liftover( + chain_file: Union[str, Path], + position: ChromosomalPosition +) -> Maybe[ChromosomalPosition]: """Liftover chromosomal position using chain file.""" # The chain file format is described at # https://genome.ucsc.edu/goldenPath/help/chain.html @@ -91,7 +95,10 @@ def liftover(chain_file: Path, position: ChromosomalPosition) -> Maybe[Chromosom return Nothing -def liftover_interval(chain_file: str, interval: ChromosomalInterval) -> ChromosomalInterval: +def liftover_interval( + chain_file: Union[str, Path], + interval: ChromosomalInterval +) -> ChromosomalInterval: """ Liftover interval using chain file. @@ -258,7 +265,7 @@ def parse_query(synteny_files_directory: Path, query: str): xapian.Query(species_prefix + lifted_species), chromosome_prefix, range_prefixes.index("position"), - partial(liftover_interval, + partial(liftover_interval,# type: ignore[arg-type] synteny_files_directory / chain_file))) queryparser.add_boolean_prefix( shorthand, diff --git a/gn3/api/streaming.py b/gn3/api/streaming.py new file mode 100644 index 0000000..2b6b431 --- /dev/null +++ b/gn3/api/streaming.py @@ -0,0 +1,26 @@ +""" File contains endpoint for computational streaming""" +import os +from flask import current_app +from flask import jsonify +from flask import Blueprint +from flask import request + +streaming = Blueprint("stream", __name__) + + +@streaming.route("/<identifier>", methods=["GET"]) +def stream(identifier): + """ This endpoint streams stdout from a file. + It expects the identifier to be the filename + in the TMPDIR created at the main computation + endpoint see example api/rqtl.""" + output_file = os.path.join(current_app.config.get("TMPDIR"), + f"{identifier}.txt") + seek_position = int(request.args.get("peak", 0)) + with open(output_file, encoding="utf-8") as file_handler: + # read from the last read position default to 0 + file_handler.seek(seek_position) + results = {"data": file_handler.readlines(), + "run_id": identifier, + "pointer": file_handler.tell()} + return jsonify(results) diff --git a/gn3/api/wgcna.py b/gn3/api/wgcna.py index fa044cf..5468a2e 100644 --- a/gn3/api/wgcna.py +++ b/gn3/api/wgcna.py @@ -17,7 +17,9 @@ def run_wgcna(): wgcna_script = current_app.config["WGCNA_RSCRIPT"] - results = call_wgcna_script(wgcna_script, wgcna_data) + results = call_wgcna_script(wgcna_script, + wgcna_data, + current_app.config["TMPDIR"]) if results.get("data") is None: return jsonify(results), 401 |
