about summary refs log tree commit diff
path: root/R2R/r2r/pipes/ingestion/kg_storage_pipe.py
diff options
context:
space:
mode:
Diffstat (limited to 'R2R/r2r/pipes/ingestion/kg_storage_pipe.py')
-rwxr-xr-xR2R/r2r/pipes/ingestion/kg_storage_pipe.py133
1 files changed, 133 insertions, 0 deletions
diff --git a/R2R/r2r/pipes/ingestion/kg_storage_pipe.py b/R2R/r2r/pipes/ingestion/kg_storage_pipe.py
new file mode 100755
index 00000000..9ac63479
--- /dev/null
+++ b/R2R/r2r/pipes/ingestion/kg_storage_pipe.py
@@ -0,0 +1,133 @@
+import asyncio
+import logging
+import uuid
+from typing import Any, AsyncGenerator, Optional
+
+from r2r.base import (
+    AsyncState,
+    EmbeddingProvider,
+    KGExtraction,
+    KGProvider,
+    KVLoggingSingleton,
+    PipeType,
+)
+from r2r.base.abstractions.llama_abstractions import EntityNode, Relation
+from r2r.base.pipes.base_pipe import AsyncPipe
+
+logger = logging.getLogger(__name__)
+
+
+class KGStoragePipe(AsyncPipe):
+    class Input(AsyncPipe.Input):
+        message: AsyncGenerator[KGExtraction, None]
+
+    def __init__(
+        self,
+        kg_provider: KGProvider,
+        embedding_provider: Optional[EmbeddingProvider] = None,
+        storage_batch_size: int = 1,
+        pipe_logger: Optional[KVLoggingSingleton] = None,
+        type: PipeType = PipeType.INGESTOR,
+        config: Optional[AsyncPipe.PipeConfig] = None,
+        *args,
+        **kwargs,
+    ):
+        """
+        Initializes the async knowledge graph storage pipe with necessary components and configurations.
+        """
+        logger.info(
+            f"Initializing an `KGStoragePipe` to store knowledge graph extractions in a graph database."
+        )
+
+        super().__init__(
+            pipe_logger=pipe_logger,
+            type=type,
+            config=config,
+            *args,
+            **kwargs,
+        )
+        self.kg_provider = kg_provider
+        self.embedding_provider = embedding_provider
+        self.storage_batch_size = storage_batch_size
+
+    async def store(
+        self,
+        kg_extractions: list[KGExtraction],
+    ) -> None:
+        """
+        Stores a batch of knowledge graph extractions in the graph database.
+        """
+        try:
+            nodes = []
+            relations = []
+            for extraction in kg_extractions:
+                for entity in extraction.entities.values():
+                    embedding = None
+                    if self.embedding_provider:
+                        embedding = self.embedding_provider.get_embedding(
+                            "Entity:\n{entity.value}\nLabel:\n{entity.category}\nSubcategory:\n{entity.subcategory}"
+                        )
+                    nodes.append(
+                        EntityNode(
+                            name=entity.value,
+                            label=entity.category,
+                            embedding=embedding,
+                            properties=(
+                                {"subcategory": entity.subcategory}
+                                if entity.subcategory
+                                else {}
+                            ),
+                        )
+                    )
+                for triple in extraction.triples:
+                    relations.append(
+                        Relation(
+                            source_id=triple.subject,
+                            target_id=triple.object,
+                            label=triple.predicate,
+                        )
+                    )
+            self.kg_provider.upsert_nodes(nodes)
+            self.kg_provider.upsert_relations(relations)
+        except Exception as e:
+            error_message = f"Failed to store knowledge graph extractions in the database: {e}"
+            logger.error(error_message)
+            raise ValueError(error_message)
+
+    async def _run_logic(
+        self,
+        input: Input,
+        state: AsyncState,
+        run_id: uuid.UUID,
+        *args: Any,
+        **kwargs: Any,
+    ) -> AsyncGenerator[None, None]:
+        """
+        Executes the async knowledge graph storage pipe: storing knowledge graph extractions in the graph database.
+        """
+        batch_tasks = []
+        kg_batch = []
+
+        async for kg_extraction in input.message:
+            kg_batch.append(kg_extraction)
+            if len(kg_batch) >= self.storage_batch_size:
+                # Schedule the storage task
+                batch_tasks.append(
+                    asyncio.create_task(
+                        self.store(kg_batch.copy()),
+                        name=f"kg-store-{self.config.name}",
+                    )
+                )
+                kg_batch.clear()
+
+        if kg_batch:  # Process any remaining extractions
+            batch_tasks.append(
+                asyncio.create_task(
+                    self.store(kg_batch.copy()),
+                    name=f"kg-store-{self.config.name}",
+                )
+            )
+
+        # Wait for all storage tasks to complete
+        await asyncio.gather(*batch_tasks)
+        yield None