diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/core/base/agent')
3 files changed, 330 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/core/base/agent/__init__.py b/.venv/lib/python3.12/site-packages/core/base/agent/__init__.py new file mode 100644 index 00000000..815b9ae7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/base/agent/__init__.py @@ -0,0 +1,17 @@ +# FIXME: Once the agent is properly type annotated, remove the type: ignore comments +from .agent import ( # type: ignore + Agent, + AgentConfig, + Conversation, + Tool, + ToolResult, +) + +__all__ = [ + # Agent abstractions + "Agent", + "AgentConfig", + "Conversation", + "Tool", + "ToolResult", +] diff --git a/.venv/lib/python3.12/site-packages/core/base/agent/agent.py b/.venv/lib/python3.12/site-packages/core/base/agent/agent.py new file mode 100644 index 00000000..6813dd21 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/base/agent/agent.py @@ -0,0 +1,291 @@ +# type: ignore +import asyncio +import json +import logging +from abc import ABC, abstractmethod +from datetime import datetime +from json import JSONDecodeError +from typing import Any, AsyncGenerator, Optional, Type + +from pydantic import BaseModel + +from core.base.abstractions import ( + GenerationConfig, + LLMChatCompletion, + Message, +) +from core.base.providers import CompletionProvider, DatabaseProvider + +from .base import Tool, ToolResult + +logger = logging.getLogger() + + +class Conversation: + def __init__(self): + self.messages: list[Message] = [] + self._lock = asyncio.Lock() + + async def add_message(self, message): + async with self._lock: + self.messages.append(message) + + async def get_messages(self) -> list[dict[str, Any]]: + async with self._lock: + return [ + {**msg.model_dump(exclude_none=True), "role": str(msg.role)} + for msg in self.messages + ] + + +# TODO - Move agents to provider pattern +class AgentConfig(BaseModel): + rag_rag_agent_static_prompt: str = "static_rag_agent" + rag_agent_dynamic_prompt: str = "dynamic_reasoning_rag_agent_prompted" + stream: bool = False + include_tools: bool = True + max_iterations: int = 10 + + @classmethod + def create(cls: Type["AgentConfig"], **kwargs: Any) -> "AgentConfig": + base_args = cls.model_fields.keys() + filtered_kwargs = { + k: v if v != "None" else None + for k, v in kwargs.items() + if k in base_args + } + return cls(**filtered_kwargs) # type: ignore + + +class Agent(ABC): + def __init__( + self, + llm_provider: CompletionProvider, + database_provider: DatabaseProvider, + config: AgentConfig, + rag_generation_config: GenerationConfig, + ): + self.llm_provider = llm_provider + self.database_provider: DatabaseProvider = database_provider + self.config = config + self.conversation = Conversation() + self._completed = False + self._tools: list[Tool] = [] + self.tool_calls: list[dict] = [] + self.rag_generation_config = rag_generation_config + self._register_tools() + + @abstractmethod + def _register_tools(self): + pass + + async def _setup( + self, system_instruction: Optional[str] = None, *args, **kwargs + ): + await self.conversation.add_message( + Message( + role="system", + content=system_instruction + or ( + await self.database_provider.prompts_handler.get_cached_prompt( + self.config.rag_rag_agent_static_prompt, + inputs={ + "date": str(datetime.now().strftime("%m/%d/%Y")) + }, + ) + + f"\n Note,you only have {self.config.max_iterations} iterations or tool calls to reach a conclusion before your operation terminates." + ), + ) + ) + + @property + def tools(self) -> list[Tool]: + return self._tools + + @tools.setter + def tools(self, tools: list[Tool]): + self._tools = tools + + @abstractmethod + async def arun( + self, + system_instruction: Optional[str] = None, + messages: Optional[list[Message]] = None, + *args, + **kwargs, + ) -> list[LLMChatCompletion] | AsyncGenerator[LLMChatCompletion, None]: + pass + + @abstractmethod + async def process_llm_response( + self, + response: Any, + *args, + **kwargs, + ) -> None | AsyncGenerator[str, None]: + pass + + async def execute_tool(self, tool_name: str, *args, **kwargs) -> str: + if tool := next((t for t in self.tools if t.name == tool_name), None): + return await tool.results_function(*args, **kwargs) + else: + return f"Error: Tool {tool_name} not found." + + def get_generation_config( + self, last_message: dict, stream: bool = False + ) -> GenerationConfig: + if ( + last_message["role"] in ["tool", "function"] + and last_message["content"] != "" + and "ollama" in self.rag_generation_config.model + or not self.config.include_tools + ): + return GenerationConfig( + **self.rag_generation_config.model_dump( + exclude={"functions", "tools", "stream"} + ), + stream=stream, + ) + + return GenerationConfig( + **self.rag_generation_config.model_dump( + exclude={"functions", "tools", "stream"} + ), + # FIXME: Use tools instead of functions + # TODO - Investigate why `tools` fails with OpenAI+LiteLLM + tools=( + [ + { + "function": { + "name": tool.name, + "description": tool.description, + "parameters": tool.parameters, + }, + "type": "function", + "name": tool.name, + } + for tool in self.tools + ] + if self.tools + else None + ), + stream=stream, + ) + + async def handle_function_or_tool_call( + self, + function_name: str, + function_arguments: str, + tool_id: Optional[str] = None, + save_messages: bool = True, + *args, + **kwargs, + ) -> ToolResult: + logger.debug( + f"Calling function: {function_name}, args: {function_arguments}, tool_id: {tool_id}" + ) + if tool := next( + (t for t in self.tools if t.name == function_name), None + ): + try: + function_args = json.loads(function_arguments) + + except JSONDecodeError as e: + error_message = f"Calling the requested tool '{function_name}' with arguments {function_arguments} failed with `JSONDecodeError`." + if save_messages: + await self.conversation.add_message( + Message( + role="tool" if tool_id else "function", + content=error_message, + name=function_name, + tool_call_id=tool_id, + ) + ) + + # raise R2RException( + # message=f"Error parsing function arguments: {e}, agent likely produced invalid tool inputs.", + # status_code=400, + # ) + + merged_kwargs = {**kwargs, **function_args} + try: + raw_result = await tool.results_function( + *args, **merged_kwargs + ) + llm_formatted_result = tool.llm_format_function(raw_result) + except Exception as e: + raw_result = f"Calling the requested tool '{function_name}' with arguments {function_arguments} failed with an exception: {e}." + logger.error(raw_result) + llm_formatted_result = raw_result + + tool_result = ToolResult( + raw_result=raw_result, + llm_formatted_result=llm_formatted_result, + ) + if tool.stream_function: + tool_result.stream_result = tool.stream_function(raw_result) + + if save_messages: + await self.conversation.add_message( + Message( + role="tool" if tool_id else "function", + content=str(tool_result.llm_formatted_result), + name=function_name, + tool_call_id=tool_id, + ) + ) + # HACK - to fix issues with claude thinking + tool use [https://github.com/anthropics/anthropic-cookbook/blob/main/extended_thinking/extended_thinking_with_tool_use.ipynb] + if self.rag_generation_config.extended_thinking: + await self.conversation.add_message( + Message( + role="user", + content="Continue...", + ) + ) + + self.tool_calls.append( + { + "name": function_name, + "args": function_arguments, + } + ) + return tool_result + + +# TODO - Move agents to provider pattern +class RAGAgentConfig(AgentConfig): + rag_rag_agent_static_prompt: str = "static_rag_agent" + rag_agent_dynamic_prompt: str = "dynamic_reasoning_rag_agent_prompted" + stream: bool = False + include_tools: bool = True + max_iterations: int = 10 + # tools: list[str] = [] # HACK - unused variable. + + # Default RAG tools + rag_tools: list[str] = [ + "search_file_descriptions", + "search_file_knowledge", + "get_file_content", + ] + + # Default Research tools + research_tools: list[str] = [ + "rag", + "reasoning", + # DISABLED by default + "critique", + "python_executor", + ] + + @classmethod + def create(cls: Type["AgentConfig"], **kwargs: Any) -> "AgentConfig": + base_args = cls.model_fields.keys() + filtered_kwargs = { + k: v if v != "None" else None + for k, v in kwargs.items() + if k in base_args + } + filtered_kwargs["tools"] = kwargs.get("tools", None) or kwargs.get( + "tool_names", None + ) + return cls(**filtered_kwargs) # type: ignore diff --git a/.venv/lib/python3.12/site-packages/core/base/agent/base.py b/.venv/lib/python3.12/site-packages/core/base/agent/base.py new file mode 100644 index 00000000..0d8f15ee --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/base/agent/base.py @@ -0,0 +1,22 @@ +from typing import Any, Callable, Optional + +from ..abstractions import R2RSerializable + + +class Tool(R2RSerializable): + name: str + description: str + results_function: Callable + llm_format_function: Callable + stream_function: Optional[Callable] = None + parameters: Optional[dict[str, Any]] = None + + class Config: + populate_by_name = True + arbitrary_types_allowed = True + + +class ToolResult(R2RSerializable): + raw_result: Any + llm_formatted_result: str + stream_result: Optional[str] = None |