# type: ignore import asyncio import json import logging from abc import ABC, abstractmethod from datetime import datetime from json import JSONDecodeError from typing import Any, AsyncGenerator, Optional, Type from pydantic import BaseModel from core.base.abstractions import ( GenerationConfig, LLMChatCompletion, Message, ) from core.base.providers import CompletionProvider, DatabaseProvider from .base import Tool, ToolResult logger = logging.getLogger() class Conversation: def __init__(self): self.messages: list[Message] = [] self._lock = asyncio.Lock() async def add_message(self, message): async with self._lock: self.messages.append(message) async def get_messages(self) -> list[dict[str, Any]]: async with self._lock: return [ {**msg.model_dump(exclude_none=True), "role": str(msg.role)} for msg in self.messages ] # TODO - Move agents to provider pattern class AgentConfig(BaseModel): rag_rag_agent_static_prompt: str = "static_rag_agent" rag_agent_dynamic_prompt: str = "dynamic_reasoning_rag_agent_prompted" stream: bool = False include_tools: bool = True max_iterations: int = 10 @classmethod def create(cls: Type["AgentConfig"], **kwargs: Any) -> "AgentConfig": base_args = cls.model_fields.keys() filtered_kwargs = { k: v if v != "None" else None for k, v in kwargs.items() if k in base_args } return cls(**filtered_kwargs) # type: ignore class Agent(ABC): def __init__( self, llm_provider: CompletionProvider, database_provider: DatabaseProvider, config: AgentConfig, rag_generation_config: GenerationConfig, ): self.llm_provider = llm_provider self.database_provider: DatabaseProvider = database_provider self.config = config self.conversation = Conversation() self._completed = False self._tools: list[Tool] = [] self.tool_calls: list[dict] = [] self.rag_generation_config = rag_generation_config self._register_tools() @abstractmethod def _register_tools(self): pass async def _setup( self, system_instruction: Optional[str] = None, *args, **kwargs ): await self.conversation.add_message( Message( role="system", content=system_instruction or ( await self.database_provider.prompts_handler.get_cached_prompt( self.config.rag_rag_agent_static_prompt, inputs={ "date": str(datetime.now().strftime("%m/%d/%Y")) }, ) + f"\n Note,you only have {self.config.max_iterations} iterations or tool calls to reach a conclusion before your operation terminates." ), ) ) @property def tools(self) -> list[Tool]: return self._tools @tools.setter def tools(self, tools: list[Tool]): self._tools = tools @abstractmethod async def arun( self, system_instruction: Optional[str] = None, messages: Optional[list[Message]] = None, *args, **kwargs, ) -> list[LLMChatCompletion] | AsyncGenerator[LLMChatCompletion, None]: pass @abstractmethod async def process_llm_response( self, response: Any, *args, **kwargs, ) -> None | AsyncGenerator[str, None]: pass async def execute_tool(self, tool_name: str, *args, **kwargs) -> str: if tool := next((t for t in self.tools if t.name == tool_name), None): return await tool.results_function(*args, **kwargs) else: return f"Error: Tool {tool_name} not found." def get_generation_config( self, last_message: dict, stream: bool = False ) -> GenerationConfig: if ( last_message["role"] in ["tool", "function"] and last_message["content"] != "" and "ollama" in self.rag_generation_config.model or not self.config.include_tools ): return GenerationConfig( **self.rag_generation_config.model_dump( exclude={"functions", "tools", "stream"} ), stream=stream, ) return GenerationConfig( **self.rag_generation_config.model_dump( exclude={"functions", "tools", "stream"} ), # FIXME: Use tools instead of functions # TODO - Investigate why `tools` fails with OpenAI+LiteLLM tools=( [ { "function": { "name": tool.name, "description": tool.description, "parameters": tool.parameters, }, "type": "function", "name": tool.name, } for tool in self.tools ] if self.tools else None ), stream=stream, ) async def handle_function_or_tool_call( self, function_name: str, function_arguments: str, tool_id: Optional[str] = None, save_messages: bool = True, *args, **kwargs, ) -> ToolResult: logger.debug( f"Calling function: {function_name}, args: {function_arguments}, tool_id: {tool_id}" ) if tool := next( (t for t in self.tools if t.name == function_name), None ): try: function_args = json.loads(function_arguments) except JSONDecodeError as e: error_message = f"Calling the requested tool '{function_name}' with arguments {function_arguments} failed with `JSONDecodeError`." if save_messages: await self.conversation.add_message( Message( role="tool" if tool_id else "function", content=error_message, name=function_name, tool_call_id=tool_id, ) ) # raise R2RException( # message=f"Error parsing function arguments: {e}, agent likely produced invalid tool inputs.", # status_code=400, # ) merged_kwargs = {**kwargs, **function_args} try: raw_result = await tool.results_function( *args, **merged_kwargs ) llm_formatted_result = tool.llm_format_function(raw_result) except Exception as e: raw_result = f"Calling the requested tool '{function_name}' with arguments {function_arguments} failed with an exception: {e}." logger.error(raw_result) llm_formatted_result = raw_result tool_result = ToolResult( raw_result=raw_result, llm_formatted_result=llm_formatted_result, ) if tool.stream_function: tool_result.stream_result = tool.stream_function(raw_result) if save_messages: await self.conversation.add_message( Message( role="tool" if tool_id else "function", content=str(tool_result.llm_formatted_result), name=function_name, tool_call_id=tool_id, ) ) # HACK - to fix issues with claude thinking + tool use [https://github.com/anthropics/anthropic-cookbook/blob/main/extended_thinking/extended_thinking_with_tool_use.ipynb] if self.rag_generation_config.extended_thinking: await self.conversation.add_message( Message( role="user", content="Continue...", ) ) self.tool_calls.append( { "name": function_name, "args": function_arguments, } ) return tool_result # TODO - Move agents to provider pattern class RAGAgentConfig(AgentConfig): rag_rag_agent_static_prompt: str = "static_rag_agent" rag_agent_dynamic_prompt: str = "dynamic_reasoning_rag_agent_prompted" stream: bool = False include_tools: bool = True max_iterations: int = 10 # tools: list[str] = [] # HACK - unused variable. # Default RAG tools rag_tools: list[str] = [ "search_file_descriptions", "search_file_knowledge", "get_file_content", ] # Default Research tools research_tools: list[str] = [ "rag", "reasoning", # DISABLED by default "critique", "python_executor", ] @classmethod def create(cls: Type["AgentConfig"], **kwargs: Any) -> "AgentConfig": base_args = cls.model_fields.keys() filtered_kwargs = { k: v if v != "None" else None for k, v in kwargs.items() if k in base_args } filtered_kwargs["tools"] = kwargs.get("tools", None) or kwargs.get( "tool_names", None ) return cls(**filtered_kwargs) # type: ignore