aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/shared/utils/base_utils.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/shared/utils/base_utils.py')
-rw-r--r--.venv/lib/python3.12/site-packages/shared/utils/base_utils.py783
1 files changed, 783 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/shared/utils/base_utils.py b/.venv/lib/python3.12/site-packages/shared/utils/base_utils.py
new file mode 100644
index 00000000..1864d0b4
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/shared/utils/base_utils.py
@@ -0,0 +1,783 @@
+import json
+import logging
+import math
+import uuid
+from abc import ABCMeta
+from copy import deepcopy
+from datetime import datetime
+from typing import TYPE_CHECKING, Any, Optional, Tuple, TypeVar
+from uuid import NAMESPACE_DNS, UUID, uuid4, uuid5
+
+import tiktoken
+
+from ..abstractions import (
+ AggregateSearchResult,
+ AsyncSyncMeta,
+ GraphCommunityResult,
+ GraphEntityResult,
+ GraphRelationshipResult,
+)
+from ..abstractions.vector import VectorQuantizationType
+
+if TYPE_CHECKING:
+ pass
+
+
+logger = logging.getLogger()
+
+
+def id_to_shorthand(id: str | UUID):
+ return str(id)[:7]
+
+
+def format_search_results_for_llm(
+ results: AggregateSearchResult,
+ collector: Any, # SearchResultsCollector
+) -> str:
+ """
+ Instead of resetting 'source_counter' to 1, we:
+ - For each chunk / graph / web / doc in `results`,
+ - Find the aggregator index from the collector,
+ - Print 'Source [X]:' with that aggregator index.
+ """
+ lines = []
+
+ # We'll build a quick helper to locate aggregator indices for each object:
+ # Or you can rely on the fact that we've added them to the collector
+ # in the same order. But let's do a "lookup aggregator index" approach:
+
+ # 1) Chunk search
+ if results.chunk_search_results:
+ lines.append("Vector Search Results:")
+ for c in results.chunk_search_results:
+ lines.append(f"Source ID [{id_to_shorthand(c.id)}]:")
+ lines.append(c.text or "") # or c.text[:200] to truncate
+
+ # 2) Graph search
+ if results.graph_search_results:
+ lines.append("Graph Search Results:")
+ for g in results.graph_search_results:
+ lines.append(f"Source ID [{id_to_shorthand(g.id)}]:")
+ if isinstance(g.content, GraphCommunityResult):
+ lines.append(f"Community Name: {g.content.name}")
+ lines.append(f"ID: {g.content.id}")
+ lines.append(f"Summary: {g.content.summary}")
+ # etc. ...
+ elif isinstance(g.content, GraphEntityResult):
+ lines.append(f"Entity Name: {g.content.name}")
+ lines.append(f"Description: {g.content.description}")
+ elif isinstance(g.content, GraphRelationshipResult):
+ lines.append(
+ f"Relationship: {g.content.subject}-{g.content.predicate}-{g.content.object}"
+ )
+ # Add metadata if needed
+
+ # 3) Web search
+ if results.web_search_results:
+ lines.append("Web Search Results:")
+ for w in results.web_search_results:
+ lines.append(f"Source ID [{id_to_shorthand(w.id)}]:")
+ lines.append(f"Title: {w.title}")
+ lines.append(f"Link: {w.link}")
+ lines.append(f"Snippet: {w.snippet}")
+
+ # 4) Local context docs
+ if results.document_search_results:
+ lines.append("Local Context Documents:")
+ for doc_result in results.document_search_results:
+ doc_title = doc_result.title or "Untitled Document"
+ doc_id = doc_result.id
+ summary = doc_result.summary
+
+ lines.append(f"Full Document ID: {doc_id}")
+ lines.append(f"Shortened Document ID: {id_to_shorthand(doc_id)}")
+ lines.append(f"Document Title: {doc_title}")
+ if summary:
+ lines.append(f"Summary: {summary}")
+
+ if doc_result.chunks:
+ # Then each chunk inside:
+ for chunk in doc_result.chunks:
+ lines.append(
+ f"\nChunk ID {id_to_shorthand(chunk['id'])}:\n{chunk['text']}"
+ )
+
+ result = "\n".join(lines)
+ return result
+
+
+def _generate_id_from_label(label) -> UUID:
+ return uuid5(NAMESPACE_DNS, label)
+
+
+def generate_id(label: Optional[str] = None) -> UUID:
+ """Generates a unique run id."""
+ return _generate_id_from_label(
+ label if label is not None else str(uuid4())
+ )
+
+
+def generate_document_id(filename: str, user_id: UUID) -> UUID:
+ """Generates a unique document id from a given filename and user id."""
+ safe_filename = filename.replace("/", "_")
+ return _generate_id_from_label(f"{safe_filename}-{str(user_id)}")
+
+
+def generate_extraction_id(
+ document_id: UUID, iteration: int = 0, version: str = "0"
+) -> UUID:
+ """Generates a unique extraction id from a given document id and
+ iteration."""
+ return _generate_id_from_label(f"{str(document_id)}-{iteration}-{version}")
+
+
+def generate_default_user_collection_id(user_id: UUID) -> UUID:
+ """Generates a unique collection id from a given user id."""
+ return _generate_id_from_label(str(user_id))
+
+
+def generate_user_id(email: str) -> UUID:
+ """Generates a unique user id from a given email."""
+ return _generate_id_from_label(email)
+
+
+def generate_default_prompt_id(prompt_name: str) -> UUID:
+ """Generates a unique prompt id."""
+ return _generate_id_from_label(prompt_name)
+
+
+def generate_entity_document_id() -> UUID:
+ """Generates a unique document id inserting entities into a graph."""
+ generation_time = datetime.now().isoformat()
+ return _generate_id_from_label(f"entity-{generation_time}")
+
+
+def increment_version(version: str) -> str:
+ prefix = version[:-1]
+ suffix = int(version[-1])
+ return f"{prefix}{suffix + 1}"
+
+
+def decrement_version(version: str) -> str:
+ prefix = version[:-1]
+ suffix = int(version[-1])
+ return f"{prefix}{max(0, suffix - 1)}"
+
+
+def validate_uuid(uuid_str: str) -> UUID:
+ return UUID(uuid_str)
+
+
+def update_settings_from_dict(server_settings, settings_dict: dict):
+ """Updates a settings object with values from a dictionary."""
+ settings = deepcopy(server_settings)
+ for key, value in settings_dict.items():
+ if value is not None:
+ if isinstance(value, dict):
+ for k, v in value.items():
+ if isinstance(getattr(settings, key), dict):
+ getattr(settings, key)[k] = v
+ else:
+ setattr(getattr(settings, key), k, v)
+ else:
+ setattr(settings, key, value)
+
+ return settings
+
+
+def _decorate_vector_type(
+ input_str: str,
+ quantization_type: VectorQuantizationType = VectorQuantizationType.FP32,
+) -> str:
+ return f"{quantization_type.db_type}{input_str}"
+
+
+def _get_vector_column_str(
+ dimension: int | float, quantization_type: VectorQuantizationType
+) -> str:
+ """Returns a string representation of a vector column type.
+
+ Explicitly handles the case where the dimension is not a valid number meant
+ to support embedding models that do not allow for specifying the dimension.
+ """
+ if math.isnan(dimension) or dimension <= 0:
+ vector_dim = "" # Allows for Postgres to handle any dimension
+ else:
+ vector_dim = f"({dimension})"
+ return _decorate_vector_type(vector_dim, quantization_type)
+
+
+KeyType = TypeVar("KeyType")
+
+
+def deep_update(
+ mapping: dict[KeyType, Any], *updating_mappings: dict[KeyType, Any]
+) -> dict[KeyType, Any]:
+ """
+ Taken from Pydantic v1:
+ https://github.com/pydantic/pydantic/blob/fd2991fe6a73819b48c906e3c3274e8e47d0f761/pydantic/utils.py#L200
+ """
+ updated_mapping = mapping.copy()
+ for updating_mapping in updating_mappings:
+ for k, v in updating_mapping.items():
+ if (
+ k in updated_mapping
+ and isinstance(updated_mapping[k], dict)
+ and isinstance(v, dict)
+ ):
+ updated_mapping[k] = deep_update(updated_mapping[k], v)
+ else:
+ updated_mapping[k] = v
+ return updated_mapping
+
+
+def tokens_count_for_message(message, encoding):
+ """Return the number of tokens used by a single message."""
+ tokens_per_message = 3
+
+ num_tokens = 0
+ num_tokens += tokens_per_message
+ if message.get("function_call"):
+ num_tokens += len(encoding.encode(message["function_call"]["name"]))
+ num_tokens += len(
+ encoding.encode(message["function_call"]["arguments"])
+ )
+ elif message.get("tool_calls"):
+ for tool_call in message["tool_calls"]:
+ num_tokens += len(encoding.encode(tool_call["function"]["name"]))
+ num_tokens += len(
+ encoding.encode(tool_call["function"]["arguments"])
+ )
+ else:
+ if "content" in message:
+ num_tokens += len(encoding.encode(message["content"]))
+
+ return num_tokens
+
+
+def num_tokens_from_messages(messages, model="gpt-4o"):
+ """Return the number of tokens used by a list of messages for both user and assistant."""
+ try:
+ encoding = tiktoken.encoding_for_model(model)
+ except KeyError:
+ logger.warning("Warning: model not found. Using cl100k_base encoding.")
+ encoding = tiktoken.get_encoding("cl100k_base")
+
+ tokens = 0
+ for message_ in messages:
+ tokens += tokens_count_for_message(message_, encoding)
+
+ tokens += 3 # every reply is primed with assistant
+ return tokens
+
+
+class SearchResultsCollector:
+ """
+ Collects search results in the form (source_type, result_obj).
+ Handles both object-oriented and dictionary-based search results.
+ """
+
+ def __init__(self):
+ # We'll store a list of (source_type, result_obj)
+ self._results_in_order = []
+
+ @property
+ def results(self):
+ """Get the results list"""
+ return self._results_in_order
+
+ @results.setter
+ def results(self, value):
+ """
+ Set the results directly, with automatic type detection for 'unknown' items
+ Handles the format: [('unknown', {...}), ('unknown', {...})]
+ """
+ self._results_in_order = []
+
+ if isinstance(value, list):
+ for item in value:
+ if isinstance(item, tuple) and len(item) == 2:
+ source_type, result_obj = item
+
+ # Only auto-detect if the source type is "unknown"
+ if source_type == "unknown":
+ detected_type = self._detect_result_type(result_obj)
+ self._results_in_order.append(
+ (detected_type, result_obj)
+ )
+ else:
+ self._results_in_order.append(
+ (source_type, result_obj)
+ )
+ else:
+ # If not a tuple, detect and add
+ detected_type = self._detect_result_type(item)
+ self._results_in_order.append((detected_type, item))
+ else:
+ raise ValueError("Results must be a list")
+
+ def add_aggregate_result(self, agg):
+ """
+ Flatten the chunk_search_results, graph_search_results, web_search_results,
+ and document_search_results into the collector, including nested chunks.
+ """
+ if hasattr(agg, "chunk_search_results") and agg.chunk_search_results:
+ for c in agg.chunk_search_results:
+ self._results_in_order.append(("chunk", c))
+
+ if hasattr(agg, "graph_search_results") and agg.graph_search_results:
+ for g in agg.graph_search_results:
+ self._results_in_order.append(("graph", g))
+
+ if hasattr(agg, "web_search_results") and agg.web_search_results:
+ for w in agg.web_search_results:
+ self._results_in_order.append(("web", w))
+
+ # Add documents and extract their chunks
+ if (
+ hasattr(agg, "document_search_results")
+ and agg.document_search_results
+ ):
+ for doc in agg.document_search_results:
+ # Add the document itself
+ self._results_in_order.append(("doc", doc))
+
+ # Extract and add chunks from the document
+ chunks = None
+ if isinstance(doc, dict):
+ chunks = doc.get("chunks", [])
+ elif hasattr(doc, "chunks") and doc.chunks is not None:
+ chunks = doc.chunks
+
+ if chunks:
+ for chunk in chunks:
+ # Ensure each chunk has the minimum required attributes
+ if isinstance(chunk, dict) and "id" in chunk:
+ # Add the chunk directly to results for citation lookup
+ self._results_in_order.append(("chunk", chunk))
+ elif hasattr(chunk, "id"):
+ self._results_in_order.append(("chunk", chunk))
+
+ def add_result(self, result_obj, source_type=None):
+ """
+ Add a single result object to the collector.
+ If source_type is not provided, automatically detect the type.
+ """
+ if source_type:
+ self._results_in_order.append((source_type, result_obj))
+ return source_type
+
+ detected_type = self._detect_result_type(result_obj)
+ self._results_in_order.append((detected_type, result_obj))
+ return detected_type
+
+ def _detect_result_type(self, obj):
+ """
+ Detect the type of a result object based on its properties.
+ Works with both object attributes and dictionary keys.
+ """
+ # Handle dictionary types first (common for web search results)
+ if isinstance(obj, dict):
+ # Web search pattern
+ if all(k in obj for k in ["title", "link"]) and any(
+ k in obj for k in ["snippet", "description"]
+ ):
+ return "web"
+
+ # Check for graph dictionary patterns
+ if "content" in obj and isinstance(obj["content"], dict):
+ content = obj["content"]
+ if all(k in content for k in ["name", "description"]):
+ return "graph" # Entity
+ if all(
+ k in content for k in ["subject", "predicate", "object"]
+ ):
+ return "graph" # Relationship
+ if all(k in content for k in ["name", "summary"]):
+ return "graph" # Community
+
+ # Chunk pattern
+ if all(k in obj for k in ["text", "id"]) and any(
+ k in obj for k in ["score", "metadata"]
+ ):
+ return "chunk"
+
+ # Context document pattern
+ if "document" in obj and "chunks" in obj:
+ return "doc"
+
+ # Check for explicit type indicator
+ if "type" in obj:
+ type_val = str(obj["type"]).lower()
+ if any(t in type_val for t in ["web", "organic"]):
+ return "web"
+ if "graph" in type_val:
+ return "graph"
+ if "chunk" in type_val:
+ return "chunk"
+ if "document" in type_val:
+ return "doc"
+
+ # Handle object attributes for OOP-style results
+ if hasattr(obj, "result_type"):
+ result_type = str(obj.result_type).lower()
+ if result_type in ["entity", "relationship", "community"]:
+ return "graph"
+
+ # Check class name hints
+ class_name = obj.__class__.__name__
+ if "Graph" in class_name:
+ return "graph"
+ if "Chunk" in class_name:
+ return "chunk"
+ if "Web" in class_name:
+ return "web"
+ if "Document" in class_name:
+ return "doc"
+
+ # Check for object attribute patterns
+ if hasattr(obj, "content"):
+ content = obj.content
+ if hasattr(content, "name") and hasattr(content, "description"):
+ return "graph" # Entity
+ if hasattr(content, "subject") and hasattr(content, "predicate"):
+ return "graph" # Relationship
+ if hasattr(content, "name") and hasattr(content, "summary"):
+ return "graph" # Community
+
+ if (
+ hasattr(obj, "text")
+ and hasattr(obj, "id")
+ and (hasattr(obj, "score") or hasattr(obj, "metadata"))
+ ):
+ return "chunk"
+
+ if (
+ hasattr(obj, "title")
+ and hasattr(obj, "link")
+ and hasattr(obj, "snippet")
+ ):
+ return "web"
+
+ if hasattr(obj, "document") and hasattr(obj, "chunks"):
+ return "doc"
+
+ # Default when type can't be determined
+ return "unknown"
+
+ def find_by_short_id(self, short_id):
+ """Find a result by its short ID prefix with better chunk handling"""
+ if not short_id:
+ return None
+
+ # First try direct lookup using regular iteration
+ for _, result_obj in self._results_in_order:
+ # Check dictionary objects
+ if isinstance(result_obj, dict) and "id" in result_obj:
+ result_id = str(result_obj["id"])
+ if result_id.startswith(short_id):
+ return result_obj
+
+ # Check object with id attribute
+ elif hasattr(result_obj, "id"):
+ obj_id = getattr(result_obj, "id", None)
+ if obj_id and str(obj_id).startswith(short_id):
+ # Convert to dict if possible
+ if hasattr(result_obj, "as_dict"):
+ return result_obj.as_dict()
+ elif hasattr(result_obj, "model_dump"):
+ return result_obj.model_dump()
+ elif hasattr(result_obj, "dict"):
+ return result_obj.dict()
+ else:
+ return result_obj
+
+ # If not found, look for chunks inside documents that weren't extracted properly
+ for source_type, result_obj in self._results_in_order:
+ if source_type == "doc":
+ # Try various ways to access chunks
+ chunks = None
+ if isinstance(result_obj, dict) and "chunks" in result_obj:
+ chunks = result_obj["chunks"]
+ elif (
+ hasattr(result_obj, "chunks")
+ and result_obj.chunks is not None
+ ):
+ chunks = result_obj.chunks
+
+ if chunks:
+ for chunk in chunks:
+ # Try each chunk
+ chunk_id = None
+ if isinstance(chunk, dict) and "id" in chunk:
+ chunk_id = chunk["id"]
+ elif hasattr(chunk, "id"):
+ chunk_id = chunk.id
+
+ if chunk_id and str(chunk_id).startswith(short_id):
+ return chunk
+
+ return None
+
+ def get_results_by_type(self, type_name):
+ """Get all results of a specific type"""
+ return [
+ result_obj
+ for source_type, result_obj in self._results_in_order
+ if source_type == type_name
+ ]
+
+ def __repr__(self):
+ """String representation showing counts by type"""
+ type_counts = {}
+ for source_type, _ in self._results_in_order:
+ type_counts[source_type] = type_counts.get(source_type, 0) + 1
+
+ return f"SearchResultsCollector with {len(self._results_in_order)} results: {type_counts}"
+
+ def get_all_results(self) -> list[Tuple[str, Any]]:
+ """
+ Return list of (source_type, result_obj, aggregator_index),
+ in the order appended.
+ """
+ return self._results_in_order
+
+
+def convert_nonserializable_objects(obj):
+ if hasattr(obj, "model_dump"):
+ obj = obj.model_dump()
+ if hasattr(obj, "as_dict"):
+ obj = obj.as_dict()
+ if hasattr(obj, "to_dict"):
+ obj = obj.to_dict()
+
+ if isinstance(obj, dict):
+ new_obj = {}
+ for key, value in obj.items():
+ # Convert key to string if it is a UUID or not already a string.
+ new_key = str(key) if not isinstance(key, str) else key
+ new_obj[new_key] = convert_nonserializable_objects(value)
+ return new_obj
+ elif isinstance(obj, list):
+ return [convert_nonserializable_objects(item) for item in obj]
+ elif isinstance(obj, tuple):
+ return tuple(convert_nonserializable_objects(item) for item in obj)
+ elif isinstance(obj, set):
+ return {convert_nonserializable_objects(item) for item in obj}
+ elif isinstance(obj, uuid.UUID):
+ return str(obj)
+ elif isinstance(obj, datetime):
+ return obj.isoformat() # Convert datetime to ISO formatted string
+ else:
+ return obj
+
+
+def dump_obj(obj) -> list[dict[str, Any]]:
+ if hasattr(obj, "model_dump"):
+ obj = obj.model_dump()
+ elif hasattr(obj, "dict"):
+ obj = obj.dict()
+ elif hasattr(obj, "as_dict"):
+ obj = obj.as_dict()
+ elif hasattr(obj, "to_dict"):
+ obj = obj.to_dict()
+ obj = convert_nonserializable_objects(obj)
+
+ return obj
+
+
+def dump_collector(collector: SearchResultsCollector) -> list[dict[str, Any]]:
+ dumped = []
+ for source_type, result_obj in collector.get_all_results():
+ # Get the dictionary from the result object
+ if hasattr(result_obj, "model_dump"):
+ result_dict = result_obj.model_dump()
+ elif hasattr(result_obj, "dict"):
+ result_dict = result_obj.dict()
+ elif hasattr(result_obj, "as_dict"):
+ result_dict = result_obj.as_dict()
+ elif hasattr(result_obj, "to_dict"):
+ result_dict = result_obj.to_dict()
+ else:
+ result_dict = (
+ result_obj # Fallback if no conversion method is available
+ )
+
+ # Use the recursive conversion on the entire dictionary
+ result_dict = convert_nonserializable_objects(result_dict)
+
+ dumped.append(
+ {
+ "source_type": source_type,
+ "result": result_dict,
+ }
+ )
+ return dumped
+
+
+def num_tokens(text, model="gpt-4o"):
+ try:
+ encoding = tiktoken.encoding_for_model(model)
+ except KeyError:
+ encoding = tiktoken.get_encoding("cl100k_base")
+
+ """Return the number of tokens used by a list of messages for both user and assistant."""
+ return len(encoding.encode(text, disallowed_special=()))
+
+
+class CombinedMeta(AsyncSyncMeta, ABCMeta):
+ pass
+
+
+async def yield_sse_event(event_name: str, payload: dict, chunk_size=1024):
+ """
+ Helper that yields a single SSE event in properly chunked lines.
+
+ e.g. event: event_name
+ data: (partial JSON 1)
+ data: (partial JSON 2)
+ ...
+ [blank line to end event]
+ """
+
+ # SSE: first the "event: ..."
+ yield f"event: {event_name}\n"
+
+ # Convert payload to JSON
+ content_str = json.dumps(payload, default=str)
+
+ # data
+ yield f"data: {content_str}\n"
+
+ # blank line signals end of SSE event
+ yield "\n"
+
+
+class SSEFormatter:
+ """
+ Enhanced formatter for Server-Sent Events (SSE) with citation tracking.
+ Extends the existing SSEFormatter with improved citation handling.
+ """
+
+ @staticmethod
+ async def yield_citation_event(
+ citation_data: dict,
+ ):
+ """
+ Emits a citation event with optimized payload.
+
+ Args:
+ citation_id: The short ID of the citation (e.g., 'abc1234')
+ span: (start, end) position tuple for this occurrence
+ payload: Source object (included only for first occurrence)
+ is_new: Whether this is the first time we've seen this citation
+ citation_id_counter: Optional counter for citation occurrences
+
+ Yields:
+ Formatted SSE event lines
+ """
+
+ # Include the full payload only for new citations
+ if not citation_data.get("is_new") or "payload" not in citation_data:
+ citation_data["payload"] = None
+
+ # Yield the event
+ async for line in yield_sse_event("citation", citation_data):
+ yield line
+
+ @staticmethod
+ async def yield_final_answer_event(
+ final_data: dict,
+ ):
+ # Yield the event
+ async for line in yield_sse_event("final_answer", final_data):
+ yield line
+
+ # Include other existing SSEFormatter methods for compatibility
+ @staticmethod
+ async def yield_message_event(text_segment, msg_id=None):
+ msg_id = msg_id or f"msg_{uuid.uuid4().hex[:8]}"
+ msg_payload = {
+ "id": msg_id,
+ "object": "agent.message.delta",
+ "delta": {
+ "content": [
+ {
+ "type": "text",
+ "payload": {
+ "value": text_segment,
+ "annotations": [],
+ },
+ }
+ ]
+ },
+ }
+ async for line in yield_sse_event("message", msg_payload):
+ yield line
+
+ @staticmethod
+ async def yield_thinking_event(text_segment, thinking_id=None):
+ thinking_id = thinking_id or f"think_{uuid.uuid4().hex[:8]}"
+ thinking_data = {
+ "id": thinking_id,
+ "object": "agent.thinking.delta",
+ "delta": {
+ "content": [
+ {
+ "type": "text",
+ "payload": {
+ "value": text_segment,
+ "annotations": [],
+ },
+ }
+ ]
+ },
+ }
+ async for line in yield_sse_event("thinking", thinking_data):
+ yield line
+
+ @staticmethod
+ def yield_done_event():
+ return "event: done\ndata: [DONE]\n\n"
+
+ @staticmethod
+ async def yield_error_event(error_message, error_id=None):
+ error_id = error_id or f"err_{uuid.uuid4().hex[:8]}"
+ error_payload = {
+ "id": error_id,
+ "object": "agent.error",
+ "error": {"message": error_message, "type": "agent_error"},
+ }
+ async for line in yield_sse_event("error", error_payload):
+ yield line
+
+ @staticmethod
+ async def yield_tool_call_event(tool_call_data):
+ from ..api.models.retrieval.responses import ToolCallEvent
+
+ tc_event = ToolCallEvent(event="tool_call", data=tool_call_data)
+ async for line in yield_sse_event(
+ "tool_call", tc_event.dict()["data"]
+ ):
+ yield line
+
+ # New helper for emitting search results:
+ @staticmethod
+ async def yield_search_results_event(aggregated_results):
+ payload = {
+ "id": "search_1",
+ "object": "rag.search_results",
+ "data": aggregated_results.as_dict(),
+ }
+ async for line in yield_sse_event("search_results", payload):
+ yield line
+
+ @staticmethod
+ async def yield_tool_result_event(tool_result_data):
+ from ..api.models.retrieval.responses import ToolResultEvent
+
+ tr_event = ToolResultEvent(event="tool_result", data=tool_result_data)
+ async for line in yield_sse_event(
+ "tool_result", tr_event.dict()["data"]
+ ):
+ yield line