diff options
Diffstat (limited to 'R2R/r2r/pipes/ingestion')
-rwxr-xr-x | R2R/r2r/pipes/ingestion/__init__.py | 0 | ||||
-rwxr-xr-x | R2R/r2r/pipes/ingestion/embedding_pipe.py | 218 | ||||
-rwxr-xr-x | R2R/r2r/pipes/ingestion/kg_extraction_pipe.py | 226 | ||||
-rwxr-xr-x | R2R/r2r/pipes/ingestion/kg_storage_pipe.py | 133 | ||||
-rwxr-xr-x | R2R/r2r/pipes/ingestion/parsing_pipe.py | 211 | ||||
-rwxr-xr-x | R2R/r2r/pipes/ingestion/vector_storage_pipe.py | 128 |
6 files changed, 916 insertions, 0 deletions
diff --git a/R2R/r2r/pipes/ingestion/__init__.py b/R2R/r2r/pipes/ingestion/__init__.py new file mode 100755 index 00000000..e69de29b --- /dev/null +++ b/R2R/r2r/pipes/ingestion/__init__.py diff --git a/R2R/r2r/pipes/ingestion/embedding_pipe.py b/R2R/r2r/pipes/ingestion/embedding_pipe.py new file mode 100755 index 00000000..971ccc9d --- /dev/null +++ b/R2R/r2r/pipes/ingestion/embedding_pipe.py @@ -0,0 +1,218 @@ +import asyncio +import copy +import logging +import uuid +from typing import Any, AsyncGenerator, Optional, Union + +from r2r.base import ( + AsyncState, + EmbeddingProvider, + Extraction, + Fragment, + FragmentType, + KVLoggingSingleton, + PipeType, + R2RDocumentProcessingError, + TextSplitter, + Vector, + VectorEntry, + generate_id_from_label, +) +from r2r.base.pipes.base_pipe import AsyncPipe + +logger = logging.getLogger(__name__) + + +class EmbeddingPipe(AsyncPipe): + """ + Embeds and stores documents using a specified embedding model and database. + """ + + class Input(AsyncPipe.Input): + message: AsyncGenerator[ + Union[Extraction, R2RDocumentProcessingError], None + ] + + def __init__( + self, + embedding_provider: EmbeddingProvider, + text_splitter: TextSplitter, + embedding_batch_size: int = 1, + id_prefix: str = "demo", + pipe_logger: Optional[KVLoggingSingleton] = None, + type: PipeType = PipeType.INGESTOR, + config: Optional[AsyncPipe.PipeConfig] = None, + *args, + **kwargs, + ): + """ + Initializes the embedding pipe with necessary components and configurations. + """ + super().__init__( + pipe_logger=pipe_logger, + type=type, + config=config + or AsyncPipe.PipeConfig(name="default_embedding_pipe"), + ) + self.embedding_provider = embedding_provider + self.text_splitter = text_splitter + self.embedding_batch_size = embedding_batch_size + self.id_prefix = id_prefix + self.pipe_run_info = None + + async def fragment( + self, extraction: Extraction, run_id: uuid.UUID + ) -> AsyncGenerator[Fragment, None]: + """ + Splits text into manageable chunks for embedding. + """ + if not isinstance(extraction, Extraction): + raise ValueError( + f"Expected an Extraction, but received {type(extraction)}." + ) + if not isinstance(extraction.data, str): + raise ValueError( + f"Expected a string, but received {type(extraction.data)}." + ) + text_chunks = [ + ele.page_content + for ele in self.text_splitter.create_documents([extraction.data]) + ] + for iteration, chunk in enumerate(text_chunks): + fragment = Fragment( + id=generate_id_from_label(f"{extraction.id}-{iteration}"), + type=FragmentType.TEXT, + data=chunk, + metadata=copy.deepcopy(extraction.metadata), + extraction_id=extraction.id, + document_id=extraction.document_id, + ) + yield fragment + iteration += 1 + + async def transform_fragments( + self, fragments: list[Fragment], metadatas: list[dict] + ) -> AsyncGenerator[Fragment, None]: + """ + Transforms text chunks based on their metadata, e.g., adding prefixes. + """ + async for fragment, metadata in zip(fragments, metadatas): + if "chunk_prefix" in metadata: + prefix = metadata.pop("chunk_prefix") + fragment.data = f"{prefix}\n{fragment.data}" + yield fragment + + async def embed(self, fragments: list[Fragment]) -> list[float]: + return await self.embedding_provider.async_get_embeddings( + [fragment.data for fragment in fragments], + EmbeddingProvider.PipeStage.BASE, + ) + + async def _process_batch( + self, fragment_batch: list[Fragment] + ) -> list[VectorEntry]: + """ + Embeds a batch of fragments and yields vector entries. + """ + vectors = await self.embed(fragment_batch) + return [ + VectorEntry( + id=fragment.id, + vector=Vector(data=raw_vector), + metadata={ + "document_id": fragment.document_id, + "extraction_id": fragment.extraction_id, + "text": fragment.data, + **fragment.metadata, + }, + ) + for raw_vector, fragment in zip(vectors, fragment_batch) + ] + + async def _process_and_enqueue_batch( + self, fragment_batch: list[Fragment], vector_entry_queue: asyncio.Queue + ): + try: + batch_result = await self._process_batch(fragment_batch) + for vector_entry in batch_result: + await vector_entry_queue.put(vector_entry) + except Exception as e: + logger.error(f"Error processing batch: {e}") + await vector_entry_queue.put( + R2RDocumentProcessingError( + error_message=str(e), + document_id=fragment_batch[0].document_id, + ) + ) + finally: + await vector_entry_queue.put(None) # Signal completion + + async def _run_logic( + self, + input: Input, + state: AsyncState, + run_id: uuid.UUID, + *args: Any, + **kwargs: Any, + ) -> AsyncGenerator[Union[R2RDocumentProcessingError, VectorEntry], None]: + """ + Executes the embedding pipe: chunking, transforming, embedding, and storing documents. + """ + vector_entry_queue = asyncio.Queue() + fragment_batch = [] + active_tasks = 0 + + fragment_info = {} + async for extraction in input.message: + if isinstance(extraction, R2RDocumentProcessingError): + yield extraction + continue + + async for fragment in self.fragment(extraction, run_id): + if extraction.document_id in fragment_info: + fragment_info[extraction.document_id] += 1 + else: + fragment_info[extraction.document_id] = 0 # Start with 0 + fragment.metadata["chunk_order"] = fragment_info[ + extraction.document_id + ] + + version = fragment.metadata.get("version", "v0") + + # Ensure fragment ID is set correctly + if not fragment.id: + fragment.id = generate_id_from_label( + f"{extraction.id}-{fragment_info[extraction.document_id]}-{version}" + ) + + fragment_batch.append(fragment) + if len(fragment_batch) >= self.embedding_batch_size: + asyncio.create_task( + self._process_and_enqueue_batch( + fragment_batch.copy(), vector_entry_queue + ) + ) + active_tasks += 1 + fragment_batch.clear() + + logger.debug( + f"Fragmented the input document ids into counts as shown: {fragment_info}" + ) + + if fragment_batch: + asyncio.create_task( + self._process_and_enqueue_batch( + fragment_batch.copy(), vector_entry_queue + ) + ) + active_tasks += 1 + + while active_tasks > 0: + vector_entry = await vector_entry_queue.get() + if vector_entry is None: # Check for termination signal + active_tasks -= 1 + elif isinstance(vector_entry, Exception): + yield vector_entry # Propagate the exception + active_tasks -= 1 + else: + yield vector_entry diff --git a/R2R/r2r/pipes/ingestion/kg_extraction_pipe.py b/R2R/r2r/pipes/ingestion/kg_extraction_pipe.py new file mode 100755 index 00000000..13025e39 --- /dev/null +++ b/R2R/r2r/pipes/ingestion/kg_extraction_pipe.py @@ -0,0 +1,226 @@ +import asyncio +import copy +import json +import logging +import uuid +from typing import Any, AsyncGenerator, Optional + +from r2r.base import ( + AsyncState, + Extraction, + Fragment, + FragmentType, + KGExtraction, + KGProvider, + KVLoggingSingleton, + LLMProvider, + PipeType, + PromptProvider, + TextSplitter, + extract_entities, + extract_triples, + generate_id_from_label, +) +from r2r.base.pipes.base_pipe import AsyncPipe + +logger = logging.getLogger(__name__) + + +class ClientError(Exception): + """Base class for client connection errors.""" + + pass + + +class KGExtractionPipe(AsyncPipe): + """ + Embeds and stores documents using a specified embedding model and database. + """ + + def __init__( + self, + kg_provider: KGProvider, + llm_provider: LLMProvider, + prompt_provider: PromptProvider, + text_splitter: TextSplitter, + kg_batch_size: int = 1, + id_prefix: str = "demo", + pipe_logger: Optional[KVLoggingSingleton] = None, + type: PipeType = PipeType.INGESTOR, + config: Optional[AsyncPipe.PipeConfig] = None, + *args, + **kwargs, + ): + """ + Initializes the embedding pipe with necessary components and configurations. + """ + super().__init__( + pipe_logger=pipe_logger, + type=type, + config=config + or AsyncPipe.PipeConfig(name="default_embedding_pipe"), + ) + + self.kg_provider = kg_provider + self.prompt_provider = prompt_provider + self.llm_provider = llm_provider + self.text_splitter = text_splitter + self.kg_batch_size = kg_batch_size + self.id_prefix = id_prefix + self.pipe_run_info = None + + async def fragment( + self, extraction: Extraction, run_id: uuid.UUID + ) -> AsyncGenerator[Fragment, None]: + """ + Splits text into manageable chunks for embedding. + """ + if not isinstance(extraction, Extraction): + raise ValueError( + f"Expected an Extraction, but received {type(extraction)}." + ) + if not isinstance(extraction.data, str): + raise ValueError( + f"Expected a string, but received {type(extraction.data)}." + ) + text_chunks = [ + ele.page_content + for ele in self.text_splitter.create_documents([extraction.data]) + ] + for iteration, chunk in enumerate(text_chunks): + fragment = Fragment( + id=generate_id_from_label(f"{extraction.id}-{iteration}"), + type=FragmentType.TEXT, + data=chunk, + metadata=copy.deepcopy(extraction.metadata), + extraction_id=extraction.id, + document_id=extraction.document_id, + ) + yield fragment + + async def transform_fragments( + self, fragments: list[Fragment] + ) -> AsyncGenerator[Fragment, None]: + """ + Transforms text chunks based on their metadata, e.g., adding prefixes. + """ + async for fragment in fragments: + if "chunk_prefix" in fragment.metadata: + prefix = fragment.metadata.pop("chunk_prefix") + fragment.data = f"{prefix}\n{fragment.data}" + yield fragment + + async def extract_kg( + self, + fragment: Fragment, + retries: int = 3, + delay: int = 2, + ) -> KGExtraction: + """ + Extracts NER triples from a list of fragments with retries. + """ + task_prompt = self.prompt_provider.get_prompt( + self.kg_provider.config.kg_extraction_prompt, + inputs={"input": fragment.data}, + ) + messages = self.prompt_provider._get_message_payload( + self.prompt_provider.get_prompt("default_system"), task_prompt + ) + for attempt in range(retries): + try: + response = await self.llm_provider.aget_completion( + messages, self.kg_provider.config.kg_extraction_config + ) + + kg_extraction = response.choices[0].message.content + + # Parsing JSON from the response + kg_json = ( + json.loads( + kg_extraction.split("```json")[1].split("```")[0] + ) + if """```json""" in kg_extraction + else json.loads(kg_extraction) + ) + llm_payload = kg_json.get("entities_and_triples", {}) + + # Extract triples with detailed logging + entities = extract_entities(llm_payload) + triples = extract_triples(llm_payload, entities) + + # Create KG extraction object + return KGExtraction(entities=entities, triples=triples) + except ( + ClientError, + json.JSONDecodeError, + KeyError, + IndexError, + ) as e: + logger.error(f"Error in extract_kg: {e}") + if attempt < retries - 1: + await asyncio.sleep(delay) + else: + logger.error(f"Failed after retries with {e}") + # raise e # Ensure the exception is raised after the final attempt + + return KGExtraction(entities={}, triples=[]) + + async def _process_batch( + self, + fragment_batch: list[Fragment], + ) -> list[KGExtraction]: + """ + Embeds a batch of fragments and yields vector entries. + """ + tasks = [ + asyncio.create_task(self.extract_kg(fragment)) + for fragment in fragment_batch + ] + return await asyncio.gather(*tasks) + + async def _run_logic( + self, + input: AsyncPipe.Input, + state: AsyncState, + run_id: uuid.UUID, + *args: Any, + **kwargs: Any, + ) -> AsyncGenerator[KGExtraction, None]: + """ + Executes the embedding pipe: chunking, transforming, embedding, and storing documents. + """ + batch_tasks = [] + fragment_batch = [] + + fragment_info = {} + async for extraction in input.message: + async for fragment in self.transform_fragments( + self.fragment(extraction, run_id) + ): + if extraction.document_id in fragment_info: + fragment_info[extraction.document_id] += 1 + else: + fragment_info[extraction.document_id] = 1 + extraction.metadata["chunk_order"] = fragment_info[ + extraction.document_id + ] + fragment_batch.append(fragment) + if len(fragment_batch) >= self.kg_batch_size: + # Here, ensure `_process_batch` is scheduled as a coroutine, not called directly + batch_tasks.append( + self._process_batch(fragment_batch.copy()) + ) # pass a copy if necessary + fragment_batch.clear() # Clear the batch for new fragments + + logger.debug( + f"Fragmented the input document ids into counts as shown: {fragment_info}" + ) + + if fragment_batch: # Process any remaining fragments + batch_tasks.append(self._process_batch(fragment_batch.copy())) + + # Process tasks as they complete + for task in asyncio.as_completed(batch_tasks): + batch_result = await task # Wait for the next task to complete + for kg_extraction in batch_result: + yield kg_extraction diff --git a/R2R/r2r/pipes/ingestion/kg_storage_pipe.py b/R2R/r2r/pipes/ingestion/kg_storage_pipe.py new file mode 100755 index 00000000..9ac63479 --- /dev/null +++ b/R2R/r2r/pipes/ingestion/kg_storage_pipe.py @@ -0,0 +1,133 @@ +import asyncio +import logging +import uuid +from typing import Any, AsyncGenerator, Optional + +from r2r.base import ( + AsyncState, + EmbeddingProvider, + KGExtraction, + KGProvider, + KVLoggingSingleton, + PipeType, +) +from r2r.base.abstractions.llama_abstractions import EntityNode, Relation +from r2r.base.pipes.base_pipe import AsyncPipe + +logger = logging.getLogger(__name__) + + +class KGStoragePipe(AsyncPipe): + class Input(AsyncPipe.Input): + message: AsyncGenerator[KGExtraction, None] + + def __init__( + self, + kg_provider: KGProvider, + embedding_provider: Optional[EmbeddingProvider] = None, + storage_batch_size: int = 1, + pipe_logger: Optional[KVLoggingSingleton] = None, + type: PipeType = PipeType.INGESTOR, + config: Optional[AsyncPipe.PipeConfig] = None, + *args, + **kwargs, + ): + """ + Initializes the async knowledge graph storage pipe with necessary components and configurations. + """ + logger.info( + f"Initializing an `KGStoragePipe` to store knowledge graph extractions in a graph database." + ) + + super().__init__( + pipe_logger=pipe_logger, + type=type, + config=config, + *args, + **kwargs, + ) + self.kg_provider = kg_provider + self.embedding_provider = embedding_provider + self.storage_batch_size = storage_batch_size + + async def store( + self, + kg_extractions: list[KGExtraction], + ) -> None: + """ + Stores a batch of knowledge graph extractions in the graph database. + """ + try: + nodes = [] + relations = [] + for extraction in kg_extractions: + for entity in extraction.entities.values(): + embedding = None + if self.embedding_provider: + embedding = self.embedding_provider.get_embedding( + "Entity:\n{entity.value}\nLabel:\n{entity.category}\nSubcategory:\n{entity.subcategory}" + ) + nodes.append( + EntityNode( + name=entity.value, + label=entity.category, + embedding=embedding, + properties=( + {"subcategory": entity.subcategory} + if entity.subcategory + else {} + ), + ) + ) + for triple in extraction.triples: + relations.append( + Relation( + source_id=triple.subject, + target_id=triple.object, + label=triple.predicate, + ) + ) + self.kg_provider.upsert_nodes(nodes) + self.kg_provider.upsert_relations(relations) + except Exception as e: + error_message = f"Failed to store knowledge graph extractions in the database: {e}" + logger.error(error_message) + raise ValueError(error_message) + + async def _run_logic( + self, + input: Input, + state: AsyncState, + run_id: uuid.UUID, + *args: Any, + **kwargs: Any, + ) -> AsyncGenerator[None, None]: + """ + Executes the async knowledge graph storage pipe: storing knowledge graph extractions in the graph database. + """ + batch_tasks = [] + kg_batch = [] + + async for kg_extraction in input.message: + kg_batch.append(kg_extraction) + if len(kg_batch) >= self.storage_batch_size: + # Schedule the storage task + batch_tasks.append( + asyncio.create_task( + self.store(kg_batch.copy()), + name=f"kg-store-{self.config.name}", + ) + ) + kg_batch.clear() + + if kg_batch: # Process any remaining extractions + batch_tasks.append( + asyncio.create_task( + self.store(kg_batch.copy()), + name=f"kg-store-{self.config.name}", + ) + ) + + # Wait for all storage tasks to complete + await asyncio.gather(*batch_tasks) + yield None diff --git a/R2R/r2r/pipes/ingestion/parsing_pipe.py b/R2R/r2r/pipes/ingestion/parsing_pipe.py new file mode 100755 index 00000000..f3c81ca0 --- /dev/null +++ b/R2R/r2r/pipes/ingestion/parsing_pipe.py @@ -0,0 +1,211 @@ +""" +This module contains the `DocumentParsingPipe` class, which is responsible for parsing incoming documents into plaintext. +""" + +import asyncio +import logging +import time +import uuid +from typing import AsyncGenerator, Optional, Union + +from r2r.base import ( + AsyncParser, + AsyncState, + Document, + DocumentType, + Extraction, + ExtractionType, + KVLoggingSingleton, + PipeType, + generate_id_from_label, +) +from r2r.base.abstractions.exception import R2RDocumentProcessingError +from r2r.base.pipes.base_pipe import AsyncPipe +from r2r.parsers.media.audio_parser import AudioParser +from r2r.parsers.media.docx_parser import DOCXParser +from r2r.parsers.media.img_parser import ImageParser +from r2r.parsers.media.movie_parser import MovieParser +from r2r.parsers.media.pdf_parser import PDFParser +from r2r.parsers.media.ppt_parser import PPTParser +from r2r.parsers.structured.csv_parser import CSVParser +from r2r.parsers.structured.json_parser import JSONParser +from r2r.parsers.structured.xlsx_parser import XLSXParser +from r2r.parsers.text.html_parser import HTMLParser +from r2r.parsers.text.md_parser import MDParser +from r2r.parsers.text.text_parser import TextParser + +logger = logging.getLogger(__name__) + + +class ParsingPipe(AsyncPipe): + """ + Processes incoming documents into plaintext based on their data type. + Supports TXT, JSON, HTML, and PDF formats. + """ + + class Input(AsyncPipe.Input): + message: AsyncGenerator[Document, None] + + AVAILABLE_PARSERS = { + DocumentType.CSV: CSVParser, + DocumentType.DOCX: DOCXParser, + DocumentType.HTML: HTMLParser, + DocumentType.JSON: JSONParser, + DocumentType.MD: MDParser, + DocumentType.PDF: PDFParser, + DocumentType.PPTX: PPTParser, + DocumentType.TXT: TextParser, + DocumentType.XLSX: XLSXParser, + DocumentType.GIF: ImageParser, + DocumentType.JPEG: ImageParser, + DocumentType.JPG: ImageParser, + DocumentType.PNG: ImageParser, + DocumentType.SVG: ImageParser, + DocumentType.MP3: AudioParser, + DocumentType.MP4: MovieParser, + } + + IMAGE_TYPES = { + DocumentType.GIF, + DocumentType.JPG, + DocumentType.JPEG, + DocumentType.PNG, + DocumentType.SVG, + } + + def __init__( + self, + excluded_parsers: list[DocumentType], + override_parsers: Optional[dict[DocumentType, AsyncParser]] = None, + pipe_logger: Optional[KVLoggingSingleton] = None, + type: PipeType = PipeType.INGESTOR, + config: Optional[AsyncPipe.PipeConfig] = None, + *args, + **kwargs, + ): + super().__init__( + pipe_logger=pipe_logger, + type=type, + config=config + or AsyncPipe.PipeConfig(name="default_document_parsing_pipe"), + *args, + **kwargs, + ) + + self.parsers = {} + + if not override_parsers: + override_parsers = {} + + # Apply overrides if specified + for doc_type, parser in override_parsers.items(): + self.parsers[doc_type] = parser + + for doc_type, parser_info in self.AVAILABLE_PARSERS.items(): + if ( + doc_type not in excluded_parsers + and doc_type not in self.parsers + ): + self.parsers[doc_type] = parser_info() + + @property + def supported_types(self) -> list[str]: + """ + Lists the data types supported by the pipe. + """ + return [entry_type for entry_type in DocumentType] + + async def _parse( + self, + document: Document, + run_id: uuid.UUID, + version: str, + ) -> AsyncGenerator[Union[R2RDocumentProcessingError, Extraction], None]: + if document.type not in self.parsers: + yield R2RDocumentProcessingError( + document_id=document.id, + error_message=f"Parser for {document.type} not found in `ParsingPipe`.", + ) + return + parser = self.parsers[document.type] + texts = parser.ingest(document.data) + extraction_type = ExtractionType.TXT + t0 = time.time() + if document.type in self.IMAGE_TYPES: + extraction_type = ExtractionType.IMG + document.metadata["image_type"] = document.type.value + # SAVE IMAGE DATA + # try: + # import base64 + # sanitized_data = base64.b64encode(document.data).decode('utf-8') + # except Exception as e: + # sanitized_data = document.data + + # document.metadata["image_data"] = sanitized_data + elif document.type == DocumentType.MP4: + extraction_type = ExtractionType.MOV + document.metadata["audio_type"] = document.type.value + + iteration = 0 + async for text in texts: + extraction_id = generate_id_from_label( + f"{document.id}-{iteration}-{version}" + ) + document.metadata["version"] = version + extraction = Extraction( + id=extraction_id, + data=text, + metadata=document.metadata, + document_id=document.id, + type=extraction_type, + ) + yield extraction + # TODO - Add settings to enable extraction logging + # extraction_dict = extraction.dict() + # await self.enqueue_log( + # run_id=run_id, + # key="extraction", + # value=json.dumps( + # { + # "data": extraction_dict["data"], + # "document_id": str(extraction_dict["document_id"]), + # "extraction_id": str(extraction_dict["id"]), + # } + # ), + # ) + iteration += 1 + logger.debug( + f"Parsed document with id={document.id}, title={document.metadata.get('title', None)}, user_id={document.metadata.get('user_id', None)}, metadata={document.metadata} into {iteration} extractions in t={time.time() - t0:.2f} seconds." + ) + + async def _run_logic( + self, + input: Input, + state: AsyncState, + run_id: uuid.UUID, + versions: Optional[list[str]] = None, + *args, + **kwargs, + ) -> AsyncGenerator[Extraction, None]: + parse_tasks = [] + + iteration = 0 + async for document in input.message: + version = versions[iteration] if versions else "v0" + iteration += 1 + parse_tasks.append( + self._handle_parse_task(document, version, run_id) + ) + + # Await all tasks and yield results concurrently + for parse_task in asyncio.as_completed(parse_tasks): + for extraction in await parse_task: + yield extraction + + async def _handle_parse_task( + self, document: Document, version: str, run_id: uuid.UUID + ) -> AsyncGenerator[Extraction, None]: + extractions = [] + async for extraction in self._parse(document, run_id, version): + extractions.append(extraction) + return extractions diff --git a/R2R/r2r/pipes/ingestion/vector_storage_pipe.py b/R2R/r2r/pipes/ingestion/vector_storage_pipe.py new file mode 100755 index 00000000..9564fd22 --- /dev/null +++ b/R2R/r2r/pipes/ingestion/vector_storage_pipe.py @@ -0,0 +1,128 @@ +import asyncio +import logging +import uuid +from typing import Any, AsyncGenerator, Optional, Tuple, Union + +from r2r.base import ( + AsyncState, + KVLoggingSingleton, + PipeType, + VectorDBProvider, + VectorEntry, +) +from r2r.base.pipes.base_pipe import AsyncPipe + +from ...base.abstractions.exception import R2RDocumentProcessingError + +logger = logging.getLogger(__name__) + + +class VectorStoragePipe(AsyncPipe): + class Input(AsyncPipe.Input): + message: AsyncGenerator[ + Union[R2RDocumentProcessingError, VectorEntry], None + ] + do_upsert: bool = True + + def __init__( + self, + vector_db_provider: VectorDBProvider, + storage_batch_size: int = 128, + pipe_logger: Optional[KVLoggingSingleton] = None, + type: PipeType = PipeType.INGESTOR, + config: Optional[AsyncPipe.PipeConfig] = None, + *args, + **kwargs, + ): + """ + Initializes the async vector storage pipe with necessary components and configurations. + """ + super().__init__( + pipe_logger=pipe_logger, + type=type, + config=config, + *args, + **kwargs, + ) + self.vector_db_provider = vector_db_provider + self.storage_batch_size = storage_batch_size + + async def store( + self, + vector_entries: list[VectorEntry], + do_upsert: bool = True, + ) -> None: + """ + Stores a batch of vector entries in the database. + """ + + try: + if do_upsert: + self.vector_db_provider.upsert_entries(vector_entries) + else: + self.vector_db_provider.copy_entries(vector_entries) + except Exception as e: + error_message = ( + f"Failed to store vector entries in the database: {e}" + ) + logger.error(error_message) + raise ValueError(error_message) + + async def _run_logic( + self, + input: Input, + state: AsyncState, + run_id: uuid.UUID, + *args: Any, + **kwargs: Any, + ) -> AsyncGenerator[ + Tuple[uuid.UUID, Union[str, R2RDocumentProcessingError]], None + ]: + """ + Executes the async vector storage pipe: storing embeddings in the vector database. + """ + batch_tasks = [] + vector_batch = [] + document_counts = {} + i = 0 + async for msg in input.message: + i += 1 + if isinstance(msg, R2RDocumentProcessingError): + yield (msg.document_id, msg) + continue + + document_id = msg.metadata.get("document_id", None) + if not document_id: + raise ValueError("Document ID not found in the metadata.") + if document_id not in document_counts: + document_counts[document_id] = 1 + else: + document_counts[document_id] += 1 + + vector_batch.append(msg) + if len(vector_batch) >= self.storage_batch_size: + # Schedule the storage task + batch_tasks.append( + asyncio.create_task( + self.store(vector_batch.copy(), input.do_upsert), + name=f"vector-store-{self.config.name}", + ) + ) + vector_batch.clear() + + if vector_batch: # Process any remaining vectors + batch_tasks.append( + asyncio.create_task( + self.store(vector_batch.copy(), input.do_upsert), + name=f"vector-store-{self.config.name}", + ) + ) + + # Wait for all storage tasks to complete + await asyncio.gather(*batch_tasks) + + for document_id, count in document_counts.items(): + yield ( + document_id, + f"Processed {count} vectors for document {document_id}.", + ) |