diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /R2R/r2r/pipes/ingestion/vector_storage_pipe.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to 'R2R/r2r/pipes/ingestion/vector_storage_pipe.py')
-rwxr-xr-x | R2R/r2r/pipes/ingestion/vector_storage_pipe.py | 128 |
1 files changed, 128 insertions, 0 deletions
diff --git a/R2R/r2r/pipes/ingestion/vector_storage_pipe.py b/R2R/r2r/pipes/ingestion/vector_storage_pipe.py new file mode 100755 index 00000000..9564fd22 --- /dev/null +++ b/R2R/r2r/pipes/ingestion/vector_storage_pipe.py @@ -0,0 +1,128 @@ +import asyncio +import logging +import uuid +from typing import Any, AsyncGenerator, Optional, Tuple, Union + +from r2r.base import ( + AsyncState, + KVLoggingSingleton, + PipeType, + VectorDBProvider, + VectorEntry, +) +from r2r.base.pipes.base_pipe import AsyncPipe + +from ...base.abstractions.exception import R2RDocumentProcessingError + +logger = logging.getLogger(__name__) + + +class VectorStoragePipe(AsyncPipe): + class Input(AsyncPipe.Input): + message: AsyncGenerator[ + Union[R2RDocumentProcessingError, VectorEntry], None + ] + do_upsert: bool = True + + def __init__( + self, + vector_db_provider: VectorDBProvider, + storage_batch_size: int = 128, + pipe_logger: Optional[KVLoggingSingleton] = None, + type: PipeType = PipeType.INGESTOR, + config: Optional[AsyncPipe.PipeConfig] = None, + *args, + **kwargs, + ): + """ + Initializes the async vector storage pipe with necessary components and configurations. + """ + super().__init__( + pipe_logger=pipe_logger, + type=type, + config=config, + *args, + **kwargs, + ) + self.vector_db_provider = vector_db_provider + self.storage_batch_size = storage_batch_size + + async def store( + self, + vector_entries: list[VectorEntry], + do_upsert: bool = True, + ) -> None: + """ + Stores a batch of vector entries in the database. + """ + + try: + if do_upsert: + self.vector_db_provider.upsert_entries(vector_entries) + else: + self.vector_db_provider.copy_entries(vector_entries) + except Exception as e: + error_message = ( + f"Failed to store vector entries in the database: {e}" + ) + logger.error(error_message) + raise ValueError(error_message) + + async def _run_logic( + self, + input: Input, + state: AsyncState, + run_id: uuid.UUID, + *args: Any, + **kwargs: Any, + ) -> AsyncGenerator[ + Tuple[uuid.UUID, Union[str, R2RDocumentProcessingError]], None + ]: + """ + Executes the async vector storage pipe: storing embeddings in the vector database. + """ + batch_tasks = [] + vector_batch = [] + document_counts = {} + i = 0 + async for msg in input.message: + i += 1 + if isinstance(msg, R2RDocumentProcessingError): + yield (msg.document_id, msg) + continue + + document_id = msg.metadata.get("document_id", None) + if not document_id: + raise ValueError("Document ID not found in the metadata.") + if document_id not in document_counts: + document_counts[document_id] = 1 + else: + document_counts[document_id] += 1 + + vector_batch.append(msg) + if len(vector_batch) >= self.storage_batch_size: + # Schedule the storage task + batch_tasks.append( + asyncio.create_task( + self.store(vector_batch.copy(), input.do_upsert), + name=f"vector-store-{self.config.name}", + ) + ) + vector_batch.clear() + + if vector_batch: # Process any remaining vectors + batch_tasks.append( + asyncio.create_task( + self.store(vector_batch.copy(), input.do_upsert), + name=f"vector-store-{self.config.name}", + ) + ) + + # Wait for all storage tasks to complete + await asyncio.gather(*batch_tasks) + + for document_id, count in document_counts.items(): + yield ( + document_id, + f"Processed {count} vectors for document {document_id}.", + ) |