diff options
Diffstat (limited to 'gn3')
| -rw-r--r-- | gn3/api/llm.py | 75 | ||||
| -rw-r--r-- | gn3/api/metadata_api/wiki.py | 6 |
2 files changed, 42 insertions, 39 deletions
diff --git a/gn3/api/llm.py b/gn3/api/llm.py index b9ffbb2..93ffc78 100644 --- a/gn3/api/llm.py +++ b/gn3/api/llm.py @@ -1,6 +1,7 @@ """Api endpoints for gnqa""" import json from datetime import datetime +from typing import Optional from flask import Blueprint from flask import current_app @@ -9,7 +10,8 @@ from flask import request from gn3.llms.process import get_gnqa from gn3.llms.errors import LLMError -from gn3.auth.authorisation.oauth2.resource_server import require_oauth + +from gn3.oauth2.authorisation import require_token from gn3.auth import db @@ -48,7 +50,8 @@ def database_setup(): @gnqa.route("/search", methods=["GET"]) -def search(): +@require_token +def search(auth_token=None): """Api endpoint for searching queries in fahamu Api""" query = request.args.get("query", "") if not query: @@ -57,9 +60,9 @@ def search(): if not fahamu_token: raise LLMError( "Request failed: an LLM authorisation token is required ", query) + user_id = get_user_id(auth_token) database_setup() - with (db.connection(current_app.config["LLM_DB_PATH"]) as conn, - require_oauth.acquire("profile user") as token): + with db.connection(current_app.config["LLM_DB_PATH"]) as conn: cursor = conn.cursor() previous_answer_query = """ SELECT user_id, task_id, query, results FROM history @@ -67,7 +70,7 @@ def search(): user_id = ? AND query = ? ORDER BY created_at DESC LIMIT 1 """ - res = cursor.execute(previous_answer_query, (str(token.user.user_id), query)) + res = cursor.execute(previous_answer_query, (user_id, query)) previous_result = res.fetchone() if previous_result: _, _, _, response = previous_result @@ -84,7 +87,7 @@ def search(): cursor.execute( """INSERT INTO history(user_id, task_id, query, results) VALUES(?, ?, ?, ?) - """, (str(token.user.user_id), str(task_id["task_id"]), + """, (user_id, str(task_id["task_id"]), query, json.dumps(response)) ) @@ -92,14 +95,14 @@ def search(): @gnqa.route("/rating/<task_id>", methods=["POST"]) -@require_oauth("profile") -def rate_queries(task_id): +@require_token +def rate_queries(task_id, auth_token=None): """Api endpoint for rating GNQA query and answer""" database_setup() - with (require_oauth.acquire("profile") as token, - db.connection(current_app.config["LLM_DB_PATH"]) as conn): + user_id = get_user_id(auth_token) + with db.connection(current_app.config["LLM_DB_PATH"]) as conn: results = request.json - user_id, query, answer, weight = (token.user.user_id, + user_id, query, answer, weight = (user_id, results.get("query"), results.get("answer"), results.get("weight", 0)) @@ -109,24 +112,23 @@ def rate_queries(task_id): VALUES(?, ?, ?, ?, ?) ON CONFLICT(task_id) DO UPDATE SET weight=excluded.weight - """, (str(user_id), query, answer, weight, task_id)) + """, (user_id, query, answer, weight, task_id)) return { "message": "You have successfully rated this query. Thank you!" }, 200 @gnqa.route("/search/records", methods=["GET"]) -@require_oauth("profile user") -def get_user_search_records(): +@require_token +def get_user_search_records(auth_token=None): """get all history records for a given user using their user id """ - with (require_oauth.acquire("profile user") as token, - db.connection(current_app.config["LLM_DB_PATH"]) as conn): + with db.connection(current_app.config["LLM_DB_PATH"]) as conn: cursor = conn.cursor() cursor.execute( """SELECT task_id, query, created_at from history WHERE user_id=?""", - (str(token.user.user_id),)) + (get_user_id(auth_token),)) results = [dict(item) for item in cursor.fetchall()] return jsonify(sorted(results, reverse=True, key=lambda x: datetime.strptime(x.get("created_at"), @@ -134,17 +136,15 @@ def get_user_search_records(): @gnqa.route("/search/record/<task_id>", methods=["GET"]) -@require_oauth("profile user") -def get_user_record_by_task(task_id): +@require_token +def get_user_record_by_task(task_id, auth_token = None): """Get user previous search record by task id """ - with (require_oauth.acquire("profile user") as token, - db.connection(current_app.config["LLM_DB_PATH"]) as conn): + with db.connection(current_app.config["LLM_DB_PATH"]) as conn: cursor = conn.cursor() cursor.execute( """SELECT results from history Where task_id=? and user_id=?""", - (task_id, - str(token.user.user_id),)) + (task_id, get_user_id(auth_token),)) record = cursor.fetchone() if record: return dict(record).get("results") @@ -152,28 +152,31 @@ def get_user_record_by_task(task_id): @gnqa.route("/search/record/<task_id>", methods=["DELETE"]) -@require_oauth("profile user") -def delete_record(task_id): +@require_token +def delete_record(task_id, auth_token = None): """Delete user previous seach record by task-id""" - with (require_oauth.acquire("profile user") as token, - db.connection(current_app.config["LLM_DB_PATH"]) as conn): + with db.connection(current_app.config["LLM_DB_PATH"]) as conn: cursor = conn.cursor() query = """DELETE FROM history WHERE task_id=? and user_id=?""" - cursor.execute(query, (task_id, token.user.user_id,)) + cursor.execute(query, (task_id, get_user_id(auth_token),)) return {"msg": f"Successfully Deleted the task {task_id}"} @gnqa.route("/search/records", methods=["DELETE"]) -@require_oauth("profile user") -def delete_records(): +@require_token +def delete_records(auth_token=None): """ Delete a users records using for all given task ids""" - with (require_oauth.acquire("profile user") as token, - db.connection(current_app.config["LLM_DB_PATH"]) as conn): + with db.connection(current_app.config["LLM_DB_PATH"]) as conn: task_ids = list(request.json.values()) cursor = conn.cursor() - query = f""" -DELETE FROM history WHERE task_id IN ({', '.join('?' * len(task_ids))}) AND user_id=? - """ - cursor.execute(query, (*task_ids, str(token.user.user_id),)) + query = f""" DELETE FROM history WHERE task_id IN ({', '.join('?' * len(task_ids))}) AND user_id=? """ + cursor.execute(query, (*task_ids, get_user_id(auth_token),)) return jsonify({}) + + +def get_user_id(auth_token: Optional[dict] = None): + if auth_token is None or auth_token.get("jwt", {}).get("sub") is None: + raise LLMError("Invalid auth token encountered") + user_id = auth_token["jwt"]["sub"] + return user_id diff --git a/gn3/api/metadata_api/wiki.py b/gn3/api/metadata_api/wiki.py index e8c59b5..7a00786 100644 --- a/gn3/api/metadata_api/wiki.py +++ b/gn3/api/metadata_api/wiki.py @@ -6,7 +6,7 @@ from typing import Any, Dict from flask import Blueprint, request, jsonify, current_app, make_response from gn3 import db_utils -from gn3.auth.authorisation.oauth2.resource_server import require_oauth +from gn3.oauth2.authorisation import require_token from gn3.db import wiki from gn3.db.rdf.wiki import ( get_wiki_entries_by_symbol, @@ -21,8 +21,8 @@ rif_blueprint = Blueprint("rif", __name__, url_prefix="rif") @wiki_blueprint.route("/<int:comment_id>/edit", methods=["POST"]) -@require_oauth("profile") -def edit_wiki(comment_id: int): +@require_token +def edit_wiki(comment_id: int, **kwargs): """Edit wiki comment. This is achieved by adding another entry with a new VersionId""" # FIXME: attempt to check and fix for types here with relevant errors payload: Dict[str, Any] = request.json # type: ignore |
