diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /R2R/r2r/pipes/ingestion/embedding_pipe.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to 'R2R/r2r/pipes/ingestion/embedding_pipe.py')
-rwxr-xr-x | R2R/r2r/pipes/ingestion/embedding_pipe.py | 218 |
1 files changed, 218 insertions, 0 deletions
diff --git a/R2R/r2r/pipes/ingestion/embedding_pipe.py b/R2R/r2r/pipes/ingestion/embedding_pipe.py new file mode 100755 index 00000000..971ccc9d --- /dev/null +++ b/R2R/r2r/pipes/ingestion/embedding_pipe.py @@ -0,0 +1,218 @@ +import asyncio +import copy +import logging +import uuid +from typing import Any, AsyncGenerator, Optional, Union + +from r2r.base import ( + AsyncState, + EmbeddingProvider, + Extraction, + Fragment, + FragmentType, + KVLoggingSingleton, + PipeType, + R2RDocumentProcessingError, + TextSplitter, + Vector, + VectorEntry, + generate_id_from_label, +) +from r2r.base.pipes.base_pipe import AsyncPipe + +logger = logging.getLogger(__name__) + + +class EmbeddingPipe(AsyncPipe): + """ + Embeds and stores documents using a specified embedding model and database. + """ + + class Input(AsyncPipe.Input): + message: AsyncGenerator[ + Union[Extraction, R2RDocumentProcessingError], None + ] + + def __init__( + self, + embedding_provider: EmbeddingProvider, + text_splitter: TextSplitter, + embedding_batch_size: int = 1, + id_prefix: str = "demo", + pipe_logger: Optional[KVLoggingSingleton] = None, + type: PipeType = PipeType.INGESTOR, + config: Optional[AsyncPipe.PipeConfig] = None, + *args, + **kwargs, + ): + """ + Initializes the embedding pipe with necessary components and configurations. + """ + super().__init__( + pipe_logger=pipe_logger, + type=type, + config=config + or AsyncPipe.PipeConfig(name="default_embedding_pipe"), + ) + self.embedding_provider = embedding_provider + self.text_splitter = text_splitter + self.embedding_batch_size = embedding_batch_size + self.id_prefix = id_prefix + self.pipe_run_info = None + + async def fragment( + self, extraction: Extraction, run_id: uuid.UUID + ) -> AsyncGenerator[Fragment, None]: + """ + Splits text into manageable chunks for embedding. + """ + if not isinstance(extraction, Extraction): + raise ValueError( + f"Expected an Extraction, but received {type(extraction)}." + ) + if not isinstance(extraction.data, str): + raise ValueError( + f"Expected a string, but received {type(extraction.data)}." + ) + text_chunks = [ + ele.page_content + for ele in self.text_splitter.create_documents([extraction.data]) + ] + for iteration, chunk in enumerate(text_chunks): + fragment = Fragment( + id=generate_id_from_label(f"{extraction.id}-{iteration}"), + type=FragmentType.TEXT, + data=chunk, + metadata=copy.deepcopy(extraction.metadata), + extraction_id=extraction.id, + document_id=extraction.document_id, + ) + yield fragment + iteration += 1 + + async def transform_fragments( + self, fragments: list[Fragment], metadatas: list[dict] + ) -> AsyncGenerator[Fragment, None]: + """ + Transforms text chunks based on their metadata, e.g., adding prefixes. + """ + async for fragment, metadata in zip(fragments, metadatas): + if "chunk_prefix" in metadata: + prefix = metadata.pop("chunk_prefix") + fragment.data = f"{prefix}\n{fragment.data}" + yield fragment + + async def embed(self, fragments: list[Fragment]) -> list[float]: + return await self.embedding_provider.async_get_embeddings( + [fragment.data for fragment in fragments], + EmbeddingProvider.PipeStage.BASE, + ) + + async def _process_batch( + self, fragment_batch: list[Fragment] + ) -> list[VectorEntry]: + """ + Embeds a batch of fragments and yields vector entries. + """ + vectors = await self.embed(fragment_batch) + return [ + VectorEntry( + id=fragment.id, + vector=Vector(data=raw_vector), + metadata={ + "document_id": fragment.document_id, + "extraction_id": fragment.extraction_id, + "text": fragment.data, + **fragment.metadata, + }, + ) + for raw_vector, fragment in zip(vectors, fragment_batch) + ] + + async def _process_and_enqueue_batch( + self, fragment_batch: list[Fragment], vector_entry_queue: asyncio.Queue + ): + try: + batch_result = await self._process_batch(fragment_batch) + for vector_entry in batch_result: + await vector_entry_queue.put(vector_entry) + except Exception as e: + logger.error(f"Error processing batch: {e}") + await vector_entry_queue.put( + R2RDocumentProcessingError( + error_message=str(e), + document_id=fragment_batch[0].document_id, + ) + ) + finally: + await vector_entry_queue.put(None) # Signal completion + + async def _run_logic( + self, + input: Input, + state: AsyncState, + run_id: uuid.UUID, + *args: Any, + **kwargs: Any, + ) -> AsyncGenerator[Union[R2RDocumentProcessingError, VectorEntry], None]: + """ + Executes the embedding pipe: chunking, transforming, embedding, and storing documents. + """ + vector_entry_queue = asyncio.Queue() + fragment_batch = [] + active_tasks = 0 + + fragment_info = {} + async for extraction in input.message: + if isinstance(extraction, R2RDocumentProcessingError): + yield extraction + continue + + async for fragment in self.fragment(extraction, run_id): + if extraction.document_id in fragment_info: + fragment_info[extraction.document_id] += 1 + else: + fragment_info[extraction.document_id] = 0 # Start with 0 + fragment.metadata["chunk_order"] = fragment_info[ + extraction.document_id + ] + + version = fragment.metadata.get("version", "v0") + + # Ensure fragment ID is set correctly + if not fragment.id: + fragment.id = generate_id_from_label( + f"{extraction.id}-{fragment_info[extraction.document_id]}-{version}" + ) + + fragment_batch.append(fragment) + if len(fragment_batch) >= self.embedding_batch_size: + asyncio.create_task( + self._process_and_enqueue_batch( + fragment_batch.copy(), vector_entry_queue + ) + ) + active_tasks += 1 + fragment_batch.clear() + + logger.debug( + f"Fragmented the input document ids into counts as shown: {fragment_info}" + ) + + if fragment_batch: + asyncio.create_task( + self._process_and_enqueue_batch( + fragment_batch.copy(), vector_entry_queue + ) + ) + active_tasks += 1 + + while active_tasks > 0: + vector_entry = await vector_entry_queue.get() + if vector_entry is None: # Check for termination signal + active_tasks -= 1 + elif isinstance(vector_entry, Exception): + yield vector_entry # Propagate the exception + active_tasks -= 1 + else: + yield vector_entry |