about summary refs log tree commit diff
path: root/gn3/api/llm.py
diff options
context:
space:
mode:
Diffstat (limited to 'gn3/api/llm.py')
-rw-r--r--gn3/api/llm.py236
1 files changed, 193 insertions, 43 deletions
diff --git a/gn3/api/llm.py b/gn3/api/llm.py
index b9ffbb2..dc8412e 100644
--- a/gn3/api/llm.py
+++ b/gn3/api/llm.py
@@ -1,16 +1,25 @@
 """Api endpoints for gnqa"""
+import ipaddress
 import json
+import string
+import uuid
+
 from datetime import datetime
+from datetime import timedelta
+from typing import Optional
+from functools import wraps
 
 from flask import Blueprint
 from flask import current_app
 from flask import jsonify
 from flask import request
 
+from authlib.jose.errors import DecodeError
 from gn3.llms.process import get_gnqa
 from gn3.llms.errors import LLMError
-from gn3.auth.authorisation.oauth2.resource_server import require_oauth
-from gn3.auth import db
+
+from gn3.oauth2.authorisation import require_token
+from gn3 import sqlite_db_utils as db
 
 
 gnqa = Blueprint("gnqa", __name__)
@@ -26,6 +35,7 @@ CREATE TABLE IF NOT EXISTS history(
     ) WITHOUT ROWID
 """
 
+
 RATING_TABLE_CREATE_QUERY = """
 CREATE TABLE IF NOT EXISTS Rating(
     user_id TEXT NOT NULL,
@@ -39,40 +49,177 @@ CREATE TABLE IF NOT EXISTS Rating(
 """
 
 
+RATE_LIMITER_TABLE_CREATE_QUERY = """
+CREATE TABLE IF NOT EXISTS Limiter(
+    identifier TEXT NOT NULL,
+    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+    tokens INTEGER,
+    expiry_time TIMESTAMP,
+    PRIMARY KEY(identifier)
+)
+"""
+
+
 def database_setup():
     """Temporary method to remove the need to have CREATE queries in functions"""
     with db.connection(current_app.config["LLM_DB_PATH"]) as conn:
         cursor = conn.cursor()
         cursor.execute(HISTORY_TABLE_CREATE_QUERY)
         cursor.execute(RATING_TABLE_CREATE_QUERY)
+        cursor.execute(RATE_LIMITER_TABLE_CREATE_QUERY)
+
+
+def clean_query(query:str) -> str:
+    """This function cleans up query  removing
+    punctuation  and whitepace and transform to
+    lowercase
+    clean_query("!hello test.") -> "hello test"
+    """
+    strip_chars = string.punctuation + string.whitespace
+    str_query = query.lower().strip(strip_chars)
+    return str_query
+
+
+def is_verified_anonymous_user(header_metadata):
+    """This function should verify autheniticity of metadate from gn2 """
+    anony_id = header_metadata.get("Anonymous-Id") #should verify this + metadata signature
+    user_status = header_metadata.get("Anonymous-Status", "")
+    _user_signed_metadata = (
+        header_metadata.get("Anony-Metadata", "")) # TODO~ verify this for integrity with tokens
+    return bool(anony_id) and user_status.lower() == "verified"
+
+def with_gnqna_fallback(view_func):
+    """Allow fallback to GNQNA user if token auth fails or token is malformed."""
+    @wraps(view_func)
+    def wrapper(*args, **kwargs):
+        def call_with_anonymous_fallback():
+            return view_func.__wrapped__(*args,
+                   **{**kwargs, "auth_token": None, "valid_anony": True})
+
+        try:
+            response = view_func(*args, **kwargs)
+
+            is_invalid_token = (
+                isinstance(response, tuple) and
+                len(response) == 2 and
+                response[1] == 400
+            )
+
+            if is_invalid_token and is_verified_anonymous_user(dict(request.headers)):
+                return call_with_anonymous_fallback()
+
+            return response
+
+        except (DecodeError, ValueError): # occurs when trying to parse the token or auth results
+            if is_verified_anonymous_user(dict(request.headers)):
+                return call_with_anonymous_fallback()
+            return view_func.__wrapped__(*args, **kwargs)
+
+    return wrapper
+
+
+def is_valid_address(ip_string) -> bool :
+    """Function checks if is a valid ip address is valid"""
+    # todo !verify data is sent from gn2
+    try:
+        ipaddress.ip_address(ip_string)
+        return True
+    except ValueError:
+        return False
+
+
+def check_rate_limiter(ip_address, db_path,  query, tokens_lifespan=1440, default_tokens=4):
+    """
+    Checks if an anonymous user has a valid token within the given lifespan.
+    If expired or not found, creates or resets the token bucket.
+    `tokens_lifespan` is in seconds.  1440 seconds.
+    default_token set to 4 requests per hour.
+    """
+    # Extract IP address /identifier
+    if not ip_address or not is_valid_address(ip_address):
+        raise ValueError("Please provide a valid IP address")
+    now = datetime.utcnow()
+    new_expiry = (now + timedelta(seconds=tokens_lifespan)).strftime("%Y-%m-%d %H:%M:%S")
+
+    with  db.connection(db_path) as conn:
+        cursor = conn.cursor()
+        # Fetch existing limiter record
+        cursor.execute("""
+            SELECT tokens, expiry_time FROM Limiter
+            WHERE identifier = ?
+        """, (ip_address,))
+        row = cursor.fetchone()
+
+        if row:
+            tokens, expiry_time_str = row
+            expiry_time = datetime.strptime(expiry_time_str, "%Y-%m-%d %H:%M:%S")
+            time_diff = (expiry_time - now).total_seconds()
+
+            if 0 < time_diff <= tokens_lifespan:
+                if tokens > 0:
+                    # Consume token
+                    cursor.execute("""
+                        UPDATE Limiter
+                        SET tokens = tokens - 1
+                        WHERE identifier = ? AND tokens > 0
+                    """, (ip_address,))
+                    return True
+                else:
+                    raise LLMError("Rate limit exceeded. Please try again later.",
+                                   query)
+            else:
+                # Token expired — reset ~probably reset this after 200 status
+                cursor.execute("""
+                    UPDATE Limiter
+                    SET tokens = ?, expiry_time = ?
+                    WHERE identifier = ?
+                """, (default_tokens, new_expiry, ip_address))
+                return True
+        else:
+            # New user — insert record  ~probably reset this after 200 status
+            cursor.execute("""
+                INSERT INTO Limiter(identifier, tokens, expiry_time)
+                VALUES (?, ?, ?)
+            """, (ip_address, default_tokens, new_expiry))
+            return True
 
 
 @gnqa.route("/search", methods=["GET"])
-def search():
+@with_gnqna_fallback
+@require_token
+def search(auth_token=None, valid_anony=False):
     """Api  endpoint for searching queries in fahamu Api"""
     query = request.args.get("query", "")
     if not query:
         return jsonify({"error": "query get parameter is missing in the request"}), 400
+
     fahamu_token = current_app.config.get("FAHAMU_AUTH_TOKEN")
     if not fahamu_token:
         raise LLMError(
             "Request failed: an LLM authorisation token  is required ", query)
     database_setup()
-    with (db.connection(current_app.config["LLM_DB_PATH"]) as conn,
-          require_oauth.acquire("profile user") as token):
+    with db.connection(current_app.config["LLM_DB_PATH"]) as conn:
         cursor = conn.cursor()
         previous_answer_query = """
         SELECT user_id, task_id, query, results FROM history
-            WHERE created_at > DATE('now', '-1 day') AND
-                user_id = ? AND
+            WHERE created_at > DATE('now', '-21 day') AND
                 query = ?
             ORDER BY created_at DESC LIMIT 1 """
-        res = cursor.execute(previous_answer_query, (str(token.user.user_id), query))
+        res = cursor.execute(previous_answer_query, (clean_query(query),))
         previous_result = res.fetchone()
         if previous_result:
             _, _, _, response = previous_result
+            response = json.loads(response)
+            response["query"] = query
             return response
 
+        if valid_anony:
+            # rate limit anonymous verified users
+            user_metadata = json.loads(request.headers.get("Anony-Metadata", {}))
+            check_rate_limiter(user_metadata.get("ip_address", ""),
+                               current_app.config["LLM_DB_PATH"],
+                               request.args.get("query", ""))
+
         task_id, answer, refs = get_gnqa(
             query, fahamu_token, current_app.config.get("DATA_DIR"))
         response = {
@@ -81,52 +228,51 @@ def search():
             "answer": answer,
             "references": refs
         }
+        user_id = str(uuid.uuid4()) if valid_anony else get_user_id(auth_token)
         cursor.execute(
             """INSERT INTO history(user_id, task_id, query, results)
             VALUES(?, ?, ?, ?)
-            """, (str(token.user.user_id), str(task_id["task_id"]),
-                  query,
+            """, (user_id, str(task_id["task_id"]),
+                  clean_query(query),
                   json.dumps(response))
         )
         return response
 
 
 @gnqa.route("/rating/<task_id>", methods=["POST"])
-@require_oauth("profile")
-def rate_queries(task_id):
+@require_token
+def rate_queries(task_id, auth_token=None):
     """Api endpoint for rating GNQA query and answer"""
     database_setup()
-    with (require_oauth.acquire("profile") as token,
-          db.connection(current_app.config["LLM_DB_PATH"]) as conn):
+    user_id = get_user_id(auth_token)
+    with db.connection(current_app.config["LLM_DB_PATH"]) as conn:
         results = request.json
-        user_id, query, answer, weight = (token.user.user_id,
-                                          results.get("query"),
-                                          results.get("answer"),
-                                          results.get("weight", 0))
+        query, answer, weight = (results.get("query"),
+                                 results.get("answer"),
+                                 results.get("weight", 0))
         cursor = conn.cursor()
         cursor.execute("""INSERT INTO Rating(user_id, query,
         answer, weight, task_id)
         VALUES(?, ?, ?, ?, ?)
         ON CONFLICT(task_id) DO UPDATE SET
         weight=excluded.weight
-        """, (str(user_id), query, answer, weight, task_id))
+        """, (user_id, query, answer, weight, task_id))
         return {
             "message": "You have successfully rated this query. Thank you!"
         }, 200
 
 
 @gnqa.route("/search/records", methods=["GET"])
-@require_oauth("profile user")
-def get_user_search_records():
+@require_token
+def get_user_search_records(auth_token=None):
     """get all  history records for a given user using their
     user id
     """
-    with (require_oauth.acquire("profile user") as token,
-          db.connection(current_app.config["LLM_DB_PATH"]) as conn):
+    with db.connection(current_app.config["LLM_DB_PATH"]) as conn:
         cursor = conn.cursor()
         cursor.execute(
             """SELECT task_id, query, created_at from history WHERE user_id=?""",
-            (str(token.user.user_id),))
+            (get_user_id(auth_token),))
         results = [dict(item) for item in cursor.fetchall()]
         return jsonify(sorted(results, reverse=True,
                               key=lambda x: datetime.strptime(x.get("created_at"),
@@ -134,17 +280,15 @@ def get_user_search_records():
 
 
 @gnqa.route("/search/record/<task_id>", methods=["GET"])
-@require_oauth("profile user")
-def get_user_record_by_task(task_id):
+@require_token
+def get_user_record_by_task(task_id, auth_token = None):
     """Get user previous search record by task id """
-    with (require_oauth.acquire("profile user") as token,
-          db.connection(current_app.config["LLM_DB_PATH"]) as conn):
+    with db.connection(current_app.config["LLM_DB_PATH"]) as conn:
         cursor = conn.cursor()
         cursor.execute(
             """SELECT results from history
             Where task_id=? and user_id=?""",
-            (task_id,
-             str(token.user.user_id),))
+            (task_id, get_user_id(auth_token),))
         record = cursor.fetchone()
         if record:
             return dict(record).get("results")
@@ -152,28 +296,34 @@ def get_user_record_by_task(task_id):
 
 
 @gnqa.route("/search/record/<task_id>", methods=["DELETE"])
-@require_oauth("profile user")
-def delete_record(task_id):
+@require_token
+def delete_record(task_id, auth_token = None):
     """Delete user previous seach record by task-id"""
-    with (require_oauth.acquire("profile user") as token,
-          db.connection(current_app.config["LLM_DB_PATH"]) as conn):
+    with db.connection(current_app.config["LLM_DB_PATH"]) as conn:
         cursor = conn.cursor()
         query = """DELETE FROM history
         WHERE task_id=? and user_id=?"""
-        cursor.execute(query, (task_id, token.user.user_id,))
+        cursor.execute(query, (task_id, get_user_id(auth_token),))
         return {"msg": f"Successfully Deleted the task {task_id}"}
 
 
 @gnqa.route("/search/records", methods=["DELETE"])
-@require_oauth("profile user")
-def delete_records():
+@require_token
+def delete_records(auth_token=None):
     """ Delete a users records using for all given task ids"""
-    with (require_oauth.acquire("profile user") as token,
-          db.connection(current_app.config["LLM_DB_PATH"]) as conn):
+    with db.connection(current_app.config["LLM_DB_PATH"]) as conn:
         task_ids = list(request.json.values())
         cursor = conn.cursor()
-        query = f"""
-DELETE FROM history WHERE task_id IN ({', '.join('?' * len(task_ids))}) AND user_id=?
-        """
-        cursor.execute(query, (*task_ids, str(token.user.user_id),))
+        query = ("DELETE FROM history WHERE task_id IN "
+                 f"({', '.join('?' * len(task_ids))}) "
+                 "AND user_id=?")
+        cursor.execute(query, (*task_ids, get_user_id(auth_token),))
         return jsonify({})
+
+
+def get_user_id(auth_token: Optional[dict] = None):
+    """Retrieve the user ID from the JWT token."""
+    if auth_token is None or auth_token.get("jwt", {}).get("sub") is None:
+        raise LLMError("Invalid auth token encountered")
+    user_id = auth_token["jwt"]["sub"]
+    return user_id