aboutsummaryrefslogtreecommitdiff
path: root/R2R/r2r/pipelines/eval_pipeline.py
blob: 60aa50d4970d65a3b0207906b9193bb3b85d262f (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import logging
from typing import Any, Optional

from r2r.base.logging.run_manager import RunManager
from r2r.base.pipeline.base_pipeline import AsyncPipeline
from r2r.base.pipes.base_pipe import AsyncPipe, AsyncState

logger = logging.getLogger(__name__)


class EvalPipeline(AsyncPipeline):
    """A pipeline for evaluation."""

    pipeline_type: str = "eval"

    async def run(
        self,
        input: Any,
        state: Optional[AsyncState] = None,
        stream: bool = False,
        run_manager: Optional[RunManager] = None,
        *args: Any,
        **kwargs: Any,
    ):
        return await super().run(
            input, state, stream, run_manager, *args, **kwargs
        )

    def add_pipe(
        self,
        pipe: AsyncPipe,
        add_upstream_outputs: Optional[list[dict[str, str]]] = None,
        *args,
        **kwargs,
    ) -> None:
        logger.debug(f"Adding pipe {pipe.config.name} to the EvalPipeline")
        return super().add_pipe(pipe, add_upstream_outputs, *args, **kwargs)