about summary refs log tree commit diff
path: root/R2R/r2r/pipes/retrieval/query_transform_pipe.py
diff options
context:
space:
mode:
Diffstat (limited to 'R2R/r2r/pipes/retrieval/query_transform_pipe.py')
-rwxr-xr-xR2R/r2r/pipes/retrieval/query_transform_pipe.py101
1 files changed, 101 insertions, 0 deletions
diff --git a/R2R/r2r/pipes/retrieval/query_transform_pipe.py b/R2R/r2r/pipes/retrieval/query_transform_pipe.py
new file mode 100755
index 00000000..99df6b5b
--- /dev/null
+++ b/R2R/r2r/pipes/retrieval/query_transform_pipe.py
@@ -0,0 +1,101 @@
+import logging
+import uuid
+from typing import Any, AsyncGenerator, Optional
+
+from r2r.base import (
+    AsyncPipe,
+    AsyncState,
+    LLMProvider,
+    PipeType,
+    PromptProvider,
+)
+from r2r.base.abstractions.llm import GenerationConfig
+
+from ..abstractions.generator_pipe import GeneratorPipe
+
+logger = logging.getLogger(__name__)
+
+
+class QueryTransformPipe(GeneratorPipe):
+    class QueryTransformConfig(GeneratorPipe.PipeConfig):
+        name: str = "default_query_transform"
+        system_prompt: str = "default_system"
+        task_prompt: str = "hyde"
+
+    class Input(GeneratorPipe.Input):
+        message: AsyncGenerator[str, None]
+
+    def __init__(
+        self,
+        llm_provider: LLMProvider,
+        prompt_provider: PromptProvider,
+        type: PipeType = PipeType.TRANSFORM,
+        config: Optional[QueryTransformConfig] = None,
+        *args,
+        **kwargs,
+    ):
+        logger.info(f"Initalizing an `QueryTransformPipe` pipe.")
+        super().__init__(
+            llm_provider=llm_provider,
+            prompt_provider=prompt_provider,
+            type=type,
+            config=config or QueryTransformPipe.QueryTransformConfig(),
+            *args,
+            **kwargs,
+        )
+
+    async def _run_logic(
+        self,
+        input: AsyncPipe.Input,
+        state: AsyncState,
+        run_id: uuid.UUID,
+        query_transform_generation_config: GenerationConfig,
+        num_query_xf_outputs: int = 3,
+        *args: Any,
+        **kwargs: Any,
+    ) -> AsyncGenerator[str, None]:
+        async for query in input.message:
+            logger.info(
+                f"Transforming query: {query} into {num_query_xf_outputs} outputs with {self.config.task_prompt}."
+            )
+
+            query_transform_request = self._get_message_payload(
+                query, num_outputs=num_query_xf_outputs
+            )
+
+            response = await self.llm_provider.aget_completion(
+                messages=query_transform_request,
+                generation_config=query_transform_generation_config,
+            )
+            content = self.llm_provider.extract_content(response)
+            outputs = content.split("\n")
+            outputs = [
+                output.strip() for output in outputs if output.strip() != ""
+            ]
+            await state.update(
+                self.config.name, {"output": {"outputs": outputs}}
+            )
+
+            for output in outputs:
+                logger.info(f"Yielding transformed output: {output}")
+                yield output
+
+    def _get_message_payload(self, input: str, num_outputs: int) -> dict:
+        return [
+            {
+                "role": "system",
+                "content": self.prompt_provider.get_prompt(
+                    self.config.system_prompt,
+                ),
+            },
+            {
+                "role": "user",
+                "content": self.prompt_provider.get_prompt(
+                    self.config.task_prompt,
+                    inputs={
+                        "message": input,
+                        "num_outputs": num_outputs,
+                    },
+                ),
+            },
+        ]