diff options
| author | Alexander_Kabui | 2025-07-15 22:43:56 +0300 |
|---|---|---|
| committer | BonfaceKilz | 2025-07-16 22:50:45 +0300 |
| commit | 15e338f376e9312b20ef660dc75a218739a95bee (patch) | |
| tree | 1a7d1d9fce9f6db4f3864c627fca47f2242a27e0 | |
| parent | 74bab3b7623ab8130fc0dea7f9aa504e109bee77 (diff) | |
| download | genenetwork3-15e338f376e9312b20ef660dc75a218739a95bee.tar.gz | |
feat: Add unittests for llm rate limiting functionality.
| -rw-r--r-- | tests/unit/test_llm.py | 129 |
1 files changed, 129 insertions, 0 deletions
diff --git a/tests/unit/test_llm.py b/tests/unit/test_llm.py index 3b91918..3a79486 100644 --- a/tests/unit/test_llm.py +++ b/tests/unit/test_llm.py @@ -1,12 +1,22 @@ """Test cases for procedures defined in llms """ # pylint: disable=C0301 +# pylint: disable=W0613 +from datetime import datetime, timedelta +from unittest.mock import patch +from unittest.mock import MagicMock + import pytest from gn3.llms.process import fetch_pubmed from gn3.llms.process import parse_context from gn3.llms.process import format_bibliography_info +from gn3.llms.errors import LLMError from gn3.api.llm import clean_query +from gn3.api.llm import is_verified_anonymous_user +from gn3.api.llm import is_valid_address +from gn3.api.llm import check_rate_limiter +FAKE_NOW = datetime(2025, 1, 1, 12, 0, 0) @pytest.mark.unit_test def test_parse_context(): """test for parsing doc id context""" @@ -113,3 +123,122 @@ def test_clean_query(): assert clean_query("!what is genetics.") == "what is genetics" assert clean_query("hello test?") == "hello test" assert clean_query(" hello test with space?") == "hello test with space" + + +@pytest.mark.unit_test +def test_is_verified_anonymous_user(): + """Test function for verifying anonymous user metadata""" + assert is_verified_anonymous_user({}) is False + assert is_verified_anonymous_user({"Anonymous-Id" : "qws2121dwsdwdwe", + "Anonymous-Status" : "verified"}) is True + +@pytest.mark.unit_test +def test_is_valid_address() : + """Test function checks if is a valid ip address is valid""" + assert is_valid_address("invalid_ip") is False + assert is_valid_address("127.0.0.1") is True + + +@patch("gn3.api.llm.datetime") +@patch("gn3.api.llm.db.connection") +@patch("gn3.api.llm.is_valid_address", return_value=True) +@pytest.mark.unit_test +def test_first_time_visitor(mock_is_valid, mock_db_conn, mock_datetime): + """Test rate limiting for first-time visitor""" + mock_datetime.utcnow.return_value = FAKE_NOW + mock_datetime.strptime = datetime.strptime # keep real one + mock_datetime.strftime = datetime.strftime # keep real one + + # Set up DB mock + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_conn.__enter__.return_value = mock_conn + mock_conn.cursor.return_value = mock_cursor + mock_cursor.fetchone.return_value = None + mock_db_conn.return_value = mock_conn + + result = check_rate_limiter("127.0.0.1", "test/llm.db", "Chromosome x") + assert result is True + mock_cursor.execute.assert_any_call(""" + INSERT INTO Limiter(identifier, tokens, expiry_time) + VALUES (?, ?, ?) + """, ("127.0.0.1", 4, "2025-01-01 12:24:00")) + + +@patch("gn3.api.llm.datetime") +@patch("gn3.api.llm.db.connection") +@patch("gn3.api.llm.is_valid_address", return_value=True) +@pytest.mark.unit_test +def test_visitor_at_limit(mock_is_valid, mock_db_conn, mock_datetime): + """Test rate limiting for Visitor at limit""" + mock_datetime.utcnow.return_value = FAKE_NOW + mock_datetime.strptime = datetime.strptime # keep real one + mock_datetime.strftime = datetime.strftime + + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_conn.__enter__.return_value = mock_conn + mock_conn.cursor.return_value = mock_cursor + fake_expiry = (FAKE_NOW + timedelta(minutes=10)).strftime("%Y-%m-%d %H:%M:%S") + mock_cursor.fetchone.return_value = (0, fake_expiry) #token returned are 0 + mock_db_conn.return_value = mock_conn + with pytest.raises(LLMError) as exc_info: + check_rate_limiter("127.0.0.1", "test/llm.db", "Chromosome x") + # assert llm error with correct message is raised + assert exc_info.value.args == ('Rate limit exceeded. Please try again later.', 'Chromosome x') + + +@patch("gn3.api.llm.datetime") +@patch("gn3.api.llm.db.connection") +@patch("gn3.api.llm.is_valid_address", return_value=True) +@pytest.mark.unit_test +def test_visitor_with_tokens(mock_is_valid, mock_db_conn, mock_datetime): + """Test rate limiting for user with valid tokens""" + + mock_datetime.utcnow.return_value = FAKE_NOW + mock_datetime.strptime = datetime.strptime # Use real versions + mock_datetime.strftime = datetime.strftime + + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_conn.__enter__.return_value = mock_conn + mock_conn.cursor.return_value = mock_cursor + + fake_expiry = (FAKE_NOW + timedelta(minutes=10)).strftime("%Y-%m-%d %H:%M:%S") + mock_cursor.fetchone.return_value = (3, fake_expiry) # Simulate 3 tokens + + mock_db_conn.return_value = mock_conn + + results = check_rate_limiter("127.0.0.1", "test/llm.db", "Chromosome x") + assert results is True + mock_cursor.execute.assert_any_call(""" + UPDATE Limiter + SET tokens = tokens - 1 + WHERE identifier = ? AND tokens > 0 + """, ("127.0.0.1",)) + +@patch("gn3.api.llm.datetime") +@patch("gn3.api.llm.db.connection") +@patch("gn3.api.llm.is_valid_address", return_value=True) +@pytest.mark.unit_test +def test_visitor_token_expired(mock_is_valid, mock_db_conn, mock_datetime): + """Test rate limiting for expired tokens""" + + mock_datetime.utcnow.return_value = FAKE_NOW + mock_datetime.strptime = datetime.strptime + mock_datetime.strftime = datetime.strftime + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_conn.__enter__.return_value = mock_conn + mock_conn.cursor.return_value = mock_cursor + fake_expiry = (FAKE_NOW - timedelta(minutes=10)).strftime("%Y-%m-%d %H:%M:%S") + mock_cursor.fetchone.return_value = (3, fake_expiry) # Simulate 3 tokens + mock_db_conn.return_value = mock_conn + + result = check_rate_limiter("127.0.0.1", "test/llm.db", "Chromosome x") + assert result is True + mock_cursor.execute.assert_any_call(""" + UPDATE Limiter + SET tokens = ?, expiry_time = ? + WHERE identifier = ? + """, (4, "2025-01-01 12:24:00", "127.0.0.1")) |
