about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--gn3/llms/process.py158
-rw-r--r--tests/unit/test_llm.py97
2 files changed, 153 insertions, 102 deletions
diff --git a/gn3/llms/process.py b/gn3/llms/process.py
index f4a55cf..bd10c19 100644
--- a/gn3/llms/process.py
+++ b/gn3/llms/process.py
@@ -1,132 +1,86 @@
 
-# pylint: skip-file
+"""this module contains code for processing response from fahamu client.py"""
 
-import requests
-import sys
-import time
 import string
 import json
-import os
-from urllib.request import urlretrieve
-from urllib.parse import quote
 
 from urllib.parse import urljoin
+from urllib.parse import quote
+import requests
+
 
 from gn3.llms.client import GeneNetworkQAClient
 from gn3.llms.response import DocIDs
 
-
-baseUrl = 'https://genenetwork.fahamuai.com/api/tasks'
-answerUrl = baseUrl + '/answers'
-basedir = os.path.abspath(os.path.dirname(__file__))
-
-
-def formatBibliographyInfo(bibInfo):
-    if isinstance(bibInfo, str):
-        # remove '.txt'
-        bibInfo = bibInfo.removesuffix('.txt')
-    elif isinstance(bibInfo, dict):
-        # format string bibliography information
-        bibInfo = "{0}.{1}.{2}.{3} ".format(
-            bibInfo['author'], bibInfo['title'], bibInfo['year'], bibInfo['doi'])
-    return bibInfo
+BASE_URL = 'https://genenetwork.fahamuai.com/api/tasks'
 
 
-def askTheDocuments(extendUrl, my_auth):
-    try:
-        res = requests.post(baseUrl+extendUrl,
-                            data={},
-                            headers=my_auth)
-        res.raise_for_status()
-    except:
-        raise  # what
-    if (res.status_code != 200):
-        return negativeStatusMsg(res), 0
-    task_id = getTaskIDFromResult(res)
-    res = getAnswerUsingTaskID(task_id, my_auth)
-    if (res.status_code != 200):
-        return negativeStatusMsg(res), 0
-    return res, 1
-
-
-def getAnswerUsingTaskID(extendUrl, my_auth):
-    try:
-        res = requests.get(answerUrl+extendUrl, data={}, headers=my_auth)
-        res.raise_for_status()
-    except:
-        raise
-    return res
-
-
-def openAPIConfig():
-    f = open(os.path.join(basedir, "api.config.json"), "rb")
-    result = json.load(f)
-    f.close()
-    return result
+# pylint: disable=C0301
 
 
-def getTaskIDFromResult(res):
-    task_id = json.loads(res.text)
-    result = '?task_id=' + str(task_id['task_id'])
-    return result
+def format_bibliography_info(bib_info):
+    """Function for formatting bibliography info"""
+    if isinstance(bib_info, str):
+        return bib_info.removesuffix('.txt')
+    elif isinstance(bib_info, dict):
+        return f"{bib_info['author']}.{bib_info['title']}.{bib_info['year']}.{bib_info['doi']} "
+    return bib_info
 
 
-def negativeStatusMsg(res):
-    # mypy: ignore
-    return 'Problems\n\tStatus code => {0}\n\tReason=> {1}'.format(res.status_code, res.reason)
+def filter_response_text(val):
+    """helper function for filtering non-printable chars"""
+    return json.loads(''.join([str(char)
+                               for char in val if char in string.printable]))
 
 
-def filterResponseText(val):
-    return json.loads(''.join([str(char) for char in val if char in string.printable]))
-
-
-def getGNQA(query, auth_token):
-    apiClient = GeneNetworkQAClient(requests.Session(), api_key=auth_token)
-    res, task_id = apiClient.ask('?ask=' + quote(query), auth_token)
-    res, success = apiClient.get_answer(task_id)
-
-    if (success == 1):
-        respText = filterResponseText(res.text)
-        if respText.get("data") is None:
-            return "Unfortunately I have nothing on the query", []
-        answer = respText['data']['answer']
-        context = respText['data']['context']
-        references = parse_context(context)
-        return task_id, answer, references
-    else:
-        return task_id, res, "Unfortunately I have nothing."
-
-
-def parse_context(context):
-    """parse content map id to reference"""
-    result = []
+def parse_context(context, get_info_func, format_bib_func):
+    """function to parse doc_ids content"""
+    results = []
     for doc_ids, summary in context.items():
-        comboTxt = ""
+        combo_txt = ""
         for entry in summary:
-            comboTxt += '\t' + entry['text']
-
-        docInfo = DocIDs().getInfo(doc_ids)
-        if doc_ids != docInfo:
-            bibInfo = formatBibliographyInfo(docInfo)
-
-        else:
-            bibInfo = doc_ids
-        result.append(
-            {"doc_id": doc_ids, "bibInfo": bibInfo, "comboTxt": comboTxt})
-    return result
+            combo_txt += "\t" + entry["text"]
+        doc_info = get_info_func(doc_ids)
+        bib_info = doc_ids if doc_ids == doc_info else format_bib_func(
+            doc_info)
+        results.append(
+            {"doc_id": doc_ids, "bibInfo": bib_info, "comboTxt": combo_txt})
+    return results
 
 
 def rate_document(task_id, doc_id, rating, auth_token):
     """This method is used to provide feedback for a document by making a rating."""
+    # todo move this to clients
     try:
-        resp = requests.post(
-            urljoin(baseUrl, f"/feedback?task_id={task_id}&document_id={doc_id}&feedback={rating}"),
-            headers={"Authorization": f"Bearer {auth_token}"}
-        )
+        url = urljoin(BASE_URL,
+                      f"""/feedback?task_id={task_id}&document_id={doc_id}&feedback={rating}""")
+        headers = {"Authorization": f"Bearer {auth_token}"}
+
+        resp = requests.post(url, headers=headers)
         resp.raise_for_status()
+
         return {"status": "success", **resp.json()}
     except requests.exceptions.HTTPError as http_error:
-        raise RuntimeError(f"HTTP Error Occurred: {http_error.response.text} -with status code- {http_error.response.status_code}")
+        raise RuntimeError(f"HTTP Error Occurred:\
+            {http_error.response.text} -with status code- {http_error.response.status_code}") from http_error
     except Exception as error:
+        raise RuntimeError(f"An error occurred: {str(error)}") from error
+
 
-        raise RuntimeError(f"An error occurred: {str(error)}")
+def get_gnqa(query, auth_token):
+    """entry function for the gn3 api endpoint()"""
+
+    api_client = GeneNetworkQAClient(requests.Session(), api_key=auth_token)
+    res, task_id = api_client.ask('?ask=' + quote(query), auth_token)
+    res, success = api_client.get_answer(task_id)
+    if success == 1:
+        resp_text = filter_response_text(res.text)
+        if resp_text.get("data") is None:
+            return "Unfortunately I have nothing on the query", []
+        answer = resp_text['data']['answer']
+        context = resp_text['data']['context']
+        references = parse_context(
+            context, DocIDs().getInfo, format_bibliography_info)
+        return task_id, answer, references
+    else:
+        return task_id, res, "Unfortunately I have nothing."
diff --git a/tests/unit/test_llm.py b/tests/unit/test_llm.py
new file mode 100644
index 0000000..a29190a
--- /dev/null
+++ b/tests/unit/test_llm.py
@@ -0,0 +1,97 @@
+"""Test cases for procedures defined in llms module"""
+import pytest
+from dataclasses import dataclass
+from gn3.llms.process import get_gnqa
+from gn3.llms.process import parse_context
+
+
+@pytest.fixture
+def context_data():
+    return {
+        "doc1": [{"text": "Summary 1"}, {"text": "Summary 2"}],
+        "doc2": [{"text": "Summary 3"}, {"text": "Summary 4"}],
+    }
+
+
+@pytest.mark.unit_test
+def test_parse_context(context_data):
+    def mock_get_info(doc_id):
+        return f"Info for {doc_id}"
+
+    def mock_format_bib(doc_info):
+        return f"Formatted Bibliography: {doc_info}"
+
+    parsed_result = parse_context(context_data, mock_get_info, mock_format_bib)
+
+    expected_result = [
+        {
+            "doc_id": "doc1",
+            "bibInfo": "Formatted Bibliography: Info for doc1",
+            "comboTxt": "\tSummary 1\tSummary 2",
+        },
+        {
+            "doc_id": "doc2",
+            "bibInfo": "Formatted Bibliography: Info for doc2",
+            "comboTxt": "\tSummary 3\tSummary 4",
+        },
+    ]
+
+    assert parsed_result == expected_result
+
+@dataclass(frozen=True)
+class MockResponse:
+    text: str
+
+    def __getattr__(self, name: str):
+        return self.__dict__[f"_{name}"]
+
+class MockGeneNetworkQAClient:
+    def __init__(self, session, api_key):
+        pass
+
+    def ask(self, query, auth_token):
+        # Simulate the ask method
+        return MockResponse("Mock response"), "F400995EAFE104EA72A5927CE10C73B7"
+
+    def get_answer(self, task_id):
+        # Simulate the get_answer method
+        return MockResponse("Mock answer"), 1
+
+
+def mock_filter_response_text(text):
+    """ method to simulate the filterResponseText method"""
+    return {"data": {"answer": "Mock answer for what is a gene", "context": {}}}
+
+
+def mock_parse_context(context, get_info_func, format_bib_func):
+    """method to simulate the  parse context method"""
+    return []
+
+
+@pytest.mark.unit_test
+def test_get_gnqa(monkeypatch):
+    monkeypatch.setattr(
+        "gn3.llms.process.GeneNetworkQAClient",
+        MockGeneNetworkQAClient
+    )
+
+    monkeypatch.setattr(
+        'gn3.llms.process.filter_response_text',
+        mock_filter_response_text
+    )
+    monkeypatch.setattr(
+        'gn3.llms.process.parse_context',
+        mock_parse_context
+    )
+
+    query = "What is a gene"
+    auth_token = "test_token"
+    result = get_gnqa(query, auth_token)
+
+    expected_result = (
+        "F400995EAFE104EA72A5927CE10C73B7",
+        'Mock answer for what is a gene',
+        []
+    )
+
+    assert result == expected_result