diff options
Diffstat (limited to 'R2R/r2r/main/assembly/factory_extensions.py')
-rwxr-xr-x | R2R/r2r/main/assembly/factory_extensions.py | 69 |
1 files changed, 69 insertions, 0 deletions
diff --git a/R2R/r2r/main/assembly/factory_extensions.py b/R2R/r2r/main/assembly/factory_extensions.py new file mode 100755 index 00000000..56e82ef7 --- /dev/null +++ b/R2R/r2r/main/assembly/factory_extensions.py @@ -0,0 +1,69 @@ +from r2r.main import R2RPipeFactory +from r2r.pipes.retrieval.multi_search import MultiSearchPipe +from r2r.pipes.retrieval.query_transform_pipe import QueryTransformPipe + + +class R2RPipeFactoryWithMultiSearch(R2RPipeFactory): + QUERY_GENERATION_TEMPLATE: dict = ( + { # TODO - Can we have stricter typing like so? `: {"template": str, "input_types": dict[str, str]} = {`` + "template": "### Instruction:\n\nGiven the following query that follows to write a double newline separated list of up to {num_outputs} queries meant to help answer the original query. \nDO NOT generate any single query which is likely to require information from multiple distinct documents, \nEACH single query will be used to carry out a cosine similarity semantic search over distinct indexed documents, such as varied medical documents. \nFOR EXAMPLE if asked `how do the key themes of Great Gatsby compare with 1984`, the two queries would be \n`What are the key themes of Great Gatsby?` and `What are the key themes of 1984?`.\nHere is the original user query to be transformed into answers:\n\n### Query:\n{message}\n\n### Response:\n", + "input_types": {"num_outputs": "int", "message": "str"}, + } + ) + + def create_vector_search_pipe(self, *args, **kwargs): + """ + A factory method to create a search pipe. + + Overrides include + task_prompt_name: str + multi_query_transform_pipe_override: QueryTransformPipe + multi_inner_search_pipe_override: SearchPipe + query_generation_template_override: {'template': str, 'input_types': dict[str, str]} + """ + multi_search_config = MultiSearchPipe.PipeConfig() + if kwargs.get("task_prompt_name") and kwargs.get( + "query_generation_template_override" + ): + raise ValueError( + "Cannot provide both `task_prompt_name` and `query_generation_template_override`" + ) + task_prompt_name = ( + kwargs.get("task_prompt_name") + or f"{multi_search_config.name}_task_prompt" + ) + if kwargs.get("query_generation_template_override"): + # Add a prompt for transforming the user query + template = kwargs.get("query_generation_template_override") + self.providers.prompt.add_prompt( + **( + kwargs.get("query_generation_template_override") + or self.QUERY_GENERATION_TEMPLATE + ), + ) + task_prompt_name = template["name"] + + # Initialize the new query transform pipe + query_transform_pipe = kwargs.get( + "multi_query_transform_pipe_override", None + ) or QueryTransformPipe( + llm_provider=self.providers.llm, + prompt_provider=self.providers.prompt, + config=QueryTransformPipe.QueryTransformConfig( + name=multi_search_config.name, + task_prompt=task_prompt_name, + ), + ) + # Create search pipe override and pipes + inner_search_pipe = kwargs.get( + "multi_inner_search_pipe_override", None + ) or super().create_vector_search_pipe(*args, **kwargs) + + # TODO - modify `create_..._pipe` to allow naming the pipe + inner_search_pipe.config.name = multi_search_config.name + + return MultiSearchPipe( + query_transform_pipe=query_transform_pipe, + inner_search_pipe=inner_search_pipe, + config=multi_search_config, + ) |