import asyncio import io import uuid from datetime import datetime from unittest.mock import AsyncMock, MagicMock, Mock import pytest from fastapi import UploadFile from r2r.base import ( Document, DocumentInfo, R2RDocumentProcessingError, R2RException, generate_id_from_label, ) from r2r.main import R2RPipelines, R2RProviders from r2r.main.services.ingestion_service import IngestionService @pytest.fixture(scope="session", autouse=True) def event_loop_policy(): asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy()) @pytest.fixture(scope="function") def event_loop(): loop = asyncio.get_event_loop_policy().new_event_loop() yield loop loop.close() asyncio.set_event_loop(None) @pytest.fixture(scope="session", autouse=True) async def cleanup_tasks(): yield for task in asyncio.all_tasks(): if task is not asyncio.current_task(): task.cancel() try: await task except asyncio.CancelledError: pass @pytest.fixture def mock_vector_db(): mock_db = MagicMock() mock_db.get_documents_overview.return_value = [] # Default to empty list return mock_db @pytest.fixture def mock_embedding_model(): return Mock() @pytest.fixture def ingestion_service(mock_vector_db, mock_embedding_model): config = MagicMock() config.app.get.return_value = 32 # Default max file size providers = Mock(spec=R2RProviders) providers.vector_db = mock_vector_db providers.embedding_model = mock_embedding_model pipelines = Mock(spec=R2RPipelines) pipelines.ingestion_pipeline = AsyncMock() pipelines.ingestion_pipeline.run.return_value = { "embedding_pipeline_output": [] } run_manager = Mock() run_manager.run_info = {"mock_run_id": {}} logging_connection = AsyncMock() return IngestionService( config, providers, pipelines, run_manager, logging_connection ) @pytest.mark.asyncio async def test_ingest_single_document(ingestion_service, mock_vector_db): try: document = Document( id=generate_id_from_label("test_id"), data="Test content", type="txt", metadata={}, ) ingestion_service.pipelines.ingestion_pipeline.run.return_value = { "embedding_pipeline_output": [(document.id, None)] } mock_vector_db.get_documents_overview.return_value = ( [] ) # No existing documents result = await ingestion_service.ingest_documents([document]) assert result["processed_documents"] == [ f"Document '{document.id}' processed successfully." ] assert not result["failed_documents"] assert not result["skipped_documents"] except asyncio.CancelledError: pass @pytest.mark.asyncio async def test_ingest_duplicate_document(ingestion_service, mock_vector_db): try: document = Document( id=generate_id_from_label("test_id"), data="Test content", type="txt", metadata={}, ) mock_vector_db.get_documents_overview.return_value = [ DocumentInfo( document_id=document.id, version="v0", size_in_bytes=len(document.data), metadata={}, title=str(document.id), user_id=None, created_at=datetime.now(), updated_at=datetime.now(), status="success", ) ] with pytest.raises(R2RException) as exc_info: await ingestion_service.ingest_documents([document]) assert ( f"Document with ID {document.id} was already successfully processed" in str(exc_info.value) ) except asyncio.CancelledError: pass @pytest.mark.asyncio async def test_ingest_file(ingestion_service): try: file_content = b"Test content" file_mock = UploadFile( filename="test.txt", file=io.BytesIO(file_content) ) file_mock.file.seek(0) file_mock.size = len(file_content) # Set file size manually ingestion_service.pipelines.ingestion_pipeline.run.return_value = { "embedding_pipeline_output": [ (generate_id_from_label("test.txt"), None) ] } result = await ingestion_service.ingest_files([file_mock]) assert len(result["processed_documents"]) == 1 assert not result["failed_documents"] assert not result["skipped_documents"] except asyncio.CancelledError: pass @pytest.mark.asyncio async def test_ingest_mixed_success_and_failure( ingestion_service, mock_vector_db ): try: documents = [ Document( id=generate_id_from_label("success_id"), data="Success content", type="txt", metadata={}, ), Document( id=generate_id_from_label("failure_id"), data="Failure content", type="txt", metadata={}, ), ] ingestion_service.pipelines.ingestion_pipeline.run.return_value = { "embedding_pipeline_output": [ ( documents[0].id, f"Processed 1 vectors for document {documents[0].id}.", ), ( documents[1].id, R2RDocumentProcessingError( error_message="Embedding failed", document_id=documents[1].id, ), ), ] } result = await ingestion_service.ingest_documents(documents) assert len(result["processed_documents"]) == 1 assert len(result["failed_documents"]) == 1 assert str(documents[0].id) in result["processed_documents"][0] assert str(documents[1].id) in result["failed_documents"][0] assert "Embedding failed" in result["failed_documents"][0] assert mock_vector_db.upsert_documents_overview.call_count == 2 upserted_docs = mock_vector_db.upsert_documents_overview.call_args[0][ 0 ] assert len(upserted_docs) == 2 assert upserted_docs[0].document_id == documents[0].id assert upserted_docs[0].status == "success" assert upserted_docs[1].document_id == documents[1].id assert upserted_docs[1].status == "failure" except asyncio.CancelledError: pass @pytest.mark.asyncio async def test_ingest_unsupported_file_type(ingestion_service): try: file_mock = UploadFile( filename="test.unsupported", file=io.BytesIO(b"Test content") ) file_mock.file.seek(0) file_mock.size = 12 # Set file size manually with pytest.raises(R2RException) as exc_info: await ingestion_service.ingest_files([file_mock]) assert "is not a valid DocumentType" in str(exc_info.value) except asyncio.CancelledError: pass @pytest.mark.asyncio async def test_ingest_large_file(ingestion_service): try: large_content = b"Large content" * 1000000 # 12MB content file_mock = UploadFile( filename="large_file.txt", file=io.BytesIO(large_content) ) file_mock.file.seek(0) file_mock.size = len(large_content) # Set file size manually ingestion_service.config.app.get.return_value = ( 10 # Set max file size to 10MB ) with pytest.raises(R2RException) as exc_info: await ingestion_service.ingest_files([file_mock]) assert "File size exceeds maximum allowed size" in str(exc_info.value) except asyncio.CancelledError: pass @pytest.mark.asyncio async def test_partial_ingestion_success(ingestion_service, mock_vector_db): try: documents = [ Document( id=generate_id_from_label("success_1"), data="Success content 1", type="txt", metadata={}, ), Document( id=generate_id_from_label("fail"), data="Fail content", type="txt", metadata={}, ), Document( id=generate_id_from_label("success_2"), data="Success content 2", type="txt", metadata={}, ), ] ingestion_service.pipelines.ingestion_pipeline.run.return_value = { "embedding_pipeline_output": [ (documents[0].id, None), ( documents[1].id, R2RDocumentProcessingError( error_message="Embedding failed", document_id=documents[1].id, ), ), (documents[2].id, None), ] } result = await ingestion_service.ingest_documents(documents) assert len(result["processed_documents"]) == 2 assert len(result["failed_documents"]) == 1 assert str(documents[1].id) in result["failed_documents"][0] except asyncio.CancelledError: pass @pytest.mark.asyncio async def test_version_increment(ingestion_service, mock_vector_db): try: document = Document( id=generate_id_from_label("test_id"), data="Test content", type="txt", metadata={}, ) mock_vector_db.get_documents_overview.return_value = [ DocumentInfo( document_id=document.id, version="v2", status="success", size_in_bytes=0, metadata={}, ) ] file_mock = UploadFile( filename="test.txt", file=io.BytesIO(b"Updated content") ) await ingestion_service.update_files([file_mock], [document.id]) calls = mock_vector_db.upsert_documents_overview.call_args_list assert len(calls) == 2 assert calls[1][0][0][0].version == "v3" except asyncio.CancelledError: pass @pytest.mark.asyncio async def test_process_ingestion_results_error_handling(ingestion_service): try: document_infos = [ DocumentInfo( document_id=uuid.uuid4(), version="v0", status="processing", size_in_bytes=0, metadata={}, ) ] ingestion_results = { "embedding_pipeline_output": [ ( document_infos[0].document_id, R2RDocumentProcessingError( "Unexpected error", document_id=document_infos[0].document_id, ), ) ] } result = await ingestion_service._process_ingestion_results( ingestion_results, document_infos, [], {document_infos[0].document_id: "test"}, ) assert len(result["failed_documents"]) == 1 assert "Unexpected error" in result["failed_documents"][0] except asyncio.CancelledError: pass @pytest.mark.asyncio async def test_file_size_limit_edge_cases(ingestion_service): try: ingestion_service.config.app.get.return_value = 1 # 1MB limit just_under_limit = b"x" * (1024 * 1024 - 1) at_limit = b"x" * (1024 * 1024) over_limit = b"x" * (1024 * 1024 + 1) file_under = UploadFile( filename="under.txt", file=io.BytesIO(just_under_limit), size=1024 * 1024 - 1, ) file_at = UploadFile( filename="at.txt", file=io.BytesIO(at_limit), size=1024 * 1024 ) file_over = UploadFile( filename="over.txt", file=io.BytesIO(over_limit), size=1024 * 1024 + 1, ) await ingestion_service.ingest_files([file_under]) # Should succeed await ingestion_service.ingest_files([file_at]) # Should succeed with pytest.raises( R2RException, match="File size exceeds maximum allowed size" ): await ingestion_service.ingest_files([file_over]) except asyncio.CancelledError: pass @pytest.mark.asyncio async def test_document_status_update_after_ingestion( ingestion_service, mock_vector_db ): try: document = Document( id=generate_id_from_label("test_id"), data="Test content", type="txt", metadata={}, ) ingestion_service.pipelines.ingestion_pipeline.run.return_value = { "embedding_pipeline_output": [(document.id, None)] } mock_vector_db.get_documents_overview.return_value = ( [] ) # No existing documents await ingestion_service.ingest_documents([document]) # Check that upsert_documents_overview was called twice assert mock_vector_db.upsert_documents_overview.call_count == 2 # Check the second call to upsert_documents_overview (status update) second_call_args = ( mock_vector_db.upsert_documents_overview.call_args_list[1][0][0] ) assert len(second_call_args) == 1 assert second_call_args[0].document_id == document.id assert second_call_args[0].status == "success" except asyncio.CancelledError: pass