diff options
| -rw-r--r-- | gn3/api/llm.py | 93 |
1 files changed, 92 insertions, 1 deletions
diff --git a/gn3/api/llm.py b/gn3/api/llm.py index 39f434a..8689290 100644 --- a/gn3/api/llm.py +++ b/gn3/api/llm.py @@ -1,8 +1,11 @@ """Api endpoints for gnqa""" +import ipaddress import json import string import uuid + from datetime import datetime +from datetime import timedelta from typing import Optional from functools import wraps @@ -46,12 +49,24 @@ CREATE TABLE IF NOT EXISTS Rating( """ +RATE_LIMITER_TABLE_CREATE_QUERY = """ +CREATE TABLE IF NOT EXISTS Limiter( + identifier TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + tokens INTEGER, + expiry_time TIMESTAMP, + PRIMARY KEY(identifier) +) +""" + + def database_setup(): """Temporary method to remove the need to have CREATE queries in functions""" with db.connection(current_app.config["LLM_DB_PATH"]) as conn: cursor = conn.cursor() cursor.execute(HISTORY_TABLE_CREATE_QUERY) cursor.execute(RATING_TABLE_CREATE_QUERY) + cursor.execute(RATE_LIMITER_TABLE_CREATE_QUERY) def clean_query(query:str) -> str: @@ -73,7 +88,6 @@ def is_verified_anonymous_user(request_metadata): request_metadata.headers.get("Anony-Metadata", "")) # TODO~ verify this for integrity return bool(anony_id) and user_status.lower() == "verified" - def with_gnqna_fallback(view_func): """Allow fallback to GNQNA user if token auth fails or token is malformed.""" @wraps(view_func) @@ -104,6 +118,74 @@ def with_gnqna_fallback(view_func): return wrapper +def is_valid_address(ip_string) -> bool : + """Function checks if is a valid ip address is valid""" + # todo !verify data is sent from gn2 + try: + ipaddress.ip_address(ip_string) + return True + except ValueError: + return False + + +def check_rate_limiter(request_metadata, db_path, tokens_lifespan=1440, default_tokens=4): + """ + Checks if an anonymous user has a valid token within the given lifespan. + If expired or not found, creates or resets the token bucket. + `tokens_lifespan` is in seconds. 24*60 + default_token set to 4 requests per hour. + """ + # Extract IP address /identifier + user_metadata = json.loads(request_metadata.headers.get("Anony-Metadata", {})) + ip_address = user_metadata.get("ip_address") + if not ip_address or not is_valid_address(ip_address): + raise ValueError("Please provide a valid IP address") + now = datetime.utcnow() + new_expiry = (now + timedelta(seconds=tokens_lifespan)).strftime("%Y-%m-%d %H:%M:%S") + + with db.connection(db_path) as conn: + cursor = conn.cursor() + # Fetch existing limiter record + cursor.execute(""" + SELECT tokens, expiry_time FROM Limiter + WHERE identifier = ? + """, (ip_address,)) + row = cursor.fetchone() + + if row: + tokens, expiry_time_str = row + expiry_time = datetime.strptime(expiry_time_str, "%Y-%m-%d %H:%M:%S") + time_diff = (expiry_time - now).total_seconds() + + if 0 < time_diff <= tokens_lifespan: + if tokens > 0: + # Consume token + cursor.execute(""" + UPDATE Limiter + SET tokens = tokens - 1 + WHERE identifier = ? AND tokens > 0 + """, (ip_address,)) + return True + else: + raise LLMError("Rate limit exceeded. Please try again later.", + request_metadata.args.get("query")) + else: + # Token expired — reset ~probably reset this after 200 status + cursor.execute(""" + UPDATE Limiter + SET tokens = ?, expiry_time = ? + WHERE identifier = ? + """, (default_tokens, new_expiry, ip_address)) + return True + else: + # New user — insert record ~probably reset this after 200 status + cursor.execute(""" + INSERT INTO Limiter(identifier, tokens, expiry_time) + VALUES (?, ?, ?) + """, (ip_address, default_tokens, new_expiry)) + return True + + @gnqa.route("/search", methods=["GET"]) @with_gnqna_fallback @require_token @@ -112,11 +194,16 @@ def search(auth_token=None, valid_anony=False): query = request.args.get("query", "") if not query: return jsonify({"error": "query get parameter is missing in the request"}), 400 + fahamu_token = current_app.config.get("FAHAMU_AUTH_TOKEN") if not fahamu_token: raise LLMError( "Request failed: an LLM authorisation token is required ", query) database_setup() + # check if is valid anon + # if valid_anony: + check_rate_limiter(request, current_app.config["LLM_DB_PATH"]) #Will raise error if not + # else verified user allowed with db.connection(current_app.config["LLM_DB_PATH"]) as conn: cursor = conn.cursor() previous_answer_query = """ @@ -132,6 +219,10 @@ def search(auth_token=None, valid_anony=False): response["query"] = query return response + if valid_anony: + # rate limit anonymous verified users + check_rate_limiter(request, current_app.config["LLM_DB_PATH"]) + task_id, answer, refs = get_gnqa( query, fahamu_token, current_app.config.get("DATA_DIR")) response = { |
