aboutsummaryrefslogtreecommitdiff
path: root/R2R/r2r/pipes/retrieval/kg_agent_search_pipe.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /R2R/r2r/pipes/retrieval/kg_agent_search_pipe.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-4a52a71956a8d46fcb7294ac71734504bb09bcc2.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to 'R2R/r2r/pipes/retrieval/kg_agent_search_pipe.py')
-rwxr-xr-xR2R/r2r/pipes/retrieval/kg_agent_search_pipe.py103
1 files changed, 103 insertions, 0 deletions
diff --git a/R2R/r2r/pipes/retrieval/kg_agent_search_pipe.py b/R2R/r2r/pipes/retrieval/kg_agent_search_pipe.py
new file mode 100755
index 00000000..60935265
--- /dev/null
+++ b/R2R/r2r/pipes/retrieval/kg_agent_search_pipe.py
@@ -0,0 +1,103 @@
+import logging
+import uuid
+from typing import Any, Optional
+
+from r2r.base import (
+ AsyncState,
+ KGProvider,
+ KGSearchSettings,
+ KVLoggingSingleton,
+ LLMProvider,
+ PipeType,
+ PromptProvider,
+)
+
+from ..abstractions.generator_pipe import GeneratorPipe
+
+logger = logging.getLogger(__name__)
+
+
+class KGAgentSearchPipe(GeneratorPipe):
+ """
+ Embeds and stores documents using a specified embedding model and database.
+ """
+
+ def __init__(
+ self,
+ kg_provider: KGProvider,
+ llm_provider: LLMProvider,
+ prompt_provider: PromptProvider,
+ pipe_logger: Optional[KVLoggingSingleton] = None,
+ type: PipeType = PipeType.INGESTOR,
+ config: Optional[GeneratorPipe.PipeConfig] = None,
+ *args,
+ **kwargs,
+ ):
+ """
+ Initializes the embedding pipe with necessary components and configurations.
+ """
+ super().__init__(
+ llm_provider=llm_provider,
+ prompt_provider=prompt_provider,
+ type=type,
+ config=config
+ or GeneratorPipe.Config(
+ name="kg_rag_pipe", task_prompt="kg_agent"
+ ),
+ pipe_logger=pipe_logger,
+ *args,
+ **kwargs,
+ )
+ self.kg_provider = kg_provider
+ self.llm_provider = llm_provider
+ self.prompt_provider = prompt_provider
+ self.pipe_run_info = None
+
+ async def _run_logic(
+ self,
+ input: GeneratorPipe.Input,
+ state: AsyncState,
+ run_id: uuid.UUID,
+ kg_search_settings: KGSearchSettings,
+ *args: Any,
+ **kwargs: Any,
+ ):
+ async for message in input.message:
+ # TODO - Remove hard code
+ formatted_prompt = self.prompt_provider.get_prompt(
+ "kg_agent", {"input": message}
+ )
+ messages = self._get_message_payload(formatted_prompt)
+
+ result = await self.llm_provider.aget_completion(
+ messages=messages,
+ generation_config=kg_search_settings.agent_generation_config,
+ )
+
+ extraction = result.choices[0].message.content
+ query = extraction.split("```cypher")[1].split("```")[0]
+ result = self.kg_provider.structured_query(query)
+ yield (query, result)
+
+ await self.enqueue_log(
+ run_id=run_id,
+ key="kg_agent_response",
+ value=extraction,
+ )
+
+ await self.enqueue_log(
+ run_id=run_id,
+ key="kg_agent_execution_result",
+ value=result,
+ )
+
+ def _get_message_payload(self, message: str) -> dict:
+ return [
+ {
+ "role": "system",
+ "content": self.prompt_provider.get_prompt(
+ self.config.system_prompt,
+ ),
+ },
+ {"role": "user", "content": message},
+ ]