diff options
-rw-r--r-- | gn3/api/llm.py | 81 |
1 files changed, 55 insertions, 26 deletions
diff --git a/gn3/api/llm.py b/gn3/api/llm.py index e23429e..b9ffbb2 100644 --- a/gn3/api/llm.py +++ b/gn3/api/llm.py @@ -15,35 +15,72 @@ from gn3.auth import db gnqa = Blueprint("gnqa", __name__) +HISTORY_TABLE_CREATE_QUERY = """ +CREATE TABLE IF NOT EXISTS history( + user_id TEXT NOT NULL, + task_id TEXT NOT NULL, + query TEXT NOT NULL, + results JSONB, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY(task_id) + ) WITHOUT ROWID +""" + +RATING_TABLE_CREATE_QUERY = """ +CREATE TABLE IF NOT EXISTS Rating( + user_id TEXT NOT NULL, + query TEXT NOT NULL, + answer TEXT NOT NULL, + weight INTEGER NOT NULL DEFAULT 0, + task_id TEXT NOT NULL UNIQUE, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY(task_id) + ) +""" + + +def database_setup(): + """Temporary method to remove the need to have CREATE queries in functions""" + with db.connection(current_app.config["LLM_DB_PATH"]) as conn: + cursor = conn.cursor() + cursor.execute(HISTORY_TABLE_CREATE_QUERY) + cursor.execute(RATING_TABLE_CREATE_QUERY) + @gnqa.route("/search", methods=["GET"]) def search(): """Api endpoint for searching queries in fahamu Api""" query = request.args.get("query", "") if not query: - return jsonify({"error": "querygnqa is missing in the request"}), 400 + return jsonify({"error": "query get parameter is missing in the request"}), 400 fahamu_token = current_app.config.get("FAHAMU_AUTH_TOKEN") if not fahamu_token: raise LLMError( "Request failed: an LLM authorisation token is required ", query) - task_id, answer, refs = get_gnqa( - query, fahamu_token, current_app.config.get("DATA_DIR")) - response = { - "task_id": task_id, - "query": query, - "answer": answer, - "references": refs - } + database_setup() with (db.connection(current_app.config["LLM_DB_PATH"]) as conn, require_oauth.acquire("profile user") as token): cursor = conn.cursor() - cursor.execute("""CREATE TABLE IF NOT EXISTS - history(user_id TEXT NOT NULL, - task_id TEXT NOT NULL, - query TEXT NOT NULL, - results JSONB, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - PRIMARY KEY(task_id)) WITHOUT ROWID""") + previous_answer_query = """ + SELECT user_id, task_id, query, results FROM history + WHERE created_at > DATE('now', '-1 day') AND + user_id = ? AND + query = ? + ORDER BY created_at DESC LIMIT 1 """ + res = cursor.execute(previous_answer_query, (str(token.user.user_id), query)) + previous_result = res.fetchone() + if previous_result: + _, _, _, response = previous_result + return response + + task_id, answer, refs = get_gnqa( + query, fahamu_token, current_app.config.get("DATA_DIR")) + response = { + "task_id": task_id, + "query": query, + "answer": answer, + "references": refs + } cursor.execute( """INSERT INTO history(user_id, task_id, query, results) VALUES(?, ?, ?, ?) @@ -51,13 +88,14 @@ def search(): query, json.dumps(response)) ) - return response + return response @gnqa.route("/rating/<task_id>", methods=["POST"]) @require_oauth("profile") def rate_queries(task_id): """Api endpoint for rating GNQA query and answer""" + database_setup() with (require_oauth.acquire("profile") as token, db.connection(current_app.config["LLM_DB_PATH"]) as conn): results = request.json @@ -66,15 +104,6 @@ def rate_queries(task_id): results.get("answer"), results.get("weight", 0)) cursor = conn.cursor() - create_table = """CREATE TABLE IF NOT EXISTS Rating( - user_id TEXT NOT NULL, - query TEXT NOT NULL, - answer TEXT NOT NULL, - weight INTEGER NOT NULL DEFAULT 0, - task_id TEXT NOT NULL UNIQUE, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - PRIMARY KEY(task_id))""" - cursor.execute(create_table) cursor.execute("""INSERT INTO Rating(user_id, query, answer, weight, task_id) VALUES(?, ?, ?, ?, ?) |