diff options
-rw-r--r-- | gn3/api/llm.py | 47 |
1 files changed, 35 insertions, 12 deletions
diff --git a/gn3/api/llm.py b/gn3/api/llm.py index 98ff4e2..aff01cb 100644 --- a/gn3/api/llm.py +++ b/gn3/api/llm.py @@ -11,9 +11,13 @@ from gn3.llms.process import get_gnqa from gn3.llms.process import rate_document from gn3.llms.process import get_user_queries from gn3.llms.process import fetch_query_results +from gn3.auth.authorisation.oauth2.resource_server import require_oauth +from gn3.auth import db + from redis import Redis import json +import sqlite3 from datetime import timedelta GnQNA = Blueprint("GnQNA", __name__) @@ -58,19 +62,38 @@ def gnqa(): return jsonify({"query": query, "error": f"Request failed-{str(error)}"}), 500 -@GnQNA.route("/rating/<task_id>/<doc_id>/<int:rating>", methods=["POST"]) -def rating(task_id, doc_id, rating): +@GnQNA.route("/rating/<task_id>", methods=["POST"]) +@require_oauth("profile") +def rating(task_id): try: - results = rate_document(task_id, doc_id, rating, - current_app.config.get("FAHAMU_AUTH_TOKEN")) - - return jsonify({ - **results, - "doc_id": doc_id, - "task_id": task_id, - }), - except Exception as error: - return jsonify({"error": str(error), doc_id: doc_id}), 500 + with require_oauth.acquire("profile") as the_token: + user = the_token.user.user_id + results = request.json + user_id, query, answer, weight = (the_token.user.user_id, + results.get("query"), + results.get("answer"), + results.get("weight", 0)) + with db.connection(current_app.config["GNQA_DB"]) as conn: + cursor = conn.cursor() + create_table = """CREATE TABLE IF NOT EXISTS Rating( + user_id INTEGER NOT NULL, + query TEXT NOT NULL, + answer TEXT NOT NULL, + weight INTEGER NOT NULL DEFAULT 0, + task_id TEXT NOT NULL UNIQUE + )""" + cursor.execute(create_table) + cursor.execute("""INSERT INTO Rating(user_id,query,answer,weight,task_id) + VALUES(?,?,?,?,?) + ON CONFLICT(task_id) DO UPDATE SET + weight=excluded.weight + """, (user_id, query, answer, weight, task_id)) + return { + "message": "success", + "status": 0 + }, 200 + except sqlite3.Error as error: + raise error @GnQNA.route("/history/<query>", methods=["GET"]) |