aboutsummaryrefslogtreecommitdiff
path: root/R2R/r2r/pipelines/ingestion_pipeline.py
blob: df1263f97c11d240aa177c282cff226bc3610639 (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import asyncio
import logging
from asyncio import Queue
from typing import Any, Optional

from r2r.base.logging.kv_logger import KVLoggingSingleton
from r2r.base.logging.run_manager import RunManager, manage_run
from r2r.base.pipeline.base_pipeline import AsyncPipeline, dequeue_requests
from r2r.base.pipes.base_pipe import AsyncPipe, AsyncState

logger = logging.getLogger(__name__)


class IngestionPipeline(AsyncPipeline):
    """A pipeline for ingestion."""

    pipeline_type: str = "ingestion"

    def __init__(
        self,
        pipe_logger: Optional[KVLoggingSingleton] = None,
        run_manager: Optional[RunManager] = None,
    ):
        super().__init__(pipe_logger, run_manager)
        self.parsing_pipe = None
        self.embedding_pipeline = None
        self.kg_pipeline = None

    async def run(
        self,
        input: Any,
        state: Optional[AsyncState] = None,
        stream: bool = False,
        run_manager: Optional[RunManager] = None,
        log_run_info: bool = True,
        *args: Any,
        **kwargs: Any,
    ) -> None:
        self.state = state or AsyncState()
        async with manage_run(run_manager, self.pipeline_type):
            if log_run_info:
                await run_manager.log_run_info(
                    key="pipeline_type",
                    value=self.pipeline_type,
                    is_info_log=True,
                )
            if self.parsing_pipe is None:
                raise ValueError(
                    "parsing_pipeline must be set before running the ingestion pipeline"
                )
            if self.embedding_pipeline is None and self.kg_pipeline is None:
                raise ValueError(
                    "At least one of embedding_pipeline or kg_pipeline must be set before running the ingestion pipeline"
                )
            # Use queues to duplicate the documents for each pipeline
            embedding_queue = Queue()
            kg_queue = Queue()

            async def enqueue_documents():
                async for document in await self.parsing_pipe.run(
                    self.parsing_pipe.Input(message=input),
                    state,
                    run_manager,
                    *args,
                    **kwargs,
                ):
                    if self.embedding_pipeline:
                        await embedding_queue.put(document)
                    if self.kg_pipeline:
                        await kg_queue.put(document)
                await embedding_queue.put(None)
                await kg_queue.put(None)

            # Start the document enqueuing process
            enqueue_task = asyncio.create_task(enqueue_documents())

            # Start the embedding and KG pipelines in parallel
            if self.embedding_pipeline:
                embedding_task = asyncio.create_task(
                    self.embedding_pipeline.run(
                        dequeue_requests(embedding_queue),
                        state,
                        stream,
                        run_manager,
                        log_run_info=False,  # Do not log run info since we have already done so
                        *args,
                        **kwargs,
                    )
                )

            if self.kg_pipeline:
                kg_task = asyncio.create_task(
                    self.kg_pipeline.run(
                        dequeue_requests(kg_queue),
                        state,
                        stream,
                        run_manager,
                        log_run_info=False,  # Do not log run info since we have already done so
                        *args,
                        **kwargs,
                    )
                )

            # Wait for the enqueueing task to complete
            await enqueue_task

            results = {}
            # Wait for the embedding and KG tasks to complete
            if self.embedding_pipeline:
                results["embedding_pipeline_output"] = await embedding_task
            if self.kg_pipeline:
                results["kg_pipeline_output"] = await kg_task
            return results

    def add_pipe(
        self,
        pipe: AsyncPipe,
        add_upstream_outputs: Optional[list[dict[str, str]]] = None,
        parsing_pipe: bool = False,
        kg_pipe: bool = False,
        embedding_pipe: bool = False,
        *args,
        **kwargs,
    ) -> None:
        logger.debug(
            f"Adding pipe {pipe.config.name} to the IngestionPipeline"
        )

        if parsing_pipe:
            self.parsing_pipe = pipe
        elif kg_pipe:
            if not self.kg_pipeline:
                self.kg_pipeline = AsyncPipeline()
            self.kg_pipeline.add_pipe(
                pipe, add_upstream_outputs, *args, **kwargs
            )
        elif embedding_pipe:
            if not self.embedding_pipeline:
                self.embedding_pipeline = AsyncPipeline()
            self.embedding_pipeline.add_pipe(
                pipe, add_upstream_outputs, *args, **kwargs
            )
        else:
            raise ValueError("Pipe must be a parsing, embedding, or KG pipe")