aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/core/agent
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/core/agent
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to '.venv/lib/python3.12/site-packages/core/agent')
-rw-r--r--.venv/lib/python3.12/site-packages/core/agent/__init__.py36
-rw-r--r--.venv/lib/python3.12/site-packages/core/agent/base.py1484
-rw-r--r--.venv/lib/python3.12/site-packages/core/agent/rag.py662
-rw-r--r--.venv/lib/python3.12/site-packages/core/agent/research.py697
4 files changed, 2879 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/core/agent/__init__.py b/.venv/lib/python3.12/site-packages/core/agent/__init__.py
new file mode 100644
index 00000000..bd6dda79
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/agent/__init__.py
@@ -0,0 +1,36 @@
+# FIXME: Once the agent is properly type annotated, remove the type: ignore comments
+from .base import ( # type: ignore
+ R2RAgent,
+ R2RStreamingAgent,
+ R2RXMLStreamingAgent,
+)
+from .rag import ( # type: ignore
+ R2RRAGAgent,
+ R2RStreamingRAGAgent,
+ R2RXMLToolsRAGAgent,
+ R2RXMLToolsStreamingRAGAgent,
+)
+
+# Import the concrete implementations
+from .research import (
+ R2RResearchAgent,
+ R2RStreamingResearchAgent,
+ R2RXMLToolsResearchAgent,
+ R2RXMLToolsStreamingResearchAgent,
+)
+
+__all__ = [
+ # Base
+ "R2RAgent",
+ "R2RStreamingAgent",
+ "R2RXMLStreamingAgent",
+ # RAG Agents
+ "R2RRAGAgent",
+ "R2RXMLToolsRAGAgent",
+ "R2RStreamingRAGAgent",
+ "R2RXMLToolsStreamingRAGAgent",
+ "R2RResearchAgent",
+ "R2RStreamingResearchAgent",
+ "R2RXMLToolsResearchAgent",
+ "R2RXMLToolsStreamingResearchAgent",
+]
diff --git a/.venv/lib/python3.12/site-packages/core/agent/base.py b/.venv/lib/python3.12/site-packages/core/agent/base.py
new file mode 100644
index 00000000..84aae3f2
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/agent/base.py
@@ -0,0 +1,1484 @@
+import asyncio
+import json
+import logging
+import re
+from abc import ABCMeta
+from typing import AsyncGenerator, Optional, Tuple
+
+from core.base import AsyncSyncMeta, LLMChatCompletion, Message, syncable
+from core.base.agent import Agent, Conversation
+from core.utils import (
+ CitationTracker,
+ SearchResultsCollector,
+ SSEFormatter,
+ convert_nonserializable_objects,
+ dump_obj,
+ find_new_citation_spans,
+)
+
+logger = logging.getLogger()
+
+
+class CombinedMeta(AsyncSyncMeta, ABCMeta):
+ pass
+
+
+def sync_wrapper(async_gen):
+ loop = asyncio.get_event_loop()
+
+ def wrapper():
+ try:
+ while True:
+ try:
+ yield loop.run_until_complete(async_gen.__anext__())
+ except StopAsyncIteration:
+ break
+ finally:
+ loop.run_until_complete(async_gen.aclose())
+
+ return wrapper()
+
+
+class R2RAgent(Agent, metaclass=CombinedMeta):
+ def __init__(self, *args, **kwargs):
+ self.search_results_collector = SearchResultsCollector()
+ super().__init__(*args, **kwargs)
+ self._reset()
+
+ async def _generate_llm_summary(self, iterations_count: int) -> str:
+ """
+ Generate a summary of the conversation using the LLM when max iterations are exceeded.
+
+ Args:
+ iterations_count: The number of iterations that were completed
+
+ Returns:
+ A string containing the LLM-generated summary
+ """
+ try:
+ # Get all messages in the conversation
+ all_messages = await self.conversation.get_messages()
+
+ # Create a prompt for the LLM to summarize
+ summary_prompt = {
+ "role": "user",
+ "content": (
+ f"The conversation has reached the maximum limit of {iterations_count} iterations "
+ f"without completing the task. Please provide a concise summary of: "
+ f"1) The key information you've gathered that's relevant to the original query, "
+ f"2) What you've attempted so far and why it's incomplete, and "
+ f"3) A specific recommendation for how to proceed. "
+ f"Keep your summary brief (3-4 sentences total) and focused on the most valuable insights. If it is possible to answer the original user query, then do so now instead."
+ f"Start with '⚠️ **Maximum iterations exceeded**'"
+ ),
+ }
+
+ # Create a new message list with just the conversation history and summary request
+ summary_messages = all_messages + [summary_prompt]
+
+ # Get a completion for the summary
+ generation_config = self.get_generation_config(summary_prompt)
+ response = await self.llm_provider.aget_completion(
+ summary_messages,
+ generation_config,
+ )
+
+ return response.choices[0].message.content
+ except Exception as e:
+ logger.error(f"Error generating LLM summary: {str(e)}")
+ # Fall back to basic summary if LLM generation fails
+ return (
+ "⚠️ **Maximum iterations exceeded**\n\n"
+ "The agent reached the maximum iteration limit without completing the task. "
+ "Consider breaking your request into smaller steps or refining your query."
+ )
+
+ def _reset(self):
+ self._completed = False
+ self.conversation = Conversation()
+
+ @syncable
+ async def arun(
+ self,
+ messages: list[Message],
+ system_instruction: Optional[str] = None,
+ *args,
+ **kwargs,
+ ) -> list[dict]:
+ self._reset()
+ await self._setup(system_instruction)
+
+ if messages:
+ for message in messages:
+ await self.conversation.add_message(message)
+ iterations_count = 0
+ while (
+ not self._completed
+ and iterations_count < self.config.max_iterations
+ ):
+ iterations_count += 1
+ messages_list = await self.conversation.get_messages()
+ generation_config = self.get_generation_config(messages_list[-1])
+ response = await self.llm_provider.aget_completion(
+ messages_list,
+ generation_config,
+ )
+ logger.debug(f"R2RAgent response: {response}")
+ await self.process_llm_response(response, *args, **kwargs)
+
+ if not self._completed:
+ # Generate a summary of the conversation using the LLM
+ summary = await self._generate_llm_summary(iterations_count)
+ await self.conversation.add_message(
+ Message(role="assistant", content=summary)
+ )
+
+ # Return final content
+ all_messages: list[dict] = await self.conversation.get_messages()
+ all_messages.reverse()
+
+ output_messages = []
+ for message_2 in all_messages:
+ if (
+ # message_2.get("content")
+ message_2.get("content") != messages[-1].content
+ ):
+ output_messages.append(message_2)
+ else:
+ break
+ output_messages.reverse()
+
+ return output_messages
+
+ async def process_llm_response(
+ self, response: LLMChatCompletion, *args, **kwargs
+ ) -> None:
+ if not self._completed:
+ message = response.choices[0].message
+ finish_reason = response.choices[0].finish_reason
+
+ if finish_reason == "stop":
+ self._completed = True
+
+ # Determine which provider we're using
+ using_anthropic = (
+ "anthropic" in self.rag_generation_config.model.lower()
+ )
+
+ # OPENAI HANDLING
+ if not using_anthropic:
+ if message.tool_calls:
+ assistant_msg = Message(
+ role="assistant",
+ content="",
+ tool_calls=[msg.dict() for msg in message.tool_calls],
+ )
+ await self.conversation.add_message(assistant_msg)
+
+ # If there are multiple tool_calls, call them sequentially here
+ for tool_call in message.tool_calls:
+ await self.handle_function_or_tool_call(
+ tool_call.function.name,
+ tool_call.function.arguments,
+ tool_id=tool_call.id,
+ *args,
+ **kwargs,
+ )
+ else:
+ await self.conversation.add_message(
+ Message(role="assistant", content=message.content)
+ )
+ self._completed = True
+
+ else:
+ # First handle thinking blocks if present
+ if (
+ hasattr(message, "structured_content")
+ and message.structured_content
+ ):
+ # Check if structured_content contains any tool_use blocks
+ has_tool_use = any(
+ block.get("type") == "tool_use"
+ for block in message.structured_content
+ )
+
+ if not has_tool_use and message.tool_calls:
+ # If it has thinking but no tool_use, add a separate message with structured_content
+ assistant_msg = Message(
+ role="assistant",
+ structured_content=message.structured_content, # Use structured_content field
+ )
+ await self.conversation.add_message(assistant_msg)
+
+ # Add explicit tool_use blocks in a separate message
+ tool_uses = []
+ for tool_call in message.tool_calls:
+ # Safely parse arguments if they're a string
+ try:
+ if isinstance(
+ tool_call.function.arguments, str
+ ):
+ input_args = json.loads(
+ tool_call.function.arguments
+ )
+ else:
+ input_args = tool_call.function.arguments
+ except json.JSONDecodeError:
+ logger.error(
+ f"Failed to parse tool arguments: {tool_call.function.arguments}"
+ )
+ input_args = {
+ "_raw": tool_call.function.arguments
+ }
+
+ tool_uses.append(
+ {
+ "type": "tool_use",
+ "id": tool_call.id,
+ "name": tool_call.function.name,
+ "input": input_args,
+ }
+ )
+
+ # Add tool_use blocks as a separate assistant message with structured content
+ if tool_uses:
+ await self.conversation.add_message(
+ Message(
+ role="assistant",
+ structured_content=tool_uses,
+ content="",
+ )
+ )
+ else:
+ # If it already has tool_use or no tool_calls, preserve original structure
+ assistant_msg = Message(
+ role="assistant",
+ structured_content=message.structured_content,
+ )
+ await self.conversation.add_message(assistant_msg)
+
+ elif message.content:
+ # For regular text content
+ await self.conversation.add_message(
+ Message(role="assistant", content=message.content)
+ )
+
+ # If there are tool calls, add them as structured content
+ if message.tool_calls:
+ tool_uses = []
+ for tool_call in message.tool_calls:
+ # Same safe parsing as above
+ try:
+ if isinstance(
+ tool_call.function.arguments, str
+ ):
+ input_args = json.loads(
+ tool_call.function.arguments
+ )
+ else:
+ input_args = tool_call.function.arguments
+ except json.JSONDecodeError:
+ logger.error(
+ f"Failed to parse tool arguments: {tool_call.function.arguments}"
+ )
+ input_args = {
+ "_raw": tool_call.function.arguments
+ }
+
+ tool_uses.append(
+ {
+ "type": "tool_use",
+ "id": tool_call.id,
+ "name": tool_call.function.name,
+ "input": input_args,
+ }
+ )
+
+ await self.conversation.add_message(
+ Message(
+ role="assistant", structured_content=tool_uses
+ )
+ )
+
+ # NEW CASE: Handle tool_calls with no content or structured_content
+ elif message.tool_calls:
+ # Create tool_uses for the message with only tool_calls
+ tool_uses = []
+ for tool_call in message.tool_calls:
+ try:
+ if isinstance(tool_call.function.arguments, str):
+ input_args = json.loads(
+ tool_call.function.arguments
+ )
+ else:
+ input_args = tool_call.function.arguments
+ except json.JSONDecodeError:
+ logger.error(
+ f"Failed to parse tool arguments: {tool_call.function.arguments}"
+ )
+ input_args = {"_raw": tool_call.function.arguments}
+
+ tool_uses.append(
+ {
+ "type": "tool_use",
+ "id": tool_call.id,
+ "name": tool_call.function.name,
+ "input": input_args,
+ }
+ )
+
+ # Add tool_use blocks as a message before processing tools
+ if tool_uses:
+ await self.conversation.add_message(
+ Message(
+ role="assistant",
+ structured_content=tool_uses,
+ )
+ )
+
+ # Process the tool calls
+ if message.tool_calls:
+ for tool_call in message.tool_calls:
+ await self.handle_function_or_tool_call(
+ tool_call.function.name,
+ tool_call.function.arguments,
+ tool_id=tool_call.id,
+ *args,
+ **kwargs,
+ )
+
+
+class R2RStreamingAgent(R2RAgent):
+ """
+ Base class for all streaming agents with core streaming functionality.
+ Supports emitting messages, tool calls, and results as SSE events.
+ """
+
+ # These two regexes will detect bracket references and then find short IDs.
+ BRACKET_PATTERN = re.compile(r"\[([^\]]+)\]")
+ SHORT_ID_PATTERN = re.compile(
+ r"[A-Za-z0-9]{7,8}"
+ ) # 7-8 chars, for example
+
+ def __init__(self, *args, **kwargs):
+ # Force streaming on
+ if hasattr(kwargs.get("config", {}), "stream"):
+ kwargs["config"].stream = True
+ super().__init__(*args, **kwargs)
+
+ async def arun(
+ self,
+ system_instruction: str | None = None,
+ messages: list[Message] | None = None,
+ *args,
+ **kwargs,
+ ) -> AsyncGenerator[str, None]:
+ """
+ Main streaming entrypoint: returns an async generator of SSE lines.
+ """
+ self._reset()
+ await self._setup(system_instruction)
+
+ if messages:
+ for m in messages:
+ await self.conversation.add_message(m)
+
+ # Initialize citation tracker for this run
+ citation_tracker = CitationTracker()
+
+ # Dictionary to store citation payloads by ID
+ citation_payloads = {}
+
+ # Track all citations emitted during streaming for final persistence
+ self.streaming_citations: list[dict] = []
+
+ async def sse_generator() -> AsyncGenerator[str, None]:
+ pending_tool_calls = {}
+ partial_text_buffer = ""
+ iterations_count = 0
+
+ try:
+ # Keep streaming until we complete
+ while (
+ not self._completed
+ and iterations_count < self.config.max_iterations
+ ):
+ iterations_count += 1
+ # 1) Get current messages
+ msg_list = await self.conversation.get_messages()
+ gen_cfg = self.get_generation_config(
+ msg_list[-1], stream=True
+ )
+
+ accumulated_thinking = ""
+ thinking_signatures = {} # Map thinking content to signatures
+
+ # 2) Start streaming from LLM
+ llm_stream = self.llm_provider.aget_completion_stream(
+ msg_list, gen_cfg
+ )
+ async for chunk in llm_stream:
+ delta = chunk.choices[0].delta
+ finish_reason = chunk.choices[0].finish_reason
+
+ if hasattr(delta, "thinking") and delta.thinking:
+ # Accumulate thinking for later use in messages
+ accumulated_thinking += delta.thinking
+
+ # Emit SSE "thinking" event
+ async for (
+ line
+ ) in SSEFormatter.yield_thinking_event(
+ delta.thinking
+ ):
+ yield line
+
+ # Add this new handler for thinking signatures
+ if hasattr(delta, "thinking_signature"):
+ thinking_signatures[accumulated_thinking] = (
+ delta.thinking_signature
+ )
+ accumulated_thinking = ""
+
+ # 3) If new text, accumulate it
+ if delta.content:
+ partial_text_buffer += delta.content
+
+ # (a) Now emit the newly streamed text as a "message" event
+ async for line in SSEFormatter.yield_message_event(
+ delta.content
+ ):
+ yield line
+
+ # (b) Find new citation spans in the accumulated text
+ new_citation_spans = find_new_citation_spans(
+ partial_text_buffer, citation_tracker
+ )
+
+ # Process each new citation span
+ for cid, spans in new_citation_spans.items():
+ for span in spans:
+ # Check if this is the first time we've seen this citation ID
+ is_new_citation = (
+ citation_tracker.is_new_citation(cid)
+ )
+
+ # Get payload if it's a new citation
+ payload = None
+ if is_new_citation:
+ source_obj = self.search_results_collector.find_by_short_id(
+ cid
+ )
+ if source_obj:
+ # Store payload for reuse
+ payload = dump_obj(source_obj)
+ citation_payloads[cid] = payload
+
+ # Create citation event payload
+ citation_data = {
+ "id": cid,
+ "object": "citation",
+ "is_new": is_new_citation,
+ "span": {
+ "start": span[0],
+ "end": span[1],
+ },
+ }
+
+ # Only include full payload for new citations
+ if is_new_citation and payload:
+ citation_data["payload"] = payload
+
+ # Add to streaming citations for final answer
+ self.streaming_citations.append(
+ citation_data
+ )
+
+ # Emit the citation event
+ async for (
+ line
+ ) in SSEFormatter.yield_citation_event(
+ citation_data
+ ):
+ yield line
+
+ if delta.tool_calls:
+ for tc in delta.tool_calls:
+ idx = tc.index
+ if idx not in pending_tool_calls:
+ pending_tool_calls[idx] = {
+ "id": tc.id,
+ "name": tc.function.name or "",
+ "arguments": tc.function.arguments
+ or "",
+ }
+ else:
+ # Accumulate partial name/arguments
+ if tc.function.name:
+ pending_tool_calls[idx]["name"] = (
+ tc.function.name
+ )
+ if tc.function.arguments:
+ pending_tool_calls[idx][
+ "arguments"
+ ] += tc.function.arguments
+
+ # 5) If the stream signals we should handle "tool_calls"
+ if finish_reason == "tool_calls":
+ # Handle thinking if present
+ await self._handle_thinking(
+ thinking_signatures, accumulated_thinking
+ )
+
+ calls_list = []
+ for idx in sorted(pending_tool_calls.keys()):
+ cinfo = pending_tool_calls[idx]
+ calls_list.append(
+ {
+ "tool_call_id": cinfo["id"]
+ or f"call_{idx}",
+ "name": cinfo["name"],
+ "arguments": cinfo["arguments"],
+ }
+ )
+
+ # (a) Emit SSE "tool_call" events
+ for c in calls_list:
+ tc_data = self._create_tool_call_data(c)
+ async for (
+ line
+ ) in SSEFormatter.yield_tool_call_event(
+ tc_data
+ ):
+ yield line
+
+ # (b) Add an assistant message capturing these calls
+ await self._add_tool_calls_message(
+ calls_list, partial_text_buffer
+ )
+
+ # (c) Execute each tool call in parallel
+ await asyncio.gather(
+ *[
+ self.handle_function_or_tool_call(
+ c["name"],
+ c["arguments"],
+ tool_id=c["tool_call_id"],
+ )
+ for c in calls_list
+ ]
+ )
+
+ # Reset buffer & calls
+ pending_tool_calls.clear()
+ partial_text_buffer = ""
+
+ elif finish_reason == "stop":
+ # Handle thinking if present
+ await self._handle_thinking(
+ thinking_signatures, accumulated_thinking
+ )
+
+ # 6) The LLM is done. If we have any leftover partial text,
+ # finalize it in the conversation
+ if partial_text_buffer:
+ # Create the final message with metadata including citations
+ final_message = Message(
+ role="assistant",
+ content=partial_text_buffer,
+ metadata={
+ "citations": self.streaming_citations
+ },
+ )
+
+ # Add it to the conversation
+ await self.conversation.add_message(
+ final_message
+ )
+
+ # (a) Prepare final answer with optimized citations
+ consolidated_citations = []
+ # Group citations by ID with all their spans
+ for (
+ cid,
+ spans,
+ ) in citation_tracker.get_all_spans().items():
+ if cid in citation_payloads:
+ consolidated_citations.append(
+ {
+ "id": cid,
+ "object": "citation",
+ "spans": [
+ {"start": s[0], "end": s[1]}
+ for s in spans
+ ],
+ "payload": citation_payloads[cid],
+ }
+ )
+
+ # Create final answer payload
+ final_evt_payload = {
+ "id": "msg_final",
+ "object": "agent.final_answer",
+ "generated_answer": partial_text_buffer,
+ "citations": consolidated_citations,
+ }
+
+ # Emit final answer event
+ async for (
+ line
+ ) in SSEFormatter.yield_final_answer_event(
+ final_evt_payload
+ ):
+ yield line
+
+ # (b) Signal the end of the SSE stream
+ yield SSEFormatter.yield_done_event()
+ self._completed = True
+ break
+
+ # If we exit the while loop due to hitting max iterations
+ if not self._completed:
+ # Generate a summary using the LLM
+ summary = await self._generate_llm_summary(
+ iterations_count
+ )
+
+ # Send the summary as a message event
+ async for line in SSEFormatter.yield_message_event(
+ summary
+ ):
+ yield line
+
+ # Add summary to conversation with citations metadata
+ await self.conversation.add_message(
+ Message(
+ role="assistant",
+ content=summary,
+ metadata={"citations": self.streaming_citations},
+ )
+ )
+
+ # Create and emit a final answer payload with the summary
+ final_evt_payload = {
+ "id": "msg_final",
+ "object": "agent.final_answer",
+ "generated_answer": summary,
+ "citations": consolidated_citations,
+ }
+
+ async for line in SSEFormatter.yield_final_answer_event(
+ final_evt_payload
+ ):
+ yield line
+
+ # Signal the end of the SSE stream
+ yield SSEFormatter.yield_done_event()
+ self._completed = True
+
+ except Exception as e:
+ logger.error(f"Error in streaming agent: {str(e)}")
+ # Emit error event for client
+ async for line in SSEFormatter.yield_error_event(
+ f"Agent error: {str(e)}"
+ ):
+ yield line
+ # Send done event to close the stream
+ yield SSEFormatter.yield_done_event()
+
+ # Finally, we return the async generator
+ async for line in sse_generator():
+ yield line
+
+ async def _handle_thinking(
+ self, thinking_signatures, accumulated_thinking
+ ):
+ """Process any accumulated thinking content"""
+ if accumulated_thinking:
+ structured_content = [
+ {
+ "type": "thinking",
+ "thinking": accumulated_thinking,
+ # Anthropic will validate this in their API
+ "signature": "placeholder_signature",
+ }
+ ]
+
+ assistant_msg = Message(
+ role="assistant",
+ structured_content=structured_content,
+ )
+ await self.conversation.add_message(assistant_msg)
+
+ elif thinking_signatures:
+ for (
+ accumulated_thinking,
+ thinking_signature,
+ ) in thinking_signatures.items():
+ structured_content = [
+ {
+ "type": "thinking",
+ "thinking": accumulated_thinking,
+ # Anthropic will validate this in their API
+ "signature": thinking_signature,
+ }
+ ]
+
+ assistant_msg = Message(
+ role="assistant",
+ structured_content=structured_content,
+ )
+ await self.conversation.add_message(assistant_msg)
+
+ async def _add_tool_calls_message(self, calls_list, partial_text_buffer):
+ """Add a message with tool calls to the conversation"""
+ assistant_msg = Message(
+ role="assistant",
+ content=partial_text_buffer or "",
+ tool_calls=[
+ {
+ "id": c["tool_call_id"],
+ "type": "function",
+ "function": {
+ "name": c["name"],
+ "arguments": c["arguments"],
+ },
+ }
+ for c in calls_list
+ ],
+ )
+ await self.conversation.add_message(assistant_msg)
+
+ def _create_tool_call_data(self, call_info):
+ """Create tool call data structure from call info"""
+ return {
+ "tool_call_id": call_info["tool_call_id"],
+ "name": call_info["name"],
+ "arguments": call_info["arguments"],
+ }
+
+ def _create_citation_payload(self, short_id, payload):
+ """Create citation payload for a short ID"""
+ # This will be overridden in RAG subclasses
+ # check if as_dict is on payload
+ if hasattr(payload, "as_dict"):
+ payload = payload.as_dict()
+ if hasattr(payload, "dict"):
+ payload = payload.dict
+ if hasattr(payload, "to_dict"):
+ payload = payload.to_dict()
+
+ return {
+ "id": f"{short_id}",
+ "object": "citation",
+ "payload": dump_obj(payload), # Will be populated in RAG agents
+ }
+
+ def _create_final_answer_payload(self, answer_text, citations):
+ """Create the final answer payload"""
+ # This will be extended in RAG subclasses
+ return {
+ "id": "msg_final",
+ "object": "agent.final_answer",
+ "generated_answer": answer_text,
+ "citations": citations,
+ }
+
+
+class R2RXMLStreamingAgent(R2RStreamingAgent):
+ """
+ A streaming agent that parses XML-formatted responses with special handling for:
+ - <think> or <Thought> blocks for chain-of-thought reasoning
+ - <Action>, <ToolCalls>, <ToolCall> blocks for tool execution
+ """
+
+ # We treat <think> or <Thought> as the same token boundaries
+ THOUGHT_OPEN = re.compile(r"<(Thought|think)>", re.IGNORECASE)
+ THOUGHT_CLOSE = re.compile(r"</(Thought|think)>", re.IGNORECASE)
+
+ # Regexes to parse out <Action>, <ToolCalls>, <ToolCall>, <Name>, <Parameters>, <Response>
+ ACTION_PATTERN = re.compile(
+ r"<Action>(.*?)</Action>", re.IGNORECASE | re.DOTALL
+ )
+ TOOLCALLS_PATTERN = re.compile(
+ r"<ToolCalls>(.*?)</ToolCalls>", re.IGNORECASE | re.DOTALL
+ )
+ TOOLCALL_PATTERN = re.compile(
+ r"<ToolCall>(.*?)</ToolCall>", re.IGNORECASE | re.DOTALL
+ )
+ NAME_PATTERN = re.compile(r"<Name>(.*?)</Name>", re.IGNORECASE | re.DOTALL)
+ PARAMS_PATTERN = re.compile(
+ r"<Parameters>(.*?)</Parameters>", re.IGNORECASE | re.DOTALL
+ )
+ RESPONSE_PATTERN = re.compile(
+ r"<Response>(.*?)</Response>", re.IGNORECASE | re.DOTALL
+ )
+
+ async def arun(
+ self,
+ system_instruction: str | None = None,
+ messages: list[Message] | None = None,
+ *args,
+ **kwargs,
+ ) -> AsyncGenerator[str, None]:
+ """
+ Main streaming entrypoint: returns an async generator of SSE lines.
+ """
+ self._reset()
+ await self._setup(system_instruction)
+
+ if messages:
+ for m in messages:
+ await self.conversation.add_message(m)
+
+ # Initialize citation tracker for this run
+ citation_tracker = CitationTracker()
+
+ # Dictionary to store citation payloads by ID
+ citation_payloads = {}
+
+ # Track all citations emitted during streaming for final persistence
+ self.streaming_citations: list[dict] = []
+
+ async def sse_generator() -> AsyncGenerator[str, None]:
+ iterations_count = 0
+
+ try:
+ # Keep streaming until we complete
+ while (
+ not self._completed
+ and iterations_count < self.config.max_iterations
+ ):
+ iterations_count += 1
+ # 1) Get current messages
+ msg_list = await self.conversation.get_messages()
+ gen_cfg = self.get_generation_config(
+ msg_list[-1], stream=True
+ )
+
+ # 2) Start streaming from LLM
+ llm_stream = self.llm_provider.aget_completion_stream(
+ msg_list, gen_cfg
+ )
+
+ # Create state variables for each iteration
+ iteration_buffer = ""
+ yielded_first_event = False
+ in_action_block = False
+ is_thinking = False
+ accumulated_thinking = ""
+ thinking_signatures = {}
+
+ async for chunk in llm_stream:
+ delta = chunk.choices[0].delta
+ finish_reason = chunk.choices[0].finish_reason
+
+ # Handle thinking if present
+ if hasattr(delta, "thinking") and delta.thinking:
+ # Accumulate thinking for later use in messages
+ accumulated_thinking += delta.thinking
+
+ # Emit SSE "thinking" event
+ async for (
+ line
+ ) in SSEFormatter.yield_thinking_event(
+ delta.thinking
+ ):
+ yield line
+
+ # Add this new handler for thinking signatures
+ if hasattr(delta, "thinking_signature"):
+ thinking_signatures[accumulated_thinking] = (
+ delta.thinking_signature
+ )
+ accumulated_thinking = ""
+
+ # 3) If new text, accumulate it
+ if delta.content:
+ iteration_buffer += delta.content
+
+ # Check if we have accumulated enough text for a `<Thought>` block
+ if len(iteration_buffer) < len("<Thought>"):
+ continue
+
+ # Check if we have yielded the first event
+ if not yielded_first_event:
+ # Emit the first chunk
+ if self.THOUGHT_OPEN.findall(iteration_buffer):
+ is_thinking = True
+ async for (
+ line
+ ) in SSEFormatter.yield_thinking_event(
+ iteration_buffer
+ ):
+ yield line
+ else:
+ async for (
+ line
+ ) in SSEFormatter.yield_message_event(
+ iteration_buffer
+ ):
+ yield line
+
+ # Mark as yielded
+ yielded_first_event = True
+ continue
+
+ # Check if we are in a thinking block
+ if is_thinking:
+ # Still thinking, so keep yielding thinking events
+ if not self.THOUGHT_CLOSE.findall(
+ iteration_buffer
+ ):
+ # Emit SSE "thinking" event
+ async for (
+ line
+ ) in SSEFormatter.yield_thinking_event(
+ delta.content
+ ):
+ yield line
+
+ continue
+ # Done thinking, so emit the last thinking event
+ else:
+ is_thinking = False
+ thought_text = delta.content.split(
+ "</Thought>"
+ )[0].split("</think>")[0]
+ async for (
+ line
+ ) in SSEFormatter.yield_thinking_event(
+ thought_text
+ ):
+ yield line
+ post_thought_text = delta.content.split(
+ "</Thought>"
+ )[-1].split("</think>")[-1]
+ delta.content = post_thought_text
+
+ # (b) Find new citation spans in the accumulated text
+ new_citation_spans = find_new_citation_spans(
+ iteration_buffer, citation_tracker
+ )
+
+ # Process each new citation span
+ for cid, spans in new_citation_spans.items():
+ for span in spans:
+ # Check if this is the first time we've seen this citation ID
+ is_new_citation = (
+ citation_tracker.is_new_citation(cid)
+ )
+
+ # Get payload if it's a new citation
+ payload = None
+ if is_new_citation:
+ source_obj = self.search_results_collector.find_by_short_id(
+ cid
+ )
+ if source_obj:
+ # Store payload for reuse
+ payload = dump_obj(source_obj)
+ citation_payloads[cid] = payload
+
+ # Create citation event payload
+ citation_data = {
+ "id": cid,
+ "object": "citation",
+ "is_new": is_new_citation,
+ "span": {
+ "start": span[0],
+ "end": span[1],
+ },
+ }
+
+ # Only include full payload for new citations
+ if is_new_citation and payload:
+ citation_data["payload"] = payload
+
+ # Add to streaming citations for final answer
+ self.streaming_citations.append(
+ citation_data
+ )
+
+ # Emit the citation event
+ async for (
+ line
+ ) in SSEFormatter.yield_citation_event(
+ citation_data
+ ):
+ yield line
+
+ # Now prepare to emit the newly streamed text as a "message" event
+ if (
+ iteration_buffer.count("<")
+ and not in_action_block
+ ):
+ in_action_block = True
+
+ if (
+ in_action_block
+ and len(
+ self.ACTION_PATTERN.findall(
+ iteration_buffer
+ )
+ )
+ < 2
+ ):
+ continue
+
+ elif in_action_block:
+ in_action_block = False
+ # Emit the post action block text, if it is there
+ post_action_text = iteration_buffer.split(
+ "</Action>"
+ )[-1]
+ if post_action_text:
+ async for (
+ line
+ ) in SSEFormatter.yield_message_event(
+ post_action_text
+ ):
+ yield line
+
+ else:
+ async for (
+ line
+ ) in SSEFormatter.yield_message_event(
+ delta.content
+ ):
+ yield line
+
+ elif finish_reason == "stop":
+ break
+
+ # Process any accumulated thinking
+ await self._handle_thinking(
+ thinking_signatures, accumulated_thinking
+ )
+
+ # 6) The LLM is done. If we have any leftover partial text,
+ # finalize it in the conversation
+ if iteration_buffer:
+ # Create the final message with metadata including citations
+ final_message = Message(
+ role="assistant",
+ content=iteration_buffer,
+ metadata={"citations": self.streaming_citations},
+ )
+
+ # Add it to the conversation
+ await self.conversation.add_message(final_message)
+
+ # --- 4) Process any <Action>/<ToolCalls> blocks, or mark completed
+ action_matches = self.ACTION_PATTERN.findall(
+ iteration_buffer
+ )
+
+ if len(action_matches) > 0:
+ # Process each ToolCall
+ xml_toolcalls = "<ToolCalls>"
+
+ for action_block in action_matches:
+ tool_calls_text = []
+ # Look for ToolCalls wrapper, or use the raw action block
+ calls_wrapper = self.TOOLCALLS_PATTERN.findall(
+ action_block
+ )
+ if calls_wrapper:
+ for tw in calls_wrapper:
+ tool_calls_text.append(tw)
+ else:
+ tool_calls_text.append(action_block)
+
+ for calls_region in tool_calls_text:
+ calls_found = self.TOOLCALL_PATTERN.findall(
+ calls_region
+ )
+ for tc_block in calls_found:
+ tool_name, tool_params = (
+ self._parse_single_tool_call(tc_block)
+ )
+ if tool_name:
+ # Emit SSE event for tool call
+ tool_call_id = (
+ f"call_{abs(hash(tc_block))}"
+ )
+ call_evt_data = {
+ "tool_call_id": tool_call_id,
+ "name": tool_name,
+ "arguments": json.dumps(
+ tool_params
+ ),
+ }
+ async for line in (
+ SSEFormatter.yield_tool_call_event(
+ call_evt_data
+ )
+ ):
+ yield line
+
+ try:
+ tool_result = await self.handle_function_or_tool_call(
+ tool_name,
+ json.dumps(tool_params),
+ tool_id=tool_call_id,
+ save_messages=False,
+ )
+ result_content = tool_result.llm_formatted_result
+ except Exception as e:
+ result_content = f"Error in tool '{tool_name}': {str(e)}"
+
+ xml_toolcalls += (
+ f"<ToolCall>"
+ f"<Name>{tool_name}</Name>"
+ f"<Parameters>{json.dumps(tool_params)}</Parameters>"
+ f"<Result>{result_content}</Result>"
+ f"</ToolCall>"
+ )
+
+ # Emit SSE tool result for non-result tools
+ result_data = {
+ "tool_call_id": tool_call_id,
+ "role": "tool",
+ "content": json.dumps(
+ convert_nonserializable_objects(
+ result_content
+ )
+ ),
+ }
+ async for line in SSEFormatter.yield_tool_result_event(
+ result_data
+ ):
+ yield line
+
+ xml_toolcalls += "</ToolCalls>"
+ pre_action_text = iteration_buffer[
+ : iteration_buffer.find(action_block)
+ ]
+ post_action_text = iteration_buffer[
+ iteration_buffer.find(action_block)
+ + len(action_block) :
+ ]
+ iteration_text = (
+ pre_action_text + xml_toolcalls + post_action_text
+ )
+
+ # Update the conversation with tool results
+ await self.conversation.add_message(
+ Message(
+ role="assistant",
+ content=iteration_text,
+ metadata={
+ "citations": self.streaming_citations
+ },
+ )
+ )
+ else:
+ # (a) Prepare final answer with optimized citations
+ consolidated_citations = []
+ # Group citations by ID with all their spans
+ for (
+ cid,
+ spans,
+ ) in citation_tracker.get_all_spans().items():
+ if cid in citation_payloads:
+ consolidated_citations.append(
+ {
+ "id": cid,
+ "object": "citation",
+ "spans": [
+ {"start": s[0], "end": s[1]}
+ for s in spans
+ ],
+ "payload": citation_payloads[cid],
+ }
+ )
+
+ # Create final answer payload
+ final_evt_payload = {
+ "id": "msg_final",
+ "object": "agent.final_answer",
+ "generated_answer": iteration_buffer,
+ "citations": consolidated_citations,
+ }
+
+ # Emit final answer event
+ async for (
+ line
+ ) in SSEFormatter.yield_final_answer_event(
+ final_evt_payload
+ ):
+ yield line
+
+ # (b) Signal the end of the SSE stream
+ yield SSEFormatter.yield_done_event()
+ self._completed = True
+
+ # If we exit the while loop due to hitting max iterations
+ if not self._completed:
+ # Generate a summary using the LLM
+ summary = await self._generate_llm_summary(
+ iterations_count
+ )
+
+ # Send the summary as a message event
+ async for line in SSEFormatter.yield_message_event(
+ summary
+ ):
+ yield line
+
+ # Add summary to conversation with citations metadata
+ await self.conversation.add_message(
+ Message(
+ role="assistant",
+ content=summary,
+ metadata={"citations": self.streaming_citations},
+ )
+ )
+
+ # Create and emit a final answer payload with the summary
+ final_evt_payload = {
+ "id": "msg_final",
+ "object": "agent.final_answer",
+ "generated_answer": summary,
+ "citations": consolidated_citations,
+ }
+
+ async for line in SSEFormatter.yield_final_answer_event(
+ final_evt_payload
+ ):
+ yield line
+
+ # Signal the end of the SSE stream
+ yield SSEFormatter.yield_done_event()
+ self._completed = True
+
+ except Exception as e:
+ logger.error(f"Error in streaming agent: {str(e)}")
+ # Emit error event for client
+ async for line in SSEFormatter.yield_error_event(
+ f"Agent error: {str(e)}"
+ ):
+ yield line
+ # Send done event to close the stream
+ yield SSEFormatter.yield_done_event()
+
+ # Finally, we return the async generator
+ async for line in sse_generator():
+ yield line
+
+ def _parse_single_tool_call(
+ self, toolcall_text: str
+ ) -> Tuple[Optional[str], dict]:
+ """
+ Parse a ToolCall block to extract the name and parameters.
+
+ Args:
+ toolcall_text: The text content of a ToolCall block
+
+ Returns:
+ Tuple of (tool_name, tool_parameters)
+ """
+ name_match = self.NAME_PATTERN.search(toolcall_text)
+ if not name_match:
+ return None, {}
+ tool_name = name_match.group(1).strip()
+
+ params_match = self.PARAMS_PATTERN.search(toolcall_text)
+ if not params_match:
+ return tool_name, {}
+
+ raw_params = params_match.group(1).strip()
+ try:
+ # Handle potential JSON parsing issues
+ # First try direct parsing
+ tool_params = json.loads(raw_params)
+ except json.JSONDecodeError:
+ # If that fails, try to clean up the JSON string
+ try:
+ # Replace escaped quotes that might cause issues
+ cleaned_params = raw_params.replace('\\"', '"')
+ # Try again with the cleaned string
+ tool_params = json.loads(cleaned_params)
+ except json.JSONDecodeError:
+ # If all else fails, treat as a plain string value
+ tool_params = {"value": raw_params}
+
+ return tool_name, tool_params
+
+
+class R2RXMLToolsAgent(R2RAgent):
+ """
+ A non-streaming agent that:
+ - parses <think> or <Thought> blocks as chain-of-thought
+ - filters out XML tags related to tool calls and actions
+ - processes <Action><ToolCalls><ToolCall> blocks
+ - properly extracts citations when they appear in the text
+ """
+
+ # We treat <think> or <Thought> as the same token boundaries
+ THOUGHT_OPEN = re.compile(r"<(Thought|think)>", re.IGNORECASE)
+ THOUGHT_CLOSE = re.compile(r"</(Thought|think)>", re.IGNORECASE)
+
+ # Regexes to parse out <Action>, <ToolCalls>, <ToolCall>, <Name>, <Parameters>, <Response>
+ ACTION_PATTERN = re.compile(
+ r"<Action>(.*?)</Action>", re.IGNORECASE | re.DOTALL
+ )
+ TOOLCALLS_PATTERN = re.compile(
+ r"<ToolCalls>(.*?)</ToolCalls>", re.IGNORECASE | re.DOTALL
+ )
+ TOOLCALL_PATTERN = re.compile(
+ r"<ToolCall>(.*?)</ToolCall>", re.IGNORECASE | re.DOTALL
+ )
+ NAME_PATTERN = re.compile(r"<Name>(.*?)</Name>", re.IGNORECASE | re.DOTALL)
+ PARAMS_PATTERN = re.compile(
+ r"<Parameters>(.*?)</Parameters>", re.IGNORECASE | re.DOTALL
+ )
+ RESPONSE_PATTERN = re.compile(
+ r"<Response>(.*?)</Response>", re.IGNORECASE | re.DOTALL
+ )
+
+ async def process_llm_response(self, response, *args, **kwargs):
+ """
+ Override the base process_llm_response to handle XML structured responses
+ including thoughts and tool calls.
+ """
+ if self._completed:
+ return
+
+ message = response.choices[0].message
+ finish_reason = response.choices[0].finish_reason
+
+ if not message.content:
+ # If there's no content, let the parent class handle the normal tool_calls flow
+ return await super().process_llm_response(
+ response, *args, **kwargs
+ )
+
+ # Get the response content
+ content = message.content
+
+ # HACK for gemini
+ content = content.replace("```action", "")
+ content = content.replace("```tool_code", "")
+ content = content.replace("```", "")
+
+ if (
+ not content.startswith("<")
+ and "deepseek" in self.rag_generation_config.model
+ ): # HACK - fix issues with adding `<think>` to the beginning
+ content = "<think>" + content
+
+ # Process any tool calls in the content
+ action_matches = self.ACTION_PATTERN.findall(content)
+ if action_matches:
+ xml_toolcalls = "<ToolCalls>"
+ for action_block in action_matches:
+ tool_calls_text = []
+ # Look for ToolCalls wrapper, or use the raw action block
+ calls_wrapper = self.TOOLCALLS_PATTERN.findall(action_block)
+ if calls_wrapper:
+ for tw in calls_wrapper:
+ tool_calls_text.append(tw)
+ else:
+ tool_calls_text.append(action_block)
+
+ # Process each ToolCall
+ for calls_region in tool_calls_text:
+ calls_found = self.TOOLCALL_PATTERN.findall(calls_region)
+ for tc_block in calls_found:
+ tool_name, tool_params = self._parse_single_tool_call(
+ tc_block
+ )
+ if tool_name:
+ tool_call_id = f"call_{abs(hash(tc_block))}"
+ try:
+ tool_result = (
+ await self.handle_function_or_tool_call(
+ tool_name,
+ json.dumps(tool_params),
+ tool_id=tool_call_id,
+ save_messages=False,
+ )
+ )
+
+ # Add tool result to XML
+ xml_toolcalls += (
+ f"<ToolCall>"
+ f"<Name>{tool_name}</Name>"
+ f"<Parameters>{json.dumps(tool_params)}</Parameters>"
+ f"<Result>{tool_result.llm_formatted_result}</Result>"
+ f"</ToolCall>"
+ )
+
+ except Exception as e:
+ logger.error(f"Error in tool call: {str(e)}")
+ # Add error to XML
+ xml_toolcalls += (
+ f"<ToolCall>"
+ f"<Name>{tool_name}</Name>"
+ f"<Parameters>{json.dumps(tool_params)}</Parameters>"
+ f"<Result>Error: {str(e)}</Result>"
+ f"</ToolCall>"
+ )
+
+ xml_toolcalls += "</ToolCalls>"
+ pre_action_text = content[: content.find(action_block)]
+ post_action_text = content[
+ content.find(action_block) + len(action_block) :
+ ]
+ iteration_text = pre_action_text + xml_toolcalls + post_action_text
+
+ # Create the assistant message
+ await self.conversation.add_message(
+ Message(role="assistant", content=iteration_text)
+ )
+ else:
+ # Create an assistant message with the content as-is
+ await self.conversation.add_message(
+ Message(role="assistant", content=content)
+ )
+
+ # Only mark as completed if the finish_reason is "stop" or there are no action calls
+ # This allows the agent to continue the conversation when tool calls are processed
+ if finish_reason == "stop":
+ self._completed = True
+
+ def _parse_single_tool_call(
+ self, toolcall_text: str
+ ) -> Tuple[Optional[str], dict]:
+ """
+ Parse a ToolCall block to extract the name and parameters.
+
+ Args:
+ toolcall_text: The text content of a ToolCall block
+
+ Returns:
+ Tuple of (tool_name, tool_parameters)
+ """
+ name_match = self.NAME_PATTERN.search(toolcall_text)
+ if not name_match:
+ return None, {}
+ tool_name = name_match.group(1).strip()
+
+ params_match = self.PARAMS_PATTERN.search(toolcall_text)
+ if not params_match:
+ return tool_name, {}
+
+ raw_params = params_match.group(1).strip()
+ try:
+ # Handle potential JSON parsing issues
+ # First try direct parsing
+ tool_params = json.loads(raw_params)
+ except json.JSONDecodeError:
+ # If that fails, try to clean up the JSON string
+ try:
+ # Replace escaped quotes that might cause issues
+ cleaned_params = raw_params.replace('\\"', '"')
+ # Try again with the cleaned string
+ tool_params = json.loads(cleaned_params)
+ except json.JSONDecodeError:
+ # If all else fails, treat as a plain string value
+ tool_params = {"value": raw_params}
+
+ return tool_name, tool_params
diff --git a/.venv/lib/python3.12/site-packages/core/agent/rag.py b/.venv/lib/python3.12/site-packages/core/agent/rag.py
new file mode 100644
index 00000000..6f3ab630
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/agent/rag.py
@@ -0,0 +1,662 @@
+# type: ignore
+import logging
+from typing import Any, Callable, Optional
+
+from core.base import (
+ format_search_results_for_llm,
+)
+from core.base.abstractions import (
+ AggregateSearchResult,
+ GenerationConfig,
+ SearchSettings,
+ WebPageSearchResult,
+ WebSearchResult,
+)
+from core.base.agent import Tool
+from core.base.providers import DatabaseProvider
+from core.providers import (
+ AnthropicCompletionProvider,
+ LiteLLMCompletionProvider,
+ OpenAICompletionProvider,
+ R2RCompletionProvider,
+)
+from core.utils import (
+ SearchResultsCollector,
+ generate_id,
+ num_tokens,
+)
+
+from ..base.agent.agent import RAGAgentConfig
+
+# Import the base classes from the refactored base file
+from .base import (
+ R2RAgent,
+ R2RStreamingAgent,
+ R2RXMLStreamingAgent,
+ R2RXMLToolsAgent,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class RAGAgentMixin:
+ """
+ A Mixin for adding search_file_knowledge, web_search, and content tools
+ to your R2R Agents. This allows your agent to:
+ - call knowledge_search_method (semantic/hybrid search)
+ - call content_method (fetch entire doc/chunk structures)
+ - call an external web search API
+ """
+
+ def __init__(
+ self,
+ *args,
+ search_settings: SearchSettings,
+ knowledge_search_method: Callable,
+ content_method: Callable,
+ file_search_method: Callable,
+ max_tool_context_length=10_000,
+ max_context_window_tokens=512_000,
+ **kwargs,
+ ):
+ # Save references to the retrieval logic
+ self.search_settings = search_settings
+ self.knowledge_search_method = knowledge_search_method
+ self.content_method = content_method
+ self.file_search_method = file_search_method
+ self.max_tool_context_length = max_tool_context_length
+ self.max_context_window_tokens = max_context_window_tokens
+ self.search_results_collector = SearchResultsCollector()
+ super().__init__(*args, **kwargs)
+
+ def _register_tools(self):
+ """
+ Called by the base R2RAgent to register all requested tools from self.config.rag_tools.
+ """
+ if not self.config.rag_tools:
+ return
+
+ for tool_name in set(self.config.rag_tools):
+ if tool_name == "get_file_content":
+ self._tools.append(self.content())
+ elif tool_name == "web_scrape":
+ self._tools.append(self.web_scrape())
+ elif tool_name == "search_file_knowledge":
+ self._tools.append(self.search_file_knowledge())
+ elif tool_name == "search_file_descriptions":
+ self._tools.append(self.search_files())
+ elif tool_name == "web_search":
+ self._tools.append(self.web_search())
+ else:
+ raise ValueError(f"Unknown tool requested: {tool_name}")
+ logger.debug(f"Registered {len(self._tools)} RAG tools.")
+
+ # Local Search Tool
+ def search_file_knowledge(self) -> Tool:
+ """
+ Tool to do a semantic/hybrid search on the local knowledge base
+ using self.knowledge_search_method.
+ """
+ return Tool(
+ name="search_file_knowledge",
+ description=(
+ "Search your local knowledge base using the R2R system. "
+ "Use this when you want relevant text chunks or knowledge graph data."
+ ),
+ results_function=self._file_knowledge_search_function,
+ llm_format_function=self.format_search_results_for_llm,
+ parameters={
+ "type": "object",
+ "properties": {
+ "query": {
+ "type": "string",
+ "description": "User query to search in the local DB.",
+ },
+ },
+ "required": ["query"],
+ },
+ )
+
+ async def _file_knowledge_search_function(
+ self,
+ query: str,
+ *args,
+ **kwargs,
+ ) -> AggregateSearchResult:
+ """
+ Calls the passed-in `knowledge_search_method(query, search_settings)`.
+ Expects either an AggregateSearchResult or a dict with chunk_search_results, etc.
+ """
+ if not self.knowledge_search_method:
+ raise ValueError(
+ "No knowledge_search_method provided to RAGAgentMixin."
+ )
+
+ raw_response = await self.knowledge_search_method(
+ query=query, search_settings=self.search_settings
+ )
+
+ if isinstance(raw_response, AggregateSearchResult):
+ agg = raw_response
+ else:
+ agg = AggregateSearchResult(
+ chunk_search_results=raw_response.get(
+ "chunk_search_results", []
+ ),
+ graph_search_results=raw_response.get(
+ "graph_search_results", []
+ ),
+ )
+
+ # 1) Store them so that we can do final citations later
+ self.search_results_collector.add_aggregate_result(agg)
+ return agg
+
+ # 2) Local Context
+ def content(self) -> Tool:
+ """Tool to fetch entire documents from the local database.
+
+ Typically used if the agent needs deeper or more structured context
+ from documents, not just chunk-level hits.
+ """
+ if "gemini" in self.rag_generation_config.model:
+ tool = Tool(
+ name="get_file_content",
+ description=(
+ "Fetches the complete contents of all user documents from the local database. "
+ "Can be used alongside filter criteria (e.g. doc IDs, collection IDs, etc.) to restrict the query."
+ "For instance, a single document can be returned with a filter like so:"
+ "{'document_id': {'$eq': '...'}}."
+ "Be sure to use the full 32 character hexidecimal document ID, and not the shortened 8 character ID."
+ ),
+ results_function=self._content_function,
+ llm_format_function=self.format_search_results_for_llm,
+ parameters={
+ "type": "object",
+ "properties": {
+ "filters": {
+ "type": "string",
+ "description": (
+ "Dictionary with filter criteria, such as "
+ '{"$and": [{"document_id": {"$eq": "6c9d1c39..."}, {"collection_ids": {"$overlap": [...]}]}'
+ ),
+ },
+ },
+ "required": ["filters"],
+ },
+ )
+
+ else:
+ tool = Tool(
+ name="get_file_content",
+ description=(
+ "Fetches the complete contents of all user documents from the local database. "
+ "Can be used alongside filter criteria (e.g. doc IDs, collection IDs, etc.) to restrict the query."
+ "For instance, a single document can be returned with a filter like so:"
+ "{'document_id': {'$eq': '...'}}."
+ ),
+ results_function=self._content_function,
+ llm_format_function=self.format_search_results_for_llm,
+ parameters={
+ "type": "object",
+ "properties": {
+ "filters": {
+ "type": "object",
+ "description": (
+ "Dictionary with filter criteria, such as "
+ '{"$and": [{"document_id": {"$eq": "6c9d1c39..."}, {"collection_ids": {"$overlap": [...]}]}'
+ ),
+ },
+ },
+ "required": ["filters"],
+ },
+ )
+ return tool
+
+ async def _content_function(
+ self,
+ filters: Optional[dict[str, Any]] = None,
+ options: Optional[dict[str, Any]] = None,
+ *args,
+ **kwargs,
+ ) -> AggregateSearchResult:
+ """Calls the passed-in `content_method(filters, options)` to fetch
+ doc+chunk structures.
+
+ Typically returns a list of dicts:
+ [
+ { 'document': {...}, 'chunks': [ {...}, {...}, ... ] },
+ ...
+ ]
+ We'll store these in a new field `document_search_results` of
+ AggregateSearchResult so we don't collide with chunk_search_results.
+ """
+ if not self.content_method:
+ raise ValueError("No content_method provided to RAGAgentMixin.")
+
+ if filters:
+ if "document_id" in filters:
+ filters["id"] = filters.pop("document_id")
+ if self.search_settings.filters != {}:
+ filters = {"$and": [filters, self.search_settings.filters]}
+ else:
+ filters = self.search_settings.filters
+
+ options = options or {}
+
+ # Actually call your data retrieval
+ content = await self.content_method(filters, options)
+ # raw_context presumably is a list[dict], each with 'document' + 'chunks'.
+
+ # Return them in the new aggregator field
+ agg = AggregateSearchResult(
+ # We won't put them in chunk_search_results:
+ chunk_search_results=None,
+ graph_search_results=None,
+ web_search_results=None,
+ document_search_results=content,
+ )
+ self.search_results_collector.add_aggregate_result(agg)
+ return agg
+
+ # Web Search Tool
+ def web_search(self) -> Tool:
+ return Tool(
+ name="web_search",
+ description=(
+ "Search for information on the web - use this tool when the user "
+ "query needs LIVE or recent data from the internet."
+ ),
+ results_function=self._web_search_function,
+ llm_format_function=self.format_search_results_for_llm,
+ parameters={
+ "type": "object",
+ "properties": {
+ "query": {
+ "type": "string",
+ "description": "The query to search with an external web API.",
+ },
+ },
+ "required": ["query"],
+ },
+ )
+
+ async def _web_search_function(
+ self,
+ query: str,
+ *args,
+ **kwargs,
+ ) -> AggregateSearchResult:
+ """
+ Calls an external search engine (Serper, Google, etc.) asynchronously
+ and returns results in an AggregateSearchResult.
+ """
+ import asyncio
+
+ from ..utils.serper import SerperClient # adjust your import
+
+ serper_client = SerperClient()
+
+ # If SerperClient.get_raw is not already async, wrap it in run_in_executor
+ raw_results = await asyncio.get_event_loop().run_in_executor(
+ None, # Uses the default executor
+ lambda: serper_client.get_raw(query),
+ )
+
+ # If from_serper_results is not already async, wrap it in run_in_executor too
+ web_response = await asyncio.get_event_loop().run_in_executor(
+ None, lambda: WebSearchResult.from_serper_results(raw_results)
+ )
+
+ agg = AggregateSearchResult(
+ chunk_search_results=None,
+ graph_search_results=None,
+ web_search_results=web_response.organic_results,
+ )
+ self.search_results_collector.add_aggregate_result(agg)
+ return agg
+
+ def search_files(self) -> Tool:
+ """
+ A tool to search over file-level metadata (titles, doc-level descriptions, etc.)
+ returning a list of DocumentResponse objects.
+ """
+ return Tool(
+ name="search_file_descriptions",
+ description=(
+ "Semantic search over the stored documents over AI generated summaries of input documents. "
+ "This does NOT retrieve chunk-level contents or knowledge-graph relationships. "
+ "Use this when you need a broad overview of which documents (files) might be relevant."
+ ),
+ results_function=self._search_files_function,
+ llm_format_function=self.format_search_results_for_llm,
+ parameters={
+ "type": "object",
+ "properties": {
+ "query": {
+ "type": "string",
+ "description": "Query string to semantic search over available files 'list documents about XYZ'.",
+ }
+ },
+ "required": ["query"],
+ },
+ )
+
+ async def _search_files_function(
+ self, query: str, *args, **kwargs
+ ) -> AggregateSearchResult:
+ if not self.file_search_method:
+ raise ValueError(
+ "No file_search_method provided to RAGAgentMixin."
+ )
+
+ # call the doc-level search
+ """
+ FIXME: This is going to fail, as it requires an embedding NOT a query.
+ I've moved 'search_settings' to 'settings' which had been causing a silent failure
+ causing null content in the Message object.
+ """
+ doc_results = await self.file_search_method(
+ query=query,
+ settings=self.search_settings,
+ )
+
+ # Wrap them in an AggregateSearchResult
+ agg = AggregateSearchResult(document_search_results=doc_results)
+
+ # Add them to the collector
+ self.search_results_collector.add_aggregate_result(agg)
+ return agg
+
+ def format_search_results_for_llm(
+ self, results: AggregateSearchResult
+ ) -> str:
+ context = format_search_results_for_llm(
+ results, self.search_results_collector
+ )
+ context_tokens = num_tokens(context) + 1
+ frac_to_return = self.max_tool_context_length / (context_tokens)
+
+ if frac_to_return > 1:
+ return context
+ else:
+ return context[: int(frac_to_return * len(context))]
+
+ def web_scrape(self) -> Tool:
+ """
+ A new Tool that uses Firecrawl to scrape a single URL and return
+ its contents in an LLM-friendly format (e.g. markdown).
+ """
+ return Tool(
+ name="web_scrape",
+ description=(
+ "Use Firecrawl to scrape a single webpage and retrieve its contents "
+ "as clean markdown. Useful when you need the entire body of a page, "
+ "not just a quick snippet or standard web search result."
+ ),
+ results_function=self._web_scrape_function,
+ llm_format_function=self.format_search_results_for_llm,
+ parameters={
+ "type": "object",
+ "properties": {
+ "url": {
+ "type": "string",
+ "description": (
+ "The absolute URL of the webpage you want to scrape. "
+ "Example: 'https://docs.firecrawl.dev/getting-started'"
+ ),
+ }
+ },
+ "required": ["url"],
+ },
+ )
+
+ async def _web_scrape_function(
+ self,
+ url: str,
+ *args,
+ **kwargs,
+ ) -> AggregateSearchResult:
+ """
+ Performs the Firecrawl scrape asynchronously, returning results
+ as an `AggregateSearchResult` with a single WebPageSearchResult.
+ """
+ import asyncio
+
+ from firecrawl import FirecrawlApp
+
+ app = FirecrawlApp()
+ logger.debug(f"[Firecrawl] Scraping URL={url}")
+
+ # Create a proper async wrapper for the synchronous scrape_url method
+ # This offloads the blocking operation to a thread pool
+ response = await asyncio.get_event_loop().run_in_executor(
+ None, # Uses the default executor
+ lambda: app.scrape_url(
+ url=url,
+ params={"formats": ["markdown"]},
+ ),
+ )
+
+ markdown_text = response.get("markdown", "")
+ metadata = response.get("metadata", {})
+ page_title = metadata.get("title", "Untitled page")
+
+ if len(markdown_text) > 100_000:
+ markdown_text = (
+ markdown_text[:100_000] + "...FURTHER CONTENT TRUNCATED..."
+ )
+
+ # Create a single WebPageSearchResult HACK - TODO FIX
+ web_result = WebPageSearchResult(
+ title=page_title,
+ link=url,
+ snippet=markdown_text,
+ position=0,
+ id=generate_id(markdown_text),
+ type="firecrawl",
+ )
+
+ agg = AggregateSearchResult(web_search_results=[web_result])
+
+ # Add results to the collector
+ if self.search_results_collector:
+ self.search_results_collector.add_aggregate_result(agg)
+
+ return agg
+
+
+class R2RRAGAgent(RAGAgentMixin, R2RAgent):
+ """
+ Non-streaming RAG Agent that supports search_file_knowledge, content, web_search.
+ """
+
+ def __init__(
+ self,
+ database_provider: DatabaseProvider,
+ llm_provider: (
+ AnthropicCompletionProvider
+ | LiteLLMCompletionProvider
+ | OpenAICompletionProvider
+ | R2RCompletionProvider
+ ),
+ config: RAGAgentConfig,
+ search_settings: SearchSettings,
+ rag_generation_config: GenerationConfig,
+ knowledge_search_method: Callable,
+ content_method: Callable,
+ file_search_method: Callable,
+ max_tool_context_length: int = 20_000,
+ ):
+ # Initialize base R2RAgent
+ R2RAgent.__init__(
+ self,
+ database_provider=database_provider,
+ llm_provider=llm_provider,
+ config=config,
+ rag_generation_config=rag_generation_config,
+ )
+ # Initialize the RAGAgentMixin
+ RAGAgentMixin.__init__(
+ self,
+ database_provider=database_provider,
+ llm_provider=llm_provider,
+ config=config,
+ search_settings=search_settings,
+ rag_generation_config=rag_generation_config,
+ max_tool_context_length=max_tool_context_length,
+ knowledge_search_method=knowledge_search_method,
+ file_search_method=file_search_method,
+ content_method=content_method,
+ )
+
+
+class R2RXMLToolsRAGAgent(RAGAgentMixin, R2RXMLToolsAgent):
+ """
+ Non-streaming RAG Agent that supports search_file_knowledge, content, web_search.
+ """
+
+ def __init__(
+ self,
+ database_provider: DatabaseProvider,
+ llm_provider: (
+ AnthropicCompletionProvider
+ | LiteLLMCompletionProvider
+ | OpenAICompletionProvider
+ | R2RCompletionProvider
+ ),
+ config: RAGAgentConfig,
+ search_settings: SearchSettings,
+ rag_generation_config: GenerationConfig,
+ knowledge_search_method: Callable,
+ content_method: Callable,
+ file_search_method: Callable,
+ max_tool_context_length: int = 20_000,
+ ):
+ # Initialize base R2RAgent
+ R2RXMLToolsAgent.__init__(
+ self,
+ database_provider=database_provider,
+ llm_provider=llm_provider,
+ config=config,
+ rag_generation_config=rag_generation_config,
+ )
+ # Initialize the RAGAgentMixin
+ RAGAgentMixin.__init__(
+ self,
+ database_provider=database_provider,
+ llm_provider=llm_provider,
+ config=config,
+ search_settings=search_settings,
+ rag_generation_config=rag_generation_config,
+ max_tool_context_length=max_tool_context_length,
+ knowledge_search_method=knowledge_search_method,
+ file_search_method=file_search_method,
+ content_method=content_method,
+ )
+
+
+class R2RStreamingRAGAgent(RAGAgentMixin, R2RStreamingAgent):
+ """
+ Streaming-capable RAG Agent that supports search_file_knowledge, content, web_search,
+ and emits citations as [abc1234] short IDs if the LLM includes them in brackets.
+ """
+
+ def __init__(
+ self,
+ database_provider: DatabaseProvider,
+ llm_provider: (
+ AnthropicCompletionProvider
+ | LiteLLMCompletionProvider
+ | OpenAICompletionProvider
+ | R2RCompletionProvider
+ ),
+ config: RAGAgentConfig,
+ search_settings: SearchSettings,
+ rag_generation_config: GenerationConfig,
+ knowledge_search_method: Callable,
+ content_method: Callable,
+ file_search_method: Callable,
+ max_tool_context_length: int = 10_000,
+ ):
+ # Force streaming on
+ config.stream = True
+
+ # Initialize base R2RStreamingAgent
+ R2RStreamingAgent.__init__(
+ self,
+ database_provider=database_provider,
+ llm_provider=llm_provider,
+ config=config,
+ rag_generation_config=rag_generation_config,
+ )
+
+ # Initialize the RAGAgentMixin
+ RAGAgentMixin.__init__(
+ self,
+ database_provider=database_provider,
+ llm_provider=llm_provider,
+ config=config,
+ search_settings=search_settings,
+ rag_generation_config=rag_generation_config,
+ max_tool_context_length=max_tool_context_length,
+ knowledge_search_method=knowledge_search_method,
+ content_method=content_method,
+ file_search_method=file_search_method,
+ )
+
+
+class R2RXMLToolsStreamingRAGAgent(RAGAgentMixin, R2RXMLStreamingAgent):
+ """
+ A streaming agent that:
+ - treats <think> or <Thought> blocks as chain-of-thought
+ and emits them incrementally as SSE "thinking" events.
+ - accumulates user-visible text outside those tags as SSE "message" events.
+ - filters out all XML tags related to tool calls and actions.
+ - upon finishing each iteration, it parses <Action><ToolCalls><ToolCall> blocks,
+ calls the appropriate tool, and emits SSE "tool_call" / "tool_result".
+ - properly emits citations when they appear in the text
+ """
+
+ def __init__(
+ self,
+ database_provider: DatabaseProvider,
+ llm_provider: (
+ AnthropicCompletionProvider
+ | LiteLLMCompletionProvider
+ | OpenAICompletionProvider
+ | R2RCompletionProvider
+ ),
+ config: RAGAgentConfig,
+ search_settings: SearchSettings,
+ rag_generation_config: GenerationConfig,
+ knowledge_search_method: Callable,
+ content_method: Callable,
+ file_search_method: Callable,
+ max_tool_context_length: int = 10_000,
+ ):
+ # Force streaming on
+ config.stream = True
+
+ # Initialize base R2RXMLStreamingAgent
+ R2RXMLStreamingAgent.__init__(
+ self,
+ database_provider=database_provider,
+ llm_provider=llm_provider,
+ config=config,
+ rag_generation_config=rag_generation_config,
+ )
+
+ # Initialize the RAGAgentMixin
+ RAGAgentMixin.__init__(
+ self,
+ database_provider=database_provider,
+ llm_provider=llm_provider,
+ config=config,
+ search_settings=search_settings,
+ rag_generation_config=rag_generation_config,
+ max_tool_context_length=max_tool_context_length,
+ knowledge_search_method=knowledge_search_method,
+ content_method=content_method,
+ file_search_method=file_search_method,
+ )
diff --git a/.venv/lib/python3.12/site-packages/core/agent/research.py b/.venv/lib/python3.12/site-packages/core/agent/research.py
new file mode 100644
index 00000000..6ea35783
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/agent/research.py
@@ -0,0 +1,697 @@
+import logging
+import os
+import subprocess
+import sys
+import tempfile
+from copy import copy
+from typing import Any, Callable, Optional
+
+from core.base import AppConfig
+from core.base.abstractions import GenerationConfig, Message, SearchSettings
+from core.base.agent import Tool
+from core.base.providers import DatabaseProvider
+from core.providers import (
+ AnthropicCompletionProvider,
+ LiteLLMCompletionProvider,
+ OpenAICompletionProvider,
+ R2RCompletionProvider,
+)
+from core.utils import extract_citations
+
+from ..base.agent.agent import RAGAgentConfig # type: ignore
+
+# Import the RAG agents we'll leverage
+from .rag import ( # type: ignore
+ R2RRAGAgent,
+ R2RStreamingRAGAgent,
+ R2RXMLToolsRAGAgent,
+ R2RXMLToolsStreamingRAGAgent,
+ RAGAgentMixin,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class ResearchAgentMixin(RAGAgentMixin):
+ """
+ A mixin that extends RAGAgentMixin to add research capabilities to any R2R agent.
+
+ This mixin provides all RAG capabilities plus additional research tools:
+ - A RAG tool for knowledge retrieval (which leverages the underlying RAG capabilities)
+ - A Python execution tool for code execution and computation
+ - A reasoning tool for complex problem solving
+ - A critique tool for analyzing conversation history
+ """
+
+ def __init__(
+ self,
+ *args,
+ app_config: AppConfig,
+ search_settings: SearchSettings,
+ knowledge_search_method: Callable,
+ content_method: Callable,
+ file_search_method: Callable,
+ max_tool_context_length=10_000,
+ **kwargs,
+ ):
+ # Store the app configuration needed for research tools
+ self.app_config = app_config
+
+ # Call the parent RAGAgentMixin's __init__ with explicitly passed parameters
+ super().__init__(
+ *args,
+ search_settings=search_settings,
+ knowledge_search_method=knowledge_search_method,
+ content_method=content_method,
+ file_search_method=file_search_method,
+ max_tool_context_length=max_tool_context_length,
+ **kwargs,
+ )
+
+ # Register our research-specific tools
+ self._register_research_tools()
+
+ def _register_research_tools(self):
+ """
+ Register research-specific tools to the agent.
+ This is called by the mixin's __init__ after the parent class initialization.
+ """
+ # Add our research tools to whatever tools are already registered
+ research_tools = []
+ for tool_name in set(self.config.research_tools):
+ if tool_name == "rag":
+ research_tools.append(self.rag_tool())
+ elif tool_name == "reasoning":
+ research_tools.append(self.reasoning_tool())
+ elif tool_name == "critique":
+ research_tools.append(self.critique_tool())
+ elif tool_name == "python_executor":
+ research_tools.append(self.python_execution_tool())
+ else:
+ logger.warning(f"Unknown research tool: {tool_name}")
+ raise ValueError(f"Unknown research tool: {tool_name}")
+
+ logger.debug(f"Registered research tools: {research_tools}")
+ self.tools = research_tools
+
+ def rag_tool(self) -> Tool:
+ """Tool that provides access to the RAG agent's search capabilities."""
+ return Tool(
+ name="rag",
+ description=(
+ "Search for information using RAG (Retrieval-Augmented Generation). "
+ "This tool searches across relevant sources and returns comprehensive information. "
+ "Use this tool when you need to find specific information on any topic. Be sure to pose your query as a comprehensive query."
+ ),
+ results_function=self._rag,
+ llm_format_function=self._format_search_results,
+ parameters={
+ "type": "object",
+ "properties": {
+ "query": {
+ "type": "string",
+ "description": "The search query to find information.",
+ }
+ },
+ "required": ["query"],
+ },
+ )
+
+ def reasoning_tool(self) -> Tool:
+ """Tool that provides access to a strong reasoning model."""
+ return Tool(
+ name="reasoning",
+ description=(
+ "A dedicated reasoning system that excels at solving complex problems through step-by-step analysis. "
+ "This tool connects to a separate AI system optimized for deep analytical thinking.\n\n"
+ "USAGE GUIDELINES:\n"
+ "1. Formulate your request as a complete, standalone question to a reasoning expert.\n"
+ "2. Clearly state the problem/question at the beginning.\n"
+ "3. Provide all relevant context, data, and constraints.\n\n"
+ "IMPORTANT: This system has no memory of previous interactions or context from your conversation.\n\n"
+ "STRENGTHS: Mathematical reasoning, logical analysis, evaluating complex scenarios, "
+ "solving multi-step problems, and identifying potential errors in reasoning."
+ ),
+ results_function=self._reason,
+ llm_format_function=self._format_search_results,
+ parameters={
+ "type": "object",
+ "properties": {
+ "query": {
+ "type": "string",
+ "description": "A complete, standalone question with all necessary context, appropriate for a dedicated reasoning system.",
+ }
+ },
+ "required": ["query"],
+ },
+ )
+
+ def critique_tool(self) -> Tool:
+ """Tool that provides critical analysis of the reasoning done so far in the conversation."""
+ return Tool(
+ name="critique",
+ description=(
+ "Analyzes the conversation history to identify potential flaws, biases, and alternative "
+ "approaches to the reasoning presented so far.\n\n"
+ "Use this tool to get a second opinion on your reasoning, find overlooked considerations, "
+ "identify biases or fallacies, explore alternative hypotheses, and improve the robustness "
+ "of your conclusions."
+ ),
+ results_function=self._critique,
+ llm_format_function=self._format_search_results,
+ parameters={
+ "type": "object",
+ "properties": {
+ "query": {
+ "type": "string",
+ "description": "A specific aspect of the reasoning you want critiqued, or leave empty for a general critique.",
+ },
+ "focus_areas": {
+ "type": "array",
+ "items": {"type": "string"},
+ "description": "Optional specific areas to focus the critique (e.g., ['logical fallacies', 'methodology'])",
+ },
+ },
+ "required": ["query"],
+ },
+ )
+
+ def python_execution_tool(self) -> Tool:
+ """Tool that provides Python code execution capabilities."""
+ return Tool(
+ name="python_executor",
+ description=(
+ "Executes Python code and returns the results, output, and any errors. "
+ "Use this tool for complex calculations, statistical operations, or algorithmic implementations.\n\n"
+ "The execution environment includes common libraries such as numpy, pandas, sympy, scipy, statsmodels, biopython, etc.\n\n"
+ "USAGE:\n"
+ "1. Send complete, executable Python code as a string.\n"
+ "2. Use print statements for output you want to see.\n"
+ "3. Assign to the 'result' variable for values you want to return.\n"
+ "4. Do not use input() or plotting (matplotlib). Output is text-based."
+ ),
+ results_function=self._execute_python_with_process_timeout,
+ llm_format_function=self._format_python_results,
+ parameters={
+ "type": "object",
+ "properties": {
+ "code": {
+ "type": "string",
+ "description": "Python code to execute.",
+ }
+ },
+ "required": ["code"],
+ },
+ )
+
+ async def _rag(
+ self,
+ query: str,
+ *args,
+ **kwargs,
+ ) -> dict[str, Any]:
+ """Execute a search using an internal RAG agent."""
+ # Create a copy of the current configuration for the RAG agent
+ config_copy = copy(self.config)
+ config_copy.max_iterations = 10 # Could be configurable
+ config_copy.rag_tools = [
+ "web_search",
+ "web_scrape",
+ ] # HACK HACK TODO - Fix.
+
+ # Create a generation config for the RAG agent
+ generation_config = GenerationConfig(
+ model=self.app_config.quality_llm,
+ max_tokens_to_sample=16000,
+ )
+
+ # Create a new RAG agent - we'll use the non-streaming variant for consistent results
+ rag_agent = R2RRAGAgent(
+ database_provider=self.database_provider,
+ llm_provider=self.llm_provider,
+ config=config_copy,
+ search_settings=self.search_settings,
+ rag_generation_config=generation_config,
+ knowledge_search_method=self.knowledge_search_method,
+ content_method=self.content_method,
+ file_search_method=self.file_search_method,
+ max_tool_context_length=self.max_tool_context_length,
+ )
+
+ # Run the RAG agent with the query
+ user_message = Message(role="user", content=query)
+ response = await rag_agent.arun(messages=[user_message])
+
+ # Get the content from the response
+ structured_content = response[-1].get("structured_content")
+ if structured_content:
+ possible_text = structured_content[-1].get("text")
+ content = response[-1].get("content") or possible_text
+ else:
+ content = response[-1].get("content")
+
+ # Extract citations and transfer search results from RAG agent to research agent
+ short_ids = extract_citations(content)
+ if short_ids:
+ logger.info(f"Found citations in RAG response: {short_ids}")
+
+ for short_id in short_ids:
+ result = rag_agent.search_results_collector.find_by_short_id(
+ short_id
+ )
+ if result:
+ self.search_results_collector.add_result(result)
+
+ # Log confirmation for successful transfer
+ logger.info(
+ "Transferred search results from RAG agent to research agent for citations"
+ )
+ return content
+
+ async def _reason(
+ self,
+ query: str,
+ *args,
+ **kwargs,
+ ) -> dict[str, Any]:
+ """Execute a reasoning query using a specialized reasoning LLM."""
+ msg_list = await self.conversation.get_messages()
+
+ # Create a specialized generation config for reasoning
+ gen_cfg = self.get_generation_config(msg_list[-1], stream=False)
+ gen_cfg.model = self.app_config.reasoning_llm
+ gen_cfg.top_p = None
+ gen_cfg.temperature = 0.1
+ gen_cfg.max_tokens_to_sample = 64000
+ gen_cfg.stream = False
+ gen_cfg.tools = None
+ gen_cfg.functions = None
+ gen_cfg.reasoning_effort = "high"
+ gen_cfg.add_generation_kwargs = None
+
+ # Call the LLM with the reasoning request
+ response = await self.llm_provider.aget_completion(
+ [{"role": "user", "content": query}], gen_cfg
+ )
+ return response.choices[0].message.content
+
+ async def _critique(
+ self,
+ query: str,
+ focus_areas: Optional[list] = None,
+ *args,
+ **kwargs,
+ ) -> dict[str, Any]:
+ """Critique the conversation history."""
+ msg_list = await self.conversation.get_messages()
+ if not focus_areas:
+ focus_areas = []
+ # Build the critique prompt
+ critique_prompt = (
+ "You are a critical reasoning expert. Your task is to analyze the following conversation "
+ "and critique the reasoning. Look for:\n"
+ "1. Logical fallacies or inconsistencies\n"
+ "2. Cognitive biases\n"
+ "3. Overlooked questions or considerations\n"
+ "4. Alternative approaches\n"
+ "5. Improvements in rigor\n\n"
+ )
+
+ if focus_areas:
+ critique_prompt += f"Focus areas: {', '.join(focus_areas)}\n\n"
+
+ if query.strip():
+ critique_prompt += f"Specific question: {query}\n\n"
+
+ critique_prompt += (
+ "Structure your critique:\n"
+ "1. Summary\n"
+ "2. Key strengths\n"
+ "3. Potential issues\n"
+ "4. Alternatives\n"
+ "5. Recommendations\n\n"
+ )
+
+ # Add the conversation history to the prompt
+ conversation_text = "\n--- CONVERSATION HISTORY ---\n\n"
+ for msg in msg_list:
+ role = msg.get("role", "")
+ content = msg.get("content", "")
+ if content and role in ["user", "assistant", "system"]:
+ conversation_text += f"{role.upper()}: {content}\n\n"
+
+ final_prompt = critique_prompt + conversation_text
+
+ # Use the reasoning tool to process the critique
+ return await self._reason(final_prompt, *args, **kwargs)
+
+ async def _execute_python_with_process_timeout(
+ self, code: str, timeout: int = 10, *args, **kwargs
+ ) -> dict[str, Any]:
+ """
+ Executes Python code in a separate subprocess with a timeout.
+ This provides isolation and prevents re-importing the current agent module.
+
+ Parameters:
+ code (str): Python code to execute.
+ timeout (int): Timeout in seconds (default: 10).
+
+ Returns:
+ dict[str, Any]: Dictionary containing stdout, stderr, return code, etc.
+ """
+ # Write user code to a temporary file
+ with tempfile.NamedTemporaryFile(
+ mode="w", suffix=".py", delete=False
+ ) as tmp_file:
+ tmp_file.write(code)
+ script_path = tmp_file.name
+
+ try:
+ # Run the script in a fresh subprocess
+ result = subprocess.run(
+ [sys.executable, script_path],
+ capture_output=True,
+ text=True,
+ timeout=timeout,
+ )
+
+ return {
+ "result": None, # We'll parse from stdout if needed
+ "stdout": result.stdout,
+ "stderr": result.stderr,
+ "error": (
+ None
+ if result.returncode == 0
+ else {
+ "type": "SubprocessError",
+ "message": f"Process exited with code {result.returncode}",
+ "traceback": "",
+ }
+ ),
+ "locals": {}, # No direct local var capture in a separate process
+ "success": (result.returncode == 0),
+ "timed_out": False,
+ "timeout": timeout,
+ }
+ except subprocess.TimeoutExpired as e:
+ return {
+ "result": None,
+ "stdout": e.output or "",
+ "stderr": e.stderr or "",
+ "error": {
+ "type": "TimeoutError",
+ "message": f"Execution exceeded {timeout} second limit.",
+ "traceback": "",
+ },
+ "locals": {},
+ "success": False,
+ "timed_out": True,
+ "timeout": timeout,
+ }
+ finally:
+ # Clean up the temp file
+ if os.path.exists(script_path):
+ os.remove(script_path)
+
+ def _format_python_results(self, results: dict[str, Any]) -> str:
+ """Format Python execution results for display."""
+ output = []
+
+ # Timeout notification
+ if results.get("timed_out", False):
+ output.append(
+ f"⚠️ **Execution Timeout**: Code exceeded the {results.get('timeout', 10)} second limit."
+ )
+ output.append("")
+
+ # Stdout
+ if results.get("stdout"):
+ output.append("## Output:")
+ output.append("```")
+ output.append(results["stdout"].rstrip())
+ output.append("```")
+ output.append("")
+
+ # If there's a 'result' variable to display
+ if results.get("result") is not None:
+ output.append("## Result:")
+ output.append("```")
+ output.append(str(results["result"]))
+ output.append("```")
+ output.append("")
+
+ # Error info
+ if not results.get("success", True):
+ output.append("## Error:")
+ output.append("```")
+ stderr_out = results.get("stderr", "").rstrip()
+ if stderr_out:
+ output.append(stderr_out)
+
+ err_obj = results.get("error")
+ if err_obj and err_obj.get("message"):
+ output.append(err_obj["message"])
+ output.append("```")
+
+ # Return formatted output
+ return (
+ "\n".join(output)
+ if output
+ else "Code executed with no output or result."
+ )
+
+ def _format_search_results(self, results) -> str:
+ """Simple pass-through formatting for RAG search results."""
+ return results
+
+
+class R2RResearchAgent(ResearchAgentMixin, R2RRAGAgent):
+ """
+ A non-streaming research agent that uses the standard R2R agent as its base.
+
+ This agent combines research capabilities with the non-streaming RAG agent,
+ providing tools for deep research through tool-based interaction.
+ """
+
+ def __init__(
+ self,
+ app_config: AppConfig,
+ database_provider: DatabaseProvider,
+ llm_provider: (
+ AnthropicCompletionProvider
+ | LiteLLMCompletionProvider
+ | OpenAICompletionProvider
+ | R2RCompletionProvider
+ ),
+ config: RAGAgentConfig,
+ search_settings: SearchSettings,
+ rag_generation_config: GenerationConfig,
+ knowledge_search_method: Callable,
+ content_method: Callable,
+ file_search_method: Callable,
+ max_tool_context_length: int = 20_000,
+ ):
+ # Set a higher max iterations for research tasks
+ config.max_iterations = config.max_iterations or 15
+
+ # Initialize the RAG agent first
+ R2RRAGAgent.__init__(
+ self,
+ database_provider=database_provider,
+ llm_provider=llm_provider,
+ config=config,
+ search_settings=search_settings,
+ rag_generation_config=rag_generation_config,
+ knowledge_search_method=knowledge_search_method,
+ content_method=content_method,
+ file_search_method=file_search_method,
+ max_tool_context_length=max_tool_context_length,
+ )
+
+ # Then initialize the ResearchAgentMixin
+ ResearchAgentMixin.__init__(
+ self,
+ app_config=app_config,
+ database_provider=database_provider,
+ llm_provider=llm_provider,
+ config=config,
+ search_settings=search_settings,
+ rag_generation_config=rag_generation_config,
+ max_tool_context_length=max_tool_context_length,
+ knowledge_search_method=knowledge_search_method,
+ file_search_method=file_search_method,
+ content_method=content_method,
+ )
+
+
+class R2RStreamingResearchAgent(ResearchAgentMixin, R2RStreamingRAGAgent):
+ """
+ A streaming research agent that uses the streaming RAG agent as its base.
+
+ This agent combines research capabilities with streaming text generation,
+ providing real-time responses while still offering research tools.
+ """
+
+ def __init__(
+ self,
+ app_config: AppConfig,
+ database_provider: DatabaseProvider,
+ llm_provider: (
+ AnthropicCompletionProvider
+ | LiteLLMCompletionProvider
+ | OpenAICompletionProvider
+ | R2RCompletionProvider
+ ),
+ config: RAGAgentConfig,
+ search_settings: SearchSettings,
+ rag_generation_config: GenerationConfig,
+ knowledge_search_method: Callable,
+ content_method: Callable,
+ file_search_method: Callable,
+ max_tool_context_length: int = 10_000,
+ ):
+ # Force streaming on
+ config.stream = True
+ config.max_iterations = config.max_iterations or 15
+
+ # Initialize the streaming RAG agent first
+ R2RStreamingRAGAgent.__init__(
+ self,
+ database_provider=database_provider,
+ llm_provider=llm_provider,
+ config=config,
+ search_settings=search_settings,
+ rag_generation_config=rag_generation_config,
+ knowledge_search_method=knowledge_search_method,
+ content_method=content_method,
+ file_search_method=file_search_method,
+ max_tool_context_length=max_tool_context_length,
+ )
+
+ # Then initialize the ResearchAgentMixin
+ ResearchAgentMixin.__init__(
+ self,
+ app_config=app_config,
+ database_provider=database_provider,
+ llm_provider=llm_provider,
+ config=config,
+ search_settings=search_settings,
+ rag_generation_config=rag_generation_config,
+ max_tool_context_length=max_tool_context_length,
+ knowledge_search_method=knowledge_search_method,
+ content_method=content_method,
+ file_search_method=file_search_method,
+ )
+
+
+class R2RXMLToolsResearchAgent(ResearchAgentMixin, R2RXMLToolsRAGAgent):
+ """
+ A non-streaming research agent that uses XML tool formatting.
+
+ This agent combines research capabilities with the XML-based tool calling format,
+ which might be more appropriate for certain LLM providers.
+ """
+
+ def __init__(
+ self,
+ app_config: AppConfig,
+ database_provider: DatabaseProvider,
+ llm_provider: (
+ AnthropicCompletionProvider
+ | LiteLLMCompletionProvider
+ | OpenAICompletionProvider
+ | R2RCompletionProvider
+ ),
+ config: RAGAgentConfig,
+ search_settings: SearchSettings,
+ rag_generation_config: GenerationConfig,
+ knowledge_search_method: Callable,
+ content_method: Callable,
+ file_search_method: Callable,
+ max_tool_context_length: int = 20_000,
+ ):
+ # Set higher max iterations
+ config.max_iterations = config.max_iterations or 15
+
+ # Initialize the XML Tools RAG agent first
+ R2RXMLToolsRAGAgent.__init__(
+ self,
+ database_provider=database_provider,
+ llm_provider=llm_provider,
+ config=config,
+ search_settings=search_settings,
+ rag_generation_config=rag_generation_config,
+ knowledge_search_method=knowledge_search_method,
+ content_method=content_method,
+ file_search_method=file_search_method,
+ max_tool_context_length=max_tool_context_length,
+ )
+
+ # Then initialize the ResearchAgentMixin
+ ResearchAgentMixin.__init__(
+ self,
+ app_config=app_config,
+ search_settings=search_settings,
+ knowledge_search_method=knowledge_search_method,
+ content_method=content_method,
+ file_search_method=file_search_method,
+ max_tool_context_length=max_tool_context_length,
+ )
+
+
+class R2RXMLToolsStreamingResearchAgent(
+ ResearchAgentMixin, R2RXMLToolsStreamingRAGAgent
+):
+ """
+ A streaming research agent that uses XML tool formatting.
+
+ This agent combines research capabilities with streaming and XML-based tool calling,
+ providing real-time responses in a format suitable for certain LLM providers.
+ """
+
+ def __init__(
+ self,
+ app_config: AppConfig,
+ database_provider: DatabaseProvider,
+ llm_provider: (
+ AnthropicCompletionProvider
+ | LiteLLMCompletionProvider
+ | OpenAICompletionProvider
+ | R2RCompletionProvider
+ ),
+ config: RAGAgentConfig,
+ search_settings: SearchSettings,
+ rag_generation_config: GenerationConfig,
+ knowledge_search_method: Callable,
+ content_method: Callable,
+ file_search_method: Callable,
+ max_tool_context_length: int = 10_000,
+ ):
+ # Force streaming on
+ config.stream = True
+ config.max_iterations = config.max_iterations or 15
+
+ # Initialize the XML Tools Streaming RAG agent first
+ R2RXMLToolsStreamingRAGAgent.__init__(
+ self,
+ database_provider=database_provider,
+ llm_provider=llm_provider,
+ config=config,
+ search_settings=search_settings,
+ rag_generation_config=rag_generation_config,
+ knowledge_search_method=knowledge_search_method,
+ content_method=content_method,
+ file_search_method=file_search_method,
+ max_tool_context_length=max_tool_context_length,
+ )
+
+ # Then initialize the ResearchAgentMixin
+ ResearchAgentMixin.__init__(
+ self,
+ app_config=app_config,
+ search_settings=search_settings,
+ knowledge_search_method=knowledge_search_method,
+ content_method=content_method,
+ file_search_method=file_search_method,
+ max_tool_context_length=max_tool_context_length,
+ )