diff options
Diffstat (limited to 'tests/unit/test_llm.py')
-rw-r--r-- | tests/unit/test_llm.py | 28 |
1 files changed, 16 insertions, 12 deletions
diff --git a/tests/unit/test_llm.py b/tests/unit/test_llm.py index 6f9bdcc..7b8a970 100644 --- a/tests/unit/test_llm.py +++ b/tests/unit/test_llm.py @@ -1,28 +1,25 @@ -# pylint: skip-file +# pylint: disable=unused-argument """Test cases for procedures defined in llms module""" -import pytest from dataclasses import dataclass +import pytest from gn3.llms.process import get_gnqa from gn3.llms.process import parse_context -@pytest.fixture -def context_data(): - return { - "doc1": [{"text": "Summary 1"}, {"text": "Summary 2"}], - "doc2": [{"text": "Summary 3"}, {"text": "Summary 4"}], - } - @pytest.mark.unit_test -def test_parse_context(context_data): +def test_parse_context(): + """test for parsing doc id context""" def mock_get_info(doc_id): return f"Info for {doc_id}" def mock_format_bib(doc_info): return f"Formatted Bibliography: {doc_info}" - parsed_result = parse_context(context_data, mock_get_info, mock_format_bib) + parsed_result = parse_context({ + "doc1": [{"text": "Summary 1"}, {"text": "Summary 2"}], + "doc2": [{"text": "Summary 3"}, {"text": "Summary 4"}], + }, mock_get_info, mock_format_bib) expected_result = [ { @@ -39,23 +36,29 @@ def test_parse_context(context_data): assert parsed_result == expected_result + @dataclass(frozen=True) class MockResponse: + """mock a response object""" text: str def __getattr__(self, name: str): return self.__dict__[f"_{name}"] + class MockGeneNetworkQAClient: + """mock the GeneNetworkQAClient class""" + def __init__(self, session, api_key): pass def ask(self, query, auth_token): + """mock method for ask query""" # Simulate the ask method return MockResponse("Mock response"), "F400995EAFE104EA72A5927CE10C73B7" def get_answer(self, task_id): - # Simulate the get_answer method + """mock get_answer method""" return MockResponse("Mock answer"), 1 @@ -71,6 +74,7 @@ def mock_parse_context(context, get_info_func, format_bib_func): @pytest.mark.unit_test def test_get_gnqa(monkeypatch): + """test for process.get_gnqa functoin""" monkeypatch.setattr( "gn3.llms.process.GeneNetworkQAClient", MockGeneNetworkQAClient |