import asyncio
import io
import uuid
from datetime import datetime
from unittest.mock import AsyncMock, MagicMock, Mock
import pytest
from fastapi import UploadFile
from r2r.base import (
Document,
DocumentInfo,
R2RDocumentProcessingError,
R2RException,
generate_id_from_label,
)
from r2r.main import R2RPipelines, R2RProviders
from r2r.main.services.ingestion_service import IngestionService
@pytest.fixture(scope="session", autouse=True)
def event_loop_policy():
asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy())
@pytest.fixture(scope="function")
def event_loop():
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
asyncio.set_event_loop(None)
@pytest.fixture(scope="session", autouse=True)
async def cleanup_tasks():
yield
for task in asyncio.all_tasks():
if task is not asyncio.current_task():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
@pytest.fixture
def mock_vector_db():
mock_db = MagicMock()
mock_db.get_documents_overview.return_value = [] # Default to empty list
return mock_db
@pytest.fixture
def mock_embedding_model():
return Mock()
@pytest.fixture
def ingestion_service(mock_vector_db, mock_embedding_model):
config = MagicMock()
config.app.get.return_value = 32 # Default max file size
providers = Mock(spec=R2RProviders)
providers.vector_db = mock_vector_db
providers.embedding_model = mock_embedding_model
pipelines = Mock(spec=R2RPipelines)
pipelines.ingestion_pipeline = AsyncMock()
pipelines.ingestion_pipeline.run.return_value = {
"embedding_pipeline_output": []
}
run_manager = Mock()
run_manager.run_info = {"mock_run_id": {}}
logging_connection = AsyncMock()
return IngestionService(
config, providers, pipelines, run_manager, logging_connection
)
@pytest.mark.asyncio
async def test_ingest_single_document(ingestion_service, mock_vector_db):
try:
document = Document(
id=generate_id_from_label("test_id"),
data="Test content",
type="txt",
metadata={},
)
ingestion_service.pipelines.ingestion_pipeline.run.return_value = {
"embedding_pipeline_output": [(document.id, None)]
}
mock_vector_db.get_documents_overview.return_value = (
[]
) # No existing documents
result = await ingestion_service.ingest_documents([document])
assert result["processed_documents"] == [
f"Document '{document.id}' processed successfully."
]
assert not result["failed_documents"]
assert not result["skipped_documents"]
except asyncio.CancelledError:
pass
@pytest.mark.asyncio
async def test_ingest_duplicate_document(ingestion_service, mock_vector_db):
try:
document = Document(
id=generate_id_from_label("test_id"),
data="Test content",
type="txt",
metadata={},
)
mock_vector_db.get_documents_overview.return_value = [
DocumentInfo(
document_id=document.id,
version="v0",
size_in_bytes=len(document.data),
metadata={},
title=str(document.id),
user_id=None,
created_at=datetime.now(),
updated_at=datetime.now(),
status="success",
)
]
with pytest.raises(R2RException) as exc_info:
await ingestion_service.ingest_documents([document])
assert (
f"Document with ID {document.id} was already successfully processed"
in str(exc_info.value)
)
except asyncio.CancelledError:
pass
@pytest.mark.asyncio
async def test_ingest_file(ingestion_service):
try:
file_content = b"Test content"
file_mock = UploadFile(
filename="test.txt", file=io.BytesIO(file_content)
)
file_mock.file.seek(0)
file_mock.size = len(file_content) # Set file size manually
ingestion_service.pipelines.ingestion_pipeline.run.return_value = {
"embedding_pipeline_output": [
(generate_id_from_label("test.txt"), None)
]
}
result = await ingestion_service.ingest_files([file_mock])
assert len(result["processed_documents"]) == 1
assert not result["failed_documents"]
assert not result["skipped_documents"]
except asyncio.CancelledError:
pass
@pytest.mark.asyncio
async def test_ingest_mixed_success_and_failure(
ingestion_service, mock_vector_db
):
try:
documents = [
Document(
id=generate_id_from_label("success_id"),
data="Success content",
type="txt",
metadata={},
),
Document(
id=generate_id_from_label("failure_id"),
data="Failure content",
type="txt",
metadata={},
),
]
ingestion_service.pipelines.ingestion_pipeline.run.return_value = {
"embedding_pipeline_output": [
(
documents[0].id,
f"Processed 1 vectors for document {documents[0].id}.",
),
(
documents[1].id,
R2RDocumentProcessingError(
error_message="Embedding failed",
document_id=documents[1].id,
),
),
]
}
result = await ingestion_service.ingest_documents(documents)
assert len(result["processed_documents"]) == 1
assert len(result["failed_documents"]) == 1
assert str(documents[0].id) in result["processed_documents"][0]
assert str(documents[1].id) in result["failed_documents"][0]
assert "Embedding failed" in result["failed_documents"][0]
assert mock_vector_db.upsert_documents_overview.call_count == 2
upserted_docs = mock_vector_db.upsert_documents_overview.call_args[0][
0
]
assert len(upserted_docs) == 2
assert upserted_docs[0].document_id == documents[0].id
assert upserted_docs[0].status == "success"
assert upserted_docs[1].document_id == documents[1].id
assert upserted_docs[1].status == "failure"
except asyncio.CancelledError:
pass
@pytest.mark.asyncio
async def test_ingest_unsupported_file_type(ingestion_service):
try:
file_mock = UploadFile(
filename="test.unsupported", file=io.BytesIO(b"Test content")
)
file_mock.file.seek(0)
file_mock.size = 12 # Set file size manually
with pytest.raises(R2RException) as exc_info:
await ingestion_service.ingest_files([file_mock])
assert "is not a valid DocumentType" in str(exc_info.value)
except asyncio.CancelledError:
pass
@pytest.mark.asyncio
async def test_ingest_large_file(ingestion_service):
try:
large_content = b"Large content" * 1000000 # 12MB content
file_mock = UploadFile(
filename="large_file.txt", file=io.BytesIO(large_content)
)
file_mock.file.seek(0)
file_mock.size = len(large_content) # Set file size manually
ingestion_service.config.app.get.return_value = (
10 # Set max file size to 10MB
)
with pytest.raises(R2RException) as exc_info:
await ingestion_service.ingest_files([file_mock])
assert "File size exceeds maximum allowed size" in str(exc_info.value)
except asyncio.CancelledError:
pass
@pytest.mark.asyncio
async def test_partial_ingestion_success(ingestion_service, mock_vector_db):
try:
documents = [
Document(
id=generate_id_from_label("success_1"),
data="Success content 1",
type="txt",
metadata={},
),
Document(
id=generate_id_from_label("fail"),
data="Fail content",
type="txt",
metadata={},
),
Document(
id=generate_id_from_label("success_2"),
data="Success content 2",
type="txt",
metadata={},
),
]
ingestion_service.pipelines.ingestion_pipeline.run.return_value = {
"embedding_pipeline_output": [
(documents[0].id, None),
(
documents[1].id,
R2RDocumentProcessingError(
error_message="Embedding failed",
document_id=documents[1].id,
),
),
(documents[2].id, None),
]
}
result = await ingestion_service.ingest_documents(documents)
assert len(result["processed_documents"]) == 2
assert len(result["failed_documents"]) == 1
assert str(documents[1].id) in result["failed_documents"][0]
except asyncio.CancelledError:
pass
@pytest.mark.asyncio
async def test_version_increment(ingestion_service, mock_vector_db):
try:
document = Document(
id=generate_id_from_label("test_id"),
data="Test content",
type="txt",
metadata={},
)
mock_vector_db.get_documents_overview.return_value = [
DocumentInfo(
document_id=document.id,
version="v2",
status="success",
size_in_bytes=0,
metadata={},
)
]
file_mock = UploadFile(
filename="test.txt", file=io.BytesIO(b"Updated content")
)
await ingestion_service.update_files([file_mock], [document.id])
calls = mock_vector_db.upsert_documents_overview.call_args_list
assert len(calls) == 2
assert calls[1][0][0][0].version == "v3"
except asyncio.CancelledError:
pass
@pytest.mark.asyncio
async def test_process_ingestion_results_error_handling(ingestion_service):
try:
document_infos = [
DocumentInfo(
document_id=uuid.uuid4(),
version="v0",
status="processing",
size_in_bytes=0,
metadata={},
)
]
ingestion_results = {
"embedding_pipeline_output": [
(
document_infos[0].document_id,
R2RDocumentProcessingError(
"Unexpected error",
document_id=document_infos[0].document_id,
),
)
]
}
result = await ingestion_service._process_ingestion_results(
ingestion_results,
document_infos,
[],
{document_infos[0].document_id: "test"},
)
assert len(result["failed_documents"]) == 1
assert "Unexpected error" in result["failed_documents"][0]
except asyncio.CancelledError:
pass
@pytest.mark.asyncio
async def test_file_size_limit_edge_cases(ingestion_service):
try:
ingestion_service.config.app.get.return_value = 1 # 1MB limit
just_under_limit = b"x" * (1024 * 1024 - 1)
at_limit = b"x" * (1024 * 1024)
over_limit = b"x" * (1024 * 1024 + 1)
file_under = UploadFile(
filename="under.txt",
file=io.BytesIO(just_under_limit),
size=1024 * 1024 - 1,
)
file_at = UploadFile(
filename="at.txt", file=io.BytesIO(at_limit), size=1024 * 1024
)
file_over = UploadFile(
filename="over.txt",
file=io.BytesIO(over_limit),
size=1024 * 1024 + 1,
)
await ingestion_service.ingest_files([file_under]) # Should succeed
await ingestion_service.ingest_files([file_at]) # Should succeed
with pytest.raises(
R2RException, match="File size exceeds maximum allowed size"
):
await ingestion_service.ingest_files([file_over])
except asyncio.CancelledError:
pass
@pytest.mark.asyncio
async def test_document_status_update_after_ingestion(
ingestion_service, mock_vector_db
):
try:
document = Document(
id=generate_id_from_label("test_id"),
data="Test content",
type="txt",
metadata={},
)
ingestion_service.pipelines.ingestion_pipeline.run.return_value = {
"embedding_pipeline_output": [(document.id, None)]
}
mock_vector_db.get_documents_overview.return_value = (
[]
) # No existing documents
await ingestion_service.ingest_documents([document])
# Check that upsert_documents_overview was called twice
assert mock_vector_db.upsert_documents_overview.call_count == 2
# Check the second call to upsert_documents_overview (status update)
second_call_args = (
mock_vector_db.upsert_documents_overview.call_args_list[1][0][0]
)
assert len(second_call_args) == 1
assert second_call_args[0].document_id == document.id
assert second_call_args[0].status == "success"
except asyncio.CancelledError:
pass