from r2r.main import R2RPipeFactory from r2r.pipes.retrieval.multi_search import MultiSearchPipe from r2r.pipes.retrieval.query_transform_pipe import QueryTransformPipe class R2RPipeFactoryWithMultiSearch(R2RPipeFactory): QUERY_GENERATION_TEMPLATE: dict = ( { # TODO - Can we have stricter typing like so? `: {"template": str, "input_types": dict[str, str]} = {`` "template": "### Instruction:\n\nGiven the following query that follows to write a double newline separated list of up to {num_outputs} queries meant to help answer the original query. \nDO NOT generate any single query which is likely to require information from multiple distinct documents, \nEACH single query will be used to carry out a cosine similarity semantic search over distinct indexed documents, such as varied medical documents. \nFOR EXAMPLE if asked `how do the key themes of Great Gatsby compare with 1984`, the two queries would be \n`What are the key themes of Great Gatsby?` and `What are the key themes of 1984?`.\nHere is the original user query to be transformed into answers:\n\n### Query:\n{message}\n\n### Response:\n", "input_types": {"num_outputs": "int", "message": "str"}, } ) def create_vector_search_pipe(self, *args, **kwargs): """ A factory method to create a search pipe. Overrides include task_prompt_name: str multi_query_transform_pipe_override: QueryTransformPipe multi_inner_search_pipe_override: SearchPipe query_generation_template_override: {'template': str, 'input_types': dict[str, str]} """ multi_search_config = MultiSearchPipe.PipeConfig() if kwargs.get("task_prompt_name") and kwargs.get( "query_generation_template_override" ): raise ValueError( "Cannot provide both `task_prompt_name` and `query_generation_template_override`" ) task_prompt_name = ( kwargs.get("task_prompt_name") or f"{multi_search_config.name}_task_prompt" ) if kwargs.get("query_generation_template_override"): # Add a prompt for transforming the user query template = kwargs.get("query_generation_template_override") self.providers.prompt.add_prompt( **( kwargs.get("query_generation_template_override") or self.QUERY_GENERATION_TEMPLATE ), ) task_prompt_name = template["name"] # Initialize the new query transform pipe query_transform_pipe = kwargs.get( "multi_query_transform_pipe_override", None ) or QueryTransformPipe( llm_provider=self.providers.llm, prompt_provider=self.providers.prompt, config=QueryTransformPipe.QueryTransformConfig( name=multi_search_config.name, task_prompt=task_prompt_name, ), ) # Create search pipe override and pipes inner_search_pipe = kwargs.get( "multi_inner_search_pipe_override", None ) or super().create_vector_search_pipe(*args, **kwargs) # TODO - modify `create_..._pipe` to allow naming the pipe inner_search_pipe.config.name = multi_search_config.name return MultiSearchPipe( query_transform_pipe=query_transform_pipe, inner_search_pipe=inner_search_pipe, config=multi_search_config, )