1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
|
import asyncio
import logging
import uuid
from typing import Any, AsyncGenerator, Optional
from r2r.base import (
AsyncState,
EmbeddingProvider,
KGExtraction,
KGProvider,
KVLoggingSingleton,
PipeType,
)
from r2r.base.abstractions.llama_abstractions import EntityNode, Relation
from r2r.base.pipes.base_pipe import AsyncPipe
logger = logging.getLogger(__name__)
class KGStoragePipe(AsyncPipe):
class Input(AsyncPipe.Input):
message: AsyncGenerator[KGExtraction, None]
def __init__(
self,
kg_provider: KGProvider,
embedding_provider: Optional[EmbeddingProvider] = None,
storage_batch_size: int = 1,
pipe_logger: Optional[KVLoggingSingleton] = None,
type: PipeType = PipeType.INGESTOR,
config: Optional[AsyncPipe.PipeConfig] = None,
*args,
**kwargs,
):
"""
Initializes the async knowledge graph storage pipe with necessary components and configurations.
"""
logger.info(
f"Initializing an `KGStoragePipe` to store knowledge graph extractions in a graph database."
)
super().__init__(
pipe_logger=pipe_logger,
type=type,
config=config,
*args,
**kwargs,
)
self.kg_provider = kg_provider
self.embedding_provider = embedding_provider
self.storage_batch_size = storage_batch_size
async def store(
self,
kg_extractions: list[KGExtraction],
) -> None:
"""
Stores a batch of knowledge graph extractions in the graph database.
"""
try:
nodes = []
relations = []
for extraction in kg_extractions:
for entity in extraction.entities.values():
embedding = None
if self.embedding_provider:
embedding = self.embedding_provider.get_embedding(
"Entity:\n{entity.value}\nLabel:\n{entity.category}\nSubcategory:\n{entity.subcategory}"
)
nodes.append(
EntityNode(
name=entity.value,
label=entity.category,
embedding=embedding,
properties=(
{"subcategory": entity.subcategory}
if entity.subcategory
else {}
),
)
)
for triple in extraction.triples:
relations.append(
Relation(
source_id=triple.subject,
target_id=triple.object,
label=triple.predicate,
)
)
self.kg_provider.upsert_nodes(nodes)
self.kg_provider.upsert_relations(relations)
except Exception as e:
error_message = f"Failed to store knowledge graph extractions in the database: {e}"
logger.error(error_message)
raise ValueError(error_message)
async def _run_logic(
self,
input: Input,
state: AsyncState,
run_id: uuid.UUID,
*args: Any,
**kwargs: Any,
) -> AsyncGenerator[None, None]:
"""
Executes the async knowledge graph storage pipe: storing knowledge graph extractions in the graph database.
"""
batch_tasks = []
kg_batch = []
async for kg_extraction in input.message:
kg_batch.append(kg_extraction)
if len(kg_batch) >= self.storage_batch_size:
# Schedule the storage task
batch_tasks.append(
asyncio.create_task(
self.store(kg_batch.copy()),
name=f"kg-store-{self.config.name}",
)
)
kg_batch.clear()
if kg_batch: # Process any remaining extractions
batch_tasks.append(
asyncio.create_task(
self.store(kg_batch.copy()),
name=f"kg-store-{self.config.name}",
)
)
# Wait for all storage tasks to complete
await asyncio.gather(*batch_tasks)
yield None
|