about summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/unit/test_llm.py129
1 files changed, 129 insertions, 0 deletions
diff --git a/tests/unit/test_llm.py b/tests/unit/test_llm.py
index 3b91918..3a79486 100644
--- a/tests/unit/test_llm.py
+++ b/tests/unit/test_llm.py
@@ -1,12 +1,22 @@
 """Test cases for procedures defined in llms """
 # pylint: disable=C0301
+# pylint: disable=W0613
+from datetime import datetime, timedelta
+from unittest.mock import patch
+from unittest.mock import MagicMock
+
 import pytest
 from gn3.llms.process import fetch_pubmed
 from gn3.llms.process import parse_context
 from gn3.llms.process import format_bibliography_info
+from gn3.llms.errors import LLMError
 from gn3.api.llm  import clean_query
+from gn3.api.llm  import is_verified_anonymous_user
+from gn3.api.llm  import is_valid_address
+from gn3.api.llm  import check_rate_limiter
 
 
+FAKE_NOW = datetime(2025, 1, 1, 12, 0, 0)
 @pytest.mark.unit_test
 def test_parse_context():
     """test for parsing doc id context"""
@@ -113,3 +123,122 @@ def test_clean_query():
     assert clean_query("!what is genetics.") == "what is genetics"
     assert clean_query("hello test?") == "hello test"
     assert clean_query("  hello test with space?") == "hello test with space"
+
+
+@pytest.mark.unit_test
+def test_is_verified_anonymous_user():
+    """Test function for verifying anonymous user metadata"""
+    assert is_verified_anonymous_user({}) is False
+    assert is_verified_anonymous_user({"Anonymous-Id" : "qws2121dwsdwdwe",
+                                        "Anonymous-Status" : "verified"}) is True
+
+@pytest.mark.unit_test
+def test_is_valid_address() :
+    """Test function checks if is a valid ip address is valid"""
+    assert  is_valid_address("invalid_ip") is False
+    assert is_valid_address("127.0.0.1") is True
+
+
+@patch("gn3.api.llm.datetime")
+@patch("gn3.api.llm.db.connection")
+@patch("gn3.api.llm.is_valid_address", return_value=True)
+@pytest.mark.unit_test
+def test_first_time_visitor(mock_is_valid, mock_db_conn, mock_datetime):
+    """Test rate limiting for first-time visitor"""
+    mock_datetime.utcnow.return_value = FAKE_NOW
+    mock_datetime.strptime = datetime.strptime  # keep real one
+    mock_datetime.strftime = datetime.strftime  # keep real one
+
+    # Set up DB mock
+    mock_conn = MagicMock()
+    mock_cursor = MagicMock()
+    mock_conn.__enter__.return_value = mock_conn
+    mock_conn.cursor.return_value = mock_cursor
+    mock_cursor.fetchone.return_value = None
+    mock_db_conn.return_value = mock_conn
+
+    result = check_rate_limiter("127.0.0.1", "test/llm.db", "Chromosome x")
+    assert result is True
+    mock_cursor.execute.assert_any_call("""
+                INSERT INTO Limiter(identifier, tokens, expiry_time)
+                VALUES (?, ?, ?)
+            """, ("127.0.0.1", 4, "2025-01-01 12:24:00"))
+
+
+@patch("gn3.api.llm.datetime")
+@patch("gn3.api.llm.db.connection")
+@patch("gn3.api.llm.is_valid_address", return_value=True)
+@pytest.mark.unit_test
+def test_visitor_at_limit(mock_is_valid, mock_db_conn, mock_datetime):
+    """Test rate limiting for Visitor at limit"""
+    mock_datetime.utcnow.return_value = FAKE_NOW
+    mock_datetime.strptime = datetime.strptime  # keep real one
+    mock_datetime.strftime = datetime.strftime
+
+    mock_conn = MagicMock()
+    mock_cursor = MagicMock()
+    mock_conn.__enter__.return_value = mock_conn
+    mock_conn.cursor.return_value = mock_cursor
+    fake_expiry = (FAKE_NOW + timedelta(minutes=10)).strftime("%Y-%m-%d %H:%M:%S")
+    mock_cursor.fetchone.return_value = (0, fake_expiry) #token returned are 0
+    mock_db_conn.return_value = mock_conn
+    with pytest.raises(LLMError) as exc_info:
+        check_rate_limiter("127.0.0.1", "test/llm.db", "Chromosome x")
+    # assert llm error with correct message is raised
+    assert exc_info.value.args == ('Rate limit exceeded. Please try again later.', 'Chromosome x')
+
+
+@patch("gn3.api.llm.datetime")
+@patch("gn3.api.llm.db.connection")
+@patch("gn3.api.llm.is_valid_address", return_value=True)
+@pytest.mark.unit_test
+def test_visitor_with_tokens(mock_is_valid, mock_db_conn, mock_datetime):
+    """Test rate limiting for user with valid tokens"""
+
+    mock_datetime.utcnow.return_value = FAKE_NOW
+    mock_datetime.strptime = datetime.strptime  # Use real versions
+    mock_datetime.strftime = datetime.strftime
+
+    mock_conn = MagicMock()
+    mock_cursor = MagicMock()
+    mock_conn.__enter__.return_value = mock_conn
+    mock_conn.cursor.return_value = mock_cursor
+
+    fake_expiry = (FAKE_NOW + timedelta(minutes=10)).strftime("%Y-%m-%d %H:%M:%S")
+    mock_cursor.fetchone.return_value = (3, fake_expiry)  # Simulate 3 tokens
+
+    mock_db_conn.return_value = mock_conn
+
+    results = check_rate_limiter("127.0.0.1", "test/llm.db", "Chromosome x")
+    assert results is True
+    mock_cursor.execute.assert_any_call("""
+                        UPDATE Limiter
+                        SET tokens = tokens - 1
+                        WHERE identifier = ? AND tokens > 0
+                    """, ("127.0.0.1",))
+
+@patch("gn3.api.llm.datetime")
+@patch("gn3.api.llm.db.connection")
+@patch("gn3.api.llm.is_valid_address", return_value=True)
+@pytest.mark.unit_test
+def test_visitor_token_expired(mock_is_valid, mock_db_conn, mock_datetime):
+    """Test rate limiting for expired tokens"""
+
+    mock_datetime.utcnow.return_value = FAKE_NOW
+    mock_datetime.strptime = datetime.strptime
+    mock_datetime.strftime = datetime.strftime
+    mock_conn = MagicMock()
+    mock_cursor = MagicMock()
+    mock_conn.__enter__.return_value = mock_conn
+    mock_conn.cursor.return_value = mock_cursor
+    fake_expiry = (FAKE_NOW - timedelta(minutes=10)).strftime("%Y-%m-%d %H:%M:%S")
+    mock_cursor.fetchone.return_value = (3, fake_expiry)  # Simulate 3 tokens
+    mock_db_conn.return_value = mock_conn
+
+    result = check_rate_limiter("127.0.0.1", "test/llm.db", "Chromosome x")
+    assert result is True
+    mock_cursor.execute.assert_any_call("""
+                    UPDATE Limiter
+                    SET tokens = ?, expiry_time = ?
+                    WHERE identifier = ?
+                """, (4, "2025-01-01 12:24:00", "127.0.0.1"))