# type: ignore
import asyncio
import json
import logging
from abc import ABC, abstractmethod
from datetime import datetime
from json import JSONDecodeError
from typing import Any, AsyncGenerator, Optional, Type
from pydantic import BaseModel
from core.base.abstractions import (
GenerationConfig,
LLMChatCompletion,
Message,
)
from core.base.providers import CompletionProvider, DatabaseProvider
from .base import Tool, ToolResult
logger = logging.getLogger()
class Conversation:
def __init__(self):
self.messages: list[Message] = []
self._lock = asyncio.Lock()
async def add_message(self, message):
async with self._lock:
self.messages.append(message)
async def get_messages(self) -> list[dict[str, Any]]:
async with self._lock:
return [
{**msg.model_dump(exclude_none=True), "role": str(msg.role)}
for msg in self.messages
]
# TODO - Move agents to provider pattern
class AgentConfig(BaseModel):
rag_rag_agent_static_prompt: str = "static_rag_agent"
rag_agent_dynamic_prompt: str = "dynamic_reasoning_rag_agent_prompted"
stream: bool = False
include_tools: bool = True
max_iterations: int = 10
@classmethod
def create(cls: Type["AgentConfig"], **kwargs: Any) -> "AgentConfig":
base_args = cls.model_fields.keys()
filtered_kwargs = {
k: v if v != "None" else None
for k, v in kwargs.items()
if k in base_args
}
return cls(**filtered_kwargs) # type: ignore
class Agent(ABC):
def __init__(
self,
llm_provider: CompletionProvider,
database_provider: DatabaseProvider,
config: AgentConfig,
rag_generation_config: GenerationConfig,
):
self.llm_provider = llm_provider
self.database_provider: DatabaseProvider = database_provider
self.config = config
self.conversation = Conversation()
self._completed = False
self._tools: list[Tool] = []
self.tool_calls: list[dict] = []
self.rag_generation_config = rag_generation_config
self._register_tools()
@abstractmethod
def _register_tools(self):
pass
async def _setup(
self, system_instruction: Optional[str] = None, *args, **kwargs
):
await self.conversation.add_message(
Message(
role="system",
content=system_instruction
or (
await self.database_provider.prompts_handler.get_cached_prompt(
self.config.rag_rag_agent_static_prompt,
inputs={
"date": str(datetime.now().strftime("%m/%d/%Y"))
},
)
+ f"\n Note,you only have {self.config.max_iterations} iterations or tool calls to reach a conclusion before your operation terminates."
),
)
)
@property
def tools(self) -> list[Tool]:
return self._tools
@tools.setter
def tools(self, tools: list[Tool]):
self._tools = tools
@abstractmethod
async def arun(
self,
system_instruction: Optional[str] = None,
messages: Optional[list[Message]] = None,
*args,
**kwargs,
) -> list[LLMChatCompletion] | AsyncGenerator[LLMChatCompletion, None]:
pass
@abstractmethod
async def process_llm_response(
self,
response: Any,
*args,
**kwargs,
) -> None | AsyncGenerator[str, None]:
pass
async def execute_tool(self, tool_name: str, *args, **kwargs) -> str:
if tool := next((t for t in self.tools if t.name == tool_name), None):
return await tool.results_function(*args, **kwargs)
else:
return f"Error: Tool {tool_name} not found."
def get_generation_config(
self, last_message: dict, stream: bool = False
) -> GenerationConfig:
if (
last_message["role"] in ["tool", "function"]
and last_message["content"] != ""
and "ollama" in self.rag_generation_config.model
or not self.config.include_tools
):
return GenerationConfig(
**self.rag_generation_config.model_dump(
exclude={"functions", "tools", "stream"}
),
stream=stream,
)
return GenerationConfig(
**self.rag_generation_config.model_dump(
exclude={"functions", "tools", "stream"}
),
# FIXME: Use tools instead of functions
# TODO - Investigate why `tools` fails with OpenAI+LiteLLM
tools=(
[
{
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.parameters,
},
"type": "function",
"name": tool.name,
}
for tool in self.tools
]
if self.tools
else None
),
stream=stream,
)
async def handle_function_or_tool_call(
self,
function_name: str,
function_arguments: str,
tool_id: Optional[str] = None,
save_messages: bool = True,
*args,
**kwargs,
) -> ToolResult:
logger.debug(
f"Calling function: {function_name}, args: {function_arguments}, tool_id: {tool_id}"
)
if tool := next(
(t for t in self.tools if t.name == function_name), None
):
try:
function_args = json.loads(function_arguments)
except JSONDecodeError as e:
error_message = f"Calling the requested tool '{function_name}' with arguments {function_arguments} failed with `JSONDecodeError`."
if save_messages:
await self.conversation.add_message(
Message(
role="tool" if tool_id else "function",
content=error_message,
name=function_name,
tool_call_id=tool_id,
)
)
# raise R2RException(
# message=f"Error parsing function arguments: {e}, agent likely produced invalid tool inputs.",
# status_code=400,
# )
merged_kwargs = {**kwargs, **function_args}
try:
raw_result = await tool.results_function(
*args, **merged_kwargs
)
llm_formatted_result = tool.llm_format_function(raw_result)
except Exception as e:
raw_result = f"Calling the requested tool '{function_name}' with arguments {function_arguments} failed with an exception: {e}."
logger.error(raw_result)
llm_formatted_result = raw_result
tool_result = ToolResult(
raw_result=raw_result,
llm_formatted_result=llm_formatted_result,
)
if tool.stream_function:
tool_result.stream_result = tool.stream_function(raw_result)
if save_messages:
await self.conversation.add_message(
Message(
role="tool" if tool_id else "function",
content=str(tool_result.llm_formatted_result),
name=function_name,
tool_call_id=tool_id,
)
)
# HACK - to fix issues with claude thinking + tool use [https://github.com/anthropics/anthropic-cookbook/blob/main/extended_thinking/extended_thinking_with_tool_use.ipynb]
if self.rag_generation_config.extended_thinking:
await self.conversation.add_message(
Message(
role="user",
content="Continue...",
)
)
self.tool_calls.append(
{
"name": function_name,
"args": function_arguments,
}
)
return tool_result
# TODO - Move agents to provider pattern
class RAGAgentConfig(AgentConfig):
rag_rag_agent_static_prompt: str = "static_rag_agent"
rag_agent_dynamic_prompt: str = "dynamic_reasoning_rag_agent_prompted"
stream: bool = False
include_tools: bool = True
max_iterations: int = 10
# tools: list[str] = [] # HACK - unused variable.
# Default RAG tools
rag_tools: list[str] = [
"search_file_descriptions",
"search_file_knowledge",
"get_file_content",
]
# Default Research tools
research_tools: list[str] = [
"rag",
"reasoning",
# DISABLED by default
"critique",
"python_executor",
]
@classmethod
def create(cls: Type["AgentConfig"], **kwargs: Any) -> "AgentConfig":
base_args = cls.model_fields.keys()
filtered_kwargs = {
k: v if v != "None" else None
for k, v in kwargs.items()
if k in base_args
}
filtered_kwargs["tools"] = kwargs.get("tools", None) or kwargs.get(
"tool_names", None
)
return cls(**filtered_kwargs) # type: ignore