diff options
Diffstat (limited to 'gn3/api')
-rw-r--r-- | gn3/api/llm.py | 244 |
1 files changed, 133 insertions, 111 deletions
diff --git a/gn3/api/llm.py b/gn3/api/llm.py index 7d860d8..7e60271 100644 --- a/gn3/api/llm.py +++ b/gn3/api/llm.py @@ -1,128 +1,150 @@ -"""API for data used to generate menus""" - -# pylint: skip-file +"""Api endpoints for gnqa""" +import json +from datetime import datetime -from flask import jsonify, request, Blueprint, current_app +from flask import Blueprint +from flask import current_app +from flask import jsonify +from flask import request -from functools import wraps from gn3.llms.process import get_gnqa -from gn3.llms.process import get_user_queries -from gn3.llms.process import fetch_query_results +from gn3.llms.errors import LLMError from gn3.auth.authorisation.oauth2.resource_server import require_oauth from gn3.auth import db -from redis import Redis -import json -import sqlite3 -from datetime import timedelta - -GnQNA = Blueprint("GnQNA", __name__) -def handle_errors(func): - @wraps(func) - def decorated_function(*args, **kwargs): - try: - return func(*args, **kwargs) - except Exception as error: - return jsonify({"error": str(error)}), 500 - return decorated_function +gnqa = Blueprint("gnqa", __name__) -@GnQNA.route("/gnqna", methods=["POST"]) -def gnqa(): - # todo add auth +@gnqa.route("/search", methods=["PUT"]) +def search(): + """Api endpoint for searching queries in fahamu Api""" query = request.json.get("querygnqa", "") if not query: return jsonify({"error": "querygnqa is missing in the request"}), 400 - - try: - fahamu_token = current_app.config.get("FAHAMU_AUTH_TOKEN") - if fahamu_token is None: - return jsonify({"query": query, "error": "Use of invalid fahamu auth token"}), 500 - task_id, answer, refs = get_gnqa( - query, fahamu_token, current_app.config.get("DATA_DIR")) - response = { - "task_id": task_id, - "query": query, - "answer": answer, - "references": refs - } - with (Redis.from_url(current_app.config["REDIS_URI"], - decode_responses=True) as redis_conn): - # The key will be deleted after 60 seconds - redis_conn.setex(f"LLM:random_user-{query}", timedelta(days=10), json.dumps(response)) - return jsonify({ - **response, - "prev_queries": get_user_queries("random_user", redis_conn) - }) - except Exception as error: - return jsonify({"query": query, "error": f"Request failed-{str(error)}"}), 500 - - -@GnQNA.route("/rating/<task_id>", methods=["POST"]) + fahamu_token = current_app.config.get("FAHAMU_AUTH_TOKEN") + if not fahamu_token: + raise LLMError( + "Request failed: an LLM authorisation token is required ", query) + task_id, answer, refs = get_gnqa( + query, fahamu_token, current_app.config.get("DATA_DIR")) + response = { + "task_id": task_id, + "query": query, + "answer": answer, + "references": refs + } + with (db.connection(current_app.config["LLM_DB_PATH"]) as conn, + require_oauth.acquire("profile user") as token): + cursor = conn.cursor() + cursor.execute("""CREATE TABLE IF NOT EXISTS + history(user_id TEXT NOT NULL, + task_id TEXT NOT NULL, + query TEXT NOT NULL, + results JSONB, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY(task_id)) WITHOUT ROWID""") + cursor.execute( + """INSERT INTO history(user_id, task_id, query, results) + VALUES(?, ?, ?, ?) + """, (str(token.user.user_id), str(task_id["task_id"]), + query, + json.dumps(response)) + ) + return response + + +@gnqa.route("/rating/<task_id>", methods=["POST"]) @require_oauth("profile") -def rating(task_id): - try: - llm_db_path = current_app.config["LLM_DB_PATH"] - with (require_oauth.acquire("profile") as token, - db.connection(llm_db_path) as conn): - - results = request.json - user_id, query, answer, weight = (token.user.user_id, - results.get("query"), - results.get("answer"), - results.get("weight", 0)) - cursor = conn.cursor() - create_table = """CREATE TABLE IF NOT EXISTS Rating( - user_id TEXT NOT NULL, - query TEXT NOT NULL, - answer TEXT NOT NULL, - weight INTEGER NOT NULL DEFAULT 0, - task_id TEXT NOT NULL UNIQUE - )""" - cursor.execute(create_table) - cursor.execute("""INSERT INTO Rating(user_id,query,answer,weight,task_id) - VALUES(?,?,?,?,?) - ON CONFLICT(task_id) DO UPDATE SET - weight=excluded.weight - """, (str(user_id), query, answer, weight, task_id)) +def rate_queries(task_id): + """Api endpoint for rating GNQA query and answer""" + with (require_oauth.acquire("profile") as token, + db.connection(current_app.config["LLM_DB_PATH"]) as conn): + results = request.json + user_id, query, answer, weight = (token.user.user_id, + results.get("query"), + results.get("answer"), + results.get("weight", 0)) + cursor = conn.cursor() + create_table = """CREATE TABLE IF NOT EXISTS Rating( + user_id TEXT NOT NULL, + query TEXT NOT NULL, + answer TEXT NOT NULL, + weight INTEGER NOT NULL DEFAULT 0, + task_id TEXT NOT NULL UNIQUE, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY(task_id))""" + cursor.execute(create_table) + cursor.execute("""INSERT INTO Rating(user_id, query, + answer, weight, task_id) + VALUES(?, ?, ?, ?, ?) + ON CONFLICT(task_id) DO UPDATE SET + weight=excluded.weight + """, (str(user_id), query, answer, weight, task_id)) return { - "message": "You have successfully rated this query:Thank you!!" - }, 200 - except sqlite3.Error as error: - return jsonify({"error": str(error)}), 500 - except Exception as error: - raise error + "message": "You have successfully rated this query. Thank you!" + }, 200 -@GnQNA.route("/history/<query>", methods=["GET"]) +@gnqa.route("/search/records", methods=["GET"]) @require_oauth("profile user") -@handle_errors -def fetch_user_hist(query): - - with (require_oauth.acquire("profile user") as the_token, Redis.from_url(current_app.config["REDIS_URI"], - decode_responses=True) as redis_conn): - return jsonify({ - **fetch_query_results(query, the_token.user.id, redis_conn), - "prev_queries": get_user_queries("random_user", redis_conn) - }) - - -@GnQNA.route("/historys/<query>", methods=["GET"]) -@handle_errors -def fetch_users_hist_records(query): - """method to fetch all users hist:note this is a test functionality to be replaced by fetch_user_hist""" - - with Redis.from_url(current_app.config["REDIS_URI"], decode_responses=True) as redis_conn: - return jsonify({ - **fetch_query_results(query, "random_user", redis_conn), - "prev_queries": get_user_queries("random_user", redis_conn) - }) - - -@GnQNA.route("/get_hist_names", methods=["GET"]) -@handle_errors -def fetch_prev_hist_ids(): - - with (Redis.from_url(current_app.config["REDIS_URI"], decode_responses=True)) as redis_conn: - return jsonify({"prev_queries": get_user_queries("random_user", redis_conn)}) +def get_user_search_records(): + """get all history records for a given user using their + user id + """ + with (require_oauth.acquire("profile user") as token, + db.connection(current_app.config["LLM_DB_PATH"]) as conn): + cursor = conn.cursor() + cursor.execute( + """SELECT task_id, query, created_at from history WHERE user_id=?""", + (str(token.user.user_id),)) + results = [dict(item) for item in cursor.fetchall()] + return jsonify(sorted(results, reverse=True, + key=lambda x: datetime.strptime(x.get("created_at"), + '%Y-%m-%d %H:%M:%S'))) + + +@gnqa.route("/search/record/<task_id>", methods=["GET"]) +@require_oauth("profile user") +def get_user_record_by_task(task_id): + """Get user previous search record by task id """ + with (require_oauth.acquire("profile user") as token, + db.connection(current_app.config["LLM_DB_PATH"]) as conn): + cursor = conn.cursor() + cursor.execute( + """SELECT results from history + Where task_id=? and user_id=?""", + (task_id, + str(token.user.user_id),)) + record = cursor.fetchone() + if record: + return dict(record).get("results") + return {} + + +@gnqa.route("/search/record/<task_id>", methods=["DELETE"]) +@require_oauth("profile user") +def delete_record(task_id): + """Delete user previous seach record by task-id""" + with (require_oauth.acquire("profile user") as token, + db.connection(current_app.config["LLM_DB_PATH"]) as conn): + cursor = conn.cursor() + query = """DELETE FROM history + WHERE task_id=? and user_id=?""" + cursor.execute(query, (task_id, token.user.user_id,)) + return {"msg": f"Successfully Deleted the task {task_id}"} + + +@gnqa.route("/search/records", methods=["DELETE"]) +@require_oauth("profile user") +def delete_records(): + """ Delete a users records using for all given task ids""" + with (require_oauth.acquire("profile user") as token, + db.connection(current_app.config["LLM_DB_PATH"]) as conn): + task_ids = list(request.json.values()) + cursor = conn.cursor() + query = """DELETE FROM history + WHERE task_id IN ({}) + and user_id=?""".format(",".join("?" * len(task_ids))) + cursor.execute(query, (*task_ids, str(token.user.user_id),)) + return jsonify({}) |