aboutsummaryrefslogtreecommitdiff
path: root/R2R/r2r/pipes/retrieval/query_transform_pipe.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /R2R/r2r/pipes/retrieval/query_transform_pipe.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-4a52a71956a8d46fcb7294ac71734504bb09bcc2.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to 'R2R/r2r/pipes/retrieval/query_transform_pipe.py')
-rwxr-xr-xR2R/r2r/pipes/retrieval/query_transform_pipe.py101
1 files changed, 101 insertions, 0 deletions
diff --git a/R2R/r2r/pipes/retrieval/query_transform_pipe.py b/R2R/r2r/pipes/retrieval/query_transform_pipe.py
new file mode 100755
index 00000000..99df6b5b
--- /dev/null
+++ b/R2R/r2r/pipes/retrieval/query_transform_pipe.py
@@ -0,0 +1,101 @@
+import logging
+import uuid
+from typing import Any, AsyncGenerator, Optional
+
+from r2r.base import (
+ AsyncPipe,
+ AsyncState,
+ LLMProvider,
+ PipeType,
+ PromptProvider,
+)
+from r2r.base.abstractions.llm import GenerationConfig
+
+from ..abstractions.generator_pipe import GeneratorPipe
+
+logger = logging.getLogger(__name__)
+
+
+class QueryTransformPipe(GeneratorPipe):
+ class QueryTransformConfig(GeneratorPipe.PipeConfig):
+ name: str = "default_query_transform"
+ system_prompt: str = "default_system"
+ task_prompt: str = "hyde"
+
+ class Input(GeneratorPipe.Input):
+ message: AsyncGenerator[str, None]
+
+ def __init__(
+ self,
+ llm_provider: LLMProvider,
+ prompt_provider: PromptProvider,
+ type: PipeType = PipeType.TRANSFORM,
+ config: Optional[QueryTransformConfig] = None,
+ *args,
+ **kwargs,
+ ):
+ logger.info(f"Initalizing an `QueryTransformPipe` pipe.")
+ super().__init__(
+ llm_provider=llm_provider,
+ prompt_provider=prompt_provider,
+ type=type,
+ config=config or QueryTransformPipe.QueryTransformConfig(),
+ *args,
+ **kwargs,
+ )
+
+ async def _run_logic(
+ self,
+ input: AsyncPipe.Input,
+ state: AsyncState,
+ run_id: uuid.UUID,
+ query_transform_generation_config: GenerationConfig,
+ num_query_xf_outputs: int = 3,
+ *args: Any,
+ **kwargs: Any,
+ ) -> AsyncGenerator[str, None]:
+ async for query in input.message:
+ logger.info(
+ f"Transforming query: {query} into {num_query_xf_outputs} outputs with {self.config.task_prompt}."
+ )
+
+ query_transform_request = self._get_message_payload(
+ query, num_outputs=num_query_xf_outputs
+ )
+
+ response = await self.llm_provider.aget_completion(
+ messages=query_transform_request,
+ generation_config=query_transform_generation_config,
+ )
+ content = self.llm_provider.extract_content(response)
+ outputs = content.split("\n")
+ outputs = [
+ output.strip() for output in outputs if output.strip() != ""
+ ]
+ await state.update(
+ self.config.name, {"output": {"outputs": outputs}}
+ )
+
+ for output in outputs:
+ logger.info(f"Yielding transformed output: {output}")
+ yield output
+
+ def _get_message_payload(self, input: str, num_outputs: int) -> dict:
+ return [
+ {
+ "role": "system",
+ "content": self.prompt_provider.get_prompt(
+ self.config.system_prompt,
+ ),
+ },
+ {
+ "role": "user",
+ "content": self.prompt_provider.get_prompt(
+ self.config.task_prompt,
+ inputs={
+ "message": input,
+ "num_outputs": num_outputs,
+ },
+ ),
+ },
+ ]