aboutsummaryrefslogtreecommitdiff
path: root/R2R/r2r/pipes/retrieval/vector_search_pipe.py
diff options
context:
space:
mode:
Diffstat (limited to 'R2R/r2r/pipes/retrieval/vector_search_pipe.py')
-rwxr-xr-xR2R/r2r/pipes/retrieval/vector_search_pipe.py123
1 files changed, 123 insertions, 0 deletions
diff --git a/R2R/r2r/pipes/retrieval/vector_search_pipe.py b/R2R/r2r/pipes/retrieval/vector_search_pipe.py
new file mode 100755
index 00000000..742de16b
--- /dev/null
+++ b/R2R/r2r/pipes/retrieval/vector_search_pipe.py
@@ -0,0 +1,123 @@
+import json
+import logging
+import uuid
+from typing import Any, AsyncGenerator, Optional
+
+from r2r.base import (
+ AsyncPipe,
+ AsyncState,
+ EmbeddingProvider,
+ PipeType,
+ VectorDBProvider,
+ VectorSearchResult,
+ VectorSearchSettings,
+)
+
+from ..abstractions.search_pipe import SearchPipe
+
+logger = logging.getLogger(__name__)
+
+
+class VectorSearchPipe(SearchPipe):
+ def __init__(
+ self,
+ vector_db_provider: VectorDBProvider,
+ embedding_provider: EmbeddingProvider,
+ type: PipeType = PipeType.SEARCH,
+ config: Optional[SearchPipe.SearchConfig] = None,
+ *args,
+ **kwargs,
+ ):
+ super().__init__(
+ type=type,
+ config=config or SearchPipe.SearchConfig(),
+ *args,
+ **kwargs,
+ )
+ self.embedding_provider = embedding_provider
+ self.vector_db_provider = vector_db_provider
+
+ async def search(
+ self,
+ message: str,
+ run_id: uuid.UUID,
+ vector_search_settings: VectorSearchSettings,
+ *args: Any,
+ **kwargs: Any,
+ ) -> AsyncGenerator[VectorSearchResult, None]:
+ await self.enqueue_log(
+ run_id=run_id, key="search_query", value=message
+ )
+ search_filters = (
+ vector_search_settings.search_filters or self.config.search_filters
+ )
+ search_limit = (
+ vector_search_settings.search_limit or self.config.search_limit
+ )
+ results = []
+ query_vector = self.embedding_provider.get_embedding(
+ message,
+ )
+ search_results = (
+ self.vector_db_provider.hybrid_search(
+ query_vector=query_vector,
+ query_text=message,
+ filters=search_filters,
+ limit=search_limit,
+ )
+ if vector_search_settings.do_hybrid_search
+ else self.vector_db_provider.search(
+ query_vector=query_vector,
+ filters=search_filters,
+ limit=search_limit,
+ )
+ )
+ reranked_results = self.embedding_provider.rerank(
+ query=message, results=search_results, limit=search_limit
+ )
+ for result in reranked_results:
+ result.metadata["associatedQuery"] = message
+ results.append(result)
+ yield result
+ await self.enqueue_log(
+ run_id=run_id,
+ key="search_results",
+ value=json.dumps([ele.json() for ele in results]),
+ )
+
+ async def _run_logic(
+ self,
+ input: AsyncPipe.Input,
+ state: AsyncState,
+ run_id: uuid.UUID,
+ vector_search_settings: VectorSearchSettings = VectorSearchSettings(),
+ *args: Any,
+ **kwargs: Any,
+ ) -> AsyncGenerator[VectorSearchResult, None]:
+ search_queries = []
+ search_results = []
+ async for search_request in input.message:
+ search_queries.append(search_request)
+ async for result in self.search(
+ message=search_request,
+ run_id=run_id,
+ vector_search_settings=vector_search_settings,
+ *args,
+ **kwargs,
+ ):
+ search_results.append(result)
+ yield result
+
+ await state.update(
+ self.config.name, {"output": {"search_results": search_results}}
+ )
+
+ await state.update(
+ self.config.name,
+ {
+ "output": {
+ "search_queries": search_queries,
+ "search_results": search_results,
+ }
+ },
+ )