aboutsummaryrefslogtreecommitdiff
path: root/R2R/r2r/pipes/ingestion/kg_storage_pipe.py
diff options
context:
space:
mode:
Diffstat (limited to 'R2R/r2r/pipes/ingestion/kg_storage_pipe.py')
-rwxr-xr-xR2R/r2r/pipes/ingestion/kg_storage_pipe.py133
1 files changed, 133 insertions, 0 deletions
diff --git a/R2R/r2r/pipes/ingestion/kg_storage_pipe.py b/R2R/r2r/pipes/ingestion/kg_storage_pipe.py
new file mode 100755
index 00000000..9ac63479
--- /dev/null
+++ b/R2R/r2r/pipes/ingestion/kg_storage_pipe.py
@@ -0,0 +1,133 @@
+import asyncio
+import logging
+import uuid
+from typing import Any, AsyncGenerator, Optional
+
+from r2r.base import (
+ AsyncState,
+ EmbeddingProvider,
+ KGExtraction,
+ KGProvider,
+ KVLoggingSingleton,
+ PipeType,
+)
+from r2r.base.abstractions.llama_abstractions import EntityNode, Relation
+from r2r.base.pipes.base_pipe import AsyncPipe
+
+logger = logging.getLogger(__name__)
+
+
+class KGStoragePipe(AsyncPipe):
+ class Input(AsyncPipe.Input):
+ message: AsyncGenerator[KGExtraction, None]
+
+ def __init__(
+ self,
+ kg_provider: KGProvider,
+ embedding_provider: Optional[EmbeddingProvider] = None,
+ storage_batch_size: int = 1,
+ pipe_logger: Optional[KVLoggingSingleton] = None,
+ type: PipeType = PipeType.INGESTOR,
+ config: Optional[AsyncPipe.PipeConfig] = None,
+ *args,
+ **kwargs,
+ ):
+ """
+ Initializes the async knowledge graph storage pipe with necessary components and configurations.
+ """
+ logger.info(
+ f"Initializing an `KGStoragePipe` to store knowledge graph extractions in a graph database."
+ )
+
+ super().__init__(
+ pipe_logger=pipe_logger,
+ type=type,
+ config=config,
+ *args,
+ **kwargs,
+ )
+ self.kg_provider = kg_provider
+ self.embedding_provider = embedding_provider
+ self.storage_batch_size = storage_batch_size
+
+ async def store(
+ self,
+ kg_extractions: list[KGExtraction],
+ ) -> None:
+ """
+ Stores a batch of knowledge graph extractions in the graph database.
+ """
+ try:
+ nodes = []
+ relations = []
+ for extraction in kg_extractions:
+ for entity in extraction.entities.values():
+ embedding = None
+ if self.embedding_provider:
+ embedding = self.embedding_provider.get_embedding(
+ "Entity:\n{entity.value}\nLabel:\n{entity.category}\nSubcategory:\n{entity.subcategory}"
+ )
+ nodes.append(
+ EntityNode(
+ name=entity.value,
+ label=entity.category,
+ embedding=embedding,
+ properties=(
+ {"subcategory": entity.subcategory}
+ if entity.subcategory
+ else {}
+ ),
+ )
+ )
+ for triple in extraction.triples:
+ relations.append(
+ Relation(
+ source_id=triple.subject,
+ target_id=triple.object,
+ label=triple.predicate,
+ )
+ )
+ self.kg_provider.upsert_nodes(nodes)
+ self.kg_provider.upsert_relations(relations)
+ except Exception as e:
+ error_message = f"Failed to store knowledge graph extractions in the database: {e}"
+ logger.error(error_message)
+ raise ValueError(error_message)
+
+ async def _run_logic(
+ self,
+ input: Input,
+ state: AsyncState,
+ run_id: uuid.UUID,
+ *args: Any,
+ **kwargs: Any,
+ ) -> AsyncGenerator[None, None]:
+ """
+ Executes the async knowledge graph storage pipe: storing knowledge graph extractions in the graph database.
+ """
+ batch_tasks = []
+ kg_batch = []
+
+ async for kg_extraction in input.message:
+ kg_batch.append(kg_extraction)
+ if len(kg_batch) >= self.storage_batch_size:
+ # Schedule the storage task
+ batch_tasks.append(
+ asyncio.create_task(
+ self.store(kg_batch.copy()),
+ name=f"kg-store-{self.config.name}",
+ )
+ )
+ kg_batch.clear()
+
+ if kg_batch: # Process any remaining extractions
+ batch_tasks.append(
+ asyncio.create_task(
+ self.store(kg_batch.copy()),
+ name=f"kg-store-{self.config.name}",
+ )
+ )
+
+ # Wait for all storage tasks to complete
+ await asyncio.gather(*batch_tasks)
+ yield None