1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
|
# pylint: disable=unused-argument
"""Test cases for procedures defined in llms module"""
from dataclasses import dataclass
import pytest
from gn3.llms.process import get_gnqa
from gn3.llms.process import parse_context
@pytest.mark.unit_test
def test_parse_context():
"""test for parsing doc id context"""
def mock_get_info(doc_id):
return f"Info for {doc_id}"
def mock_format_bib(doc_info):
return f"Formatted Bibliography: {doc_info}"
parsed_result = parse_context({
"doc1": [{"text": "Summary 1"}, {"text": "Summary 2"}],
"doc2": [{"text": "Summary 3"}, {"text": "Summary 4"}],
}, mock_get_info, mock_format_bib)
expected_result = [
{
"doc_id": "doc1",
"bibInfo": "Formatted Bibliography: Info for doc1",
"comboTxt": "\tSummary 1\tSummary 2",
},
{
"doc_id": "doc2",
"bibInfo": "Formatted Bibliography: Info for doc2",
"comboTxt": "\tSummary 3\tSummary 4",
},
]
assert parsed_result == expected_result
@dataclass(frozen=True)
class MockResponse:
"""mock a response object"""
text: str
def __getattr__(self, name: str):
return self.__dict__[f"_{name}"]
class MockGeneNetworkQAClient:
"""mock the GeneNetworkQAClient class"""
def __init__(self, session, api_key):
pass
def ask(self, query, auth_token):
"""mock method for ask query"""
# Simulate the ask method
return MockResponse("Mock response"), "F400995EAFE104EA72A5927CE10C73B7"
def get_answer(self, task_id):
"""mock get_answer method"""
return MockResponse("Mock answer"), 1
def mock_filter_response_text(text):
""" method to simulate the filterResponseText method"""
return {"data": {"answer": "Mock answer for what is a gene", "context": {}}}
def mock_parse_context(context, get_info_func, format_bib_func):
"""method to simulate the parse context method"""
return []
@pytest.mark.unit_test
def test_get_gnqa(monkeypatch):
"""test for process.get_gnqa functoin"""
monkeypatch.setattr(
"gn3.llms.process.GeneNetworkQAClient",
MockGeneNetworkQAClient
)
monkeypatch.setattr(
'gn3.llms.process.filter_response_text',
mock_filter_response_text
)
monkeypatch.setattr(
'gn3.llms.process.parse_context',
mock_parse_context
)
query = "What is a gene"
auth_token = "test_token"
result = get_gnqa(query, auth_token)
expected_result = (
"F400995EAFE104EA72A5927CE10C73B7",
'Mock answer for what is a gene',
[]
)
assert result == expected_result
|