aboutsummaryrefslogtreecommitdiff
path: root/R2R/r2r/pipes/ingestion/vector_storage_pipe.py
blob: 9564fd22cb0f6187ff17de92d5b9e891f0b8be4a (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import asyncio
import logging
import uuid
from typing import Any, AsyncGenerator, Optional, Tuple, Union

from r2r.base import (
    AsyncState,
    KVLoggingSingleton,
    PipeType,
    VectorDBProvider,
    VectorEntry,
)
from r2r.base.pipes.base_pipe import AsyncPipe

from ...base.abstractions.exception import R2RDocumentProcessingError

logger = logging.getLogger(__name__)


class VectorStoragePipe(AsyncPipe):
    class Input(AsyncPipe.Input):
        message: AsyncGenerator[
            Union[R2RDocumentProcessingError, VectorEntry], None
        ]
        do_upsert: bool = True

    def __init__(
        self,
        vector_db_provider: VectorDBProvider,
        storage_batch_size: int = 128,
        pipe_logger: Optional[KVLoggingSingleton] = None,
        type: PipeType = PipeType.INGESTOR,
        config: Optional[AsyncPipe.PipeConfig] = None,
        *args,
        **kwargs,
    ):
        """
        Initializes the async vector storage pipe with necessary components and configurations.
        """
        super().__init__(
            pipe_logger=pipe_logger,
            type=type,
            config=config,
            *args,
            **kwargs,
        )
        self.vector_db_provider = vector_db_provider
        self.storage_batch_size = storage_batch_size

    async def store(
        self,
        vector_entries: list[VectorEntry],
        do_upsert: bool = True,
    ) -> None:
        """
        Stores a batch of vector entries in the database.
        """

        try:
            if do_upsert:
                self.vector_db_provider.upsert_entries(vector_entries)
            else:
                self.vector_db_provider.copy_entries(vector_entries)
        except Exception as e:
            error_message = (
                f"Failed to store vector entries in the database: {e}"
            )
            logger.error(error_message)
            raise ValueError(error_message)

    async def _run_logic(
        self,
        input: Input,
        state: AsyncState,
        run_id: uuid.UUID,
        *args: Any,
        **kwargs: Any,
    ) -> AsyncGenerator[
        Tuple[uuid.UUID, Union[str, R2RDocumentProcessingError]], None
    ]:
        """
        Executes the async vector storage pipe: storing embeddings in the vector database.
        """
        batch_tasks = []
        vector_batch = []
        document_counts = {}
        i = 0
        async for msg in input.message:
            i += 1
            if isinstance(msg, R2RDocumentProcessingError):
                yield (msg.document_id, msg)
                continue

            document_id = msg.metadata.get("document_id", None)
            if not document_id:
                raise ValueError("Document ID not found in the metadata.")
            if document_id not in document_counts:
                document_counts[document_id] = 1
            else:
                document_counts[document_id] += 1

            vector_batch.append(msg)
            if len(vector_batch) >= self.storage_batch_size:
                # Schedule the storage task
                batch_tasks.append(
                    asyncio.create_task(
                        self.store(vector_batch.copy(), input.do_upsert),
                        name=f"vector-store-{self.config.name}",
                    )
                )
                vector_batch.clear()

        if vector_batch:  # Process any remaining vectors
            batch_tasks.append(
                asyncio.create_task(
                    self.store(vector_batch.copy(), input.do_upsert),
                    name=f"vector-store-{self.config.name}",
                )
            )

        # Wait for all storage tasks to complete
        await asyncio.gather(*batch_tasks)

        for document_id, count in document_counts.items():
            yield (
                document_id,
                f"Processed {count} vectors for document {document_id}.",
            )