about summary refs log tree commit diff
path: root/R2R/r2r/pipes/other
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /R2R/r2r/pipes/other
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to 'R2R/r2r/pipes/other')
-rwxr-xr-xR2R/r2r/pipes/other/eval_pipe.py54
-rwxr-xr-xR2R/r2r/pipes/other/web_search_pipe.py105
2 files changed, 159 insertions, 0 deletions
diff --git a/R2R/r2r/pipes/other/eval_pipe.py b/R2R/r2r/pipes/other/eval_pipe.py
new file mode 100755
index 00000000..b1c60343
--- /dev/null
+++ b/R2R/r2r/pipes/other/eval_pipe.py
@@ -0,0 +1,54 @@
+import logging
+import uuid
+from typing import Any, AsyncGenerator, Optional
+
+from pydantic import BaseModel
+
+from r2r import AsyncState, EvalProvider, LLMChatCompletion, PipeType
+from r2r.base.abstractions.llm import GenerationConfig
+from r2r.base.pipes.base_pipe import AsyncPipe
+
+logger = logging.getLogger(__name__)
+
+
+class EvalPipe(AsyncPipe):
+    class EvalPayload(BaseModel):
+        query: str
+        context: str
+        completion: str
+
+    class Input(AsyncPipe.Input):
+        message: AsyncGenerator["EvalPipe.EvalPayload", None]
+
+    def __init__(
+        self,
+        eval_provider: EvalProvider,
+        type: PipeType = PipeType.EVAL,
+        config: Optional[AsyncPipe.PipeConfig] = None,
+        *args,
+        **kwargs,
+    ):
+        self.eval_provider = eval_provider
+        super().__init__(
+            type=type,
+            config=config or AsyncPipe.PipeConfig(name="default_eval_pipe"),
+            *args,
+            **kwargs,
+        )
+
+    async def _run_logic(
+        self,
+        input: Input,
+        state: AsyncState,
+        run_id: uuid.UUID,
+        eval_generation_config: GenerationConfig,
+        *args: Any,
+        **kwargs: Any,
+    ) -> AsyncGenerator[LLMChatCompletion, None]:
+        async for item in input.message:
+            yield self.eval_provider.evaluate(
+                item.query,
+                item.context,
+                item.completion,
+                eval_generation_config,
+            )
diff --git a/R2R/r2r/pipes/other/web_search_pipe.py b/R2R/r2r/pipes/other/web_search_pipe.py
new file mode 100755
index 00000000..92e3feee
--- /dev/null
+++ b/R2R/r2r/pipes/other/web_search_pipe.py
@@ -0,0 +1,105 @@
+import json
+import logging
+import uuid
+from typing import Any, AsyncGenerator, Optional
+
+from r2r.base import (
+    AsyncPipe,
+    AsyncState,
+    PipeType,
+    VectorSearchResult,
+    generate_id_from_label,
+)
+from r2r.integrations import SerperClient
+
+from ..abstractions.search_pipe import SearchPipe
+
+logger = logging.getLogger(__name__)
+
+
+class WebSearchPipe(SearchPipe):
+    def __init__(
+        self,
+        serper_client: SerperClient,
+        type: PipeType = PipeType.SEARCH,
+        config: Optional[SearchPipe.SearchConfig] = None,
+        *args,
+        **kwargs,
+    ):
+        super().__init__(
+            type=type,
+            config=config or SearchPipe.SearchConfig(),
+            *args,
+            **kwargs,
+        )
+        self.serper_client = serper_client
+
+    async def search(
+        self,
+        message: str,
+        run_id: uuid.UUID,
+        *args: Any,
+        **kwargs: Any,
+    ) -> AsyncGenerator[VectorSearchResult, None]:
+        search_limit_override = kwargs.get("search_limit", None)
+        await self.enqueue_log(
+            run_id=run_id, key="search_query", value=message
+        )
+        # TODO - Make more general in the future by creating a SearchProvider interface
+        results = self.serper_client.get_raw(
+            query=message,
+            limit=search_limit_override or self.config.search_limit,
+        )
+
+        search_results = []
+        for result in results:
+            if result.get("snippet") is None:
+                continue
+            result["text"] = result.pop("snippet")
+            search_result = VectorSearchResult(
+                id=generate_id_from_label(str(result)),
+                score=result.get(
+                    "score", 0
+                ),  # TODO - Consider dynamically generating scores based on similarity
+                metadata=result,
+            )
+            search_results.append(search_result)
+            yield search_result
+
+        await self.enqueue_log(
+            run_id=run_id,
+            key="search_results",
+            value=json.dumps([ele.json() for ele in search_results]),
+        )
+
+    async def _run_logic(
+        self,
+        input: AsyncPipe.Input,
+        state: AsyncState,
+        run_id: uuid.UUID,
+        *args: Any,
+        **kwargs,
+    ) -> AsyncGenerator[VectorSearchResult, None]:
+        search_queries = []
+        search_results = []
+        async for search_request in input.message:
+            search_queries.append(search_request)
+            async for result in self.search(
+                message=search_request, run_id=run_id, *args, **kwargs
+            ):
+                search_results.append(result)
+                yield result
+
+        await state.update(
+            self.config.name, {"output": {"search_results": search_results}}
+        )
+
+        await state.update(
+            self.config.name,
+            {
+                "output": {
+                    "search_queries": search_queries,
+                    "search_results": search_results,
+                }
+            },
+        )