aboutsummaryrefslogtreecommitdiff
path: root/gn3/llms/process.py
diff options
context:
space:
mode:
Diffstat (limited to 'gn3/llms/process.py')
-rw-r--r--gn3/llms/process.py96
1 files changed, 59 insertions, 37 deletions
diff --git a/gn3/llms/process.py b/gn3/llms/process.py
index e38b73e..4725bcb 100644
--- a/gn3/llms/process.py
+++ b/gn3/llms/process.py
@@ -1,21 +1,53 @@
"""this module contains code for processing response from fahamu client.py"""
+# pylint: disable=C0301
import os
import string
import json
-
-from urllib.parse import urljoin
-from urllib.parse import quote
import logging
-import requests
+from urllib.parse import quote
from gn3.llms.client import GeneNetworkQAClient
-from gn3.llms.response import DocIDs
BASE_URL = 'https://genenetwork.fahamuai.com/api/tasks'
-
-
-# pylint: disable=C0301
+BASEDIR = os.path.abspath(os.path.dirname(__file__))
+
+
+class DocIDs():
+ """ Class Method to Parse document id and names from files"""
+ def __init__(self):
+ """
+ init method for Docids
+ * doc_ids.json: opens doc)ids for gn references
+ * sugar_doc_ids: open doci_ids for diabetes references
+ """
+ self.doc_ids = self.load_file("doc_ids.json")
+ self.sugar_doc_ids = self.load_file("all_files.json")
+ self.format_doc_ids(self.sugar_doc_ids)
+
+ def load_file(self, file_name):
+ """Method to load and read doc_id files"""
+ file_path = os.path.join(BASEDIR, file_name)
+ if os.path.isfile(file_path):
+ with open(file_path, "rb") as file_handler:
+ return json.load(file_handler)
+ else:
+ raise FileNotFoundError(f"{file_path}-- FIle does not exist\n")
+
+ def format_doc_ids(self, docs):
+ """method to format doc_ids for list items"""
+ for _key, val in docs.items():
+ if isinstance(val, list):
+ for doc_obj in val:
+ doc_name = doc_obj["filename"].removesuffix(".pdf").removesuffix(".txt").replace("_", "")
+ self.doc_ids.update({doc_obj["id"]: doc_name})
+
+ def get_info(self, doc_id):
+ """ interface to make read from doc_ids"""
+ if doc_id in self.doc_ids.keys():
+ return self.doc_ids[doc_id]
+ else:
+ return doc_id
def format_bibliography_info(bib_info):
@@ -48,25 +80,6 @@ def parse_context(context, get_info_func, format_bib_func):
return results
-def rate_document(task_id, doc_id, rating, auth_token):
- """This method is used to provide feedback for a document by making a rating."""
- # todo move this to clients
- try:
- url = urljoin(BASE_URL,
- f"""/feedback?task_id={task_id}&document_id={doc_id}&feedback={rating}""")
- headers = {"Authorization": f"Bearer {auth_token}"}
-
- resp = requests.post(url, headers=headers)
- resp.raise_for_status()
-
- return {"status": "success", **resp.json()}
- except requests.exceptions.HTTPError as http_error:
- raise RuntimeError(f"HTTP Error Occurred:\
- {http_error.response.text} -with status code- {http_error.response.status_code}") from http_error
- except Exception as error:
- raise RuntimeError(f"An error occurred: {str(error)}") from error
-
-
def load_file(filename, dir_path):
"""function to open and load json file"""
file_path = os.path.join(dir_path, f"{filename}")
@@ -92,27 +105,36 @@ def fetch_pubmed(references, file_name, data_dir=""):
return references
-def get_gnqa(query, auth_token, tmp_dir=""):
- """entry function for the gn3 api endpoint()"""
+def get_gnqa(query, auth_token, data_dir=""):
+ """entry function for the gn3 api endpoint()
+ ARGS:
+ query: what is a gene
+ auth_token: token to connect to api_client
+ data_dir: base datirectory for gn3 data
+ Returns:
+ task_id: fahamu unique identifier for task
+ answer
+ references: contains doc_name,reference,pub_med_info
+ """
- api_client = GeneNetworkQAClient(requests.Session(), api_key=auth_token)
+ api_client = GeneNetworkQAClient(api_key=auth_token)
res, task_id = api_client.ask('?ask=' + quote(query), auth_token)
if task_id == 0:
raise RuntimeError(f"Error connecting to Fahamu Api: {str(res)}")
- res, success = api_client.get_answer(task_id)
- if success == 1:
+ res, status = api_client.get_answer(task_id)
+ if status == 1:
resp_text = filter_response_text(res.text)
if resp_text.get("data") is None:
return task_id, "Please try to rephrase your question to receive feedback", []
answer = resp_text['data']['answer']
context = resp_text['data']['context']
references = parse_context(
- context, DocIDs().getInfo, format_bibliography_info)
- references = fetch_pubmed(references, "pubmed.json", tmp_dir)
+ context, DocIDs().get_info, format_bibliography_info)
+ references = fetch_pubmed(references, "pubmed.json", data_dir)
return task_id, answer, references
else:
- return task_id, "Please try to rephrase your question to receive feedback", []
+ return task_id, "We couldn't provide a response,Please try to rephrase your question to receive feedback", []
def fetch_query_results(query, user_id, redis_conn):
@@ -130,6 +152,6 @@ def fetch_query_results(query, user_id, redis_conn):
def get_user_queries(user_id, redis_conn):
"""methos to fetch all queries for a specific user"""
-
results = redis_conn.keys(f"LLM:{user_id}*")
- return [query for query in [result.partition("-")[2] for result in results] if query != ""]
+ return [query for query in
+ [result.partition("-")[2] for result in results] if query != ""]