about summary refs log tree commit diff
path: root/gn3/api/llm.py
diff options
context:
space:
mode:
Diffstat (limited to 'gn3/api/llm.py')
-rw-r--r--gn3/api/llm.py93
1 files changed, 92 insertions, 1 deletions
diff --git a/gn3/api/llm.py b/gn3/api/llm.py
index 39f434a..8689290 100644
--- a/gn3/api/llm.py
+++ b/gn3/api/llm.py
@@ -1,8 +1,11 @@
 """Api endpoints for gnqa"""
+import ipaddress
 import json
 import string
 import uuid
+
 from datetime import datetime
+from datetime import timedelta
 from typing import Optional
 from functools import wraps
 
@@ -46,12 +49,24 @@ CREATE TABLE IF NOT EXISTS Rating(
 """
 
 
+RATE_LIMITER_TABLE_CREATE_QUERY = """
+CREATE TABLE IF NOT EXISTS Limiter(
+    identifier TEXT NOT NULL,
+    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+    tokens INTEGER,
+    expiry_time TIMESTAMP,
+    PRIMARY KEY(identifier)
+)
+"""
+
+
 def database_setup():
     """Temporary method to remove the need to have CREATE queries in functions"""
     with db.connection(current_app.config["LLM_DB_PATH"]) as conn:
         cursor = conn.cursor()
         cursor.execute(HISTORY_TABLE_CREATE_QUERY)
         cursor.execute(RATING_TABLE_CREATE_QUERY)
+        cursor.execute(RATE_LIMITER_TABLE_CREATE_QUERY)
 
 
 def clean_query(query:str) -> str:
@@ -73,7 +88,6 @@ def is_verified_anonymous_user(request_metadata):
         request_metadata.headers.get("Anony-Metadata", "")) # TODO~ verify this for integrity
     return bool(anony_id) and user_status.lower() == "verified"
 
-
 def with_gnqna_fallback(view_func):
     """Allow fallback to GNQNA user if token auth fails or token is malformed."""
     @wraps(view_func)
@@ -104,6 +118,74 @@ def with_gnqna_fallback(view_func):
     return wrapper
 
 
+def is_valid_address(ip_string) -> bool :
+    """Function checks if is a valid ip address is valid"""
+    # todo !verify data is sent from gn2
+    try:
+        ipaddress.ip_address(ip_string)
+        return True
+    except ValueError:
+        return False
+
+
+def check_rate_limiter(request_metadata, db_path, tokens_lifespan=1440, default_tokens=4):
+    """
+    Checks if an anonymous user has a valid token within the given lifespan.
+    If expired or not found, creates or resets the token bucket.
+    `tokens_lifespan` is in seconds.  24*60
+    default_token set to 4 requests per hour.
+    """
+    # Extract IP address /identifier
+    user_metadata = json.loads(request_metadata.headers.get("Anony-Metadata", {}))
+    ip_address = user_metadata.get("ip_address")
+    if not ip_address or not is_valid_address(ip_address):
+        raise ValueError("Please provide a valid IP address")
+    now = datetime.utcnow()
+    new_expiry = (now + timedelta(seconds=tokens_lifespan)).strftime("%Y-%m-%d %H:%M:%S")
+
+    with  db.connection(db_path) as conn:
+        cursor = conn.cursor()
+        # Fetch existing limiter record
+        cursor.execute("""
+            SELECT tokens, expiry_time FROM Limiter
+            WHERE identifier = ?
+        """, (ip_address,))
+        row = cursor.fetchone()
+
+        if row:
+            tokens, expiry_time_str = row
+            expiry_time = datetime.strptime(expiry_time_str, "%Y-%m-%d %H:%M:%S")
+            time_diff = (expiry_time - now).total_seconds()
+
+            if 0 < time_diff <= tokens_lifespan:
+                if tokens > 0:
+                    # Consume token
+                    cursor.execute("""
+                        UPDATE Limiter
+                        SET tokens = tokens - 1
+                        WHERE identifier = ? AND tokens > 0
+                    """, (ip_address,))
+                    return True
+                else:
+                    raise LLMError("Rate limit exceeded. Please try again later.",
+                                   request_metadata.args.get("query"))
+            else:
+                # Token expired — reset ~probably reset this after 200 status
+                cursor.execute("""
+                    UPDATE Limiter
+                    SET tokens = ?, expiry_time = ?
+                    WHERE identifier = ?
+                """, (default_tokens, new_expiry, ip_address))
+                return True
+        else:
+            # New user — insert record  ~probably reset this after 200 status
+            cursor.execute("""
+                INSERT INTO Limiter(identifier, tokens, expiry_time)
+                VALUES (?, ?, ?)
+            """, (ip_address, default_tokens, new_expiry))
+            return True
+
+
 @gnqa.route("/search", methods=["GET"])
 @with_gnqna_fallback
 @require_token
@@ -112,11 +194,16 @@ def search(auth_token=None, valid_anony=False):
     query = request.args.get("query", "")
     if not query:
         return jsonify({"error": "query get parameter is missing in the request"}), 400
+
     fahamu_token = current_app.config.get("FAHAMU_AUTH_TOKEN")
     if not fahamu_token:
         raise LLMError(
             "Request failed: an LLM authorisation token  is required ", query)
     database_setup()
+    # check if is valid anon
+    # if valid_anony:
+    check_rate_limiter(request, current_app.config["LLM_DB_PATH"]) #Will raise error if not
+    # else verified user allowed
     with db.connection(current_app.config["LLM_DB_PATH"]) as conn:
         cursor = conn.cursor()
         previous_answer_query = """
@@ -132,6 +219,10 @@ def search(auth_token=None, valid_anony=False):
             response["query"] = query
             return response
 
+        if valid_anony:
+            # rate limit anonymous verified users
+            check_rate_limiter(request, current_app.config["LLM_DB_PATH"])
+
         task_id, answer, refs = get_gnqa(
             query, fahamu_token, current_app.config.get("DATA_DIR"))
         response = {