aboutsummaryrefslogtreecommitdiff
path: root/R2R/r2r/pipes/retrieval/multi_search.py
diff options
context:
space:
mode:
Diffstat (limited to 'R2R/r2r/pipes/retrieval/multi_search.py')
-rwxr-xr-xR2R/r2r/pipes/retrieval/multi_search.py79
1 files changed, 79 insertions, 0 deletions
diff --git a/R2R/r2r/pipes/retrieval/multi_search.py b/R2R/r2r/pipes/retrieval/multi_search.py
new file mode 100755
index 00000000..6da2c34b
--- /dev/null
+++ b/R2R/r2r/pipes/retrieval/multi_search.py
@@ -0,0 +1,79 @@
+import uuid
+from copy import copy
+from typing import Any, AsyncGenerator, Optional
+
+from r2r.base.abstractions.llm import GenerationConfig
+from r2r.base.abstractions.search import VectorSearchResult
+from r2r.base.pipes.base_pipe import AsyncPipe
+
+from ..abstractions.search_pipe import SearchPipe
+from .query_transform_pipe import QueryTransformPipe
+
+
+class MultiSearchPipe(AsyncPipe):
+ class PipeConfig(AsyncPipe.PipeConfig):
+ name: str = "multi_search_pipe"
+
+ def __init__(
+ self,
+ query_transform_pipe: QueryTransformPipe,
+ inner_search_pipe: SearchPipe,
+ config: Optional[PipeConfig] = None,
+ *args,
+ **kwargs,
+ ):
+ self.query_transform_pipe = query_transform_pipe
+ self.vector_search_pipe = inner_search_pipe
+ if (
+ not query_transform_pipe.config.name
+ == inner_search_pipe.config.name
+ ):
+ raise ValueError(
+ "The query transform pipe and search pipe must have the same name."
+ )
+ if config and not config.name == query_transform_pipe.config.name:
+ raise ValueError(
+ "The pipe config name must match the query transform pipe name."
+ )
+
+ super().__init__(
+ config=config
+ or MultiSearchPipe.PipeConfig(
+ name=query_transform_pipe.config.name
+ ),
+ *args,
+ **kwargs,
+ )
+
+ async def _run_logic(
+ self,
+ input: Any,
+ state: Any,
+ run_id: uuid.UUID,
+ query_transform_generation_config: Optional[GenerationConfig] = None,
+ *args: Any,
+ **kwargs: Any,
+ ) -> AsyncGenerator[VectorSearchResult, None]:
+ query_transform_generation_config = (
+ query_transform_generation_config
+ or copy(kwargs.get("rag_generation_config", None))
+ or GenerationConfig(model="gpt-4o")
+ )
+ query_transform_generation_config.stream = False
+
+ query_generator = await self.query_transform_pipe.run(
+ input,
+ state,
+ query_transform_generation_config=query_transform_generation_config,
+ num_query_xf_outputs=3,
+ *args,
+ **kwargs,
+ )
+
+ async for search_result in await self.vector_search_pipe.run(
+ self.vector_search_pipe.Input(message=query_generator),
+ state,
+ *args,
+ **kwargs,
+ ):
+ yield search_result