aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/core/base/agent/agent.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/core/base/agent/agent.py')
-rw-r--r--.venv/lib/python3.12/site-packages/core/base/agent/agent.py291
1 files changed, 291 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/core/base/agent/agent.py b/.venv/lib/python3.12/site-packages/core/base/agent/agent.py
new file mode 100644
index 00000000..6813dd21
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/base/agent/agent.py
@@ -0,0 +1,291 @@
+# type: ignore
+import asyncio
+import json
+import logging
+from abc import ABC, abstractmethod
+from datetime import datetime
+from json import JSONDecodeError
+from typing import Any, AsyncGenerator, Optional, Type
+
+from pydantic import BaseModel
+
+from core.base.abstractions import (
+ GenerationConfig,
+ LLMChatCompletion,
+ Message,
+)
+from core.base.providers import CompletionProvider, DatabaseProvider
+
+from .base import Tool, ToolResult
+
+logger = logging.getLogger()
+
+
+class Conversation:
+ def __init__(self):
+ self.messages: list[Message] = []
+ self._lock = asyncio.Lock()
+
+ async def add_message(self, message):
+ async with self._lock:
+ self.messages.append(message)
+
+ async def get_messages(self) -> list[dict[str, Any]]:
+ async with self._lock:
+ return [
+ {**msg.model_dump(exclude_none=True), "role": str(msg.role)}
+ for msg in self.messages
+ ]
+
+
+# TODO - Move agents to provider pattern
+class AgentConfig(BaseModel):
+ rag_rag_agent_static_prompt: str = "static_rag_agent"
+ rag_agent_dynamic_prompt: str = "dynamic_reasoning_rag_agent_prompted"
+ stream: bool = False
+ include_tools: bool = True
+ max_iterations: int = 10
+
+ @classmethod
+ def create(cls: Type["AgentConfig"], **kwargs: Any) -> "AgentConfig":
+ base_args = cls.model_fields.keys()
+ filtered_kwargs = {
+ k: v if v != "None" else None
+ for k, v in kwargs.items()
+ if k in base_args
+ }
+ return cls(**filtered_kwargs) # type: ignore
+
+
+class Agent(ABC):
+ def __init__(
+ self,
+ llm_provider: CompletionProvider,
+ database_provider: DatabaseProvider,
+ config: AgentConfig,
+ rag_generation_config: GenerationConfig,
+ ):
+ self.llm_provider = llm_provider
+ self.database_provider: DatabaseProvider = database_provider
+ self.config = config
+ self.conversation = Conversation()
+ self._completed = False
+ self._tools: list[Tool] = []
+ self.tool_calls: list[dict] = []
+ self.rag_generation_config = rag_generation_config
+ self._register_tools()
+
+ @abstractmethod
+ def _register_tools(self):
+ pass
+
+ async def _setup(
+ self, system_instruction: Optional[str] = None, *args, **kwargs
+ ):
+ await self.conversation.add_message(
+ Message(
+ role="system",
+ content=system_instruction
+ or (
+ await self.database_provider.prompts_handler.get_cached_prompt(
+ self.config.rag_rag_agent_static_prompt,
+ inputs={
+ "date": str(datetime.now().strftime("%m/%d/%Y"))
+ },
+ )
+ + f"\n Note,you only have {self.config.max_iterations} iterations or tool calls to reach a conclusion before your operation terminates."
+ ),
+ )
+ )
+
+ @property
+ def tools(self) -> list[Tool]:
+ return self._tools
+
+ @tools.setter
+ def tools(self, tools: list[Tool]):
+ self._tools = tools
+
+ @abstractmethod
+ async def arun(
+ self,
+ system_instruction: Optional[str] = None,
+ messages: Optional[list[Message]] = None,
+ *args,
+ **kwargs,
+ ) -> list[LLMChatCompletion] | AsyncGenerator[LLMChatCompletion, None]:
+ pass
+
+ @abstractmethod
+ async def process_llm_response(
+ self,
+ response: Any,
+ *args,
+ **kwargs,
+ ) -> None | AsyncGenerator[str, None]:
+ pass
+
+ async def execute_tool(self, tool_name: str, *args, **kwargs) -> str:
+ if tool := next((t for t in self.tools if t.name == tool_name), None):
+ return await tool.results_function(*args, **kwargs)
+ else:
+ return f"Error: Tool {tool_name} not found."
+
+ def get_generation_config(
+ self, last_message: dict, stream: bool = False
+ ) -> GenerationConfig:
+ if (
+ last_message["role"] in ["tool", "function"]
+ and last_message["content"] != ""
+ and "ollama" in self.rag_generation_config.model
+ or not self.config.include_tools
+ ):
+ return GenerationConfig(
+ **self.rag_generation_config.model_dump(
+ exclude={"functions", "tools", "stream"}
+ ),
+ stream=stream,
+ )
+
+ return GenerationConfig(
+ **self.rag_generation_config.model_dump(
+ exclude={"functions", "tools", "stream"}
+ ),
+ # FIXME: Use tools instead of functions
+ # TODO - Investigate why `tools` fails with OpenAI+LiteLLM
+ tools=(
+ [
+ {
+ "function": {
+ "name": tool.name,
+ "description": tool.description,
+ "parameters": tool.parameters,
+ },
+ "type": "function",
+ "name": tool.name,
+ }
+ for tool in self.tools
+ ]
+ if self.tools
+ else None
+ ),
+ stream=stream,
+ )
+
+ async def handle_function_or_tool_call(
+ self,
+ function_name: str,
+ function_arguments: str,
+ tool_id: Optional[str] = None,
+ save_messages: bool = True,
+ *args,
+ **kwargs,
+ ) -> ToolResult:
+ logger.debug(
+ f"Calling function: {function_name}, args: {function_arguments}, tool_id: {tool_id}"
+ )
+ if tool := next(
+ (t for t in self.tools if t.name == function_name), None
+ ):
+ try:
+ function_args = json.loads(function_arguments)
+
+ except JSONDecodeError as e:
+ error_message = f"Calling the requested tool '{function_name}' with arguments {function_arguments} failed with `JSONDecodeError`."
+ if save_messages:
+ await self.conversation.add_message(
+ Message(
+ role="tool" if tool_id else "function",
+ content=error_message,
+ name=function_name,
+ tool_call_id=tool_id,
+ )
+ )
+
+ # raise R2RException(
+ # message=f"Error parsing function arguments: {e}, agent likely produced invalid tool inputs.",
+ # status_code=400,
+ # )
+
+ merged_kwargs = {**kwargs, **function_args}
+ try:
+ raw_result = await tool.results_function(
+ *args, **merged_kwargs
+ )
+ llm_formatted_result = tool.llm_format_function(raw_result)
+ except Exception as e:
+ raw_result = f"Calling the requested tool '{function_name}' with arguments {function_arguments} failed with an exception: {e}."
+ logger.error(raw_result)
+ llm_formatted_result = raw_result
+
+ tool_result = ToolResult(
+ raw_result=raw_result,
+ llm_formatted_result=llm_formatted_result,
+ )
+ if tool.stream_function:
+ tool_result.stream_result = tool.stream_function(raw_result)
+
+ if save_messages:
+ await self.conversation.add_message(
+ Message(
+ role="tool" if tool_id else "function",
+ content=str(tool_result.llm_formatted_result),
+ name=function_name,
+ tool_call_id=tool_id,
+ )
+ )
+ # HACK - to fix issues with claude thinking + tool use [https://github.com/anthropics/anthropic-cookbook/blob/main/extended_thinking/extended_thinking_with_tool_use.ipynb]
+ if self.rag_generation_config.extended_thinking:
+ await self.conversation.add_message(
+ Message(
+ role="user",
+ content="Continue...",
+ )
+ )
+
+ self.tool_calls.append(
+ {
+ "name": function_name,
+ "args": function_arguments,
+ }
+ )
+ return tool_result
+
+
+# TODO - Move agents to provider pattern
+class RAGAgentConfig(AgentConfig):
+ rag_rag_agent_static_prompt: str = "static_rag_agent"
+ rag_agent_dynamic_prompt: str = "dynamic_reasoning_rag_agent_prompted"
+ stream: bool = False
+ include_tools: bool = True
+ max_iterations: int = 10
+ # tools: list[str] = [] # HACK - unused variable.
+
+ # Default RAG tools
+ rag_tools: list[str] = [
+ "search_file_descriptions",
+ "search_file_knowledge",
+ "get_file_content",
+ ]
+
+ # Default Research tools
+ research_tools: list[str] = [
+ "rag",
+ "reasoning",
+ # DISABLED by default
+ "critique",
+ "python_executor",
+ ]
+
+ @classmethod
+ def create(cls: Type["AgentConfig"], **kwargs: Any) -> "AgentConfig":
+ base_args = cls.model_fields.keys()
+ filtered_kwargs = {
+ k: v if v != "None" else None
+ for k, v in kwargs.items()
+ if k in base_args
+ }
+ filtered_kwargs["tools"] = kwargs.get("tools", None) or kwargs.get(
+ "tool_names", None
+ )
+ return cls(**filtered_kwargs) # type: ignore