import asyncio import copy import logging import uuid from typing import Any, AsyncGenerator, Optional, Union from r2r.base import ( AsyncState, EmbeddingProvider, Extraction, Fragment, FragmentType, KVLoggingSingleton, PipeType, R2RDocumentProcessingError, TextSplitter, Vector, VectorEntry, generate_id_from_label, ) from r2r.base.pipes.base_pipe import AsyncPipe logger = logging.getLogger(__name__) class EmbeddingPipe(AsyncPipe): """ Embeds and stores documents using a specified embedding model and database. """ class Input(AsyncPipe.Input): message: AsyncGenerator[ Union[Extraction, R2RDocumentProcessingError], None ] def __init__( self, embedding_provider: EmbeddingProvider, text_splitter: TextSplitter, embedding_batch_size: int = 1, id_prefix: str = "demo", pipe_logger: Optional[KVLoggingSingleton] = None, type: PipeType = PipeType.INGESTOR, config: Optional[AsyncPipe.PipeConfig] = None, *args, **kwargs, ): """ Initializes the embedding pipe with necessary components and configurations. """ super().__init__( pipe_logger=pipe_logger, type=type, config=config or AsyncPipe.PipeConfig(name="default_embedding_pipe"), ) self.embedding_provider = embedding_provider self.text_splitter = text_splitter self.embedding_batch_size = embedding_batch_size self.id_prefix = id_prefix self.pipe_run_info = None async def fragment( self, extraction: Extraction, run_id: uuid.UUID ) -> AsyncGenerator[Fragment, None]: """ Splits text into manageable chunks for embedding. """ if not isinstance(extraction, Extraction): raise ValueError( f"Expected an Extraction, but received {type(extraction)}." ) if not isinstance(extraction.data, str): raise ValueError( f"Expected a string, but received {type(extraction.data)}." ) text_chunks = [ ele.page_content for ele in self.text_splitter.create_documents([extraction.data]) ] for iteration, chunk in enumerate(text_chunks): fragment = Fragment( id=generate_id_from_label(f"{extraction.id}-{iteration}"), type=FragmentType.TEXT, data=chunk, metadata=copy.deepcopy(extraction.metadata), extraction_id=extraction.id, document_id=extraction.document_id, ) yield fragment iteration += 1 async def transform_fragments( self, fragments: list[Fragment], metadatas: list[dict] ) -> AsyncGenerator[Fragment, None]: """ Transforms text chunks based on their metadata, e.g., adding prefixes. """ async for fragment, metadata in zip(fragments, metadatas): if "chunk_prefix" in metadata: prefix = metadata.pop("chunk_prefix") fragment.data = f"{prefix}\n{fragment.data}" yield fragment async def embed(self, fragments: list[Fragment]) -> list[float]: return await self.embedding_provider.async_get_embeddings( [fragment.data for fragment in fragments], EmbeddingProvider.PipeStage.BASE, ) async def _process_batch( self, fragment_batch: list[Fragment] ) -> list[VectorEntry]: """ Embeds a batch of fragments and yields vector entries. """ vectors = await self.embed(fragment_batch) return [ VectorEntry( id=fragment.id, vector=Vector(data=raw_vector), metadata={ "document_id": fragment.document_id, "extraction_id": fragment.extraction_id, "text": fragment.data, **fragment.metadata, }, ) for raw_vector, fragment in zip(vectors, fragment_batch) ] async def _process_and_enqueue_batch( self, fragment_batch: list[Fragment], vector_entry_queue: asyncio.Queue ): try: batch_result = await self._process_batch(fragment_batch) for vector_entry in batch_result: await vector_entry_queue.put(vector_entry) except Exception as e: logger.error(f"Error processing batch: {e}") await vector_entry_queue.put( R2RDocumentProcessingError( error_message=str(e), document_id=fragment_batch[0].document_id, ) ) finally: await vector_entry_queue.put(None) # Signal completion async def _run_logic( self, input: Input, state: AsyncState, run_id: uuid.UUID, *args: Any, **kwargs: Any, ) -> AsyncGenerator[Union[R2RDocumentProcessingError, VectorEntry], None]: """ Executes the embedding pipe: chunking, transforming, embedding, and storing documents. """ vector_entry_queue = asyncio.Queue() fragment_batch = [] active_tasks = 0 fragment_info = {} async for extraction in input.message: if isinstance(extraction, R2RDocumentProcessingError): yield extraction continue async for fragment in self.fragment(extraction, run_id): if extraction.document_id in fragment_info: fragment_info[extraction.document_id] += 1 else: fragment_info[extraction.document_id] = 0 # Start with 0 fragment.metadata["chunk_order"] = fragment_info[ extraction.document_id ] version = fragment.metadata.get("version", "v0") # Ensure fragment ID is set correctly if not fragment.id: fragment.id = generate_id_from_label( f"{extraction.id}-{fragment_info[extraction.document_id]}-{version}" ) fragment_batch.append(fragment) if len(fragment_batch) >= self.embedding_batch_size: asyncio.create_task( self._process_and_enqueue_batch( fragment_batch.copy(), vector_entry_queue ) ) active_tasks += 1 fragment_batch.clear() logger.debug( f"Fragmented the input document ids into counts as shown: {fragment_info}" ) if fragment_batch: asyncio.create_task( self._process_and_enqueue_batch( fragment_batch.copy(), vector_entry_queue ) ) active_tasks += 1 while active_tasks > 0: vector_entry = await vector_entry_queue.get() if vector_entry is None: # Check for termination signal active_tasks -= 1 elif isinstance(vector_entry, Exception): yield vector_entry # Propagate the exception active_tasks -= 1 else: yield vector_entry