1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
|
import asyncio
import logging
from asyncio import Queue
from typing import Any, Optional
from ..base.abstractions.search import (
AggregateSearchResult,
KGSearchSettings,
VectorSearchSettings,
)
from ..base.logging.kv_logger import KVLoggingSingleton
from ..base.logging.run_manager import RunManager, manage_run
from ..base.pipeline.base_pipeline import AsyncPipeline, dequeue_requests
from ..base.pipes.base_pipe import AsyncPipe, AsyncState
logger = logging.getLogger(__name__)
class SearchPipeline(AsyncPipeline):
"""A pipeline for search."""
pipeline_type: str = "search"
def __init__(
self,
pipe_logger: Optional[KVLoggingSingleton] = None,
run_manager: Optional[RunManager] = None,
):
super().__init__(pipe_logger, run_manager)
self._parsing_pipe = None
self._vector_search_pipeline = None
self._kg_search_pipeline = None
async def run(
self,
input: Any,
state: Optional[AsyncState] = None,
stream: bool = False,
run_manager: Optional[RunManager] = None,
log_run_info: bool = True,
vector_search_settings: VectorSearchSettings = VectorSearchSettings(),
kg_search_settings: KGSearchSettings = KGSearchSettings(),
*args: Any,
**kwargs: Any,
):
self.state = state or AsyncState()
do_vector_search = (
self._vector_search_pipeline is not None
and vector_search_settings.use_vector_search
)
do_kg = (
self._kg_search_pipeline is not None
and kg_search_settings.use_kg_search
)
async with manage_run(run_manager, self.pipeline_type):
if log_run_info:
await run_manager.log_run_info(
key="pipeline_type",
value=self.pipeline_type,
is_info_log=True,
)
vector_search_queue = Queue()
kg_queue = Queue()
async def enqueue_requests():
async for message in input:
if do_vector_search:
await vector_search_queue.put(message)
if do_kg:
await kg_queue.put(message)
await vector_search_queue.put(None)
await kg_queue.put(None)
# Start the document enqueuing process
enqueue_task = asyncio.create_task(enqueue_requests())
# Start the embedding and KG pipelines in parallel
if do_vector_search:
vector_search_task = asyncio.create_task(
self._vector_search_pipeline.run(
dequeue_requests(vector_search_queue),
state,
stream,
run_manager,
log_run_info=False,
vector_search_settings=vector_search_settings,
)
)
if do_kg:
kg_task = asyncio.create_task(
self._kg_search_pipeline.run(
dequeue_requests(kg_queue),
state,
stream,
run_manager,
log_run_info=False,
kg_search_settings=kg_search_settings,
)
)
await enqueue_task
vector_search_results = (
await vector_search_task if do_vector_search else None
)
kg_results = await kg_task if do_kg else None
return AggregateSearchResult(
vector_search_results=vector_search_results,
kg_search_results=kg_results,
)
def add_pipe(
self,
pipe: AsyncPipe,
add_upstream_outputs: Optional[list[dict[str, str]]] = None,
kg_pipe: bool = False,
vector_search_pipe: bool = False,
*args,
**kwargs,
) -> None:
logger.debug(f"Adding pipe {pipe.config.name} to the SearchPipeline")
if kg_pipe:
if not self._kg_search_pipeline:
self._kg_search_pipeline = AsyncPipeline()
self._kg_search_pipeline.add_pipe(
pipe, add_upstream_outputs, *args, **kwargs
)
elif vector_search_pipe:
if not self._vector_search_pipeline:
self._vector_search_pipeline = AsyncPipeline()
self._vector_search_pipeline.add_pipe(
pipe, add_upstream_outputs, *args, **kwargs
)
else:
raise ValueError("Pipe must be a vector search or KG pipe")
|