1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
|
import asyncio
import logging
import uuid
from typing import Any, AsyncGenerator, Optional, Tuple, Union
from r2r.base import (
AsyncState,
KVLoggingSingleton,
PipeType,
VectorDBProvider,
VectorEntry,
)
from r2r.base.pipes.base_pipe import AsyncPipe
from ...base.abstractions.exception import R2RDocumentProcessingError
logger = logging.getLogger(__name__)
class VectorStoragePipe(AsyncPipe):
class Input(AsyncPipe.Input):
message: AsyncGenerator[
Union[R2RDocumentProcessingError, VectorEntry], None
]
do_upsert: bool = True
def __init__(
self,
vector_db_provider: VectorDBProvider,
storage_batch_size: int = 128,
pipe_logger: Optional[KVLoggingSingleton] = None,
type: PipeType = PipeType.INGESTOR,
config: Optional[AsyncPipe.PipeConfig] = None,
*args,
**kwargs,
):
"""
Initializes the async vector storage pipe with necessary components and configurations.
"""
super().__init__(
pipe_logger=pipe_logger,
type=type,
config=config,
*args,
**kwargs,
)
self.vector_db_provider = vector_db_provider
self.storage_batch_size = storage_batch_size
async def store(
self,
vector_entries: list[VectorEntry],
do_upsert: bool = True,
) -> None:
"""
Stores a batch of vector entries in the database.
"""
try:
if do_upsert:
self.vector_db_provider.upsert_entries(vector_entries)
else:
self.vector_db_provider.copy_entries(vector_entries)
except Exception as e:
error_message = (
f"Failed to store vector entries in the database: {e}"
)
logger.error(error_message)
raise ValueError(error_message)
async def _run_logic(
self,
input: Input,
state: AsyncState,
run_id: uuid.UUID,
*args: Any,
**kwargs: Any,
) -> AsyncGenerator[
Tuple[uuid.UUID, Union[str, R2RDocumentProcessingError]], None
]:
"""
Executes the async vector storage pipe: storing embeddings in the vector database.
"""
batch_tasks = []
vector_batch = []
document_counts = {}
i = 0
async for msg in input.message:
i += 1
if isinstance(msg, R2RDocumentProcessingError):
yield (msg.document_id, msg)
continue
document_id = msg.metadata.get("document_id", None)
if not document_id:
raise ValueError("Document ID not found in the metadata.")
if document_id not in document_counts:
document_counts[document_id] = 1
else:
document_counts[document_id] += 1
vector_batch.append(msg)
if len(vector_batch) >= self.storage_batch_size:
# Schedule the storage task
batch_tasks.append(
asyncio.create_task(
self.store(vector_batch.copy(), input.do_upsert),
name=f"vector-store-{self.config.name}",
)
)
vector_batch.clear()
if vector_batch: # Process any remaining vectors
batch_tasks.append(
asyncio.create_task(
self.store(vector_batch.copy(), input.do_upsert),
name=f"vector-store-{self.config.name}",
)
)
# Wait for all storage tasks to complete
await asyncio.gather(*batch_tasks)
for document_id, count in document_counts.items():
yield (
document_id,
f"Processed {count} vectors for document {document_id}.",
)
|