about summary refs log tree commit diff
diff options
context:
space:
mode:
authorJohn Nduli2024-09-12 14:29:45 +0300
committerBonfaceKilz2024-09-12 18:35:40 +0300
commit82e5a814b5e0e855c467aa30d773d0927f0520ec (patch)
tree2cfaaff0d80dc7e62eae1b59e72b4775a6798318
parentc60bc85ed3ac69a14c1746cf11c73a8172da9308 (diff)
downloadgenenetwork3-82e5a814b5e0e855c467aa30d773d0927f0520ec.tar.gz
feat: pick results from sqlite3 if they were stored
-rw-r--r--gn3/api/llm.py81
1 files changed, 55 insertions, 26 deletions
diff --git a/gn3/api/llm.py b/gn3/api/llm.py
index e23429e..b9ffbb2 100644
--- a/gn3/api/llm.py
+++ b/gn3/api/llm.py
@@ -15,35 +15,72 @@ from gn3.auth import db
 
 gnqa = Blueprint("gnqa", __name__)
 
+HISTORY_TABLE_CREATE_QUERY = """
+CREATE TABLE IF NOT EXISTS history(
+    user_id TEXT NOT NULL,
+    task_id TEXT NOT NULL,
+    query TEXT NOT NULL,
+    results JSONB,
+    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+    PRIMARY KEY(task_id)
+    ) WITHOUT ROWID
+"""
+
+RATING_TABLE_CREATE_QUERY = """
+CREATE TABLE IF NOT EXISTS Rating(
+    user_id TEXT NOT NULL,
+    query TEXT NOT NULL,
+    answer TEXT NOT NULL,
+    weight INTEGER NOT NULL DEFAULT 0,
+    task_id TEXT NOT NULL UNIQUE,
+    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+    PRIMARY KEY(task_id)
+    )
+"""
+
+
+def database_setup():
+    """Temporary method to remove the need to have CREATE queries in functions"""
+    with db.connection(current_app.config["LLM_DB_PATH"]) as conn:
+        cursor = conn.cursor()
+        cursor.execute(HISTORY_TABLE_CREATE_QUERY)
+        cursor.execute(RATING_TABLE_CREATE_QUERY)
+
 
 @gnqa.route("/search", methods=["GET"])
 def search():
     """Api  endpoint for searching queries in fahamu Api"""
     query = request.args.get("query", "")
     if not query:
-        return jsonify({"error": "querygnqa is missing in the request"}), 400
+        return jsonify({"error": "query get parameter is missing in the request"}), 400
     fahamu_token = current_app.config.get("FAHAMU_AUTH_TOKEN")
     if not fahamu_token:
         raise LLMError(
             "Request failed: an LLM authorisation token  is required ", query)
-    task_id, answer, refs = get_gnqa(
-        query, fahamu_token, current_app.config.get("DATA_DIR"))
-    response = {
-        "task_id": task_id,
-        "query": query,
-        "answer": answer,
-        "references": refs
-    }
+    database_setup()
     with (db.connection(current_app.config["LLM_DB_PATH"]) as conn,
           require_oauth.acquire("profile user") as token):
         cursor = conn.cursor()
-        cursor.execute("""CREATE TABLE IF NOT EXISTS
-        history(user_id TEXT NOT NULL,
-        task_id TEXT NOT NULL,
-        query  TEXT NOT NULL,
-        results  JSONB,
-        created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
-        PRIMARY KEY(task_id)) WITHOUT ROWID""")
+        previous_answer_query = """
+        SELECT user_id, task_id, query, results FROM history
+            WHERE created_at > DATE('now', '-1 day') AND
+                user_id = ? AND
+                query = ?
+            ORDER BY created_at DESC LIMIT 1 """
+        res = cursor.execute(previous_answer_query, (str(token.user.user_id), query))
+        previous_result = res.fetchone()
+        if previous_result:
+            _, _, _, response = previous_result
+            return response
+
+        task_id, answer, refs = get_gnqa(
+            query, fahamu_token, current_app.config.get("DATA_DIR"))
+        response = {
+            "task_id": task_id,
+            "query": query,
+            "answer": answer,
+            "references": refs
+        }
         cursor.execute(
             """INSERT INTO history(user_id, task_id, query, results)
             VALUES(?, ?, ?, ?)
@@ -51,13 +88,14 @@ def search():
                   query,
                   json.dumps(response))
         )
-    return response
+        return response
 
 
 @gnqa.route("/rating/<task_id>", methods=["POST"])
 @require_oauth("profile")
 def rate_queries(task_id):
     """Api endpoint for rating GNQA query and answer"""
+    database_setup()
     with (require_oauth.acquire("profile") as token,
           db.connection(current_app.config["LLM_DB_PATH"]) as conn):
         results = request.json
@@ -66,15 +104,6 @@ def rate_queries(task_id):
                                           results.get("answer"),
                                           results.get("weight", 0))
         cursor = conn.cursor()
-        create_table = """CREATE TABLE IF NOT EXISTS Rating(
-              user_id TEXT NOT NULL,
-              query TEXT NOT NULL,
-              answer TEXT NOT NULL,
-              weight INTEGER NOT NULL DEFAULT 0,
-              task_id TEXT NOT NULL UNIQUE,
-              created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
-              PRIMARY KEY(task_id))"""
-        cursor.execute(create_table)
         cursor.execute("""INSERT INTO Rating(user_id, query,
         answer, weight, task_id)
         VALUES(?, ?, ?, ?, ?)