about summary refs log tree commit diff
path: root/R2R/r2r/pipes/retrieval/kg_agent_search_pipe.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /R2R/r2r/pipes/retrieval/kg_agent_search_pipe.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to 'R2R/r2r/pipes/retrieval/kg_agent_search_pipe.py')
-rwxr-xr-xR2R/r2r/pipes/retrieval/kg_agent_search_pipe.py103
1 files changed, 103 insertions, 0 deletions
diff --git a/R2R/r2r/pipes/retrieval/kg_agent_search_pipe.py b/R2R/r2r/pipes/retrieval/kg_agent_search_pipe.py
new file mode 100755
index 00000000..60935265
--- /dev/null
+++ b/R2R/r2r/pipes/retrieval/kg_agent_search_pipe.py
@@ -0,0 +1,103 @@
+import logging
+import uuid
+from typing import Any, Optional
+
+from r2r.base import (
+    AsyncState,
+    KGProvider,
+    KGSearchSettings,
+    KVLoggingSingleton,
+    LLMProvider,
+    PipeType,
+    PromptProvider,
+)
+
+from ..abstractions.generator_pipe import GeneratorPipe
+
+logger = logging.getLogger(__name__)
+
+
+class KGAgentSearchPipe(GeneratorPipe):
+    """
+    Embeds and stores documents using a specified embedding model and database.
+    """
+
+    def __init__(
+        self,
+        kg_provider: KGProvider,
+        llm_provider: LLMProvider,
+        prompt_provider: PromptProvider,
+        pipe_logger: Optional[KVLoggingSingleton] = None,
+        type: PipeType = PipeType.INGESTOR,
+        config: Optional[GeneratorPipe.PipeConfig] = None,
+        *args,
+        **kwargs,
+    ):
+        """
+        Initializes the embedding pipe with necessary components and configurations.
+        """
+        super().__init__(
+            llm_provider=llm_provider,
+            prompt_provider=prompt_provider,
+            type=type,
+            config=config
+            or GeneratorPipe.Config(
+                name="kg_rag_pipe", task_prompt="kg_agent"
+            ),
+            pipe_logger=pipe_logger,
+            *args,
+            **kwargs,
+        )
+        self.kg_provider = kg_provider
+        self.llm_provider = llm_provider
+        self.prompt_provider = prompt_provider
+        self.pipe_run_info = None
+
+    async def _run_logic(
+        self,
+        input: GeneratorPipe.Input,
+        state: AsyncState,
+        run_id: uuid.UUID,
+        kg_search_settings: KGSearchSettings,
+        *args: Any,
+        **kwargs: Any,
+    ):
+        async for message in input.message:
+            # TODO - Remove hard code
+            formatted_prompt = self.prompt_provider.get_prompt(
+                "kg_agent", {"input": message}
+            )
+            messages = self._get_message_payload(formatted_prompt)
+
+            result = await self.llm_provider.aget_completion(
+                messages=messages,
+                generation_config=kg_search_settings.agent_generation_config,
+            )
+
+            extraction = result.choices[0].message.content
+            query = extraction.split("```cypher")[1].split("```")[0]
+            result = self.kg_provider.structured_query(query)
+            yield (query, result)
+
+            await self.enqueue_log(
+                run_id=run_id,
+                key="kg_agent_response",
+                value=extraction,
+            )
+
+            await self.enqueue_log(
+                run_id=run_id,
+                key="kg_agent_execution_result",
+                value=result,
+            )
+
+    def _get_message_payload(self, message: str) -> dict:
+        return [
+            {
+                "role": "system",
+                "content": self.prompt_provider.get_prompt(
+                    self.config.system_prompt,
+                ),
+            },
+            {"role": "user", "content": message},
+        ]