import re
from typing import Set, Tuple
from shared.utils.base_utils import (
SearchResultsCollector,
SSEFormatter,
convert_nonserializable_objects,
decrement_version,
deep_update,
dump_collector,
dump_obj,
format_search_results_for_llm,
generate_default_user_collection_id,
generate_document_id,
generate_extraction_id,
generate_id,
generate_user_id,
increment_version,
num_tokens,
num_tokens_from_messages,
update_settings_from_dict,
validate_uuid,
yield_sse_event,
)
from shared.utils.splitter.text import (
RecursiveCharacterTextSplitter,
TextSplitter,
)
def extract_citations(text: str) -> list[str]:
"""
Extract citation IDs enclosed in brackets like [abc1234].
Returns a list of citation IDs.
"""
# Direct pattern to match IDs inside brackets with alphanumeric pattern
CITATION_PATTERN = re.compile(r"\[([A-Za-z0-9]{7,8})\]")
sids = []
for match in CITATION_PATTERN.finditer(text):
sid = match.group(1)
sids.append(sid)
return sids
def extract_citation_spans(text: str) -> dict[str, list[Tuple[int, int]]]:
"""
Extract citation IDs with their positions in the text.
Args:
text: The text to search for citations
Returns:
dictionary mapping citation IDs to lists of (start, end) position tuples
"""
# Use the same pattern as the original extract_citations
CITATION_PATTERN = re.compile(r"\[([A-Za-z0-9]{7,8})\]")
citation_spans: dict = {}
for match in CITATION_PATTERN.finditer(text):
sid = match.group(1)
start = match.start()
end = match.end()
if sid not in citation_spans:
citation_spans[sid] = []
# Add the position span
citation_spans[sid].append((start, end))
return citation_spans
class CitationTracker:
"""
Tracks citation spans to ensure each one is only emitted once.
"""
def __init__(self):
# Track which citation spans we've processed
# Format: {citation_id: {(start, end), (start, end), ...}}
self.processed_spans: dict[str, Set[Tuple[int, int]]] = {}
# Track which citation IDs we've seen
self.seen_citation_ids: Set[str] = set()
def is_new_citation(self, citation_id: str) -> bool:
"""Check if this is the first occurrence of this citation ID."""
is_new = citation_id not in self.seen_citation_ids
if is_new:
self.seen_citation_ids.add(citation_id)
return is_new
def is_new_span(self, citation_id: str, span: Tuple[int, int]) -> bool:
"""
Check if this span has already been processed for this citation ID.
Args:
citation_id: The citation ID
span: (start, end) position tuple
Returns:
True if this span hasn't been processed yet, False otherwise
"""
# Initialize set for this citation ID if needed
if citation_id not in self.processed_spans:
self.processed_spans[citation_id] = set()
# Check if we've seen this span before
if span in self.processed_spans[citation_id]:
return False
# This is a new span, track it
self.processed_spans[citation_id].add(span)
return True
def get_all_spans(self) -> dict[str, list[Tuple[int, int]]]:
"""Get all processed spans for final answer."""
return {
cid: list(spans) for cid, spans in self.processed_spans.items()
}
def find_new_citation_spans(
text: str, tracker: CitationTracker
) -> dict[str, list[Tuple[int, int]]]:
"""
Extract citation spans that haven't been processed yet.
Args:
text: Text to search
tracker: The CitationTracker instance
Returns:
dictionary of citation IDs to lists of new (start, end) spans
"""
# Get all citation spans in the text
all_spans = extract_citation_spans(text)
# Filter to only spans we haven't processed yet
new_spans: dict = {}
for cid, spans in all_spans.items():
for span in spans:
if tracker.is_new_span(cid, span):
if cid not in new_spans:
new_spans[cid] = []
new_spans[cid].append(span)
return new_spans
__all__ = [
"format_search_results_for_llm",
"generate_id",
"generate_document_id",
"generate_extraction_id",
"generate_user_id",
"increment_version",
"decrement_version",
"generate_default_user_collection_id",
"validate_uuid",
"yield_sse_event",
"dump_collector",
"dump_obj",
"convert_nonserializable_objects",
"num_tokens",
"num_tokens_from_messages",
"SSEFormatter",
"SearchResultsCollector",
"update_settings_from_dict",
"deep_update",
# Text splitter
"RecursiveCharacterTextSplitter",
"TextSplitter",
"extract_citations",
"extract_citation_spans",
"CitationTracker",
"find_new_citation_spans",
]