about summary refs log tree commit diff
path: root/R2R/r2r/pipes/other/eval_pipe.py
diff options
context:
space:
mode:
Diffstat (limited to 'R2R/r2r/pipes/other/eval_pipe.py')
-rwxr-xr-xR2R/r2r/pipes/other/eval_pipe.py54
1 files changed, 54 insertions, 0 deletions
diff --git a/R2R/r2r/pipes/other/eval_pipe.py b/R2R/r2r/pipes/other/eval_pipe.py
new file mode 100755
index 00000000..b1c60343
--- /dev/null
+++ b/R2R/r2r/pipes/other/eval_pipe.py
@@ -0,0 +1,54 @@
+import logging
+import uuid
+from typing import Any, AsyncGenerator, Optional
+
+from pydantic import BaseModel
+
+from r2r import AsyncState, EvalProvider, LLMChatCompletion, PipeType
+from r2r.base.abstractions.llm import GenerationConfig
+from r2r.base.pipes.base_pipe import AsyncPipe
+
+logger = logging.getLogger(__name__)
+
+
+class EvalPipe(AsyncPipe):
+    class EvalPayload(BaseModel):
+        query: str
+        context: str
+        completion: str
+
+    class Input(AsyncPipe.Input):
+        message: AsyncGenerator["EvalPipe.EvalPayload", None]
+
+    def __init__(
+        self,
+        eval_provider: EvalProvider,
+        type: PipeType = PipeType.EVAL,
+        config: Optional[AsyncPipe.PipeConfig] = None,
+        *args,
+        **kwargs,
+    ):
+        self.eval_provider = eval_provider
+        super().__init__(
+            type=type,
+            config=config or AsyncPipe.PipeConfig(name="default_eval_pipe"),
+            *args,
+            **kwargs,
+        )
+
+    async def _run_logic(
+        self,
+        input: Input,
+        state: AsyncState,
+        run_id: uuid.UUID,
+        eval_generation_config: GenerationConfig,
+        *args: Any,
+        **kwargs: Any,
+    ) -> AsyncGenerator[LLMChatCompletion, None]:
+        async for item in input.message:
+            yield self.eval_provider.evaluate(
+                item.query,
+                item.context,
+                item.completion,
+                eval_generation_config,
+            )