import json import logging import math import uuid from abc import ABCMeta from copy import deepcopy from datetime import datetime from typing import TYPE_CHECKING, Any, Optional, Tuple, TypeVar from uuid import NAMESPACE_DNS, UUID, uuid4, uuid5 import tiktoken from ..abstractions import ( AggregateSearchResult, AsyncSyncMeta, GraphCommunityResult, GraphEntityResult, GraphRelationshipResult, ) from ..abstractions.vector import VectorQuantizationType if TYPE_CHECKING: pass logger = logging.getLogger() def id_to_shorthand(id: str | UUID): return str(id)[:7] def format_search_results_for_llm( results: AggregateSearchResult, collector: Any, # SearchResultsCollector ) -> str: """ Instead of resetting 'source_counter' to 1, we: - For each chunk / graph / web / doc in `results`, - Find the aggregator index from the collector, - Print 'Source [X]:' with that aggregator index. """ lines = [] # We'll build a quick helper to locate aggregator indices for each object: # Or you can rely on the fact that we've added them to the collector # in the same order. But let's do a "lookup aggregator index" approach: # 1) Chunk search if results.chunk_search_results: lines.append("Vector Search Results:") for c in results.chunk_search_results: lines.append(f"Source ID [{id_to_shorthand(c.id)}]:") lines.append(c.text or "") # or c.text[:200] to truncate # 2) Graph search if results.graph_search_results: lines.append("Graph Search Results:") for g in results.graph_search_results: lines.append(f"Source ID [{id_to_shorthand(g.id)}]:") if isinstance(g.content, GraphCommunityResult): lines.append(f"Community Name: {g.content.name}") lines.append(f"ID: {g.content.id}") lines.append(f"Summary: {g.content.summary}") # etc. ... elif isinstance(g.content, GraphEntityResult): lines.append(f"Entity Name: {g.content.name}") lines.append(f"Description: {g.content.description}") elif isinstance(g.content, GraphRelationshipResult): lines.append( f"Relationship: {g.content.subject}-{g.content.predicate}-{g.content.object}" ) # Add metadata if needed # 3) Web search if results.web_search_results: lines.append("Web Search Results:") for w in results.web_search_results: lines.append(f"Source ID [{id_to_shorthand(w.id)}]:") lines.append(f"Title: {w.title}") lines.append(f"Link: {w.link}") lines.append(f"Snippet: {w.snippet}") # 4) Local context docs if results.document_search_results: lines.append("Local Context Documents:") for doc_result in results.document_search_results: doc_title = doc_result.title or "Untitled Document" doc_id = doc_result.id summary = doc_result.summary lines.append(f"Full Document ID: {doc_id}") lines.append(f"Shortened Document ID: {id_to_shorthand(doc_id)}") lines.append(f"Document Title: {doc_title}") if summary: lines.append(f"Summary: {summary}") if doc_result.chunks: # Then each chunk inside: for chunk in doc_result.chunks: lines.append( f"\nChunk ID {id_to_shorthand(chunk['id'])}:\n{chunk['text']}" ) result = "\n".join(lines) return result def _generate_id_from_label(label) -> UUID: return uuid5(NAMESPACE_DNS, label) def generate_id(label: Optional[str] = None) -> UUID: """Generates a unique run id.""" return _generate_id_from_label( label if label is not None else str(uuid4()) ) def generate_document_id(filename: str, user_id: UUID) -> UUID: """Generates a unique document id from a given filename and user id.""" safe_filename = filename.replace("/", "_") return _generate_id_from_label(f"{safe_filename}-{str(user_id)}") def generate_extraction_id( document_id: UUID, iteration: int = 0, version: str = "0" ) -> UUID: """Generates a unique extraction id from a given document id and iteration.""" return _generate_id_from_label(f"{str(document_id)}-{iteration}-{version}") def generate_default_user_collection_id(user_id: UUID) -> UUID: """Generates a unique collection id from a given user id.""" return _generate_id_from_label(str(user_id)) def generate_user_id(email: str) -> UUID: """Generates a unique user id from a given email.""" return _generate_id_from_label(email) def generate_default_prompt_id(prompt_name: str) -> UUID: """Generates a unique prompt id.""" return _generate_id_from_label(prompt_name) def generate_entity_document_id() -> UUID: """Generates a unique document id inserting entities into a graph.""" generation_time = datetime.now().isoformat() return _generate_id_from_label(f"entity-{generation_time}") def increment_version(version: str) -> str: prefix = version[:-1] suffix = int(version[-1]) return f"{prefix}{suffix + 1}" def decrement_version(version: str) -> str: prefix = version[:-1] suffix = int(version[-1]) return f"{prefix}{max(0, suffix - 1)}" def validate_uuid(uuid_str: str) -> UUID: return UUID(uuid_str) def update_settings_from_dict(server_settings, settings_dict: dict): """Updates a settings object with values from a dictionary.""" settings = deepcopy(server_settings) for key, value in settings_dict.items(): if value is not None: if isinstance(value, dict): for k, v in value.items(): if isinstance(getattr(settings, key), dict): getattr(settings, key)[k] = v else: setattr(getattr(settings, key), k, v) else: setattr(settings, key, value) return settings def _decorate_vector_type( input_str: str, quantization_type: VectorQuantizationType = VectorQuantizationType.FP32, ) -> str: return f"{quantization_type.db_type}{input_str}" def _get_vector_column_str( dimension: int | float, quantization_type: VectorQuantizationType ) -> str: """Returns a string representation of a vector column type. Explicitly handles the case where the dimension is not a valid number meant to support embedding models that do not allow for specifying the dimension. """ if math.isnan(dimension) or dimension <= 0: vector_dim = "" # Allows for Postgres to handle any dimension else: vector_dim = f"({dimension})" return _decorate_vector_type(vector_dim, quantization_type) KeyType = TypeVar("KeyType") def deep_update( mapping: dict[KeyType, Any], *updating_mappings: dict[KeyType, Any] ) -> dict[KeyType, Any]: """ Taken from Pydantic v1: https://github.com/pydantic/pydantic/blob/fd2991fe6a73819b48c906e3c3274e8e47d0f761/pydantic/utils.py#L200 """ updated_mapping = mapping.copy() for updating_mapping in updating_mappings: for k, v in updating_mapping.items(): if ( k in updated_mapping and isinstance(updated_mapping[k], dict) and isinstance(v, dict) ): updated_mapping[k] = deep_update(updated_mapping[k], v) else: updated_mapping[k] = v return updated_mapping def tokens_count_for_message(message, encoding): """Return the number of tokens used by a single message.""" tokens_per_message = 3 num_tokens = 0 num_tokens += tokens_per_message if message.get("function_call"): num_tokens += len(encoding.encode(message["function_call"]["name"])) num_tokens += len( encoding.encode(message["function_call"]["arguments"]) ) elif message.get("tool_calls"): for tool_call in message["tool_calls"]: num_tokens += len(encoding.encode(tool_call["function"]["name"])) num_tokens += len( encoding.encode(tool_call["function"]["arguments"]) ) else: if "content" in message: num_tokens += len(encoding.encode(message["content"])) return num_tokens def num_tokens_from_messages(messages, model="gpt-4o"): """Return the number of tokens used by a list of messages for both user and assistant.""" try: encoding = tiktoken.encoding_for_model(model) except KeyError: logger.warning("Warning: model not found. Using cl100k_base encoding.") encoding = tiktoken.get_encoding("cl100k_base") tokens = 0 for message_ in messages: tokens += tokens_count_for_message(message_, encoding) tokens += 3 # every reply is primed with assistant return tokens class SearchResultsCollector: """ Collects search results in the form (source_type, result_obj). Handles both object-oriented and dictionary-based search results. """ def __init__(self): # We'll store a list of (source_type, result_obj) self._results_in_order = [] @property def results(self): """Get the results list""" return self._results_in_order @results.setter def results(self, value): """ Set the results directly, with automatic type detection for 'unknown' items Handles the format: [('unknown', {...}), ('unknown', {...})] """ self._results_in_order = [] if isinstance(value, list): for item in value: if isinstance(item, tuple) and len(item) == 2: source_type, result_obj = item # Only auto-detect if the source type is "unknown" if source_type == "unknown": detected_type = self._detect_result_type(result_obj) self._results_in_order.append( (detected_type, result_obj) ) else: self._results_in_order.append( (source_type, result_obj) ) else: # If not a tuple, detect and add detected_type = self._detect_result_type(item) self._results_in_order.append((detected_type, item)) else: raise ValueError("Results must be a list") def add_aggregate_result(self, agg): """ Flatten the chunk_search_results, graph_search_results, web_search_results, and document_search_results into the collector, including nested chunks. """ if hasattr(agg, "chunk_search_results") and agg.chunk_search_results: for c in agg.chunk_search_results: self._results_in_order.append(("chunk", c)) if hasattr(agg, "graph_search_results") and agg.graph_search_results: for g in agg.graph_search_results: self._results_in_order.append(("graph", g)) if hasattr(agg, "web_search_results") and agg.web_search_results: for w in agg.web_search_results: self._results_in_order.append(("web", w)) # Add documents and extract their chunks if ( hasattr(agg, "document_search_results") and agg.document_search_results ): for doc in agg.document_search_results: # Add the document itself self._results_in_order.append(("doc", doc)) # Extract and add chunks from the document chunks = None if isinstance(doc, dict): chunks = doc.get("chunks", []) elif hasattr(doc, "chunks") and doc.chunks is not None: chunks = doc.chunks if chunks: for chunk in chunks: # Ensure each chunk has the minimum required attributes if isinstance(chunk, dict) and "id" in chunk: # Add the chunk directly to results for citation lookup self._results_in_order.append(("chunk", chunk)) elif hasattr(chunk, "id"): self._results_in_order.append(("chunk", chunk)) def add_result(self, result_obj, source_type=None): """ Add a single result object to the collector. If source_type is not provided, automatically detect the type. """ if source_type: self._results_in_order.append((source_type, result_obj)) return source_type detected_type = self._detect_result_type(result_obj) self._results_in_order.append((detected_type, result_obj)) return detected_type def _detect_result_type(self, obj): """ Detect the type of a result object based on its properties. Works with both object attributes and dictionary keys. """ # Handle dictionary types first (common for web search results) if isinstance(obj, dict): # Web search pattern if all(k in obj for k in ["title", "link"]) and any( k in obj for k in ["snippet", "description"] ): return "web" # Check for graph dictionary patterns if "content" in obj and isinstance(obj["content"], dict): content = obj["content"] if all(k in content for k in ["name", "description"]): return "graph" # Entity if all( k in content for k in ["subject", "predicate", "object"] ): return "graph" # Relationship if all(k in content for k in ["name", "summary"]): return "graph" # Community # Chunk pattern if all(k in obj for k in ["text", "id"]) and any( k in obj for k in ["score", "metadata"] ): return "chunk" # Context document pattern if "document" in obj and "chunks" in obj: return "doc" # Check for explicit type indicator if "type" in obj: type_val = str(obj["type"]).lower() if any(t in type_val for t in ["web", "organic"]): return "web" if "graph" in type_val: return "graph" if "chunk" in type_val: return "chunk" if "document" in type_val: return "doc" # Handle object attributes for OOP-style results if hasattr(obj, "result_type"): result_type = str(obj.result_type).lower() if result_type in ["entity", "relationship", "community"]: return "graph" # Check class name hints class_name = obj.__class__.__name__ if "Graph" in class_name: return "graph" if "Chunk" in class_name: return "chunk" if "Web" in class_name: return "web" if "Document" in class_name: return "doc" # Check for object attribute patterns if hasattr(obj, "content"): content = obj.content if hasattr(content, "name") and hasattr(content, "description"): return "graph" # Entity if hasattr(content, "subject") and hasattr(content, "predicate"): return "graph" # Relationship if hasattr(content, "name") and hasattr(content, "summary"): return "graph" # Community if ( hasattr(obj, "text") and hasattr(obj, "id") and (hasattr(obj, "score") or hasattr(obj, "metadata")) ): return "chunk" if ( hasattr(obj, "title") and hasattr(obj, "link") and hasattr(obj, "snippet") ): return "web" if hasattr(obj, "document") and hasattr(obj, "chunks"): return "doc" # Default when type can't be determined return "unknown" def find_by_short_id(self, short_id): """Find a result by its short ID prefix with better chunk handling""" if not short_id: return None # First try direct lookup using regular iteration for _, result_obj in self._results_in_order: # Check dictionary objects if isinstance(result_obj, dict) and "id" in result_obj: result_id = str(result_obj["id"]) if result_id.startswith(short_id): return result_obj # Check object with id attribute elif hasattr(result_obj, "id"): obj_id = getattr(result_obj, "id", None) if obj_id and str(obj_id).startswith(short_id): # Convert to dict if possible if hasattr(result_obj, "as_dict"): return result_obj.as_dict() elif hasattr(result_obj, "model_dump"): return result_obj.model_dump() elif hasattr(result_obj, "dict"): return result_obj.dict() else: return result_obj # If not found, look for chunks inside documents that weren't extracted properly for source_type, result_obj in self._results_in_order: if source_type == "doc": # Try various ways to access chunks chunks = None if isinstance(result_obj, dict) and "chunks" in result_obj: chunks = result_obj["chunks"] elif ( hasattr(result_obj, "chunks") and result_obj.chunks is not None ): chunks = result_obj.chunks if chunks: for chunk in chunks: # Try each chunk chunk_id = None if isinstance(chunk, dict) and "id" in chunk: chunk_id = chunk["id"] elif hasattr(chunk, "id"): chunk_id = chunk.id if chunk_id and str(chunk_id).startswith(short_id): return chunk return None def get_results_by_type(self, type_name): """Get all results of a specific type""" return [ result_obj for source_type, result_obj in self._results_in_order if source_type == type_name ] def __repr__(self): """String representation showing counts by type""" type_counts = {} for source_type, _ in self._results_in_order: type_counts[source_type] = type_counts.get(source_type, 0) + 1 return f"SearchResultsCollector with {len(self._results_in_order)} results: {type_counts}" def get_all_results(self) -> list[Tuple[str, Any]]: """ Return list of (source_type, result_obj, aggregator_index), in the order appended. """ return self._results_in_order def convert_nonserializable_objects(obj): if hasattr(obj, "model_dump"): obj = obj.model_dump() if hasattr(obj, "as_dict"): obj = obj.as_dict() if hasattr(obj, "to_dict"): obj = obj.to_dict() if isinstance(obj, dict): new_obj = {} for key, value in obj.items(): # Convert key to string if it is a UUID or not already a string. new_key = str(key) if not isinstance(key, str) else key new_obj[new_key] = convert_nonserializable_objects(value) return new_obj elif isinstance(obj, list): return [convert_nonserializable_objects(item) for item in obj] elif isinstance(obj, tuple): return tuple(convert_nonserializable_objects(item) for item in obj) elif isinstance(obj, set): return {convert_nonserializable_objects(item) for item in obj} elif isinstance(obj, uuid.UUID): return str(obj) elif isinstance(obj, datetime): return obj.isoformat() # Convert datetime to ISO formatted string else: return obj def dump_obj(obj) -> list[dict[str, Any]]: if hasattr(obj, "model_dump"): obj = obj.model_dump() elif hasattr(obj, "dict"): obj = obj.dict() elif hasattr(obj, "as_dict"): obj = obj.as_dict() elif hasattr(obj, "to_dict"): obj = obj.to_dict() obj = convert_nonserializable_objects(obj) return obj def dump_collector(collector: SearchResultsCollector) -> list[dict[str, Any]]: dumped = [] for source_type, result_obj in collector.get_all_results(): # Get the dictionary from the result object if hasattr(result_obj, "model_dump"): result_dict = result_obj.model_dump() elif hasattr(result_obj, "dict"): result_dict = result_obj.dict() elif hasattr(result_obj, "as_dict"): result_dict = result_obj.as_dict() elif hasattr(result_obj, "to_dict"): result_dict = result_obj.to_dict() else: result_dict = ( result_obj # Fallback if no conversion method is available ) # Use the recursive conversion on the entire dictionary result_dict = convert_nonserializable_objects(result_dict) dumped.append( { "source_type": source_type, "result": result_dict, } ) return dumped def num_tokens(text, model="gpt-4o"): try: encoding = tiktoken.encoding_for_model(model) except KeyError: encoding = tiktoken.get_encoding("cl100k_base") """Return the number of tokens used by a list of messages for both user and assistant.""" return len(encoding.encode(text, disallowed_special=())) class CombinedMeta(AsyncSyncMeta, ABCMeta): pass async def yield_sse_event(event_name: str, payload: dict, chunk_size=1024): """ Helper that yields a single SSE event in properly chunked lines. e.g. event: event_name data: (partial JSON 1) data: (partial JSON 2) ... [blank line to end event] """ # SSE: first the "event: ..." yield f"event: {event_name}\n" # Convert payload to JSON content_str = json.dumps(payload, default=str) # data yield f"data: {content_str}\n" # blank line signals end of SSE event yield "\n" class SSEFormatter: """ Enhanced formatter for Server-Sent Events (SSE) with citation tracking. Extends the existing SSEFormatter with improved citation handling. """ @staticmethod async def yield_citation_event( citation_data: dict, ): """ Emits a citation event with optimized payload. Args: citation_id: The short ID of the citation (e.g., 'abc1234') span: (start, end) position tuple for this occurrence payload: Source object (included only for first occurrence) is_new: Whether this is the first time we've seen this citation citation_id_counter: Optional counter for citation occurrences Yields: Formatted SSE event lines """ # Include the full payload only for new citations if not citation_data.get("is_new") or "payload" not in citation_data: citation_data["payload"] = None # Yield the event async for line in yield_sse_event("citation", citation_data): yield line @staticmethod async def yield_final_answer_event( final_data: dict, ): # Yield the event async for line in yield_sse_event("final_answer", final_data): yield line # Include other existing SSEFormatter methods for compatibility @staticmethod async def yield_message_event(text_segment, msg_id=None): msg_id = msg_id or f"msg_{uuid.uuid4().hex[:8]}" msg_payload = { "id": msg_id, "object": "agent.message.delta", "delta": { "content": [ { "type": "text", "payload": { "value": text_segment, "annotations": [], }, } ] }, } async for line in yield_sse_event("message", msg_payload): yield line @staticmethod async def yield_thinking_event(text_segment, thinking_id=None): thinking_id = thinking_id or f"think_{uuid.uuid4().hex[:8]}" thinking_data = { "id": thinking_id, "object": "agent.thinking.delta", "delta": { "content": [ { "type": "text", "payload": { "value": text_segment, "annotations": [], }, } ] }, } async for line in yield_sse_event("thinking", thinking_data): yield line @staticmethod def yield_done_event(): return "event: done\ndata: [DONE]\n\n" @staticmethod async def yield_error_event(error_message, error_id=None): error_id = error_id or f"err_{uuid.uuid4().hex[:8]}" error_payload = { "id": error_id, "object": "agent.error", "error": {"message": error_message, "type": "agent_error"}, } async for line in yield_sse_event("error", error_payload): yield line @staticmethod async def yield_tool_call_event(tool_call_data): from ..api.models.retrieval.responses import ToolCallEvent tc_event = ToolCallEvent(event="tool_call", data=tool_call_data) async for line in yield_sse_event( "tool_call", tc_event.dict()["data"] ): yield line # New helper for emitting search results: @staticmethod async def yield_search_results_event(aggregated_results): payload = { "id": "search_1", "object": "rag.search_results", "data": aggregated_results.as_dict(), } async for line in yield_sse_event("search_results", payload): yield line @staticmethod async def yield_tool_result_event(tool_result_data): from ..api.models.retrieval.responses import ToolResultEvent tr_event = ToolResultEvent(event="tool_result", data=tool_result_data) async for line in yield_sse_event( "tool_result", tr_event.dict()["data"] ): yield line