diff options
Diffstat (limited to 'R2R/tests')
-rwxr-xr-x | R2R/tests/test_abstractions.py | 162 | ||||
-rwxr-xr-x | R2R/tests/test_config.py | 187 | ||||
-rwxr-xr-x | R2R/tests/test_embedding.py | 162 | ||||
-rwxr-xr-x | R2R/tests/test_end_to_end.py | 375 | ||||
-rwxr-xr-x | R2R/tests/test_ingestion_service.py | 443 | ||||
-rwxr-xr-x | R2R/tests/test_llms.py | 59 | ||||
-rwxr-xr-x | R2R/tests/test_logging.py | 360 | ||||
-rwxr-xr-x | R2R/tests/test_parser.py | 159 | ||||
-rwxr-xr-x | R2R/tests/test_pipeline.py | 291 | ||||
-rwxr-xr-x | R2R/tests/test_vector_db.py | 160 |
10 files changed, 2358 insertions, 0 deletions
diff --git a/R2R/tests/test_abstractions.py b/R2R/tests/test_abstractions.py new file mode 100755 index 00000000..a360e952 --- /dev/null +++ b/R2R/tests/test_abstractions.py @@ -0,0 +1,162 @@ +import asyncio +import uuid + +import pytest + +from r2r import ( + AsyncPipe, + AsyncState, + Prompt, + Vector, + VectorEntry, + VectorSearchRequest, + VectorSearchResult, + VectorType, + generate_id_from_label, +) + + +@pytest.fixture(scope="session", autouse=True) +def event_loop_policy(): + asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy()) + + +@pytest.fixture(scope="function") +def event_loop(): + loop = asyncio.get_event_loop_policy().new_event_loop() + yield loop + loop.close() + asyncio.set_event_loop(None) + + +@pytest.fixture(scope="session", autouse=True) +async def cleanup_tasks(): + yield + for task in asyncio.all_tasks(): + if task is not asyncio.current_task(): + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + +@pytest.mark.asyncio +async def test_async_state_update_and_get(): + state = AsyncState() + outer_key = "test_key" + values = {"inner_key": "value"} + await state.update(outer_key, values) + result = await state.get(outer_key, "inner_key") + assert result == "value" + + +@pytest.mark.asyncio +async def test_async_state_delete(): + state = AsyncState() + outer_key = "test_key" + values = {"inner_key": "value"} + await state.update(outer_key, values) + await state.delete(outer_key, "inner_key") + result = await state.get(outer_key, "inner_key") + assert result == {}, "Expect empty result after deletion" + + +class MockAsyncPipe(AsyncPipe): + async def _run_logic(self, input, state, run_id, *args, **kwargs): + yield "processed" + + +@pytest.mark.asyncio +async def test_async_pipe_run(): + pipe = MockAsyncPipe() + + async def list_to_generator(lst): + for item in lst: + yield item + + input = pipe.Input(message=list_to_generator(["test"])) + state = AsyncState() + try: + async_generator = await pipe.run(input, state) + results = [result async for result in async_generator] + assert results == ["processed"] + except asyncio.CancelledError: + pass # Task cancelled as expected + + +def test_prompt_initialization_and_formatting(): + prompt = Prompt( + name="greet", template="Hello, {name}!", input_types={"name": "str"} + ) + formatted = prompt.format_prompt({"name": "Alice"}) + assert formatted == "Hello, Alice!" + + +def test_prompt_missing_input(): + prompt = Prompt( + name="greet", template="Hello, {name}!", input_types={"name": "str"} + ) + with pytest.raises(ValueError): + prompt.format_prompt({}) + + +def test_prompt_invalid_input_type(): + prompt = Prompt( + name="greet", template="Hello, {name}!", input_types={"name": "int"} + ) + with pytest.raises(TypeError): + prompt.format_prompt({"name": "Alice"}) + + +def test_search_request_with_optional_filters(): + request = VectorSearchRequest( + query="test", limit=10, filters={"category": "books"} + ) + assert request.query == "test" + assert request.limit == 10 + assert request.filters == {"category": "books"} + + +def test_search_result_to_string(): + result = VectorSearchResult( + id=generate_id_from_label("1"), + score=9.5, + metadata={"author": "John Doe"}, + ) + result_str = str(result) + assert ( + result_str + == f"VectorSearchResult(id={str(generate_id_from_label('1'))}, score=9.5, metadata={{'author': 'John Doe'}})" + ) + + +def test_search_result_repr(): + result = VectorSearchResult( + id=generate_id_from_label("1"), + score=9.5, + metadata={"author": "John Doe"}, + ) + assert ( + repr(result) + == f"VectorSearchResult(id={str(generate_id_from_label('1'))}, score=9.5, metadata={{'author': 'John Doe'}})" + ) + + +def test_vector_fixed_length_validation(): + with pytest.raises(ValueError): + Vector(data=[1.0, 2.0], type=VectorType.FIXED, length=3) + + +def test_vector_entry_serialization(): + vector = Vector(data=[1.0, 2.0], type=VectorType.FIXED, length=2) + entry_id = uuid.uuid4() + entry = VectorEntry( + id=entry_id, vector=vector, metadata={"key": uuid.uuid4()} + ) + serializable = entry.to_serializable() + assert serializable["id"] == str(entry_id) + assert serializable["vector"] == [1.0, 2.0] + assert isinstance( + serializable["metadata"]["key"], str + ) # Check UUID conversion to string diff --git a/R2R/tests/test_config.py b/R2R/tests/test_config.py new file mode 100755 index 00000000..5e60833c --- /dev/null +++ b/R2R/tests/test_config.py @@ -0,0 +1,187 @@ +import asyncio +import json +from unittest.mock import Mock, mock_open, patch + +import pytest + +from r2r import DocumentType, R2RConfig + + +@pytest.fixture(scope="session", autouse=True) +def event_loop_policy(): + asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy()) + + +@pytest.fixture(scope="function") +def event_loop(): + loop = asyncio.get_event_loop_policy().new_event_loop() + yield loop + loop.close() + asyncio.set_event_loop(None) + + +@pytest.fixture(scope="session", autouse=True) +async def cleanup_tasks(): + yield + for task in asyncio.all_tasks(): + if task is not asyncio.current_task(): + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + +@pytest.fixture +def mock_bad_file(): + mock_data = json.dumps({}) + with patch("builtins.open", mock_open(read_data=mock_data)) as m: + yield m + + +@pytest.fixture +def mock_file(): + mock_data = json.dumps( + { + "app": {"max_file_size_in_mb": 128}, + "embedding": { + "provider": "example_provider", + "base_model": "model", + "base_dimension": 128, + "batch_size": 16, + "text_splitter": "default", + }, + "kg": { + "provider": "None", + "batch_size": 1, + "text_splitter": { + "type": "recursive_character", + "chunk_size": 2048, + "chunk_overlap": 0, + }, + }, + "eval": {"llm": {"provider": "local"}}, + "ingestion": {"excluded_parsers": {}}, + "completions": {"provider": "lm_provider"}, + "logging": { + "provider": "local", + "log_table": "logs", + "log_info_table": "log_info", + }, + "prompt": {"provider": "prompt_provider"}, + "vector_database": {"provider": "vector_db"}, + } + ) + with patch("builtins.open", mock_open(read_data=mock_data)) as m: + yield m + + +@pytest.mark.asyncio +async def test_r2r_config_loading_required_keys(mock_bad_file): + with pytest.raises(KeyError): + R2RConfig.from_json("config.json") + + +@pytest.mark.asyncio +async def test_r2r_config_loading(mock_file): + config = R2RConfig.from_json("config.json") + assert ( + config.embedding.provider == "example_provider" + ), "Provider should match the mock data" + + +@pytest.fixture +def mock_redis_client(): + client = Mock() + return client + + +def test_r2r_config_serialization(mock_file, mock_redis_client): + config = R2RConfig.from_json("config.json") + config.save_to_redis(mock_redis_client, "test_key") + mock_redis_client.set.assert_called_once() + saved_data = json.loads(mock_redis_client.set.call_args[0][1]) + assert saved_data["app"]["max_file_size_in_mb"] == 128 + + +def test_r2r_config_deserialization(mock_file, mock_redis_client): + config_data = { + "app": {"max_file_size_in_mb": 128}, + "embedding": { + "provider": "example_provider", + "base_model": "model", + "base_dimension": 128, + "batch_size": 16, + "text_splitter": "default", + }, + "kg": { + "provider": "None", + "batch_size": 1, + "text_splitter": { + "type": "recursive_character", + "chunk_size": 2048, + "chunk_overlap": 0, + }, + }, + "eval": {"llm": {"provider": "local"}}, + "ingestion": {"excluded_parsers": ["pdf"]}, + "completions": {"provider": "lm_provider"}, + "logging": { + "provider": "local", + "log_table": "logs", + "log_info_table": "log_info", + }, + "prompt": {"provider": "prompt_provider"}, + "vector_database": {"provider": "vector_db"}, + } + mock_redis_client.get.return_value = json.dumps(config_data) + config = R2RConfig.load_from_redis(mock_redis_client, "test_key") + assert config.app["max_file_size_in_mb"] == 128 + assert DocumentType.PDF in config.ingestion["excluded_parsers"] + + +def test_r2r_config_missing_section(): + invalid_data = { + "embedding": { + "provider": "example_provider", + "base_model": "model", + "base_dimension": 128, + "batch_size": 16, + "text_splitter": "default", + } + } + with patch("builtins.open", mock_open(read_data=json.dumps(invalid_data))): + with pytest.raises(KeyError): + R2RConfig.from_json("config.json") + + +def test_r2r_config_missing_required_key(): + invalid_data = { + "app": {"max_file_size_in_mb": 128}, + "embedding": { + "base_model": "model", + "base_dimension": 128, + "batch_size": 16, + "text_splitter": "default", + }, + "kg": { + "provider": "None", + "batch_size": 1, + "text_splitter": { + "type": "recursive_character", + "chunk_size": 2048, + "chunk_overlap": 0, + }, + }, + "completions": {"provider": "lm_provider"}, + "logging": { + "provider": "local", + "log_table": "logs", + "log_info_table": "log_info", + }, + "prompt": {"provider": "prompt_provider"}, + "vector_database": {"provider": "vector_db"}, + } + with patch("builtins.open", mock_open(read_data=json.dumps(invalid_data))): + with pytest.raises(KeyError): + R2RConfig.from_json("config.json") diff --git a/R2R/tests/test_embedding.py b/R2R/tests/test_embedding.py new file mode 100755 index 00000000..7a3e760a --- /dev/null +++ b/R2R/tests/test_embedding.py @@ -0,0 +1,162 @@ +import asyncio + +import pytest + +from r2r import EmbeddingConfig, VectorSearchResult, generate_id_from_label +from r2r.providers.embeddings import ( + OpenAIEmbeddingProvider, + SentenceTransformerEmbeddingProvider, +) + + +@pytest.fixture(scope="session", autouse=True) +def event_loop_policy(): + asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy()) + + +@pytest.fixture(scope="function") +def event_loop(): + loop = asyncio.get_event_loop_policy().new_event_loop() + yield loop + loop.close() + asyncio.set_event_loop(None) + + +@pytest.fixture(scope="session", autouse=True) +async def cleanup_tasks(): + yield + for task in asyncio.all_tasks(): + if task is not asyncio.current_task(): + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + +@pytest.fixture +def openai_provider(): + config = EmbeddingConfig( + provider="openai", + base_model="text-embedding-3-small", + base_dimension=1536, + ) + return OpenAIEmbeddingProvider(config) + + +def test_openai_initialization(openai_provider): + assert isinstance(openai_provider, OpenAIEmbeddingProvider) + assert openai_provider.base_model == "text-embedding-3-small" + assert openai_provider.base_dimension == 1536 + + +def test_openai_invalid_provider_initialization(): + config = EmbeddingConfig(provider="invalid_provider") + with pytest.raises(ValueError): + OpenAIEmbeddingProvider(config) + + +def test_openai_get_embedding(openai_provider): + embedding = openai_provider.get_embedding("test text") + assert len(embedding) == 1536 + assert isinstance(embedding, list) + + +@pytest.mark.asyncio +async def test_openai_async_get_embedding(openai_provider): + try: + embedding = await openai_provider.async_get_embedding("test text") + assert len(embedding) == 1536 + assert isinstance(embedding, list) + except asyncio.CancelledError: + pass # Task cancelled as expected + + +def test_openai_get_embeddings(openai_provider): + embeddings = openai_provider.get_embeddings(["text1", "text2"]) + assert len(embeddings) == 2 + assert all(len(emb) == 1536 for emb in embeddings) + + +@pytest.mark.asyncio +async def test_openai_async_get_embeddings(openai_provider): + try: + embeddings = await openai_provider.async_get_embeddings( + ["text1", "text2"] + ) + assert len(embeddings) == 2 + assert all(len(emb) == 1536 for emb in embeddings) + except asyncio.CancelledError: + pass # Task cancelled as expected + + +def test_openai_tokenize_string(openai_provider): + tokens = openai_provider.tokenize_string( + "test text", "text-embedding-3-small" + ) + assert isinstance(tokens, list) + assert all(isinstance(token, int) for token in tokens) + + +@pytest.fixture +def sentence_transformer_provider(): + config = EmbeddingConfig( + provider="sentence-transformers", + base_model="mixedbread-ai/mxbai-embed-large-v1", + base_dimension=512, + rerank_model="jinaai/jina-reranker-v1-turbo-en", + rerank_dimension=384, + ) + return SentenceTransformerEmbeddingProvider(config) + + +def test_sentence_transformer_initialization(sentence_transformer_provider): + assert isinstance( + sentence_transformer_provider, SentenceTransformerEmbeddingProvider + ) + assert sentence_transformer_provider.do_search + # assert sentence_transformer_provider.do_rerank + + +def test_sentence_transformer_invalid_provider_initialization(): + config = EmbeddingConfig(provider="invalid_provider") + with pytest.raises(ValueError): + SentenceTransformerEmbeddingProvider(config) + + +def test_sentence_transformer_get_embedding(sentence_transformer_provider): + embedding = sentence_transformer_provider.get_embedding("test text") + assert len(embedding) == 512 + assert isinstance(embedding, list) + + +def test_sentence_transformer_get_embeddings(sentence_transformer_provider): + embeddings = sentence_transformer_provider.get_embeddings( + ["text1", "text2"] + ) + assert len(embeddings) == 2 + assert all(len(emb) == 512 for emb in embeddings) + + +def test_sentence_transformer_rerank(sentence_transformer_provider): + results = [ + VectorSearchResult( + id=generate_id_from_label("x"), + score=0.9, + metadata={"text": "doc1"}, + ), + VectorSearchResult( + id=generate_id_from_label("y"), + score=0.8, + metadata={"text": "doc2"}, + ), + ] + reranked_results = sentence_transformer_provider.rerank("query", results) + assert len(reranked_results) == 2 + assert reranked_results[0].metadata["text"] == "doc1" + assert reranked_results[1].metadata["text"] == "doc2" + + +def test_sentence_transformer_tokenize_string(sentence_transformer_provider): + with pytest.raises(ValueError): + sentence_transformer_provider.tokenize_string("test text") diff --git a/R2R/tests/test_end_to_end.py b/R2R/tests/test_end_to_end.py new file mode 100755 index 00000000..5e13ab5c --- /dev/null +++ b/R2R/tests/test_end_to_end.py @@ -0,0 +1,375 @@ +import asyncio +import os +import uuid + +import pytest +from fastapi.datastructures import UploadFile + +from r2r import ( + Document, + KVLoggingSingleton, + R2RConfig, + R2REngine, + R2RPipeFactory, + R2RPipelineFactory, + R2RProviderFactory, + VectorSearchSettings, + generate_id_from_label, +) +from r2r.base.abstractions.llm import GenerationConfig + + +@pytest.fixture(scope="session", autouse=True) +def event_loop_policy(): + asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy()) + + +@pytest.fixture(scope="function") +def event_loop(): + loop = asyncio.get_event_loop_policy().new_event_loop() + yield loop + loop.close() + asyncio.set_event_loop(None) + + +@pytest.fixture(scope="session", autouse=True) +async def cleanup_tasks(): + yield + for task in asyncio.all_tasks(): + if task is not asyncio.current_task(): + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + +@pytest.fixture(scope="function") +def app(request): + config = R2RConfig.from_json() + config.logging.provider = "local" + config.logging.logging_path = uuid.uuid4().hex + + vector_db_provider = request.param + if vector_db_provider == "pgvector": + config.vector_database.provider = "pgvector" + config.vector_database.extra_fields["vecs_collection"] = ( + config.logging.logging_path + ) + try: + providers = R2RProviderFactory(config).create_providers() + pipes = R2RPipeFactory(config, providers).create_pipes() + pipelines = R2RPipelineFactory(config, pipes).create_pipelines() + + r2r = R2REngine( + config=config, + providers=providers, + pipelines=pipelines, + ) + + try: + KVLoggingSingleton.configure(config.logging) + except: + KVLoggingSingleton._config.logging_path = ( + config.logging.logging_path + ) + + yield r2r + finally: + if os.path.exists(config.logging.logging_path): + os.remove(config.logging.logging_path) + + +@pytest.fixture +def logging_connection(): + return KVLoggingSingleton() + + +@pytest.mark.parametrize("app", ["pgvector"], indirect=True) +@pytest.mark.asyncio +async def test_ingest_txt_document(app, logging_connection): + try: + await app.aingest_documents( + [ + Document( + id=generate_id_from_label("doc_1"), + data="The quick brown fox jumps over the lazy dog.", + type="txt", + metadata={"author": "John Doe"}, + ), + ] + ) + except asyncio.CancelledError: + pass + + +@pytest.mark.parametrize("app", ["pgvector"], indirect=True) +@pytest.mark.asyncio +async def test_ingest_txt_file(app, logging_connection): + try: + # Prepare the test data + metadata = {"author": "John Doe"} + files = [ + UploadFile( + filename="test.txt", + file=open( + os.path.join( + os.path.dirname(__file__), + "..", + "r2r", + "examples", + "data", + "test.txt", + ), + "rb", + ), + ) + ] + # Set file size manually + for file in files: + file.file.seek(0, 2) # Move to the end of the file + file.size = file.file.tell() # Get the file size + file.file.seek(0) # Move back to the start of the file + + await app.aingest_files(metadatas=[metadata], files=files) + except asyncio.CancelledError: + pass + + +@pytest.mark.parametrize("app", ["pgvector"], indirect=True) +@pytest.mark.asyncio +async def test_ingest_search_txt_file(app, logging_connection): + try: + # Prepare the test data + metadata = {} + files = [ + UploadFile( + filename="aristotle.txt", + file=open( + os.path.join( + os.path.dirname(__file__), + "..", + "r2r", + "examples", + "data", + "aristotle.txt", + ), + "rb", + ), + ), + ] + + # Set file size manually + for file in files: + file.file.seek(0, 2) # Move to the end of the file + file.size = file.file.tell() # Get the file size + file.file.seek(0) # Move back to the start of the file + + await app.aingest_files(metadatas=[metadata], files=files) + + search_results = await app.asearch("who was aristotle?") + assert len(search_results["vector_search_results"]) == 10 + assert ( + "was an Ancient Greek philosopher and polymath" + in search_results["vector_search_results"][0]["metadata"]["text"] + ) + + search_results = await app.asearch( + "who was aristotle?", + vector_search_settings=VectorSearchSettings(search_limit=20), + ) + assert len(search_results["vector_search_results"]) == 20 + assert ( + "was an Ancient Greek philosopher and polymath" + in search_results["vector_search_results"][0]["metadata"]["text"] + ) + run_info = await logging_connection.get_run_info( + log_type_filter="search" + ) + + assert len(run_info) == 2, f"Expected 2 runs, but got {len(run_info)}" + + logs = await logging_connection.get_logs( + [run.run_id for run in run_info], 100 + ) + assert len(logs) == 6, f"Expected 6 logs, but got {len(logs)}" + + ## test stream + response = await app.arag( + query="Who was aristotle?", + rag_generation_config=GenerationConfig( + **{"model": "gpt-3.5-turbo", "stream": True} + ), + ) + collector = "" + async for chunk in response: + collector += chunk + assert "Aristotle" in collector + assert "Greek" in collector + assert "philosopher" in collector + assert "polymath" in collector + assert "Ancient" in collector + except asyncio.CancelledError: + pass + + +@pytest.mark.parametrize("app", ["pgvector"], indirect=True) +@pytest.mark.asyncio +async def test_ingest_search_then_delete(app, logging_connection): + try: + # Ingest a document + await app.aingest_documents( + [ + Document( + id=generate_id_from_label("doc_1"), + data="The quick brown fox jumps over the lazy dog.", + type="txt", + metadata={"author": "John Doe"}, + ), + ] + ) + + # Search for the document + search_results = await app.asearch("who was aristotle?") + + # Verify that the search results are not empty + assert ( + len(search_results["vector_search_results"]) > 0 + ), "Expected search results, but got none" + assert ( + search_results["vector_search_results"][0]["metadata"]["text"] + == "The quick brown fox jumps over the lazy dog." + ) + + # Delete the document + delete_result = await app.adelete(["author"], ["John Doe"]) + + # Verify the deletion was successful + expected_deletion_message = "deleted successfully" + assert ( + expected_deletion_message in delete_result + ), f"Expected successful deletion message, but got {delete_result}" + + # Search for the document again + search_results_2 = await app.asearch("who was aristotle?") + + # Verify that the search results are empty + assert ( + len(search_results_2["vector_search_results"]) == 0 + ), f"Expected no search results, but got {search_results_2['results']}" + except asyncio.CancelledError: + pass + + +@pytest.mark.parametrize("app", ["local", "pgvector"], indirect=True) +@pytest.mark.asyncio +async def test_ingest_user_documents(app, logging_connection): + try: + user_id_0 = generate_id_from_label("user_0") + user_id_1 = generate_id_from_label("user_1") + + try: + await app.aingest_documents( + [ + Document( + id=generate_id_from_label("doc_01"), + data="The quick brown fox jumps over the lazy dog.", + type="txt", + metadata={"author": "John Doe", "user_id": user_id_0}, + ), + Document( + id=generate_id_from_label("doc_11"), + data="The lazy dog jumps over the quick brown fox.", + type="txt", + metadata={"author": "John Doe", "user_id": user_id_1}, + ), + ] + ) + user_id_results = await app.ausers_overview([user_id_0, user_id_1]) + assert set([stats.user_id for stats in user_id_results]) == set( + [user_id_0, user_id_1] + ), f"Expected user ids {user_id_0} and {user_id_1}, but got {user_id_results}" + + user_0_docs = await app.adocuments_overview(user_ids=[user_id_0]) + user_1_docs = await app.adocuments_overview(user_ids=[user_id_1]) + + assert ( + len(user_0_docs) == 1 + ), f"Expected 1 document for user {user_id_0}, but got {len(user_0_docs)}" + assert ( + len(user_1_docs) == 1 + ), f"Expected 1 document for user {user_id_1}, but got {len(user_1_docs)}" + assert user_0_docs[0].document_id == generate_id_from_label( + "doc_01" + ), f"Expected document id {str(generate_id_from_label('doc_0'))} for user {user_id_0}, but got {user_0_docs[0].document_id}" + assert user_1_docs[0].document_id == generate_id_from_label( + "doc_11" + ), f"Expected document id {str(generate_id_from_label('doc_1'))} for user {user_id_1}, but got {user_1_docs[0].document_id}" + finally: + await app.adelete( + ["document_id", "document_id"], + [ + str(generate_id_from_label("doc_01")), + str(generate_id_from_label("doc_11")), + ], + ) + except asyncio.CancelledError: + pass + + +@pytest.mark.parametrize("app", ["pgvector"], indirect=True) +@pytest.mark.asyncio +async def test_delete_by_id(app, logging_connection): + try: + await app.aingest_documents( + [ + Document( + id=generate_id_from_label("doc_1"), + data="The quick brown fox jumps over the lazy dog.", + type="txt", + metadata={"author": "John Doe"}, + ), + ] + ) + search_results = await app.asearch("who was aristotle?") + + assert len(search_results["vector_search_results"]) > 0 + await app.adelete( + ["document_id"], [str(generate_id_from_label("doc_1"))] + ) + search_results = await app.asearch("who was aristotle?") + assert len(search_results["vector_search_results"]) == 0 + except asyncio.CancelledError: + pass + + +@pytest.mark.parametrize("app", ["pgvector"], indirect=True) +@pytest.mark.asyncio +async def test_double_ingest(app, logging_connection): + try: + await app.aingest_documents( + [ + Document( + id=generate_id_from_label("doc_1"), + data="The quick brown fox jumps over the lazy dog.", + type="txt", + metadata={"author": "John Doe"}, + ), + ] + ) + search_results = await app.asearch("who was aristotle?") + + assert len(search_results["vector_search_results"]) == 1 + with pytest.raises(Exception): + await app.aingest_documents( + [ + Document( + id=generate_id_from_label("doc_1"), + data="The quick brown fox jumps over the lazy dog.", + type="txt", + metadata={"author": "John Doe"}, + ), + ] + ) + except asyncio.CancelledError: + pass diff --git a/R2R/tests/test_ingestion_service.py b/R2R/tests/test_ingestion_service.py new file mode 100755 index 00000000..375e51f9 --- /dev/null +++ b/R2R/tests/test_ingestion_service.py @@ -0,0 +1,443 @@ +import asyncio +import io +import uuid +from datetime import datetime +from unittest.mock import AsyncMock, MagicMock, Mock + +import pytest +from fastapi import UploadFile + +from r2r.base import ( + Document, + DocumentInfo, + R2RDocumentProcessingError, + R2RException, + generate_id_from_label, +) +from r2r.main import R2RPipelines, R2RProviders +from r2r.main.services.ingestion_service import IngestionService + + +@pytest.fixture(scope="session", autouse=True) +def event_loop_policy(): + asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy()) + + +@pytest.fixture(scope="function") +def event_loop(): + loop = asyncio.get_event_loop_policy().new_event_loop() + yield loop + loop.close() + asyncio.set_event_loop(None) + + +@pytest.fixture(scope="session", autouse=True) +async def cleanup_tasks(): + yield + for task in asyncio.all_tasks(): + if task is not asyncio.current_task(): + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + +@pytest.fixture +def mock_vector_db(): + mock_db = MagicMock() + mock_db.get_documents_overview.return_value = [] # Default to empty list + return mock_db + + +@pytest.fixture +def mock_embedding_model(): + return Mock() + + +@pytest.fixture +def ingestion_service(mock_vector_db, mock_embedding_model): + config = MagicMock() + config.app.get.return_value = 32 # Default max file size + providers = Mock(spec=R2RProviders) + providers.vector_db = mock_vector_db + providers.embedding_model = mock_embedding_model + pipelines = Mock(spec=R2RPipelines) + pipelines.ingestion_pipeline = AsyncMock() + pipelines.ingestion_pipeline.run.return_value = { + "embedding_pipeline_output": [] + } + run_manager = Mock() + run_manager.run_info = {"mock_run_id": {}} + logging_connection = AsyncMock() + + return IngestionService( + config, providers, pipelines, run_manager, logging_connection + ) + + +@pytest.mark.asyncio +async def test_ingest_single_document(ingestion_service, mock_vector_db): + try: + document = Document( + id=generate_id_from_label("test_id"), + data="Test content", + type="txt", + metadata={}, + ) + + ingestion_service.pipelines.ingestion_pipeline.run.return_value = { + "embedding_pipeline_output": [(document.id, None)] + } + mock_vector_db.get_documents_overview.return_value = ( + [] + ) # No existing documents + + result = await ingestion_service.ingest_documents([document]) + + assert result["processed_documents"] == [ + f"Document '{document.id}' processed successfully." + ] + assert not result["failed_documents"] + assert not result["skipped_documents"] + except asyncio.CancelledError: + pass + + +@pytest.mark.asyncio +async def test_ingest_duplicate_document(ingestion_service, mock_vector_db): + try: + document = Document( + id=generate_id_from_label("test_id"), + data="Test content", + type="txt", + metadata={}, + ) + mock_vector_db.get_documents_overview.return_value = [ + DocumentInfo( + document_id=document.id, + version="v0", + size_in_bytes=len(document.data), + metadata={}, + title=str(document.id), + user_id=None, + created_at=datetime.now(), + updated_at=datetime.now(), + status="success", + ) + ] + + with pytest.raises(R2RException) as exc_info: + await ingestion_service.ingest_documents([document]) + + assert ( + f"Document with ID {document.id} was already successfully processed" + in str(exc_info.value) + ) + except asyncio.CancelledError: + pass + + +@pytest.mark.asyncio +async def test_ingest_file(ingestion_service): + try: + file_content = b"Test content" + file_mock = UploadFile( + filename="test.txt", file=io.BytesIO(file_content) + ) + file_mock.file.seek(0) + file_mock.size = len(file_content) # Set file size manually + + ingestion_service.pipelines.ingestion_pipeline.run.return_value = { + "embedding_pipeline_output": [ + (generate_id_from_label("test.txt"), None) + ] + } + + result = await ingestion_service.ingest_files([file_mock]) + + assert len(result["processed_documents"]) == 1 + assert not result["failed_documents"] + assert not result["skipped_documents"] + except asyncio.CancelledError: + pass + + +@pytest.mark.asyncio +async def test_ingest_mixed_success_and_failure( + ingestion_service, mock_vector_db +): + try: + documents = [ + Document( + id=generate_id_from_label("success_id"), + data="Success content", + type="txt", + metadata={}, + ), + Document( + id=generate_id_from_label("failure_id"), + data="Failure content", + type="txt", + metadata={}, + ), + ] + + ingestion_service.pipelines.ingestion_pipeline.run.return_value = { + "embedding_pipeline_output": [ + ( + documents[0].id, + f"Processed 1 vectors for document {documents[0].id}.", + ), + ( + documents[1].id, + R2RDocumentProcessingError( + error_message="Embedding failed", + document_id=documents[1].id, + ), + ), + ] + } + + result = await ingestion_service.ingest_documents(documents) + + assert len(result["processed_documents"]) == 1 + assert len(result["failed_documents"]) == 1 + assert str(documents[0].id) in result["processed_documents"][0] + assert str(documents[1].id) in result["failed_documents"][0] + assert "Embedding failed" in result["failed_documents"][0] + + assert mock_vector_db.upsert_documents_overview.call_count == 2 + upserted_docs = mock_vector_db.upsert_documents_overview.call_args[0][ + 0 + ] + assert len(upserted_docs) == 2 + assert upserted_docs[0].document_id == documents[0].id + assert upserted_docs[0].status == "success" + assert upserted_docs[1].document_id == documents[1].id + assert upserted_docs[1].status == "failure" + except asyncio.CancelledError: + pass + + +@pytest.mark.asyncio +async def test_ingest_unsupported_file_type(ingestion_service): + try: + file_mock = UploadFile( + filename="test.unsupported", file=io.BytesIO(b"Test content") + ) + file_mock.file.seek(0) + file_mock.size = 12 # Set file size manually + + with pytest.raises(R2RException) as exc_info: + await ingestion_service.ingest_files([file_mock]) + + assert "is not a valid DocumentType" in str(exc_info.value) + except asyncio.CancelledError: + pass + + +@pytest.mark.asyncio +async def test_ingest_large_file(ingestion_service): + try: + large_content = b"Large content" * 1000000 # 12MB content + file_mock = UploadFile( + filename="large_file.txt", file=io.BytesIO(large_content) + ) + file_mock.file.seek(0) + file_mock.size = len(large_content) # Set file size manually + + ingestion_service.config.app.get.return_value = ( + 10 # Set max file size to 10MB + ) + + with pytest.raises(R2RException) as exc_info: + await ingestion_service.ingest_files([file_mock]) + + assert "File size exceeds maximum allowed size" in str(exc_info.value) + except asyncio.CancelledError: + pass + + +@pytest.mark.asyncio +async def test_partial_ingestion_success(ingestion_service, mock_vector_db): + try: + documents = [ + Document( + id=generate_id_from_label("success_1"), + data="Success content 1", + type="txt", + metadata={}, + ), + Document( + id=generate_id_from_label("fail"), + data="Fail content", + type="txt", + metadata={}, + ), + Document( + id=generate_id_from_label("success_2"), + data="Success content 2", + type="txt", + metadata={}, + ), + ] + + ingestion_service.pipelines.ingestion_pipeline.run.return_value = { + "embedding_pipeline_output": [ + (documents[0].id, None), + ( + documents[1].id, + R2RDocumentProcessingError( + error_message="Embedding failed", + document_id=documents[1].id, + ), + ), + (documents[2].id, None), + ] + } + + result = await ingestion_service.ingest_documents(documents) + + assert len(result["processed_documents"]) == 2 + assert len(result["failed_documents"]) == 1 + assert str(documents[1].id) in result["failed_documents"][0] + except asyncio.CancelledError: + pass + + +@pytest.mark.asyncio +async def test_version_increment(ingestion_service, mock_vector_db): + try: + document = Document( + id=generate_id_from_label("test_id"), + data="Test content", + type="txt", + metadata={}, + ) + mock_vector_db.get_documents_overview.return_value = [ + DocumentInfo( + document_id=document.id, + version="v2", + status="success", + size_in_bytes=0, + metadata={}, + ) + ] + + file_mock = UploadFile( + filename="test.txt", file=io.BytesIO(b"Updated content") + ) + await ingestion_service.update_files([file_mock], [document.id]) + + calls = mock_vector_db.upsert_documents_overview.call_args_list + assert len(calls) == 2 + assert calls[1][0][0][0].version == "v3" + except asyncio.CancelledError: + pass + + +@pytest.mark.asyncio +async def test_process_ingestion_results_error_handling(ingestion_service): + try: + document_infos = [ + DocumentInfo( + document_id=uuid.uuid4(), + version="v0", + status="processing", + size_in_bytes=0, + metadata={}, + ) + ] + ingestion_results = { + "embedding_pipeline_output": [ + ( + document_infos[0].document_id, + R2RDocumentProcessingError( + "Unexpected error", + document_id=document_infos[0].document_id, + ), + ) + ] + } + + result = await ingestion_service._process_ingestion_results( + ingestion_results, + document_infos, + [], + {document_infos[0].document_id: "test"}, + ) + + assert len(result["failed_documents"]) == 1 + assert "Unexpected error" in result["failed_documents"][0] + except asyncio.CancelledError: + pass + + +@pytest.mark.asyncio +async def test_file_size_limit_edge_cases(ingestion_service): + try: + ingestion_service.config.app.get.return_value = 1 # 1MB limit + + just_under_limit = b"x" * (1024 * 1024 - 1) + at_limit = b"x" * (1024 * 1024) + over_limit = b"x" * (1024 * 1024 + 1) + + file_under = UploadFile( + filename="under.txt", + file=io.BytesIO(just_under_limit), + size=1024 * 1024 - 1, + ) + file_at = UploadFile( + filename="at.txt", file=io.BytesIO(at_limit), size=1024 * 1024 + ) + file_over = UploadFile( + filename="over.txt", + file=io.BytesIO(over_limit), + size=1024 * 1024 + 1, + ) + + await ingestion_service.ingest_files([file_under]) # Should succeed + await ingestion_service.ingest_files([file_at]) # Should succeed + + with pytest.raises( + R2RException, match="File size exceeds maximum allowed size" + ): + await ingestion_service.ingest_files([file_over]) + except asyncio.CancelledError: + pass + + +@pytest.mark.asyncio +async def test_document_status_update_after_ingestion( + ingestion_service, mock_vector_db +): + try: + document = Document( + id=generate_id_from_label("test_id"), + data="Test content", + type="txt", + metadata={}, + ) + + ingestion_service.pipelines.ingestion_pipeline.run.return_value = { + "embedding_pipeline_output": [(document.id, None)] + } + mock_vector_db.get_documents_overview.return_value = ( + [] + ) # No existing documents + + await ingestion_service.ingest_documents([document]) + + # Check that upsert_documents_overview was called twice + assert mock_vector_db.upsert_documents_overview.call_count == 2 + + # Check the second call to upsert_documents_overview (status update) + second_call_args = ( + mock_vector_db.upsert_documents_overview.call_args_list[1][0][0] + ) + assert len(second_call_args) == 1 + assert second_call_args[0].document_id == document.id + assert second_call_args[0].status == "success" + except asyncio.CancelledError: + pass diff --git a/R2R/tests/test_llms.py b/R2R/tests/test_llms.py new file mode 100755 index 00000000..666bbff8 --- /dev/null +++ b/R2R/tests/test_llms.py @@ -0,0 +1,59 @@ +import pytest + +from r2r import LLMConfig +from r2r.base.abstractions.llm import GenerationConfig +from r2r.providers.llms import LiteLLM + + +@pytest.fixture +def lite_llm(): + config = LLMConfig(provider="litellm") + return LiteLLM(config) + + +@pytest.mark.parametrize("llm_fixture", ["lite_llm"]) +def test_get_completion_ollama(request, llm_fixture): + llm = request.getfixturevalue(llm_fixture) + + messages = [ + { + "role": "user", + "content": "This is a test, return only the word `True`", + } + ] + generation_config = GenerationConfig( + model="ollama/llama2", + temperature=0.0, + top_p=0.9, + max_tokens_to_sample=50, + stream=False, + ) + + completion = llm.get_completion(messages, generation_config) + # assert isinstance(completion, LLMChatCompletion) + assert completion.choices[0].message.role == "assistant" + assert completion.choices[0].message.content.strip() == "True" + + +@pytest.mark.parametrize("llm_fixture", ["lite_llm"]) +def test_get_completion_openai(request, llm_fixture): + llm = request.getfixturevalue(llm_fixture) + + messages = [ + { + "role": "user", + "content": "This is a test, return only the word `True`", + } + ] + generation_config = GenerationConfig( + model="gpt-3.5-turbo", + temperature=0.0, + top_p=0.9, + max_tokens_to_sample=50, + stream=False, + ) + + completion = llm.get_completion(messages, generation_config) + # assert isinstance(completion, LLMChatCompletion) + assert completion.choices[0].message.role == "assistant" + assert completion.choices[0].message.content.strip() == "True" diff --git a/R2R/tests/test_logging.py b/R2R/tests/test_logging.py new file mode 100755 index 00000000..cab5051d --- /dev/null +++ b/R2R/tests/test_logging.py @@ -0,0 +1,360 @@ +import asyncio +import logging +import os +import uuid + +import pytest + +from r2r import ( + LocalKVLoggingProvider, + LoggingConfig, + PostgresKVLoggingProvider, + PostgresLoggingConfig, + RedisKVLoggingProvider, + RedisLoggingConfig, + generate_run_id, +) + +logger = logging.getLogger(__name__) + + +@pytest.fixture(scope="session", autouse=True) +def event_loop_policy(): + asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy()) + + +@pytest.fixture(scope="function") +def event_loop(): + loop = asyncio.get_event_loop_policy().new_event_loop() + yield loop + loop.close() + asyncio.set_event_loop(None) + + +@pytest.fixture(scope="session", autouse=True) +async def cleanup_tasks(): + yield + for task in asyncio.all_tasks(): + if task is not asyncio.current_task(): + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + +@pytest.fixture(scope="function") +def local_provider(): + """Fixture to create and tear down the LocalKVLoggingProvider with a unique database file.""" + # Generate a unique file name for the SQLite database + unique_id = str(uuid.uuid4()) + logging_path = f"test_{unique_id}.sqlite" + + # Setup the LocalKVLoggingProvider with the unique file + provider = LocalKVLoggingProvider(LoggingConfig(logging_path=logging_path)) + + # Provide the setup provider to the test + yield provider + + # Cleanup: Remove the SQLite file after test completes + provider.close() + if os.path.exists(logging_path): + os.remove(logging_path) + + +@pytest.mark.asyncio +async def test_local_logging(local_provider): + """Test logging and retrieving from the local logging provider.""" + try: + run_id = generate_run_id() + await local_provider.init() + await local_provider.log(run_id, "key", "value") + logs = await local_provider.get_logs([run_id]) + assert len(logs) == 1 + assert logs[0]["key"] == "key" + assert logs[0]["value"] == "value" + except asyncio.CancelledError: + pass + + +@pytest.mark.asyncio +async def test_multiple_log_entries(local_provider): + """Test logging multiple entries and retrieving them.""" + try: + run_id_0 = generate_run_id() + run_id_1 = generate_run_id() + run_id_2 = generate_run_id() + await local_provider.init() + + entries = [ + (run_id_0, "key_0", "value_0"), + (run_id_1, "key_1", "value_1"), + (run_id_2, "key_2", "value_2"), + ] + for run_id, key, value in entries: + await local_provider.log(run_id, key, value) + + logs = await local_provider.get_logs([run_id_0, run_id_1, run_id_2]) + assert len(logs) == 3 + + # Check that logs are returned in the correct order (most recent first if applicable) + for log in logs: + selected_entry = [ + entry for entry in entries if entry[0] == log["log_id"] + ][0] + assert log["log_id"] == selected_entry[0] + assert log["key"] == selected_entry[1] + assert log["value"] == selected_entry[2] + except asyncio.CancelledError: + pass + + +@pytest.mark.asyncio +async def test_log_retrieval_limit(local_provider): + """Test the max_logs limit parameter works correctly.""" + try: + await local_provider.init() + + run_ids = [] + for i in range(10): # Add 10 entries + run_ids.append(generate_run_id()) + await local_provider.log(run_ids[-1], f"key_{i}", f"value_{i}") + + logs = await local_provider.get_logs(run_ids[0:5]) + assert len(logs) == 5 # Ensure only 5 logs are returned + except asyncio.CancelledError: + pass + + +@pytest.mark.asyncio +async def test_specific_run_type_retrieval(local_provider): + """Test retrieving logs for a specific run type works correctly.""" + try: + await local_provider.init() + run_id_0 = generate_run_id() + run_id_1 = generate_run_id() + + await local_provider.log( + run_id_0, "pipeline_type", "search", is_info_log=True + ) + await local_provider.log(run_id_0, "key_0", "value_0") + await local_provider.log( + run_id_1, "pipeline_type", "rag", is_info_log=True + ) + await local_provider.log(run_id_1, "key_1", "value_1") + + run_info = await local_provider.get_run_info(log_type_filter="search") + logs = await local_provider.get_logs([run.run_id for run in run_info]) + assert len(logs) == 1 + assert logs[0]["log_id"] == run_id_0 + assert logs[0]["key"] == "key_0" + assert logs[0]["value"] == "value_0" + except asyncio.CancelledError: + pass + + +@pytest.fixture(scope="function") +def postgres_provider(): + """Fixture to create and tear down the PostgresKVLoggingProvider.""" + log_table = f"logs_{str(uuid.uuid4()).replace('-', '_')}" + log_info_table = f"log_info_{str(uuid.uuid4()).replace('-', '_')}" + + provider = PostgresKVLoggingProvider( + PostgresLoggingConfig( + log_table=log_table, log_info_table=log_info_table + ) + ) + yield provider + + +@pytest.mark.asyncio +async def test_postgres_logging(postgres_provider): + """Test logging and retrieving from the postgres logging provider.""" + try: + await postgres_provider.init() + run_id = generate_run_id() + await postgres_provider.log(run_id, "key", "value") + logs = await postgres_provider.get_logs([run_id]) + assert len(logs) == 1 + assert logs[0]["key"] == "key" + assert logs[0]["value"] == "value" + except asyncio.CancelledError: + pass + + +@pytest.mark.asyncio +async def test_postgres_multiple_log_entries(postgres_provider): + """Test logging multiple entries and retrieving them.""" + try: + await postgres_provider.init() + run_id_0 = generate_run_id() + run_id_1 = generate_run_id() + run_id_2 = generate_run_id() + + entries = [ + (run_id_0, "key_0", "value_0"), + (run_id_1, "key_1", "value_1"), + (run_id_2, "key_2", "value_2"), + ] + for run_id, key, value in entries: + await postgres_provider.log(run_id, key, value) + + logs = await postgres_provider.get_logs([run_id_0, run_id_1, run_id_2]) + assert len(logs) == 3 + + # Check that logs are returned in the correct order (most recent first if applicable) + for log in logs: + selected_entry = [ + entry for entry in entries if entry[0] == log["log_id"] + ][0] + assert log["log_id"] == selected_entry[0] + assert log["key"] == selected_entry[1] + assert log["value"] == selected_entry[2] + except asyncio.CancelledError: + pass + + +@pytest.mark.asyncio +async def test_postgres_log_retrieval_limit(postgres_provider): + """Test the max_logs limit parameter works correctly.""" + try: + await postgres_provider.init() + run_ids = [] + for i in range(10): # Add 10 entries + run_ids.append(generate_run_id()) + await postgres_provider.log(run_ids[-1], f"key_{i}", f"value_{i}") + + logs = await postgres_provider.get_logs(run_ids[:5]) + assert len(logs) == 5 # Ensure only 5 logs are returned + except asyncio.CancelledError: + pass + + +@pytest.mark.asyncio +async def test_postgres_specific_run_type_retrieval(postgres_provider): + """Test retrieving logs for a specific run type works correctly.""" + try: + await postgres_provider.init() + run_id_0 = generate_run_id() + run_id_1 = generate_run_id() + + await postgres_provider.log( + run_id_0, "pipeline_type", "search", is_info_log=True + ) + await postgres_provider.log(run_id_0, "key_0", "value_0") + await postgres_provider.log( + run_id_1, "pipeline_type", "rag", is_info_log=True + ) + await postgres_provider.log(run_id_1, "key_1", "value_1") + + run_info = await postgres_provider.get_run_info( + log_type_filter="search" + ) + logs = await postgres_provider.get_logs( + [run.run_id for run in run_info] + ) + assert len(logs) == 1 + assert logs[0]["log_id"] == run_id_0 + assert logs[0]["key"] == "key_0" + assert logs[0]["value"] == "value_0" + except asyncio.CancelledError: + pass + + +@pytest.fixture(scope="function") +def redis_provider(): + """Fixture to create and tear down the RedisKVLoggingProvider.""" + log_table = f"logs_{str(uuid.uuid4()).replace('-', '_')}" + log_info_table = f"log_info_{str(uuid.uuid4()).replace('-', '_')}" + + provider = RedisKVLoggingProvider( + RedisLoggingConfig(log_table=log_table, log_info_table=log_info_table) + ) + yield provider + provider.close() + + +@pytest.mark.asyncio +async def test_redis_logging(redis_provider): + """Test logging and retrieving from the Redis logging provider.""" + try: + run_id = generate_run_id() + await redis_provider.log(run_id, "key", "value") + logs = await redis_provider.get_logs([run_id]) + assert len(logs) == 1 + assert logs[0]["key"] == "key" + assert logs[0]["value"] == "value" + except asyncio.CancelledError: + pass + + +@pytest.mark.asyncio +async def test_redis_multiple_log_entries(redis_provider): + """Test logging multiple entries and retrieving them.""" + try: + run_id_0 = generate_run_id() + run_id_1 = generate_run_id() + run_id_2 = generate_run_id() + + entries = [ + (run_id_0, "key_0", "value_0"), + (run_id_1, "key_1", "value_1"), + (run_id_2, "key_2", "value_2"), + ] + for run_id, key, value in entries: + await redis_provider.log(run_id, key, value) + + logs = await redis_provider.get_logs([run_id_0, run_id_1, run_id_2]) + assert len(logs) == 3 + + # Check that logs are returned in the correct order (most recent first if applicable) + for log in logs: + selected_entry = [ + entry for entry in entries if entry[0] == log["log_id"] + ][0] + assert log["log_id"] == selected_entry[0] + assert log["key"] == selected_entry[1] + assert log["value"] == selected_entry[2] + except asyncio.CancelledError: + pass + + +@pytest.mark.asyncio +async def test_redis_log_retrieval_limit(redis_provider): + """Test the max_logs limit parameter works correctly.""" + try: + run_ids = [] + for i in range(10): # Add 10 entries + run_ids.append(generate_run_id()) + await redis_provider.log(run_ids[-1], f"key_{i}", f"value_{i}") + + logs = await redis_provider.get_logs(run_ids[0:5]) + assert len(logs) == 5 # Ensure only 5 logs are returned + except asyncio.CancelledError: + pass + + +@pytest.mark.asyncio +async def test_redis_specific_run_type_retrieval(redis_provider): + """Test retrieving logs for a specific run type works correctly.""" + try: + run_id_0 = generate_run_id() + run_id_1 = generate_run_id() + + await redis_provider.log( + run_id_0, "pipeline_type", "search", is_info_log=True + ) + await redis_provider.log(run_id_0, "key_0", "value_0") + await redis_provider.log( + run_id_1, "pipeline_type", "rag", is_info_log=True + ) + await redis_provider.log(run_id_1, "key_1", "value_1") + + run_info = await redis_provider.get_run_info(log_type_filter="search") + logs = await redis_provider.get_logs([run.run_id for run in run_info]) + assert len(logs) == 1 + assert logs[0]["log_id"] == run_id_0 + assert logs[0]["key"] == "key_0" + assert logs[0]["value"] == "value_0" + except asyncio.CancelledError: + pass diff --git a/R2R/tests/test_parser.py b/R2R/tests/test_parser.py new file mode 100755 index 00000000..6965c5a9 --- /dev/null +++ b/R2R/tests/test_parser.py @@ -0,0 +1,159 @@ +import asyncio +import json +from unittest.mock import MagicMock, patch + +import pytest + +from r2r.parsers.media.docx_parser import DOCXParser +from r2r.parsers.media.pdf_parser import PDFParser +from r2r.parsers.media.ppt_parser import PPTParser +from r2r.parsers.structured.csv_parser import CSVParser +from r2r.parsers.structured.json_parser import JSONParser +from r2r.parsers.structured.xlsx_parser import XLSXParser +from r2r.parsers.text.html_parser import HTMLParser +from r2r.parsers.text.md_parser import MDParser +from r2r.parsers.text.text_parser import TextParser + + +@pytest.fixture(scope="session", autouse=True) +def event_loop_policy(): + asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy()) + + +@pytest.fixture(scope="function") +def event_loop(): + loop = asyncio.get_event_loop_policy().new_event_loop() + yield loop + loop.close() + asyncio.set_event_loop(None) + + +@pytest.fixture(scope="session", autouse=True) +async def cleanup_tasks(): + yield + for task in asyncio.all_tasks(): + if task is not asyncio.current_task(): + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + +@pytest.mark.asyncio +async def test_text_parser(): + try: + parser = TextParser() + data = "Simple text" + async for result in parser.ingest(data): + assert result == "Simple text" + except asyncio.CancelledError: + pass + + +@pytest.mark.asyncio +async def test_json_parser(): + try: + parser = JSONParser() + data = json.dumps({"key": "value", "null_key": None}) + async for result in parser.ingest(data): + assert "key: value" in result + assert "null_key" not in result + except asyncio.CancelledError: + pass + + +@pytest.mark.asyncio +async def test_html_parser(): + try: + parser = HTMLParser() + data = "<html><body><p>Hello World</p></body></html>" + async for result in parser.ingest(data): + assert result.strip() == "Hello World" + except asyncio.CancelledError: + pass + + +@pytest.mark.asyncio +@patch("pypdf.PdfReader") +async def test_pdf_parser(mock_pdf_reader): + try: + parser = PDFParser() + mock_pdf_reader.return_value.pages = [ + MagicMock(extract_text=lambda: "Page text") + ] + data = b"fake PDF data" + async for result in parser.ingest(data): + assert result == "Page text" + except asyncio.CancelledError: + pass + + +@pytest.mark.asyncio +@patch("pptx.Presentation") +async def test_ppt_parser(mock_presentation): + try: + mock_slide = MagicMock() + mock_shape = MagicMock(text="Slide text") + mock_slide.shapes = [mock_shape] + mock_presentation.return_value.slides = [mock_slide] + parser = PPTParser() + data = b"fake PPT data" + async for result in parser.ingest(data): + assert result == "Slide text" + except asyncio.CancelledError: + pass + + +@pytest.mark.asyncio +@patch("docx.Document") +async def test_docx_parser(mock_document): + try: + mock_paragraph = MagicMock(text="Paragraph text") + mock_document.return_value.paragraphs = [mock_paragraph] + parser = DOCXParser() + data = b"fake DOCX data" + async for result in parser.ingest(data): + assert result == "Paragraph text" + except asyncio.CancelledError: + pass + + +@pytest.mark.asyncio +async def test_csv_parser(): + try: + parser = CSVParser() + data = "col1,col2\nvalue1,value2" + async for result in parser.ingest(data): + assert result == "col1, col2" + break + except asyncio.CancelledError: + pass + + +@pytest.mark.asyncio +@patch("openpyxl.load_workbook") +async def test_xlsx_parser(mock_load_workbook): + try: + mock_sheet = MagicMock() + mock_sheet.iter_rows.return_value = [(1, 2), (3, 4)] + mock_workbook = MagicMock(worksheets=[mock_sheet]) + mock_load_workbook.return_value = mock_workbook + parser = XLSXParser() + data = b"fake XLSX data" + async for result in parser.ingest(data): + assert result == "1, 2" + break + except asyncio.CancelledError: + pass + + +@pytest.mark.asyncio +async def test_markdown_parser(): + try: + parser = MDParser() + data = "# Header\nContent" + async for result in parser.ingest(data): + assert result.strip() == "Header\nContent" + except asyncio.CancelledError: + pass diff --git a/R2R/tests/test_pipeline.py b/R2R/tests/test_pipeline.py new file mode 100755 index 00000000..1811de0f --- /dev/null +++ b/R2R/tests/test_pipeline.py @@ -0,0 +1,291 @@ +import asyncio +from typing import Any, AsyncGenerator + +import pytest + +from r2r import AsyncPipe, AsyncPipeline, PipeType + + +class MultiplierPipe(AsyncPipe): + def __init__(self, multiplier=1, delay=0, name="multiplier_pipe"): + super().__init__( + type=PipeType.OTHER, + config=self.PipeConfig(name=name), + ) + self.multiplier = multiplier + self.delay = delay + + async def _run_logic( + self, + input: AsyncGenerator[Any, None], + state, + run_id=None, + *args, + **kwargs, + ) -> AsyncGenerator[Any, None]: + async for item in input.message: + if self.delay > 0: + await asyncio.sleep(self.delay) # Simulate processing delay + if isinstance(item, list): + processed = [x * self.multiplier for x in item] + elif isinstance(item, int): + processed = item * self.multiplier + else: + raise ValueError(f"Unsupported type: {type(item)}") + yield processed + + +class FanOutPipe(AsyncPipe): + def __init__(self, multiplier=1, delay=0, name="fan_out_pipe"): + super().__init__( + type=PipeType.OTHER, + config=self.PipeConfig(name=name), + ) + self.multiplier = multiplier + self.delay = delay + + async def _run_logic( + self, + input: AsyncGenerator[Any, None], + state, + run_id=None, + *args, + **kwargs, + ) -> AsyncGenerator[Any, None]: + inputs = [] + async for item in input.message: + inputs.append(item) + for it in range(self.multiplier): + if self.delay > 0: + await asyncio.sleep(self.delay) + yield [(it + 1) * ele for ele in inputs] + + +class FanInPipe(AsyncPipe): + def __init__(self, delay=0, name="fan_in_pipe"): + super().__init__( + type=PipeType.OTHER, + config=self.PipeConfig(name=name), + ) + self.delay = delay + + async def _run_logic( + self, + input: AsyncGenerator[Any, None], + state, + run_id=None, + *args, + **kwargs, + ) -> AsyncGenerator[Any, None]: + total_sum = 0 + async for batch in input.message: + if self.delay > 0: + await asyncio.sleep(self.delay) # Simulate processing delay + total_sum += sum( + batch + ) # Assuming batch is iterable and contains numeric values + yield total_sum + + +@pytest.fixture +def pipe_factory(): + def create_pipe(type, **kwargs): + if type == "multiplier": + return MultiplierPipe(**kwargs) + elif type == "fan_out": + return FanOutPipe(**kwargs) + elif type == "fan_in": + return FanInPipe(**kwargs) + else: + raise ValueError("Unsupported pipe type") + + return create_pipe + + +@pytest.mark.asyncio +@pytest.mark.parametrize("multiplier, delay, name", [(2, 0.1, "pipe")]) +async def test_single_multiplier(pipe_factory, multiplier, delay, name): + pipe = pipe_factory( + "multiplier", multiplier=multiplier, delay=delay, name=name + ) + + async def input_generator(): + for i in [1, 2, 3]: + yield i + + pipeline = AsyncPipeline() + pipeline.add_pipe(pipe) + + result = [] + for output in await pipeline.run(input_generator()): + result.append(output) + + expected_result = [i * multiplier for i in [1, 2, 3]] + assert ( + result == expected_result + ), "Pipeline output did not match expected multipliers" + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "multiplier_a, delay_a, name_a, multiplier_b, delay_b, name_b", + [(2, 0.1, "pipe_a", 2, 0.1, "pipe_b")], +) +async def test_double_multiplier( + pipe_factory, multiplier_a, delay_a, name_a, multiplier_b, delay_b, name_b +): + pipe_a = pipe_factory( + "multiplier", multiplier=multiplier_a, delay=delay_a, name=name_a + ) + pipe_b = pipe_factory( + "multiplier", multiplier=multiplier_b, delay=delay_b, name=name_b + ) + + async def input_generator(): + for i in [1, 2, 3]: + yield i + + pipeline = AsyncPipeline() + pipeline.add_pipe(pipe_a) + pipeline.add_pipe(pipe_b) + + result = [] + for output in await pipeline.run(input_generator()): + result.append(output) + + expected_result = [i * multiplier_a * multiplier_b for i in [1, 2, 3]] + assert ( + result == expected_result + ), "Pipeline output did not match expected multipliers" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("multiplier, delay, name", [(3, 0.1, "pipe")]) +async def test_fan_out(pipe_factory, multiplier, delay, name): + pipe = pipe_factory( + "fan_out", multiplier=multiplier, delay=delay, name=name + ) + + async def input_generator(): + for i in [1, 2, 3]: + yield i + + pipeline = AsyncPipeline() + pipeline.add_pipe(pipe) + + result = [] + for output in await pipeline.run(input_generator()): + result.append(output) + + expected_result = [ + [i + 1, 2 * (i + 1), 3 * (i + 1)] for i in range(multiplier) + ] + assert ( + result == expected_result + ), "Pipeline output did not match expected multipliers" + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "multiplier_a, delay_a, name_a, multiplier_b, delay_b, name_b", + [ + (2, 0.1, "pipe_a", 2, 0.1, "pipe_b"), + (4, 0.1, "pipe_a", 3, 0.1, "pipe_b"), + ], +) +async def multiply_then_fan_out( + pipe_factory, multiplier_a, delay_a, name_a, multiplier_b, delay_b, name_b +): + pipe_a = pipe_factory( + "multiplier", multiplier=multiplier_a, delay=delay_a, name=name_a + ) + pipe_b = pipe_factory( + "fan_out", multiplier=multiplier_b, delay=delay_b, name=name_b + ) + + async def input_generator(): + for i in [1, 2, 3]: + yield i + + pipeline = AsyncPipeline() + pipeline.add_pipe(pipe_a) + pipeline.add_pipe(pipe_b) + + result = [] + async for output in await pipeline.run(input_generator()): + result.append(output) + + expected_result = [[i * multiplier_a] async for i in input_generator()] + assert ( + result[0] == expected_result + ), "Pipeline output did not match expected multipliers" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("multiplier, delay, name", [(3, 0.1, "pipe")]) +async def test_fan_in_sum(pipe_factory, multiplier, delay, name): + # Create fan-out to generate multiple streams + fan_out_pipe = pipe_factory( + "fan_out", multiplier=multiplier, delay=delay, name=name + "_a" + ) + # Summing fan-in pipe + fan_in_sum_pipe = pipe_factory("fan_in", delay=delay, name=name + "_b") + + async def input_generator(): + for i in [1, 2, 3]: + yield i + + pipeline = AsyncPipeline() + pipeline.add_pipe(fan_out_pipe) + pipeline.add_pipe(fan_in_sum_pipe) + + result = await pipeline.run(input_generator()) + + # Calculate expected results based on the multiplier and the sum of inputs + expected_result = sum( + [sum([j * i for j in [1, 2, 3]]) for i in range(1, multiplier + 1)] + ) + assert ( + result[0] == expected_result + ), "Pipeline output did not match expected sums" + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "multiplier_a, delay_a, name_a, multiplier_b, delay_b, name_b", + [ + (3, 0.1, "pipe_a", 2, 0.1, "pipe_b"), + (4, 0.1, "pipe_a", 3, 0.1, "pipe_b"), + ], +) +async def test_fan_out_then_multiply( + pipe_factory, multiplier_a, delay_a, name_a, multiplier_b, delay_b, name_b +): + pipe_a = pipe_factory( + "multiplier", multiplier=multiplier_a, delay=delay_a, name=name_a + ) + pipe_b = pipe_factory( + "fan_out", multiplier=multiplier_b, delay=delay_b, name=name_b + ) + pipe_c = pipe_factory("fan_in", delay=0.1, name="pipe_c") + + async def input_generator(): + for i in [1, 2, 3]: + yield i + + pipeline = AsyncPipeline() + pipeline.add_pipe(pipe_a) + pipeline.add_pipe(pipe_b) + pipeline.add_pipe(pipe_c) + + result = await pipeline.run(input_generator()) + + expected_result = sum( + [ + sum([j * i * multiplier_a for j in [1, 2, 3]]) + for i in range(1, multiplier_b + 1) + ] + ) + assert ( + result[0] == expected_result + ), "Pipeline output did not match expected multipliers" diff --git a/R2R/tests/test_vector_db.py b/R2R/tests/test_vector_db.py new file mode 100755 index 00000000..023145ce --- /dev/null +++ b/R2R/tests/test_vector_db.py @@ -0,0 +1,160 @@ +import random + +import pytest +from dotenv import load_dotenv + +from r2r import ( + Vector, + VectorDBConfig, + VectorDBProvider, + VectorEntry, + generate_id_from_label, +) +from r2r.providers.vector_dbs import PGVectorDB + +load_dotenv() + + +# Sample vector entries +def generate_random_vector_entry(id: str, dimension: int) -> VectorEntry: + vector = [random.random() for _ in range(dimension)] + metadata = {"key": f"value_{id}"} + return VectorEntry( + id=generate_id_from_label(id), vector=Vector(vector), metadata=metadata + ) + + +dimension = 3 +num_entries = 100 +sample_entries = [ + generate_random_vector_entry(f"id_{i}", dimension) + for i in range(num_entries) +] + + +# Fixture for PGVectorDB +@pytest.fixture +def pg_vector_db(): + random_collection_name = ( + f"test_collection_{random.randint(0, 1_000_000_000)}" + ) + config = VectorDBConfig.create( + provider="pgvector", vecs_collection=random_collection_name + ) + db = PGVectorDB(config) + db.initialize_collection(dimension=dimension) + yield db + # Teardown + db.vx.delete_collection( + db.config.extra_fields.get("vecs_collection", None) + ) + + +@pytest.mark.parametrize("db_fixture", ["pg_vector_db"]) +def test_get_metadatas(request, db_fixture): + db = request.getfixturevalue(db_fixture) + for entry in sample_entries: + db.upsert(entry) + + unique_metadatas = db.get_metadatas(metadata_fields=["key"]) + unique_values = set([ele["key"] for ele in unique_metadatas]) + assert len(unique_values) == num_entries + assert all(f"value_id_{i}" in unique_values for i in range(num_entries)) + + +@pytest.mark.parametrize("db_fixture", ["pg_vector_db"]) +def test_db_initialization(request, db_fixture): + db = request.getfixturevalue(db_fixture) + assert isinstance(db, VectorDBProvider) + + +@pytest.mark.parametrize("db_fixture", ["pg_vector_db"]) +def test_db_copy_and_search(request, db_fixture): + db = request.getfixturevalue(db_fixture) + db.upsert(sample_entries[0]) + results = db.search(query_vector=sample_entries[0].vector.data) + assert len(results) == 1 + assert results[0].id == sample_entries[0].id + assert results[0].score == pytest.approx(1.0, rel=1e-3) + + +@pytest.mark.parametrize("db_fixture", ["pg_vector_db"]) +def test_db_upsert_and_search(request, db_fixture): + db = request.getfixturevalue(db_fixture) + db.upsert(sample_entries[0]) + results = db.search(query_vector=sample_entries[0].vector.data) + assert len(results) == 1 + assert results[0].id == sample_entries[0].id + assert results[0].score == pytest.approx(1.0, rel=1e-3) + + +@pytest.mark.parametrize("db_fixture", ["pg_vector_db"]) +def test_imperfect_match(request, db_fixture): + db = request.getfixturevalue(db_fixture) + db.upsert(sample_entries[0]) + query_vector = [val + 0.1 for val in sample_entries[0].vector.data] + results = db.search(query_vector=query_vector) + assert len(results) == 1 + assert results[0].id == sample_entries[0].id + assert results[0].score < 1.0 + + +@pytest.mark.parametrize("db_fixture", ["pg_vector_db"]) +def test_bulk_insert_and_search(request, db_fixture): + db = request.getfixturevalue(db_fixture) + for entry in sample_entries: + db.upsert(entry) + + query_vector = sample_entries[0].vector.data + results = db.search(query_vector=query_vector, limit=5) + assert len(results) == 5 + assert results[0].id == sample_entries[0].id + assert results[0].score == pytest.approx(1.0, rel=1e-3) + + +@pytest.mark.parametrize("db_fixture", ["pg_vector_db"]) +def test_search_with_filters(request, db_fixture): + db = request.getfixturevalue(db_fixture) + for entry in sample_entries: + db.upsert(entry) + + filtered_id = sample_entries[0].metadata["key"] + query_vector = sample_entries[0].vector.data + results = db.search( + query_vector=query_vector, filters={"key": filtered_id} + ) + assert len(results) == 1 + assert results[0].id == sample_entries[0].id + assert results[0].metadata["key"] == filtered_id + + +@pytest.mark.parametrize("db_fixture", ["pg_vector_db"]) +def test_delete_by_metadata(request, db_fixture): + db = request.getfixturevalue(db_fixture) + for entry in sample_entries: + db.upsert(entry) + + key_to_delete = sample_entries[0].metadata["key"] + db.delete_by_metadata( + metadata_fields=["key"], metadata_values=[key_to_delete] + ) + + results = db.search(query_vector=sample_entries[0].vector.data) + assert all(result.metadata["key"] != key_to_delete for result in results) + + +@pytest.mark.parametrize("db_fixture", ["pg_vector_db"]) +def test_upsert(request, db_fixture): + db = request.getfixturevalue(db_fixture) + db.upsert(sample_entries[0]) + modified_entry = VectorEntry( + id=sample_entries[0].id, + vector=Vector([0.5, 0.5, 0.5]), + metadata={"key": "new_value"}, + ) + db.upsert(modified_entry) + + results = db.search(query_vector=[0.5, 0.5, 0.5]) + assert len(results) == 1 + assert results[0].id == sample_entries[0].id + assert results[0].metadata["key"] == "new_value" |