aboutsummaryrefslogtreecommitdiff
path: root/tests/unit/test_llm.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/unit/test_llm.py')
-rw-r--r--tests/unit/test_llm.py28
1 files changed, 16 insertions, 12 deletions
diff --git a/tests/unit/test_llm.py b/tests/unit/test_llm.py
index 6f9bdcc..7b8a970 100644
--- a/tests/unit/test_llm.py
+++ b/tests/unit/test_llm.py
@@ -1,28 +1,25 @@
-# pylint: skip-file
+# pylint: disable=unused-argument
"""Test cases for procedures defined in llms module"""
-import pytest
from dataclasses import dataclass
+import pytest
from gn3.llms.process import get_gnqa
from gn3.llms.process import parse_context
-@pytest.fixture
-def context_data():
- return {
- "doc1": [{"text": "Summary 1"}, {"text": "Summary 2"}],
- "doc2": [{"text": "Summary 3"}, {"text": "Summary 4"}],
- }
-
@pytest.mark.unit_test
-def test_parse_context(context_data):
+def test_parse_context():
+ """test for parsing doc id context"""
def mock_get_info(doc_id):
return f"Info for {doc_id}"
def mock_format_bib(doc_info):
return f"Formatted Bibliography: {doc_info}"
- parsed_result = parse_context(context_data, mock_get_info, mock_format_bib)
+ parsed_result = parse_context({
+ "doc1": [{"text": "Summary 1"}, {"text": "Summary 2"}],
+ "doc2": [{"text": "Summary 3"}, {"text": "Summary 4"}],
+ }, mock_get_info, mock_format_bib)
expected_result = [
{
@@ -39,23 +36,29 @@ def test_parse_context(context_data):
assert parsed_result == expected_result
+
@dataclass(frozen=True)
class MockResponse:
+ """mock a response object"""
text: str
def __getattr__(self, name: str):
return self.__dict__[f"_{name}"]
+
class MockGeneNetworkQAClient:
+ """mock the GeneNetworkQAClient class"""
+
def __init__(self, session, api_key):
pass
def ask(self, query, auth_token):
+ """mock method for ask query"""
# Simulate the ask method
return MockResponse("Mock response"), "F400995EAFE104EA72A5927CE10C73B7"
def get_answer(self, task_id):
- # Simulate the get_answer method
+ """mock get_answer method"""
return MockResponse("Mock answer"), 1
@@ -71,6 +74,7 @@ def mock_parse_context(context, get_info_func, format_bib_func):
@pytest.mark.unit_test
def test_get_gnqa(monkeypatch):
+ """test for process.get_gnqa functoin"""
monkeypatch.setattr(
"gn3.llms.process.GeneNetworkQAClient",
MockGeneNetworkQAClient