diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /R2R/r2r/pipes/retrieval/query_transform_pipe.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-4a52a71956a8d46fcb7294ac71734504bb09bcc2.tar.gz |
Diffstat (limited to 'R2R/r2r/pipes/retrieval/query_transform_pipe.py')
-rwxr-xr-x | R2R/r2r/pipes/retrieval/query_transform_pipe.py | 101 |
1 files changed, 101 insertions, 0 deletions
diff --git a/R2R/r2r/pipes/retrieval/query_transform_pipe.py b/R2R/r2r/pipes/retrieval/query_transform_pipe.py new file mode 100755 index 00000000..99df6b5b --- /dev/null +++ b/R2R/r2r/pipes/retrieval/query_transform_pipe.py @@ -0,0 +1,101 @@ +import logging +import uuid +from typing import Any, AsyncGenerator, Optional + +from r2r.base import ( + AsyncPipe, + AsyncState, + LLMProvider, + PipeType, + PromptProvider, +) +from r2r.base.abstractions.llm import GenerationConfig + +from ..abstractions.generator_pipe import GeneratorPipe + +logger = logging.getLogger(__name__) + + +class QueryTransformPipe(GeneratorPipe): + class QueryTransformConfig(GeneratorPipe.PipeConfig): + name: str = "default_query_transform" + system_prompt: str = "default_system" + task_prompt: str = "hyde" + + class Input(GeneratorPipe.Input): + message: AsyncGenerator[str, None] + + def __init__( + self, + llm_provider: LLMProvider, + prompt_provider: PromptProvider, + type: PipeType = PipeType.TRANSFORM, + config: Optional[QueryTransformConfig] = None, + *args, + **kwargs, + ): + logger.info(f"Initalizing an `QueryTransformPipe` pipe.") + super().__init__( + llm_provider=llm_provider, + prompt_provider=prompt_provider, + type=type, + config=config or QueryTransformPipe.QueryTransformConfig(), + *args, + **kwargs, + ) + + async def _run_logic( + self, + input: AsyncPipe.Input, + state: AsyncState, + run_id: uuid.UUID, + query_transform_generation_config: GenerationConfig, + num_query_xf_outputs: int = 3, + *args: Any, + **kwargs: Any, + ) -> AsyncGenerator[str, None]: + async for query in input.message: + logger.info( + f"Transforming query: {query} into {num_query_xf_outputs} outputs with {self.config.task_prompt}." + ) + + query_transform_request = self._get_message_payload( + query, num_outputs=num_query_xf_outputs + ) + + response = await self.llm_provider.aget_completion( + messages=query_transform_request, + generation_config=query_transform_generation_config, + ) + content = self.llm_provider.extract_content(response) + outputs = content.split("\n") + outputs = [ + output.strip() for output in outputs if output.strip() != "" + ] + await state.update( + self.config.name, {"output": {"outputs": outputs}} + ) + + for output in outputs: + logger.info(f"Yielding transformed output: {output}") + yield output + + def _get_message_payload(self, input: str, num_outputs: int) -> dict: + return [ + { + "role": "system", + "content": self.prompt_provider.get_prompt( + self.config.system_prompt, + ), + }, + { + "role": "user", + "content": self.prompt_provider.get_prompt( + self.config.task_prompt, + inputs={ + "message": input, + "num_outputs": num_outputs, + }, + ), + }, + ] |