aboutsummaryrefslogtreecommitdiff
import json
import logging
import math
import uuid
from abc import ABCMeta
from copy import deepcopy
from datetime import datetime
from typing import TYPE_CHECKING, Any, Optional, Tuple, TypeVar
from uuid import NAMESPACE_DNS, UUID, uuid4, uuid5

import tiktoken

from ..abstractions import (
    AggregateSearchResult,
    AsyncSyncMeta,
    GraphCommunityResult,
    GraphEntityResult,
    GraphRelationshipResult,
)
from ..abstractions.vector import VectorQuantizationType

if TYPE_CHECKING:
    pass


logger = logging.getLogger()


def id_to_shorthand(id: str | UUID):
    return str(id)[:7]


def format_search_results_for_llm(
    results: AggregateSearchResult,
    collector: Any,  # SearchResultsCollector
) -> str:
    """
    Instead of resetting 'source_counter' to 1, we:
     - For each chunk / graph / web / doc in `results`,
     - Find the aggregator index from the collector,
     - Print 'Source [X]:' with that aggregator index.
    """
    lines = []

    # We'll build a quick helper to locate aggregator indices for each object:
    # Or you can rely on the fact that we've added them to the collector
    # in the same order. But let's do a "lookup aggregator index" approach:

    # 1) Chunk search
    if results.chunk_search_results:
        lines.append("Vector Search Results:")
        for c in results.chunk_search_results:
            lines.append(f"Source ID [{id_to_shorthand(c.id)}]:")
            lines.append(c.text or "")  # or c.text[:200] to truncate

    # 2) Graph search
    if results.graph_search_results:
        lines.append("Graph Search Results:")
        for g in results.graph_search_results:
            lines.append(f"Source ID [{id_to_shorthand(g.id)}]:")
            if isinstance(g.content, GraphCommunityResult):
                lines.append(f"Community Name: {g.content.name}")
                lines.append(f"ID: {g.content.id}")
                lines.append(f"Summary: {g.content.summary}")
                # etc. ...
            elif isinstance(g.content, GraphEntityResult):
                lines.append(f"Entity Name: {g.content.name}")
                lines.append(f"Description: {g.content.description}")
            elif isinstance(g.content, GraphRelationshipResult):
                lines.append(
                    f"Relationship: {g.content.subject}-{g.content.predicate}-{g.content.object}"
                )
            # Add metadata if needed

    # 3) Web search
    if results.web_search_results:
        lines.append("Web Search Results:")
        for w in results.web_search_results:
            lines.append(f"Source ID [{id_to_shorthand(w.id)}]:")
            lines.append(f"Title: {w.title}")
            lines.append(f"Link: {w.link}")
            lines.append(f"Snippet: {w.snippet}")

    # 4) Local context docs
    if results.document_search_results:
        lines.append("Local Context Documents:")
        for doc_result in results.document_search_results:
            doc_title = doc_result.title or "Untitled Document"
            doc_id = doc_result.id
            summary = doc_result.summary

            lines.append(f"Full Document ID: {doc_id}")
            lines.append(f"Shortened Document ID: {id_to_shorthand(doc_id)}")
            lines.append(f"Document Title: {doc_title}")
            if summary:
                lines.append(f"Summary: {summary}")

            if doc_result.chunks:
                # Then each chunk inside:
                for chunk in doc_result.chunks:
                    lines.append(
                        f"\nChunk ID {id_to_shorthand(chunk['id'])}:\n{chunk['text']}"
                    )

    result = "\n".join(lines)
    return result


def _generate_id_from_label(label) -> UUID:
    return uuid5(NAMESPACE_DNS, label)


def generate_id(label: Optional[str] = None) -> UUID:
    """Generates a unique run id."""
    return _generate_id_from_label(
        label if label is not None else str(uuid4())
    )


def generate_document_id(filename: str, user_id: UUID) -> UUID:
    """Generates a unique document id from a given filename and user id."""
    safe_filename = filename.replace("/", "_")
    return _generate_id_from_label(f"{safe_filename}-{str(user_id)}")


def generate_extraction_id(
    document_id: UUID, iteration: int = 0, version: str = "0"
) -> UUID:
    """Generates a unique extraction id from a given document id and
    iteration."""
    return _generate_id_from_label(f"{str(document_id)}-{iteration}-{version}")


def generate_default_user_collection_id(user_id: UUID) -> UUID:
    """Generates a unique collection id from a given user id."""
    return _generate_id_from_label(str(user_id))


def generate_user_id(email: str) -> UUID:
    """Generates a unique user id from a given email."""
    return _generate_id_from_label(email)


def generate_default_prompt_id(prompt_name: str) -> UUID:
    """Generates a unique prompt id."""
    return _generate_id_from_label(prompt_name)


def generate_entity_document_id() -> UUID:
    """Generates a unique document id inserting entities into a graph."""
    generation_time = datetime.now().isoformat()
    return _generate_id_from_label(f"entity-{generation_time}")


def increment_version(version: str) -> str:
    prefix = version[:-1]
    suffix = int(version[-1])
    return f"{prefix}{suffix + 1}"


def decrement_version(version: str) -> str:
    prefix = version[:-1]
    suffix = int(version[-1])
    return f"{prefix}{max(0, suffix - 1)}"


def validate_uuid(uuid_str: str) -> UUID:
    return UUID(uuid_str)


def update_settings_from_dict(server_settings, settings_dict: dict):
    """Updates a settings object with values from a dictionary."""
    settings = deepcopy(server_settings)
    for key, value in settings_dict.items():
        if value is not None:
            if isinstance(value, dict):
                for k, v in value.items():
                    if isinstance(getattr(settings, key), dict):
                        getattr(settings, key)[k] = v
                    else:
                        setattr(getattr(settings, key), k, v)
            else:
                setattr(settings, key, value)

    return settings


def _decorate_vector_type(
    input_str: str,
    quantization_type: VectorQuantizationType = VectorQuantizationType.FP32,
) -> str:
    return f"{quantization_type.db_type}{input_str}"


def _get_vector_column_str(
    dimension: int | float, quantization_type: VectorQuantizationType
) -> str:
    """Returns a string representation of a vector column type.

    Explicitly handles the case where the dimension is not a valid number meant
    to support embedding models that do not allow for specifying the dimension.
    """
    if math.isnan(dimension) or dimension <= 0:
        vector_dim = ""  # Allows for Postgres to handle any dimension
    else:
        vector_dim = f"({dimension})"
    return _decorate_vector_type(vector_dim, quantization_type)


KeyType = TypeVar("KeyType")


def deep_update(
    mapping: dict[KeyType, Any], *updating_mappings: dict[KeyType, Any]
) -> dict[KeyType, Any]:
    """
    Taken from Pydantic v1:
    https://github.com/pydantic/pydantic/blob/fd2991fe6a73819b48c906e3c3274e8e47d0f761/pydantic/utils.py#L200
    """
    updated_mapping = mapping.copy()
    for updating_mapping in updating_mappings:
        for k, v in updating_mapping.items():
            if (
                k in updated_mapping
                and isinstance(updated_mapping[k], dict)
                and isinstance(v, dict)
            ):
                updated_mapping[k] = deep_update(updated_mapping[k], v)
            else:
                updated_mapping[k] = v
    return updated_mapping


def tokens_count_for_message(message, encoding):
    """Return the number of tokens used by a single message."""
    tokens_per_message = 3

    num_tokens = 0
    num_tokens += tokens_per_message
    if message.get("function_call"):
        num_tokens += len(encoding.encode(message["function_call"]["name"]))
        num_tokens += len(
            encoding.encode(message["function_call"]["arguments"])
        )
    elif message.get("tool_calls"):
        for tool_call in message["tool_calls"]:
            num_tokens += len(encoding.encode(tool_call["function"]["name"]))
            num_tokens += len(
                encoding.encode(tool_call["function"]["arguments"])
            )
    else:
        if "content" in message:
            num_tokens += len(encoding.encode(message["content"]))

    return num_tokens


def num_tokens_from_messages(messages, model="gpt-4o"):
    """Return the number of tokens used by a list of messages for both user and assistant."""
    try:
        encoding = tiktoken.encoding_for_model(model)
    except KeyError:
        logger.warning("Warning: model not found. Using cl100k_base encoding.")
        encoding = tiktoken.get_encoding("cl100k_base")

    tokens = 0
    for message_ in messages:
        tokens += tokens_count_for_message(message_, encoding)

        tokens += 3  # every reply is primed with assistant
    return tokens


class SearchResultsCollector:
    """
    Collects search results in the form (source_type, result_obj).
    Handles both object-oriented and dictionary-based search results.
    """

    def __init__(self):
        # We'll store a list of (source_type, result_obj)
        self._results_in_order = []

    @property
    def results(self):
        """Get the results list"""
        return self._results_in_order

    @results.setter
    def results(self, value):
        """
        Set the results directly, with automatic type detection for 'unknown' items
        Handles the format: [('unknown', {...}), ('unknown', {...})]
        """
        self._results_in_order = []

        if isinstance(value, list):
            for item in value:
                if isinstance(item, tuple) and len(item) == 2:
                    source_type, result_obj = item

                    # Only auto-detect if the source type is "unknown"
                    if source_type == "unknown":
                        detected_type = self._detect_result_type(result_obj)
                        self._results_in_order.append(
                            (detected_type, result_obj)
                        )
                    else:
                        self._results_in_order.append(
                            (source_type, result_obj)
                        )
                else:
                    # If not a tuple, detect and add
                    detected_type = self._detect_result_type(item)
                    self._results_in_order.append((detected_type, item))
        else:
            raise ValueError("Results must be a list")

    def add_aggregate_result(self, agg):
        """
        Flatten the chunk_search_results, graph_search_results, web_search_results,
        and document_search_results into the collector, including nested chunks.
        """
        if hasattr(agg, "chunk_search_results") and agg.chunk_search_results:
            for c in agg.chunk_search_results:
                self._results_in_order.append(("chunk", c))

        if hasattr(agg, "graph_search_results") and agg.graph_search_results:
            for g in agg.graph_search_results:
                self._results_in_order.append(("graph", g))

        if hasattr(agg, "web_search_results") and agg.web_search_results:
            for w in agg.web_search_results:
                self._results_in_order.append(("web", w))

        # Add documents and extract their chunks
        if (
            hasattr(agg, "document_search_results")
            and agg.document_search_results
        ):
            for doc in agg.document_search_results:
                # Add the document itself
                self._results_in_order.append(("doc", doc))

                # Extract and add chunks from the document
                chunks = None
                if isinstance(doc, dict):
                    chunks = doc.get("chunks", [])
                elif hasattr(doc, "chunks") and doc.chunks is not None:
                    chunks = doc.chunks

                if chunks:
                    for chunk in chunks:
                        # Ensure each chunk has the minimum required attributes
                        if isinstance(chunk, dict) and "id" in chunk:
                            # Add the chunk directly to results for citation lookup
                            self._results_in_order.append(("chunk", chunk))
                        elif hasattr(chunk, "id"):
                            self._results_in_order.append(("chunk", chunk))

    def add_result(self, result_obj, source_type=None):
        """
        Add a single result object to the collector.
        If source_type is not provided, automatically detect the type.
        """
        if source_type:
            self._results_in_order.append((source_type, result_obj))
            return source_type

        detected_type = self._detect_result_type(result_obj)
        self._results_in_order.append((detected_type, result_obj))
        return detected_type

    def _detect_result_type(self, obj):
        """
        Detect the type of a result object based on its properties.
        Works with both object attributes and dictionary keys.
        """
        # Handle dictionary types first (common for web search results)
        if isinstance(obj, dict):
            # Web search pattern
            if all(k in obj for k in ["title", "link"]) and any(
                k in obj for k in ["snippet", "description"]
            ):
                return "web"

            # Check for graph dictionary patterns
            if "content" in obj and isinstance(obj["content"], dict):
                content = obj["content"]
                if all(k in content for k in ["name", "description"]):
                    return "graph"  # Entity
                if all(
                    k in content for k in ["subject", "predicate", "object"]
                ):
                    return "graph"  # Relationship
                if all(k in content for k in ["name", "summary"]):
                    return "graph"  # Community

            # Chunk pattern
            if all(k in obj for k in ["text", "id"]) and any(
                k in obj for k in ["score", "metadata"]
            ):
                return "chunk"

            # Context document pattern
            if "document" in obj and "chunks" in obj:
                return "doc"

            # Check for explicit type indicator
            if "type" in obj:
                type_val = str(obj["type"]).lower()
                if any(t in type_val for t in ["web", "organic"]):
                    return "web"
                if "graph" in type_val:
                    return "graph"
                if "chunk" in type_val:
                    return "chunk"
                if "document" in type_val:
                    return "doc"

        # Handle object attributes for OOP-style results
        if hasattr(obj, "result_type"):
            result_type = str(obj.result_type).lower()
            if result_type in ["entity", "relationship", "community"]:
                return "graph"

        # Check class name hints
        class_name = obj.__class__.__name__
        if "Graph" in class_name:
            return "graph"
        if "Chunk" in class_name:
            return "chunk"
        if "Web" in class_name:
            return "web"
        if "Document" in class_name:
            return "doc"

        # Check for object attribute patterns
        if hasattr(obj, "content"):
            content = obj.content
            if hasattr(content, "name") and hasattr(content, "description"):
                return "graph"  # Entity
            if hasattr(content, "subject") and hasattr(content, "predicate"):
                return "graph"  # Relationship
            if hasattr(content, "name") and hasattr(content, "summary"):
                return "graph"  # Community

        if (
            hasattr(obj, "text")
            and hasattr(obj, "id")
            and (hasattr(obj, "score") or hasattr(obj, "metadata"))
        ):
            return "chunk"

        if (
            hasattr(obj, "title")
            and hasattr(obj, "link")
            and hasattr(obj, "snippet")
        ):
            return "web"

        if hasattr(obj, "document") and hasattr(obj, "chunks"):
            return "doc"

        # Default when type can't be determined
        return "unknown"

    def find_by_short_id(self, short_id):
        """Find a result by its short ID prefix with better chunk handling"""
        if not short_id:
            return None

        # First try direct lookup using regular iteration
        for _, result_obj in self._results_in_order:
            # Check dictionary objects
            if isinstance(result_obj, dict) and "id" in result_obj:
                result_id = str(result_obj["id"])
                if result_id.startswith(short_id):
                    return result_obj

            # Check object with id attribute
            elif hasattr(result_obj, "id"):
                obj_id = getattr(result_obj, "id", None)
                if obj_id and str(obj_id).startswith(short_id):
                    # Convert to dict if possible
                    if hasattr(result_obj, "as_dict"):
                        return result_obj.as_dict()
                    elif hasattr(result_obj, "model_dump"):
                        return result_obj.model_dump()
                    elif hasattr(result_obj, "dict"):
                        return result_obj.dict()
                    else:
                        return result_obj

        # If not found, look for chunks inside documents that weren't extracted properly
        for source_type, result_obj in self._results_in_order:
            if source_type == "doc":
                # Try various ways to access chunks
                chunks = None
                if isinstance(result_obj, dict) and "chunks" in result_obj:
                    chunks = result_obj["chunks"]
                elif (
                    hasattr(result_obj, "chunks")
                    and result_obj.chunks is not None
                ):
                    chunks = result_obj.chunks

                if chunks:
                    for chunk in chunks:
                        # Try each chunk
                        chunk_id = None
                        if isinstance(chunk, dict) and "id" in chunk:
                            chunk_id = chunk["id"]
                        elif hasattr(chunk, "id"):
                            chunk_id = chunk.id

                        if chunk_id and str(chunk_id).startswith(short_id):
                            return chunk

        return None

    def get_results_by_type(self, type_name):
        """Get all results of a specific type"""
        return [
            result_obj
            for source_type, result_obj in self._results_in_order
            if source_type == type_name
        ]

    def __repr__(self):
        """String representation showing counts by type"""
        type_counts = {}
        for source_type, _ in self._results_in_order:
            type_counts[source_type] = type_counts.get(source_type, 0) + 1

        return f"SearchResultsCollector with {len(self._results_in_order)} results: {type_counts}"

    def get_all_results(self) -> list[Tuple[str, Any]]:
        """
        Return list of (source_type, result_obj, aggregator_index),
        in the order appended.
        """
        return self._results_in_order


def convert_nonserializable_objects(obj):
    if hasattr(obj, "model_dump"):
        obj = obj.model_dump()
    if hasattr(obj, "as_dict"):
        obj = obj.as_dict()
    if hasattr(obj, "to_dict"):
        obj = obj.to_dict()

    if isinstance(obj, dict):
        new_obj = {}
        for key, value in obj.items():
            # Convert key to string if it is a UUID or not already a string.
            new_key = str(key) if not isinstance(key, str) else key
            new_obj[new_key] = convert_nonserializable_objects(value)
        return new_obj
    elif isinstance(obj, list):
        return [convert_nonserializable_objects(item) for item in obj]
    elif isinstance(obj, tuple):
        return tuple(convert_nonserializable_objects(item) for item in obj)
    elif isinstance(obj, set):
        return {convert_nonserializable_objects(item) for item in obj}
    elif isinstance(obj, uuid.UUID):
        return str(obj)
    elif isinstance(obj, datetime):
        return obj.isoformat()  # Convert datetime to ISO formatted string
    else:
        return obj


def dump_obj(obj) -> list[dict[str, Any]]:
    if hasattr(obj, "model_dump"):
        obj = obj.model_dump()
    elif hasattr(obj, "dict"):
        obj = obj.dict()
    elif hasattr(obj, "as_dict"):
        obj = obj.as_dict()
    elif hasattr(obj, "to_dict"):
        obj = obj.to_dict()
    obj = convert_nonserializable_objects(obj)

    return obj


def dump_collector(collector: SearchResultsCollector) -> list[dict[str, Any]]:
    dumped = []
    for source_type, result_obj in collector.get_all_results():
        # Get the dictionary from the result object
        if hasattr(result_obj, "model_dump"):
            result_dict = result_obj.model_dump()
        elif hasattr(result_obj, "dict"):
            result_dict = result_obj.dict()
        elif hasattr(result_obj, "as_dict"):
            result_dict = result_obj.as_dict()
        elif hasattr(result_obj, "to_dict"):
            result_dict = result_obj.to_dict()
        else:
            result_dict = (
                result_obj  # Fallback if no conversion method is available
            )

        # Use the recursive conversion on the entire dictionary
        result_dict = convert_nonserializable_objects(result_dict)

        dumped.append(
            {
                "source_type": source_type,
                "result": result_dict,
            }
        )
    return dumped


def num_tokens(text, model="gpt-4o"):
    try:
        encoding = tiktoken.encoding_for_model(model)
    except KeyError:
        encoding = tiktoken.get_encoding("cl100k_base")

    """Return the number of tokens used by a list of messages for both user and assistant."""
    return len(encoding.encode(text, disallowed_special=()))


class CombinedMeta(AsyncSyncMeta, ABCMeta):
    pass


async def yield_sse_event(event_name: str, payload: dict, chunk_size=1024):
    """
    Helper that yields a single SSE event in properly chunked lines.

    e.g. event: event_name
         data: (partial JSON 1)
         data: (partial JSON 2)
         ...
         [blank line to end event]
    """

    # SSE: first the "event: ..."
    yield f"event: {event_name}\n"

    # Convert payload to JSON
    content_str = json.dumps(payload, default=str)

    # data
    yield f"data: {content_str}\n"

    # blank line signals end of SSE event
    yield "\n"


class SSEFormatter:
    """
    Enhanced formatter for Server-Sent Events (SSE) with citation tracking.
    Extends the existing SSEFormatter with improved citation handling.
    """

    @staticmethod
    async def yield_citation_event(
        citation_data: dict,
    ):
        """
        Emits a citation event with optimized payload.

        Args:
            citation_id: The short ID of the citation (e.g., 'abc1234')
            span: (start, end) position tuple for this occurrence
            payload: Source object (included only for first occurrence)
            is_new: Whether this is the first time we've seen this citation
            citation_id_counter: Optional counter for citation occurrences

        Yields:
            Formatted SSE event lines
        """

        # Include the full payload only for new citations
        if not citation_data.get("is_new") or "payload" not in citation_data:
            citation_data["payload"] = None

        # Yield the event
        async for line in yield_sse_event("citation", citation_data):
            yield line

    @staticmethod
    async def yield_final_answer_event(
        final_data: dict,
    ):
        # Yield the event
        async for line in yield_sse_event("final_answer", final_data):
            yield line

    # Include other existing SSEFormatter methods for compatibility
    @staticmethod
    async def yield_message_event(text_segment, msg_id=None):
        msg_id = msg_id or f"msg_{uuid.uuid4().hex[:8]}"
        msg_payload = {
            "id": msg_id,
            "object": "agent.message.delta",
            "delta": {
                "content": [
                    {
                        "type": "text",
                        "payload": {
                            "value": text_segment,
                            "annotations": [],
                        },
                    }
                ]
            },
        }
        async for line in yield_sse_event("message", msg_payload):
            yield line

    @staticmethod
    async def yield_thinking_event(text_segment, thinking_id=None):
        thinking_id = thinking_id or f"think_{uuid.uuid4().hex[:8]}"
        thinking_data = {
            "id": thinking_id,
            "object": "agent.thinking.delta",
            "delta": {
                "content": [
                    {
                        "type": "text",
                        "payload": {
                            "value": text_segment,
                            "annotations": [],
                        },
                    }
                ]
            },
        }
        async for line in yield_sse_event("thinking", thinking_data):
            yield line

    @staticmethod
    def yield_done_event():
        return "event: done\ndata: [DONE]\n\n"

    @staticmethod
    async def yield_error_event(error_message, error_id=None):
        error_id = error_id or f"err_{uuid.uuid4().hex[:8]}"
        error_payload = {
            "id": error_id,
            "object": "agent.error",
            "error": {"message": error_message, "type": "agent_error"},
        }
        async for line in yield_sse_event("error", error_payload):
            yield line

    @staticmethod
    async def yield_tool_call_event(tool_call_data):
        from ..api.models.retrieval.responses import ToolCallEvent

        tc_event = ToolCallEvent(event="tool_call", data=tool_call_data)
        async for line in yield_sse_event(
            "tool_call", tc_event.dict()["data"]
        ):
            yield line

    # New helper for emitting search results:
    @staticmethod
    async def yield_search_results_event(aggregated_results):
        payload = {
            "id": "search_1",
            "object": "rag.search_results",
            "data": aggregated_results.as_dict(),
        }
        async for line in yield_sse_event("search_results", payload):
            yield line

    @staticmethod
    async def yield_tool_result_event(tool_result_data):
        from ..api.models.retrieval.responses import ToolResultEvent

        tr_event = ToolResultEvent(event="tool_result", data=tool_result_data)
        async for line in yield_sse_event(
            "tool_result", tr_event.dict()["data"]
        ):
            yield line