about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/core/base/agent
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/core/base/agent
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/core/base/agent')
-rw-r--r--.venv/lib/python3.12/site-packages/core/base/agent/__init__.py17
-rw-r--r--.venv/lib/python3.12/site-packages/core/base/agent/agent.py291
-rw-r--r--.venv/lib/python3.12/site-packages/core/base/agent/base.py22
3 files changed, 330 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/core/base/agent/__init__.py b/.venv/lib/python3.12/site-packages/core/base/agent/__init__.py
new file mode 100644
index 00000000..815b9ae7
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/base/agent/__init__.py
@@ -0,0 +1,17 @@
+# FIXME: Once the agent is properly type annotated, remove the type: ignore comments
+from .agent import (  # type: ignore
+    Agent,
+    AgentConfig,
+    Conversation,
+    Tool,
+    ToolResult,
+)
+
+__all__ = [
+    # Agent abstractions
+    "Agent",
+    "AgentConfig",
+    "Conversation",
+    "Tool",
+    "ToolResult",
+]
diff --git a/.venv/lib/python3.12/site-packages/core/base/agent/agent.py b/.venv/lib/python3.12/site-packages/core/base/agent/agent.py
new file mode 100644
index 00000000..6813dd21
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/base/agent/agent.py
@@ -0,0 +1,291 @@
+# type: ignore
+import asyncio
+import json
+import logging
+from abc import ABC, abstractmethod
+from datetime import datetime
+from json import JSONDecodeError
+from typing import Any, AsyncGenerator, Optional, Type
+
+from pydantic import BaseModel
+
+from core.base.abstractions import (
+    GenerationConfig,
+    LLMChatCompletion,
+    Message,
+)
+from core.base.providers import CompletionProvider, DatabaseProvider
+
+from .base import Tool, ToolResult
+
+logger = logging.getLogger()
+
+
+class Conversation:
+    def __init__(self):
+        self.messages: list[Message] = []
+        self._lock = asyncio.Lock()
+
+    async def add_message(self, message):
+        async with self._lock:
+            self.messages.append(message)
+
+    async def get_messages(self) -> list[dict[str, Any]]:
+        async with self._lock:
+            return [
+                {**msg.model_dump(exclude_none=True), "role": str(msg.role)}
+                for msg in self.messages
+            ]
+
+
+# TODO - Move agents to provider pattern
+class AgentConfig(BaseModel):
+    rag_rag_agent_static_prompt: str = "static_rag_agent"
+    rag_agent_dynamic_prompt: str = "dynamic_reasoning_rag_agent_prompted"
+    stream: bool = False
+    include_tools: bool = True
+    max_iterations: int = 10
+
+    @classmethod
+    def create(cls: Type["AgentConfig"], **kwargs: Any) -> "AgentConfig":
+        base_args = cls.model_fields.keys()
+        filtered_kwargs = {
+            k: v if v != "None" else None
+            for k, v in kwargs.items()
+            if k in base_args
+        }
+        return cls(**filtered_kwargs)  # type: ignore
+
+
+class Agent(ABC):
+    def __init__(
+        self,
+        llm_provider: CompletionProvider,
+        database_provider: DatabaseProvider,
+        config: AgentConfig,
+        rag_generation_config: GenerationConfig,
+    ):
+        self.llm_provider = llm_provider
+        self.database_provider: DatabaseProvider = database_provider
+        self.config = config
+        self.conversation = Conversation()
+        self._completed = False
+        self._tools: list[Tool] = []
+        self.tool_calls: list[dict] = []
+        self.rag_generation_config = rag_generation_config
+        self._register_tools()
+
+    @abstractmethod
+    def _register_tools(self):
+        pass
+
+    async def _setup(
+        self, system_instruction: Optional[str] = None, *args, **kwargs
+    ):
+        await self.conversation.add_message(
+            Message(
+                role="system",
+                content=system_instruction
+                or (
+                    await self.database_provider.prompts_handler.get_cached_prompt(
+                        self.config.rag_rag_agent_static_prompt,
+                        inputs={
+                            "date": str(datetime.now().strftime("%m/%d/%Y"))
+                        },
+                    )
+                    + f"\n Note,you only have {self.config.max_iterations} iterations or tool calls to reach a conclusion before your operation terminates."
+                ),
+            )
+        )
+
+    @property
+    def tools(self) -> list[Tool]:
+        return self._tools
+
+    @tools.setter
+    def tools(self, tools: list[Tool]):
+        self._tools = tools
+
+    @abstractmethod
+    async def arun(
+        self,
+        system_instruction: Optional[str] = None,
+        messages: Optional[list[Message]] = None,
+        *args,
+        **kwargs,
+    ) -> list[LLMChatCompletion] | AsyncGenerator[LLMChatCompletion, None]:
+        pass
+
+    @abstractmethod
+    async def process_llm_response(
+        self,
+        response: Any,
+        *args,
+        **kwargs,
+    ) -> None | AsyncGenerator[str, None]:
+        pass
+
+    async def execute_tool(self, tool_name: str, *args, **kwargs) -> str:
+        if tool := next((t for t in self.tools if t.name == tool_name), None):
+            return await tool.results_function(*args, **kwargs)
+        else:
+            return f"Error: Tool {tool_name} not found."
+
+    def get_generation_config(
+        self, last_message: dict, stream: bool = False
+    ) -> GenerationConfig:
+        if (
+            last_message["role"] in ["tool", "function"]
+            and last_message["content"] != ""
+            and "ollama" in self.rag_generation_config.model
+            or not self.config.include_tools
+        ):
+            return GenerationConfig(
+                **self.rag_generation_config.model_dump(
+                    exclude={"functions", "tools", "stream"}
+                ),
+                stream=stream,
+            )
+
+        return GenerationConfig(
+            **self.rag_generation_config.model_dump(
+                exclude={"functions", "tools", "stream"}
+            ),
+            # FIXME: Use tools instead of functions
+            # TODO - Investigate why `tools` fails with OpenAI+LiteLLM
+            tools=(
+                [
+                    {
+                        "function": {
+                            "name": tool.name,
+                            "description": tool.description,
+                            "parameters": tool.parameters,
+                        },
+                        "type": "function",
+                        "name": tool.name,
+                    }
+                    for tool in self.tools
+                ]
+                if self.tools
+                else None
+            ),
+            stream=stream,
+        )
+
+    async def handle_function_or_tool_call(
+        self,
+        function_name: str,
+        function_arguments: str,
+        tool_id: Optional[str] = None,
+        save_messages: bool = True,
+        *args,
+        **kwargs,
+    ) -> ToolResult:
+        logger.debug(
+            f"Calling function: {function_name}, args: {function_arguments}, tool_id: {tool_id}"
+        )
+        if tool := next(
+            (t for t in self.tools if t.name == function_name), None
+        ):
+            try:
+                function_args = json.loads(function_arguments)
+
+            except JSONDecodeError as e:
+                error_message = f"Calling the requested tool '{function_name}' with arguments {function_arguments} failed with `JSONDecodeError`."
+                if save_messages:
+                    await self.conversation.add_message(
+                        Message(
+                            role="tool" if tool_id else "function",
+                            content=error_message,
+                            name=function_name,
+                            tool_call_id=tool_id,
+                        )
+                    )
+
+                # raise R2RException(
+                #     message=f"Error parsing function arguments: {e}, agent likely produced invalid tool inputs.",
+                #     status_code=400,
+                # )
+
+            merged_kwargs = {**kwargs, **function_args}
+            try:
+                raw_result = await tool.results_function(
+                    *args, **merged_kwargs
+                )
+                llm_formatted_result = tool.llm_format_function(raw_result)
+            except Exception as e:
+                raw_result = f"Calling the requested tool '{function_name}' with arguments {function_arguments} failed with an exception: {e}."
+                logger.error(raw_result)
+                llm_formatted_result = raw_result
+
+            tool_result = ToolResult(
+                raw_result=raw_result,
+                llm_formatted_result=llm_formatted_result,
+            )
+            if tool.stream_function:
+                tool_result.stream_result = tool.stream_function(raw_result)
+
+            if save_messages:
+                await self.conversation.add_message(
+                    Message(
+                        role="tool" if tool_id else "function",
+                        content=str(tool_result.llm_formatted_result),
+                        name=function_name,
+                        tool_call_id=tool_id,
+                    )
+                )
+                # HACK - to fix issues with claude thinking + tool use [https://github.com/anthropics/anthropic-cookbook/blob/main/extended_thinking/extended_thinking_with_tool_use.ipynb]
+                if self.rag_generation_config.extended_thinking:
+                    await self.conversation.add_message(
+                        Message(
+                            role="user",
+                            content="Continue...",
+                        )
+                    )
+
+            self.tool_calls.append(
+                {
+                    "name": function_name,
+                    "args": function_arguments,
+                }
+            )
+        return tool_result
+
+
+# TODO - Move agents to provider pattern
+class RAGAgentConfig(AgentConfig):
+    rag_rag_agent_static_prompt: str = "static_rag_agent"
+    rag_agent_dynamic_prompt: str = "dynamic_reasoning_rag_agent_prompted"
+    stream: bool = False
+    include_tools: bool = True
+    max_iterations: int = 10
+    # tools: list[str] = [] # HACK - unused variable.
+
+    # Default RAG tools
+    rag_tools: list[str] = [
+        "search_file_descriptions",
+        "search_file_knowledge",
+        "get_file_content",
+    ]
+
+    # Default Research tools
+    research_tools: list[str] = [
+        "rag",
+        "reasoning",
+        # DISABLED by default
+        "critique",
+        "python_executor",
+    ]
+
+    @classmethod
+    def create(cls: Type["AgentConfig"], **kwargs: Any) -> "AgentConfig":
+        base_args = cls.model_fields.keys()
+        filtered_kwargs = {
+            k: v if v != "None" else None
+            for k, v in kwargs.items()
+            if k in base_args
+        }
+        filtered_kwargs["tools"] = kwargs.get("tools", None) or kwargs.get(
+            "tool_names", None
+        )
+        return cls(**filtered_kwargs)  # type: ignore
diff --git a/.venv/lib/python3.12/site-packages/core/base/agent/base.py b/.venv/lib/python3.12/site-packages/core/base/agent/base.py
new file mode 100644
index 00000000..0d8f15ee
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/base/agent/base.py
@@ -0,0 +1,22 @@
+from typing import Any, Callable, Optional
+
+from ..abstractions import R2RSerializable
+
+
+class Tool(R2RSerializable):
+    name: str
+    description: str
+    results_function: Callable
+    llm_format_function: Callable
+    stream_function: Optional[Callable] = None
+    parameters: Optional[dict[str, Any]] = None
+
+    class Config:
+        populate_by_name = True
+        arbitrary_types_allowed = True
+
+
+class ToolResult(R2RSerializable):
+    raw_result: Any
+    llm_formatted_result: str
+    stream_result: Optional[str] = None