diff options
Diffstat (limited to 'R2R/r2r/pipes/retrieval/vector_search_pipe.py')
-rwxr-xr-x | R2R/r2r/pipes/retrieval/vector_search_pipe.py | 123 |
1 files changed, 123 insertions, 0 deletions
diff --git a/R2R/r2r/pipes/retrieval/vector_search_pipe.py b/R2R/r2r/pipes/retrieval/vector_search_pipe.py new file mode 100755 index 00000000..742de16b --- /dev/null +++ b/R2R/r2r/pipes/retrieval/vector_search_pipe.py @@ -0,0 +1,123 @@ +import json +import logging +import uuid +from typing import Any, AsyncGenerator, Optional + +from r2r.base import ( + AsyncPipe, + AsyncState, + EmbeddingProvider, + PipeType, + VectorDBProvider, + VectorSearchResult, + VectorSearchSettings, +) + +from ..abstractions.search_pipe import SearchPipe + +logger = logging.getLogger(__name__) + + +class VectorSearchPipe(SearchPipe): + def __init__( + self, + vector_db_provider: VectorDBProvider, + embedding_provider: EmbeddingProvider, + type: PipeType = PipeType.SEARCH, + config: Optional[SearchPipe.SearchConfig] = None, + *args, + **kwargs, + ): + super().__init__( + type=type, + config=config or SearchPipe.SearchConfig(), + *args, + **kwargs, + ) + self.embedding_provider = embedding_provider + self.vector_db_provider = vector_db_provider + + async def search( + self, + message: str, + run_id: uuid.UUID, + vector_search_settings: VectorSearchSettings, + *args: Any, + **kwargs: Any, + ) -> AsyncGenerator[VectorSearchResult, None]: + await self.enqueue_log( + run_id=run_id, key="search_query", value=message + ) + search_filters = ( + vector_search_settings.search_filters or self.config.search_filters + ) + search_limit = ( + vector_search_settings.search_limit or self.config.search_limit + ) + results = [] + query_vector = self.embedding_provider.get_embedding( + message, + ) + search_results = ( + self.vector_db_provider.hybrid_search( + query_vector=query_vector, + query_text=message, + filters=search_filters, + limit=search_limit, + ) + if vector_search_settings.do_hybrid_search + else self.vector_db_provider.search( + query_vector=query_vector, + filters=search_filters, + limit=search_limit, + ) + ) + reranked_results = self.embedding_provider.rerank( + query=message, results=search_results, limit=search_limit + ) + for result in reranked_results: + result.metadata["associatedQuery"] = message + results.append(result) + yield result + await self.enqueue_log( + run_id=run_id, + key="search_results", + value=json.dumps([ele.json() for ele in results]), + ) + + async def _run_logic( + self, + input: AsyncPipe.Input, + state: AsyncState, + run_id: uuid.UUID, + vector_search_settings: VectorSearchSettings = VectorSearchSettings(), + *args: Any, + **kwargs: Any, + ) -> AsyncGenerator[VectorSearchResult, None]: + search_queries = [] + search_results = [] + async for search_request in input.message: + search_queries.append(search_request) + async for result in self.search( + message=search_request, + run_id=run_id, + vector_search_settings=vector_search_settings, + *args, + **kwargs, + ): + search_results.append(result) + yield result + + await state.update( + self.config.name, {"output": {"search_results": search_results}} + ) + + await state.update( + self.config.name, + { + "output": { + "search_queries": search_queries, + "search_results": search_results, + } + }, + ) |