about summary refs log tree commit diff
diff options
context:
space:
mode:
authorJohn Nduli2024-10-15 19:34:11 +0300
committerFrederick Muriuki Muriithi2024-10-15 13:30:27 -0500
commit9bacc257ba784e06bb5f91e8ab3b64805b552f21 (patch)
tree9c6dca843794e6652ad9b10f43307492a838ea54
parentea88747988ad1cf93455559b4fe8ffe3cd126935 (diff)
downloadgenenetwork3-9bacc257ba784e06bb5f91e8ab3b64805b552f21.tar.gz
fix: use require_token to validate gn3 apis
-rw-r--r--gn3/api/llm.py75
-rw-r--r--gn3/api/metadata_api/wiki.py6
2 files changed, 42 insertions, 39 deletions
diff --git a/gn3/api/llm.py b/gn3/api/llm.py
index b9ffbb2..93ffc78 100644
--- a/gn3/api/llm.py
+++ b/gn3/api/llm.py
@@ -1,6 +1,7 @@
 """Api endpoints for gnqa"""
 import json
 from datetime import datetime
+from typing import Optional
 
 from flask import Blueprint
 from flask import current_app
@@ -9,7 +10,8 @@ from flask import request
 
 from gn3.llms.process import get_gnqa
 from gn3.llms.errors import LLMError
-from gn3.auth.authorisation.oauth2.resource_server import require_oauth
+
+from gn3.oauth2.authorisation import require_token
 from gn3.auth import db
 
 
@@ -48,7 +50,8 @@ def database_setup():
 
 
 @gnqa.route("/search", methods=["GET"])
-def search():
+@require_token
+def search(auth_token=None):
     """Api  endpoint for searching queries in fahamu Api"""
     query = request.args.get("query", "")
     if not query:
@@ -57,9 +60,9 @@ def search():
     if not fahamu_token:
         raise LLMError(
             "Request failed: an LLM authorisation token  is required ", query)
+    user_id = get_user_id(auth_token)
     database_setup()
-    with (db.connection(current_app.config["LLM_DB_PATH"]) as conn,
-          require_oauth.acquire("profile user") as token):
+    with db.connection(current_app.config["LLM_DB_PATH"]) as conn:
         cursor = conn.cursor()
         previous_answer_query = """
         SELECT user_id, task_id, query, results FROM history
@@ -67,7 +70,7 @@ def search():
                 user_id = ? AND
                 query = ?
             ORDER BY created_at DESC LIMIT 1 """
-        res = cursor.execute(previous_answer_query, (str(token.user.user_id), query))
+        res = cursor.execute(previous_answer_query, (user_id, query))
         previous_result = res.fetchone()
         if previous_result:
             _, _, _, response = previous_result
@@ -84,7 +87,7 @@ def search():
         cursor.execute(
             """INSERT INTO history(user_id, task_id, query, results)
             VALUES(?, ?, ?, ?)
-            """, (str(token.user.user_id), str(task_id["task_id"]),
+            """, (user_id, str(task_id["task_id"]),
                   query,
                   json.dumps(response))
         )
@@ -92,14 +95,14 @@ def search():
 
 
 @gnqa.route("/rating/<task_id>", methods=["POST"])
-@require_oauth("profile")
-def rate_queries(task_id):
+@require_token
+def rate_queries(task_id, auth_token=None):
     """Api endpoint for rating GNQA query and answer"""
     database_setup()
-    with (require_oauth.acquire("profile") as token,
-          db.connection(current_app.config["LLM_DB_PATH"]) as conn):
+    user_id = get_user_id(auth_token)
+    with db.connection(current_app.config["LLM_DB_PATH"]) as conn:
         results = request.json
-        user_id, query, answer, weight = (token.user.user_id,
+        user_id, query, answer, weight = (user_id,
                                           results.get("query"),
                                           results.get("answer"),
                                           results.get("weight", 0))
@@ -109,24 +112,23 @@ def rate_queries(task_id):
         VALUES(?, ?, ?, ?, ?)
         ON CONFLICT(task_id) DO UPDATE SET
         weight=excluded.weight
-        """, (str(user_id), query, answer, weight, task_id))
+        """, (user_id, query, answer, weight, task_id))
         return {
             "message": "You have successfully rated this query. Thank you!"
         }, 200
 
 
 @gnqa.route("/search/records", methods=["GET"])
-@require_oauth("profile user")
-def get_user_search_records():
+@require_token
+def get_user_search_records(auth_token=None):
     """get all  history records for a given user using their
     user id
     """
-    with (require_oauth.acquire("profile user") as token,
-          db.connection(current_app.config["LLM_DB_PATH"]) as conn):
+    with db.connection(current_app.config["LLM_DB_PATH"]) as conn:
         cursor = conn.cursor()
         cursor.execute(
             """SELECT task_id, query, created_at from history WHERE user_id=?""",
-            (str(token.user.user_id),))
+            (get_user_id(auth_token),))
         results = [dict(item) for item in cursor.fetchall()]
         return jsonify(sorted(results, reverse=True,
                               key=lambda x: datetime.strptime(x.get("created_at"),
@@ -134,17 +136,15 @@ def get_user_search_records():
 
 
 @gnqa.route("/search/record/<task_id>", methods=["GET"])
-@require_oauth("profile user")
-def get_user_record_by_task(task_id):
+@require_token
+def get_user_record_by_task(task_id, auth_token = None):
     """Get user previous search record by task id """
-    with (require_oauth.acquire("profile user") as token,
-          db.connection(current_app.config["LLM_DB_PATH"]) as conn):
+    with db.connection(current_app.config["LLM_DB_PATH"]) as conn:
         cursor = conn.cursor()
         cursor.execute(
             """SELECT results from history
             Where task_id=? and user_id=?""",
-            (task_id,
-             str(token.user.user_id),))
+            (task_id, get_user_id(auth_token),))
         record = cursor.fetchone()
         if record:
             return dict(record).get("results")
@@ -152,28 +152,31 @@ def get_user_record_by_task(task_id):
 
 
 @gnqa.route("/search/record/<task_id>", methods=["DELETE"])
-@require_oauth("profile user")
-def delete_record(task_id):
+@require_token
+def delete_record(task_id, auth_token = None):
     """Delete user previous seach record by task-id"""
-    with (require_oauth.acquire("profile user") as token,
-          db.connection(current_app.config["LLM_DB_PATH"]) as conn):
+    with db.connection(current_app.config["LLM_DB_PATH"]) as conn:
         cursor = conn.cursor()
         query = """DELETE FROM history
         WHERE task_id=? and user_id=?"""
-        cursor.execute(query, (task_id, token.user.user_id,))
+        cursor.execute(query, (task_id, get_user_id(auth_token),))
         return {"msg": f"Successfully Deleted the task {task_id}"}
 
 
 @gnqa.route("/search/records", methods=["DELETE"])
-@require_oauth("profile user")
-def delete_records():
+@require_token
+def delete_records(auth_token=None):
     """ Delete a users records using for all given task ids"""
-    with (require_oauth.acquire("profile user") as token,
-          db.connection(current_app.config["LLM_DB_PATH"]) as conn):
+    with db.connection(current_app.config["LLM_DB_PATH"]) as conn:
         task_ids = list(request.json.values())
         cursor = conn.cursor()
-        query = f"""
-DELETE FROM history WHERE task_id IN ({', '.join('?' * len(task_ids))}) AND user_id=?
-        """
-        cursor.execute(query, (*task_ids, str(token.user.user_id),))
+        query = f""" DELETE FROM history WHERE task_id IN ({', '.join('?' * len(task_ids))}) AND user_id=? """
+        cursor.execute(query, (*task_ids, get_user_id(auth_token),))
         return jsonify({})
+
+
+def get_user_id(auth_token: Optional[dict] = None):
+    if auth_token is None or auth_token.get("jwt", {}).get("sub") is None:
+        raise LLMError("Invalid auth token encountered")
+    user_id = auth_token["jwt"]["sub"]
+    return user_id
diff --git a/gn3/api/metadata_api/wiki.py b/gn3/api/metadata_api/wiki.py
index e8c59b5..7a00786 100644
--- a/gn3/api/metadata_api/wiki.py
+++ b/gn3/api/metadata_api/wiki.py
@@ -6,7 +6,7 @@ from typing import Any, Dict
 from flask import Blueprint, request, jsonify, current_app, make_response
 
 from gn3 import db_utils
-from gn3.auth.authorisation.oauth2.resource_server import require_oauth
+from gn3.oauth2.authorisation import require_token
 from gn3.db import wiki
 from gn3.db.rdf.wiki import (
     get_wiki_entries_by_symbol,
@@ -21,8 +21,8 @@ rif_blueprint = Blueprint("rif", __name__, url_prefix="rif")
 
 
 @wiki_blueprint.route("/<int:comment_id>/edit", methods=["POST"])
-@require_oauth("profile")
-def edit_wiki(comment_id: int):
+@require_token
+def edit_wiki(comment_id: int, **kwargs):
     """Edit wiki comment. This is achieved by adding another entry with a new VersionId"""
     # FIXME: attempt to check and fix for types here with relevant errors
     payload: Dict[str, Any] = request.json  # type: ignore