about summary refs log tree commit diff
path: root/R2R/r2r/pipes/ingestion/vector_storage_pipe.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /R2R/r2r/pipes/ingestion/vector_storage_pipe.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to 'R2R/r2r/pipes/ingestion/vector_storage_pipe.py')
-rwxr-xr-xR2R/r2r/pipes/ingestion/vector_storage_pipe.py128
1 files changed, 128 insertions, 0 deletions
diff --git a/R2R/r2r/pipes/ingestion/vector_storage_pipe.py b/R2R/r2r/pipes/ingestion/vector_storage_pipe.py
new file mode 100755
index 00000000..9564fd22
--- /dev/null
+++ b/R2R/r2r/pipes/ingestion/vector_storage_pipe.py
@@ -0,0 +1,128 @@
+import asyncio
+import logging
+import uuid
+from typing import Any, AsyncGenerator, Optional, Tuple, Union
+
+from r2r.base import (
+    AsyncState,
+    KVLoggingSingleton,
+    PipeType,
+    VectorDBProvider,
+    VectorEntry,
+)
+from r2r.base.pipes.base_pipe import AsyncPipe
+
+from ...base.abstractions.exception import R2RDocumentProcessingError
+
+logger = logging.getLogger(__name__)
+
+
+class VectorStoragePipe(AsyncPipe):
+    class Input(AsyncPipe.Input):
+        message: AsyncGenerator[
+            Union[R2RDocumentProcessingError, VectorEntry], None
+        ]
+        do_upsert: bool = True
+
+    def __init__(
+        self,
+        vector_db_provider: VectorDBProvider,
+        storage_batch_size: int = 128,
+        pipe_logger: Optional[KVLoggingSingleton] = None,
+        type: PipeType = PipeType.INGESTOR,
+        config: Optional[AsyncPipe.PipeConfig] = None,
+        *args,
+        **kwargs,
+    ):
+        """
+        Initializes the async vector storage pipe with necessary components and configurations.
+        """
+        super().__init__(
+            pipe_logger=pipe_logger,
+            type=type,
+            config=config,
+            *args,
+            **kwargs,
+        )
+        self.vector_db_provider = vector_db_provider
+        self.storage_batch_size = storage_batch_size
+
+    async def store(
+        self,
+        vector_entries: list[VectorEntry],
+        do_upsert: bool = True,
+    ) -> None:
+        """
+        Stores a batch of vector entries in the database.
+        """
+
+        try:
+            if do_upsert:
+                self.vector_db_provider.upsert_entries(vector_entries)
+            else:
+                self.vector_db_provider.copy_entries(vector_entries)
+        except Exception as e:
+            error_message = (
+                f"Failed to store vector entries in the database: {e}"
+            )
+            logger.error(error_message)
+            raise ValueError(error_message)
+
+    async def _run_logic(
+        self,
+        input: Input,
+        state: AsyncState,
+        run_id: uuid.UUID,
+        *args: Any,
+        **kwargs: Any,
+    ) -> AsyncGenerator[
+        Tuple[uuid.UUID, Union[str, R2RDocumentProcessingError]], None
+    ]:
+        """
+        Executes the async vector storage pipe: storing embeddings in the vector database.
+        """
+        batch_tasks = []
+        vector_batch = []
+        document_counts = {}
+        i = 0
+        async for msg in input.message:
+            i += 1
+            if isinstance(msg, R2RDocumentProcessingError):
+                yield (msg.document_id, msg)
+                continue
+
+            document_id = msg.metadata.get("document_id", None)
+            if not document_id:
+                raise ValueError("Document ID not found in the metadata.")
+            if document_id not in document_counts:
+                document_counts[document_id] = 1
+            else:
+                document_counts[document_id] += 1
+
+            vector_batch.append(msg)
+            if len(vector_batch) >= self.storage_batch_size:
+                # Schedule the storage task
+                batch_tasks.append(
+                    asyncio.create_task(
+                        self.store(vector_batch.copy(), input.do_upsert),
+                        name=f"vector-store-{self.config.name}",
+                    )
+                )
+                vector_batch.clear()
+
+        if vector_batch:  # Process any remaining vectors
+            batch_tasks.append(
+                asyncio.create_task(
+                    self.store(vector_batch.copy(), input.do_upsert),
+                    name=f"vector-store-{self.config.name}",
+                )
+            )
+
+        # Wait for all storage tasks to complete
+        await asyncio.gather(*batch_tasks)
+
+        for document_id, count in document_counts.items():
+            yield (
+                document_id,
+                f"Processed {count} vectors for document {document_id}.",
+            )