diff options
author | Alexander Kabui | 2024-01-17 15:41:45 +0300 |
---|---|---|
committer | GitHub | 2024-01-17 15:41:45 +0300 |
commit | b61d969711c2f0d3c7a0e09f965c7312c945b0c1 (patch) | |
tree | 30e31a14f9b3e6aadf2d4b944add4520f4d364a2 | |
parent | 482a8908dc08d6e5a13e576c4ba4bc3ff934bb8d (diff) | |
download | genenetwork3-b61d969711c2f0d3c7a0e09f965c7312c945b0c1.tar.gz |
Feature/gn llm refactoring (#147)
* refactor code for processing response from fahamu client
* Add tests for gn-llm
-rw-r--r-- | gn3/llms/process.py | 158 | ||||
-rw-r--r-- | tests/unit/test_llm.py | 97 |
2 files changed, 153 insertions, 102 deletions
diff --git a/gn3/llms/process.py b/gn3/llms/process.py index f4a55cf..bd10c19 100644 --- a/gn3/llms/process.py +++ b/gn3/llms/process.py @@ -1,132 +1,86 @@ -# pylint: skip-file +"""this module contains code for processing response from fahamu client.py""" -import requests -import sys -import time import string import json -import os -from urllib.request import urlretrieve -from urllib.parse import quote from urllib.parse import urljoin +from urllib.parse import quote +import requests + from gn3.llms.client import GeneNetworkQAClient from gn3.llms.response import DocIDs - -baseUrl = 'https://genenetwork.fahamuai.com/api/tasks' -answerUrl = baseUrl + '/answers' -basedir = os.path.abspath(os.path.dirname(__file__)) - - -def formatBibliographyInfo(bibInfo): - if isinstance(bibInfo, str): - # remove '.txt' - bibInfo = bibInfo.removesuffix('.txt') - elif isinstance(bibInfo, dict): - # format string bibliography information - bibInfo = "{0}.{1}.{2}.{3} ".format( - bibInfo['author'], bibInfo['title'], bibInfo['year'], bibInfo['doi']) - return bibInfo +BASE_URL = 'https://genenetwork.fahamuai.com/api/tasks' -def askTheDocuments(extendUrl, my_auth): - try: - res = requests.post(baseUrl+extendUrl, - data={}, - headers=my_auth) - res.raise_for_status() - except: - raise # what - if (res.status_code != 200): - return negativeStatusMsg(res), 0 - task_id = getTaskIDFromResult(res) - res = getAnswerUsingTaskID(task_id, my_auth) - if (res.status_code != 200): - return negativeStatusMsg(res), 0 - return res, 1 - - -def getAnswerUsingTaskID(extendUrl, my_auth): - try: - res = requests.get(answerUrl+extendUrl, data={}, headers=my_auth) - res.raise_for_status() - except: - raise - return res - - -def openAPIConfig(): - f = open(os.path.join(basedir, "api.config.json"), "rb") - result = json.load(f) - f.close() - return result +# pylint: disable=C0301 -def getTaskIDFromResult(res): - task_id = json.loads(res.text) - result = '?task_id=' + str(task_id['task_id']) - return result +def format_bibliography_info(bib_info): + """Function for formatting bibliography info""" + if isinstance(bib_info, str): + return bib_info.removesuffix('.txt') + elif isinstance(bib_info, dict): + return f"{bib_info['author']}.{bib_info['title']}.{bib_info['year']}.{bib_info['doi']} " + return bib_info -def negativeStatusMsg(res): - # mypy: ignore - return 'Problems\n\tStatus code => {0}\n\tReason=> {1}'.format(res.status_code, res.reason) +def filter_response_text(val): + """helper function for filtering non-printable chars""" + return json.loads(''.join([str(char) + for char in val if char in string.printable])) -def filterResponseText(val): - return json.loads(''.join([str(char) for char in val if char in string.printable])) - - -def getGNQA(query, auth_token): - apiClient = GeneNetworkQAClient(requests.Session(), api_key=auth_token) - res, task_id = apiClient.ask('?ask=' + quote(query), auth_token) - res, success = apiClient.get_answer(task_id) - - if (success == 1): - respText = filterResponseText(res.text) - if respText.get("data") is None: - return "Unfortunately I have nothing on the query", [] - answer = respText['data']['answer'] - context = respText['data']['context'] - references = parse_context(context) - return task_id, answer, references - else: - return task_id, res, "Unfortunately I have nothing." - - -def parse_context(context): - """parse content map id to reference""" - result = [] +def parse_context(context, get_info_func, format_bib_func): + """function to parse doc_ids content""" + results = [] for doc_ids, summary in context.items(): - comboTxt = "" + combo_txt = "" for entry in summary: - comboTxt += '\t' + entry['text'] - - docInfo = DocIDs().getInfo(doc_ids) - if doc_ids != docInfo: - bibInfo = formatBibliographyInfo(docInfo) - - else: - bibInfo = doc_ids - result.append( - {"doc_id": doc_ids, "bibInfo": bibInfo, "comboTxt": comboTxt}) - return result + combo_txt += "\t" + entry["text"] + doc_info = get_info_func(doc_ids) + bib_info = doc_ids if doc_ids == doc_info else format_bib_func( + doc_info) + results.append( + {"doc_id": doc_ids, "bibInfo": bib_info, "comboTxt": combo_txt}) + return results def rate_document(task_id, doc_id, rating, auth_token): """This method is used to provide feedback for a document by making a rating.""" + # todo move this to clients try: - resp = requests.post( - urljoin(baseUrl, f"/feedback?task_id={task_id}&document_id={doc_id}&feedback={rating}"), - headers={"Authorization": f"Bearer {auth_token}"} - ) + url = urljoin(BASE_URL, + f"""/feedback?task_id={task_id}&document_id={doc_id}&feedback={rating}""") + headers = {"Authorization": f"Bearer {auth_token}"} + + resp = requests.post(url, headers=headers) resp.raise_for_status() + return {"status": "success", **resp.json()} except requests.exceptions.HTTPError as http_error: - raise RuntimeError(f"HTTP Error Occurred: {http_error.response.text} -with status code- {http_error.response.status_code}") + raise RuntimeError(f"HTTP Error Occurred:\ + {http_error.response.text} -with status code- {http_error.response.status_code}") from http_error except Exception as error: + raise RuntimeError(f"An error occurred: {str(error)}") from error + - raise RuntimeError(f"An error occurred: {str(error)}") +def get_gnqa(query, auth_token): + """entry function for the gn3 api endpoint()""" + + api_client = GeneNetworkQAClient(requests.Session(), api_key=auth_token) + res, task_id = api_client.ask('?ask=' + quote(query), auth_token) + res, success = api_client.get_answer(task_id) + if success == 1: + resp_text = filter_response_text(res.text) + if resp_text.get("data") is None: + return "Unfortunately I have nothing on the query", [] + answer = resp_text['data']['answer'] + context = resp_text['data']['context'] + references = parse_context( + context, DocIDs().getInfo, format_bibliography_info) + return task_id, answer, references + else: + return task_id, res, "Unfortunately I have nothing." diff --git a/tests/unit/test_llm.py b/tests/unit/test_llm.py new file mode 100644 index 0000000..a29190a --- /dev/null +++ b/tests/unit/test_llm.py @@ -0,0 +1,97 @@ +"""Test cases for procedures defined in llms module""" +import pytest +from dataclasses import dataclass +from gn3.llms.process import get_gnqa +from gn3.llms.process import parse_context + + +@pytest.fixture +def context_data(): + return { + "doc1": [{"text": "Summary 1"}, {"text": "Summary 2"}], + "doc2": [{"text": "Summary 3"}, {"text": "Summary 4"}], + } + + +@pytest.mark.unit_test +def test_parse_context(context_data): + def mock_get_info(doc_id): + return f"Info for {doc_id}" + + def mock_format_bib(doc_info): + return f"Formatted Bibliography: {doc_info}" + + parsed_result = parse_context(context_data, mock_get_info, mock_format_bib) + + expected_result = [ + { + "doc_id": "doc1", + "bibInfo": "Formatted Bibliography: Info for doc1", + "comboTxt": "\tSummary 1\tSummary 2", + }, + { + "doc_id": "doc2", + "bibInfo": "Formatted Bibliography: Info for doc2", + "comboTxt": "\tSummary 3\tSummary 4", + }, + ] + + assert parsed_result == expected_result + +@dataclass(frozen=True) +class MockResponse: + text: str + + def __getattr__(self, name: str): + return self.__dict__[f"_{name}"] + +class MockGeneNetworkQAClient: + def __init__(self, session, api_key): + pass + + def ask(self, query, auth_token): + # Simulate the ask method + return MockResponse("Mock response"), "F400995EAFE104EA72A5927CE10C73B7" + + def get_answer(self, task_id): + # Simulate the get_answer method + return MockResponse("Mock answer"), 1 + + +def mock_filter_response_text(text): + """ method to simulate the filterResponseText method""" + return {"data": {"answer": "Mock answer for what is a gene", "context": {}}} + + +def mock_parse_context(context, get_info_func, format_bib_func): + """method to simulate the parse context method""" + return [] + + +@pytest.mark.unit_test +def test_get_gnqa(monkeypatch): + monkeypatch.setattr( + "gn3.llms.process.GeneNetworkQAClient", + MockGeneNetworkQAClient + ) + + monkeypatch.setattr( + 'gn3.llms.process.filter_response_text', + mock_filter_response_text + ) + monkeypatch.setattr( + 'gn3.llms.process.parse_context', + mock_parse_context + ) + + query = "What is a gene" + auth_token = "test_token" + result = get_gnqa(query, auth_token) + + expected_result = ( + "F400995EAFE104EA72A5927CE10C73B7", + 'Mock answer for what is a gene', + [] + ) + + assert result == expected_result |