aboutsummaryrefslogtreecommitdiff
import re
from typing import Set, Tuple

from shared.utils.base_utils import (
    SearchResultsCollector,
    SSEFormatter,
    convert_nonserializable_objects,
    decrement_version,
    deep_update,
    dump_collector,
    dump_obj,
    format_search_results_for_llm,
    generate_default_user_collection_id,
    generate_document_id,
    generate_extraction_id,
    generate_id,
    generate_user_id,
    increment_version,
    num_tokens,
    num_tokens_from_messages,
    update_settings_from_dict,
    validate_uuid,
    yield_sse_event,
)
from shared.utils.splitter.text import (
    RecursiveCharacterTextSplitter,
    TextSplitter,
)


def extract_citations(text: str) -> list[str]:
    """
    Extract citation IDs enclosed in brackets like [abc1234].
    Returns a list of citation IDs.
    """
    # Direct pattern to match IDs inside brackets with alphanumeric pattern
    CITATION_PATTERN = re.compile(r"\[([A-Za-z0-9]{7,8})\]")

    sids = []
    for match in CITATION_PATTERN.finditer(text):
        sid = match.group(1)
        sids.append(sid)

    return sids


def extract_citation_spans(text: str) -> dict[str, list[Tuple[int, int]]]:
    """
    Extract citation IDs with their positions in the text.

    Args:
        text: The text to search for citations

    Returns:
        dictionary mapping citation IDs to lists of (start, end) position tuples
    """
    # Use the same pattern as the original extract_citations
    CITATION_PATTERN = re.compile(r"\[([A-Za-z0-9]{7,8})\]")

    citation_spans: dict = {}

    for match in CITATION_PATTERN.finditer(text):
        sid = match.group(1)
        start = match.start()
        end = match.end()

        if sid not in citation_spans:
            citation_spans[sid] = []

        # Add the position span
        citation_spans[sid].append((start, end))

    return citation_spans


class CitationTracker:
    """
    Tracks citation spans to ensure each one is only emitted once.
    """

    def __init__(self):
        # Track which citation spans we've processed
        # Format: {citation_id: {(start, end), (start, end), ...}}
        self.processed_spans: dict[str, Set[Tuple[int, int]]] = {}

        # Track which citation IDs we've seen
        self.seen_citation_ids: Set[str] = set()

    def is_new_citation(self, citation_id: str) -> bool:
        """Check if this is the first occurrence of this citation ID."""
        is_new = citation_id not in self.seen_citation_ids
        if is_new:
            self.seen_citation_ids.add(citation_id)
        return is_new

    def is_new_span(self, citation_id: str, span: Tuple[int, int]) -> bool:
        """
        Check if this span has already been processed for this citation ID.

        Args:
            citation_id: The citation ID
            span: (start, end) position tuple

        Returns:
            True if this span hasn't been processed yet, False otherwise
        """
        # Initialize set for this citation ID if needed
        if citation_id not in self.processed_spans:
            self.processed_spans[citation_id] = set()

        # Check if we've seen this span before
        if span in self.processed_spans[citation_id]:
            return False

        # This is a new span, track it
        self.processed_spans[citation_id].add(span)
        return True

    def get_all_spans(self) -> dict[str, list[Tuple[int, int]]]:
        """Get all processed spans for final answer."""
        return {
            cid: list(spans) for cid, spans in self.processed_spans.items()
        }


def find_new_citation_spans(
    text: str, tracker: CitationTracker
) -> dict[str, list[Tuple[int, int]]]:
    """
    Extract citation spans that haven't been processed yet.

    Args:
        text: Text to search
        tracker: The CitationTracker instance

    Returns:
        dictionary of citation IDs to lists of new (start, end) spans
    """
    # Get all citation spans in the text
    all_spans = extract_citation_spans(text)

    # Filter to only spans we haven't processed yet
    new_spans: dict = {}

    for cid, spans in all_spans.items():
        for span in spans:
            if tracker.is_new_span(cid, span):
                if cid not in new_spans:
                    new_spans[cid] = []
                new_spans[cid].append(span)

    return new_spans


__all__ = [
    "format_search_results_for_llm",
    "generate_id",
    "generate_document_id",
    "generate_extraction_id",
    "generate_user_id",
    "increment_version",
    "decrement_version",
    "generate_default_user_collection_id",
    "validate_uuid",
    "yield_sse_event",
    "dump_collector",
    "dump_obj",
    "convert_nonserializable_objects",
    "num_tokens",
    "num_tokens_from_messages",
    "SSEFormatter",
    "SearchResultsCollector",
    "update_settings_from_dict",
    "deep_update",
    # Text splitter
    "RecursiveCharacterTextSplitter",
    "TextSplitter",
    "extract_citations",
    "extract_citation_spans",
    "CitationTracker",
    "find_new_citation_spans",
]