aboutsummaryrefslogtreecommitdiff
path: root/gnqa/paper1_eval/src/apis/process.py
blob: 37f2d73ceea7d02d53cd99dbf99b2eb81672d453 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
"""this module contains code for processing response from fahamu client.py"""
import os
import string
import json

from urllib.parse import urljoin
from urllib.parse import quote
import logging
import requests

from apis.gnqaclient import GeneNetworkQAClient
from apis.resp import DocIDs


BASE_URL = 'https://genenetwork.fahamuai.com/api/tasks'


# pylint: disable=C0301


def format_bibliography_info(bib_info):
    """Function for formatting bibliography info"""
    if isinstance(bib_info, str):
        return bib_info.removesuffix('.txt')
    elif isinstance(bib_info, dict):
        return f"{bib_info['author']}.{bib_info['title']}.{bib_info['year']}.{bib_info['doi']} "
    return bib_info


def filter_response_text(val):
    """helper function for filtering non-printable chars"""
    return json.loads(''.join([str(char)
                               for char in val if char in string.printable]))


def parse_context(context, get_info_func, format_bib_func):
    """function to parse doc_ids content"""
    results = []
    for doc_ids, summary in context.items():
        combo_txt = ""
        for entry in summary:
            combo_txt += "\t" + entry["text"]
        doc_info = get_info_func(doc_ids)
        bib_info = doc_ids if doc_ids == doc_info else format_bib_func(
            doc_info)
        results.append(
            {"doc_id": doc_ids, "bibInfo": bib_info, "comboTxt": combo_txt})
    return results


def rate_document(task_id, doc_id, rating, auth_token):
    """This method is used to provide feedback for a document by making a rating."""
    # todo move this to clients
    try:
        url = urljoin(BASE_URL,
                      f"""/feedback?task_id={task_id}&document_id={doc_id}&feedback={rating}""")
        headers = {"Authorization": f"Bearer {auth_token}"}

        resp = requests.post(url, headers=headers)
        resp.raise_for_status()

        return {"status": "success", **resp.json()}
    except requests.exceptions.HTTPError as http_error:
        raise RuntimeError(f"HTTP Error Occurred:\
            {http_error.response.text} -with status code- {http_error.response.status_code}") from http_error
    except Exception as error:
        raise RuntimeError(f"An error occurred: {str(error)}") from error


def load_file(filename, dir_path):
    """function to open and load json file"""
    file_path = os.path.join(dir_path, f"{filename}")
    if not os.path.isfile(file_path):
        raise FileNotFoundError(f"{filename} was not found or is a directory")
    with open(file_path, "rb") as file_handler:
        return json.load(file_handler)


def fetch_pubmed(references, file_name, data_dir=""):
    """method to fetch and populate references with pubmed"""

    try:
        pubmed = load_file(file_name, os.path.join(data_dir, "gn-meta/lit"))
        for reference in references:
            if pubmed.get(reference["doc_id"]):
                reference["pubmed"] = pubmed.get(reference["doc_id"])
        return references

    except FileNotFoundError:
        logging.error("failed to find pubmed_path for %s/%s",
                      data_dir, file_name)
        return references


def get_gnqa(query, auth_token, tmp_dir=""):
    """entry function for the gn3 api endpoint()"""

    api_client = GeneNetworkQAClient(requests.Session(), api_key=auth_token)
    res, task_id = api_client.ask('?ask=' + quote(query), auth_token)
    if task_id == 0:
        raise RuntimeError(f"Error connecting to Fahamu Api: {str(res)}")
    res, success = api_client.get_answer(task_id)
    if success == 1:
        resp_text = filter_response_text(res.text)
        if resp_text.get("data") is None:
            return task_id, "Please try to rephrase your question to receive feedback", []
        answer = resp_text['data']['answer']
        context = resp_text['data']['context']
        references = parse_context(
            context, DocIDs().getInfo, format_bibliography_info)
        #references = fetch_pubmed(references, "pubmed.json", tmp_dir)

        return task_id, answer, references
    else:
        return task_id, "Please try to rephrase your question to receive feedback", []
    
def get_response_from_taskid(auth_token, task_id):
    api_client = GeneNetworkQAClient(requests.Session(), api_key=auth_token)
    res, success = api_client.answer(task_id)
    if success == 1:
        resp_text = filter_response_text(res.text)
        if resp_text.get("data") is None:
            return task_id, "Please try to rephrase your question to receive feedback", []
        answer = resp_text['data']['answer']
        context = resp_text['data']['context']
        references = parse_context(
            context, DocIDs().getInfo, format_bibliography_info)
        #references = fetch_pubmed(references, "pubmed.json", tmp_dir)

        return task_id, answer, references
    else:
        return task_id, "Please try to rephrase your question to receive feedback", []


def fetch_query_results(query, user_id, redis_conn):
    """this method fetches prev user query searches"""
    result = redis_conn.get(f"LLM:{user_id}-{query}")
    if result:
        return json.loads(result)
    return {
        "query": query,
        "answer": "Sorry No answer for you",
        "references": [],
        "task_id": None
    }


def get_user_queries(user_id, redis_conn):
    """methods to fetch all queries for a specific user"""

    results = redis_conn.keys(f"LLM:{user_id}*")
    return [query for query in [result.partition("-")[2] for result in results] if query != ""]