diff options
author | Alexander Kabui | 2024-01-17 15:41:45 +0300 |
---|---|---|
committer | GitHub | 2024-01-17 15:41:45 +0300 |
commit | b61d969711c2f0d3c7a0e09f965c7312c945b0c1 (patch) | |
tree | 30e31a14f9b3e6aadf2d4b944add4520f4d364a2 /gn3/llms | |
parent | 482a8908dc08d6e5a13e576c4ba4bc3ff934bb8d (diff) | |
download | genenetwork3-b61d969711c2f0d3c7a0e09f965c7312c945b0c1.tar.gz |
Feature/gn llm refactoring (#147)
* refactor code for processing response from fahamu client
* Add tests for gn-llm
Diffstat (limited to 'gn3/llms')
-rw-r--r-- | gn3/llms/process.py | 158 |
1 files changed, 56 insertions, 102 deletions
diff --git a/gn3/llms/process.py b/gn3/llms/process.py index f4a55cf..bd10c19 100644 --- a/gn3/llms/process.py +++ b/gn3/llms/process.py @@ -1,132 +1,86 @@ -# pylint: skip-file +"""this module contains code for processing response from fahamu client.py""" -import requests -import sys -import time import string import json -import os -from urllib.request import urlretrieve -from urllib.parse import quote from urllib.parse import urljoin +from urllib.parse import quote +import requests + from gn3.llms.client import GeneNetworkQAClient from gn3.llms.response import DocIDs - -baseUrl = 'https://genenetwork.fahamuai.com/api/tasks' -answerUrl = baseUrl + '/answers' -basedir = os.path.abspath(os.path.dirname(__file__)) - - -def formatBibliographyInfo(bibInfo): - if isinstance(bibInfo, str): - # remove '.txt' - bibInfo = bibInfo.removesuffix('.txt') - elif isinstance(bibInfo, dict): - # format string bibliography information - bibInfo = "{0}.{1}.{2}.{3} ".format( - bibInfo['author'], bibInfo['title'], bibInfo['year'], bibInfo['doi']) - return bibInfo +BASE_URL = 'https://genenetwork.fahamuai.com/api/tasks' -def askTheDocuments(extendUrl, my_auth): - try: - res = requests.post(baseUrl+extendUrl, - data={}, - headers=my_auth) - res.raise_for_status() - except: - raise # what - if (res.status_code != 200): - return negativeStatusMsg(res), 0 - task_id = getTaskIDFromResult(res) - res = getAnswerUsingTaskID(task_id, my_auth) - if (res.status_code != 200): - return negativeStatusMsg(res), 0 - return res, 1 - - -def getAnswerUsingTaskID(extendUrl, my_auth): - try: - res = requests.get(answerUrl+extendUrl, data={}, headers=my_auth) - res.raise_for_status() - except: - raise - return res - - -def openAPIConfig(): - f = open(os.path.join(basedir, "api.config.json"), "rb") - result = json.load(f) - f.close() - return result +# pylint: disable=C0301 -def getTaskIDFromResult(res): - task_id = json.loads(res.text) - result = '?task_id=' + str(task_id['task_id']) - return result +def format_bibliography_info(bib_info): + """Function for formatting bibliography info""" + if isinstance(bib_info, str): + return bib_info.removesuffix('.txt') + elif isinstance(bib_info, dict): + return f"{bib_info['author']}.{bib_info['title']}.{bib_info['year']}.{bib_info['doi']} " + return bib_info -def negativeStatusMsg(res): - # mypy: ignore - return 'Problems\n\tStatus code => {0}\n\tReason=> {1}'.format(res.status_code, res.reason) +def filter_response_text(val): + """helper function for filtering non-printable chars""" + return json.loads(''.join([str(char) + for char in val if char in string.printable])) -def filterResponseText(val): - return json.loads(''.join([str(char) for char in val if char in string.printable])) - - -def getGNQA(query, auth_token): - apiClient = GeneNetworkQAClient(requests.Session(), api_key=auth_token) - res, task_id = apiClient.ask('?ask=' + quote(query), auth_token) - res, success = apiClient.get_answer(task_id) - - if (success == 1): - respText = filterResponseText(res.text) - if respText.get("data") is None: - return "Unfortunately I have nothing on the query", [] - answer = respText['data']['answer'] - context = respText['data']['context'] - references = parse_context(context) - return task_id, answer, references - else: - return task_id, res, "Unfortunately I have nothing." - - -def parse_context(context): - """parse content map id to reference""" - result = [] +def parse_context(context, get_info_func, format_bib_func): + """function to parse doc_ids content""" + results = [] for doc_ids, summary in context.items(): - comboTxt = "" + combo_txt = "" for entry in summary: - comboTxt += '\t' + entry['text'] - - docInfo = DocIDs().getInfo(doc_ids) - if doc_ids != docInfo: - bibInfo = formatBibliographyInfo(docInfo) - - else: - bibInfo = doc_ids - result.append( - {"doc_id": doc_ids, "bibInfo": bibInfo, "comboTxt": comboTxt}) - return result + combo_txt += "\t" + entry["text"] + doc_info = get_info_func(doc_ids) + bib_info = doc_ids if doc_ids == doc_info else format_bib_func( + doc_info) + results.append( + {"doc_id": doc_ids, "bibInfo": bib_info, "comboTxt": combo_txt}) + return results def rate_document(task_id, doc_id, rating, auth_token): """This method is used to provide feedback for a document by making a rating.""" + # todo move this to clients try: - resp = requests.post( - urljoin(baseUrl, f"/feedback?task_id={task_id}&document_id={doc_id}&feedback={rating}"), - headers={"Authorization": f"Bearer {auth_token}"} - ) + url = urljoin(BASE_URL, + f"""/feedback?task_id={task_id}&document_id={doc_id}&feedback={rating}""") + headers = {"Authorization": f"Bearer {auth_token}"} + + resp = requests.post(url, headers=headers) resp.raise_for_status() + return {"status": "success", **resp.json()} except requests.exceptions.HTTPError as http_error: - raise RuntimeError(f"HTTP Error Occurred: {http_error.response.text} -with status code- {http_error.response.status_code}") + raise RuntimeError(f"HTTP Error Occurred:\ + {http_error.response.text} -with status code- {http_error.response.status_code}") from http_error except Exception as error: + raise RuntimeError(f"An error occurred: {str(error)}") from error + - raise RuntimeError(f"An error occurred: {str(error)}") +def get_gnqa(query, auth_token): + """entry function for the gn3 api endpoint()""" + + api_client = GeneNetworkQAClient(requests.Session(), api_key=auth_token) + res, task_id = api_client.ask('?ask=' + quote(query), auth_token) + res, success = api_client.get_answer(task_id) + if success == 1: + resp_text = filter_response_text(res.text) + if resp_text.get("data") is None: + return "Unfortunately I have nothing on the query", [] + answer = resp_text['data']['answer'] + context = resp_text['data']['context'] + references = parse_context( + context, DocIDs().getInfo, format_bibliography_info) + return task_id, answer, references + else: + return task_id, res, "Unfortunately I have nothing." |