aboutsummaryrefslogtreecommitdiff
path: root/R2R/r2r/pipes/other/eval_pipe.py
blob: b1c603431a0ac20e0e18390275564965fc6bb3d3 (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import logging
import uuid
from typing import Any, AsyncGenerator, Optional

from pydantic import BaseModel

from r2r import AsyncState, EvalProvider, LLMChatCompletion, PipeType
from r2r.base.abstractions.llm import GenerationConfig
from r2r.base.pipes.base_pipe import AsyncPipe

logger = logging.getLogger(__name__)


class EvalPipe(AsyncPipe):
    class EvalPayload(BaseModel):
        query: str
        context: str
        completion: str

    class Input(AsyncPipe.Input):
        message: AsyncGenerator["EvalPipe.EvalPayload", None]

    def __init__(
        self,
        eval_provider: EvalProvider,
        type: PipeType = PipeType.EVAL,
        config: Optional[AsyncPipe.PipeConfig] = None,
        *args,
        **kwargs,
    ):
        self.eval_provider = eval_provider
        super().__init__(
            type=type,
            config=config or AsyncPipe.PipeConfig(name="default_eval_pipe"),
            *args,
            **kwargs,
        )

    async def _run_logic(
        self,
        input: Input,
        state: AsyncState,
        run_id: uuid.UUID,
        eval_generation_config: GenerationConfig,
        *args: Any,
        **kwargs: Any,
    ) -> AsyncGenerator[LLMChatCompletion, None]:
        async for item in input.message:
            yield self.eval_provider.evaluate(
                item.query,
                item.context,
                item.completion,
                eval_generation_config,
            )