aboutsummaryrefslogtreecommitdiff
path: root/R2R/r2r/pipes/retrieval/query_transform_pipe.py
blob: 99df6b5bbc1124c097edac5783f62f60dbd312d3 (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import logging
import uuid
from typing import Any, AsyncGenerator, Optional

from r2r.base import (
    AsyncPipe,
    AsyncState,
    LLMProvider,
    PipeType,
    PromptProvider,
)
from r2r.base.abstractions.llm import GenerationConfig

from ..abstractions.generator_pipe import GeneratorPipe

logger = logging.getLogger(__name__)


class QueryTransformPipe(GeneratorPipe):
    class QueryTransformConfig(GeneratorPipe.PipeConfig):
        name: str = "default_query_transform"
        system_prompt: str = "default_system"
        task_prompt: str = "hyde"

    class Input(GeneratorPipe.Input):
        message: AsyncGenerator[str, None]

    def __init__(
        self,
        llm_provider: LLMProvider,
        prompt_provider: PromptProvider,
        type: PipeType = PipeType.TRANSFORM,
        config: Optional[QueryTransformConfig] = None,
        *args,
        **kwargs,
    ):
        logger.info(f"Initalizing an `QueryTransformPipe` pipe.")
        super().__init__(
            llm_provider=llm_provider,
            prompt_provider=prompt_provider,
            type=type,
            config=config or QueryTransformPipe.QueryTransformConfig(),
            *args,
            **kwargs,
        )

    async def _run_logic(
        self,
        input: AsyncPipe.Input,
        state: AsyncState,
        run_id: uuid.UUID,
        query_transform_generation_config: GenerationConfig,
        num_query_xf_outputs: int = 3,
        *args: Any,
        **kwargs: Any,
    ) -> AsyncGenerator[str, None]:
        async for query in input.message:
            logger.info(
                f"Transforming query: {query} into {num_query_xf_outputs} outputs with {self.config.task_prompt}."
            )

            query_transform_request = self._get_message_payload(
                query, num_outputs=num_query_xf_outputs
            )

            response = await self.llm_provider.aget_completion(
                messages=query_transform_request,
                generation_config=query_transform_generation_config,
            )
            content = self.llm_provider.extract_content(response)
            outputs = content.split("\n")
            outputs = [
                output.strip() for output in outputs if output.strip() != ""
            ]
            await state.update(
                self.config.name, {"output": {"outputs": outputs}}
            )

            for output in outputs:
                logger.info(f"Yielding transformed output: {output}")
                yield output

    def _get_message_payload(self, input: str, num_outputs: int) -> dict:
        return [
            {
                "role": "system",
                "content": self.prompt_provider.get_prompt(
                    self.config.system_prompt,
                ),
            },
            {
                "role": "user",
                "content": self.prompt_provider.get_prompt(
                    self.config.task_prompt,
                    inputs={
                        "message": input,
                        "num_outputs": num_outputs,
                    },
                ),
            },
        ]