about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/shared/utils/base_utils.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/shared/utils/base_utils.py')
-rw-r--r--.venv/lib/python3.12/site-packages/shared/utils/base_utils.py783
1 files changed, 783 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/shared/utils/base_utils.py b/.venv/lib/python3.12/site-packages/shared/utils/base_utils.py
new file mode 100644
index 00000000..1864d0b4
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/shared/utils/base_utils.py
@@ -0,0 +1,783 @@
+import json
+import logging
+import math
+import uuid
+from abc import ABCMeta
+from copy import deepcopy
+from datetime import datetime
+from typing import TYPE_CHECKING, Any, Optional, Tuple, TypeVar
+from uuid import NAMESPACE_DNS, UUID, uuid4, uuid5
+
+import tiktoken
+
+from ..abstractions import (
+    AggregateSearchResult,
+    AsyncSyncMeta,
+    GraphCommunityResult,
+    GraphEntityResult,
+    GraphRelationshipResult,
+)
+from ..abstractions.vector import VectorQuantizationType
+
+if TYPE_CHECKING:
+    pass
+
+
+logger = logging.getLogger()
+
+
+def id_to_shorthand(id: str | UUID):
+    return str(id)[:7]
+
+
+def format_search_results_for_llm(
+    results: AggregateSearchResult,
+    collector: Any,  # SearchResultsCollector
+) -> str:
+    """
+    Instead of resetting 'source_counter' to 1, we:
+     - For each chunk / graph / web / doc in `results`,
+     - Find the aggregator index from the collector,
+     - Print 'Source [X]:' with that aggregator index.
+    """
+    lines = []
+
+    # We'll build a quick helper to locate aggregator indices for each object:
+    # Or you can rely on the fact that we've added them to the collector
+    # in the same order. But let's do a "lookup aggregator index" approach:
+
+    # 1) Chunk search
+    if results.chunk_search_results:
+        lines.append("Vector Search Results:")
+        for c in results.chunk_search_results:
+            lines.append(f"Source ID [{id_to_shorthand(c.id)}]:")
+            lines.append(c.text or "")  # or c.text[:200] to truncate
+
+    # 2) Graph search
+    if results.graph_search_results:
+        lines.append("Graph Search Results:")
+        for g in results.graph_search_results:
+            lines.append(f"Source ID [{id_to_shorthand(g.id)}]:")
+            if isinstance(g.content, GraphCommunityResult):
+                lines.append(f"Community Name: {g.content.name}")
+                lines.append(f"ID: {g.content.id}")
+                lines.append(f"Summary: {g.content.summary}")
+                # etc. ...
+            elif isinstance(g.content, GraphEntityResult):
+                lines.append(f"Entity Name: {g.content.name}")
+                lines.append(f"Description: {g.content.description}")
+            elif isinstance(g.content, GraphRelationshipResult):
+                lines.append(
+                    f"Relationship: {g.content.subject}-{g.content.predicate}-{g.content.object}"
+                )
+            # Add metadata if needed
+
+    # 3) Web search
+    if results.web_search_results:
+        lines.append("Web Search Results:")
+        for w in results.web_search_results:
+            lines.append(f"Source ID [{id_to_shorthand(w.id)}]:")
+            lines.append(f"Title: {w.title}")
+            lines.append(f"Link: {w.link}")
+            lines.append(f"Snippet: {w.snippet}")
+
+    # 4) Local context docs
+    if results.document_search_results:
+        lines.append("Local Context Documents:")
+        for doc_result in results.document_search_results:
+            doc_title = doc_result.title or "Untitled Document"
+            doc_id = doc_result.id
+            summary = doc_result.summary
+
+            lines.append(f"Full Document ID: {doc_id}")
+            lines.append(f"Shortened Document ID: {id_to_shorthand(doc_id)}")
+            lines.append(f"Document Title: {doc_title}")
+            if summary:
+                lines.append(f"Summary: {summary}")
+
+            if doc_result.chunks:
+                # Then each chunk inside:
+                for chunk in doc_result.chunks:
+                    lines.append(
+                        f"\nChunk ID {id_to_shorthand(chunk['id'])}:\n{chunk['text']}"
+                    )
+
+    result = "\n".join(lines)
+    return result
+
+
+def _generate_id_from_label(label) -> UUID:
+    return uuid5(NAMESPACE_DNS, label)
+
+
+def generate_id(label: Optional[str] = None) -> UUID:
+    """Generates a unique run id."""
+    return _generate_id_from_label(
+        label if label is not None else str(uuid4())
+    )
+
+
+def generate_document_id(filename: str, user_id: UUID) -> UUID:
+    """Generates a unique document id from a given filename and user id."""
+    safe_filename = filename.replace("/", "_")
+    return _generate_id_from_label(f"{safe_filename}-{str(user_id)}")
+
+
+def generate_extraction_id(
+    document_id: UUID, iteration: int = 0, version: str = "0"
+) -> UUID:
+    """Generates a unique extraction id from a given document id and
+    iteration."""
+    return _generate_id_from_label(f"{str(document_id)}-{iteration}-{version}")
+
+
+def generate_default_user_collection_id(user_id: UUID) -> UUID:
+    """Generates a unique collection id from a given user id."""
+    return _generate_id_from_label(str(user_id))
+
+
+def generate_user_id(email: str) -> UUID:
+    """Generates a unique user id from a given email."""
+    return _generate_id_from_label(email)
+
+
+def generate_default_prompt_id(prompt_name: str) -> UUID:
+    """Generates a unique prompt id."""
+    return _generate_id_from_label(prompt_name)
+
+
+def generate_entity_document_id() -> UUID:
+    """Generates a unique document id inserting entities into a graph."""
+    generation_time = datetime.now().isoformat()
+    return _generate_id_from_label(f"entity-{generation_time}")
+
+
+def increment_version(version: str) -> str:
+    prefix = version[:-1]
+    suffix = int(version[-1])
+    return f"{prefix}{suffix + 1}"
+
+
+def decrement_version(version: str) -> str:
+    prefix = version[:-1]
+    suffix = int(version[-1])
+    return f"{prefix}{max(0, suffix - 1)}"
+
+
+def validate_uuid(uuid_str: str) -> UUID:
+    return UUID(uuid_str)
+
+
+def update_settings_from_dict(server_settings, settings_dict: dict):
+    """Updates a settings object with values from a dictionary."""
+    settings = deepcopy(server_settings)
+    for key, value in settings_dict.items():
+        if value is not None:
+            if isinstance(value, dict):
+                for k, v in value.items():
+                    if isinstance(getattr(settings, key), dict):
+                        getattr(settings, key)[k] = v
+                    else:
+                        setattr(getattr(settings, key), k, v)
+            else:
+                setattr(settings, key, value)
+
+    return settings
+
+
+def _decorate_vector_type(
+    input_str: str,
+    quantization_type: VectorQuantizationType = VectorQuantizationType.FP32,
+) -> str:
+    return f"{quantization_type.db_type}{input_str}"
+
+
+def _get_vector_column_str(
+    dimension: int | float, quantization_type: VectorQuantizationType
+) -> str:
+    """Returns a string representation of a vector column type.
+
+    Explicitly handles the case where the dimension is not a valid number meant
+    to support embedding models that do not allow for specifying the dimension.
+    """
+    if math.isnan(dimension) or dimension <= 0:
+        vector_dim = ""  # Allows for Postgres to handle any dimension
+    else:
+        vector_dim = f"({dimension})"
+    return _decorate_vector_type(vector_dim, quantization_type)
+
+
+KeyType = TypeVar("KeyType")
+
+
+def deep_update(
+    mapping: dict[KeyType, Any], *updating_mappings: dict[KeyType, Any]
+) -> dict[KeyType, Any]:
+    """
+    Taken from Pydantic v1:
+    https://github.com/pydantic/pydantic/blob/fd2991fe6a73819b48c906e3c3274e8e47d0f761/pydantic/utils.py#L200
+    """
+    updated_mapping = mapping.copy()
+    for updating_mapping in updating_mappings:
+        for k, v in updating_mapping.items():
+            if (
+                k in updated_mapping
+                and isinstance(updated_mapping[k], dict)
+                and isinstance(v, dict)
+            ):
+                updated_mapping[k] = deep_update(updated_mapping[k], v)
+            else:
+                updated_mapping[k] = v
+    return updated_mapping
+
+
+def tokens_count_for_message(message, encoding):
+    """Return the number of tokens used by a single message."""
+    tokens_per_message = 3
+
+    num_tokens = 0
+    num_tokens += tokens_per_message
+    if message.get("function_call"):
+        num_tokens += len(encoding.encode(message["function_call"]["name"]))
+        num_tokens += len(
+            encoding.encode(message["function_call"]["arguments"])
+        )
+    elif message.get("tool_calls"):
+        for tool_call in message["tool_calls"]:
+            num_tokens += len(encoding.encode(tool_call["function"]["name"]))
+            num_tokens += len(
+                encoding.encode(tool_call["function"]["arguments"])
+            )
+    else:
+        if "content" in message:
+            num_tokens += len(encoding.encode(message["content"]))
+
+    return num_tokens
+
+
+def num_tokens_from_messages(messages, model="gpt-4o"):
+    """Return the number of tokens used by a list of messages for both user and assistant."""
+    try:
+        encoding = tiktoken.encoding_for_model(model)
+    except KeyError:
+        logger.warning("Warning: model not found. Using cl100k_base encoding.")
+        encoding = tiktoken.get_encoding("cl100k_base")
+
+    tokens = 0
+    for message_ in messages:
+        tokens += tokens_count_for_message(message_, encoding)
+
+        tokens += 3  # every reply is primed with assistant
+    return tokens
+
+
+class SearchResultsCollector:
+    """
+    Collects search results in the form (source_type, result_obj).
+    Handles both object-oriented and dictionary-based search results.
+    """
+
+    def __init__(self):
+        # We'll store a list of (source_type, result_obj)
+        self._results_in_order = []
+
+    @property
+    def results(self):
+        """Get the results list"""
+        return self._results_in_order
+
+    @results.setter
+    def results(self, value):
+        """
+        Set the results directly, with automatic type detection for 'unknown' items
+        Handles the format: [('unknown', {...}), ('unknown', {...})]
+        """
+        self._results_in_order = []
+
+        if isinstance(value, list):
+            for item in value:
+                if isinstance(item, tuple) and len(item) == 2:
+                    source_type, result_obj = item
+
+                    # Only auto-detect if the source type is "unknown"
+                    if source_type == "unknown":
+                        detected_type = self._detect_result_type(result_obj)
+                        self._results_in_order.append(
+                            (detected_type, result_obj)
+                        )
+                    else:
+                        self._results_in_order.append(
+                            (source_type, result_obj)
+                        )
+                else:
+                    # If not a tuple, detect and add
+                    detected_type = self._detect_result_type(item)
+                    self._results_in_order.append((detected_type, item))
+        else:
+            raise ValueError("Results must be a list")
+
+    def add_aggregate_result(self, agg):
+        """
+        Flatten the chunk_search_results, graph_search_results, web_search_results,
+        and document_search_results into the collector, including nested chunks.
+        """
+        if hasattr(agg, "chunk_search_results") and agg.chunk_search_results:
+            for c in agg.chunk_search_results:
+                self._results_in_order.append(("chunk", c))
+
+        if hasattr(agg, "graph_search_results") and agg.graph_search_results:
+            for g in agg.graph_search_results:
+                self._results_in_order.append(("graph", g))
+
+        if hasattr(agg, "web_search_results") and agg.web_search_results:
+            for w in agg.web_search_results:
+                self._results_in_order.append(("web", w))
+
+        # Add documents and extract their chunks
+        if (
+            hasattr(agg, "document_search_results")
+            and agg.document_search_results
+        ):
+            for doc in agg.document_search_results:
+                # Add the document itself
+                self._results_in_order.append(("doc", doc))
+
+                # Extract and add chunks from the document
+                chunks = None
+                if isinstance(doc, dict):
+                    chunks = doc.get("chunks", [])
+                elif hasattr(doc, "chunks") and doc.chunks is not None:
+                    chunks = doc.chunks
+
+                if chunks:
+                    for chunk in chunks:
+                        # Ensure each chunk has the minimum required attributes
+                        if isinstance(chunk, dict) and "id" in chunk:
+                            # Add the chunk directly to results for citation lookup
+                            self._results_in_order.append(("chunk", chunk))
+                        elif hasattr(chunk, "id"):
+                            self._results_in_order.append(("chunk", chunk))
+
+    def add_result(self, result_obj, source_type=None):
+        """
+        Add a single result object to the collector.
+        If source_type is not provided, automatically detect the type.
+        """
+        if source_type:
+            self._results_in_order.append((source_type, result_obj))
+            return source_type
+
+        detected_type = self._detect_result_type(result_obj)
+        self._results_in_order.append((detected_type, result_obj))
+        return detected_type
+
+    def _detect_result_type(self, obj):
+        """
+        Detect the type of a result object based on its properties.
+        Works with both object attributes and dictionary keys.
+        """
+        # Handle dictionary types first (common for web search results)
+        if isinstance(obj, dict):
+            # Web search pattern
+            if all(k in obj for k in ["title", "link"]) and any(
+                k in obj for k in ["snippet", "description"]
+            ):
+                return "web"
+
+            # Check for graph dictionary patterns
+            if "content" in obj and isinstance(obj["content"], dict):
+                content = obj["content"]
+                if all(k in content for k in ["name", "description"]):
+                    return "graph"  # Entity
+                if all(
+                    k in content for k in ["subject", "predicate", "object"]
+                ):
+                    return "graph"  # Relationship
+                if all(k in content for k in ["name", "summary"]):
+                    return "graph"  # Community
+
+            # Chunk pattern
+            if all(k in obj for k in ["text", "id"]) and any(
+                k in obj for k in ["score", "metadata"]
+            ):
+                return "chunk"
+
+            # Context document pattern
+            if "document" in obj and "chunks" in obj:
+                return "doc"
+
+            # Check for explicit type indicator
+            if "type" in obj:
+                type_val = str(obj["type"]).lower()
+                if any(t in type_val for t in ["web", "organic"]):
+                    return "web"
+                if "graph" in type_val:
+                    return "graph"
+                if "chunk" in type_val:
+                    return "chunk"
+                if "document" in type_val:
+                    return "doc"
+
+        # Handle object attributes for OOP-style results
+        if hasattr(obj, "result_type"):
+            result_type = str(obj.result_type).lower()
+            if result_type in ["entity", "relationship", "community"]:
+                return "graph"
+
+        # Check class name hints
+        class_name = obj.__class__.__name__
+        if "Graph" in class_name:
+            return "graph"
+        if "Chunk" in class_name:
+            return "chunk"
+        if "Web" in class_name:
+            return "web"
+        if "Document" in class_name:
+            return "doc"
+
+        # Check for object attribute patterns
+        if hasattr(obj, "content"):
+            content = obj.content
+            if hasattr(content, "name") and hasattr(content, "description"):
+                return "graph"  # Entity
+            if hasattr(content, "subject") and hasattr(content, "predicate"):
+                return "graph"  # Relationship
+            if hasattr(content, "name") and hasattr(content, "summary"):
+                return "graph"  # Community
+
+        if (
+            hasattr(obj, "text")
+            and hasattr(obj, "id")
+            and (hasattr(obj, "score") or hasattr(obj, "metadata"))
+        ):
+            return "chunk"
+
+        if (
+            hasattr(obj, "title")
+            and hasattr(obj, "link")
+            and hasattr(obj, "snippet")
+        ):
+            return "web"
+
+        if hasattr(obj, "document") and hasattr(obj, "chunks"):
+            return "doc"
+
+        # Default when type can't be determined
+        return "unknown"
+
+    def find_by_short_id(self, short_id):
+        """Find a result by its short ID prefix with better chunk handling"""
+        if not short_id:
+            return None
+
+        # First try direct lookup using regular iteration
+        for _, result_obj in self._results_in_order:
+            # Check dictionary objects
+            if isinstance(result_obj, dict) and "id" in result_obj:
+                result_id = str(result_obj["id"])
+                if result_id.startswith(short_id):
+                    return result_obj
+
+            # Check object with id attribute
+            elif hasattr(result_obj, "id"):
+                obj_id = getattr(result_obj, "id", None)
+                if obj_id and str(obj_id).startswith(short_id):
+                    # Convert to dict if possible
+                    if hasattr(result_obj, "as_dict"):
+                        return result_obj.as_dict()
+                    elif hasattr(result_obj, "model_dump"):
+                        return result_obj.model_dump()
+                    elif hasattr(result_obj, "dict"):
+                        return result_obj.dict()
+                    else:
+                        return result_obj
+
+        # If not found, look for chunks inside documents that weren't extracted properly
+        for source_type, result_obj in self._results_in_order:
+            if source_type == "doc":
+                # Try various ways to access chunks
+                chunks = None
+                if isinstance(result_obj, dict) and "chunks" in result_obj:
+                    chunks = result_obj["chunks"]
+                elif (
+                    hasattr(result_obj, "chunks")
+                    and result_obj.chunks is not None
+                ):
+                    chunks = result_obj.chunks
+
+                if chunks:
+                    for chunk in chunks:
+                        # Try each chunk
+                        chunk_id = None
+                        if isinstance(chunk, dict) and "id" in chunk:
+                            chunk_id = chunk["id"]
+                        elif hasattr(chunk, "id"):
+                            chunk_id = chunk.id
+
+                        if chunk_id and str(chunk_id).startswith(short_id):
+                            return chunk
+
+        return None
+
+    def get_results_by_type(self, type_name):
+        """Get all results of a specific type"""
+        return [
+            result_obj
+            for source_type, result_obj in self._results_in_order
+            if source_type == type_name
+        ]
+
+    def __repr__(self):
+        """String representation showing counts by type"""
+        type_counts = {}
+        for source_type, _ in self._results_in_order:
+            type_counts[source_type] = type_counts.get(source_type, 0) + 1
+
+        return f"SearchResultsCollector with {len(self._results_in_order)} results: {type_counts}"
+
+    def get_all_results(self) -> list[Tuple[str, Any]]:
+        """
+        Return list of (source_type, result_obj, aggregator_index),
+        in the order appended.
+        """
+        return self._results_in_order
+
+
+def convert_nonserializable_objects(obj):
+    if hasattr(obj, "model_dump"):
+        obj = obj.model_dump()
+    if hasattr(obj, "as_dict"):
+        obj = obj.as_dict()
+    if hasattr(obj, "to_dict"):
+        obj = obj.to_dict()
+
+    if isinstance(obj, dict):
+        new_obj = {}
+        for key, value in obj.items():
+            # Convert key to string if it is a UUID or not already a string.
+            new_key = str(key) if not isinstance(key, str) else key
+            new_obj[new_key] = convert_nonserializable_objects(value)
+        return new_obj
+    elif isinstance(obj, list):
+        return [convert_nonserializable_objects(item) for item in obj]
+    elif isinstance(obj, tuple):
+        return tuple(convert_nonserializable_objects(item) for item in obj)
+    elif isinstance(obj, set):
+        return {convert_nonserializable_objects(item) for item in obj}
+    elif isinstance(obj, uuid.UUID):
+        return str(obj)
+    elif isinstance(obj, datetime):
+        return obj.isoformat()  # Convert datetime to ISO formatted string
+    else:
+        return obj
+
+
+def dump_obj(obj) -> list[dict[str, Any]]:
+    if hasattr(obj, "model_dump"):
+        obj = obj.model_dump()
+    elif hasattr(obj, "dict"):
+        obj = obj.dict()
+    elif hasattr(obj, "as_dict"):
+        obj = obj.as_dict()
+    elif hasattr(obj, "to_dict"):
+        obj = obj.to_dict()
+    obj = convert_nonserializable_objects(obj)
+
+    return obj
+
+
+def dump_collector(collector: SearchResultsCollector) -> list[dict[str, Any]]:
+    dumped = []
+    for source_type, result_obj in collector.get_all_results():
+        # Get the dictionary from the result object
+        if hasattr(result_obj, "model_dump"):
+            result_dict = result_obj.model_dump()
+        elif hasattr(result_obj, "dict"):
+            result_dict = result_obj.dict()
+        elif hasattr(result_obj, "as_dict"):
+            result_dict = result_obj.as_dict()
+        elif hasattr(result_obj, "to_dict"):
+            result_dict = result_obj.to_dict()
+        else:
+            result_dict = (
+                result_obj  # Fallback if no conversion method is available
+            )
+
+        # Use the recursive conversion on the entire dictionary
+        result_dict = convert_nonserializable_objects(result_dict)
+
+        dumped.append(
+            {
+                "source_type": source_type,
+                "result": result_dict,
+            }
+        )
+    return dumped
+
+
+def num_tokens(text, model="gpt-4o"):
+    try:
+        encoding = tiktoken.encoding_for_model(model)
+    except KeyError:
+        encoding = tiktoken.get_encoding("cl100k_base")
+
+    """Return the number of tokens used by a list of messages for both user and assistant."""
+    return len(encoding.encode(text, disallowed_special=()))
+
+
+class CombinedMeta(AsyncSyncMeta, ABCMeta):
+    pass
+
+
+async def yield_sse_event(event_name: str, payload: dict, chunk_size=1024):
+    """
+    Helper that yields a single SSE event in properly chunked lines.
+
+    e.g. event: event_name
+         data: (partial JSON 1)
+         data: (partial JSON 2)
+         ...
+         [blank line to end event]
+    """
+
+    # SSE: first the "event: ..."
+    yield f"event: {event_name}\n"
+
+    # Convert payload to JSON
+    content_str = json.dumps(payload, default=str)
+
+    # data
+    yield f"data: {content_str}\n"
+
+    # blank line signals end of SSE event
+    yield "\n"
+
+
+class SSEFormatter:
+    """
+    Enhanced formatter for Server-Sent Events (SSE) with citation tracking.
+    Extends the existing SSEFormatter with improved citation handling.
+    """
+
+    @staticmethod
+    async def yield_citation_event(
+        citation_data: dict,
+    ):
+        """
+        Emits a citation event with optimized payload.
+
+        Args:
+            citation_id: The short ID of the citation (e.g., 'abc1234')
+            span: (start, end) position tuple for this occurrence
+            payload: Source object (included only for first occurrence)
+            is_new: Whether this is the first time we've seen this citation
+            citation_id_counter: Optional counter for citation occurrences
+
+        Yields:
+            Formatted SSE event lines
+        """
+
+        # Include the full payload only for new citations
+        if not citation_data.get("is_new") or "payload" not in citation_data:
+            citation_data["payload"] = None
+
+        # Yield the event
+        async for line in yield_sse_event("citation", citation_data):
+            yield line
+
+    @staticmethod
+    async def yield_final_answer_event(
+        final_data: dict,
+    ):
+        # Yield the event
+        async for line in yield_sse_event("final_answer", final_data):
+            yield line
+
+    # Include other existing SSEFormatter methods for compatibility
+    @staticmethod
+    async def yield_message_event(text_segment, msg_id=None):
+        msg_id = msg_id or f"msg_{uuid.uuid4().hex[:8]}"
+        msg_payload = {
+            "id": msg_id,
+            "object": "agent.message.delta",
+            "delta": {
+                "content": [
+                    {
+                        "type": "text",
+                        "payload": {
+                            "value": text_segment,
+                            "annotations": [],
+                        },
+                    }
+                ]
+            },
+        }
+        async for line in yield_sse_event("message", msg_payload):
+            yield line
+
+    @staticmethod
+    async def yield_thinking_event(text_segment, thinking_id=None):
+        thinking_id = thinking_id or f"think_{uuid.uuid4().hex[:8]}"
+        thinking_data = {
+            "id": thinking_id,
+            "object": "agent.thinking.delta",
+            "delta": {
+                "content": [
+                    {
+                        "type": "text",
+                        "payload": {
+                            "value": text_segment,
+                            "annotations": [],
+                        },
+                    }
+                ]
+            },
+        }
+        async for line in yield_sse_event("thinking", thinking_data):
+            yield line
+
+    @staticmethod
+    def yield_done_event():
+        return "event: done\ndata: [DONE]\n\n"
+
+    @staticmethod
+    async def yield_error_event(error_message, error_id=None):
+        error_id = error_id or f"err_{uuid.uuid4().hex[:8]}"
+        error_payload = {
+            "id": error_id,
+            "object": "agent.error",
+            "error": {"message": error_message, "type": "agent_error"},
+        }
+        async for line in yield_sse_event("error", error_payload):
+            yield line
+
+    @staticmethod
+    async def yield_tool_call_event(tool_call_data):
+        from ..api.models.retrieval.responses import ToolCallEvent
+
+        tc_event = ToolCallEvent(event="tool_call", data=tool_call_data)
+        async for line in yield_sse_event(
+            "tool_call", tc_event.dict()["data"]
+        ):
+            yield line
+
+    # New helper for emitting search results:
+    @staticmethod
+    async def yield_search_results_event(aggregated_results):
+        payload = {
+            "id": "search_1",
+            "object": "rag.search_results",
+            "data": aggregated_results.as_dict(),
+        }
+        async for line in yield_sse_event("search_results", payload):
+            yield line
+
+    @staticmethod
+    async def yield_tool_result_event(tool_result_data):
+        from ..api.models.retrieval.responses import ToolResultEvent
+
+        tr_event = ToolResultEvent(event="tool_result", data=tool_result_data)
+        async for line in yield_sse_event(
+            "tool_result", tr_event.dict()["data"]
+        ):
+            yield line