aboutsummaryrefslogtreecommitdiff
path: root/R2R/tests
diff options
context:
space:
mode:
Diffstat (limited to 'R2R/tests')
-rwxr-xr-xR2R/tests/test_abstractions.py162
-rwxr-xr-xR2R/tests/test_config.py187
-rwxr-xr-xR2R/tests/test_embedding.py162
-rwxr-xr-xR2R/tests/test_end_to_end.py375
-rwxr-xr-xR2R/tests/test_ingestion_service.py443
-rwxr-xr-xR2R/tests/test_llms.py59
-rwxr-xr-xR2R/tests/test_logging.py360
-rwxr-xr-xR2R/tests/test_parser.py159
-rwxr-xr-xR2R/tests/test_pipeline.py291
-rwxr-xr-xR2R/tests/test_vector_db.py160
10 files changed, 2358 insertions, 0 deletions
diff --git a/R2R/tests/test_abstractions.py b/R2R/tests/test_abstractions.py
new file mode 100755
index 00000000..a360e952
--- /dev/null
+++ b/R2R/tests/test_abstractions.py
@@ -0,0 +1,162 @@
+import asyncio
+import uuid
+
+import pytest
+
+from r2r import (
+ AsyncPipe,
+ AsyncState,
+ Prompt,
+ Vector,
+ VectorEntry,
+ VectorSearchRequest,
+ VectorSearchResult,
+ VectorType,
+ generate_id_from_label,
+)
+
+
+@pytest.fixture(scope="session", autouse=True)
+def event_loop_policy():
+ asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy())
+
+
+@pytest.fixture(scope="function")
+def event_loop():
+ loop = asyncio.get_event_loop_policy().new_event_loop()
+ yield loop
+ loop.close()
+ asyncio.set_event_loop(None)
+
+
+@pytest.fixture(scope="session", autouse=True)
+async def cleanup_tasks():
+ yield
+ for task in asyncio.all_tasks():
+ if task is not asyncio.current_task():
+ task.cancel()
+ try:
+ await task
+ except asyncio.CancelledError:
+ pass
+
+
+@pytest.mark.asyncio
+async def test_async_state_update_and_get():
+ state = AsyncState()
+ outer_key = "test_key"
+ values = {"inner_key": "value"}
+ await state.update(outer_key, values)
+ result = await state.get(outer_key, "inner_key")
+ assert result == "value"
+
+
+@pytest.mark.asyncio
+async def test_async_state_delete():
+ state = AsyncState()
+ outer_key = "test_key"
+ values = {"inner_key": "value"}
+ await state.update(outer_key, values)
+ await state.delete(outer_key, "inner_key")
+ result = await state.get(outer_key, "inner_key")
+ assert result == {}, "Expect empty result after deletion"
+
+
+class MockAsyncPipe(AsyncPipe):
+ async def _run_logic(self, input, state, run_id, *args, **kwargs):
+ yield "processed"
+
+
+@pytest.mark.asyncio
+async def test_async_pipe_run():
+ pipe = MockAsyncPipe()
+
+ async def list_to_generator(lst):
+ for item in lst:
+ yield item
+
+ input = pipe.Input(message=list_to_generator(["test"]))
+ state = AsyncState()
+ try:
+ async_generator = await pipe.run(input, state)
+ results = [result async for result in async_generator]
+ assert results == ["processed"]
+ except asyncio.CancelledError:
+ pass # Task cancelled as expected
+
+
+def test_prompt_initialization_and_formatting():
+ prompt = Prompt(
+ name="greet", template="Hello, {name}!", input_types={"name": "str"}
+ )
+ formatted = prompt.format_prompt({"name": "Alice"})
+ assert formatted == "Hello, Alice!"
+
+
+def test_prompt_missing_input():
+ prompt = Prompt(
+ name="greet", template="Hello, {name}!", input_types={"name": "str"}
+ )
+ with pytest.raises(ValueError):
+ prompt.format_prompt({})
+
+
+def test_prompt_invalid_input_type():
+ prompt = Prompt(
+ name="greet", template="Hello, {name}!", input_types={"name": "int"}
+ )
+ with pytest.raises(TypeError):
+ prompt.format_prompt({"name": "Alice"})
+
+
+def test_search_request_with_optional_filters():
+ request = VectorSearchRequest(
+ query="test", limit=10, filters={"category": "books"}
+ )
+ assert request.query == "test"
+ assert request.limit == 10
+ assert request.filters == {"category": "books"}
+
+
+def test_search_result_to_string():
+ result = VectorSearchResult(
+ id=generate_id_from_label("1"),
+ score=9.5,
+ metadata={"author": "John Doe"},
+ )
+ result_str = str(result)
+ assert (
+ result_str
+ == f"VectorSearchResult(id={str(generate_id_from_label('1'))}, score=9.5, metadata={{'author': 'John Doe'}})"
+ )
+
+
+def test_search_result_repr():
+ result = VectorSearchResult(
+ id=generate_id_from_label("1"),
+ score=9.5,
+ metadata={"author": "John Doe"},
+ )
+ assert (
+ repr(result)
+ == f"VectorSearchResult(id={str(generate_id_from_label('1'))}, score=9.5, metadata={{'author': 'John Doe'}})"
+ )
+
+
+def test_vector_fixed_length_validation():
+ with pytest.raises(ValueError):
+ Vector(data=[1.0, 2.0], type=VectorType.FIXED, length=3)
+
+
+def test_vector_entry_serialization():
+ vector = Vector(data=[1.0, 2.0], type=VectorType.FIXED, length=2)
+ entry_id = uuid.uuid4()
+ entry = VectorEntry(
+ id=entry_id, vector=vector, metadata={"key": uuid.uuid4()}
+ )
+ serializable = entry.to_serializable()
+ assert serializable["id"] == str(entry_id)
+ assert serializable["vector"] == [1.0, 2.0]
+ assert isinstance(
+ serializable["metadata"]["key"], str
+ ) # Check UUID conversion to string
diff --git a/R2R/tests/test_config.py b/R2R/tests/test_config.py
new file mode 100755
index 00000000..5e60833c
--- /dev/null
+++ b/R2R/tests/test_config.py
@@ -0,0 +1,187 @@
+import asyncio
+import json
+from unittest.mock import Mock, mock_open, patch
+
+import pytest
+
+from r2r import DocumentType, R2RConfig
+
+
+@pytest.fixture(scope="session", autouse=True)
+def event_loop_policy():
+ asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy())
+
+
+@pytest.fixture(scope="function")
+def event_loop():
+ loop = asyncio.get_event_loop_policy().new_event_loop()
+ yield loop
+ loop.close()
+ asyncio.set_event_loop(None)
+
+
+@pytest.fixture(scope="session", autouse=True)
+async def cleanup_tasks():
+ yield
+ for task in asyncio.all_tasks():
+ if task is not asyncio.current_task():
+ task.cancel()
+ try:
+ await task
+ except asyncio.CancelledError:
+ pass
+
+
+@pytest.fixture
+def mock_bad_file():
+ mock_data = json.dumps({})
+ with patch("builtins.open", mock_open(read_data=mock_data)) as m:
+ yield m
+
+
+@pytest.fixture
+def mock_file():
+ mock_data = json.dumps(
+ {
+ "app": {"max_file_size_in_mb": 128},
+ "embedding": {
+ "provider": "example_provider",
+ "base_model": "model",
+ "base_dimension": 128,
+ "batch_size": 16,
+ "text_splitter": "default",
+ },
+ "kg": {
+ "provider": "None",
+ "batch_size": 1,
+ "text_splitter": {
+ "type": "recursive_character",
+ "chunk_size": 2048,
+ "chunk_overlap": 0,
+ },
+ },
+ "eval": {"llm": {"provider": "local"}},
+ "ingestion": {"excluded_parsers": {}},
+ "completions": {"provider": "lm_provider"},
+ "logging": {
+ "provider": "local",
+ "log_table": "logs",
+ "log_info_table": "log_info",
+ },
+ "prompt": {"provider": "prompt_provider"},
+ "vector_database": {"provider": "vector_db"},
+ }
+ )
+ with patch("builtins.open", mock_open(read_data=mock_data)) as m:
+ yield m
+
+
+@pytest.mark.asyncio
+async def test_r2r_config_loading_required_keys(mock_bad_file):
+ with pytest.raises(KeyError):
+ R2RConfig.from_json("config.json")
+
+
+@pytest.mark.asyncio
+async def test_r2r_config_loading(mock_file):
+ config = R2RConfig.from_json("config.json")
+ assert (
+ config.embedding.provider == "example_provider"
+ ), "Provider should match the mock data"
+
+
+@pytest.fixture
+def mock_redis_client():
+ client = Mock()
+ return client
+
+
+def test_r2r_config_serialization(mock_file, mock_redis_client):
+ config = R2RConfig.from_json("config.json")
+ config.save_to_redis(mock_redis_client, "test_key")
+ mock_redis_client.set.assert_called_once()
+ saved_data = json.loads(mock_redis_client.set.call_args[0][1])
+ assert saved_data["app"]["max_file_size_in_mb"] == 128
+
+
+def test_r2r_config_deserialization(mock_file, mock_redis_client):
+ config_data = {
+ "app": {"max_file_size_in_mb": 128},
+ "embedding": {
+ "provider": "example_provider",
+ "base_model": "model",
+ "base_dimension": 128,
+ "batch_size": 16,
+ "text_splitter": "default",
+ },
+ "kg": {
+ "provider": "None",
+ "batch_size": 1,
+ "text_splitter": {
+ "type": "recursive_character",
+ "chunk_size": 2048,
+ "chunk_overlap": 0,
+ },
+ },
+ "eval": {"llm": {"provider": "local"}},
+ "ingestion": {"excluded_parsers": ["pdf"]},
+ "completions": {"provider": "lm_provider"},
+ "logging": {
+ "provider": "local",
+ "log_table": "logs",
+ "log_info_table": "log_info",
+ },
+ "prompt": {"provider": "prompt_provider"},
+ "vector_database": {"provider": "vector_db"},
+ }
+ mock_redis_client.get.return_value = json.dumps(config_data)
+ config = R2RConfig.load_from_redis(mock_redis_client, "test_key")
+ assert config.app["max_file_size_in_mb"] == 128
+ assert DocumentType.PDF in config.ingestion["excluded_parsers"]
+
+
+def test_r2r_config_missing_section():
+ invalid_data = {
+ "embedding": {
+ "provider": "example_provider",
+ "base_model": "model",
+ "base_dimension": 128,
+ "batch_size": 16,
+ "text_splitter": "default",
+ }
+ }
+ with patch("builtins.open", mock_open(read_data=json.dumps(invalid_data))):
+ with pytest.raises(KeyError):
+ R2RConfig.from_json("config.json")
+
+
+def test_r2r_config_missing_required_key():
+ invalid_data = {
+ "app": {"max_file_size_in_mb": 128},
+ "embedding": {
+ "base_model": "model",
+ "base_dimension": 128,
+ "batch_size": 16,
+ "text_splitter": "default",
+ },
+ "kg": {
+ "provider": "None",
+ "batch_size": 1,
+ "text_splitter": {
+ "type": "recursive_character",
+ "chunk_size": 2048,
+ "chunk_overlap": 0,
+ },
+ },
+ "completions": {"provider": "lm_provider"},
+ "logging": {
+ "provider": "local",
+ "log_table": "logs",
+ "log_info_table": "log_info",
+ },
+ "prompt": {"provider": "prompt_provider"},
+ "vector_database": {"provider": "vector_db"},
+ }
+ with patch("builtins.open", mock_open(read_data=json.dumps(invalid_data))):
+ with pytest.raises(KeyError):
+ R2RConfig.from_json("config.json")
diff --git a/R2R/tests/test_embedding.py b/R2R/tests/test_embedding.py
new file mode 100755
index 00000000..7a3e760a
--- /dev/null
+++ b/R2R/tests/test_embedding.py
@@ -0,0 +1,162 @@
+import asyncio
+
+import pytest
+
+from r2r import EmbeddingConfig, VectorSearchResult, generate_id_from_label
+from r2r.providers.embeddings import (
+ OpenAIEmbeddingProvider,
+ SentenceTransformerEmbeddingProvider,
+)
+
+
+@pytest.fixture(scope="session", autouse=True)
+def event_loop_policy():
+ asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy())
+
+
+@pytest.fixture(scope="function")
+def event_loop():
+ loop = asyncio.get_event_loop_policy().new_event_loop()
+ yield loop
+ loop.close()
+ asyncio.set_event_loop(None)
+
+
+@pytest.fixture(scope="session", autouse=True)
+async def cleanup_tasks():
+ yield
+ for task in asyncio.all_tasks():
+ if task is not asyncio.current_task():
+ task.cancel()
+ try:
+ await task
+ except asyncio.CancelledError:
+ pass
+
+
+@pytest.fixture
+def openai_provider():
+ config = EmbeddingConfig(
+ provider="openai",
+ base_model="text-embedding-3-small",
+ base_dimension=1536,
+ )
+ return OpenAIEmbeddingProvider(config)
+
+
+def test_openai_initialization(openai_provider):
+ assert isinstance(openai_provider, OpenAIEmbeddingProvider)
+ assert openai_provider.base_model == "text-embedding-3-small"
+ assert openai_provider.base_dimension == 1536
+
+
+def test_openai_invalid_provider_initialization():
+ config = EmbeddingConfig(provider="invalid_provider")
+ with pytest.raises(ValueError):
+ OpenAIEmbeddingProvider(config)
+
+
+def test_openai_get_embedding(openai_provider):
+ embedding = openai_provider.get_embedding("test text")
+ assert len(embedding) == 1536
+ assert isinstance(embedding, list)
+
+
+@pytest.mark.asyncio
+async def test_openai_async_get_embedding(openai_provider):
+ try:
+ embedding = await openai_provider.async_get_embedding("test text")
+ assert len(embedding) == 1536
+ assert isinstance(embedding, list)
+ except asyncio.CancelledError:
+ pass # Task cancelled as expected
+
+
+def test_openai_get_embeddings(openai_provider):
+ embeddings = openai_provider.get_embeddings(["text1", "text2"])
+ assert len(embeddings) == 2
+ assert all(len(emb) == 1536 for emb in embeddings)
+
+
+@pytest.mark.asyncio
+async def test_openai_async_get_embeddings(openai_provider):
+ try:
+ embeddings = await openai_provider.async_get_embeddings(
+ ["text1", "text2"]
+ )
+ assert len(embeddings) == 2
+ assert all(len(emb) == 1536 for emb in embeddings)
+ except asyncio.CancelledError:
+ pass # Task cancelled as expected
+
+
+def test_openai_tokenize_string(openai_provider):
+ tokens = openai_provider.tokenize_string(
+ "test text", "text-embedding-3-small"
+ )
+ assert isinstance(tokens, list)
+ assert all(isinstance(token, int) for token in tokens)
+
+
+@pytest.fixture
+def sentence_transformer_provider():
+ config = EmbeddingConfig(
+ provider="sentence-transformers",
+ base_model="mixedbread-ai/mxbai-embed-large-v1",
+ base_dimension=512,
+ rerank_model="jinaai/jina-reranker-v1-turbo-en",
+ rerank_dimension=384,
+ )
+ return SentenceTransformerEmbeddingProvider(config)
+
+
+def test_sentence_transformer_initialization(sentence_transformer_provider):
+ assert isinstance(
+ sentence_transformer_provider, SentenceTransformerEmbeddingProvider
+ )
+ assert sentence_transformer_provider.do_search
+ # assert sentence_transformer_provider.do_rerank
+
+
+def test_sentence_transformer_invalid_provider_initialization():
+ config = EmbeddingConfig(provider="invalid_provider")
+ with pytest.raises(ValueError):
+ SentenceTransformerEmbeddingProvider(config)
+
+
+def test_sentence_transformer_get_embedding(sentence_transformer_provider):
+ embedding = sentence_transformer_provider.get_embedding("test text")
+ assert len(embedding) == 512
+ assert isinstance(embedding, list)
+
+
+def test_sentence_transformer_get_embeddings(sentence_transformer_provider):
+ embeddings = sentence_transformer_provider.get_embeddings(
+ ["text1", "text2"]
+ )
+ assert len(embeddings) == 2
+ assert all(len(emb) == 512 for emb in embeddings)
+
+
+def test_sentence_transformer_rerank(sentence_transformer_provider):
+ results = [
+ VectorSearchResult(
+ id=generate_id_from_label("x"),
+ score=0.9,
+ metadata={"text": "doc1"},
+ ),
+ VectorSearchResult(
+ id=generate_id_from_label("y"),
+ score=0.8,
+ metadata={"text": "doc2"},
+ ),
+ ]
+ reranked_results = sentence_transformer_provider.rerank("query", results)
+ assert len(reranked_results) == 2
+ assert reranked_results[0].metadata["text"] == "doc1"
+ assert reranked_results[1].metadata["text"] == "doc2"
+
+
+def test_sentence_transformer_tokenize_string(sentence_transformer_provider):
+ with pytest.raises(ValueError):
+ sentence_transformer_provider.tokenize_string("test text")
diff --git a/R2R/tests/test_end_to_end.py b/R2R/tests/test_end_to_end.py
new file mode 100755
index 00000000..5e13ab5c
--- /dev/null
+++ b/R2R/tests/test_end_to_end.py
@@ -0,0 +1,375 @@
+import asyncio
+import os
+import uuid
+
+import pytest
+from fastapi.datastructures import UploadFile
+
+from r2r import (
+ Document,
+ KVLoggingSingleton,
+ R2RConfig,
+ R2REngine,
+ R2RPipeFactory,
+ R2RPipelineFactory,
+ R2RProviderFactory,
+ VectorSearchSettings,
+ generate_id_from_label,
+)
+from r2r.base.abstractions.llm import GenerationConfig
+
+
+@pytest.fixture(scope="session", autouse=True)
+def event_loop_policy():
+ asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy())
+
+
+@pytest.fixture(scope="function")
+def event_loop():
+ loop = asyncio.get_event_loop_policy().new_event_loop()
+ yield loop
+ loop.close()
+ asyncio.set_event_loop(None)
+
+
+@pytest.fixture(scope="session", autouse=True)
+async def cleanup_tasks():
+ yield
+ for task in asyncio.all_tasks():
+ if task is not asyncio.current_task():
+ task.cancel()
+ try:
+ await task
+ except asyncio.CancelledError:
+ pass
+
+
+@pytest.fixture(scope="function")
+def app(request):
+ config = R2RConfig.from_json()
+ config.logging.provider = "local"
+ config.logging.logging_path = uuid.uuid4().hex
+
+ vector_db_provider = request.param
+ if vector_db_provider == "pgvector":
+ config.vector_database.provider = "pgvector"
+ config.vector_database.extra_fields["vecs_collection"] = (
+ config.logging.logging_path
+ )
+ try:
+ providers = R2RProviderFactory(config).create_providers()
+ pipes = R2RPipeFactory(config, providers).create_pipes()
+ pipelines = R2RPipelineFactory(config, pipes).create_pipelines()
+
+ r2r = R2REngine(
+ config=config,
+ providers=providers,
+ pipelines=pipelines,
+ )
+
+ try:
+ KVLoggingSingleton.configure(config.logging)
+ except:
+ KVLoggingSingleton._config.logging_path = (
+ config.logging.logging_path
+ )
+
+ yield r2r
+ finally:
+ if os.path.exists(config.logging.logging_path):
+ os.remove(config.logging.logging_path)
+
+
+@pytest.fixture
+def logging_connection():
+ return KVLoggingSingleton()
+
+
+@pytest.mark.parametrize("app", ["pgvector"], indirect=True)
+@pytest.mark.asyncio
+async def test_ingest_txt_document(app, logging_connection):
+ try:
+ await app.aingest_documents(
+ [
+ Document(
+ id=generate_id_from_label("doc_1"),
+ data="The quick brown fox jumps over the lazy dog.",
+ type="txt",
+ metadata={"author": "John Doe"},
+ ),
+ ]
+ )
+ except asyncio.CancelledError:
+ pass
+
+
+@pytest.mark.parametrize("app", ["pgvector"], indirect=True)
+@pytest.mark.asyncio
+async def test_ingest_txt_file(app, logging_connection):
+ try:
+ # Prepare the test data
+ metadata = {"author": "John Doe"}
+ files = [
+ UploadFile(
+ filename="test.txt",
+ file=open(
+ os.path.join(
+ os.path.dirname(__file__),
+ "..",
+ "r2r",
+ "examples",
+ "data",
+ "test.txt",
+ ),
+ "rb",
+ ),
+ )
+ ]
+ # Set file size manually
+ for file in files:
+ file.file.seek(0, 2) # Move to the end of the file
+ file.size = file.file.tell() # Get the file size
+ file.file.seek(0) # Move back to the start of the file
+
+ await app.aingest_files(metadatas=[metadata], files=files)
+ except asyncio.CancelledError:
+ pass
+
+
+@pytest.mark.parametrize("app", ["pgvector"], indirect=True)
+@pytest.mark.asyncio
+async def test_ingest_search_txt_file(app, logging_connection):
+ try:
+ # Prepare the test data
+ metadata = {}
+ files = [
+ UploadFile(
+ filename="aristotle.txt",
+ file=open(
+ os.path.join(
+ os.path.dirname(__file__),
+ "..",
+ "r2r",
+ "examples",
+ "data",
+ "aristotle.txt",
+ ),
+ "rb",
+ ),
+ ),
+ ]
+
+ # Set file size manually
+ for file in files:
+ file.file.seek(0, 2) # Move to the end of the file
+ file.size = file.file.tell() # Get the file size
+ file.file.seek(0) # Move back to the start of the file
+
+ await app.aingest_files(metadatas=[metadata], files=files)
+
+ search_results = await app.asearch("who was aristotle?")
+ assert len(search_results["vector_search_results"]) == 10
+ assert (
+ "was an Ancient Greek philosopher and polymath"
+ in search_results["vector_search_results"][0]["metadata"]["text"]
+ )
+
+ search_results = await app.asearch(
+ "who was aristotle?",
+ vector_search_settings=VectorSearchSettings(search_limit=20),
+ )
+ assert len(search_results["vector_search_results"]) == 20
+ assert (
+ "was an Ancient Greek philosopher and polymath"
+ in search_results["vector_search_results"][0]["metadata"]["text"]
+ )
+ run_info = await logging_connection.get_run_info(
+ log_type_filter="search"
+ )
+
+ assert len(run_info) == 2, f"Expected 2 runs, but got {len(run_info)}"
+
+ logs = await logging_connection.get_logs(
+ [run.run_id for run in run_info], 100
+ )
+ assert len(logs) == 6, f"Expected 6 logs, but got {len(logs)}"
+
+ ## test stream
+ response = await app.arag(
+ query="Who was aristotle?",
+ rag_generation_config=GenerationConfig(
+ **{"model": "gpt-3.5-turbo", "stream": True}
+ ),
+ )
+ collector = ""
+ async for chunk in response:
+ collector += chunk
+ assert "Aristotle" in collector
+ assert "Greek" in collector
+ assert "philosopher" in collector
+ assert "polymath" in collector
+ assert "Ancient" in collector
+ except asyncio.CancelledError:
+ pass
+
+
+@pytest.mark.parametrize("app", ["pgvector"], indirect=True)
+@pytest.mark.asyncio
+async def test_ingest_search_then_delete(app, logging_connection):
+ try:
+ # Ingest a document
+ await app.aingest_documents(
+ [
+ Document(
+ id=generate_id_from_label("doc_1"),
+ data="The quick brown fox jumps over the lazy dog.",
+ type="txt",
+ metadata={"author": "John Doe"},
+ ),
+ ]
+ )
+
+ # Search for the document
+ search_results = await app.asearch("who was aristotle?")
+
+ # Verify that the search results are not empty
+ assert (
+ len(search_results["vector_search_results"]) > 0
+ ), "Expected search results, but got none"
+ assert (
+ search_results["vector_search_results"][0]["metadata"]["text"]
+ == "The quick brown fox jumps over the lazy dog."
+ )
+
+ # Delete the document
+ delete_result = await app.adelete(["author"], ["John Doe"])
+
+ # Verify the deletion was successful
+ expected_deletion_message = "deleted successfully"
+ assert (
+ expected_deletion_message in delete_result
+ ), f"Expected successful deletion message, but got {delete_result}"
+
+ # Search for the document again
+ search_results_2 = await app.asearch("who was aristotle?")
+
+ # Verify that the search results are empty
+ assert (
+ len(search_results_2["vector_search_results"]) == 0
+ ), f"Expected no search results, but got {search_results_2['results']}"
+ except asyncio.CancelledError:
+ pass
+
+
+@pytest.mark.parametrize("app", ["local", "pgvector"], indirect=True)
+@pytest.mark.asyncio
+async def test_ingest_user_documents(app, logging_connection):
+ try:
+ user_id_0 = generate_id_from_label("user_0")
+ user_id_1 = generate_id_from_label("user_1")
+
+ try:
+ await app.aingest_documents(
+ [
+ Document(
+ id=generate_id_from_label("doc_01"),
+ data="The quick brown fox jumps over the lazy dog.",
+ type="txt",
+ metadata={"author": "John Doe", "user_id": user_id_0},
+ ),
+ Document(
+ id=generate_id_from_label("doc_11"),
+ data="The lazy dog jumps over the quick brown fox.",
+ type="txt",
+ metadata={"author": "John Doe", "user_id": user_id_1},
+ ),
+ ]
+ )
+ user_id_results = await app.ausers_overview([user_id_0, user_id_1])
+ assert set([stats.user_id for stats in user_id_results]) == set(
+ [user_id_0, user_id_1]
+ ), f"Expected user ids {user_id_0} and {user_id_1}, but got {user_id_results}"
+
+ user_0_docs = await app.adocuments_overview(user_ids=[user_id_0])
+ user_1_docs = await app.adocuments_overview(user_ids=[user_id_1])
+
+ assert (
+ len(user_0_docs) == 1
+ ), f"Expected 1 document for user {user_id_0}, but got {len(user_0_docs)}"
+ assert (
+ len(user_1_docs) == 1
+ ), f"Expected 1 document for user {user_id_1}, but got {len(user_1_docs)}"
+ assert user_0_docs[0].document_id == generate_id_from_label(
+ "doc_01"
+ ), f"Expected document id {str(generate_id_from_label('doc_0'))} for user {user_id_0}, but got {user_0_docs[0].document_id}"
+ assert user_1_docs[0].document_id == generate_id_from_label(
+ "doc_11"
+ ), f"Expected document id {str(generate_id_from_label('doc_1'))} for user {user_id_1}, but got {user_1_docs[0].document_id}"
+ finally:
+ await app.adelete(
+ ["document_id", "document_id"],
+ [
+ str(generate_id_from_label("doc_01")),
+ str(generate_id_from_label("doc_11")),
+ ],
+ )
+ except asyncio.CancelledError:
+ pass
+
+
+@pytest.mark.parametrize("app", ["pgvector"], indirect=True)
+@pytest.mark.asyncio
+async def test_delete_by_id(app, logging_connection):
+ try:
+ await app.aingest_documents(
+ [
+ Document(
+ id=generate_id_from_label("doc_1"),
+ data="The quick brown fox jumps over the lazy dog.",
+ type="txt",
+ metadata={"author": "John Doe"},
+ ),
+ ]
+ )
+ search_results = await app.asearch("who was aristotle?")
+
+ assert len(search_results["vector_search_results"]) > 0
+ await app.adelete(
+ ["document_id"], [str(generate_id_from_label("doc_1"))]
+ )
+ search_results = await app.asearch("who was aristotle?")
+ assert len(search_results["vector_search_results"]) == 0
+ except asyncio.CancelledError:
+ pass
+
+
+@pytest.mark.parametrize("app", ["pgvector"], indirect=True)
+@pytest.mark.asyncio
+async def test_double_ingest(app, logging_connection):
+ try:
+ await app.aingest_documents(
+ [
+ Document(
+ id=generate_id_from_label("doc_1"),
+ data="The quick brown fox jumps over the lazy dog.",
+ type="txt",
+ metadata={"author": "John Doe"},
+ ),
+ ]
+ )
+ search_results = await app.asearch("who was aristotle?")
+
+ assert len(search_results["vector_search_results"]) == 1
+ with pytest.raises(Exception):
+ await app.aingest_documents(
+ [
+ Document(
+ id=generate_id_from_label("doc_1"),
+ data="The quick brown fox jumps over the lazy dog.",
+ type="txt",
+ metadata={"author": "John Doe"},
+ ),
+ ]
+ )
+ except asyncio.CancelledError:
+ pass
diff --git a/R2R/tests/test_ingestion_service.py b/R2R/tests/test_ingestion_service.py
new file mode 100755
index 00000000..375e51f9
--- /dev/null
+++ b/R2R/tests/test_ingestion_service.py
@@ -0,0 +1,443 @@
+import asyncio
+import io
+import uuid
+from datetime import datetime
+from unittest.mock import AsyncMock, MagicMock, Mock
+
+import pytest
+from fastapi import UploadFile
+
+from r2r.base import (
+ Document,
+ DocumentInfo,
+ R2RDocumentProcessingError,
+ R2RException,
+ generate_id_from_label,
+)
+from r2r.main import R2RPipelines, R2RProviders
+from r2r.main.services.ingestion_service import IngestionService
+
+
+@pytest.fixture(scope="session", autouse=True)
+def event_loop_policy():
+ asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy())
+
+
+@pytest.fixture(scope="function")
+def event_loop():
+ loop = asyncio.get_event_loop_policy().new_event_loop()
+ yield loop
+ loop.close()
+ asyncio.set_event_loop(None)
+
+
+@pytest.fixture(scope="session", autouse=True)
+async def cleanup_tasks():
+ yield
+ for task in asyncio.all_tasks():
+ if task is not asyncio.current_task():
+ task.cancel()
+ try:
+ await task
+ except asyncio.CancelledError:
+ pass
+
+
+@pytest.fixture
+def mock_vector_db():
+ mock_db = MagicMock()
+ mock_db.get_documents_overview.return_value = [] # Default to empty list
+ return mock_db
+
+
+@pytest.fixture
+def mock_embedding_model():
+ return Mock()
+
+
+@pytest.fixture
+def ingestion_service(mock_vector_db, mock_embedding_model):
+ config = MagicMock()
+ config.app.get.return_value = 32 # Default max file size
+ providers = Mock(spec=R2RProviders)
+ providers.vector_db = mock_vector_db
+ providers.embedding_model = mock_embedding_model
+ pipelines = Mock(spec=R2RPipelines)
+ pipelines.ingestion_pipeline = AsyncMock()
+ pipelines.ingestion_pipeline.run.return_value = {
+ "embedding_pipeline_output": []
+ }
+ run_manager = Mock()
+ run_manager.run_info = {"mock_run_id": {}}
+ logging_connection = AsyncMock()
+
+ return IngestionService(
+ config, providers, pipelines, run_manager, logging_connection
+ )
+
+
+@pytest.mark.asyncio
+async def test_ingest_single_document(ingestion_service, mock_vector_db):
+ try:
+ document = Document(
+ id=generate_id_from_label("test_id"),
+ data="Test content",
+ type="txt",
+ metadata={},
+ )
+
+ ingestion_service.pipelines.ingestion_pipeline.run.return_value = {
+ "embedding_pipeline_output": [(document.id, None)]
+ }
+ mock_vector_db.get_documents_overview.return_value = (
+ []
+ ) # No existing documents
+
+ result = await ingestion_service.ingest_documents([document])
+
+ assert result["processed_documents"] == [
+ f"Document '{document.id}' processed successfully."
+ ]
+ assert not result["failed_documents"]
+ assert not result["skipped_documents"]
+ except asyncio.CancelledError:
+ pass
+
+
+@pytest.mark.asyncio
+async def test_ingest_duplicate_document(ingestion_service, mock_vector_db):
+ try:
+ document = Document(
+ id=generate_id_from_label("test_id"),
+ data="Test content",
+ type="txt",
+ metadata={},
+ )
+ mock_vector_db.get_documents_overview.return_value = [
+ DocumentInfo(
+ document_id=document.id,
+ version="v0",
+ size_in_bytes=len(document.data),
+ metadata={},
+ title=str(document.id),
+ user_id=None,
+ created_at=datetime.now(),
+ updated_at=datetime.now(),
+ status="success",
+ )
+ ]
+
+ with pytest.raises(R2RException) as exc_info:
+ await ingestion_service.ingest_documents([document])
+
+ assert (
+ f"Document with ID {document.id} was already successfully processed"
+ in str(exc_info.value)
+ )
+ except asyncio.CancelledError:
+ pass
+
+
+@pytest.mark.asyncio
+async def test_ingest_file(ingestion_service):
+ try:
+ file_content = b"Test content"
+ file_mock = UploadFile(
+ filename="test.txt", file=io.BytesIO(file_content)
+ )
+ file_mock.file.seek(0)
+ file_mock.size = len(file_content) # Set file size manually
+
+ ingestion_service.pipelines.ingestion_pipeline.run.return_value = {
+ "embedding_pipeline_output": [
+ (generate_id_from_label("test.txt"), None)
+ ]
+ }
+
+ result = await ingestion_service.ingest_files([file_mock])
+
+ assert len(result["processed_documents"]) == 1
+ assert not result["failed_documents"]
+ assert not result["skipped_documents"]
+ except asyncio.CancelledError:
+ pass
+
+
+@pytest.mark.asyncio
+async def test_ingest_mixed_success_and_failure(
+ ingestion_service, mock_vector_db
+):
+ try:
+ documents = [
+ Document(
+ id=generate_id_from_label("success_id"),
+ data="Success content",
+ type="txt",
+ metadata={},
+ ),
+ Document(
+ id=generate_id_from_label("failure_id"),
+ data="Failure content",
+ type="txt",
+ metadata={},
+ ),
+ ]
+
+ ingestion_service.pipelines.ingestion_pipeline.run.return_value = {
+ "embedding_pipeline_output": [
+ (
+ documents[0].id,
+ f"Processed 1 vectors for document {documents[0].id}.",
+ ),
+ (
+ documents[1].id,
+ R2RDocumentProcessingError(
+ error_message="Embedding failed",
+ document_id=documents[1].id,
+ ),
+ ),
+ ]
+ }
+
+ result = await ingestion_service.ingest_documents(documents)
+
+ assert len(result["processed_documents"]) == 1
+ assert len(result["failed_documents"]) == 1
+ assert str(documents[0].id) in result["processed_documents"][0]
+ assert str(documents[1].id) in result["failed_documents"][0]
+ assert "Embedding failed" in result["failed_documents"][0]
+
+ assert mock_vector_db.upsert_documents_overview.call_count == 2
+ upserted_docs = mock_vector_db.upsert_documents_overview.call_args[0][
+ 0
+ ]
+ assert len(upserted_docs) == 2
+ assert upserted_docs[0].document_id == documents[0].id
+ assert upserted_docs[0].status == "success"
+ assert upserted_docs[1].document_id == documents[1].id
+ assert upserted_docs[1].status == "failure"
+ except asyncio.CancelledError:
+ pass
+
+
+@pytest.mark.asyncio
+async def test_ingest_unsupported_file_type(ingestion_service):
+ try:
+ file_mock = UploadFile(
+ filename="test.unsupported", file=io.BytesIO(b"Test content")
+ )
+ file_mock.file.seek(0)
+ file_mock.size = 12 # Set file size manually
+
+ with pytest.raises(R2RException) as exc_info:
+ await ingestion_service.ingest_files([file_mock])
+
+ assert "is not a valid DocumentType" in str(exc_info.value)
+ except asyncio.CancelledError:
+ pass
+
+
+@pytest.mark.asyncio
+async def test_ingest_large_file(ingestion_service):
+ try:
+ large_content = b"Large content" * 1000000 # 12MB content
+ file_mock = UploadFile(
+ filename="large_file.txt", file=io.BytesIO(large_content)
+ )
+ file_mock.file.seek(0)
+ file_mock.size = len(large_content) # Set file size manually
+
+ ingestion_service.config.app.get.return_value = (
+ 10 # Set max file size to 10MB
+ )
+
+ with pytest.raises(R2RException) as exc_info:
+ await ingestion_service.ingest_files([file_mock])
+
+ assert "File size exceeds maximum allowed size" in str(exc_info.value)
+ except asyncio.CancelledError:
+ pass
+
+
+@pytest.mark.asyncio
+async def test_partial_ingestion_success(ingestion_service, mock_vector_db):
+ try:
+ documents = [
+ Document(
+ id=generate_id_from_label("success_1"),
+ data="Success content 1",
+ type="txt",
+ metadata={},
+ ),
+ Document(
+ id=generate_id_from_label("fail"),
+ data="Fail content",
+ type="txt",
+ metadata={},
+ ),
+ Document(
+ id=generate_id_from_label("success_2"),
+ data="Success content 2",
+ type="txt",
+ metadata={},
+ ),
+ ]
+
+ ingestion_service.pipelines.ingestion_pipeline.run.return_value = {
+ "embedding_pipeline_output": [
+ (documents[0].id, None),
+ (
+ documents[1].id,
+ R2RDocumentProcessingError(
+ error_message="Embedding failed",
+ document_id=documents[1].id,
+ ),
+ ),
+ (documents[2].id, None),
+ ]
+ }
+
+ result = await ingestion_service.ingest_documents(documents)
+
+ assert len(result["processed_documents"]) == 2
+ assert len(result["failed_documents"]) == 1
+ assert str(documents[1].id) in result["failed_documents"][0]
+ except asyncio.CancelledError:
+ pass
+
+
+@pytest.mark.asyncio
+async def test_version_increment(ingestion_service, mock_vector_db):
+ try:
+ document = Document(
+ id=generate_id_from_label("test_id"),
+ data="Test content",
+ type="txt",
+ metadata={},
+ )
+ mock_vector_db.get_documents_overview.return_value = [
+ DocumentInfo(
+ document_id=document.id,
+ version="v2",
+ status="success",
+ size_in_bytes=0,
+ metadata={},
+ )
+ ]
+
+ file_mock = UploadFile(
+ filename="test.txt", file=io.BytesIO(b"Updated content")
+ )
+ await ingestion_service.update_files([file_mock], [document.id])
+
+ calls = mock_vector_db.upsert_documents_overview.call_args_list
+ assert len(calls) == 2
+ assert calls[1][0][0][0].version == "v3"
+ except asyncio.CancelledError:
+ pass
+
+
+@pytest.mark.asyncio
+async def test_process_ingestion_results_error_handling(ingestion_service):
+ try:
+ document_infos = [
+ DocumentInfo(
+ document_id=uuid.uuid4(),
+ version="v0",
+ status="processing",
+ size_in_bytes=0,
+ metadata={},
+ )
+ ]
+ ingestion_results = {
+ "embedding_pipeline_output": [
+ (
+ document_infos[0].document_id,
+ R2RDocumentProcessingError(
+ "Unexpected error",
+ document_id=document_infos[0].document_id,
+ ),
+ )
+ ]
+ }
+
+ result = await ingestion_service._process_ingestion_results(
+ ingestion_results,
+ document_infos,
+ [],
+ {document_infos[0].document_id: "test"},
+ )
+
+ assert len(result["failed_documents"]) == 1
+ assert "Unexpected error" in result["failed_documents"][0]
+ except asyncio.CancelledError:
+ pass
+
+
+@pytest.mark.asyncio
+async def test_file_size_limit_edge_cases(ingestion_service):
+ try:
+ ingestion_service.config.app.get.return_value = 1 # 1MB limit
+
+ just_under_limit = b"x" * (1024 * 1024 - 1)
+ at_limit = b"x" * (1024 * 1024)
+ over_limit = b"x" * (1024 * 1024 + 1)
+
+ file_under = UploadFile(
+ filename="under.txt",
+ file=io.BytesIO(just_under_limit),
+ size=1024 * 1024 - 1,
+ )
+ file_at = UploadFile(
+ filename="at.txt", file=io.BytesIO(at_limit), size=1024 * 1024
+ )
+ file_over = UploadFile(
+ filename="over.txt",
+ file=io.BytesIO(over_limit),
+ size=1024 * 1024 + 1,
+ )
+
+ await ingestion_service.ingest_files([file_under]) # Should succeed
+ await ingestion_service.ingest_files([file_at]) # Should succeed
+
+ with pytest.raises(
+ R2RException, match="File size exceeds maximum allowed size"
+ ):
+ await ingestion_service.ingest_files([file_over])
+ except asyncio.CancelledError:
+ pass
+
+
+@pytest.mark.asyncio
+async def test_document_status_update_after_ingestion(
+ ingestion_service, mock_vector_db
+):
+ try:
+ document = Document(
+ id=generate_id_from_label("test_id"),
+ data="Test content",
+ type="txt",
+ metadata={},
+ )
+
+ ingestion_service.pipelines.ingestion_pipeline.run.return_value = {
+ "embedding_pipeline_output": [(document.id, None)]
+ }
+ mock_vector_db.get_documents_overview.return_value = (
+ []
+ ) # No existing documents
+
+ await ingestion_service.ingest_documents([document])
+
+ # Check that upsert_documents_overview was called twice
+ assert mock_vector_db.upsert_documents_overview.call_count == 2
+
+ # Check the second call to upsert_documents_overview (status update)
+ second_call_args = (
+ mock_vector_db.upsert_documents_overview.call_args_list[1][0][0]
+ )
+ assert len(second_call_args) == 1
+ assert second_call_args[0].document_id == document.id
+ assert second_call_args[0].status == "success"
+ except asyncio.CancelledError:
+ pass
diff --git a/R2R/tests/test_llms.py b/R2R/tests/test_llms.py
new file mode 100755
index 00000000..666bbff8
--- /dev/null
+++ b/R2R/tests/test_llms.py
@@ -0,0 +1,59 @@
+import pytest
+
+from r2r import LLMConfig
+from r2r.base.abstractions.llm import GenerationConfig
+from r2r.providers.llms import LiteLLM
+
+
+@pytest.fixture
+def lite_llm():
+ config = LLMConfig(provider="litellm")
+ return LiteLLM(config)
+
+
+@pytest.mark.parametrize("llm_fixture", ["lite_llm"])
+def test_get_completion_ollama(request, llm_fixture):
+ llm = request.getfixturevalue(llm_fixture)
+
+ messages = [
+ {
+ "role": "user",
+ "content": "This is a test, return only the word `True`",
+ }
+ ]
+ generation_config = GenerationConfig(
+ model="ollama/llama2",
+ temperature=0.0,
+ top_p=0.9,
+ max_tokens_to_sample=50,
+ stream=False,
+ )
+
+ completion = llm.get_completion(messages, generation_config)
+ # assert isinstance(completion, LLMChatCompletion)
+ assert completion.choices[0].message.role == "assistant"
+ assert completion.choices[0].message.content.strip() == "True"
+
+
+@pytest.mark.parametrize("llm_fixture", ["lite_llm"])
+def test_get_completion_openai(request, llm_fixture):
+ llm = request.getfixturevalue(llm_fixture)
+
+ messages = [
+ {
+ "role": "user",
+ "content": "This is a test, return only the word `True`",
+ }
+ ]
+ generation_config = GenerationConfig(
+ model="gpt-3.5-turbo",
+ temperature=0.0,
+ top_p=0.9,
+ max_tokens_to_sample=50,
+ stream=False,
+ )
+
+ completion = llm.get_completion(messages, generation_config)
+ # assert isinstance(completion, LLMChatCompletion)
+ assert completion.choices[0].message.role == "assistant"
+ assert completion.choices[0].message.content.strip() == "True"
diff --git a/R2R/tests/test_logging.py b/R2R/tests/test_logging.py
new file mode 100755
index 00000000..cab5051d
--- /dev/null
+++ b/R2R/tests/test_logging.py
@@ -0,0 +1,360 @@
+import asyncio
+import logging
+import os
+import uuid
+
+import pytest
+
+from r2r import (
+ LocalKVLoggingProvider,
+ LoggingConfig,
+ PostgresKVLoggingProvider,
+ PostgresLoggingConfig,
+ RedisKVLoggingProvider,
+ RedisLoggingConfig,
+ generate_run_id,
+)
+
+logger = logging.getLogger(__name__)
+
+
+@pytest.fixture(scope="session", autouse=True)
+def event_loop_policy():
+ asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy())
+
+
+@pytest.fixture(scope="function")
+def event_loop():
+ loop = asyncio.get_event_loop_policy().new_event_loop()
+ yield loop
+ loop.close()
+ asyncio.set_event_loop(None)
+
+
+@pytest.fixture(scope="session", autouse=True)
+async def cleanup_tasks():
+ yield
+ for task in asyncio.all_tasks():
+ if task is not asyncio.current_task():
+ task.cancel()
+ try:
+ await task
+ except asyncio.CancelledError:
+ pass
+
+
+@pytest.fixture(scope="function")
+def local_provider():
+ """Fixture to create and tear down the LocalKVLoggingProvider with a unique database file."""
+ # Generate a unique file name for the SQLite database
+ unique_id = str(uuid.uuid4())
+ logging_path = f"test_{unique_id}.sqlite"
+
+ # Setup the LocalKVLoggingProvider with the unique file
+ provider = LocalKVLoggingProvider(LoggingConfig(logging_path=logging_path))
+
+ # Provide the setup provider to the test
+ yield provider
+
+ # Cleanup: Remove the SQLite file after test completes
+ provider.close()
+ if os.path.exists(logging_path):
+ os.remove(logging_path)
+
+
+@pytest.mark.asyncio
+async def test_local_logging(local_provider):
+ """Test logging and retrieving from the local logging provider."""
+ try:
+ run_id = generate_run_id()
+ await local_provider.init()
+ await local_provider.log(run_id, "key", "value")
+ logs = await local_provider.get_logs([run_id])
+ assert len(logs) == 1
+ assert logs[0]["key"] == "key"
+ assert logs[0]["value"] == "value"
+ except asyncio.CancelledError:
+ pass
+
+
+@pytest.mark.asyncio
+async def test_multiple_log_entries(local_provider):
+ """Test logging multiple entries and retrieving them."""
+ try:
+ run_id_0 = generate_run_id()
+ run_id_1 = generate_run_id()
+ run_id_2 = generate_run_id()
+ await local_provider.init()
+
+ entries = [
+ (run_id_0, "key_0", "value_0"),
+ (run_id_1, "key_1", "value_1"),
+ (run_id_2, "key_2", "value_2"),
+ ]
+ for run_id, key, value in entries:
+ await local_provider.log(run_id, key, value)
+
+ logs = await local_provider.get_logs([run_id_0, run_id_1, run_id_2])
+ assert len(logs) == 3
+
+ # Check that logs are returned in the correct order (most recent first if applicable)
+ for log in logs:
+ selected_entry = [
+ entry for entry in entries if entry[0] == log["log_id"]
+ ][0]
+ assert log["log_id"] == selected_entry[0]
+ assert log["key"] == selected_entry[1]
+ assert log["value"] == selected_entry[2]
+ except asyncio.CancelledError:
+ pass
+
+
+@pytest.mark.asyncio
+async def test_log_retrieval_limit(local_provider):
+ """Test the max_logs limit parameter works correctly."""
+ try:
+ await local_provider.init()
+
+ run_ids = []
+ for i in range(10): # Add 10 entries
+ run_ids.append(generate_run_id())
+ await local_provider.log(run_ids[-1], f"key_{i}", f"value_{i}")
+
+ logs = await local_provider.get_logs(run_ids[0:5])
+ assert len(logs) == 5 # Ensure only 5 logs are returned
+ except asyncio.CancelledError:
+ pass
+
+
+@pytest.mark.asyncio
+async def test_specific_run_type_retrieval(local_provider):
+ """Test retrieving logs for a specific run type works correctly."""
+ try:
+ await local_provider.init()
+ run_id_0 = generate_run_id()
+ run_id_1 = generate_run_id()
+
+ await local_provider.log(
+ run_id_0, "pipeline_type", "search", is_info_log=True
+ )
+ await local_provider.log(run_id_0, "key_0", "value_0")
+ await local_provider.log(
+ run_id_1, "pipeline_type", "rag", is_info_log=True
+ )
+ await local_provider.log(run_id_1, "key_1", "value_1")
+
+ run_info = await local_provider.get_run_info(log_type_filter="search")
+ logs = await local_provider.get_logs([run.run_id for run in run_info])
+ assert len(logs) == 1
+ assert logs[0]["log_id"] == run_id_0
+ assert logs[0]["key"] == "key_0"
+ assert logs[0]["value"] == "value_0"
+ except asyncio.CancelledError:
+ pass
+
+
+@pytest.fixture(scope="function")
+def postgres_provider():
+ """Fixture to create and tear down the PostgresKVLoggingProvider."""
+ log_table = f"logs_{str(uuid.uuid4()).replace('-', '_')}"
+ log_info_table = f"log_info_{str(uuid.uuid4()).replace('-', '_')}"
+
+ provider = PostgresKVLoggingProvider(
+ PostgresLoggingConfig(
+ log_table=log_table, log_info_table=log_info_table
+ )
+ )
+ yield provider
+
+
+@pytest.mark.asyncio
+async def test_postgres_logging(postgres_provider):
+ """Test logging and retrieving from the postgres logging provider."""
+ try:
+ await postgres_provider.init()
+ run_id = generate_run_id()
+ await postgres_provider.log(run_id, "key", "value")
+ logs = await postgres_provider.get_logs([run_id])
+ assert len(logs) == 1
+ assert logs[0]["key"] == "key"
+ assert logs[0]["value"] == "value"
+ except asyncio.CancelledError:
+ pass
+
+
+@pytest.mark.asyncio
+async def test_postgres_multiple_log_entries(postgres_provider):
+ """Test logging multiple entries and retrieving them."""
+ try:
+ await postgres_provider.init()
+ run_id_0 = generate_run_id()
+ run_id_1 = generate_run_id()
+ run_id_2 = generate_run_id()
+
+ entries = [
+ (run_id_0, "key_0", "value_0"),
+ (run_id_1, "key_1", "value_1"),
+ (run_id_2, "key_2", "value_2"),
+ ]
+ for run_id, key, value in entries:
+ await postgres_provider.log(run_id, key, value)
+
+ logs = await postgres_provider.get_logs([run_id_0, run_id_1, run_id_2])
+ assert len(logs) == 3
+
+ # Check that logs are returned in the correct order (most recent first if applicable)
+ for log in logs:
+ selected_entry = [
+ entry for entry in entries if entry[0] == log["log_id"]
+ ][0]
+ assert log["log_id"] == selected_entry[0]
+ assert log["key"] == selected_entry[1]
+ assert log["value"] == selected_entry[2]
+ except asyncio.CancelledError:
+ pass
+
+
+@pytest.mark.asyncio
+async def test_postgres_log_retrieval_limit(postgres_provider):
+ """Test the max_logs limit parameter works correctly."""
+ try:
+ await postgres_provider.init()
+ run_ids = []
+ for i in range(10): # Add 10 entries
+ run_ids.append(generate_run_id())
+ await postgres_provider.log(run_ids[-1], f"key_{i}", f"value_{i}")
+
+ logs = await postgres_provider.get_logs(run_ids[:5])
+ assert len(logs) == 5 # Ensure only 5 logs are returned
+ except asyncio.CancelledError:
+ pass
+
+
+@pytest.mark.asyncio
+async def test_postgres_specific_run_type_retrieval(postgres_provider):
+ """Test retrieving logs for a specific run type works correctly."""
+ try:
+ await postgres_provider.init()
+ run_id_0 = generate_run_id()
+ run_id_1 = generate_run_id()
+
+ await postgres_provider.log(
+ run_id_0, "pipeline_type", "search", is_info_log=True
+ )
+ await postgres_provider.log(run_id_0, "key_0", "value_0")
+ await postgres_provider.log(
+ run_id_1, "pipeline_type", "rag", is_info_log=True
+ )
+ await postgres_provider.log(run_id_1, "key_1", "value_1")
+
+ run_info = await postgres_provider.get_run_info(
+ log_type_filter="search"
+ )
+ logs = await postgres_provider.get_logs(
+ [run.run_id for run in run_info]
+ )
+ assert len(logs) == 1
+ assert logs[0]["log_id"] == run_id_0
+ assert logs[0]["key"] == "key_0"
+ assert logs[0]["value"] == "value_0"
+ except asyncio.CancelledError:
+ pass
+
+
+@pytest.fixture(scope="function")
+def redis_provider():
+ """Fixture to create and tear down the RedisKVLoggingProvider."""
+ log_table = f"logs_{str(uuid.uuid4()).replace('-', '_')}"
+ log_info_table = f"log_info_{str(uuid.uuid4()).replace('-', '_')}"
+
+ provider = RedisKVLoggingProvider(
+ RedisLoggingConfig(log_table=log_table, log_info_table=log_info_table)
+ )
+ yield provider
+ provider.close()
+
+
+@pytest.mark.asyncio
+async def test_redis_logging(redis_provider):
+ """Test logging and retrieving from the Redis logging provider."""
+ try:
+ run_id = generate_run_id()
+ await redis_provider.log(run_id, "key", "value")
+ logs = await redis_provider.get_logs([run_id])
+ assert len(logs) == 1
+ assert logs[0]["key"] == "key"
+ assert logs[0]["value"] == "value"
+ except asyncio.CancelledError:
+ pass
+
+
+@pytest.mark.asyncio
+async def test_redis_multiple_log_entries(redis_provider):
+ """Test logging multiple entries and retrieving them."""
+ try:
+ run_id_0 = generate_run_id()
+ run_id_1 = generate_run_id()
+ run_id_2 = generate_run_id()
+
+ entries = [
+ (run_id_0, "key_0", "value_0"),
+ (run_id_1, "key_1", "value_1"),
+ (run_id_2, "key_2", "value_2"),
+ ]
+ for run_id, key, value in entries:
+ await redis_provider.log(run_id, key, value)
+
+ logs = await redis_provider.get_logs([run_id_0, run_id_1, run_id_2])
+ assert len(logs) == 3
+
+ # Check that logs are returned in the correct order (most recent first if applicable)
+ for log in logs:
+ selected_entry = [
+ entry for entry in entries if entry[0] == log["log_id"]
+ ][0]
+ assert log["log_id"] == selected_entry[0]
+ assert log["key"] == selected_entry[1]
+ assert log["value"] == selected_entry[2]
+ except asyncio.CancelledError:
+ pass
+
+
+@pytest.mark.asyncio
+async def test_redis_log_retrieval_limit(redis_provider):
+ """Test the max_logs limit parameter works correctly."""
+ try:
+ run_ids = []
+ for i in range(10): # Add 10 entries
+ run_ids.append(generate_run_id())
+ await redis_provider.log(run_ids[-1], f"key_{i}", f"value_{i}")
+
+ logs = await redis_provider.get_logs(run_ids[0:5])
+ assert len(logs) == 5 # Ensure only 5 logs are returned
+ except asyncio.CancelledError:
+ pass
+
+
+@pytest.mark.asyncio
+async def test_redis_specific_run_type_retrieval(redis_provider):
+ """Test retrieving logs for a specific run type works correctly."""
+ try:
+ run_id_0 = generate_run_id()
+ run_id_1 = generate_run_id()
+
+ await redis_provider.log(
+ run_id_0, "pipeline_type", "search", is_info_log=True
+ )
+ await redis_provider.log(run_id_0, "key_0", "value_0")
+ await redis_provider.log(
+ run_id_1, "pipeline_type", "rag", is_info_log=True
+ )
+ await redis_provider.log(run_id_1, "key_1", "value_1")
+
+ run_info = await redis_provider.get_run_info(log_type_filter="search")
+ logs = await redis_provider.get_logs([run.run_id for run in run_info])
+ assert len(logs) == 1
+ assert logs[0]["log_id"] == run_id_0
+ assert logs[0]["key"] == "key_0"
+ assert logs[0]["value"] == "value_0"
+ except asyncio.CancelledError:
+ pass
diff --git a/R2R/tests/test_parser.py b/R2R/tests/test_parser.py
new file mode 100755
index 00000000..6965c5a9
--- /dev/null
+++ b/R2R/tests/test_parser.py
@@ -0,0 +1,159 @@
+import asyncio
+import json
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from r2r.parsers.media.docx_parser import DOCXParser
+from r2r.parsers.media.pdf_parser import PDFParser
+from r2r.parsers.media.ppt_parser import PPTParser
+from r2r.parsers.structured.csv_parser import CSVParser
+from r2r.parsers.structured.json_parser import JSONParser
+from r2r.parsers.structured.xlsx_parser import XLSXParser
+from r2r.parsers.text.html_parser import HTMLParser
+from r2r.parsers.text.md_parser import MDParser
+from r2r.parsers.text.text_parser import TextParser
+
+
+@pytest.fixture(scope="session", autouse=True)
+def event_loop_policy():
+ asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy())
+
+
+@pytest.fixture(scope="function")
+def event_loop():
+ loop = asyncio.get_event_loop_policy().new_event_loop()
+ yield loop
+ loop.close()
+ asyncio.set_event_loop(None)
+
+
+@pytest.fixture(scope="session", autouse=True)
+async def cleanup_tasks():
+ yield
+ for task in asyncio.all_tasks():
+ if task is not asyncio.current_task():
+ task.cancel()
+ try:
+ await task
+ except asyncio.CancelledError:
+ pass
+
+
+@pytest.mark.asyncio
+async def test_text_parser():
+ try:
+ parser = TextParser()
+ data = "Simple text"
+ async for result in parser.ingest(data):
+ assert result == "Simple text"
+ except asyncio.CancelledError:
+ pass
+
+
+@pytest.mark.asyncio
+async def test_json_parser():
+ try:
+ parser = JSONParser()
+ data = json.dumps({"key": "value", "null_key": None})
+ async for result in parser.ingest(data):
+ assert "key: value" in result
+ assert "null_key" not in result
+ except asyncio.CancelledError:
+ pass
+
+
+@pytest.mark.asyncio
+async def test_html_parser():
+ try:
+ parser = HTMLParser()
+ data = "<html><body><p>Hello World</p></body></html>"
+ async for result in parser.ingest(data):
+ assert result.strip() == "Hello World"
+ except asyncio.CancelledError:
+ pass
+
+
+@pytest.mark.asyncio
+@patch("pypdf.PdfReader")
+async def test_pdf_parser(mock_pdf_reader):
+ try:
+ parser = PDFParser()
+ mock_pdf_reader.return_value.pages = [
+ MagicMock(extract_text=lambda: "Page text")
+ ]
+ data = b"fake PDF data"
+ async for result in parser.ingest(data):
+ assert result == "Page text"
+ except asyncio.CancelledError:
+ pass
+
+
+@pytest.mark.asyncio
+@patch("pptx.Presentation")
+async def test_ppt_parser(mock_presentation):
+ try:
+ mock_slide = MagicMock()
+ mock_shape = MagicMock(text="Slide text")
+ mock_slide.shapes = [mock_shape]
+ mock_presentation.return_value.slides = [mock_slide]
+ parser = PPTParser()
+ data = b"fake PPT data"
+ async for result in parser.ingest(data):
+ assert result == "Slide text"
+ except asyncio.CancelledError:
+ pass
+
+
+@pytest.mark.asyncio
+@patch("docx.Document")
+async def test_docx_parser(mock_document):
+ try:
+ mock_paragraph = MagicMock(text="Paragraph text")
+ mock_document.return_value.paragraphs = [mock_paragraph]
+ parser = DOCXParser()
+ data = b"fake DOCX data"
+ async for result in parser.ingest(data):
+ assert result == "Paragraph text"
+ except asyncio.CancelledError:
+ pass
+
+
+@pytest.mark.asyncio
+async def test_csv_parser():
+ try:
+ parser = CSVParser()
+ data = "col1,col2\nvalue1,value2"
+ async for result in parser.ingest(data):
+ assert result == "col1, col2"
+ break
+ except asyncio.CancelledError:
+ pass
+
+
+@pytest.mark.asyncio
+@patch("openpyxl.load_workbook")
+async def test_xlsx_parser(mock_load_workbook):
+ try:
+ mock_sheet = MagicMock()
+ mock_sheet.iter_rows.return_value = [(1, 2), (3, 4)]
+ mock_workbook = MagicMock(worksheets=[mock_sheet])
+ mock_load_workbook.return_value = mock_workbook
+ parser = XLSXParser()
+ data = b"fake XLSX data"
+ async for result in parser.ingest(data):
+ assert result == "1, 2"
+ break
+ except asyncio.CancelledError:
+ pass
+
+
+@pytest.mark.asyncio
+async def test_markdown_parser():
+ try:
+ parser = MDParser()
+ data = "# Header\nContent"
+ async for result in parser.ingest(data):
+ assert result.strip() == "Header\nContent"
+ except asyncio.CancelledError:
+ pass
diff --git a/R2R/tests/test_pipeline.py b/R2R/tests/test_pipeline.py
new file mode 100755
index 00000000..1811de0f
--- /dev/null
+++ b/R2R/tests/test_pipeline.py
@@ -0,0 +1,291 @@
+import asyncio
+from typing import Any, AsyncGenerator
+
+import pytest
+
+from r2r import AsyncPipe, AsyncPipeline, PipeType
+
+
+class MultiplierPipe(AsyncPipe):
+ def __init__(self, multiplier=1, delay=0, name="multiplier_pipe"):
+ super().__init__(
+ type=PipeType.OTHER,
+ config=self.PipeConfig(name=name),
+ )
+ self.multiplier = multiplier
+ self.delay = delay
+
+ async def _run_logic(
+ self,
+ input: AsyncGenerator[Any, None],
+ state,
+ run_id=None,
+ *args,
+ **kwargs,
+ ) -> AsyncGenerator[Any, None]:
+ async for item in input.message:
+ if self.delay > 0:
+ await asyncio.sleep(self.delay) # Simulate processing delay
+ if isinstance(item, list):
+ processed = [x * self.multiplier for x in item]
+ elif isinstance(item, int):
+ processed = item * self.multiplier
+ else:
+ raise ValueError(f"Unsupported type: {type(item)}")
+ yield processed
+
+
+class FanOutPipe(AsyncPipe):
+ def __init__(self, multiplier=1, delay=0, name="fan_out_pipe"):
+ super().__init__(
+ type=PipeType.OTHER,
+ config=self.PipeConfig(name=name),
+ )
+ self.multiplier = multiplier
+ self.delay = delay
+
+ async def _run_logic(
+ self,
+ input: AsyncGenerator[Any, None],
+ state,
+ run_id=None,
+ *args,
+ **kwargs,
+ ) -> AsyncGenerator[Any, None]:
+ inputs = []
+ async for item in input.message:
+ inputs.append(item)
+ for it in range(self.multiplier):
+ if self.delay > 0:
+ await asyncio.sleep(self.delay)
+ yield [(it + 1) * ele for ele in inputs]
+
+
+class FanInPipe(AsyncPipe):
+ def __init__(self, delay=0, name="fan_in_pipe"):
+ super().__init__(
+ type=PipeType.OTHER,
+ config=self.PipeConfig(name=name),
+ )
+ self.delay = delay
+
+ async def _run_logic(
+ self,
+ input: AsyncGenerator[Any, None],
+ state,
+ run_id=None,
+ *args,
+ **kwargs,
+ ) -> AsyncGenerator[Any, None]:
+ total_sum = 0
+ async for batch in input.message:
+ if self.delay > 0:
+ await asyncio.sleep(self.delay) # Simulate processing delay
+ total_sum += sum(
+ batch
+ ) # Assuming batch is iterable and contains numeric values
+ yield total_sum
+
+
+@pytest.fixture
+def pipe_factory():
+ def create_pipe(type, **kwargs):
+ if type == "multiplier":
+ return MultiplierPipe(**kwargs)
+ elif type == "fan_out":
+ return FanOutPipe(**kwargs)
+ elif type == "fan_in":
+ return FanInPipe(**kwargs)
+ else:
+ raise ValueError("Unsupported pipe type")
+
+ return create_pipe
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("multiplier, delay, name", [(2, 0.1, "pipe")])
+async def test_single_multiplier(pipe_factory, multiplier, delay, name):
+ pipe = pipe_factory(
+ "multiplier", multiplier=multiplier, delay=delay, name=name
+ )
+
+ async def input_generator():
+ for i in [1, 2, 3]:
+ yield i
+
+ pipeline = AsyncPipeline()
+ pipeline.add_pipe(pipe)
+
+ result = []
+ for output in await pipeline.run(input_generator()):
+ result.append(output)
+
+ expected_result = [i * multiplier for i in [1, 2, 3]]
+ assert (
+ result == expected_result
+ ), "Pipeline output did not match expected multipliers"
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ "multiplier_a, delay_a, name_a, multiplier_b, delay_b, name_b",
+ [(2, 0.1, "pipe_a", 2, 0.1, "pipe_b")],
+)
+async def test_double_multiplier(
+ pipe_factory, multiplier_a, delay_a, name_a, multiplier_b, delay_b, name_b
+):
+ pipe_a = pipe_factory(
+ "multiplier", multiplier=multiplier_a, delay=delay_a, name=name_a
+ )
+ pipe_b = pipe_factory(
+ "multiplier", multiplier=multiplier_b, delay=delay_b, name=name_b
+ )
+
+ async def input_generator():
+ for i in [1, 2, 3]:
+ yield i
+
+ pipeline = AsyncPipeline()
+ pipeline.add_pipe(pipe_a)
+ pipeline.add_pipe(pipe_b)
+
+ result = []
+ for output in await pipeline.run(input_generator()):
+ result.append(output)
+
+ expected_result = [i * multiplier_a * multiplier_b for i in [1, 2, 3]]
+ assert (
+ result == expected_result
+ ), "Pipeline output did not match expected multipliers"
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("multiplier, delay, name", [(3, 0.1, "pipe")])
+async def test_fan_out(pipe_factory, multiplier, delay, name):
+ pipe = pipe_factory(
+ "fan_out", multiplier=multiplier, delay=delay, name=name
+ )
+
+ async def input_generator():
+ for i in [1, 2, 3]:
+ yield i
+
+ pipeline = AsyncPipeline()
+ pipeline.add_pipe(pipe)
+
+ result = []
+ for output in await pipeline.run(input_generator()):
+ result.append(output)
+
+ expected_result = [
+ [i + 1, 2 * (i + 1), 3 * (i + 1)] for i in range(multiplier)
+ ]
+ assert (
+ result == expected_result
+ ), "Pipeline output did not match expected multipliers"
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ "multiplier_a, delay_a, name_a, multiplier_b, delay_b, name_b",
+ [
+ (2, 0.1, "pipe_a", 2, 0.1, "pipe_b"),
+ (4, 0.1, "pipe_a", 3, 0.1, "pipe_b"),
+ ],
+)
+async def multiply_then_fan_out(
+ pipe_factory, multiplier_a, delay_a, name_a, multiplier_b, delay_b, name_b
+):
+ pipe_a = pipe_factory(
+ "multiplier", multiplier=multiplier_a, delay=delay_a, name=name_a
+ )
+ pipe_b = pipe_factory(
+ "fan_out", multiplier=multiplier_b, delay=delay_b, name=name_b
+ )
+
+ async def input_generator():
+ for i in [1, 2, 3]:
+ yield i
+
+ pipeline = AsyncPipeline()
+ pipeline.add_pipe(pipe_a)
+ pipeline.add_pipe(pipe_b)
+
+ result = []
+ async for output in await pipeline.run(input_generator()):
+ result.append(output)
+
+ expected_result = [[i * multiplier_a] async for i in input_generator()]
+ assert (
+ result[0] == expected_result
+ ), "Pipeline output did not match expected multipliers"
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("multiplier, delay, name", [(3, 0.1, "pipe")])
+async def test_fan_in_sum(pipe_factory, multiplier, delay, name):
+ # Create fan-out to generate multiple streams
+ fan_out_pipe = pipe_factory(
+ "fan_out", multiplier=multiplier, delay=delay, name=name + "_a"
+ )
+ # Summing fan-in pipe
+ fan_in_sum_pipe = pipe_factory("fan_in", delay=delay, name=name + "_b")
+
+ async def input_generator():
+ for i in [1, 2, 3]:
+ yield i
+
+ pipeline = AsyncPipeline()
+ pipeline.add_pipe(fan_out_pipe)
+ pipeline.add_pipe(fan_in_sum_pipe)
+
+ result = await pipeline.run(input_generator())
+
+ # Calculate expected results based on the multiplier and the sum of inputs
+ expected_result = sum(
+ [sum([j * i for j in [1, 2, 3]]) for i in range(1, multiplier + 1)]
+ )
+ assert (
+ result[0] == expected_result
+ ), "Pipeline output did not match expected sums"
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ "multiplier_a, delay_a, name_a, multiplier_b, delay_b, name_b",
+ [
+ (3, 0.1, "pipe_a", 2, 0.1, "pipe_b"),
+ (4, 0.1, "pipe_a", 3, 0.1, "pipe_b"),
+ ],
+)
+async def test_fan_out_then_multiply(
+ pipe_factory, multiplier_a, delay_a, name_a, multiplier_b, delay_b, name_b
+):
+ pipe_a = pipe_factory(
+ "multiplier", multiplier=multiplier_a, delay=delay_a, name=name_a
+ )
+ pipe_b = pipe_factory(
+ "fan_out", multiplier=multiplier_b, delay=delay_b, name=name_b
+ )
+ pipe_c = pipe_factory("fan_in", delay=0.1, name="pipe_c")
+
+ async def input_generator():
+ for i in [1, 2, 3]:
+ yield i
+
+ pipeline = AsyncPipeline()
+ pipeline.add_pipe(pipe_a)
+ pipeline.add_pipe(pipe_b)
+ pipeline.add_pipe(pipe_c)
+
+ result = await pipeline.run(input_generator())
+
+ expected_result = sum(
+ [
+ sum([j * i * multiplier_a for j in [1, 2, 3]])
+ for i in range(1, multiplier_b + 1)
+ ]
+ )
+ assert (
+ result[0] == expected_result
+ ), "Pipeline output did not match expected multipliers"
diff --git a/R2R/tests/test_vector_db.py b/R2R/tests/test_vector_db.py
new file mode 100755
index 00000000..023145ce
--- /dev/null
+++ b/R2R/tests/test_vector_db.py
@@ -0,0 +1,160 @@
+import random
+
+import pytest
+from dotenv import load_dotenv
+
+from r2r import (
+ Vector,
+ VectorDBConfig,
+ VectorDBProvider,
+ VectorEntry,
+ generate_id_from_label,
+)
+from r2r.providers.vector_dbs import PGVectorDB
+
+load_dotenv()
+
+
+# Sample vector entries
+def generate_random_vector_entry(id: str, dimension: int) -> VectorEntry:
+ vector = [random.random() for _ in range(dimension)]
+ metadata = {"key": f"value_{id}"}
+ return VectorEntry(
+ id=generate_id_from_label(id), vector=Vector(vector), metadata=metadata
+ )
+
+
+dimension = 3
+num_entries = 100
+sample_entries = [
+ generate_random_vector_entry(f"id_{i}", dimension)
+ for i in range(num_entries)
+]
+
+
+# Fixture for PGVectorDB
+@pytest.fixture
+def pg_vector_db():
+ random_collection_name = (
+ f"test_collection_{random.randint(0, 1_000_000_000)}"
+ )
+ config = VectorDBConfig.create(
+ provider="pgvector", vecs_collection=random_collection_name
+ )
+ db = PGVectorDB(config)
+ db.initialize_collection(dimension=dimension)
+ yield db
+ # Teardown
+ db.vx.delete_collection(
+ db.config.extra_fields.get("vecs_collection", None)
+ )
+
+
+@pytest.mark.parametrize("db_fixture", ["pg_vector_db"])
+def test_get_metadatas(request, db_fixture):
+ db = request.getfixturevalue(db_fixture)
+ for entry in sample_entries:
+ db.upsert(entry)
+
+ unique_metadatas = db.get_metadatas(metadata_fields=["key"])
+ unique_values = set([ele["key"] for ele in unique_metadatas])
+ assert len(unique_values) == num_entries
+ assert all(f"value_id_{i}" in unique_values for i in range(num_entries))
+
+
+@pytest.mark.parametrize("db_fixture", ["pg_vector_db"])
+def test_db_initialization(request, db_fixture):
+ db = request.getfixturevalue(db_fixture)
+ assert isinstance(db, VectorDBProvider)
+
+
+@pytest.mark.parametrize("db_fixture", ["pg_vector_db"])
+def test_db_copy_and_search(request, db_fixture):
+ db = request.getfixturevalue(db_fixture)
+ db.upsert(sample_entries[0])
+ results = db.search(query_vector=sample_entries[0].vector.data)
+ assert len(results) == 1
+ assert results[0].id == sample_entries[0].id
+ assert results[0].score == pytest.approx(1.0, rel=1e-3)
+
+
+@pytest.mark.parametrize("db_fixture", ["pg_vector_db"])
+def test_db_upsert_and_search(request, db_fixture):
+ db = request.getfixturevalue(db_fixture)
+ db.upsert(sample_entries[0])
+ results = db.search(query_vector=sample_entries[0].vector.data)
+ assert len(results) == 1
+ assert results[0].id == sample_entries[0].id
+ assert results[0].score == pytest.approx(1.0, rel=1e-3)
+
+
+@pytest.mark.parametrize("db_fixture", ["pg_vector_db"])
+def test_imperfect_match(request, db_fixture):
+ db = request.getfixturevalue(db_fixture)
+ db.upsert(sample_entries[0])
+ query_vector = [val + 0.1 for val in sample_entries[0].vector.data]
+ results = db.search(query_vector=query_vector)
+ assert len(results) == 1
+ assert results[0].id == sample_entries[0].id
+ assert results[0].score < 1.0
+
+
+@pytest.mark.parametrize("db_fixture", ["pg_vector_db"])
+def test_bulk_insert_and_search(request, db_fixture):
+ db = request.getfixturevalue(db_fixture)
+ for entry in sample_entries:
+ db.upsert(entry)
+
+ query_vector = sample_entries[0].vector.data
+ results = db.search(query_vector=query_vector, limit=5)
+ assert len(results) == 5
+ assert results[0].id == sample_entries[0].id
+ assert results[0].score == pytest.approx(1.0, rel=1e-3)
+
+
+@pytest.mark.parametrize("db_fixture", ["pg_vector_db"])
+def test_search_with_filters(request, db_fixture):
+ db = request.getfixturevalue(db_fixture)
+ for entry in sample_entries:
+ db.upsert(entry)
+
+ filtered_id = sample_entries[0].metadata["key"]
+ query_vector = sample_entries[0].vector.data
+ results = db.search(
+ query_vector=query_vector, filters={"key": filtered_id}
+ )
+ assert len(results) == 1
+ assert results[0].id == sample_entries[0].id
+ assert results[0].metadata["key"] == filtered_id
+
+
+@pytest.mark.parametrize("db_fixture", ["pg_vector_db"])
+def test_delete_by_metadata(request, db_fixture):
+ db = request.getfixturevalue(db_fixture)
+ for entry in sample_entries:
+ db.upsert(entry)
+
+ key_to_delete = sample_entries[0].metadata["key"]
+ db.delete_by_metadata(
+ metadata_fields=["key"], metadata_values=[key_to_delete]
+ )
+
+ results = db.search(query_vector=sample_entries[0].vector.data)
+ assert all(result.metadata["key"] != key_to_delete for result in results)
+
+
+@pytest.mark.parametrize("db_fixture", ["pg_vector_db"])
+def test_upsert(request, db_fixture):
+ db = request.getfixturevalue(db_fixture)
+ db.upsert(sample_entries[0])
+ modified_entry = VectorEntry(
+ id=sample_entries[0].id,
+ vector=Vector([0.5, 0.5, 0.5]),
+ metadata={"key": "new_value"},
+ )
+ db.upsert(modified_entry)
+
+ results = db.search(query_vector=[0.5, 0.5, 0.5])
+ assert len(results) == 1
+ assert results[0].id == sample_entries[0].id
+ assert results[0].metadata["key"] == "new_value"