# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import asyncio import time from unittest.mock import MagicMock, patch import pytest from .api_client import ApiClient @patch('genai.api_client.ApiClient._build_request') @patch('genai.api_client.ApiClient._request') def test_request_streamed_non_blocking(mock_request, mock_build_request): api_client = ApiClient(api_key='test_api_key') http_method = 'GET' path = 'test/path' request_dict = {'key': 'value'} mock_http_request = MagicMock() mock_build_request.return_value = mock_http_request def delayed_segments(): chunks = ['{"chunk": 1}', '{"chunk": 2}', '{"chunk": 3}'] for chunk in chunks: time.sleep(0.1) # 100ms delay yield chunk mock_response = MagicMock() mock_response.segments.side_effect = delayed_segments mock_request.return_value = mock_response chunks = [] start_time = time.time() for chunk in api_client.request_streamed(http_method, path, request_dict): chunks.append(chunk) assert len(chunks) <= 3 end_time = time.time() mock_build_request.assert_called_once_with( http_method, path, request_dict, None ) mock_request.assert_called_once_with(mock_http_request, stream=True) assert chunks == ['{"chunk": 1}', '{"chunk": 2}', '{"chunk": 3}'] assert end_time - start_time > 0.3 @patch('genai.api_client.ApiClient._build_request') @patch('genai.api_client.ApiClient._async_request') @pytest.mark.asyncio async def test_async_request(mock_async_request, mock_build_request): api_client = ApiClient(api_key='test_api_key') http_method = 'GET' path = 'test/path' request_dict = {'key': 'value'} mock_http_request = MagicMock() mock_build_request.return_value = mock_http_request class MockResponse: def __init__(self, text): self.text = text async def delayed_response(http_request, stream): await asyncio.sleep(0.1) # 100ms delay return MockResponse('value') mock_async_request.side_effect = delayed_response async_coroutine1 = api_client.async_request(http_method, path, request_dict) async_coroutine2 = api_client.async_request(http_method, path, request_dict) async_coroutine3 = api_client.async_request(http_method, path, request_dict) start_time = time.time() results = await asyncio.gather( async_coroutine1, async_coroutine2, async_coroutine3 ) end_time = time.time() mock_build_request.assert_called_with(http_method, path, request_dict, None) assert mock_build_request.call_count == 3 mock_async_request.assert_called_with( http_request=mock_http_request, stream=False ) assert mock_async_request.call_count == 3 assert results == ['value', 'value', 'value'] assert 0.1 <= end_time - start_time < 0.15 @patch('genai.api_client.ApiClient._build_request') @patch('genai.api_client.ApiClient._async_request') @pytest.mark.asyncio async def test_async_request_streamed_non_blocking( mock_async_request, mock_build_request ): api_client = ApiClient(api_key='test_api_key') http_method = 'GET' path = 'test/path' request_dict = {'key': 'value'} mock_http_request = MagicMock() mock_build_request.return_value = mock_http_request class MockResponse: def __init__(self, segments): self._segments = segments # should mock async generator here but source code combines sync and async streaming in one segment method. # TODO: fix the above def segments(self): for segment in self._segments: time.sleep(0.1) # 100ms delay yield segment async def delayed_response(http_request, stream): return MockResponse(['{"chunk": 1}', '{"chunk": 2}', '{"chunk": 3}']) mock_async_request.side_effect = delayed_response chunks = [] start_time = time.time() async for chunk in api_client.async_request_streamed( http_method, path, request_dict ): chunks.append(chunk) assert len(chunks) <= 3 end_time = time.time() mock_build_request.assert_called_once_with( http_method, path, request_dict, None ) mock_async_request.assert_called_once_with( http_request=mock_http_request, stream=True ) assert chunks == ['{"chunk": 1}', '{"chunk": 2}', '{"chunk": 3}'] assert end_time - start_time > 0.3