1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
|
import json
import logging
import uuid
from typing import Any, AsyncGenerator, Optional
from r2r.base import (
AsyncPipe,
AsyncState,
EmbeddingProvider,
PipeType,
VectorDBProvider,
VectorSearchResult,
VectorSearchSettings,
)
from ..abstractions.search_pipe import SearchPipe
logger = logging.getLogger(__name__)
class VectorSearchPipe(SearchPipe):
def __init__(
self,
vector_db_provider: VectorDBProvider,
embedding_provider: EmbeddingProvider,
type: PipeType = PipeType.SEARCH,
config: Optional[SearchPipe.SearchConfig] = None,
*args,
**kwargs,
):
super().__init__(
type=type,
config=config or SearchPipe.SearchConfig(),
*args,
**kwargs,
)
self.embedding_provider = embedding_provider
self.vector_db_provider = vector_db_provider
async def search(
self,
message: str,
run_id: uuid.UUID,
vector_search_settings: VectorSearchSettings,
*args: Any,
**kwargs: Any,
) -> AsyncGenerator[VectorSearchResult, None]:
await self.enqueue_log(
run_id=run_id, key="search_query", value=message
)
search_filters = (
vector_search_settings.search_filters or self.config.search_filters
)
search_limit = (
vector_search_settings.search_limit or self.config.search_limit
)
results = []
query_vector = self.embedding_provider.get_embedding(
message,
)
search_results = (
self.vector_db_provider.hybrid_search(
query_vector=query_vector,
query_text=message,
filters=search_filters,
limit=search_limit,
)
if vector_search_settings.do_hybrid_search
else self.vector_db_provider.search(
query_vector=query_vector,
filters=search_filters,
limit=search_limit,
)
)
reranked_results = self.embedding_provider.rerank(
query=message, results=search_results, limit=search_limit
)
for result in reranked_results:
result.metadata["associatedQuery"] = message
results.append(result)
yield result
await self.enqueue_log(
run_id=run_id,
key="search_results",
value=json.dumps([ele.json() for ele in results]),
)
async def _run_logic(
self,
input: AsyncPipe.Input,
state: AsyncState,
run_id: uuid.UUID,
vector_search_settings: VectorSearchSettings = VectorSearchSettings(),
*args: Any,
**kwargs: Any,
) -> AsyncGenerator[VectorSearchResult, None]:
search_queries = []
search_results = []
async for search_request in input.message:
search_queries.append(search_request)
async for result in self.search(
message=search_request,
run_id=run_id,
vector_search_settings=vector_search_settings,
*args,
**kwargs,
):
search_results.append(result)
yield result
await state.update(
self.config.name, {"output": {"search_results": search_results}}
)
await state.update(
self.config.name,
{
"output": {
"search_queries": search_queries,
"search_results": search_results,
}
},
)
|