about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/core/agent/base.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/core/agent/base.py')
-rw-r--r--.venv/lib/python3.12/site-packages/core/agent/base.py1484
1 files changed, 1484 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/core/agent/base.py b/.venv/lib/python3.12/site-packages/core/agent/base.py
new file mode 100644
index 00000000..84aae3f2
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/agent/base.py
@@ -0,0 +1,1484 @@
+import asyncio
+import json
+import logging
+import re
+from abc import ABCMeta
+from typing import AsyncGenerator, Optional, Tuple
+
+from core.base import AsyncSyncMeta, LLMChatCompletion, Message, syncable
+from core.base.agent import Agent, Conversation
+from core.utils import (
+    CitationTracker,
+    SearchResultsCollector,
+    SSEFormatter,
+    convert_nonserializable_objects,
+    dump_obj,
+    find_new_citation_spans,
+)
+
+logger = logging.getLogger()
+
+
+class CombinedMeta(AsyncSyncMeta, ABCMeta):
+    pass
+
+
+def sync_wrapper(async_gen):
+    loop = asyncio.get_event_loop()
+
+    def wrapper():
+        try:
+            while True:
+                try:
+                    yield loop.run_until_complete(async_gen.__anext__())
+                except StopAsyncIteration:
+                    break
+        finally:
+            loop.run_until_complete(async_gen.aclose())
+
+    return wrapper()
+
+
+class R2RAgent(Agent, metaclass=CombinedMeta):
+    def __init__(self, *args, **kwargs):
+        self.search_results_collector = SearchResultsCollector()
+        super().__init__(*args, **kwargs)
+        self._reset()
+
+    async def _generate_llm_summary(self, iterations_count: int) -> str:
+        """
+        Generate a summary of the conversation using the LLM when max iterations are exceeded.
+
+        Args:
+            iterations_count: The number of iterations that were completed
+
+        Returns:
+            A string containing the LLM-generated summary
+        """
+        try:
+            # Get all messages in the conversation
+            all_messages = await self.conversation.get_messages()
+
+            # Create a prompt for the LLM to summarize
+            summary_prompt = {
+                "role": "user",
+                "content": (
+                    f"The conversation has reached the maximum limit of {iterations_count} iterations "
+                    f"without completing the task. Please provide a concise summary of: "
+                    f"1) The key information you've gathered that's relevant to the original query, "
+                    f"2) What you've attempted so far and why it's incomplete, and "
+                    f"3) A specific recommendation for how to proceed. "
+                    f"Keep your summary brief (3-4 sentences total) and focused on the most valuable insights. If it is possible to answer the original user query, then do so now instead."
+                    f"Start with '⚠️ **Maximum iterations exceeded**'"
+                ),
+            }
+
+            # Create a new message list with just the conversation history and summary request
+            summary_messages = all_messages + [summary_prompt]
+
+            # Get a completion for the summary
+            generation_config = self.get_generation_config(summary_prompt)
+            response = await self.llm_provider.aget_completion(
+                summary_messages,
+                generation_config,
+            )
+
+            return response.choices[0].message.content
+        except Exception as e:
+            logger.error(f"Error generating LLM summary: {str(e)}")
+            # Fall back to basic summary if LLM generation fails
+            return (
+                "⚠️ **Maximum iterations exceeded**\n\n"
+                "The agent reached the maximum iteration limit without completing the task. "
+                "Consider breaking your request into smaller steps or refining your query."
+            )
+
+    def _reset(self):
+        self._completed = False
+        self.conversation = Conversation()
+
+    @syncable
+    async def arun(
+        self,
+        messages: list[Message],
+        system_instruction: Optional[str] = None,
+        *args,
+        **kwargs,
+    ) -> list[dict]:
+        self._reset()
+        await self._setup(system_instruction)
+
+        if messages:
+            for message in messages:
+                await self.conversation.add_message(message)
+        iterations_count = 0
+        while (
+            not self._completed
+            and iterations_count < self.config.max_iterations
+        ):
+            iterations_count += 1
+            messages_list = await self.conversation.get_messages()
+            generation_config = self.get_generation_config(messages_list[-1])
+            response = await self.llm_provider.aget_completion(
+                messages_list,
+                generation_config,
+            )
+            logger.debug(f"R2RAgent response: {response}")
+            await self.process_llm_response(response, *args, **kwargs)
+
+        if not self._completed:
+            # Generate a summary of the conversation using the LLM
+            summary = await self._generate_llm_summary(iterations_count)
+            await self.conversation.add_message(
+                Message(role="assistant", content=summary)
+            )
+
+        # Return final content
+        all_messages: list[dict] = await self.conversation.get_messages()
+        all_messages.reverse()
+
+        output_messages = []
+        for message_2 in all_messages:
+            if (
+                # message_2.get("content")
+                message_2.get("content") != messages[-1].content
+            ):
+                output_messages.append(message_2)
+            else:
+                break
+        output_messages.reverse()
+
+        return output_messages
+
+    async def process_llm_response(
+        self, response: LLMChatCompletion, *args, **kwargs
+    ) -> None:
+        if not self._completed:
+            message = response.choices[0].message
+            finish_reason = response.choices[0].finish_reason
+
+            if finish_reason == "stop":
+                self._completed = True
+
+            # Determine which provider we're using
+            using_anthropic = (
+                "anthropic" in self.rag_generation_config.model.lower()
+            )
+
+            # OPENAI HANDLING
+            if not using_anthropic:
+                if message.tool_calls:
+                    assistant_msg = Message(
+                        role="assistant",
+                        content="",
+                        tool_calls=[msg.dict() for msg in message.tool_calls],
+                    )
+                    await self.conversation.add_message(assistant_msg)
+
+                    # If there are multiple tool_calls, call them sequentially here
+                    for tool_call in message.tool_calls:
+                        await self.handle_function_or_tool_call(
+                            tool_call.function.name,
+                            tool_call.function.arguments,
+                            tool_id=tool_call.id,
+                            *args,
+                            **kwargs,
+                        )
+                else:
+                    await self.conversation.add_message(
+                        Message(role="assistant", content=message.content)
+                    )
+                    self._completed = True
+
+            else:
+                # First handle thinking blocks if present
+                if (
+                    hasattr(message, "structured_content")
+                    and message.structured_content
+                ):
+                    # Check if structured_content contains any tool_use blocks
+                    has_tool_use = any(
+                        block.get("type") == "tool_use"
+                        for block in message.structured_content
+                    )
+
+                    if not has_tool_use and message.tool_calls:
+                        # If it has thinking but no tool_use, add a separate message with structured_content
+                        assistant_msg = Message(
+                            role="assistant",
+                            structured_content=message.structured_content,  # Use structured_content field
+                        )
+                        await self.conversation.add_message(assistant_msg)
+
+                        # Add explicit tool_use blocks in a separate message
+                        tool_uses = []
+                        for tool_call in message.tool_calls:
+                            # Safely parse arguments if they're a string
+                            try:
+                                if isinstance(
+                                    tool_call.function.arguments, str
+                                ):
+                                    input_args = json.loads(
+                                        tool_call.function.arguments
+                                    )
+                                else:
+                                    input_args = tool_call.function.arguments
+                            except json.JSONDecodeError:
+                                logger.error(
+                                    f"Failed to parse tool arguments: {tool_call.function.arguments}"
+                                )
+                                input_args = {
+                                    "_raw": tool_call.function.arguments
+                                }
+
+                            tool_uses.append(
+                                {
+                                    "type": "tool_use",
+                                    "id": tool_call.id,
+                                    "name": tool_call.function.name,
+                                    "input": input_args,
+                                }
+                            )
+
+                        # Add tool_use blocks as a separate assistant message with structured content
+                        if tool_uses:
+                            await self.conversation.add_message(
+                                Message(
+                                    role="assistant",
+                                    structured_content=tool_uses,
+                                    content="",
+                                )
+                            )
+                    else:
+                        # If it already has tool_use or no tool_calls, preserve original structure
+                        assistant_msg = Message(
+                            role="assistant",
+                            structured_content=message.structured_content,
+                        )
+                        await self.conversation.add_message(assistant_msg)
+
+                elif message.content:
+                    # For regular text content
+                    await self.conversation.add_message(
+                        Message(role="assistant", content=message.content)
+                    )
+
+                    # If there are tool calls, add them as structured content
+                    if message.tool_calls:
+                        tool_uses = []
+                        for tool_call in message.tool_calls:
+                            # Same safe parsing as above
+                            try:
+                                if isinstance(
+                                    tool_call.function.arguments, str
+                                ):
+                                    input_args = json.loads(
+                                        tool_call.function.arguments
+                                    )
+                                else:
+                                    input_args = tool_call.function.arguments
+                            except json.JSONDecodeError:
+                                logger.error(
+                                    f"Failed to parse tool arguments: {tool_call.function.arguments}"
+                                )
+                                input_args = {
+                                    "_raw": tool_call.function.arguments
+                                }
+
+                            tool_uses.append(
+                                {
+                                    "type": "tool_use",
+                                    "id": tool_call.id,
+                                    "name": tool_call.function.name,
+                                    "input": input_args,
+                                }
+                            )
+
+                        await self.conversation.add_message(
+                            Message(
+                                role="assistant", structured_content=tool_uses
+                            )
+                        )
+
+                # NEW CASE: Handle tool_calls with no content or structured_content
+                elif message.tool_calls:
+                    # Create tool_uses for the message with only tool_calls
+                    tool_uses = []
+                    for tool_call in message.tool_calls:
+                        try:
+                            if isinstance(tool_call.function.arguments, str):
+                                input_args = json.loads(
+                                    tool_call.function.arguments
+                                )
+                            else:
+                                input_args = tool_call.function.arguments
+                        except json.JSONDecodeError:
+                            logger.error(
+                                f"Failed to parse tool arguments: {tool_call.function.arguments}"
+                            )
+                            input_args = {"_raw": tool_call.function.arguments}
+
+                        tool_uses.append(
+                            {
+                                "type": "tool_use",
+                                "id": tool_call.id,
+                                "name": tool_call.function.name,
+                                "input": input_args,
+                            }
+                        )
+
+                    # Add tool_use blocks as a message before processing tools
+                    if tool_uses:
+                        await self.conversation.add_message(
+                            Message(
+                                role="assistant",
+                                structured_content=tool_uses,
+                            )
+                        )
+
+                # Process the tool calls
+                if message.tool_calls:
+                    for tool_call in message.tool_calls:
+                        await self.handle_function_or_tool_call(
+                            tool_call.function.name,
+                            tool_call.function.arguments,
+                            tool_id=tool_call.id,
+                            *args,
+                            **kwargs,
+                        )
+
+
+class R2RStreamingAgent(R2RAgent):
+    """
+    Base class for all streaming agents with core streaming functionality.
+    Supports emitting messages, tool calls, and results as SSE events.
+    """
+
+    # These two regexes will detect bracket references and then find short IDs.
+    BRACKET_PATTERN = re.compile(r"\[([^\]]+)\]")
+    SHORT_ID_PATTERN = re.compile(
+        r"[A-Za-z0-9]{7,8}"
+    )  # 7-8 chars, for example
+
+    def __init__(self, *args, **kwargs):
+        # Force streaming on
+        if hasattr(kwargs.get("config", {}), "stream"):
+            kwargs["config"].stream = True
+        super().__init__(*args, **kwargs)
+
+    async def arun(
+        self,
+        system_instruction: str | None = None,
+        messages: list[Message] | None = None,
+        *args,
+        **kwargs,
+    ) -> AsyncGenerator[str, None]:
+        """
+        Main streaming entrypoint: returns an async generator of SSE lines.
+        """
+        self._reset()
+        await self._setup(system_instruction)
+
+        if messages:
+            for m in messages:
+                await self.conversation.add_message(m)
+
+        # Initialize citation tracker for this run
+        citation_tracker = CitationTracker()
+
+        # Dictionary to store citation payloads by ID
+        citation_payloads = {}
+
+        # Track all citations emitted during streaming for final persistence
+        self.streaming_citations: list[dict] = []
+
+        async def sse_generator() -> AsyncGenerator[str, None]:
+            pending_tool_calls = {}
+            partial_text_buffer = ""
+            iterations_count = 0
+
+            try:
+                # Keep streaming until we complete
+                while (
+                    not self._completed
+                    and iterations_count < self.config.max_iterations
+                ):
+                    iterations_count += 1
+                    # 1) Get current messages
+                    msg_list = await self.conversation.get_messages()
+                    gen_cfg = self.get_generation_config(
+                        msg_list[-1], stream=True
+                    )
+
+                    accumulated_thinking = ""
+                    thinking_signatures = {}  # Map thinking content to signatures
+
+                    # 2) Start streaming from LLM
+                    llm_stream = self.llm_provider.aget_completion_stream(
+                        msg_list, gen_cfg
+                    )
+                    async for chunk in llm_stream:
+                        delta = chunk.choices[0].delta
+                        finish_reason = chunk.choices[0].finish_reason
+
+                        if hasattr(delta, "thinking") and delta.thinking:
+                            # Accumulate thinking for later use in messages
+                            accumulated_thinking += delta.thinking
+
+                            # Emit SSE "thinking" event
+                            async for (
+                                line
+                            ) in SSEFormatter.yield_thinking_event(
+                                delta.thinking
+                            ):
+                                yield line
+
+                        # Add this new handler for thinking signatures
+                        if hasattr(delta, "thinking_signature"):
+                            thinking_signatures[accumulated_thinking] = (
+                                delta.thinking_signature
+                            )
+                            accumulated_thinking = ""
+
+                        # 3) If new text, accumulate it
+                        if delta.content:
+                            partial_text_buffer += delta.content
+
+                            # (a) Now emit the newly streamed text as a "message" event
+                            async for line in SSEFormatter.yield_message_event(
+                                delta.content
+                            ):
+                                yield line
+
+                            # (b) Find new citation spans in the accumulated text
+                            new_citation_spans = find_new_citation_spans(
+                                partial_text_buffer, citation_tracker
+                            )
+
+                            # Process each new citation span
+                            for cid, spans in new_citation_spans.items():
+                                for span in spans:
+                                    # Check if this is the first time we've seen this citation ID
+                                    is_new_citation = (
+                                        citation_tracker.is_new_citation(cid)
+                                    )
+
+                                    # Get payload if it's a new citation
+                                    payload = None
+                                    if is_new_citation:
+                                        source_obj = self.search_results_collector.find_by_short_id(
+                                            cid
+                                        )
+                                        if source_obj:
+                                            # Store payload for reuse
+                                            payload = dump_obj(source_obj)
+                                            citation_payloads[cid] = payload
+
+                                    # Create citation event payload
+                                    citation_data = {
+                                        "id": cid,
+                                        "object": "citation",
+                                        "is_new": is_new_citation,
+                                        "span": {
+                                            "start": span[0],
+                                            "end": span[1],
+                                        },
+                                    }
+
+                                    # Only include full payload for new citations
+                                    if is_new_citation and payload:
+                                        citation_data["payload"] = payload
+
+                                    # Add to streaming citations for final answer
+                                    self.streaming_citations.append(
+                                        citation_data
+                                    )
+
+                                    # Emit the citation event
+                                    async for (
+                                        line
+                                    ) in SSEFormatter.yield_citation_event(
+                                        citation_data
+                                    ):
+                                        yield line
+
+                        if delta.tool_calls:
+                            for tc in delta.tool_calls:
+                                idx = tc.index
+                                if idx not in pending_tool_calls:
+                                    pending_tool_calls[idx] = {
+                                        "id": tc.id,
+                                        "name": tc.function.name or "",
+                                        "arguments": tc.function.arguments
+                                        or "",
+                                    }
+                                else:
+                                    # Accumulate partial name/arguments
+                                    if tc.function.name:
+                                        pending_tool_calls[idx]["name"] = (
+                                            tc.function.name
+                                        )
+                                    if tc.function.arguments:
+                                        pending_tool_calls[idx][
+                                            "arguments"
+                                        ] += tc.function.arguments
+
+                        # 5) If the stream signals we should handle "tool_calls"
+                        if finish_reason == "tool_calls":
+                            # Handle thinking if present
+                            await self._handle_thinking(
+                                thinking_signatures, accumulated_thinking
+                            )
+
+                            calls_list = []
+                            for idx in sorted(pending_tool_calls.keys()):
+                                cinfo = pending_tool_calls[idx]
+                                calls_list.append(
+                                    {
+                                        "tool_call_id": cinfo["id"]
+                                        or f"call_{idx}",
+                                        "name": cinfo["name"],
+                                        "arguments": cinfo["arguments"],
+                                    }
+                                )
+
+                            # (a) Emit SSE "tool_call" events
+                            for c in calls_list:
+                                tc_data = self._create_tool_call_data(c)
+                                async for (
+                                    line
+                                ) in SSEFormatter.yield_tool_call_event(
+                                    tc_data
+                                ):
+                                    yield line
+
+                            # (b) Add an assistant message capturing these calls
+                            await self._add_tool_calls_message(
+                                calls_list, partial_text_buffer
+                            )
+
+                            # (c) Execute each tool call in parallel
+                            await asyncio.gather(
+                                *[
+                                    self.handle_function_or_tool_call(
+                                        c["name"],
+                                        c["arguments"],
+                                        tool_id=c["tool_call_id"],
+                                    )
+                                    for c in calls_list
+                                ]
+                            )
+
+                            # Reset buffer & calls
+                            pending_tool_calls.clear()
+                            partial_text_buffer = ""
+
+                        elif finish_reason == "stop":
+                            # Handle thinking if present
+                            await self._handle_thinking(
+                                thinking_signatures, accumulated_thinking
+                            )
+
+                            # 6) The LLM is done. If we have any leftover partial text,
+                            #    finalize it in the conversation
+                            if partial_text_buffer:
+                                # Create the final message with metadata including citations
+                                final_message = Message(
+                                    role="assistant",
+                                    content=partial_text_buffer,
+                                    metadata={
+                                        "citations": self.streaming_citations
+                                    },
+                                )
+
+                                # Add it to the conversation
+                                await self.conversation.add_message(
+                                    final_message
+                                )
+
+                            # (a) Prepare final answer with optimized citations
+                            consolidated_citations = []
+                            # Group citations by ID with all their spans
+                            for (
+                                cid,
+                                spans,
+                            ) in citation_tracker.get_all_spans().items():
+                                if cid in citation_payloads:
+                                    consolidated_citations.append(
+                                        {
+                                            "id": cid,
+                                            "object": "citation",
+                                            "spans": [
+                                                {"start": s[0], "end": s[1]}
+                                                for s in spans
+                                            ],
+                                            "payload": citation_payloads[cid],
+                                        }
+                                    )
+
+                            # Create final answer payload
+                            final_evt_payload = {
+                                "id": "msg_final",
+                                "object": "agent.final_answer",
+                                "generated_answer": partial_text_buffer,
+                                "citations": consolidated_citations,
+                            }
+
+                            # Emit final answer event
+                            async for (
+                                line
+                            ) in SSEFormatter.yield_final_answer_event(
+                                final_evt_payload
+                            ):
+                                yield line
+
+                            # (b) Signal the end of the SSE stream
+                            yield SSEFormatter.yield_done_event()
+                            self._completed = True
+                            break
+
+                # If we exit the while loop due to hitting max iterations
+                if not self._completed:
+                    # Generate a summary using the LLM
+                    summary = await self._generate_llm_summary(
+                        iterations_count
+                    )
+
+                    # Send the summary as a message event
+                    async for line in SSEFormatter.yield_message_event(
+                        summary
+                    ):
+                        yield line
+
+                    # Add summary to conversation with citations metadata
+                    await self.conversation.add_message(
+                        Message(
+                            role="assistant",
+                            content=summary,
+                            metadata={"citations": self.streaming_citations},
+                        )
+                    )
+
+                    # Create and emit a final answer payload with the summary
+                    final_evt_payload = {
+                        "id": "msg_final",
+                        "object": "agent.final_answer",
+                        "generated_answer": summary,
+                        "citations": consolidated_citations,
+                    }
+
+                    async for line in SSEFormatter.yield_final_answer_event(
+                        final_evt_payload
+                    ):
+                        yield line
+
+                    # Signal the end of the SSE stream
+                    yield SSEFormatter.yield_done_event()
+                    self._completed = True
+
+            except Exception as e:
+                logger.error(f"Error in streaming agent: {str(e)}")
+                # Emit error event for client
+                async for line in SSEFormatter.yield_error_event(
+                    f"Agent error: {str(e)}"
+                ):
+                    yield line
+                # Send done event to close the stream
+                yield SSEFormatter.yield_done_event()
+
+        # Finally, we return the async generator
+        async for line in sse_generator():
+            yield line
+
+    async def _handle_thinking(
+        self, thinking_signatures, accumulated_thinking
+    ):
+        """Process any accumulated thinking content"""
+        if accumulated_thinking:
+            structured_content = [
+                {
+                    "type": "thinking",
+                    "thinking": accumulated_thinking,
+                    # Anthropic will validate this in their API
+                    "signature": "placeholder_signature",
+                }
+            ]
+
+            assistant_msg = Message(
+                role="assistant",
+                structured_content=structured_content,
+            )
+            await self.conversation.add_message(assistant_msg)
+
+        elif thinking_signatures:
+            for (
+                accumulated_thinking,
+                thinking_signature,
+            ) in thinking_signatures.items():
+                structured_content = [
+                    {
+                        "type": "thinking",
+                        "thinking": accumulated_thinking,
+                        # Anthropic will validate this in their API
+                        "signature": thinking_signature,
+                    }
+                ]
+
+                assistant_msg = Message(
+                    role="assistant",
+                    structured_content=structured_content,
+                )
+                await self.conversation.add_message(assistant_msg)
+
+    async def _add_tool_calls_message(self, calls_list, partial_text_buffer):
+        """Add a message with tool calls to the conversation"""
+        assistant_msg = Message(
+            role="assistant",
+            content=partial_text_buffer or "",
+            tool_calls=[
+                {
+                    "id": c["tool_call_id"],
+                    "type": "function",
+                    "function": {
+                        "name": c["name"],
+                        "arguments": c["arguments"],
+                    },
+                }
+                for c in calls_list
+            ],
+        )
+        await self.conversation.add_message(assistant_msg)
+
+    def _create_tool_call_data(self, call_info):
+        """Create tool call data structure from call info"""
+        return {
+            "tool_call_id": call_info["tool_call_id"],
+            "name": call_info["name"],
+            "arguments": call_info["arguments"],
+        }
+
+    def _create_citation_payload(self, short_id, payload):
+        """Create citation payload for a short ID"""
+        # This will be overridden in RAG subclasses
+        # check if as_dict is on payload
+        if hasattr(payload, "as_dict"):
+            payload = payload.as_dict()
+        if hasattr(payload, "dict"):
+            payload = payload.dict
+        if hasattr(payload, "to_dict"):
+            payload = payload.to_dict()
+
+        return {
+            "id": f"{short_id}",
+            "object": "citation",
+            "payload": dump_obj(payload),  # Will be populated in RAG agents
+        }
+
+    def _create_final_answer_payload(self, answer_text, citations):
+        """Create the final answer payload"""
+        # This will be extended in RAG subclasses
+        return {
+            "id": "msg_final",
+            "object": "agent.final_answer",
+            "generated_answer": answer_text,
+            "citations": citations,
+        }
+
+
+class R2RXMLStreamingAgent(R2RStreamingAgent):
+    """
+    A streaming agent that parses XML-formatted responses with special handling for:
+     - <think> or <Thought> blocks for chain-of-thought reasoning
+     - <Action>, <ToolCalls>, <ToolCall> blocks for tool execution
+    """
+
+    # We treat <think> or <Thought> as the same token boundaries
+    THOUGHT_OPEN = re.compile(r"<(Thought|think)>", re.IGNORECASE)
+    THOUGHT_CLOSE = re.compile(r"</(Thought|think)>", re.IGNORECASE)
+
+    # Regexes to parse out <Action>, <ToolCalls>, <ToolCall>, <Name>, <Parameters>, <Response>
+    ACTION_PATTERN = re.compile(
+        r"<Action>(.*?)</Action>", re.IGNORECASE | re.DOTALL
+    )
+    TOOLCALLS_PATTERN = re.compile(
+        r"<ToolCalls>(.*?)</ToolCalls>", re.IGNORECASE | re.DOTALL
+    )
+    TOOLCALL_PATTERN = re.compile(
+        r"<ToolCall>(.*?)</ToolCall>", re.IGNORECASE | re.DOTALL
+    )
+    NAME_PATTERN = re.compile(r"<Name>(.*?)</Name>", re.IGNORECASE | re.DOTALL)
+    PARAMS_PATTERN = re.compile(
+        r"<Parameters>(.*?)</Parameters>", re.IGNORECASE | re.DOTALL
+    )
+    RESPONSE_PATTERN = re.compile(
+        r"<Response>(.*?)</Response>", re.IGNORECASE | re.DOTALL
+    )
+
+    async def arun(
+        self,
+        system_instruction: str | None = None,
+        messages: list[Message] | None = None,
+        *args,
+        **kwargs,
+    ) -> AsyncGenerator[str, None]:
+        """
+        Main streaming entrypoint: returns an async generator of SSE lines.
+        """
+        self._reset()
+        await self._setup(system_instruction)
+
+        if messages:
+            for m in messages:
+                await self.conversation.add_message(m)
+
+        # Initialize citation tracker for this run
+        citation_tracker = CitationTracker()
+
+        # Dictionary to store citation payloads by ID
+        citation_payloads = {}
+
+        # Track all citations emitted during streaming for final persistence
+        self.streaming_citations: list[dict] = []
+
+        async def sse_generator() -> AsyncGenerator[str, None]:
+            iterations_count = 0
+
+            try:
+                # Keep streaming until we complete
+                while (
+                    not self._completed
+                    and iterations_count < self.config.max_iterations
+                ):
+                    iterations_count += 1
+                    # 1) Get current messages
+                    msg_list = await self.conversation.get_messages()
+                    gen_cfg = self.get_generation_config(
+                        msg_list[-1], stream=True
+                    )
+
+                    # 2) Start streaming from LLM
+                    llm_stream = self.llm_provider.aget_completion_stream(
+                        msg_list, gen_cfg
+                    )
+
+                    # Create state variables for each iteration
+                    iteration_buffer = ""
+                    yielded_first_event = False
+                    in_action_block = False
+                    is_thinking = False
+                    accumulated_thinking = ""
+                    thinking_signatures = {}
+
+                    async for chunk in llm_stream:
+                        delta = chunk.choices[0].delta
+                        finish_reason = chunk.choices[0].finish_reason
+
+                        # Handle thinking if present
+                        if hasattr(delta, "thinking") and delta.thinking:
+                            # Accumulate thinking for later use in messages
+                            accumulated_thinking += delta.thinking
+
+                            # Emit SSE "thinking" event
+                            async for (
+                                line
+                            ) in SSEFormatter.yield_thinking_event(
+                                delta.thinking
+                            ):
+                                yield line
+
+                        # Add this new handler for thinking signatures
+                        if hasattr(delta, "thinking_signature"):
+                            thinking_signatures[accumulated_thinking] = (
+                                delta.thinking_signature
+                            )
+                            accumulated_thinking = ""
+
+                        # 3) If new text, accumulate it
+                        if delta.content:
+                            iteration_buffer += delta.content
+
+                            # Check if we have accumulated enough text for a `<Thought>` block
+                            if len(iteration_buffer) < len("<Thought>"):
+                                continue
+
+                            # Check if we have yielded the first event
+                            if not yielded_first_event:
+                                # Emit the first chunk
+                                if self.THOUGHT_OPEN.findall(iteration_buffer):
+                                    is_thinking = True
+                                    async for (
+                                        line
+                                    ) in SSEFormatter.yield_thinking_event(
+                                        iteration_buffer
+                                    ):
+                                        yield line
+                                else:
+                                    async for (
+                                        line
+                                    ) in SSEFormatter.yield_message_event(
+                                        iteration_buffer
+                                    ):
+                                        yield line
+
+                                # Mark as yielded
+                                yielded_first_event = True
+                                continue
+
+                            # Check if we are in a thinking block
+                            if is_thinking:
+                                # Still thinking, so keep yielding thinking events
+                                if not self.THOUGHT_CLOSE.findall(
+                                    iteration_buffer
+                                ):
+                                    # Emit SSE "thinking" event
+                                    async for (
+                                        line
+                                    ) in SSEFormatter.yield_thinking_event(
+                                        delta.content
+                                    ):
+                                        yield line
+
+                                    continue
+                                # Done thinking, so emit the last thinking event
+                                else:
+                                    is_thinking = False
+                                    thought_text = delta.content.split(
+                                        "</Thought>"
+                                    )[0].split("</think>")[0]
+                                    async for (
+                                        line
+                                    ) in SSEFormatter.yield_thinking_event(
+                                        thought_text
+                                    ):
+                                        yield line
+                                    post_thought_text = delta.content.split(
+                                        "</Thought>"
+                                    )[-1].split("</think>")[-1]
+                                    delta.content = post_thought_text
+
+                            # (b) Find new citation spans in the accumulated text
+                            new_citation_spans = find_new_citation_spans(
+                                iteration_buffer, citation_tracker
+                            )
+
+                            # Process each new citation span
+                            for cid, spans in new_citation_spans.items():
+                                for span in spans:
+                                    # Check if this is the first time we've seen this citation ID
+                                    is_new_citation = (
+                                        citation_tracker.is_new_citation(cid)
+                                    )
+
+                                    # Get payload if it's a new citation
+                                    payload = None
+                                    if is_new_citation:
+                                        source_obj = self.search_results_collector.find_by_short_id(
+                                            cid
+                                        )
+                                        if source_obj:
+                                            # Store payload for reuse
+                                            payload = dump_obj(source_obj)
+                                            citation_payloads[cid] = payload
+
+                                    # Create citation event payload
+                                    citation_data = {
+                                        "id": cid,
+                                        "object": "citation",
+                                        "is_new": is_new_citation,
+                                        "span": {
+                                            "start": span[0],
+                                            "end": span[1],
+                                        },
+                                    }
+
+                                    # Only include full payload for new citations
+                                    if is_new_citation and payload:
+                                        citation_data["payload"] = payload
+
+                                    # Add to streaming citations for final answer
+                                    self.streaming_citations.append(
+                                        citation_data
+                                    )
+
+                                    # Emit the citation event
+                                    async for (
+                                        line
+                                    ) in SSEFormatter.yield_citation_event(
+                                        citation_data
+                                    ):
+                                        yield line
+
+                            # Now prepare to emit the newly streamed text as a "message" event
+                            if (
+                                iteration_buffer.count("<")
+                                and not in_action_block
+                            ):
+                                in_action_block = True
+
+                            if (
+                                in_action_block
+                                and len(
+                                    self.ACTION_PATTERN.findall(
+                                        iteration_buffer
+                                    )
+                                )
+                                < 2
+                            ):
+                                continue
+
+                            elif in_action_block:
+                                in_action_block = False
+                                # Emit the post action block text, if it is there
+                                post_action_text = iteration_buffer.split(
+                                    "</Action>"
+                                )[-1]
+                                if post_action_text:
+                                    async for (
+                                        line
+                                    ) in SSEFormatter.yield_message_event(
+                                        post_action_text
+                                    ):
+                                        yield line
+
+                            else:
+                                async for (
+                                    line
+                                ) in SSEFormatter.yield_message_event(
+                                    delta.content
+                                ):
+                                    yield line
+
+                        elif finish_reason == "stop":
+                            break
+
+                    # Process any accumulated thinking
+                    await self._handle_thinking(
+                        thinking_signatures, accumulated_thinking
+                    )
+
+                    # 6) The LLM is done. If we have any leftover partial text,
+                    #    finalize it in the conversation
+                    if iteration_buffer:
+                        # Create the final message with metadata including citations
+                        final_message = Message(
+                            role="assistant",
+                            content=iteration_buffer,
+                            metadata={"citations": self.streaming_citations},
+                        )
+
+                        # Add it to the conversation
+                        await self.conversation.add_message(final_message)
+
+                    # --- 4) Process any <Action>/<ToolCalls> blocks, or mark completed
+                    action_matches = self.ACTION_PATTERN.findall(
+                        iteration_buffer
+                    )
+
+                    if len(action_matches) > 0:
+                        # Process each ToolCall
+                        xml_toolcalls = "<ToolCalls>"
+
+                        for action_block in action_matches:
+                            tool_calls_text = []
+                            # Look for ToolCalls wrapper, or use the raw action block
+                            calls_wrapper = self.TOOLCALLS_PATTERN.findall(
+                                action_block
+                            )
+                            if calls_wrapper:
+                                for tw in calls_wrapper:
+                                    tool_calls_text.append(tw)
+                            else:
+                                tool_calls_text.append(action_block)
+
+                            for calls_region in tool_calls_text:
+                                calls_found = self.TOOLCALL_PATTERN.findall(
+                                    calls_region
+                                )
+                                for tc_block in calls_found:
+                                    tool_name, tool_params = (
+                                        self._parse_single_tool_call(tc_block)
+                                    )
+                                    if tool_name:
+                                        # Emit SSE event for tool call
+                                        tool_call_id = (
+                                            f"call_{abs(hash(tc_block))}"
+                                        )
+                                        call_evt_data = {
+                                            "tool_call_id": tool_call_id,
+                                            "name": tool_name,
+                                            "arguments": json.dumps(
+                                                tool_params
+                                            ),
+                                        }
+                                        async for line in (
+                                            SSEFormatter.yield_tool_call_event(
+                                                call_evt_data
+                                            )
+                                        ):
+                                            yield line
+
+                                        try:
+                                            tool_result = await self.handle_function_or_tool_call(
+                                                tool_name,
+                                                json.dumps(tool_params),
+                                                tool_id=tool_call_id,
+                                                save_messages=False,
+                                            )
+                                            result_content = tool_result.llm_formatted_result
+                                        except Exception as e:
+                                            result_content = f"Error in tool '{tool_name}': {str(e)}"
+
+                                        xml_toolcalls += (
+                                            f"<ToolCall>"
+                                            f"<Name>{tool_name}</Name>"
+                                            f"<Parameters>{json.dumps(tool_params)}</Parameters>"
+                                            f"<Result>{result_content}</Result>"
+                                            f"</ToolCall>"
+                                        )
+
+                                        # Emit SSE tool result for non-result tools
+                                        result_data = {
+                                            "tool_call_id": tool_call_id,
+                                            "role": "tool",
+                                            "content": json.dumps(
+                                                convert_nonserializable_objects(
+                                                    result_content
+                                                )
+                                            ),
+                                        }
+                                        async for line in SSEFormatter.yield_tool_result_event(
+                                            result_data
+                                        ):
+                                            yield line
+
+                        xml_toolcalls += "</ToolCalls>"
+                        pre_action_text = iteration_buffer[
+                            : iteration_buffer.find(action_block)
+                        ]
+                        post_action_text = iteration_buffer[
+                            iteration_buffer.find(action_block)
+                            + len(action_block) :
+                        ]
+                        iteration_text = (
+                            pre_action_text + xml_toolcalls + post_action_text
+                        )
+
+                        # Update the conversation with tool results
+                        await self.conversation.add_message(
+                            Message(
+                                role="assistant",
+                                content=iteration_text,
+                                metadata={
+                                    "citations": self.streaming_citations
+                                },
+                            )
+                        )
+                    else:
+                        # (a) Prepare final answer with optimized citations
+                        consolidated_citations = []
+                        # Group citations by ID with all their spans
+                        for (
+                            cid,
+                            spans,
+                        ) in citation_tracker.get_all_spans().items():
+                            if cid in citation_payloads:
+                                consolidated_citations.append(
+                                    {
+                                        "id": cid,
+                                        "object": "citation",
+                                        "spans": [
+                                            {"start": s[0], "end": s[1]}
+                                            for s in spans
+                                        ],
+                                        "payload": citation_payloads[cid],
+                                    }
+                                )
+
+                        # Create final answer payload
+                        final_evt_payload = {
+                            "id": "msg_final",
+                            "object": "agent.final_answer",
+                            "generated_answer": iteration_buffer,
+                            "citations": consolidated_citations,
+                        }
+
+                        # Emit final answer event
+                        async for (
+                            line
+                        ) in SSEFormatter.yield_final_answer_event(
+                            final_evt_payload
+                        ):
+                            yield line
+
+                        # (b) Signal the end of the SSE stream
+                        yield SSEFormatter.yield_done_event()
+                        self._completed = True
+
+                # If we exit the while loop due to hitting max iterations
+                if not self._completed:
+                    # Generate a summary using the LLM
+                    summary = await self._generate_llm_summary(
+                        iterations_count
+                    )
+
+                    # Send the summary as a message event
+                    async for line in SSEFormatter.yield_message_event(
+                        summary
+                    ):
+                        yield line
+
+                    # Add summary to conversation with citations metadata
+                    await self.conversation.add_message(
+                        Message(
+                            role="assistant",
+                            content=summary,
+                            metadata={"citations": self.streaming_citations},
+                        )
+                    )
+
+                    # Create and emit a final answer payload with the summary
+                    final_evt_payload = {
+                        "id": "msg_final",
+                        "object": "agent.final_answer",
+                        "generated_answer": summary,
+                        "citations": consolidated_citations,
+                    }
+
+                    async for line in SSEFormatter.yield_final_answer_event(
+                        final_evt_payload
+                    ):
+                        yield line
+
+                    # Signal the end of the SSE stream
+                    yield SSEFormatter.yield_done_event()
+                    self._completed = True
+
+            except Exception as e:
+                logger.error(f"Error in streaming agent: {str(e)}")
+                # Emit error event for client
+                async for line in SSEFormatter.yield_error_event(
+                    f"Agent error: {str(e)}"
+                ):
+                    yield line
+                # Send done event to close the stream
+                yield SSEFormatter.yield_done_event()
+
+        # Finally, we return the async generator
+        async for line in sse_generator():
+            yield line
+
+    def _parse_single_tool_call(
+        self, toolcall_text: str
+    ) -> Tuple[Optional[str], dict]:
+        """
+        Parse a ToolCall block to extract the name and parameters.
+
+        Args:
+            toolcall_text: The text content of a ToolCall block
+
+        Returns:
+            Tuple of (tool_name, tool_parameters)
+        """
+        name_match = self.NAME_PATTERN.search(toolcall_text)
+        if not name_match:
+            return None, {}
+        tool_name = name_match.group(1).strip()
+
+        params_match = self.PARAMS_PATTERN.search(toolcall_text)
+        if not params_match:
+            return tool_name, {}
+
+        raw_params = params_match.group(1).strip()
+        try:
+            # Handle potential JSON parsing issues
+            # First try direct parsing
+            tool_params = json.loads(raw_params)
+        except json.JSONDecodeError:
+            # If that fails, try to clean up the JSON string
+            try:
+                # Replace escaped quotes that might cause issues
+                cleaned_params = raw_params.replace('\\"', '"')
+                # Try again with the cleaned string
+                tool_params = json.loads(cleaned_params)
+            except json.JSONDecodeError:
+                # If all else fails, treat as a plain string value
+                tool_params = {"value": raw_params}
+
+        return tool_name, tool_params
+
+
+class R2RXMLToolsAgent(R2RAgent):
+    """
+    A non-streaming agent that:
+     - parses <think> or <Thought> blocks as chain-of-thought
+     - filters out XML tags related to tool calls and actions
+     - processes <Action><ToolCalls><ToolCall> blocks
+     - properly extracts citations when they appear in the text
+    """
+
+    # We treat <think> or <Thought> as the same token boundaries
+    THOUGHT_OPEN = re.compile(r"<(Thought|think)>", re.IGNORECASE)
+    THOUGHT_CLOSE = re.compile(r"</(Thought|think)>", re.IGNORECASE)
+
+    # Regexes to parse out <Action>, <ToolCalls>, <ToolCall>, <Name>, <Parameters>, <Response>
+    ACTION_PATTERN = re.compile(
+        r"<Action>(.*?)</Action>", re.IGNORECASE | re.DOTALL
+    )
+    TOOLCALLS_PATTERN = re.compile(
+        r"<ToolCalls>(.*?)</ToolCalls>", re.IGNORECASE | re.DOTALL
+    )
+    TOOLCALL_PATTERN = re.compile(
+        r"<ToolCall>(.*?)</ToolCall>", re.IGNORECASE | re.DOTALL
+    )
+    NAME_PATTERN = re.compile(r"<Name>(.*?)</Name>", re.IGNORECASE | re.DOTALL)
+    PARAMS_PATTERN = re.compile(
+        r"<Parameters>(.*?)</Parameters>", re.IGNORECASE | re.DOTALL
+    )
+    RESPONSE_PATTERN = re.compile(
+        r"<Response>(.*?)</Response>", re.IGNORECASE | re.DOTALL
+    )
+
+    async def process_llm_response(self, response, *args, **kwargs):
+        """
+        Override the base process_llm_response to handle XML structured responses
+        including thoughts and tool calls.
+        """
+        if self._completed:
+            return
+
+        message = response.choices[0].message
+        finish_reason = response.choices[0].finish_reason
+
+        if not message.content:
+            # If there's no content, let the parent class handle the normal tool_calls flow
+            return await super().process_llm_response(
+                response, *args, **kwargs
+            )
+
+        # Get the response content
+        content = message.content
+
+        # HACK for gemini
+        content = content.replace("```action", "")
+        content = content.replace("```tool_code", "")
+        content = content.replace("```", "")
+
+        if (
+            not content.startswith("<")
+            and "deepseek" in self.rag_generation_config.model
+        ):  # HACK - fix issues with adding `<think>` to the beginning
+            content = "<think>" + content
+
+        # Process any tool calls in the content
+        action_matches = self.ACTION_PATTERN.findall(content)
+        if action_matches:
+            xml_toolcalls = "<ToolCalls>"
+            for action_block in action_matches:
+                tool_calls_text = []
+                # Look for ToolCalls wrapper, or use the raw action block
+                calls_wrapper = self.TOOLCALLS_PATTERN.findall(action_block)
+                if calls_wrapper:
+                    for tw in calls_wrapper:
+                        tool_calls_text.append(tw)
+                else:
+                    tool_calls_text.append(action_block)
+
+                # Process each ToolCall
+                for calls_region in tool_calls_text:
+                    calls_found = self.TOOLCALL_PATTERN.findall(calls_region)
+                    for tc_block in calls_found:
+                        tool_name, tool_params = self._parse_single_tool_call(
+                            tc_block
+                        )
+                        if tool_name:
+                            tool_call_id = f"call_{abs(hash(tc_block))}"
+                            try:
+                                tool_result = (
+                                    await self.handle_function_or_tool_call(
+                                        tool_name,
+                                        json.dumps(tool_params),
+                                        tool_id=tool_call_id,
+                                        save_messages=False,
+                                    )
+                                )
+
+                                # Add tool result to XML
+                                xml_toolcalls += (
+                                    f"<ToolCall>"
+                                    f"<Name>{tool_name}</Name>"
+                                    f"<Parameters>{json.dumps(tool_params)}</Parameters>"
+                                    f"<Result>{tool_result.llm_formatted_result}</Result>"
+                                    f"</ToolCall>"
+                                )
+
+                            except Exception as e:
+                                logger.error(f"Error in tool call: {str(e)}")
+                                # Add error to XML
+                                xml_toolcalls += (
+                                    f"<ToolCall>"
+                                    f"<Name>{tool_name}</Name>"
+                                    f"<Parameters>{json.dumps(tool_params)}</Parameters>"
+                                    f"<Result>Error: {str(e)}</Result>"
+                                    f"</ToolCall>"
+                                )
+
+            xml_toolcalls += "</ToolCalls>"
+            pre_action_text = content[: content.find(action_block)]
+            post_action_text = content[
+                content.find(action_block) + len(action_block) :
+            ]
+            iteration_text = pre_action_text + xml_toolcalls + post_action_text
+
+            # Create the assistant message
+            await self.conversation.add_message(
+                Message(role="assistant", content=iteration_text)
+            )
+        else:
+            # Create an assistant message with the content as-is
+            await self.conversation.add_message(
+                Message(role="assistant", content=content)
+            )
+
+        # Only mark as completed if the finish_reason is "stop" or there are no action calls
+        # This allows the agent to continue the conversation when tool calls are processed
+        if finish_reason == "stop":
+            self._completed = True
+
+    def _parse_single_tool_call(
+        self, toolcall_text: str
+    ) -> Tuple[Optional[str], dict]:
+        """
+        Parse a ToolCall block to extract the name and parameters.
+
+        Args:
+            toolcall_text: The text content of a ToolCall block
+
+        Returns:
+            Tuple of (tool_name, tool_parameters)
+        """
+        name_match = self.NAME_PATTERN.search(toolcall_text)
+        if not name_match:
+            return None, {}
+        tool_name = name_match.group(1).strip()
+
+        params_match = self.PARAMS_PATTERN.search(toolcall_text)
+        if not params_match:
+            return tool_name, {}
+
+        raw_params = params_match.group(1).strip()
+        try:
+            # Handle potential JSON parsing issues
+            # First try direct parsing
+            tool_params = json.loads(raw_params)
+        except json.JSONDecodeError:
+            # If that fails, try to clean up the JSON string
+            try:
+                # Replace escaped quotes that might cause issues
+                cleaned_params = raw_params.replace('\\"', '"')
+                # Try again with the cleaned string
+                tool_params = json.loads(cleaned_params)
+            except json.JSONDecodeError:
+                # If all else fails, treat as a plain string value
+                tool_params = {"value": raw_params}
+
+        return tool_name, tool_params