diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/core/agent | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/core/agent')
-rw-r--r-- | .venv/lib/python3.12/site-packages/core/agent/__init__.py | 36 | ||||
-rw-r--r-- | .venv/lib/python3.12/site-packages/core/agent/base.py | 1484 | ||||
-rw-r--r-- | .venv/lib/python3.12/site-packages/core/agent/rag.py | 662 | ||||
-rw-r--r-- | .venv/lib/python3.12/site-packages/core/agent/research.py | 697 |
4 files changed, 2879 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/core/agent/__init__.py b/.venv/lib/python3.12/site-packages/core/agent/__init__.py new file mode 100644 index 00000000..bd6dda79 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/agent/__init__.py @@ -0,0 +1,36 @@ +# FIXME: Once the agent is properly type annotated, remove the type: ignore comments +from .base import ( # type: ignore + R2RAgent, + R2RStreamingAgent, + R2RXMLStreamingAgent, +) +from .rag import ( # type: ignore + R2RRAGAgent, + R2RStreamingRAGAgent, + R2RXMLToolsRAGAgent, + R2RXMLToolsStreamingRAGAgent, +) + +# Import the concrete implementations +from .research import ( + R2RResearchAgent, + R2RStreamingResearchAgent, + R2RXMLToolsResearchAgent, + R2RXMLToolsStreamingResearchAgent, +) + +__all__ = [ + # Base + "R2RAgent", + "R2RStreamingAgent", + "R2RXMLStreamingAgent", + # RAG Agents + "R2RRAGAgent", + "R2RXMLToolsRAGAgent", + "R2RStreamingRAGAgent", + "R2RXMLToolsStreamingRAGAgent", + "R2RResearchAgent", + "R2RStreamingResearchAgent", + "R2RXMLToolsResearchAgent", + "R2RXMLToolsStreamingResearchAgent", +] diff --git a/.venv/lib/python3.12/site-packages/core/agent/base.py b/.venv/lib/python3.12/site-packages/core/agent/base.py new file mode 100644 index 00000000..84aae3f2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/agent/base.py @@ -0,0 +1,1484 @@ +import asyncio +import json +import logging +import re +from abc import ABCMeta +from typing import AsyncGenerator, Optional, Tuple + +from core.base import AsyncSyncMeta, LLMChatCompletion, Message, syncable +from core.base.agent import Agent, Conversation +from core.utils import ( + CitationTracker, + SearchResultsCollector, + SSEFormatter, + convert_nonserializable_objects, + dump_obj, + find_new_citation_spans, +) + +logger = logging.getLogger() + + +class CombinedMeta(AsyncSyncMeta, ABCMeta): + pass + + +def sync_wrapper(async_gen): + loop = asyncio.get_event_loop() + + def wrapper(): + try: + while True: + try: + yield loop.run_until_complete(async_gen.__anext__()) + except StopAsyncIteration: + break + finally: + loop.run_until_complete(async_gen.aclose()) + + return wrapper() + + +class R2RAgent(Agent, metaclass=CombinedMeta): + def __init__(self, *args, **kwargs): + self.search_results_collector = SearchResultsCollector() + super().__init__(*args, **kwargs) + self._reset() + + async def _generate_llm_summary(self, iterations_count: int) -> str: + """ + Generate a summary of the conversation using the LLM when max iterations are exceeded. + + Args: + iterations_count: The number of iterations that were completed + + Returns: + A string containing the LLM-generated summary + """ + try: + # Get all messages in the conversation + all_messages = await self.conversation.get_messages() + + # Create a prompt for the LLM to summarize + summary_prompt = { + "role": "user", + "content": ( + f"The conversation has reached the maximum limit of {iterations_count} iterations " + f"without completing the task. Please provide a concise summary of: " + f"1) The key information you've gathered that's relevant to the original query, " + f"2) What you've attempted so far and why it's incomplete, and " + f"3) A specific recommendation for how to proceed. " + f"Keep your summary brief (3-4 sentences total) and focused on the most valuable insights. If it is possible to answer the original user query, then do so now instead." + f"Start with '⚠️ **Maximum iterations exceeded**'" + ), + } + + # Create a new message list with just the conversation history and summary request + summary_messages = all_messages + [summary_prompt] + + # Get a completion for the summary + generation_config = self.get_generation_config(summary_prompt) + response = await self.llm_provider.aget_completion( + summary_messages, + generation_config, + ) + + return response.choices[0].message.content + except Exception as e: + logger.error(f"Error generating LLM summary: {str(e)}") + # Fall back to basic summary if LLM generation fails + return ( + "⚠️ **Maximum iterations exceeded**\n\n" + "The agent reached the maximum iteration limit without completing the task. " + "Consider breaking your request into smaller steps or refining your query." + ) + + def _reset(self): + self._completed = False + self.conversation = Conversation() + + @syncable + async def arun( + self, + messages: list[Message], + system_instruction: Optional[str] = None, + *args, + **kwargs, + ) -> list[dict]: + self._reset() + await self._setup(system_instruction) + + if messages: + for message in messages: + await self.conversation.add_message(message) + iterations_count = 0 + while ( + not self._completed + and iterations_count < self.config.max_iterations + ): + iterations_count += 1 + messages_list = await self.conversation.get_messages() + generation_config = self.get_generation_config(messages_list[-1]) + response = await self.llm_provider.aget_completion( + messages_list, + generation_config, + ) + logger.debug(f"R2RAgent response: {response}") + await self.process_llm_response(response, *args, **kwargs) + + if not self._completed: + # Generate a summary of the conversation using the LLM + summary = await self._generate_llm_summary(iterations_count) + await self.conversation.add_message( + Message(role="assistant", content=summary) + ) + + # Return final content + all_messages: list[dict] = await self.conversation.get_messages() + all_messages.reverse() + + output_messages = [] + for message_2 in all_messages: + if ( + # message_2.get("content") + message_2.get("content") != messages[-1].content + ): + output_messages.append(message_2) + else: + break + output_messages.reverse() + + return output_messages + + async def process_llm_response( + self, response: LLMChatCompletion, *args, **kwargs + ) -> None: + if not self._completed: + message = response.choices[0].message + finish_reason = response.choices[0].finish_reason + + if finish_reason == "stop": + self._completed = True + + # Determine which provider we're using + using_anthropic = ( + "anthropic" in self.rag_generation_config.model.lower() + ) + + # OPENAI HANDLING + if not using_anthropic: + if message.tool_calls: + assistant_msg = Message( + role="assistant", + content="", + tool_calls=[msg.dict() for msg in message.tool_calls], + ) + await self.conversation.add_message(assistant_msg) + + # If there are multiple tool_calls, call them sequentially here + for tool_call in message.tool_calls: + await self.handle_function_or_tool_call( + tool_call.function.name, + tool_call.function.arguments, + tool_id=tool_call.id, + *args, + **kwargs, + ) + else: + await self.conversation.add_message( + Message(role="assistant", content=message.content) + ) + self._completed = True + + else: + # First handle thinking blocks if present + if ( + hasattr(message, "structured_content") + and message.structured_content + ): + # Check if structured_content contains any tool_use blocks + has_tool_use = any( + block.get("type") == "tool_use" + for block in message.structured_content + ) + + if not has_tool_use and message.tool_calls: + # If it has thinking but no tool_use, add a separate message with structured_content + assistant_msg = Message( + role="assistant", + structured_content=message.structured_content, # Use structured_content field + ) + await self.conversation.add_message(assistant_msg) + + # Add explicit tool_use blocks in a separate message + tool_uses = [] + for tool_call in message.tool_calls: + # Safely parse arguments if they're a string + try: + if isinstance( + tool_call.function.arguments, str + ): + input_args = json.loads( + tool_call.function.arguments + ) + else: + input_args = tool_call.function.arguments + except json.JSONDecodeError: + logger.error( + f"Failed to parse tool arguments: {tool_call.function.arguments}" + ) + input_args = { + "_raw": tool_call.function.arguments + } + + tool_uses.append( + { + "type": "tool_use", + "id": tool_call.id, + "name": tool_call.function.name, + "input": input_args, + } + ) + + # Add tool_use blocks as a separate assistant message with structured content + if tool_uses: + await self.conversation.add_message( + Message( + role="assistant", + structured_content=tool_uses, + content="", + ) + ) + else: + # If it already has tool_use or no tool_calls, preserve original structure + assistant_msg = Message( + role="assistant", + structured_content=message.structured_content, + ) + await self.conversation.add_message(assistant_msg) + + elif message.content: + # For regular text content + await self.conversation.add_message( + Message(role="assistant", content=message.content) + ) + + # If there are tool calls, add them as structured content + if message.tool_calls: + tool_uses = [] + for tool_call in message.tool_calls: + # Same safe parsing as above + try: + if isinstance( + tool_call.function.arguments, str + ): + input_args = json.loads( + tool_call.function.arguments + ) + else: + input_args = tool_call.function.arguments + except json.JSONDecodeError: + logger.error( + f"Failed to parse tool arguments: {tool_call.function.arguments}" + ) + input_args = { + "_raw": tool_call.function.arguments + } + + tool_uses.append( + { + "type": "tool_use", + "id": tool_call.id, + "name": tool_call.function.name, + "input": input_args, + } + ) + + await self.conversation.add_message( + Message( + role="assistant", structured_content=tool_uses + ) + ) + + # NEW CASE: Handle tool_calls with no content or structured_content + elif message.tool_calls: + # Create tool_uses for the message with only tool_calls + tool_uses = [] + for tool_call in message.tool_calls: + try: + if isinstance(tool_call.function.arguments, str): + input_args = json.loads( + tool_call.function.arguments + ) + else: + input_args = tool_call.function.arguments + except json.JSONDecodeError: + logger.error( + f"Failed to parse tool arguments: {tool_call.function.arguments}" + ) + input_args = {"_raw": tool_call.function.arguments} + + tool_uses.append( + { + "type": "tool_use", + "id": tool_call.id, + "name": tool_call.function.name, + "input": input_args, + } + ) + + # Add tool_use blocks as a message before processing tools + if tool_uses: + await self.conversation.add_message( + Message( + role="assistant", + structured_content=tool_uses, + ) + ) + + # Process the tool calls + if message.tool_calls: + for tool_call in message.tool_calls: + await self.handle_function_or_tool_call( + tool_call.function.name, + tool_call.function.arguments, + tool_id=tool_call.id, + *args, + **kwargs, + ) + + +class R2RStreamingAgent(R2RAgent): + """ + Base class for all streaming agents with core streaming functionality. + Supports emitting messages, tool calls, and results as SSE events. + """ + + # These two regexes will detect bracket references and then find short IDs. + BRACKET_PATTERN = re.compile(r"\[([^\]]+)\]") + SHORT_ID_PATTERN = re.compile( + r"[A-Za-z0-9]{7,8}" + ) # 7-8 chars, for example + + def __init__(self, *args, **kwargs): + # Force streaming on + if hasattr(kwargs.get("config", {}), "stream"): + kwargs["config"].stream = True + super().__init__(*args, **kwargs) + + async def arun( + self, + system_instruction: str | None = None, + messages: list[Message] | None = None, + *args, + **kwargs, + ) -> AsyncGenerator[str, None]: + """ + Main streaming entrypoint: returns an async generator of SSE lines. + """ + self._reset() + await self._setup(system_instruction) + + if messages: + for m in messages: + await self.conversation.add_message(m) + + # Initialize citation tracker for this run + citation_tracker = CitationTracker() + + # Dictionary to store citation payloads by ID + citation_payloads = {} + + # Track all citations emitted during streaming for final persistence + self.streaming_citations: list[dict] = [] + + async def sse_generator() -> AsyncGenerator[str, None]: + pending_tool_calls = {} + partial_text_buffer = "" + iterations_count = 0 + + try: + # Keep streaming until we complete + while ( + not self._completed + and iterations_count < self.config.max_iterations + ): + iterations_count += 1 + # 1) Get current messages + msg_list = await self.conversation.get_messages() + gen_cfg = self.get_generation_config( + msg_list[-1], stream=True + ) + + accumulated_thinking = "" + thinking_signatures = {} # Map thinking content to signatures + + # 2) Start streaming from LLM + llm_stream = self.llm_provider.aget_completion_stream( + msg_list, gen_cfg + ) + async for chunk in llm_stream: + delta = chunk.choices[0].delta + finish_reason = chunk.choices[0].finish_reason + + if hasattr(delta, "thinking") and delta.thinking: + # Accumulate thinking for later use in messages + accumulated_thinking += delta.thinking + + # Emit SSE "thinking" event + async for ( + line + ) in SSEFormatter.yield_thinking_event( + delta.thinking + ): + yield line + + # Add this new handler for thinking signatures + if hasattr(delta, "thinking_signature"): + thinking_signatures[accumulated_thinking] = ( + delta.thinking_signature + ) + accumulated_thinking = "" + + # 3) If new text, accumulate it + if delta.content: + partial_text_buffer += delta.content + + # (a) Now emit the newly streamed text as a "message" event + async for line in SSEFormatter.yield_message_event( + delta.content + ): + yield line + + # (b) Find new citation spans in the accumulated text + new_citation_spans = find_new_citation_spans( + partial_text_buffer, citation_tracker + ) + + # Process each new citation span + for cid, spans in new_citation_spans.items(): + for span in spans: + # Check if this is the first time we've seen this citation ID + is_new_citation = ( + citation_tracker.is_new_citation(cid) + ) + + # Get payload if it's a new citation + payload = None + if is_new_citation: + source_obj = self.search_results_collector.find_by_short_id( + cid + ) + if source_obj: + # Store payload for reuse + payload = dump_obj(source_obj) + citation_payloads[cid] = payload + + # Create citation event payload + citation_data = { + "id": cid, + "object": "citation", + "is_new": is_new_citation, + "span": { + "start": span[0], + "end": span[1], + }, + } + + # Only include full payload for new citations + if is_new_citation and payload: + citation_data["payload"] = payload + + # Add to streaming citations for final answer + self.streaming_citations.append( + citation_data + ) + + # Emit the citation event + async for ( + line + ) in SSEFormatter.yield_citation_event( + citation_data + ): + yield line + + if delta.tool_calls: + for tc in delta.tool_calls: + idx = tc.index + if idx not in pending_tool_calls: + pending_tool_calls[idx] = { + "id": tc.id, + "name": tc.function.name or "", + "arguments": tc.function.arguments + or "", + } + else: + # Accumulate partial name/arguments + if tc.function.name: + pending_tool_calls[idx]["name"] = ( + tc.function.name + ) + if tc.function.arguments: + pending_tool_calls[idx][ + "arguments" + ] += tc.function.arguments + + # 5) If the stream signals we should handle "tool_calls" + if finish_reason == "tool_calls": + # Handle thinking if present + await self._handle_thinking( + thinking_signatures, accumulated_thinking + ) + + calls_list = [] + for idx in sorted(pending_tool_calls.keys()): + cinfo = pending_tool_calls[idx] + calls_list.append( + { + "tool_call_id": cinfo["id"] + or f"call_{idx}", + "name": cinfo["name"], + "arguments": cinfo["arguments"], + } + ) + + # (a) Emit SSE "tool_call" events + for c in calls_list: + tc_data = self._create_tool_call_data(c) + async for ( + line + ) in SSEFormatter.yield_tool_call_event( + tc_data + ): + yield line + + # (b) Add an assistant message capturing these calls + await self._add_tool_calls_message( + calls_list, partial_text_buffer + ) + + # (c) Execute each tool call in parallel + await asyncio.gather( + *[ + self.handle_function_or_tool_call( + c["name"], + c["arguments"], + tool_id=c["tool_call_id"], + ) + for c in calls_list + ] + ) + + # Reset buffer & calls + pending_tool_calls.clear() + partial_text_buffer = "" + + elif finish_reason == "stop": + # Handle thinking if present + await self._handle_thinking( + thinking_signatures, accumulated_thinking + ) + + # 6) The LLM is done. If we have any leftover partial text, + # finalize it in the conversation + if partial_text_buffer: + # Create the final message with metadata including citations + final_message = Message( + role="assistant", + content=partial_text_buffer, + metadata={ + "citations": self.streaming_citations + }, + ) + + # Add it to the conversation + await self.conversation.add_message( + final_message + ) + + # (a) Prepare final answer with optimized citations + consolidated_citations = [] + # Group citations by ID with all their spans + for ( + cid, + spans, + ) in citation_tracker.get_all_spans().items(): + if cid in citation_payloads: + consolidated_citations.append( + { + "id": cid, + "object": "citation", + "spans": [ + {"start": s[0], "end": s[1]} + for s in spans + ], + "payload": citation_payloads[cid], + } + ) + + # Create final answer payload + final_evt_payload = { + "id": "msg_final", + "object": "agent.final_answer", + "generated_answer": partial_text_buffer, + "citations": consolidated_citations, + } + + # Emit final answer event + async for ( + line + ) in SSEFormatter.yield_final_answer_event( + final_evt_payload + ): + yield line + + # (b) Signal the end of the SSE stream + yield SSEFormatter.yield_done_event() + self._completed = True + break + + # If we exit the while loop due to hitting max iterations + if not self._completed: + # Generate a summary using the LLM + summary = await self._generate_llm_summary( + iterations_count + ) + + # Send the summary as a message event + async for line in SSEFormatter.yield_message_event( + summary + ): + yield line + + # Add summary to conversation with citations metadata + await self.conversation.add_message( + Message( + role="assistant", + content=summary, + metadata={"citations": self.streaming_citations}, + ) + ) + + # Create and emit a final answer payload with the summary + final_evt_payload = { + "id": "msg_final", + "object": "agent.final_answer", + "generated_answer": summary, + "citations": consolidated_citations, + } + + async for line in SSEFormatter.yield_final_answer_event( + final_evt_payload + ): + yield line + + # Signal the end of the SSE stream + yield SSEFormatter.yield_done_event() + self._completed = True + + except Exception as e: + logger.error(f"Error in streaming agent: {str(e)}") + # Emit error event for client + async for line in SSEFormatter.yield_error_event( + f"Agent error: {str(e)}" + ): + yield line + # Send done event to close the stream + yield SSEFormatter.yield_done_event() + + # Finally, we return the async generator + async for line in sse_generator(): + yield line + + async def _handle_thinking( + self, thinking_signatures, accumulated_thinking + ): + """Process any accumulated thinking content""" + if accumulated_thinking: + structured_content = [ + { + "type": "thinking", + "thinking": accumulated_thinking, + # Anthropic will validate this in their API + "signature": "placeholder_signature", + } + ] + + assistant_msg = Message( + role="assistant", + structured_content=structured_content, + ) + await self.conversation.add_message(assistant_msg) + + elif thinking_signatures: + for ( + accumulated_thinking, + thinking_signature, + ) in thinking_signatures.items(): + structured_content = [ + { + "type": "thinking", + "thinking": accumulated_thinking, + # Anthropic will validate this in their API + "signature": thinking_signature, + } + ] + + assistant_msg = Message( + role="assistant", + structured_content=structured_content, + ) + await self.conversation.add_message(assistant_msg) + + async def _add_tool_calls_message(self, calls_list, partial_text_buffer): + """Add a message with tool calls to the conversation""" + assistant_msg = Message( + role="assistant", + content=partial_text_buffer or "", + tool_calls=[ + { + "id": c["tool_call_id"], + "type": "function", + "function": { + "name": c["name"], + "arguments": c["arguments"], + }, + } + for c in calls_list + ], + ) + await self.conversation.add_message(assistant_msg) + + def _create_tool_call_data(self, call_info): + """Create tool call data structure from call info""" + return { + "tool_call_id": call_info["tool_call_id"], + "name": call_info["name"], + "arguments": call_info["arguments"], + } + + def _create_citation_payload(self, short_id, payload): + """Create citation payload for a short ID""" + # This will be overridden in RAG subclasses + # check if as_dict is on payload + if hasattr(payload, "as_dict"): + payload = payload.as_dict() + if hasattr(payload, "dict"): + payload = payload.dict + if hasattr(payload, "to_dict"): + payload = payload.to_dict() + + return { + "id": f"{short_id}", + "object": "citation", + "payload": dump_obj(payload), # Will be populated in RAG agents + } + + def _create_final_answer_payload(self, answer_text, citations): + """Create the final answer payload""" + # This will be extended in RAG subclasses + return { + "id": "msg_final", + "object": "agent.final_answer", + "generated_answer": answer_text, + "citations": citations, + } + + +class R2RXMLStreamingAgent(R2RStreamingAgent): + """ + A streaming agent that parses XML-formatted responses with special handling for: + - <think> or <Thought> blocks for chain-of-thought reasoning + - <Action>, <ToolCalls>, <ToolCall> blocks for tool execution + """ + + # We treat <think> or <Thought> as the same token boundaries + THOUGHT_OPEN = re.compile(r"<(Thought|think)>", re.IGNORECASE) + THOUGHT_CLOSE = re.compile(r"</(Thought|think)>", re.IGNORECASE) + + # Regexes to parse out <Action>, <ToolCalls>, <ToolCall>, <Name>, <Parameters>, <Response> + ACTION_PATTERN = re.compile( + r"<Action>(.*?)</Action>", re.IGNORECASE | re.DOTALL + ) + TOOLCALLS_PATTERN = re.compile( + r"<ToolCalls>(.*?)</ToolCalls>", re.IGNORECASE | re.DOTALL + ) + TOOLCALL_PATTERN = re.compile( + r"<ToolCall>(.*?)</ToolCall>", re.IGNORECASE | re.DOTALL + ) + NAME_PATTERN = re.compile(r"<Name>(.*?)</Name>", re.IGNORECASE | re.DOTALL) + PARAMS_PATTERN = re.compile( + r"<Parameters>(.*?)</Parameters>", re.IGNORECASE | re.DOTALL + ) + RESPONSE_PATTERN = re.compile( + r"<Response>(.*?)</Response>", re.IGNORECASE | re.DOTALL + ) + + async def arun( + self, + system_instruction: str | None = None, + messages: list[Message] | None = None, + *args, + **kwargs, + ) -> AsyncGenerator[str, None]: + """ + Main streaming entrypoint: returns an async generator of SSE lines. + """ + self._reset() + await self._setup(system_instruction) + + if messages: + for m in messages: + await self.conversation.add_message(m) + + # Initialize citation tracker for this run + citation_tracker = CitationTracker() + + # Dictionary to store citation payloads by ID + citation_payloads = {} + + # Track all citations emitted during streaming for final persistence + self.streaming_citations: list[dict] = [] + + async def sse_generator() -> AsyncGenerator[str, None]: + iterations_count = 0 + + try: + # Keep streaming until we complete + while ( + not self._completed + and iterations_count < self.config.max_iterations + ): + iterations_count += 1 + # 1) Get current messages + msg_list = await self.conversation.get_messages() + gen_cfg = self.get_generation_config( + msg_list[-1], stream=True + ) + + # 2) Start streaming from LLM + llm_stream = self.llm_provider.aget_completion_stream( + msg_list, gen_cfg + ) + + # Create state variables for each iteration + iteration_buffer = "" + yielded_first_event = False + in_action_block = False + is_thinking = False + accumulated_thinking = "" + thinking_signatures = {} + + async for chunk in llm_stream: + delta = chunk.choices[0].delta + finish_reason = chunk.choices[0].finish_reason + + # Handle thinking if present + if hasattr(delta, "thinking") and delta.thinking: + # Accumulate thinking for later use in messages + accumulated_thinking += delta.thinking + + # Emit SSE "thinking" event + async for ( + line + ) in SSEFormatter.yield_thinking_event( + delta.thinking + ): + yield line + + # Add this new handler for thinking signatures + if hasattr(delta, "thinking_signature"): + thinking_signatures[accumulated_thinking] = ( + delta.thinking_signature + ) + accumulated_thinking = "" + + # 3) If new text, accumulate it + if delta.content: + iteration_buffer += delta.content + + # Check if we have accumulated enough text for a `<Thought>` block + if len(iteration_buffer) < len("<Thought>"): + continue + + # Check if we have yielded the first event + if not yielded_first_event: + # Emit the first chunk + if self.THOUGHT_OPEN.findall(iteration_buffer): + is_thinking = True + async for ( + line + ) in SSEFormatter.yield_thinking_event( + iteration_buffer + ): + yield line + else: + async for ( + line + ) in SSEFormatter.yield_message_event( + iteration_buffer + ): + yield line + + # Mark as yielded + yielded_first_event = True + continue + + # Check if we are in a thinking block + if is_thinking: + # Still thinking, so keep yielding thinking events + if not self.THOUGHT_CLOSE.findall( + iteration_buffer + ): + # Emit SSE "thinking" event + async for ( + line + ) in SSEFormatter.yield_thinking_event( + delta.content + ): + yield line + + continue + # Done thinking, so emit the last thinking event + else: + is_thinking = False + thought_text = delta.content.split( + "</Thought>" + )[0].split("</think>")[0] + async for ( + line + ) in SSEFormatter.yield_thinking_event( + thought_text + ): + yield line + post_thought_text = delta.content.split( + "</Thought>" + )[-1].split("</think>")[-1] + delta.content = post_thought_text + + # (b) Find new citation spans in the accumulated text + new_citation_spans = find_new_citation_spans( + iteration_buffer, citation_tracker + ) + + # Process each new citation span + for cid, spans in new_citation_spans.items(): + for span in spans: + # Check if this is the first time we've seen this citation ID + is_new_citation = ( + citation_tracker.is_new_citation(cid) + ) + + # Get payload if it's a new citation + payload = None + if is_new_citation: + source_obj = self.search_results_collector.find_by_short_id( + cid + ) + if source_obj: + # Store payload for reuse + payload = dump_obj(source_obj) + citation_payloads[cid] = payload + + # Create citation event payload + citation_data = { + "id": cid, + "object": "citation", + "is_new": is_new_citation, + "span": { + "start": span[0], + "end": span[1], + }, + } + + # Only include full payload for new citations + if is_new_citation and payload: + citation_data["payload"] = payload + + # Add to streaming citations for final answer + self.streaming_citations.append( + citation_data + ) + + # Emit the citation event + async for ( + line + ) in SSEFormatter.yield_citation_event( + citation_data + ): + yield line + + # Now prepare to emit the newly streamed text as a "message" event + if ( + iteration_buffer.count("<") + and not in_action_block + ): + in_action_block = True + + if ( + in_action_block + and len( + self.ACTION_PATTERN.findall( + iteration_buffer + ) + ) + < 2 + ): + continue + + elif in_action_block: + in_action_block = False + # Emit the post action block text, if it is there + post_action_text = iteration_buffer.split( + "</Action>" + )[-1] + if post_action_text: + async for ( + line + ) in SSEFormatter.yield_message_event( + post_action_text + ): + yield line + + else: + async for ( + line + ) in SSEFormatter.yield_message_event( + delta.content + ): + yield line + + elif finish_reason == "stop": + break + + # Process any accumulated thinking + await self._handle_thinking( + thinking_signatures, accumulated_thinking + ) + + # 6) The LLM is done. If we have any leftover partial text, + # finalize it in the conversation + if iteration_buffer: + # Create the final message with metadata including citations + final_message = Message( + role="assistant", + content=iteration_buffer, + metadata={"citations": self.streaming_citations}, + ) + + # Add it to the conversation + await self.conversation.add_message(final_message) + + # --- 4) Process any <Action>/<ToolCalls> blocks, or mark completed + action_matches = self.ACTION_PATTERN.findall( + iteration_buffer + ) + + if len(action_matches) > 0: + # Process each ToolCall + xml_toolcalls = "<ToolCalls>" + + for action_block in action_matches: + tool_calls_text = [] + # Look for ToolCalls wrapper, or use the raw action block + calls_wrapper = self.TOOLCALLS_PATTERN.findall( + action_block + ) + if calls_wrapper: + for tw in calls_wrapper: + tool_calls_text.append(tw) + else: + tool_calls_text.append(action_block) + + for calls_region in tool_calls_text: + calls_found = self.TOOLCALL_PATTERN.findall( + calls_region + ) + for tc_block in calls_found: + tool_name, tool_params = ( + self._parse_single_tool_call(tc_block) + ) + if tool_name: + # Emit SSE event for tool call + tool_call_id = ( + f"call_{abs(hash(tc_block))}" + ) + call_evt_data = { + "tool_call_id": tool_call_id, + "name": tool_name, + "arguments": json.dumps( + tool_params + ), + } + async for line in ( + SSEFormatter.yield_tool_call_event( + call_evt_data + ) + ): + yield line + + try: + tool_result = await self.handle_function_or_tool_call( + tool_name, + json.dumps(tool_params), + tool_id=tool_call_id, + save_messages=False, + ) + result_content = tool_result.llm_formatted_result + except Exception as e: + result_content = f"Error in tool '{tool_name}': {str(e)}" + + xml_toolcalls += ( + f"<ToolCall>" + f"<Name>{tool_name}</Name>" + f"<Parameters>{json.dumps(tool_params)}</Parameters>" + f"<Result>{result_content}</Result>" + f"</ToolCall>" + ) + + # Emit SSE tool result for non-result tools + result_data = { + "tool_call_id": tool_call_id, + "role": "tool", + "content": json.dumps( + convert_nonserializable_objects( + result_content + ) + ), + } + async for line in SSEFormatter.yield_tool_result_event( + result_data + ): + yield line + + xml_toolcalls += "</ToolCalls>" + pre_action_text = iteration_buffer[ + : iteration_buffer.find(action_block) + ] + post_action_text = iteration_buffer[ + iteration_buffer.find(action_block) + + len(action_block) : + ] + iteration_text = ( + pre_action_text + xml_toolcalls + post_action_text + ) + + # Update the conversation with tool results + await self.conversation.add_message( + Message( + role="assistant", + content=iteration_text, + metadata={ + "citations": self.streaming_citations + }, + ) + ) + else: + # (a) Prepare final answer with optimized citations + consolidated_citations = [] + # Group citations by ID with all their spans + for ( + cid, + spans, + ) in citation_tracker.get_all_spans().items(): + if cid in citation_payloads: + consolidated_citations.append( + { + "id": cid, + "object": "citation", + "spans": [ + {"start": s[0], "end": s[1]} + for s in spans + ], + "payload": citation_payloads[cid], + } + ) + + # Create final answer payload + final_evt_payload = { + "id": "msg_final", + "object": "agent.final_answer", + "generated_answer": iteration_buffer, + "citations": consolidated_citations, + } + + # Emit final answer event + async for ( + line + ) in SSEFormatter.yield_final_answer_event( + final_evt_payload + ): + yield line + + # (b) Signal the end of the SSE stream + yield SSEFormatter.yield_done_event() + self._completed = True + + # If we exit the while loop due to hitting max iterations + if not self._completed: + # Generate a summary using the LLM + summary = await self._generate_llm_summary( + iterations_count + ) + + # Send the summary as a message event + async for line in SSEFormatter.yield_message_event( + summary + ): + yield line + + # Add summary to conversation with citations metadata + await self.conversation.add_message( + Message( + role="assistant", + content=summary, + metadata={"citations": self.streaming_citations}, + ) + ) + + # Create and emit a final answer payload with the summary + final_evt_payload = { + "id": "msg_final", + "object": "agent.final_answer", + "generated_answer": summary, + "citations": consolidated_citations, + } + + async for line in SSEFormatter.yield_final_answer_event( + final_evt_payload + ): + yield line + + # Signal the end of the SSE stream + yield SSEFormatter.yield_done_event() + self._completed = True + + except Exception as e: + logger.error(f"Error in streaming agent: {str(e)}") + # Emit error event for client + async for line in SSEFormatter.yield_error_event( + f"Agent error: {str(e)}" + ): + yield line + # Send done event to close the stream + yield SSEFormatter.yield_done_event() + + # Finally, we return the async generator + async for line in sse_generator(): + yield line + + def _parse_single_tool_call( + self, toolcall_text: str + ) -> Tuple[Optional[str], dict]: + """ + Parse a ToolCall block to extract the name and parameters. + + Args: + toolcall_text: The text content of a ToolCall block + + Returns: + Tuple of (tool_name, tool_parameters) + """ + name_match = self.NAME_PATTERN.search(toolcall_text) + if not name_match: + return None, {} + tool_name = name_match.group(1).strip() + + params_match = self.PARAMS_PATTERN.search(toolcall_text) + if not params_match: + return tool_name, {} + + raw_params = params_match.group(1).strip() + try: + # Handle potential JSON parsing issues + # First try direct parsing + tool_params = json.loads(raw_params) + except json.JSONDecodeError: + # If that fails, try to clean up the JSON string + try: + # Replace escaped quotes that might cause issues + cleaned_params = raw_params.replace('\\"', '"') + # Try again with the cleaned string + tool_params = json.loads(cleaned_params) + except json.JSONDecodeError: + # If all else fails, treat as a plain string value + tool_params = {"value": raw_params} + + return tool_name, tool_params + + +class R2RXMLToolsAgent(R2RAgent): + """ + A non-streaming agent that: + - parses <think> or <Thought> blocks as chain-of-thought + - filters out XML tags related to tool calls and actions + - processes <Action><ToolCalls><ToolCall> blocks + - properly extracts citations when they appear in the text + """ + + # We treat <think> or <Thought> as the same token boundaries + THOUGHT_OPEN = re.compile(r"<(Thought|think)>", re.IGNORECASE) + THOUGHT_CLOSE = re.compile(r"</(Thought|think)>", re.IGNORECASE) + + # Regexes to parse out <Action>, <ToolCalls>, <ToolCall>, <Name>, <Parameters>, <Response> + ACTION_PATTERN = re.compile( + r"<Action>(.*?)</Action>", re.IGNORECASE | re.DOTALL + ) + TOOLCALLS_PATTERN = re.compile( + r"<ToolCalls>(.*?)</ToolCalls>", re.IGNORECASE | re.DOTALL + ) + TOOLCALL_PATTERN = re.compile( + r"<ToolCall>(.*?)</ToolCall>", re.IGNORECASE | re.DOTALL + ) + NAME_PATTERN = re.compile(r"<Name>(.*?)</Name>", re.IGNORECASE | re.DOTALL) + PARAMS_PATTERN = re.compile( + r"<Parameters>(.*?)</Parameters>", re.IGNORECASE | re.DOTALL + ) + RESPONSE_PATTERN = re.compile( + r"<Response>(.*?)</Response>", re.IGNORECASE | re.DOTALL + ) + + async def process_llm_response(self, response, *args, **kwargs): + """ + Override the base process_llm_response to handle XML structured responses + including thoughts and tool calls. + """ + if self._completed: + return + + message = response.choices[0].message + finish_reason = response.choices[0].finish_reason + + if not message.content: + # If there's no content, let the parent class handle the normal tool_calls flow + return await super().process_llm_response( + response, *args, **kwargs + ) + + # Get the response content + content = message.content + + # HACK for gemini + content = content.replace("```action", "") + content = content.replace("```tool_code", "") + content = content.replace("```", "") + + if ( + not content.startswith("<") + and "deepseek" in self.rag_generation_config.model + ): # HACK - fix issues with adding `<think>` to the beginning + content = "<think>" + content + + # Process any tool calls in the content + action_matches = self.ACTION_PATTERN.findall(content) + if action_matches: + xml_toolcalls = "<ToolCalls>" + for action_block in action_matches: + tool_calls_text = [] + # Look for ToolCalls wrapper, or use the raw action block + calls_wrapper = self.TOOLCALLS_PATTERN.findall(action_block) + if calls_wrapper: + for tw in calls_wrapper: + tool_calls_text.append(tw) + else: + tool_calls_text.append(action_block) + + # Process each ToolCall + for calls_region in tool_calls_text: + calls_found = self.TOOLCALL_PATTERN.findall(calls_region) + for tc_block in calls_found: + tool_name, tool_params = self._parse_single_tool_call( + tc_block + ) + if tool_name: + tool_call_id = f"call_{abs(hash(tc_block))}" + try: + tool_result = ( + await self.handle_function_or_tool_call( + tool_name, + json.dumps(tool_params), + tool_id=tool_call_id, + save_messages=False, + ) + ) + + # Add tool result to XML + xml_toolcalls += ( + f"<ToolCall>" + f"<Name>{tool_name}</Name>" + f"<Parameters>{json.dumps(tool_params)}</Parameters>" + f"<Result>{tool_result.llm_formatted_result}</Result>" + f"</ToolCall>" + ) + + except Exception as e: + logger.error(f"Error in tool call: {str(e)}") + # Add error to XML + xml_toolcalls += ( + f"<ToolCall>" + f"<Name>{tool_name}</Name>" + f"<Parameters>{json.dumps(tool_params)}</Parameters>" + f"<Result>Error: {str(e)}</Result>" + f"</ToolCall>" + ) + + xml_toolcalls += "</ToolCalls>" + pre_action_text = content[: content.find(action_block)] + post_action_text = content[ + content.find(action_block) + len(action_block) : + ] + iteration_text = pre_action_text + xml_toolcalls + post_action_text + + # Create the assistant message + await self.conversation.add_message( + Message(role="assistant", content=iteration_text) + ) + else: + # Create an assistant message with the content as-is + await self.conversation.add_message( + Message(role="assistant", content=content) + ) + + # Only mark as completed if the finish_reason is "stop" or there are no action calls + # This allows the agent to continue the conversation when tool calls are processed + if finish_reason == "stop": + self._completed = True + + def _parse_single_tool_call( + self, toolcall_text: str + ) -> Tuple[Optional[str], dict]: + """ + Parse a ToolCall block to extract the name and parameters. + + Args: + toolcall_text: The text content of a ToolCall block + + Returns: + Tuple of (tool_name, tool_parameters) + """ + name_match = self.NAME_PATTERN.search(toolcall_text) + if not name_match: + return None, {} + tool_name = name_match.group(1).strip() + + params_match = self.PARAMS_PATTERN.search(toolcall_text) + if not params_match: + return tool_name, {} + + raw_params = params_match.group(1).strip() + try: + # Handle potential JSON parsing issues + # First try direct parsing + tool_params = json.loads(raw_params) + except json.JSONDecodeError: + # If that fails, try to clean up the JSON string + try: + # Replace escaped quotes that might cause issues + cleaned_params = raw_params.replace('\\"', '"') + # Try again with the cleaned string + tool_params = json.loads(cleaned_params) + except json.JSONDecodeError: + # If all else fails, treat as a plain string value + tool_params = {"value": raw_params} + + return tool_name, tool_params diff --git a/.venv/lib/python3.12/site-packages/core/agent/rag.py b/.venv/lib/python3.12/site-packages/core/agent/rag.py new file mode 100644 index 00000000..6f3ab630 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/agent/rag.py @@ -0,0 +1,662 @@ +# type: ignore +import logging +from typing import Any, Callable, Optional + +from core.base import ( + format_search_results_for_llm, +) +from core.base.abstractions import ( + AggregateSearchResult, + GenerationConfig, + SearchSettings, + WebPageSearchResult, + WebSearchResult, +) +from core.base.agent import Tool +from core.base.providers import DatabaseProvider +from core.providers import ( + AnthropicCompletionProvider, + LiteLLMCompletionProvider, + OpenAICompletionProvider, + R2RCompletionProvider, +) +from core.utils import ( + SearchResultsCollector, + generate_id, + num_tokens, +) + +from ..base.agent.agent import RAGAgentConfig + +# Import the base classes from the refactored base file +from .base import ( + R2RAgent, + R2RStreamingAgent, + R2RXMLStreamingAgent, + R2RXMLToolsAgent, +) + +logger = logging.getLogger(__name__) + + +class RAGAgentMixin: + """ + A Mixin for adding search_file_knowledge, web_search, and content tools + to your R2R Agents. This allows your agent to: + - call knowledge_search_method (semantic/hybrid search) + - call content_method (fetch entire doc/chunk structures) + - call an external web search API + """ + + def __init__( + self, + *args, + search_settings: SearchSettings, + knowledge_search_method: Callable, + content_method: Callable, + file_search_method: Callable, + max_tool_context_length=10_000, + max_context_window_tokens=512_000, + **kwargs, + ): + # Save references to the retrieval logic + self.search_settings = search_settings + self.knowledge_search_method = knowledge_search_method + self.content_method = content_method + self.file_search_method = file_search_method + self.max_tool_context_length = max_tool_context_length + self.max_context_window_tokens = max_context_window_tokens + self.search_results_collector = SearchResultsCollector() + super().__init__(*args, **kwargs) + + def _register_tools(self): + """ + Called by the base R2RAgent to register all requested tools from self.config.rag_tools. + """ + if not self.config.rag_tools: + return + + for tool_name in set(self.config.rag_tools): + if tool_name == "get_file_content": + self._tools.append(self.content()) + elif tool_name == "web_scrape": + self._tools.append(self.web_scrape()) + elif tool_name == "search_file_knowledge": + self._tools.append(self.search_file_knowledge()) + elif tool_name == "search_file_descriptions": + self._tools.append(self.search_files()) + elif tool_name == "web_search": + self._tools.append(self.web_search()) + else: + raise ValueError(f"Unknown tool requested: {tool_name}") + logger.debug(f"Registered {len(self._tools)} RAG tools.") + + # Local Search Tool + def search_file_knowledge(self) -> Tool: + """ + Tool to do a semantic/hybrid search on the local knowledge base + using self.knowledge_search_method. + """ + return Tool( + name="search_file_knowledge", + description=( + "Search your local knowledge base using the R2R system. " + "Use this when you want relevant text chunks or knowledge graph data." + ), + results_function=self._file_knowledge_search_function, + llm_format_function=self.format_search_results_for_llm, + parameters={ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "User query to search in the local DB.", + }, + }, + "required": ["query"], + }, + ) + + async def _file_knowledge_search_function( + self, + query: str, + *args, + **kwargs, + ) -> AggregateSearchResult: + """ + Calls the passed-in `knowledge_search_method(query, search_settings)`. + Expects either an AggregateSearchResult or a dict with chunk_search_results, etc. + """ + if not self.knowledge_search_method: + raise ValueError( + "No knowledge_search_method provided to RAGAgentMixin." + ) + + raw_response = await self.knowledge_search_method( + query=query, search_settings=self.search_settings + ) + + if isinstance(raw_response, AggregateSearchResult): + agg = raw_response + else: + agg = AggregateSearchResult( + chunk_search_results=raw_response.get( + "chunk_search_results", [] + ), + graph_search_results=raw_response.get( + "graph_search_results", [] + ), + ) + + # 1) Store them so that we can do final citations later + self.search_results_collector.add_aggregate_result(agg) + return agg + + # 2) Local Context + def content(self) -> Tool: + """Tool to fetch entire documents from the local database. + + Typically used if the agent needs deeper or more structured context + from documents, not just chunk-level hits. + """ + if "gemini" in self.rag_generation_config.model: + tool = Tool( + name="get_file_content", + description=( + "Fetches the complete contents of all user documents from the local database. " + "Can be used alongside filter criteria (e.g. doc IDs, collection IDs, etc.) to restrict the query." + "For instance, a single document can be returned with a filter like so:" + "{'document_id': {'$eq': '...'}}." + "Be sure to use the full 32 character hexidecimal document ID, and not the shortened 8 character ID." + ), + results_function=self._content_function, + llm_format_function=self.format_search_results_for_llm, + parameters={ + "type": "object", + "properties": { + "filters": { + "type": "string", + "description": ( + "Dictionary with filter criteria, such as " + '{"$and": [{"document_id": {"$eq": "6c9d1c39..."}, {"collection_ids": {"$overlap": [...]}]}' + ), + }, + }, + "required": ["filters"], + }, + ) + + else: + tool = Tool( + name="get_file_content", + description=( + "Fetches the complete contents of all user documents from the local database. " + "Can be used alongside filter criteria (e.g. doc IDs, collection IDs, etc.) to restrict the query." + "For instance, a single document can be returned with a filter like so:" + "{'document_id': {'$eq': '...'}}." + ), + results_function=self._content_function, + llm_format_function=self.format_search_results_for_llm, + parameters={ + "type": "object", + "properties": { + "filters": { + "type": "object", + "description": ( + "Dictionary with filter criteria, such as " + '{"$and": [{"document_id": {"$eq": "6c9d1c39..."}, {"collection_ids": {"$overlap": [...]}]}' + ), + }, + }, + "required": ["filters"], + }, + ) + return tool + + async def _content_function( + self, + filters: Optional[dict[str, Any]] = None, + options: Optional[dict[str, Any]] = None, + *args, + **kwargs, + ) -> AggregateSearchResult: + """Calls the passed-in `content_method(filters, options)` to fetch + doc+chunk structures. + + Typically returns a list of dicts: + [ + { 'document': {...}, 'chunks': [ {...}, {...}, ... ] }, + ... + ] + We'll store these in a new field `document_search_results` of + AggregateSearchResult so we don't collide with chunk_search_results. + """ + if not self.content_method: + raise ValueError("No content_method provided to RAGAgentMixin.") + + if filters: + if "document_id" in filters: + filters["id"] = filters.pop("document_id") + if self.search_settings.filters != {}: + filters = {"$and": [filters, self.search_settings.filters]} + else: + filters = self.search_settings.filters + + options = options or {} + + # Actually call your data retrieval + content = await self.content_method(filters, options) + # raw_context presumably is a list[dict], each with 'document' + 'chunks'. + + # Return them in the new aggregator field + agg = AggregateSearchResult( + # We won't put them in chunk_search_results: + chunk_search_results=None, + graph_search_results=None, + web_search_results=None, + document_search_results=content, + ) + self.search_results_collector.add_aggregate_result(agg) + return agg + + # Web Search Tool + def web_search(self) -> Tool: + return Tool( + name="web_search", + description=( + "Search for information on the web - use this tool when the user " + "query needs LIVE or recent data from the internet." + ), + results_function=self._web_search_function, + llm_format_function=self.format_search_results_for_llm, + parameters={ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The query to search with an external web API.", + }, + }, + "required": ["query"], + }, + ) + + async def _web_search_function( + self, + query: str, + *args, + **kwargs, + ) -> AggregateSearchResult: + """ + Calls an external search engine (Serper, Google, etc.) asynchronously + and returns results in an AggregateSearchResult. + """ + import asyncio + + from ..utils.serper import SerperClient # adjust your import + + serper_client = SerperClient() + + # If SerperClient.get_raw is not already async, wrap it in run_in_executor + raw_results = await asyncio.get_event_loop().run_in_executor( + None, # Uses the default executor + lambda: serper_client.get_raw(query), + ) + + # If from_serper_results is not already async, wrap it in run_in_executor too + web_response = await asyncio.get_event_loop().run_in_executor( + None, lambda: WebSearchResult.from_serper_results(raw_results) + ) + + agg = AggregateSearchResult( + chunk_search_results=None, + graph_search_results=None, + web_search_results=web_response.organic_results, + ) + self.search_results_collector.add_aggregate_result(agg) + return agg + + def search_files(self) -> Tool: + """ + A tool to search over file-level metadata (titles, doc-level descriptions, etc.) + returning a list of DocumentResponse objects. + """ + return Tool( + name="search_file_descriptions", + description=( + "Semantic search over the stored documents over AI generated summaries of input documents. " + "This does NOT retrieve chunk-level contents or knowledge-graph relationships. " + "Use this when you need a broad overview of which documents (files) might be relevant." + ), + results_function=self._search_files_function, + llm_format_function=self.format_search_results_for_llm, + parameters={ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Query string to semantic search over available files 'list documents about XYZ'.", + } + }, + "required": ["query"], + }, + ) + + async def _search_files_function( + self, query: str, *args, **kwargs + ) -> AggregateSearchResult: + if not self.file_search_method: + raise ValueError( + "No file_search_method provided to RAGAgentMixin." + ) + + # call the doc-level search + """ + FIXME: This is going to fail, as it requires an embedding NOT a query. + I've moved 'search_settings' to 'settings' which had been causing a silent failure + causing null content in the Message object. + """ + doc_results = await self.file_search_method( + query=query, + settings=self.search_settings, + ) + + # Wrap them in an AggregateSearchResult + agg = AggregateSearchResult(document_search_results=doc_results) + + # Add them to the collector + self.search_results_collector.add_aggregate_result(agg) + return agg + + def format_search_results_for_llm( + self, results: AggregateSearchResult + ) -> str: + context = format_search_results_for_llm( + results, self.search_results_collector + ) + context_tokens = num_tokens(context) + 1 + frac_to_return = self.max_tool_context_length / (context_tokens) + + if frac_to_return > 1: + return context + else: + return context[: int(frac_to_return * len(context))] + + def web_scrape(self) -> Tool: + """ + A new Tool that uses Firecrawl to scrape a single URL and return + its contents in an LLM-friendly format (e.g. markdown). + """ + return Tool( + name="web_scrape", + description=( + "Use Firecrawl to scrape a single webpage and retrieve its contents " + "as clean markdown. Useful when you need the entire body of a page, " + "not just a quick snippet or standard web search result." + ), + results_function=self._web_scrape_function, + llm_format_function=self.format_search_results_for_llm, + parameters={ + "type": "object", + "properties": { + "url": { + "type": "string", + "description": ( + "The absolute URL of the webpage you want to scrape. " + "Example: 'https://docs.firecrawl.dev/getting-started'" + ), + } + }, + "required": ["url"], + }, + ) + + async def _web_scrape_function( + self, + url: str, + *args, + **kwargs, + ) -> AggregateSearchResult: + """ + Performs the Firecrawl scrape asynchronously, returning results + as an `AggregateSearchResult` with a single WebPageSearchResult. + """ + import asyncio + + from firecrawl import FirecrawlApp + + app = FirecrawlApp() + logger.debug(f"[Firecrawl] Scraping URL={url}") + + # Create a proper async wrapper for the synchronous scrape_url method + # This offloads the blocking operation to a thread pool + response = await asyncio.get_event_loop().run_in_executor( + None, # Uses the default executor + lambda: app.scrape_url( + url=url, + params={"formats": ["markdown"]}, + ), + ) + + markdown_text = response.get("markdown", "") + metadata = response.get("metadata", {}) + page_title = metadata.get("title", "Untitled page") + + if len(markdown_text) > 100_000: + markdown_text = ( + markdown_text[:100_000] + "...FURTHER CONTENT TRUNCATED..." + ) + + # Create a single WebPageSearchResult HACK - TODO FIX + web_result = WebPageSearchResult( + title=page_title, + link=url, + snippet=markdown_text, + position=0, + id=generate_id(markdown_text), + type="firecrawl", + ) + + agg = AggregateSearchResult(web_search_results=[web_result]) + + # Add results to the collector + if self.search_results_collector: + self.search_results_collector.add_aggregate_result(agg) + + return agg + + +class R2RRAGAgent(RAGAgentMixin, R2RAgent): + """ + Non-streaming RAG Agent that supports search_file_knowledge, content, web_search. + """ + + def __init__( + self, + database_provider: DatabaseProvider, + llm_provider: ( + AnthropicCompletionProvider + | LiteLLMCompletionProvider + | OpenAICompletionProvider + | R2RCompletionProvider + ), + config: RAGAgentConfig, + search_settings: SearchSettings, + rag_generation_config: GenerationConfig, + knowledge_search_method: Callable, + content_method: Callable, + file_search_method: Callable, + max_tool_context_length: int = 20_000, + ): + # Initialize base R2RAgent + R2RAgent.__init__( + self, + database_provider=database_provider, + llm_provider=llm_provider, + config=config, + rag_generation_config=rag_generation_config, + ) + # Initialize the RAGAgentMixin + RAGAgentMixin.__init__( + self, + database_provider=database_provider, + llm_provider=llm_provider, + config=config, + search_settings=search_settings, + rag_generation_config=rag_generation_config, + max_tool_context_length=max_tool_context_length, + knowledge_search_method=knowledge_search_method, + file_search_method=file_search_method, + content_method=content_method, + ) + + +class R2RXMLToolsRAGAgent(RAGAgentMixin, R2RXMLToolsAgent): + """ + Non-streaming RAG Agent that supports search_file_knowledge, content, web_search. + """ + + def __init__( + self, + database_provider: DatabaseProvider, + llm_provider: ( + AnthropicCompletionProvider + | LiteLLMCompletionProvider + | OpenAICompletionProvider + | R2RCompletionProvider + ), + config: RAGAgentConfig, + search_settings: SearchSettings, + rag_generation_config: GenerationConfig, + knowledge_search_method: Callable, + content_method: Callable, + file_search_method: Callable, + max_tool_context_length: int = 20_000, + ): + # Initialize base R2RAgent + R2RXMLToolsAgent.__init__( + self, + database_provider=database_provider, + llm_provider=llm_provider, + config=config, + rag_generation_config=rag_generation_config, + ) + # Initialize the RAGAgentMixin + RAGAgentMixin.__init__( + self, + database_provider=database_provider, + llm_provider=llm_provider, + config=config, + search_settings=search_settings, + rag_generation_config=rag_generation_config, + max_tool_context_length=max_tool_context_length, + knowledge_search_method=knowledge_search_method, + file_search_method=file_search_method, + content_method=content_method, + ) + + +class R2RStreamingRAGAgent(RAGAgentMixin, R2RStreamingAgent): + """ + Streaming-capable RAG Agent that supports search_file_knowledge, content, web_search, + and emits citations as [abc1234] short IDs if the LLM includes them in brackets. + """ + + def __init__( + self, + database_provider: DatabaseProvider, + llm_provider: ( + AnthropicCompletionProvider + | LiteLLMCompletionProvider + | OpenAICompletionProvider + | R2RCompletionProvider + ), + config: RAGAgentConfig, + search_settings: SearchSettings, + rag_generation_config: GenerationConfig, + knowledge_search_method: Callable, + content_method: Callable, + file_search_method: Callable, + max_tool_context_length: int = 10_000, + ): + # Force streaming on + config.stream = True + + # Initialize base R2RStreamingAgent + R2RStreamingAgent.__init__( + self, + database_provider=database_provider, + llm_provider=llm_provider, + config=config, + rag_generation_config=rag_generation_config, + ) + + # Initialize the RAGAgentMixin + RAGAgentMixin.__init__( + self, + database_provider=database_provider, + llm_provider=llm_provider, + config=config, + search_settings=search_settings, + rag_generation_config=rag_generation_config, + max_tool_context_length=max_tool_context_length, + knowledge_search_method=knowledge_search_method, + content_method=content_method, + file_search_method=file_search_method, + ) + + +class R2RXMLToolsStreamingRAGAgent(RAGAgentMixin, R2RXMLStreamingAgent): + """ + A streaming agent that: + - treats <think> or <Thought> blocks as chain-of-thought + and emits them incrementally as SSE "thinking" events. + - accumulates user-visible text outside those tags as SSE "message" events. + - filters out all XML tags related to tool calls and actions. + - upon finishing each iteration, it parses <Action><ToolCalls><ToolCall> blocks, + calls the appropriate tool, and emits SSE "tool_call" / "tool_result". + - properly emits citations when they appear in the text + """ + + def __init__( + self, + database_provider: DatabaseProvider, + llm_provider: ( + AnthropicCompletionProvider + | LiteLLMCompletionProvider + | OpenAICompletionProvider + | R2RCompletionProvider + ), + config: RAGAgentConfig, + search_settings: SearchSettings, + rag_generation_config: GenerationConfig, + knowledge_search_method: Callable, + content_method: Callable, + file_search_method: Callable, + max_tool_context_length: int = 10_000, + ): + # Force streaming on + config.stream = True + + # Initialize base R2RXMLStreamingAgent + R2RXMLStreamingAgent.__init__( + self, + database_provider=database_provider, + llm_provider=llm_provider, + config=config, + rag_generation_config=rag_generation_config, + ) + + # Initialize the RAGAgentMixin + RAGAgentMixin.__init__( + self, + database_provider=database_provider, + llm_provider=llm_provider, + config=config, + search_settings=search_settings, + rag_generation_config=rag_generation_config, + max_tool_context_length=max_tool_context_length, + knowledge_search_method=knowledge_search_method, + content_method=content_method, + file_search_method=file_search_method, + ) diff --git a/.venv/lib/python3.12/site-packages/core/agent/research.py b/.venv/lib/python3.12/site-packages/core/agent/research.py new file mode 100644 index 00000000..6ea35783 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/agent/research.py @@ -0,0 +1,697 @@ +import logging +import os +import subprocess +import sys +import tempfile +from copy import copy +from typing import Any, Callable, Optional + +from core.base import AppConfig +from core.base.abstractions import GenerationConfig, Message, SearchSettings +from core.base.agent import Tool +from core.base.providers import DatabaseProvider +from core.providers import ( + AnthropicCompletionProvider, + LiteLLMCompletionProvider, + OpenAICompletionProvider, + R2RCompletionProvider, +) +from core.utils import extract_citations + +from ..base.agent.agent import RAGAgentConfig # type: ignore + +# Import the RAG agents we'll leverage +from .rag import ( # type: ignore + R2RRAGAgent, + R2RStreamingRAGAgent, + R2RXMLToolsRAGAgent, + R2RXMLToolsStreamingRAGAgent, + RAGAgentMixin, +) + +logger = logging.getLogger(__name__) + + +class ResearchAgentMixin(RAGAgentMixin): + """ + A mixin that extends RAGAgentMixin to add research capabilities to any R2R agent. + + This mixin provides all RAG capabilities plus additional research tools: + - A RAG tool for knowledge retrieval (which leverages the underlying RAG capabilities) + - A Python execution tool for code execution and computation + - A reasoning tool for complex problem solving + - A critique tool for analyzing conversation history + """ + + def __init__( + self, + *args, + app_config: AppConfig, + search_settings: SearchSettings, + knowledge_search_method: Callable, + content_method: Callable, + file_search_method: Callable, + max_tool_context_length=10_000, + **kwargs, + ): + # Store the app configuration needed for research tools + self.app_config = app_config + + # Call the parent RAGAgentMixin's __init__ with explicitly passed parameters + super().__init__( + *args, + search_settings=search_settings, + knowledge_search_method=knowledge_search_method, + content_method=content_method, + file_search_method=file_search_method, + max_tool_context_length=max_tool_context_length, + **kwargs, + ) + + # Register our research-specific tools + self._register_research_tools() + + def _register_research_tools(self): + """ + Register research-specific tools to the agent. + This is called by the mixin's __init__ after the parent class initialization. + """ + # Add our research tools to whatever tools are already registered + research_tools = [] + for tool_name in set(self.config.research_tools): + if tool_name == "rag": + research_tools.append(self.rag_tool()) + elif tool_name == "reasoning": + research_tools.append(self.reasoning_tool()) + elif tool_name == "critique": + research_tools.append(self.critique_tool()) + elif tool_name == "python_executor": + research_tools.append(self.python_execution_tool()) + else: + logger.warning(f"Unknown research tool: {tool_name}") + raise ValueError(f"Unknown research tool: {tool_name}") + + logger.debug(f"Registered research tools: {research_tools}") + self.tools = research_tools + + def rag_tool(self) -> Tool: + """Tool that provides access to the RAG agent's search capabilities.""" + return Tool( + name="rag", + description=( + "Search for information using RAG (Retrieval-Augmented Generation). " + "This tool searches across relevant sources and returns comprehensive information. " + "Use this tool when you need to find specific information on any topic. Be sure to pose your query as a comprehensive query." + ), + results_function=self._rag, + llm_format_function=self._format_search_results, + parameters={ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The search query to find information.", + } + }, + "required": ["query"], + }, + ) + + def reasoning_tool(self) -> Tool: + """Tool that provides access to a strong reasoning model.""" + return Tool( + name="reasoning", + description=( + "A dedicated reasoning system that excels at solving complex problems through step-by-step analysis. " + "This tool connects to a separate AI system optimized for deep analytical thinking.\n\n" + "USAGE GUIDELINES:\n" + "1. Formulate your request as a complete, standalone question to a reasoning expert.\n" + "2. Clearly state the problem/question at the beginning.\n" + "3. Provide all relevant context, data, and constraints.\n\n" + "IMPORTANT: This system has no memory of previous interactions or context from your conversation.\n\n" + "STRENGTHS: Mathematical reasoning, logical analysis, evaluating complex scenarios, " + "solving multi-step problems, and identifying potential errors in reasoning." + ), + results_function=self._reason, + llm_format_function=self._format_search_results, + parameters={ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "A complete, standalone question with all necessary context, appropriate for a dedicated reasoning system.", + } + }, + "required": ["query"], + }, + ) + + def critique_tool(self) -> Tool: + """Tool that provides critical analysis of the reasoning done so far in the conversation.""" + return Tool( + name="critique", + description=( + "Analyzes the conversation history to identify potential flaws, biases, and alternative " + "approaches to the reasoning presented so far.\n\n" + "Use this tool to get a second opinion on your reasoning, find overlooked considerations, " + "identify biases or fallacies, explore alternative hypotheses, and improve the robustness " + "of your conclusions." + ), + results_function=self._critique, + llm_format_function=self._format_search_results, + parameters={ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "A specific aspect of the reasoning you want critiqued, or leave empty for a general critique.", + }, + "focus_areas": { + "type": "array", + "items": {"type": "string"}, + "description": "Optional specific areas to focus the critique (e.g., ['logical fallacies', 'methodology'])", + }, + }, + "required": ["query"], + }, + ) + + def python_execution_tool(self) -> Tool: + """Tool that provides Python code execution capabilities.""" + return Tool( + name="python_executor", + description=( + "Executes Python code and returns the results, output, and any errors. " + "Use this tool for complex calculations, statistical operations, or algorithmic implementations.\n\n" + "The execution environment includes common libraries such as numpy, pandas, sympy, scipy, statsmodels, biopython, etc.\n\n" + "USAGE:\n" + "1. Send complete, executable Python code as a string.\n" + "2. Use print statements for output you want to see.\n" + "3. Assign to the 'result' variable for values you want to return.\n" + "4. Do not use input() or plotting (matplotlib). Output is text-based." + ), + results_function=self._execute_python_with_process_timeout, + llm_format_function=self._format_python_results, + parameters={ + "type": "object", + "properties": { + "code": { + "type": "string", + "description": "Python code to execute.", + } + }, + "required": ["code"], + }, + ) + + async def _rag( + self, + query: str, + *args, + **kwargs, + ) -> dict[str, Any]: + """Execute a search using an internal RAG agent.""" + # Create a copy of the current configuration for the RAG agent + config_copy = copy(self.config) + config_copy.max_iterations = 10 # Could be configurable + config_copy.rag_tools = [ + "web_search", + "web_scrape", + ] # HACK HACK TODO - Fix. + + # Create a generation config for the RAG agent + generation_config = GenerationConfig( + model=self.app_config.quality_llm, + max_tokens_to_sample=16000, + ) + + # Create a new RAG agent - we'll use the non-streaming variant for consistent results + rag_agent = R2RRAGAgent( + database_provider=self.database_provider, + llm_provider=self.llm_provider, + config=config_copy, + search_settings=self.search_settings, + rag_generation_config=generation_config, + knowledge_search_method=self.knowledge_search_method, + content_method=self.content_method, + file_search_method=self.file_search_method, + max_tool_context_length=self.max_tool_context_length, + ) + + # Run the RAG agent with the query + user_message = Message(role="user", content=query) + response = await rag_agent.arun(messages=[user_message]) + + # Get the content from the response + structured_content = response[-1].get("structured_content") + if structured_content: + possible_text = structured_content[-1].get("text") + content = response[-1].get("content") or possible_text + else: + content = response[-1].get("content") + + # Extract citations and transfer search results from RAG agent to research agent + short_ids = extract_citations(content) + if short_ids: + logger.info(f"Found citations in RAG response: {short_ids}") + + for short_id in short_ids: + result = rag_agent.search_results_collector.find_by_short_id( + short_id + ) + if result: + self.search_results_collector.add_result(result) + + # Log confirmation for successful transfer + logger.info( + "Transferred search results from RAG agent to research agent for citations" + ) + return content + + async def _reason( + self, + query: str, + *args, + **kwargs, + ) -> dict[str, Any]: + """Execute a reasoning query using a specialized reasoning LLM.""" + msg_list = await self.conversation.get_messages() + + # Create a specialized generation config for reasoning + gen_cfg = self.get_generation_config(msg_list[-1], stream=False) + gen_cfg.model = self.app_config.reasoning_llm + gen_cfg.top_p = None + gen_cfg.temperature = 0.1 + gen_cfg.max_tokens_to_sample = 64000 + gen_cfg.stream = False + gen_cfg.tools = None + gen_cfg.functions = None + gen_cfg.reasoning_effort = "high" + gen_cfg.add_generation_kwargs = None + + # Call the LLM with the reasoning request + response = await self.llm_provider.aget_completion( + [{"role": "user", "content": query}], gen_cfg + ) + return response.choices[0].message.content + + async def _critique( + self, + query: str, + focus_areas: Optional[list] = None, + *args, + **kwargs, + ) -> dict[str, Any]: + """Critique the conversation history.""" + msg_list = await self.conversation.get_messages() + if not focus_areas: + focus_areas = [] + # Build the critique prompt + critique_prompt = ( + "You are a critical reasoning expert. Your task is to analyze the following conversation " + "and critique the reasoning. Look for:\n" + "1. Logical fallacies or inconsistencies\n" + "2. Cognitive biases\n" + "3. Overlooked questions or considerations\n" + "4. Alternative approaches\n" + "5. Improvements in rigor\n\n" + ) + + if focus_areas: + critique_prompt += f"Focus areas: {', '.join(focus_areas)}\n\n" + + if query.strip(): + critique_prompt += f"Specific question: {query}\n\n" + + critique_prompt += ( + "Structure your critique:\n" + "1. Summary\n" + "2. Key strengths\n" + "3. Potential issues\n" + "4. Alternatives\n" + "5. Recommendations\n\n" + ) + + # Add the conversation history to the prompt + conversation_text = "\n--- CONVERSATION HISTORY ---\n\n" + for msg in msg_list: + role = msg.get("role", "") + content = msg.get("content", "") + if content and role in ["user", "assistant", "system"]: + conversation_text += f"{role.upper()}: {content}\n\n" + + final_prompt = critique_prompt + conversation_text + + # Use the reasoning tool to process the critique + return await self._reason(final_prompt, *args, **kwargs) + + async def _execute_python_with_process_timeout( + self, code: str, timeout: int = 10, *args, **kwargs + ) -> dict[str, Any]: + """ + Executes Python code in a separate subprocess with a timeout. + This provides isolation and prevents re-importing the current agent module. + + Parameters: + code (str): Python code to execute. + timeout (int): Timeout in seconds (default: 10). + + Returns: + dict[str, Any]: Dictionary containing stdout, stderr, return code, etc. + """ + # Write user code to a temporary file + with tempfile.NamedTemporaryFile( + mode="w", suffix=".py", delete=False + ) as tmp_file: + tmp_file.write(code) + script_path = tmp_file.name + + try: + # Run the script in a fresh subprocess + result = subprocess.run( + [sys.executable, script_path], + capture_output=True, + text=True, + timeout=timeout, + ) + + return { + "result": None, # We'll parse from stdout if needed + "stdout": result.stdout, + "stderr": result.stderr, + "error": ( + None + if result.returncode == 0 + else { + "type": "SubprocessError", + "message": f"Process exited with code {result.returncode}", + "traceback": "", + } + ), + "locals": {}, # No direct local var capture in a separate process + "success": (result.returncode == 0), + "timed_out": False, + "timeout": timeout, + } + except subprocess.TimeoutExpired as e: + return { + "result": None, + "stdout": e.output or "", + "stderr": e.stderr or "", + "error": { + "type": "TimeoutError", + "message": f"Execution exceeded {timeout} second limit.", + "traceback": "", + }, + "locals": {}, + "success": False, + "timed_out": True, + "timeout": timeout, + } + finally: + # Clean up the temp file + if os.path.exists(script_path): + os.remove(script_path) + + def _format_python_results(self, results: dict[str, Any]) -> str: + """Format Python execution results for display.""" + output = [] + + # Timeout notification + if results.get("timed_out", False): + output.append( + f"⚠️ **Execution Timeout**: Code exceeded the {results.get('timeout', 10)} second limit." + ) + output.append("") + + # Stdout + if results.get("stdout"): + output.append("## Output:") + output.append("```") + output.append(results["stdout"].rstrip()) + output.append("```") + output.append("") + + # If there's a 'result' variable to display + if results.get("result") is not None: + output.append("## Result:") + output.append("```") + output.append(str(results["result"])) + output.append("```") + output.append("") + + # Error info + if not results.get("success", True): + output.append("## Error:") + output.append("```") + stderr_out = results.get("stderr", "").rstrip() + if stderr_out: + output.append(stderr_out) + + err_obj = results.get("error") + if err_obj and err_obj.get("message"): + output.append(err_obj["message"]) + output.append("```") + + # Return formatted output + return ( + "\n".join(output) + if output + else "Code executed with no output or result." + ) + + def _format_search_results(self, results) -> str: + """Simple pass-through formatting for RAG search results.""" + return results + + +class R2RResearchAgent(ResearchAgentMixin, R2RRAGAgent): + """ + A non-streaming research agent that uses the standard R2R agent as its base. + + This agent combines research capabilities with the non-streaming RAG agent, + providing tools for deep research through tool-based interaction. + """ + + def __init__( + self, + app_config: AppConfig, + database_provider: DatabaseProvider, + llm_provider: ( + AnthropicCompletionProvider + | LiteLLMCompletionProvider + | OpenAICompletionProvider + | R2RCompletionProvider + ), + config: RAGAgentConfig, + search_settings: SearchSettings, + rag_generation_config: GenerationConfig, + knowledge_search_method: Callable, + content_method: Callable, + file_search_method: Callable, + max_tool_context_length: int = 20_000, + ): + # Set a higher max iterations for research tasks + config.max_iterations = config.max_iterations or 15 + + # Initialize the RAG agent first + R2RRAGAgent.__init__( + self, + database_provider=database_provider, + llm_provider=llm_provider, + config=config, + search_settings=search_settings, + rag_generation_config=rag_generation_config, + knowledge_search_method=knowledge_search_method, + content_method=content_method, + file_search_method=file_search_method, + max_tool_context_length=max_tool_context_length, + ) + + # Then initialize the ResearchAgentMixin + ResearchAgentMixin.__init__( + self, + app_config=app_config, + database_provider=database_provider, + llm_provider=llm_provider, + config=config, + search_settings=search_settings, + rag_generation_config=rag_generation_config, + max_tool_context_length=max_tool_context_length, + knowledge_search_method=knowledge_search_method, + file_search_method=file_search_method, + content_method=content_method, + ) + + +class R2RStreamingResearchAgent(ResearchAgentMixin, R2RStreamingRAGAgent): + """ + A streaming research agent that uses the streaming RAG agent as its base. + + This agent combines research capabilities with streaming text generation, + providing real-time responses while still offering research tools. + """ + + def __init__( + self, + app_config: AppConfig, + database_provider: DatabaseProvider, + llm_provider: ( + AnthropicCompletionProvider + | LiteLLMCompletionProvider + | OpenAICompletionProvider + | R2RCompletionProvider + ), + config: RAGAgentConfig, + search_settings: SearchSettings, + rag_generation_config: GenerationConfig, + knowledge_search_method: Callable, + content_method: Callable, + file_search_method: Callable, + max_tool_context_length: int = 10_000, + ): + # Force streaming on + config.stream = True + config.max_iterations = config.max_iterations or 15 + + # Initialize the streaming RAG agent first + R2RStreamingRAGAgent.__init__( + self, + database_provider=database_provider, + llm_provider=llm_provider, + config=config, + search_settings=search_settings, + rag_generation_config=rag_generation_config, + knowledge_search_method=knowledge_search_method, + content_method=content_method, + file_search_method=file_search_method, + max_tool_context_length=max_tool_context_length, + ) + + # Then initialize the ResearchAgentMixin + ResearchAgentMixin.__init__( + self, + app_config=app_config, + database_provider=database_provider, + llm_provider=llm_provider, + config=config, + search_settings=search_settings, + rag_generation_config=rag_generation_config, + max_tool_context_length=max_tool_context_length, + knowledge_search_method=knowledge_search_method, + content_method=content_method, + file_search_method=file_search_method, + ) + + +class R2RXMLToolsResearchAgent(ResearchAgentMixin, R2RXMLToolsRAGAgent): + """ + A non-streaming research agent that uses XML tool formatting. + + This agent combines research capabilities with the XML-based tool calling format, + which might be more appropriate for certain LLM providers. + """ + + def __init__( + self, + app_config: AppConfig, + database_provider: DatabaseProvider, + llm_provider: ( + AnthropicCompletionProvider + | LiteLLMCompletionProvider + | OpenAICompletionProvider + | R2RCompletionProvider + ), + config: RAGAgentConfig, + search_settings: SearchSettings, + rag_generation_config: GenerationConfig, + knowledge_search_method: Callable, + content_method: Callable, + file_search_method: Callable, + max_tool_context_length: int = 20_000, + ): + # Set higher max iterations + config.max_iterations = config.max_iterations or 15 + + # Initialize the XML Tools RAG agent first + R2RXMLToolsRAGAgent.__init__( + self, + database_provider=database_provider, + llm_provider=llm_provider, + config=config, + search_settings=search_settings, + rag_generation_config=rag_generation_config, + knowledge_search_method=knowledge_search_method, + content_method=content_method, + file_search_method=file_search_method, + max_tool_context_length=max_tool_context_length, + ) + + # Then initialize the ResearchAgentMixin + ResearchAgentMixin.__init__( + self, + app_config=app_config, + search_settings=search_settings, + knowledge_search_method=knowledge_search_method, + content_method=content_method, + file_search_method=file_search_method, + max_tool_context_length=max_tool_context_length, + ) + + +class R2RXMLToolsStreamingResearchAgent( + ResearchAgentMixin, R2RXMLToolsStreamingRAGAgent +): + """ + A streaming research agent that uses XML tool formatting. + + This agent combines research capabilities with streaming and XML-based tool calling, + providing real-time responses in a format suitable for certain LLM providers. + """ + + def __init__( + self, + app_config: AppConfig, + database_provider: DatabaseProvider, + llm_provider: ( + AnthropicCompletionProvider + | LiteLLMCompletionProvider + | OpenAICompletionProvider + | R2RCompletionProvider + ), + config: RAGAgentConfig, + search_settings: SearchSettings, + rag_generation_config: GenerationConfig, + knowledge_search_method: Callable, + content_method: Callable, + file_search_method: Callable, + max_tool_context_length: int = 10_000, + ): + # Force streaming on + config.stream = True + config.max_iterations = config.max_iterations or 15 + + # Initialize the XML Tools Streaming RAG agent first + R2RXMLToolsStreamingRAGAgent.__init__( + self, + database_provider=database_provider, + llm_provider=llm_provider, + config=config, + search_settings=search_settings, + rag_generation_config=rag_generation_config, + knowledge_search_method=knowledge_search_method, + content_method=content_method, + file_search_method=file_search_method, + max_tool_context_length=max_tool_context_length, + ) + + # Then initialize the ResearchAgentMixin + ResearchAgentMixin.__init__( + self, + app_config=app_config, + search_settings=search_settings, + knowledge_search_method=knowledge_search_method, + content_method=content_method, + file_search_method=file_search_method, + max_tool_context_length=max_tool_context_length, + ) |