diff options
Diffstat (limited to 'R2R/r2r/pipes/ingestion/kg_storage_pipe.py')
-rwxr-xr-x | R2R/r2r/pipes/ingestion/kg_storage_pipe.py | 133 |
1 files changed, 133 insertions, 0 deletions
diff --git a/R2R/r2r/pipes/ingestion/kg_storage_pipe.py b/R2R/r2r/pipes/ingestion/kg_storage_pipe.py new file mode 100755 index 00000000..9ac63479 --- /dev/null +++ b/R2R/r2r/pipes/ingestion/kg_storage_pipe.py @@ -0,0 +1,133 @@ +import asyncio +import logging +import uuid +from typing import Any, AsyncGenerator, Optional + +from r2r.base import ( + AsyncState, + EmbeddingProvider, + KGExtraction, + KGProvider, + KVLoggingSingleton, + PipeType, +) +from r2r.base.abstractions.llama_abstractions import EntityNode, Relation +from r2r.base.pipes.base_pipe import AsyncPipe + +logger = logging.getLogger(__name__) + + +class KGStoragePipe(AsyncPipe): + class Input(AsyncPipe.Input): + message: AsyncGenerator[KGExtraction, None] + + def __init__( + self, + kg_provider: KGProvider, + embedding_provider: Optional[EmbeddingProvider] = None, + storage_batch_size: int = 1, + pipe_logger: Optional[KVLoggingSingleton] = None, + type: PipeType = PipeType.INGESTOR, + config: Optional[AsyncPipe.PipeConfig] = None, + *args, + **kwargs, + ): + """ + Initializes the async knowledge graph storage pipe with necessary components and configurations. + """ + logger.info( + f"Initializing an `KGStoragePipe` to store knowledge graph extractions in a graph database." + ) + + super().__init__( + pipe_logger=pipe_logger, + type=type, + config=config, + *args, + **kwargs, + ) + self.kg_provider = kg_provider + self.embedding_provider = embedding_provider + self.storage_batch_size = storage_batch_size + + async def store( + self, + kg_extractions: list[KGExtraction], + ) -> None: + """ + Stores a batch of knowledge graph extractions in the graph database. + """ + try: + nodes = [] + relations = [] + for extraction in kg_extractions: + for entity in extraction.entities.values(): + embedding = None + if self.embedding_provider: + embedding = self.embedding_provider.get_embedding( + "Entity:\n{entity.value}\nLabel:\n{entity.category}\nSubcategory:\n{entity.subcategory}" + ) + nodes.append( + EntityNode( + name=entity.value, + label=entity.category, + embedding=embedding, + properties=( + {"subcategory": entity.subcategory} + if entity.subcategory + else {} + ), + ) + ) + for triple in extraction.triples: + relations.append( + Relation( + source_id=triple.subject, + target_id=triple.object, + label=triple.predicate, + ) + ) + self.kg_provider.upsert_nodes(nodes) + self.kg_provider.upsert_relations(relations) + except Exception as e: + error_message = f"Failed to store knowledge graph extractions in the database: {e}" + logger.error(error_message) + raise ValueError(error_message) + + async def _run_logic( + self, + input: Input, + state: AsyncState, + run_id: uuid.UUID, + *args: Any, + **kwargs: Any, + ) -> AsyncGenerator[None, None]: + """ + Executes the async knowledge graph storage pipe: storing knowledge graph extractions in the graph database. + """ + batch_tasks = [] + kg_batch = [] + + async for kg_extraction in input.message: + kg_batch.append(kg_extraction) + if len(kg_batch) >= self.storage_batch_size: + # Schedule the storage task + batch_tasks.append( + asyncio.create_task( + self.store(kg_batch.copy()), + name=f"kg-store-{self.config.name}", + ) + ) + kg_batch.clear() + + if kg_batch: # Process any remaining extractions + batch_tasks.append( + asyncio.create_task( + self.store(kg_batch.copy()), + name=f"kg-store-{self.config.name}", + ) + ) + + # Wait for all storage tasks to complete + await asyncio.gather(*batch_tasks) + yield None |