aboutsummaryrefslogtreecommitdiff
path: root/gn3/api/llm.py
blob: 1e84bcfe512f823fbeb674d41011e34c479e8f50 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
"""API for data used to generate menus"""

# pylint: skip-file

from flask import jsonify, request, Blueprint, current_app

from functools import wraps
from gn3.auth.authorisation.oauth2.resource_server import require_oauth

from gn3.llms.process import get_gnqa
from gn3.llms.process import rate_document
from gn3.llms.process import get_user_queries
from gn3.llms.process import fetch_query_results
from gn3.auth.authorisation.oauth2.resource_server import require_oauth
from gn3.auth import db
from gn3.settings import SQLITE_DB_PATH

from redis import Redis
import os
import json
import logging
import sqlite3
from datetime import timedelta

GnQNA = Blueprint("GnQNA", __name__)


def handle_errors(func):
    @wraps(func)
    def decorated_function(*args, **kwargs):
        try:
            return func(*args, **kwargs)
        except Exception as error:
            return jsonify({"error": str(error)}), 500
    return decorated_function


@GnQNA.route("/gnqna", methods=["POST"])
def gnqa():
    # todo  add auth
    query = request.json.get("querygnqa", "")
    if not query:
        return jsonify({"error": "querygnqa is missing in the request"}), 400

    try:
        auth_token = current_app.config.get("FAHAMU_AUTH_TOKEN")
        task_id, answer, refs = get_gnqa(
            query, auth_token, current_app.config.get("DATA_DIR"))

        response = {
            "task_id": task_id,
            "query": query,
            "answer": answer,
            "references": refs
        }
        with (Redis.from_url(current_app.config["REDIS_URI"],
                             decode_responses=True) as redis_conn):
            # The key will be deleted after 60 seconds
            redis_conn.setex(f"LLM:random_user-{query}", timedelta(days=10), json.dumps(response))
        return jsonify({
            **response,
            "prev_queries": get_user_queries("random_user", redis_conn)
        })
    except Exception as error:
        return jsonify({"query": query, "error": f"Request failed-{str(error)}"}), 500


@GnQNA.route("/rating/<task_id>", methods=["POST"])
@require_oauth("profile")
def rating(task_id):
    try:
        with require_oauth.acquire("profile") as the_token:
            user = the_token.user.user_id
            results = request.json
            user_id, query, answer, weight = (the_token.user.user_id,
                                              results.get("query"),
                                              results.get("answer"),
                                              results.get("weight", 0))

            with db.connection(os.path.join(SQLITE_DB_PATH, "llm.db")) as conn:
                cursor = conn.cursor()
                create_table = """CREATE TABLE IF NOT EXISTS Rating(
                      user_id TEXT NOT NULL,
                      query TEXT NOT NULL,
                      answer TEXT NOT NULL,
                      weight INTEGER NOT NULL DEFAULT 0,
                      task_id TEXT NOT NULL UNIQUE
                      )"""
                cursor.execute(create_table)
                cursor.execute("""INSERT INTO Rating(user_id,query,answer,weight,task_id)
                VALUES(?,?,?,?,?)
                ON CONFLICT(task_id) DO UPDATE SET
                weight=excluded.weight
                """, (str(user_id), query, answer, weight, task_id))
                return {
                    "message": "success",
                    "status": 0
                }, 200
    except sqlite3.Error as error:
        raise error


@GnQNA.route("/history/<query>", methods=["GET"])
@require_oauth("profile user")
@handle_errors
def fetch_user_hist(query):

    with (require_oauth.acquire("profile user") as the_token, Redis.from_url(current_app.config["REDIS_URI"],
                                                                             decode_responses=True) as redis_conn):
        return jsonify({
            **fetch_query_results(query, the_token.user.id, redis_conn),
            "prev_queries": get_user_queries("random_user", redis_conn)
        })


@GnQNA.route("/historys/<query>", methods=["GET"])
@handle_errors
def fetch_users_hist_records(query):
    """method to fetch all users hist:note this is a test functionality to be replaced by fetch_user_hist"""

    with Redis.from_url(current_app.config["REDIS_URI"], decode_responses=True) as redis_conn:
        return jsonify({
            **fetch_query_results(query, "random_user", redis_conn),
            "prev_queries": get_user_queries("random_user", redis_conn)
        })


@GnQNA.route("/get_hist_names", methods=["GET"])
@handle_errors
def fetch_prev_hist_ids():

    with (Redis.from_url(current_app.config["REDIS_URI"], decode_responses=True)) as redis_conn:
        return jsonify({"prev_queries": get_user_queries("random_user", redis_conn)})