diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/shared')
28 files changed, 6264 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/shared/__init__.py b/.venv/lib/python3.12/site-packages/shared/__init__.py new file mode 100644 index 00000000..5abf8a6a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/shared/__init__.py @@ -0,0 +1,7 @@ +from .abstractions import * +from .abstractions import __all__ as abstractions_all +from .api.models import * +from .api.models import __all__ as api_models_all +from .utils import * + +__all__ = abstractions_all + api_models_all diff --git a/.venv/lib/python3.12/site-packages/shared/abstractions/__init__.py b/.venv/lib/python3.12/site-packages/shared/abstractions/__init__.py new file mode 100644 index 00000000..da33ddd7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/shared/abstractions/__init__.py @@ -0,0 +1,146 @@ +from .base import AsyncSyncMeta, R2RSerializable, syncable +from .document import ( + Document, + DocumentChunk, + DocumentResponse, + DocumentType, + GraphConstructionStatus, + GraphExtractionStatus, + IngestionMode, + IngestionStatus, + RawChunk, + UnprocessedChunk, +) +from .embedding import EmbeddingPurpose, default_embedding_prefixes +from .exception import ( + PDFParsingError, + PopplerNotFoundError, + R2RDocumentProcessingError, + R2RException, +) +from .graph import ( + Community, + Entity, + GraphCommunitySettings, + GraphCreationSettings, + GraphEnrichmentSettings, + GraphExtraction, + Relationship, + StoreType, +) +from .llm import ( + GenerationConfig, + LLMChatCompletion, + LLMChatCompletionChunk, + Message, + MessageType, + RAGCompletion, +) +from .prompt import Prompt +from .search import ( + AggregateSearchResult, + ChunkSearchResult, + ChunkSearchSettings, + GraphCommunityResult, + GraphEntityResult, + GraphRelationshipResult, + GraphSearchResult, + GraphSearchResultType, + GraphSearchSettings, + HybridSearchSettings, + SearchMode, + SearchSettings, + WebPageSearchResult, + select_search_filters, +) +from .user import Token, TokenData, User +from .vector import ( + IndexArgsHNSW, + IndexArgsIVFFlat, + IndexMeasure, + IndexMethod, + StorageResult, + Vector, + VectorEntry, + VectorQuantizationType, + VectorTableName, + VectorType, +) + +__all__ = [ + # Base abstractions + "R2RSerializable", + "AsyncSyncMeta", + "syncable", + # Completion abstractions + "MessageType", + # Document abstractions + "Document", + "DocumentChunk", + "DocumentResponse", + "IngestionMode", + "IngestionStatus", + "GraphExtractionStatus", + "GraphConstructionStatus", + "DocumentType", + "RawChunk", + "UnprocessedChunk", + # Embedding abstractions + "EmbeddingPurpose", + "default_embedding_prefixes", + # Exception abstractions + "R2RDocumentProcessingError", + "R2RException", + "PDFParsingError", + "PopplerNotFoundError", + # Graph abstractions + "Entity", + "Community", + "Community", + "GraphExtraction", + "Relationship", + "StoreType", + # LLM abstractions + "GenerationConfig", + "LLMChatCompletion", + "LLMChatCompletionChunk", + "Message", + "RAGCompletion", + # Prompt abstractions + "Prompt", + # Search abstractions + "AggregateSearchResult", + "GraphSearchResult", + "WebPageSearchResult", + "GraphSearchResultType", + "GraphEntityResult", + "GraphRelationshipResult", + "GraphCommunityResult", + "GraphSearchSettings", + "ChunkSearchSettings", + "ChunkSearchResult", + "SearchSettings", + "select_search_filters", + "HybridSearchSettings", + "SearchMode", + # graph abstractions + "GraphCreationSettings", + "GraphEnrichmentSettings", + "GraphExtraction", + "GraphCommunitySettings", + # User abstractions + "Token", + "TokenData", + "User", + # Vector abstractions + "Vector", + "VectorEntry", + "VectorType", + "IndexMethod", + "IndexMeasure", + "IndexArgsIVFFlat", + "IndexArgsHNSW", + "VectorTableName", + "VectorQuantizationType", + "StorageResult", +] diff --git a/.venv/lib/python3.12/site-packages/shared/abstractions/base.py b/.venv/lib/python3.12/site-packages/shared/abstractions/base.py new file mode 100644 index 00000000..d90ba400 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/shared/abstractions/base.py @@ -0,0 +1,145 @@ +import asyncio +import json +from datetime import datetime +from enum import Enum +from typing import Any, Type, TypeVar +from uuid import UUID + +from pydantic import BaseModel + +T = TypeVar("T", bound="R2RSerializable") + + +class R2RSerializable(BaseModel): + @classmethod + def from_dict(cls: Type[T], data: dict[str, Any] | str) -> T: + if isinstance(data, str): + try: + data_dict = json.loads(data) + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON string: {e}") from e + else: + data_dict = data + return cls(**data_dict) + + def as_dict(self) -> dict[str, Any]: + data = self.model_dump(exclude_unset=True) + return self._serialize_values(data) + + def to_dict(self) -> dict[str, Any]: + data = self.model_dump(exclude_unset=True) + return self._serialize_values(data) + + def to_json(self) -> str: + data = self.to_dict() + return json.dumps(data) + + @classmethod + def from_json(cls: Type[T], json_str: str) -> T: + return cls.model_validate_json(json_str) + + @staticmethod + def _serialize_values(data: Any) -> Any: + if isinstance(data, dict): + return { + k: R2RSerializable._serialize_values(v) + for k, v in data.items() + } + elif isinstance(data, list): + return [R2RSerializable._serialize_values(v) for v in data] + elif isinstance(data, UUID): + return str(data) + elif isinstance(data, Enum): + return data.value + elif isinstance(data, datetime): + return data.isoformat() + else: + return data + + class Config: + arbitrary_types_allowed = True + json_encoders = { + UUID: str, + bytes: lambda v: v.decode("utf-8", errors="ignore"), + } + + +class AsyncSyncMeta(type): + _event_loop = None # Class-level shared event loop + + @classmethod + def get_event_loop(cls): + if cls._event_loop is None or cls._event_loop.is_closed(): + cls._event_loop = asyncio.new_event_loop() + asyncio.set_event_loop(cls._event_loop) + return cls._event_loop + + def __new__(cls, name, bases, dct): + new_cls = super().__new__(cls, name, bases, dct) + for attr_name, attr_value in dct.items(): + if asyncio.iscoroutinefunction(attr_value) and getattr( + attr_value, "_syncable", False + ): + sync_method_name = attr_name[ + 1: + ] # Remove leading 'a' for sync method + async_method = attr_value + + def make_sync_method(async_method): + def sync_wrapper(self, *args, **kwargs): + loop = cls.get_event_loop() + if not loop.is_running(): + # Setup to run the loop in a background thread if necessary + # to prevent blocking the main thread in a synchronous call environment + from threading import Thread + + result = None + exception = None + + def run(): + nonlocal result, exception + try: + asyncio.set_event_loop(loop) + result = loop.run_until_complete( + async_method(self, *args, **kwargs) + ) + except Exception as e: + exception = e + finally: + generation_config = kwargs.get( + "rag_generation_config", None + ) + if ( + not generation_config + or not generation_config.stream + ): + loop.run_until_complete( + loop.shutdown_asyncgens() + ) + loop.close() + + thread = Thread(target=run) + thread.start() + thread.join() + if exception: + raise exception + return result + else: + # If there's already a running loop, schedule and execute the coroutine + future = asyncio.run_coroutine_threadsafe( + async_method(self, *args, **kwargs), loop + ) + return future.result() + + return sync_wrapper + + setattr( + new_cls, sync_method_name, make_sync_method(async_method) + ) + return new_cls + + +def syncable(func): + """Decorator to mark methods for synchronous wrapper creation.""" + func._syncable = True + return func diff --git a/.venv/lib/python3.12/site-packages/shared/abstractions/document.py b/.venv/lib/python3.12/site-packages/shared/abstractions/document.py new file mode 100644 index 00000000..513392f8 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/shared/abstractions/document.py @@ -0,0 +1,377 @@ +"""Abstractions for documents and their extractions.""" + +import json +import logging +from datetime import datetime +from enum import Enum +from typing import Any, Optional +from uuid import UUID, uuid4 + +from pydantic import Field + +from .base import R2RSerializable +from .llm import GenerationConfig + +logger = logging.getLogger() + + +class DocumentType(str, Enum): + """Types of documents that can be stored.""" + + # Audio + MP3 = "mp3" + + # CSV + CSV = "csv" + + # Email + EML = "eml" + MSG = "msg" + P7S = "p7s" + + # EPUB + EPUB = "epub" + + # Excel + XLS = "xls" + XLSX = "xlsx" + + # HTML + HTML = "html" + HTM = "htm" + + # Image + BMP = "bmp" + HEIC = "heic" + JPEG = "jpeg" + PNG = "png" + TIFF = "tiff" + JPG = "jpg" + SVG = "svg" + + # Markdown + MD = "md" + + # Org Mode + ORG = "org" + + # Open Office + ODT = "odt" + + # PDF + PDF = "pdf" + + # Plain text + TXT = "txt" + JSON = "json" + + # PowerPoint + PPT = "ppt" + PPTX = "pptx" + + # reStructured Text + RST = "rst" + + # Rich Text + RTF = "rtf" + + # TSV + TSV = "tsv" + + # Video/GIF + GIF = "gif" + + # Word + DOC = "doc" + DOCX = "docx" + + # XML + XML = "xml" + + +class Document(R2RSerializable): + id: UUID = Field(default_factory=uuid4) + collection_ids: list[UUID] + owner_id: UUID + document_type: DocumentType + metadata: dict + + class Config: + arbitrary_types_allowed = True + ignore_extra = False + json_encoders = { + UUID: str, + } + populate_by_name = True + + +class IngestionStatus(str, Enum): + """Status of document processing.""" + + PENDING = "pending" + PARSING = "parsing" + EXTRACTING = "extracting" + CHUNKING = "chunking" + EMBEDDING = "embedding" + AUGMENTING = "augmenting" + STORING = "storing" + ENRICHING = "enriching" + + FAILED = "failed" + SUCCESS = "success" + + def __str__(self): + return self.value + + @classmethod + def table_name(cls) -> str: + return "documents" + + @classmethod + def id_column(cls) -> str: + return "document_id" + + +class GraphExtractionStatus(str, Enum): + """Status of graph creation per document.""" + + PENDING = "pending" + PROCESSING = "processing" + SUCCESS = "success" + ENRICHED = "enriched" + FAILED = "failed" + + def __str__(self): + return self.value + + @classmethod + def table_name(cls) -> str: + return "documents" + + @classmethod + def id_column(cls) -> str: + return "id" + + +class GraphConstructionStatus(str, Enum): + """Status of graph enrichment per collection.""" + + PENDING = "pending" + PROCESSING = "processing" + OUTDATED = "outdated" + SUCCESS = "success" + FAILED = "failed" + + def __str__(self): + return self.value + + @classmethod + def table_name(cls) -> str: + return "collections" + + @classmethod + def id_column(cls) -> str: + return "id" + + +class DocumentResponse(R2RSerializable): + """Base class for document information handling.""" + + id: UUID + collection_ids: list[UUID] + owner_id: UUID + document_type: DocumentType + metadata: dict + title: Optional[str] = None + version: str + size_in_bytes: Optional[int] + ingestion_status: IngestionStatus = IngestionStatus.PENDING + extraction_status: GraphExtractionStatus = GraphExtractionStatus.PENDING + created_at: Optional[datetime] = None + updated_at: Optional[datetime] = None + ingestion_attempt_number: Optional[int] = None + summary: Optional[str] = None + summary_embedding: Optional[list[float]] = None # Add optional embedding + total_tokens: Optional[int] = None + chunks: Optional[list] = None + + def convert_to_db_entry(self): + """Prepare the document info for database entry, extracting certain + fields from metadata.""" + now = datetime.now() + + # Format the embedding properly for Postgres vector type + embedding = None + if self.summary_embedding is not None: + embedding = f"[{','.join(str(x) for x in self.summary_embedding)}]" + + return { + "id": self.id, + "collection_ids": self.collection_ids, + "owner_id": self.owner_id, + "document_type": self.document_type, + "metadata": json.dumps(self.metadata), + "title": self.title or "N/A", + "version": self.version, + "size_in_bytes": self.size_in_bytes, + "ingestion_status": self.ingestion_status.value, + "extraction_status": self.extraction_status.value, + "created_at": self.created_at or now, + "updated_at": self.updated_at or now, + "ingestion_attempt_number": self.ingestion_attempt_number or 0, + "summary": self.summary, + "summary_embedding": embedding, + "total_tokens": self.total_tokens or 0, # ensure we pass 0 if None + } + + class Config: + json_schema_extra = { + "example": { + "id": "123e4567-e89b-12d3-a456-426614174000", + "collection_ids": ["123e4567-e89b-12d3-a456-426614174000"], + "owner_id": "123e4567-e89b-12d3-a456-426614174000", + "document_type": "pdf", + "metadata": {"title": "Sample Document"}, + "title": "Sample Document", + "version": "1.0", + "size_in_bytes": 123456, + "ingestion_status": "pending", + "extraction_status": "pending", + "created_at": "2021-01-01T00:00:00", + "updated_at": "2021-01-01T00:00:00", + "ingestion_attempt_number": 0, + "summary": "A summary of the document", + "summary_embedding": [0.1, 0.2, 0.3], + "total_tokens": 1000, + } + } + + +class UnprocessedChunk(R2RSerializable): + """An extraction from a document.""" + + id: Optional[UUID] = None + document_id: Optional[UUID] = None + collection_ids: list[UUID] = [] + metadata: dict = {} + text: str + + +class UpdateChunk(R2RSerializable): + """An extraction from a document.""" + + id: UUID + metadata: Optional[dict] = None + text: str + + +class DocumentChunk(R2RSerializable): + """An extraction from a document.""" + + id: UUID + document_id: UUID + collection_ids: list[UUID] + owner_id: UUID + data: str | bytes + metadata: dict + + +class RawChunk(R2RSerializable): + text: str + + +class IngestionMode(str, Enum): + hi_res = "hi-res" + fast = "fast" + custom = "custom" + + +class ChunkEnrichmentSettings(R2RSerializable): + """Settings for chunk enrichment.""" + + enable_chunk_enrichment: bool = Field( + default=False, + description="Whether to enable chunk enrichment or not", + ) + n_chunks: int = Field( + default=2, + description="The number of preceding and succeeding chunks to include. Defaults to 2.", + ) + generation_config: Optional[GenerationConfig] = Field( + default=None, + description="The generation config to use for chunk enrichment", + ) + chunk_enrichment_prompt: Optional[str] = Field( + default="chunk_enrichment", + description="The prompt to use for chunk enrichment", + ) + + +class IngestionConfig(R2RSerializable): + provider: str = "r2r" + excluded_parsers: list[str] = ["mp4"] + chunking_strategy: str = "recursive" + chunk_enrichment_settings: ChunkEnrichmentSettings = ( + ChunkEnrichmentSettings() + ) + extra_parsers: dict[str, Any] = {} + + audio_transcription_model: str = "" + + vision_img_prompt_name: str = "vision_img" + + vision_pdf_prompt_name: str = "vision_pdf" + + skip_document_summary: bool = False + document_summary_system_prompt: str = "system" + document_summary_task_prompt: str = "summary" + chunks_for_document_summary: int = 128 + document_summary_model: str = "" + + @property + def supported_providers(self) -> list[str]: + return ["r2r", "unstructured_local", "unstructured_api"] + + def validate_config(self) -> None: + if self.provider not in self.supported_providers: + raise ValueError(f"Provider {self.provider} is not supported.") + + @classmethod + def get_default(cls, mode: str) -> "IngestionConfig": + """Return default ingestion configuration for a given mode.""" + if mode == "hi-res": + # More thorough parsing, no skipping summaries, possibly larger `chunks_for_document_summary`. + return cls( + provider="r2r", + excluded_parsers=["mp4"], + chunk_enrichment_settings=ChunkEnrichmentSettings(), # default + extra_parsers={}, + audio_transcription_model="", + vision_img_prompt_name="vision_img", + vision_pdf_prompt_name="vision_pdf", + skip_document_summary=False, + document_summary_system_prompt="system", + document_summary_task_prompt="summary", + chunks_for_document_summary=256, # larger for hi-res + document_summary_model="", + ) + + elif mode == "fast": + # Skip summaries and other enrichment steps for speed. + return cls( + provider="r2r", + excluded_parsers=["mp4"], + chunk_enrichment_settings=ChunkEnrichmentSettings(), # default + extra_parsers={}, + audio_transcription_model="", + vision_img_prompt_name="vision_img", + vision_pdf_prompt_name="vision_pdf", + skip_document_summary=True, # skip summaries + document_summary_system_prompt="system", + document_summary_task_prompt="summary", + chunks_for_document_summary=64, + document_summary_model="", + ) + else: + # For `custom` or any unrecognized mode, return a base config + return cls() diff --git a/.venv/lib/python3.12/site-packages/shared/abstractions/embedding.py b/.venv/lib/python3.12/site-packages/shared/abstractions/embedding.py new file mode 100644 index 00000000..6e27da28 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/shared/abstractions/embedding.py @@ -0,0 +1,31 @@ +from enum import Enum, auto + + +class EmbeddingPurpose(str, Enum): + INDEX = auto() + QUERY = auto() + DOCUMENT = auto() + + +default_embedding_prefixes = { + "nomic-embed-text-v1.5": { + EmbeddingPurpose.INDEX: "", + EmbeddingPurpose.QUERY: "search_query: ", + EmbeddingPurpose.DOCUMENT: "search_document: ", + }, + "nomic-embed-text": { + EmbeddingPurpose.INDEX: "", + EmbeddingPurpose.QUERY: "search_query: ", + EmbeddingPurpose.DOCUMENT: "search_document: ", + }, + "mixedbread-ai/mxbai-embed-large-v1": { + EmbeddingPurpose.INDEX: "", + EmbeddingPurpose.QUERY: "Represent this sentence for searching relevant passages: ", + EmbeddingPurpose.DOCUMENT: "Represent this sentence for searching relevant passages: ", + }, + "mixedbread-ai/mxbai-embed-large": { + EmbeddingPurpose.INDEX: "", + EmbeddingPurpose.QUERY: "Represent this sentence for searching relevant passages: ", + EmbeddingPurpose.DOCUMENT: "Represent this sentence for searching relevant passages: ", + }, +} diff --git a/.venv/lib/python3.12/site-packages/shared/abstractions/exception.py b/.venv/lib/python3.12/site-packages/shared/abstractions/exception.py new file mode 100644 index 00000000..3dedfae8 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/shared/abstractions/exception.py @@ -0,0 +1,75 @@ +import textwrap +from typing import Any, Optional +from uuid import UUID + + +class R2RException(Exception): + def __init__( + self, message: str, status_code: int, detail: Optional[Any] = None + ): + self.message = message + self.status_code = status_code + super().__init__(self.message) + + def to_dict(self): + return { + "message": self.message, + "status_code": self.status_code, + "detail": self.detail, + "error_type": self.__class__.__name__, + } + + +class R2RDocumentProcessingError(R2RException): + def __init__( + self, error_message: str, document_id: UUID, status_code: int = 500 + ): + detail = { + "document_id": str(document_id), + "error_type": "document_processing_error", + } + super().__init__(error_message, status_code, detail) + + def to_dict(self): + result = super().to_dict() + result["document_id"] = self.document_id + return result + + +class PDFParsingError(R2RException): + """Custom exception for PDF parsing errors.""" + + def __init__( + self, + message: str, + original_error: Exception | None = None, + status_code: int = 500, + ): + detail = { + "original_error": str(original_error) if original_error else None + } + super().__init__(message, status_code, detail) + + +class PopplerNotFoundError(PDFParsingError): + """Specific error for when Poppler is not installed.""" + + def __init__(self): + installation_instructions = textwrap.dedent(""" + PDF processing requires Poppler to be installed. Please install Poppler and ensure it's in your system PATH. + + Installing poppler: + - Ubuntu: sudo apt-get install poppler-utils + - Archlinux: sudo pacman -S poppler + - MacOS: brew install poppler + - Windows: + 1. Download poppler from @oschwartz10612 + 2. Move extracted directory to desired location + 3. Add bin/ directory to PATH + 4. Test by running 'pdftoppm -h' in terminal + """) + super().__init__( + message=installation_instructions, + status_code=422, + original_error=None, + ) diff --git a/.venv/lib/python3.12/site-packages/shared/abstractions/graph.py b/.venv/lib/python3.12/site-packages/shared/abstractions/graph.py new file mode 100644 index 00000000..3c1cec9e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/shared/abstractions/graph.py @@ -0,0 +1,257 @@ +import json +from dataclasses import dataclass +from datetime import datetime +from enum import Enum +from typing import Any, Optional +from uuid import UUID + +from pydantic import Field + +from ..abstractions.llm import GenerationConfig +from .base import R2RSerializable + + +class Entity(R2RSerializable): + """An entity extracted from a document.""" + + name: str + description: Optional[str] = None + category: Optional[str] = None + metadata: Optional[dict[str, Any]] = None + + id: Optional[UUID] = None + parent_id: Optional[UUID] = None # graph_id | document_id + description_embedding: Optional[list[float] | str] = None + chunk_ids: Optional[list[UUID]] = [] + + def __str__(self): + return f"{self.name}:{self.category}" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + if isinstance(self.metadata, str): + try: + self.metadata = json.loads(self.metadata) + except json.JSONDecodeError: + self.metadata = self.metadata + + +class Relationship(R2RSerializable): + """A relationship between two entities. + + This is a generic relationship, and can be used to represent any type of + relationship between any two entities. + """ + + id: Optional[UUID] = None + subject: str + predicate: str + object: str + description: Optional[str] = None + subject_id: Optional[UUID] = None + object_id: Optional[UUID] = None + weight: float | None = 1.0 + chunk_ids: Optional[list[UUID]] = [] + parent_id: Optional[UUID] = None + description_embedding: Optional[list[float] | str] = None + metadata: Optional[dict[str, Any] | str] = None + + def __init__(self, **kwargs): + super().__init__(**kwargs) + if isinstance(self.metadata, str): + try: + self.metadata = json.loads(self.metadata) + except json.JSONDecodeError: + self.metadata = self.metadata + + +@dataclass +class Community(R2RSerializable): + name: str = "" + summary: str = "" + level: Optional[int] = None + findings: list[str] = [] + id: Optional[int | UUID] = None + community_id: Optional[UUID] = None + collection_id: Optional[UUID] = None + rating: Optional[float] = None + rating_explanation: Optional[str] = None + description_embedding: Optional[list[float]] = None + attributes: dict[str, Any] | None = None + created_at: datetime = Field( + default_factory=datetime.utcnow, + ) + updated_at: datetime = Field( + default_factory=datetime.utcnow, + ) + + def __init__(self, **kwargs): + if isinstance(kwargs.get("attributes", None), str): + kwargs["attributes"] = json.loads(kwargs["attributes"]) + + if isinstance(kwargs.get("embedding", None), str): + kwargs["embedding"] = json.loads(kwargs["embedding"]) + + super().__init__(**kwargs) + + @classmethod + def from_dict(cls, data: dict[str, Any] | str) -> "Community": + parsed_data: dict[str, Any] = ( + json.loads(data) if isinstance(data, str) else data + ) + if isinstance(parsed_data.get("embedding", None), str): + parsed_data["embedding"] = json.loads(parsed_data["embedding"]) + return cls(**parsed_data) + + +class GraphExtraction(R2RSerializable): + """A protocol for a knowledge graph extraction.""" + + entities: list[Entity] + relationships: list[Relationship] + + +class Graph(R2RSerializable): + id: UUID | None = Field() + name: str + description: Optional[str] = None + created_at: datetime = Field( + default_factory=datetime.utcnow, + ) + updated_at: datetime = Field( + default_factory=datetime.utcnow, + ) + status: str = "pending" + + class Config: + populate_by_name = True + from_attributes = True + + @classmethod + def from_dict(cls, data: dict[str, Any] | str) -> "Graph": + """Create a Graph instance from a dictionary.""" + # Convert string to dict if needed + parsed_data: dict[str, Any] = ( + json.loads(data) if isinstance(data, str) else data + ) + return cls(**parsed_data) + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + +class StoreType(str, Enum): + GRAPHS = "graphs" + DOCUMENTS = "documents" + + +class GraphCreationSettings(R2RSerializable): + """Settings for knowledge graph creation.""" + + graph_extraction_prompt: str = Field( + default="graph_extraction", + description="The prompt to use for knowledge graph extraction.", + ) + + graph_entity_description_prompt: str = Field( + default="graph_entity_description", + description="The prompt to use for entity description generation.", + ) + + entity_types: list[str] = Field( + default=[], + description="The types of entities to extract.", + ) + + relation_types: list[str] = Field( + default=[], + description="The types of relations to extract.", + ) + + chunk_merge_count: int = Field( + default=2, + description="""The number of extractions to merge into a single graph + extraction.""", + ) + + max_knowledge_relationships: int = Field( + default=100, + description="""The maximum number of knowledge relationships to extract + from each chunk.""", + ) + + max_description_input_length: int = Field( + default=65536, + description="""The maximum length of the description for a node in the + graph.""", + ) + + generation_config: Optional[GenerationConfig] = Field( + default=None, + description="Configuration for text generation during graph enrichment.", + ) + + automatic_deduplication: bool = Field( + default=False, + description="Whether to automatically deduplicate entities.", + ) + + +class GraphEnrichmentSettings(R2RSerializable): + """Settings for knowledge graph enrichment.""" + + force_graph_search_results_enrichment: bool = Field( + default=False, + description="""Force run the enrichment step even if graph creation is + still in progress for some documents.""", + ) + + graph_communities_prompt: str = Field( + default="graph_communities", + description="The prompt to use for knowledge graph enrichment.", + ) + + max_summary_input_length: int = Field( + default=65536, + description="The maximum length of the summary for a community.", + ) + + generation_config: Optional[GenerationConfig] = Field( + default=None, + description="Configuration for text generation during graph enrichment.", + ) + + leiden_params: dict = Field( + default_factory=dict, + description="Parameters for the Leiden algorithm.", + ) + + +class GraphCommunitySettings(R2RSerializable): + """Settings for knowledge graph community enrichment.""" + + force_graph_search_results_enrichment: bool = Field( + default=False, + description="""Force run the enrichment step even if graph creation is + still in progress for some documents.""", + ) + + graph_communities: str = Field( + default="graph_communities", + description="The prompt to use for knowledge graph enrichment.", + ) + + max_summary_input_length: int = Field( + default=65536, + description="The maximum length of the summary for a community.", + ) + + generation_config: Optional[GenerationConfig] = Field( + default=None, + description="Configuration for text generation during graph enrichment.", + ) + + leiden_params: dict = Field( + default_factory=dict, + description="Parameters for the Leiden algorithm.", + ) diff --git a/.venv/lib/python3.12/site-packages/shared/abstractions/llm.py b/.venv/lib/python3.12/site-packages/shared/abstractions/llm.py new file mode 100644 index 00000000..d71e279e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/shared/abstractions/llm.py @@ -0,0 +1,325 @@ +"""Abstractions for the LLM model.""" + +import json +from enum import Enum +from typing import TYPE_CHECKING, Any, ClassVar, Optional + +from openai.types.chat import ChatCompletionChunk +from pydantic import BaseModel, Field + +from .base import R2RSerializable + +if TYPE_CHECKING: + from .search import AggregateSearchResult + +from typing_extensions import Literal + + +class Function(BaseModel): + arguments: str + """ + The arguments to call the function with, as generated by the model in JSON + format. Note that the model does not always generate valid JSON, and may + hallucinate parameters not defined by your function schema. Validate the + arguments in your code before calling your function. + """ + + name: str + """The name of the function to call.""" + + +class ChatCompletionMessageToolCall(BaseModel): + id: str + """The ID of the tool call.""" + + function: Function + """The function that the model called.""" + + type: Literal["function"] + """The type of the tool. Currently, only `function` is supported.""" + + +class FunctionCall(BaseModel): + arguments: str + """ + The arguments to call the function with, as generated by the model in JSON + format. Note that the model does not always generate valid JSON, and may + hallucinate parameters not defined by your function schema. Validate the + arguments in your code before calling your function. + """ + + name: str + """The name of the function to call.""" + + +class ChatCompletionMessage(BaseModel): + content: Optional[str] = None + """The contents of the message.""" + + refusal: Optional[str] = None + """The refusal message generated by the model.""" + + role: Literal["assistant"] + """The role of the author of this message.""" + + # audio: Optional[ChatCompletionAudio] = None + """ + If the audio output modality is requested, this object contains data about the + audio response from the model. + [Learn more](https://platform.openai.com/docs/guides/audio). + """ + + function_call: Optional[FunctionCall] = None + """Deprecated and replaced by `tool_calls`. + + The name and arguments of a function that should be called, as generated by the + model. + """ + + tool_calls: Optional[list[ChatCompletionMessageToolCall]] = None + """The tool calls generated by the model, such as function calls.""" + + structured_content: Optional[list[dict]] = None + + +class Choice(BaseModel): + finish_reason: Literal[ + "stop", + "length", + "tool_calls", + "content_filter", + "function_call", + "max_tokens", + ] + """The reason the model stopped generating tokens. + + This will be `stop` if the model hit a natural stop point or a provided stop + sequence, `length` if the maximum number of tokens specified in the request was + reached, `content_filter` if content was omitted due to a flag from our content + filters, `tool_calls` if the model called a tool, or `function_call` + (deprecated) if the model called a function. + """ + + index: int + """The index of the choice in the list of choices.""" + + # logprobs: Optional[ChoiceLogprobs] = None + """Log probability information for the choice.""" + + message: ChatCompletionMessage + """A chat completion message generated by the model.""" + + +class LLMChatCompletion(BaseModel): + id: str + """A unique identifier for the chat completion.""" + + choices: list[Choice] + """A list of chat completion choices. + + Can be more than one if `n` is greater than 1. + """ + + created: int + """The Unix timestamp (in seconds) of when the chat completion was created.""" + + model: str + """The model used for the chat completion.""" + + object: Literal["chat.completion"] + """The object type, which is always `chat.completion`.""" + + service_tier: Optional[Literal["scale", "default"]] = None + """The service tier used for processing the request.""" + + system_fingerprint: Optional[str] = None + """This fingerprint represents the backend configuration that the model runs with. + + Can be used in conjunction with the `seed` request parameter to understand when + backend changes have been made that might impact determinism. + """ + + usage: Optional[Any] = None + """Usage statistics for the completion request.""" + + +LLMChatCompletionChunk = ChatCompletionChunk + + +class RAGCompletion: + completion: LLMChatCompletion + search_results: "AggregateSearchResult" + + def __init__( + self, + completion: LLMChatCompletion, + search_results: "AggregateSearchResult", + ): + self.completion = completion + self.search_results = search_results + + +class GenerationConfig(R2RSerializable): + _defaults: ClassVar[dict] = { + "model": None, + "temperature": 0.1, + "top_p": 1.0, + "max_tokens_to_sample": 1024, + "stream": False, + "functions": None, + "tools": None, + "add_generation_kwargs": None, + "api_base": None, + "response_format": None, + "extended_thinking": False, + "thinking_budget": None, + "reasoning_effort": None, + } + + model: Optional[str] = Field( + default_factory=lambda: GenerationConfig._defaults["model"] + ) + temperature: float = Field( + default_factory=lambda: GenerationConfig._defaults["temperature"] + ) + top_p: Optional[float] = Field( + default_factory=lambda: GenerationConfig._defaults["top_p"], + ) + max_tokens_to_sample: int = Field( + default_factory=lambda: GenerationConfig._defaults[ + "max_tokens_to_sample" + ], + ) + stream: bool = Field( + default_factory=lambda: GenerationConfig._defaults["stream"] + ) + functions: Optional[list[dict]] = Field( + default_factory=lambda: GenerationConfig._defaults["functions"] + ) + tools: Optional[list[dict]] = Field( + default_factory=lambda: GenerationConfig._defaults["tools"] + ) + add_generation_kwargs: Optional[dict] = Field( + default_factory=lambda: GenerationConfig._defaults[ + "add_generation_kwargs" + ], + ) + api_base: Optional[str] = Field( + default_factory=lambda: GenerationConfig._defaults["api_base"], + ) + response_format: Optional[dict | BaseModel] = None + extended_thinking: bool = Field( + default=False, + description="Flag to enable extended thinking mode (for Anthropic providers)", + ) + thinking_budget: Optional[int] = Field( + default=None, + description=( + "Token budget for internal reasoning when extended thinking mode is enabled. " + "Must be less than max_tokens_to_sample." + ), + ) + reasoning_effort: Optional[str] = Field( + default=None, + description=( + "Effort level for internal reasoning when extended thinking mode is enabled, `low`, `medium`, or `high`." + "Only applicable to OpenAI providers." + ), + ) + + @classmethod + def set_default(cls, **kwargs): + for key, value in kwargs.items(): + if key in cls._defaults: + cls._defaults[key] = value + else: + raise AttributeError( + f"No default attribute '{key}' in GenerationConfig" + ) + + def __init__(self, **data): + # Handle max_tokens mapping to max_tokens_to_sample + if "max_tokens" in data: + # Only set max_tokens_to_sample if it's not already provided + if "max_tokens_to_sample" not in data: + data["max_tokens_to_sample"] = data.pop("max_tokens") + else: + # If both are provided, max_tokens_to_sample takes precedence + data.pop("max_tokens") + + if ( + "response_format" in data + and isinstance(data["response_format"], type) + and issubclass(data["response_format"], BaseModel) + ): + model_class = data["response_format"] + data["response_format"] = { + "type": "json_schema", + "json_schema": { + "name": model_class.__name__, + "schema": model_class.model_json_schema(), + }, + } + + model = data.pop("model", None) + if model is not None: + super().__init__(model=model, **data) + else: + super().__init__(**data) + + def __str__(self): + return json.dumps(self.to_dict()) + + class Config: + populate_by_name = True + json_schema_extra = { + "example": { + "model": "openai/gpt-4o", + "temperature": 0.1, + "top_p": 1.0, + "max_tokens_to_sample": 1024, + "stream": False, + "functions": None, + "tools": None, + "add_generation_kwargs": None, + "api_base": None, + } + } + + +class MessageType(Enum): + SYSTEM = "system" + USER = "user" + ASSISTANT = "assistant" + FUNCTION = "function" + TOOL = "tool" + + def __str__(self): + return self.value + + +class Message(R2RSerializable): + role: MessageType | str + content: Optional[Any] = None + name: Optional[str] = None + function_call: Optional[dict[str, Any]] = None + tool_calls: Optional[list[dict[str, Any]]] = None + tool_call_id: Optional[str] = None + metadata: Optional[dict[str, Any]] = None + structured_content: Optional[list[dict]] = None + image_url: Optional[str] = None # For URL-based images + image_data: Optional[dict[str, str]] = ( + None # For base64 {media_type, data} + ) + + class Config: + populate_by_name = True + json_schema_extra = { + "example": { + "role": "user", + "content": "This is a test message.", + "name": None, + "function_call": None, + "tool_calls": None, + } + } diff --git a/.venv/lib/python3.12/site-packages/shared/abstractions/prompt.py b/.venv/lib/python3.12/site-packages/shared/abstractions/prompt.py new file mode 100644 index 00000000..85ab5312 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/shared/abstractions/prompt.py @@ -0,0 +1,39 @@ +"""Abstraction for a prompt that can be formatted with inputs.""" + +import logging +from datetime import datetime +from typing import Any +from uuid import UUID, uuid4 + +from pydantic import BaseModel, Field + +logger = logging.getLogger() + + +class Prompt(BaseModel): + """A prompt that can be formatted with inputs.""" + + id: UUID = Field(default_factory=uuid4) + name: str + template: str + input_types: dict[str, str] + created_at: datetime = Field(default_factory=datetime.utcnow) + updated_at: datetime = Field(default_factory=datetime.utcnow) + + def format_prompt(self, inputs: dict[str, Any]) -> str: + self._validate_inputs(inputs) + return self.template.format(**inputs) + + def _validate_inputs(self, inputs: dict[str, Any]) -> None: + for var, expected_type_name in self.input_types.items(): + expected_type = self._convert_type(expected_type_name) + if var not in inputs: + raise ValueError(f"Missing input: {var}") + if not isinstance(inputs[var], expected_type): + raise TypeError( + f"Input '{var}' must be of type {expected_type.__name__}, got {type(inputs[var]).__name__} instead." + ) + + def _convert_type(self, type_name: str) -> type: + type_mapping = {"int": int, "str": str} + return type_mapping.get(type_name, str) diff --git a/.venv/lib/python3.12/site-packages/shared/abstractions/search.py b/.venv/lib/python3.12/site-packages/shared/abstractions/search.py new file mode 100644 index 00000000..bf0f650e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/shared/abstractions/search.py @@ -0,0 +1,614 @@ +"""Abstractions for search functionality.""" + +from copy import copy +from enum import Enum +from typing import Any, Optional +from uuid import NAMESPACE_DNS, UUID, uuid5 + +from pydantic import Field + +from .base import R2RSerializable +from .document import DocumentResponse +from .llm import GenerationConfig +from .vector import IndexMeasure + + +def generate_id_from_label(label) -> UUID: + return uuid5(NAMESPACE_DNS, label) + + +class ChunkSearchResult(R2RSerializable): + """Result of a search operation.""" + + id: UUID + document_id: UUID + owner_id: Optional[UUID] + collection_ids: list[UUID] + score: Optional[float] = None + text: str + metadata: dict[str, Any] + + def __str__(self) -> str: + if self.score: + return ( + f"ChunkSearchResult(score={self.score:.3f}, text={self.text})" + ) + else: + return f"ChunkSearchResult(text={self.text})" + + def __repr__(self) -> str: + return self.__str__() + + def as_dict(self) -> dict: + return { + "id": self.id, + "document_id": self.document_id, + "owner_id": self.owner_id, + "collection_ids": self.collection_ids, + "score": self.score, + "text": self.text, + "metadata": self.metadata, + } + + class Config: + populate_by_name = True + json_schema_extra = { + "example": { + "id": "3f3d47f3-8baf-58eb-8bc2-0171fb1c6e09", + "document_id": "3e157b3a-8469-51db-90d9-52e7d896b49b", + "owner_id": "2acb499e-8428-543b-bd85-0d9098718220", + "collection_ids": [], + "score": 0.23943702876567796, + "text": "Example text from the document", + "metadata": { + "title": "example_document.pdf", + "associated_query": "What is the capital of France?", + }, + } + } + + +class GraphSearchResultType(str, Enum): + ENTITY = "entity" + RELATIONSHIP = "relationship" + COMMUNITY = "community" + + +class GraphEntityResult(R2RSerializable): + id: Optional[UUID] = None + name: str + description: str + metadata: Optional[dict[str, Any]] = None + + class Config: + json_schema_extra = { + "example": { + "name": "Entity Name", + "description": "Entity Description", + "metadata": {}, + } + } + + +class GraphRelationshipResult(R2RSerializable): + id: Optional[UUID] = None + subject: str + predicate: str + object: str + subject_id: Optional[UUID] = None + object_id: Optional[UUID] = None + metadata: Optional[dict[str, Any]] = None + score: Optional[float] = None + description: str | None = None + + class Config: + json_schema_extra = { + "example": { + "name": "Relationship Name", + "description": "Relationship Description", + "metadata": {}, + } + } + + def __str__(self) -> str: + return f"GraphRelationshipResult(subject={self.subject}, predicate={self.predicate}, object={self.object})" + + +class GraphCommunityResult(R2RSerializable): + id: Optional[UUID] = None + name: str + summary: str + metadata: Optional[dict[str, Any]] = None + + class Config: + json_schema_extra = { + "example": { + "name": "Community Name", + "summary": "Community Summary", + "rating": 9, + "rating_explanation": "Rating Explanation", + "metadata": {}, + } + } + + def __str__(self) -> str: + return ( + f"GraphCommunityResult(name={self.name}, summary={self.summary})" + ) + + +class GraphSearchResult(R2RSerializable): + content: GraphEntityResult | GraphRelationshipResult | GraphCommunityResult + result_type: Optional[GraphSearchResultType] = None + chunk_ids: Optional[list[UUID]] = None + metadata: dict[str, Any] = {} + score: Optional[float] = None + id: UUID + + def __str__(self) -> str: + return f"GraphSearchResult(content={self.content}, result_type={self.result_type})" + + class Config: + populate_by_name = True + json_schema_extra = { + "example": { + "content": { + "id": "3f3d47f3-8baf-58eb-8bc2-0171fb1c6e09", + "name": "Entity Name", + "description": "Entity Description", + "metadata": {}, + }, + "result_type": "entity", + "chunk_ids": ["c68dc72e-fc23-5452-8f49-d7bd46088a96"], + "metadata": { + "associated_query": "What is the capital of France?" + }, + } + } + + +class WebPageSearchResult(R2RSerializable): + title: Optional[str] = None + link: Optional[str] = None + snippet: Optional[str] = None + position: int + type: str = "organic" + date: Optional[str] = None + sitelinks: Optional[list[dict]] = None + id: UUID + + class Config: + json_schema_extra = { + "example": { + "title": "Page Title", + "link": "https://example.com/page", + "snippet": "Page snippet", + "position": 1, + "date": "2021-01-01", + "sitelinks": [ + { + "title": "Sitelink Title", + "link": "https://example.com/sitelink", + } + ], + } + } + + def __str__(self) -> str: + return f"WebPageSearchResult(title={self.title}, link={self.link}, snippet={self.snippet})" + + +class RelatedSearchResult(R2RSerializable): + query: str + type: str = "related" + id: UUID + + +class PeopleAlsoAskResult(R2RSerializable): + question: str + snippet: str + link: str + title: str + id: UUID + type: str = "peopleAlsoAsk" + + +class WebSearchResult(R2RSerializable): + organic_results: list[WebPageSearchResult] = [] + related_searches: list[RelatedSearchResult] = [] + people_also_ask: list[PeopleAlsoAskResult] = [] + + @classmethod + def from_serper_results(cls, results: list[dict]) -> "WebSearchResult": + organic = [] + related = [] + paa = [] + + for result in results: + if result["type"] == "organic": + organic.append( + WebPageSearchResult( + **result, id=generate_id_from_label(result.get("link")) + ) + ) + elif result["type"] == "relatedSearches": + related.append( + RelatedSearchResult( + **result, + id=generate_id_from_label(result.get("query")), + ) + ) + elif result["type"] == "peopleAlsoAsk": + paa.append( + PeopleAlsoAskResult( + **result, id=generate_id_from_label(result.get("link")) + ) + ) + + return cls( + organic_results=organic, + related_searches=related, + people_also_ask=paa, + ) + + +class AggregateSearchResult(R2RSerializable): + """Result of an aggregate search operation.""" + + chunk_search_results: Optional[list[ChunkSearchResult]] = None + graph_search_results: Optional[list[GraphSearchResult]] = None + web_search_results: Optional[list[WebPageSearchResult]] = None + document_search_results: Optional[list[DocumentResponse]] = None + + def __str__(self) -> str: + return f"AggregateSearchResult(chunk_search_results={self.chunk_search_results}, graph_search_results={self.graph_search_results}, web_search_results={self.web_search_results}, document_search_results={str(self.document_search_results)})" + + def __repr__(self) -> str: + return f"AggregateSearchResult(chunk_search_results={self.chunk_search_results}, graph_search_results={self.graph_search_results}, web_search_results={self.web_search_results}, document_search_results={str(self.document_search_results)})" + + def as_dict(self) -> dict: + return { + "chunk_search_results": ( + [result.as_dict() for result in self.chunk_search_results] + if self.chunk_search_results + else [] + ), + "graph_search_results": ( + [result.to_dict() for result in self.graph_search_results] + if self.graph_search_results + else [] + ), + "web_search_results": ( + [result.to_dict() for result in self.web_search_results] + if self.web_search_results + else [] + ), + "document_search_results": ( + [cdr.to_dict() for cdr in self.document_search_results] + if self.document_search_results + else [] + ), + } + + class Config: + populate_by_name = True + json_schema_extra = { + "example": { + "chunk_search_results": [ + { + "id": "3f3d47f3-8baf-58eb-8bc2-0171fb1c6e09", + "document_id": "3e157b3a-8469-51db-90d9-52e7d896b49b", + "owner_id": "2acb499e-8428-543b-bd85-0d9098718220", + "collection_ids": [], + "score": 0.23943702876567796, + "text": "Example text from the document", + "metadata": { + "title": "example_document.pdf", + "associated_query": "What is the capital of France?", + }, + } + ], + "graph_search_results": [ + { + "content": { + "id": "3f3d47f3-8baf-58eb-8bc2-0171fb1c6e09", + "name": "Entity Name", + "description": "Entity Description", + "metadata": {}, + }, + "result_type": "entity", + "chunk_ids": ["c68dc72e-fc23-5452-8f49-d7bd46088a96"], + "metadata": { + "associated_query": "What is the capital of France?" + }, + } + ], + "web_search_results": [ + { + "title": "Page Title", + "link": "https://example.com/page", + "snippet": "Page snippet", + "position": 1, + "date": "2021-01-01", + "sitelinks": [ + { + "title": "Sitelink Title", + "link": "https://example.com/sitelink", + } + ], + } + ], + "document_search_results": [ + { + "document": { + "id": "3f3d47f3-8baf-58eb-8bc2-0171fb1c6e09", + "title": "Document Title", + "chunks": ["Chunk 1", "Chunk 2"], + "metadata": {}, + }, + } + ], + } + } + + +class HybridSearchSettings(R2RSerializable): + """Settings for hybrid search combining full-text and semantic search.""" + + full_text_weight: float = Field( + default=1.0, description="Weight to apply to full text search" + ) + semantic_weight: float = Field( + default=5.0, description="Weight to apply to semantic search" + ) + full_text_limit: int = Field( + default=200, + description="Maximum number of results to return from full text search", + ) + rrf_k: int = Field( + default=50, description="K-value for RRF (Rank Reciprocal Fusion)" + ) + + +class ChunkSearchSettings(R2RSerializable): + """Settings specific to chunk/vector search.""" + + index_measure: IndexMeasure = Field( + default=IndexMeasure.cosine_distance, + description="The distance measure to use for indexing", + ) + probes: int = Field( + default=10, + description="Number of ivfflat index lists to query. Higher increases accuracy but decreases speed.", + ) + ef_search: int = Field( + default=40, + description="Size of the dynamic candidate list for HNSW index search. Higher increases accuracy but decreases speed.", + ) + enabled: bool = Field( + default=True, + description="Whether to enable chunk search", + ) + + +class GraphSearchSettings(R2RSerializable): + """Settings specific to knowledge graph search.""" + + generation_config: Optional[GenerationConfig] = Field( + default=None, + description="Configuration for text generation during graph search.", + ) + max_community_description_length: int = Field( + default=65536, + ) + max_llm_queries_for_global_search: int = Field( + default=250, + ) + limits: dict[str, int] = Field( + default={}, + ) + enabled: bool = Field( + default=True, + description="Whether to enable graph search", + ) + + +class SearchSettings(R2RSerializable): + """Main search settings class that combines shared settings with + specialized settings for chunks and graph.""" + + # Search type flags + use_hybrid_search: bool = Field( + default=False, + description="Whether to perform a hybrid search. This is equivalent to setting `use_semantic_search=True` and `use_fulltext_search=True`, e.g. combining vector and keyword search.", + ) + use_semantic_search: bool = Field( + default=True, + description="Whether to use semantic search", + ) + use_fulltext_search: bool = Field( + default=False, + description="Whether to use full-text search", + ) + + # Common search parameters + filters: dict[str, Any] = Field( + default_factory=dict, + description="""Filters to apply to the search. Allowed operators include `eq`, `neq`, `gt`, `gte`, `lt`, `lte`, `like`, `ilike`, `in`, and `nin`. + + Commonly seen filters include operations include the following: + + `{"document_id": {"$eq": "9fbe403b-..."}}` + + `{"document_id": {"$in": ["9fbe403b-...", "3e157b3a-..."]}}` + + `{"collection_ids": {"$overlap": ["122fdf6a-...", "..."]}}` + + `{"$and": {"$document_id": ..., "collection_ids": ...}}`""", + ) + limit: int = Field( + default=10, + description="Maximum number of results to return", + ge=1, + le=1_000, + ) + offset: int = Field( + default=0, + ge=0, + description="Offset to paginate search results", + ) + include_metadatas: bool = Field( + default=True, + description="Whether to include element metadata in the search results", + ) + include_scores: bool = Field( + default=True, + description="""Whether to include search score values in the + search results""", + ) + + # Search strategy and settings + search_strategy: str = Field( + default="vanilla", + description="""Search strategy to use + (e.g., 'vanilla', 'query_fusion', 'hyde')""", + ) + hybrid_settings: HybridSearchSettings = Field( + default_factory=HybridSearchSettings, + description="""Settings for hybrid search (only used if + `use_semantic_search` and `use_fulltext_search` are both true)""", + ) + + # Specialized settings + chunk_settings: ChunkSearchSettings = Field( + default_factory=ChunkSearchSettings, + description="Settings specific to chunk/vector search", + ) + graph_settings: GraphSearchSettings = Field( + default_factory=GraphSearchSettings, + description="Settings specific to knowledge graph search", + ) + + # For HyDE or multi-query: + num_sub_queries: int = Field( + default=5, + description="Number of sub-queries/hypothetical docs to generate when using hyde or rag_fusion search strategies.", + ) + + class Config: + populate_by_name = True + json_encoders = {UUID: str} + json_schema_extra = { + "example": { + "use_semantic_search": True, + "use_fulltext_search": False, + "use_hybrid_search": False, + "filters": {"category": "technology"}, + "limit": 20, + "offset": 0, + "search_strategy": "vanilla", + "hybrid_settings": { + "full_text_weight": 1.0, + "semantic_weight": 5.0, + "full_text_limit": 200, + "rrf_k": 50, + }, + "chunk_settings": { + "enabled": True, + "index_measure": "cosine_distance", + "include_metadata": True, + "probes": 10, + "ef_search": 40, + }, + "graph_settings": { + "enabled": True, + "generation_config": GenerationConfig.Config.json_schema_extra, + "max_community_description_length": 65536, + "max_llm_queries_for_global_search": 250, + "limits": { + "entity": 20, + "relationship": 20, + "community": 20, + }, + }, + } + } + + def __init__(self, **data): + # Handle legacy search_filters field + data["filters"] = { + **data.get("filters", {}), + **data.get("search_filters", {}), + } + super().__init__(**data) + + def model_dump(self, *args, **kwargs): + return super().model_dump(*args, **kwargs) + + @classmethod + def get_default(cls, mode: str) -> "SearchSettings": + """Return default search settings for a given mode.""" + if mode == "basic": + # A simpler search that relies primarily on semantic search. + return cls( + use_semantic_search=True, + use_fulltext_search=False, + use_hybrid_search=False, + search_strategy="vanilla", + # Other relevant defaults can be provided here as needed + ) + elif mode == "advanced": + # A more powerful, combined search that leverages both semantic and fulltext. + return cls( + use_semantic_search=True, + use_fulltext_search=True, + use_hybrid_search=True, + search_strategy="hyde", + # Other advanced defaults as needed + ) + else: + # For 'custom' or unrecognized modes, return a basic empty config. + return cls() + + +class SearchMode(str, Enum): + """Search modes for the search endpoint.""" + + basic = "basic" + advanced = "advanced" + custom = "custom" + + +def select_search_filters( + auth_user: Any, + search_settings: SearchSettings, +) -> dict[str, Any]: + filters = copy(search_settings.filters) + selected_collections = None + if not auth_user.is_superuser: + user_collections = set(auth_user.collection_ids) + for key in filters.keys(): + if "collection_ids" in key: + selected_collections = set(map(UUID, filters[key]["$overlap"])) + break + + if selected_collections: + allowed_collections = user_collections.intersection( + selected_collections + ) + else: + allowed_collections = user_collections + # for non-superusers, we filter by user_id and selected & allowed collections + collection_filters = { + "$or": [ + {"owner_id": {"$eq": auth_user.id}}, + {"collection_ids": {"$overlap": list(allowed_collections)}}, + ] # type: ignore + } + + filters.pop("collection_ids", None) + if filters != {}: + filters = {"$and": [collection_filters, filters]} # type: ignore + else: + filters = collection_filters + return filters diff --git a/.venv/lib/python3.12/site-packages/shared/abstractions/user.py b/.venv/lib/python3.12/site-packages/shared/abstractions/user.py new file mode 100644 index 00000000..b04ac50b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/shared/abstractions/user.py @@ -0,0 +1,69 @@ +from datetime import datetime +from typing import Optional +from uuid import UUID + +from pydantic import BaseModel, Field + +from shared.abstractions import R2RSerializable + +from ..utils import generate_default_user_collection_id + + +class Collection(BaseModel): + id: UUID + name: str + description: Optional[str] = None + created_at: datetime = Field( + default_factory=datetime.utcnow, + ) + updated_at: datetime = Field( + default_factory=datetime.utcnow, + ) + + class Config: + populate_by_name = True + from_attributes = True + + def __init__(self, **data): + super().__init__(**data) + if self.id is None: + self.id = generate_default_user_collection_id(self.name) + + +class Token(BaseModel): + token: str + token_type: str + + +class TokenData(BaseModel): + email: str + token_type: str + exp: datetime + + +class User(R2RSerializable): + id: UUID + email: str + is_active: bool = True + is_superuser: bool = False + created_at: datetime = datetime.now() + updated_at: datetime = datetime.now() + is_verified: bool = False + collection_ids: list[UUID] = [] + graph_ids: list[UUID] = [] + document_ids: list[UUID] = [] + + # Optional fields (to update or set at creation) + limits_overrides: Optional[dict] = None + metadata: Optional[dict] = None + verification_code_expiry: Optional[datetime] = None + name: Optional[str] = None + bio: Optional[str] = None + profile_picture: Optional[str] = None + total_size_in_bytes: Optional[int] = None + num_files: Optional[int] = None + + account_type: str = "password" + hashed_password: Optional[str] = None + google_id: Optional[str] = None + github_id: Optional[str] = None diff --git a/.venv/lib/python3.12/site-packages/shared/abstractions/vector.py b/.venv/lib/python3.12/site-packages/shared/abstractions/vector.py new file mode 100644 index 00000000..0b88a765 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/shared/abstractions/vector.py @@ -0,0 +1,239 @@ +"""Abstraction for a vector that can be stored in the system.""" + +from enum import Enum +from typing import Any, Optional +from uuid import UUID + +from pydantic import BaseModel, Field + +from .base import R2RSerializable + + +class VectorType(str, Enum): + FIXED = "FIXED" + + +class IndexMethod(str, Enum): + """An enum representing the index methods available. + + This class currently only supports the 'ivfflat' method but may + expand in the future. + + Attributes: + auto (str): Automatically choose the best available index method. + ivfflat (str): The ivfflat index method. + hnsw (str): The hnsw index method. + """ + + auto = "auto" + ivfflat = "ivfflat" + hnsw = "hnsw" + + def __str__(self) -> str: + return self.value + + +class IndexMeasure(str, Enum): + """An enum representing the types of distance measures available for + indexing. + + Attributes: + cosine_distance (str): The cosine distance measure for indexing. + l2_distance (str): The Euclidean (L2) distance measure for indexing. + max_inner_product (str): The maximum inner product measure for indexing. + """ + + l2_distance = "l2_distance" + max_inner_product = "max_inner_product" + cosine_distance = "cosine_distance" + l1_distance = "l1_distance" + hamming_distance = "hamming_distance" + jaccard_distance = "jaccard_distance" + + def __str__(self) -> str: + return self.value + + @property + def ops(self) -> str: + return { + IndexMeasure.l2_distance: "_l2_ops", + IndexMeasure.max_inner_product: "_ip_ops", + IndexMeasure.cosine_distance: "_cosine_ops", + IndexMeasure.l1_distance: "_l1_ops", + IndexMeasure.hamming_distance: "_hamming_ops", + IndexMeasure.jaccard_distance: "_jaccard_ops", + }[self] + + @property + def pgvector_repr(self) -> str: + return { + IndexMeasure.l2_distance: "<->", + IndexMeasure.max_inner_product: "<#>", + IndexMeasure.cosine_distance: "<=>", + IndexMeasure.l1_distance: "<+>", + IndexMeasure.hamming_distance: "<~>", + IndexMeasure.jaccard_distance: "<%>", + }[self] + + +class IndexArgsIVFFlat(R2RSerializable): + """A class for arguments that can optionally be supplied to the index + creation method when building an IVFFlat type index. + + Attributes: + nlist (int): The number of IVF centroids that the index should use + """ + + n_lists: int + + +class IndexArgsHNSW(R2RSerializable): + """A class for arguments that can optionally be supplied to the index + creation method when building an HNSW type index. + + Ref: https://github.com/pgvector/pgvector#index-options + + Both attributes are Optional in case the user only wants to specify one and + leave the other as default + + Attributes: + m (int): Maximum number of connections per node per layer (default: 16) + ef_construction (int): Size of the dynamic candidate list for + constructing the graph (default: 64) + """ + + m: Optional[int] = 16 + ef_construction: Optional[int] = 64 + + +class VectorTableName(str, Enum): + """This enum represents the different tables where we store vectors.""" + + CHUNKS = "chunks" + ENTITIES_DOCUMENT = "documents_entities" + GRAPHS_ENTITIES = "graphs_entities" + # TODO: Add support for relationships + # TRIPLES = "relationship" + COMMUNITIES = "graphs_communities" + + def __str__(self) -> str: + return self.value + + +class VectorQuantizationType(str, Enum): + """An enum representing the types of quantization available for vectors. + + Attributes: + FP32 (str): 32-bit floating point quantization. + FP16 (str): 16-bit floating point quantization. + INT1 (str): 1-bit integer quantization. + SPARSE (str): Sparse vector quantization. + """ + + FP32 = "FP32" + FP16 = "FP16" + INT1 = "INT1" + SPARSE = "SPARSE" + + def __str__(self) -> str: + return self.value + + @property + def db_type(self) -> str: + db_type_mapping = { + "FP32": "vector", + "FP16": "halfvec", + "INT1": "bit", + "SPARSE": "sparsevec", + } + return db_type_mapping[self.value] + + +class VectorQuantizationSettings(R2RSerializable): + quantization_type: VectorQuantizationType = Field( + default=VectorQuantizationType.FP32 + ) + + +class Vector(R2RSerializable): + """A vector with the option to fix the number of elements.""" + + data: list[float] + type: VectorType = Field(default=VectorType.FIXED) + length: int = Field(default=-1) + + def __init__(self, **data): + super().__init__(**data) + if ( + self.type == VectorType.FIXED + and self.length > 0 + and len(self.data) != self.length + ): + raise ValueError( + f"Vector must be exactly {self.length} elements long." + ) + + def __repr__(self) -> str: + return ( + f"Vector(data={self.data}, type={self.type}, length={self.length})" + ) + + +class VectorEntry(R2RSerializable): + """A vector entry that can be stored directly in supported vector + databases.""" + + id: UUID + document_id: UUID + owner_id: UUID + collection_ids: list[UUID] + vector: Vector + text: str + metadata: dict[str, Any] + + def __str__(self) -> str: + """Return a string representation of the VectorEntry.""" + return ( + f"VectorEntry(" + f"chunk_id={self.id}, " + f"document_id={self.document_id}, " + f"owner_id={self.owner_id}, " + f"collection_ids={self.collection_ids}, " + f"vector={self.vector}, " + f"text={self.text}, " + f"metadata={self.metadata})" + ) + + def __repr__(self) -> str: + """Return an unambiguous string representation of the VectorEntry.""" + return self.__str__() + + +class StorageResult(R2RSerializable): + """A result of a storage operation.""" + + success: bool + document_id: UUID + num_chunks: int = 0 + error_message: Optional[str] = None + + def __str__(self) -> str: + """Return a string representation of the StorageResult.""" + return f"StorageResult(success={self.success}, error_message={self.error_message})" + + def __repr__(self) -> str: + """Return an unambiguous string representation of the StorageResult.""" + return self.__str__() + + +class IndexConfig(BaseModel): + name: Optional[str] = Field(default=None) + table_name: Optional[str] = Field(default=VectorTableName.CHUNKS) + index_method: Optional[str] = Field(default=IndexMethod.hnsw) + index_measure: Optional[str] = Field(default=IndexMeasure.cosine_distance) + index_arguments: Optional[IndexArgsIVFFlat | IndexArgsHNSW] = Field( + default=None + ) + index_name: Optional[str] = Field(default=None) + index_column: Optional[str] = Field(default=None) + concurrently: Optional[bool] = Field(default=True) diff --git a/.venv/lib/python3.12/site-packages/shared/api/models/__init__.py b/.venv/lib/python3.12/site-packages/shared/api/models/__init__.py new file mode 100644 index 00000000..2d39dab1 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/shared/api/models/__init__.py @@ -0,0 +1,194 @@ +from shared.api.models.auth.responses import ( + TokenResponse, + WrappedTokenResponse, +) +from shared.api.models.base import ( + GenericBooleanResponse, + GenericMessageResponse, + PaginatedR2RResult, + R2RResults, + WrappedBooleanResponse, + WrappedGenericMessageResponse, +) +from shared.api.models.graph.responses import ( + GraphResponse, + WrappedCommunitiesResponse, + WrappedCommunityResponse, + WrappedEntitiesResponse, + WrappedEntityResponse, + WrappedGraphResponse, + WrappedGraphsResponse, + WrappedRelationshipResponse, + WrappedRelationshipsResponse, +) +from shared.api.models.ingestion.responses import ( + IngestionResponse, + WrappedIngestionResponse, + WrappedMetadataUpdateResponse, + WrappedUpdateResponse, + WrappedVectorIndexResponse, + WrappedVectorIndicesResponse, +) +from shared.api.models.management.responses import ( + ChunkResponse, + CollectionResponse, + ConversationResponse, + MessageResponse, + PromptResponse, + ServerStats, + SettingsResponse, + WrappedAPIKeyResponse, + WrappedAPIKeysResponse, + WrappedChunkResponse, + WrappedChunksResponse, + WrappedCollectionResponse, + WrappedCollectionsResponse, + WrappedConversationMessagesResponse, + WrappedConversationResponse, + WrappedConversationsResponse, + WrappedDocumentResponse, + WrappedDocumentsResponse, + WrappedLimitsResponse, + WrappedLoginResponse, + WrappedMessageResponse, + WrappedPromptResponse, + WrappedPromptsResponse, + WrappedServerStatsResponse, + WrappedSettingsResponse, + WrappedUserResponse, + WrappedUsersResponse, +) +from shared.api.models.retrieval.responses import ( + AgentEvent, + AgentResponse, + AggregateSearchResult, + Citation, + CitationData, + CitationEvent, + Delta, + DeltaPayload, + FinalAnswerData, + FinalAnswerEvent, + MessageData, + MessageDelta, + MessageEvent, + RAGEvent, + RAGResponse, + SearchResultsData, + SearchResultsEvent, + SSEEventBase, + ThinkingData, + ThinkingEvent, + ToolCallData, + ToolCallEvent, + ToolResultData, + ToolResultEvent, + UnknownEvent, + WrappedAgentResponse, + WrappedDocumentSearchResponse, + WrappedEmbeddingResponse, + WrappedLLMChatCompletion, + WrappedRAGResponse, + WrappedSearchResponse, + WrappedVectorSearchResponse, +) + +__all__ = [ + # Generic Responses + "SSEEventBase", + "SearchResultsData", + "SearchResultsEvent", + "MessageDelta", + "MessageData", + "MessageEvent", + "DeltaPayload", + "Delta", + "CitationData", + "CitationEvent", + "FinalAnswerData", + "FinalAnswerEvent", + "ToolCallData", + "ToolCallEvent", + "ToolResultData", + "ToolResultEvent", + "ThinkingData", + "ThinkingEvent", + "AgentEvent", + "RAGEvent", + "UnknownEvent", + # Auth Responses + "GenericMessageResponse", + "TokenResponse", + "WrappedTokenResponse", + "WrappedGenericMessageResponse", + # Ingestion Responses + "IngestionResponse", + "WrappedIngestionResponse", + "WrappedUpdateResponse", + "WrappedVectorIndexResponse", + "WrappedVectorIndicesResponse", + "WrappedMetadataUpdateResponse", + "GraphResponse", + "WrappedGraphResponse", + "WrappedGraphsResponse", + "WrappedEntityResponse", + "WrappedEntitiesResponse", + "WrappedRelationshipResponse", + "WrappedRelationshipsResponse", + "WrappedCommunityResponse", + "WrappedCommunitiesResponse", + # Management Responses + "PromptResponse", + "ServerStats", + "SettingsResponse", + "ChunkResponse", + "CollectionResponse", + "ConversationResponse", + "MessageResponse", + "WrappedServerStatsResponse", + "WrappedSettingsResponse", + # Document Responses + "WrappedDocumentResponse", + "WrappedDocumentsResponse", + # Collection Responses + "WrappedCollectionResponse", + "WrappedCollectionsResponse", + # Prompt Responses + "WrappedPromptResponse", + "WrappedPromptsResponse", + # Chunk Responses + "WrappedChunkResponse", + "WrappedChunksResponse", + # Conversation Responses + "WrappedConversationMessagesResponse", + "WrappedConversationResponse", + "WrappedConversationsResponse", + # User Responses + "WrappedUserResponse", + "WrappedAPIKeyResponse", + "WrappedLimitsResponse", + "WrappedAPIKeysResponse", + "WrappedLoginResponse", + "WrappedUsersResponse", + "WrappedMessageResponse", + # Base Responses + "PaginatedR2RResult", + "R2RResults", + "GenericBooleanResponse", + "GenericMessageResponse", + "WrappedBooleanResponse", + "WrappedGenericMessageResponse", + # TODO: Clean up the following responses + # Retrieval Responses + "RAGResponse", + "Citation", + "WrappedRAGResponse", + "AgentResponse", + "AggregateSearchResult", + "WrappedSearchResponse", + "WrappedDocumentSearchResponse", + "WrappedVectorSearchResponse", + "WrappedAgentResponse", + "WrappedLLMChatCompletion", + "WrappedEmbeddingResponse", +] diff --git a/.venv/lib/python3.12/site-packages/shared/api/models/auth/__init__.py b/.venv/lib/python3.12/site-packages/shared/api/models/auth/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/shared/api/models/auth/__init__.py diff --git a/.venv/lib/python3.12/site-packages/shared/api/models/auth/responses.py b/.venv/lib/python3.12/site-packages/shared/api/models/auth/responses.py new file mode 100644 index 00000000..2d448945 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/shared/api/models/auth/responses.py @@ -0,0 +1,13 @@ +from pydantic import BaseModel + +from shared.abstractions import Token +from shared.api.models.base import R2RResults + + +class TokenResponse(BaseModel): + access_token: Token + refresh_token: Token + + +# Create wrapped versions of each response +WrappedTokenResponse = R2RResults[TokenResponse] diff --git a/.venv/lib/python3.12/site-packages/shared/api/models/base.py b/.venv/lib/python3.12/site-packages/shared/api/models/base.py new file mode 100644 index 00000000..e0493d0b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/shared/api/models/base.py @@ -0,0 +1,26 @@ +from typing import Generic, TypeVar + +from pydantic import BaseModel + +T = TypeVar("T") + + +class R2RResults(BaseModel, Generic[T]): + results: T + + +class PaginatedR2RResult(BaseModel, Generic[T]): + results: T + total_entries: int + + +class GenericBooleanResponse(BaseModel): + success: bool + + +class GenericMessageResponse(BaseModel): + message: str + + +WrappedBooleanResponse = R2RResults[GenericBooleanResponse] +WrappedGenericMessageResponse = R2RResults[GenericMessageResponse] diff --git a/.venv/lib/python3.12/site-packages/shared/api/models/graph/__init__.py b/.venv/lib/python3.12/site-packages/shared/api/models/graph/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/shared/api/models/graph/__init__.py diff --git a/.venv/lib/python3.12/site-packages/shared/api/models/graph/responses.py b/.venv/lib/python3.12/site-packages/shared/api/models/graph/responses.py new file mode 100644 index 00000000..a2272833 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/shared/api/models/graph/responses.py @@ -0,0 +1,31 @@ +from datetime import datetime +from typing import Optional +from uuid import UUID + +from pydantic import BaseModel + +from shared.abstractions.graph import Community, Entity, Relationship +from shared.api.models.base import PaginatedR2RResult, R2RResults + +WrappedEntityResponse = R2RResults[Entity] +WrappedEntitiesResponse = PaginatedR2RResult[list[Entity]] +WrappedRelationshipResponse = R2RResults[Relationship] +WrappedRelationshipsResponse = PaginatedR2RResult[list[Relationship]] +WrappedCommunityResponse = R2RResults[Community] +WrappedCommunitiesResponse = PaginatedR2RResult[list[Community]] + + +class GraphResponse(BaseModel): + id: UUID + collection_id: UUID + name: str + description: Optional[str] + status: str + created_at: datetime + updated_at: datetime + document_ids: list[UUID] + + +# Graph Responses +WrappedGraphResponse = R2RResults[GraphResponse] +WrappedGraphsResponse = PaginatedR2RResult[list[GraphResponse]] diff --git a/.venv/lib/python3.12/site-packages/shared/api/models/ingestion/__init__.py b/.venv/lib/python3.12/site-packages/shared/api/models/ingestion/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/shared/api/models/ingestion/__init__.py diff --git a/.venv/lib/python3.12/site-packages/shared/api/models/ingestion/responses.py b/.venv/lib/python3.12/site-packages/shared/api/models/ingestion/responses.py new file mode 100644 index 00000000..091e48e7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/shared/api/models/ingestion/responses.py @@ -0,0 +1,72 @@ +from typing import Any, Optional, TypeVar +from uuid import UUID + +from pydantic import BaseModel, Field + +from shared.api.models.base import PaginatedR2RResult, R2RResults + +T = TypeVar("T") + + +class IngestionResponse(BaseModel): + message: str = Field( + ..., + description="A message describing the result of the ingestion request.", + ) + task_id: Optional[UUID] = Field( + None, + description="The task ID of the ingestion request.", + ) + document_id: UUID = Field( + ..., + description="The ID of the document that was ingested.", + ) + + class Config: + json_schema_extra = { + "example": { + "message": "Ingestion task queued successfully.", + "task_id": "c68dc72e-fc23-5452-8f49-d7bd46088a96", + "document_id": "9fbe403b-c11c-5aae-8ade-ef22980c3ad1", + } + } + + +class UpdateResponse(BaseModel): + message: str = Field( + ..., + description="A message describing the result of the ingestion request.", + ) + task_id: Optional[UUID] = Field( + None, + description="The task ID of the ingestion request.", + ) + document_ids: list[UUID] = Field( + ..., + description="The ID of the document that was ingested.", + ) + + class Config: + json_schema_extra = { + "example": { + "message": "Update task queued successfully.", + "task_id": "c68dc72e-fc23-5452-8f49-d7bd46088a96", + "document_ids": ["9fbe403b-c11c-5aae-8ade-ef22980c3ad1"], + } + } + + +class VectorIndexResponse(BaseModel): + index: dict[str, Any] + + +class VectorIndicesResponse(BaseModel): + indices: list[VectorIndexResponse] + + +WrappedIngestionResponse = R2RResults[IngestionResponse] +WrappedMetadataUpdateResponse = R2RResults[IngestionResponse] +WrappedUpdateResponse = R2RResults[UpdateResponse] + +WrappedVectorIndexResponse = R2RResults[VectorIndexResponse] +WrappedVectorIndicesResponse = PaginatedR2RResult[VectorIndicesResponse] diff --git a/.venv/lib/python3.12/site-packages/shared/api/models/management/__init__.py b/.venv/lib/python3.12/site-packages/shared/api/models/management/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/shared/api/models/management/__init__.py diff --git a/.venv/lib/python3.12/site-packages/shared/api/models/management/responses.py b/.venv/lib/python3.12/site-packages/shared/api/models/management/responses.py new file mode 100644 index 00000000..5e8b67a2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/shared/api/models/management/responses.py @@ -0,0 +1,168 @@ +from datetime import datetime +from typing import Any, Optional +from uuid import UUID + +from pydantic import BaseModel + +from shared.abstractions.document import DocumentResponse +from shared.abstractions.llm import Message +from shared.abstractions.user import Token, User +from shared.api.models.base import PaginatedR2RResult, R2RResults + + +class PromptResponse(BaseModel): + id: UUID + name: str + template: str + created_at: datetime + updated_at: datetime + input_types: dict[str, str] + + +class ServerStats(BaseModel): + start_time: datetime + uptime_seconds: float + cpu_usage: float + memory_usage: float + + +class SettingsResponse(BaseModel): + config: dict[str, Any] + prompts: dict[str, Any] + r2r_project_name: str + # r2r_version: str + + +class ChunkResponse(BaseModel): + id: UUID + document_id: UUID + owner_id: UUID + collection_ids: list[UUID] + text: str + metadata: dict[str, Any] + vector: Optional[list[float]] = None + + +class CollectionResponse(BaseModel): + id: UUID + owner_id: Optional[UUID] + name: str + description: Optional[str] + graph_cluster_status: str + graph_sync_status: str + created_at: datetime + updated_at: datetime + user_count: int + document_count: int + + +class ConversationResponse(BaseModel): + id: UUID + created_at: datetime + user_id: Optional[UUID] = None + name: Optional[str] = None + + +class MessageResponse(BaseModel): + id: UUID + message: Message + metadata: dict[str, Any] = {} + + +class ApiKey(BaseModel): + public_key: str + api_key: str + key_id: str + name: Optional[str] = None + + +class ApiKeyNoPriv(BaseModel): + public_key: str + key_id: str + name: Optional[str] = None + updated_at: datetime + description: Optional[str] = None + + +class LoginResponse(BaseModel): + access_token: Token + refresh_token: Token + + +class UsageLimit(BaseModel): + used: int + limit: int + remaining: int + + +class StorageTypeLimit(BaseModel): + limit: int + used: int + remaining: int + + +class StorageLimits(BaseModel): + chunks: StorageTypeLimit + documents: StorageTypeLimit + collections: StorageTypeLimit + + +class RouteUsage(BaseModel): + route_per_min: UsageLimit + monthly_limit: UsageLimit + + +class Usage(BaseModel): + global_per_min: UsageLimit + monthly_limit: UsageLimit + routes: dict[str, RouteUsage] + + +class SystemDefaults(BaseModel): + global_per_min: int + route_per_min: Optional[int] + monthly_limit: int + + +class LimitsResponse(BaseModel): + storage_limits: StorageLimits + system_defaults: SystemDefaults + user_overrides: dict + effective_limits: SystemDefaults + usage: Usage + + +# Chunk Responses +WrappedChunkResponse = R2RResults[ChunkResponse] +WrappedChunksResponse = PaginatedR2RResult[list[ChunkResponse]] + +# Collection Responses +WrappedCollectionResponse = R2RResults[CollectionResponse] +WrappedCollectionsResponse = PaginatedR2RResult[list[CollectionResponse]] + +# Conversation Responses +WrappedConversationMessagesResponse = R2RResults[list[MessageResponse]] +WrappedConversationResponse = R2RResults[ConversationResponse] +WrappedConversationsResponse = PaginatedR2RResult[list[ConversationResponse]] +WrappedMessageResponse = R2RResults[MessageResponse] +WrappedMessagesResponse = PaginatedR2RResult[list[MessageResponse]] + +# Document Responses +WrappedDocumentResponse = R2RResults[DocumentResponse] +WrappedDocumentsResponse = PaginatedR2RResult[list[DocumentResponse]] + +# Prompt Responses +WrappedPromptResponse = R2RResults[PromptResponse] +WrappedPromptsResponse = PaginatedR2RResult[list[PromptResponse]] + +# System Responses +WrappedSettingsResponse = R2RResults[SettingsResponse] +WrappedServerStatsResponse = R2RResults[ServerStats] + +# User Responses +WrappedUserResponse = R2RResults[User] +WrappedUsersResponse = PaginatedR2RResult[list[User]] +WrappedAPIKeyResponse = R2RResults[ApiKey] +WrappedAPIKeysResponse = PaginatedR2RResult[list[ApiKeyNoPriv]] +WrappedLoginResponse = R2RResults[LoginResponse] +WrappedLimitsResponse = R2RResults[LimitsResponse] diff --git a/.venv/lib/python3.12/site-packages/shared/api/models/retrieval/__init__.py b/.venv/lib/python3.12/site-packages/shared/api/models/retrieval/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/shared/api/models/retrieval/__init__.py diff --git a/.venv/lib/python3.12/site-packages/shared/api/models/retrieval/responses.py b/.venv/lib/python3.12/site-packages/shared/api/models/retrieval/responses.py new file mode 100644 index 00000000..f695ebfb --- /dev/null +++ b/.venv/lib/python3.12/site-packages/shared/api/models/retrieval/responses.py @@ -0,0 +1,604 @@ +from typing import Any, Literal, Optional + +from pydantic import BaseModel, Field + +from shared.abstractions import ( + AggregateSearchResult, + ChunkSearchResult, + GraphSearchResult, + LLMChatCompletion, + Message, + WebPageSearchResult, +) +from shared.api.models.base import R2RResults +from shared.api.models.management.responses import DocumentResponse + +from ....abstractions import R2RSerializable + + +class CitationSpan(R2RSerializable): + """Represents a single occurrence of a citation in text.""" + + start_index: int = Field( + ..., description="Starting character index of the citation" + ) + end_index: int = Field( + ..., description="Ending character index of the citation" + ) + context_start: int = Field( + ..., description="Starting index of the surrounding context" + ) + context_end: int = Field( + ..., description="Ending index of the surrounding context" + ) + + +class Citation(R2RSerializable): + """ + Represents a citation reference in the RAG response. + + The first time a citation appears, it includes the full payload. + Subsequent appearances only include the citation ID and span information. + """ + + # Basic identification + id: str = Field( + ..., description="The short ID of the citation (e.g., 'e41ac2d')" + ) + object: str = Field( + "citation", description="The type of object, always 'citation'" + ) + + # Optimize payload delivery + is_new: bool = Field( + True, + description="Whether this is the first occurrence of this citation", + ) + + # Position information + span: Optional[CitationSpan] = Field( + None, description="Position of this citation occurrence in the text" + ) + + # Source information - only included for first occurrence + source_type: Optional[str] = Field( + None, description="Type of source: 'chunk', 'graph', 'web', or 'doc'" + ) + + # Full payload - only included for first occurrence + payload: ( + ChunkSearchResult + | GraphSearchResult + | WebPageSearchResult + | DocumentResponse + | dict[str, Any] + | None + ) = Field( + None, + description="The complete source object (only included for new citations)", + ) + + class Config: + extra = "ignore" + json_schema_extra = { + "example": { + "id": "e41ac2d", + "object": "citation", + "is_new": True, + "span": { + "start_index": 120, + "end_index": 129, + "context_start": 80, + "context_end": 180, + }, + "source_type": "chunk", + "payload": { + "id": "e41ac2d1-full-id", + "text": "The study found significant improvements...", + "metadata": {"title": "Research Paper"}, + }, + } + } + + +# class Citation(R2RSerializable): +# """Represents a single citation reference in the RAG response. + +# Combines both bracket metadata (start/end offsets, snippet range) and the +# mapped source fields (id, doc ID, chunk text, etc.). +# """ + +# # Bracket references +# id: str = Field(..., description="The ID of the citation object") +# object: str = Field( +# ..., +# description="The type of object, e.g. `citation`", +# ) +# payload: ( +# ChunkSearchResult +# | GraphSearchResult +# | WebPageSearchResult +# | DocumentResponse +# | None +# ) = Field( +# ..., description="The object payload and it's corresponding type" +# ) + +# class Config: +# extra = "ignore" # This tells Pydantic to ignore extra fields +# json_schema_extra = { +# "example": { +# "id": "cit.abcd123", +# "object": "citation", +# "payload": "ChunkSearchResult(...)", +# } +# } + + +class RAGResponse(R2RSerializable): + generated_answer: str = Field( + ..., description="The generated completion from the RAG process" + ) + search_results: AggregateSearchResult = Field( + ..., description="The search results used for the RAG process" + ) + citations: Optional[list[Citation]] = Field( + None, + description="Structured citation metadata, if you do citation extraction.", + ) + metadata: dict = Field( + default_factory=dict, + description="Additional data returned by the LLM provider", + ) + completion: str = Field( + ..., + description="The generated completion from the RAG process", + # deprecated=True, + ) + + class Config: + json_schema_extra = { + "example": { + "generated_answer": "The capital of France is Paris.", + "search_results": { + "chunk_search_results": [ + { + "index": 1, + "start_index": 25, + "end_index": 28, + "uri": "https://example.com/doc1", + "title": "example_document_1.pdf", + "license": "CC-BY-4.0", + } + ], + "graph_search_results": [ + { + "content": { + "id": "3f3d47f3-8baf-58eb-8bc2-0171fb1c6e09", + "name": "Entity Name", + "description": "Entity Description", + "metadata": {}, + }, + "result_type": "entity", + "chunk_ids": [ + "c68dc72e-fc23-5452-8f49-d7bd46088a96" + ], + "metadata": { + "associated_query": "What is the capital of France?" + }, + } + ], + "web_search_results": [ + { + "title": "Page Title", + "link": "https://example.com/page", + "snippet": "Page snippet", + "position": 1, + "date": "2021-01-01", + "sitelinks": [ + { + "title": "Sitelink Title", + "link": "https://example.com/sitelink", + } + ], + } + ], + "document_search_results": [ + { + "document": { + "id": "3f3d47f3-8baf-58eb-8bc2-0171fb1c6e09", + "title": "Document Title", + "chunks": ["Chunk 1", "Chunk 2"], + "metadata": {}, + }, + } + ], + }, + "citations": [ + { + "index": 1, + "rawIndex": 9, + "startIndex": 393, + "endIndex": 396, + "snippetStartIndex": 320, + "snippetEndIndex": 418, + "sourceType": "chunk", + "id": "e760bb76-1c6e-52eb-910d-0ce5b567011b", + "document_id": "e43864f5-a36f-548e-aacd-6f8d48b30c7f", + "owner_id": "2acb499e-8428-543b-bd85-0d9098718220", + "collection_ids": [ + "122fdf6a-e116-546b-a8f6-e4cb2e2c0a09" + ], + "score": 0.64, + "text": "Document Title: DeepSeek_R1.pdf\n\nText: could achieve an accuracy of ...", + "metadata": { + "title": "DeepSeek_R1.pdf", + "license": "CC-BY-4.0", + "chunk_order": 68, + "document_type": "pdf", + }, + } + ], + "metadata": { + "id": "chatcmpl-example123", + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "message": {"role": "assistant"}, + } + ], + }, + "completion": "TO BE DEPRECATED", + } + } + + +class AgentResponse(R2RSerializable): + messages: list[Message] = Field(..., description="Agent response messages") + conversation_id: str = Field( + ..., description="The conversation ID for the RAG agent response" + ) + + class Config: + json_schema_extra = { + "example": { + "messages": [ + { + "role": "assistant", + "content": """Aristotle (384–322 BC) was an Ancient + Greek philosopher and polymath whose contributions + have had a profound impact on various fields of + knowledge. + Here are some key points about his life and work: + \n\n1. **Early Life**: Aristotle was born in 384 BC in + Stagira, Chalcidice, which is near modern-day + Thessaloniki, Greece. His father, Nicomachus, was the + personal physician to King Amyntas of Macedon, which + exposed Aristotle to medical and biological knowledge + from a young age [C].\n\n2. **Education and Career**: + After the death of his parents, Aristotle was sent to + Athens to study at Plato's Academy, where he remained + for about 20 years. After Plato's death, Aristotle + left Athens and eventually became the tutor of + Alexander the Great [C]. + \n\n3. **Philosophical Contributions**: Aristotle + founded the Lyceum in Athens, where he established the + Peripatetic school of philosophy. His works cover a + wide range of subjects, including metaphysics, ethics, + politics, logic, biology, and aesthetics. His writings + laid the groundwork for many modern scientific and + philosophical inquiries [A].\n\n4. **Legacy**: + Aristotle's influence extends beyond philosophy to the + natural sciences, linguistics, economics, and + psychology. His method of systematic observation and + analysis has been foundational to the development of + modern science [A].\n\nAristotle's comprehensive + approach to knowledge and his systematic methodology + have earned him a lasting legacy as one of the + greatest philosophers of all time.\n\nSources: + \n- [A] Aristotle's broad range of writings and + influence on modern science.\n- [C] Details about + Aristotle's early life and education.""", + "name": None, + "function_call": None, + "tool_calls": None, + "metadata": { + "citations": [ + { + "index": 1, + "rawIndex": 9, + "startIndex": 393, + "endIndex": 396, + "snippetStartIndex": 320, + "snippetEndIndex": 418, + "sourceType": "chunk", + "id": "e760bb76-1c6e-52eb-910d-0ce5b567011b", + "document_id": """ + e43864f5-a36f-548e-aacd-6f8d48b30c7f + """, + "owner_id": """ + 2acb499e-8428-543b-bd85-0d9098718220 + """, + "collection_ids": [ + "122fdf6a-e116-546b-a8f6-e4cb2e2c0a09" + ], + "score": 0.64, + "text": """ + Document Title: DeepSeek_R1.pdf + \n\nText: could achieve an accuracy of ... + """, + "metadata": { + "title": "DeepSeek_R1.pdf", + "license": "CC-BY-4.0", + "chunk_order": 68, + "document_type": "pdf", + }, + } + ], + "aggregated_search_results": { + "chunk_search_results": [ + { + "id": "3f3d47f3-8baf-58eb-8bc2-0171fb1c6e09", + "document_id": "3e157b3a-8469-51db-90d9-52e7d896b49b", + "owner_id": "2acb499e-8428-543b-bd85-0d9098718220", + "collection_ids": [], + "score": 0.23943702876567796, + "text": "Example text from the document", + "metadata": { + "title": "example_document.pdf", + "associated_query": "What is the capital of France?", + }, + } + ], + "graph_search_results": [ + { + "content": { + "id": "3f3d47f3-8baf-58eb-8bc2-0171fb1c6e09", + "name": "Entity Name", + "description": "Entity Description", + "metadata": {}, + }, + "result_type": "entity", + "chunk_ids": [ + "c68dc72e-fc23-5452-8f49-d7bd46088a96" + ], + "metadata": { + "associated_query": "What is the capital of France?" + }, + } + ], + "web_search_results": [ + { + "title": "Page Title", + "link": "https://example.com/page", + "snippet": "Page snippet", + "position": 1, + "date": "2021-01-01", + "sitelinks": [ + { + "title": "Sitelink Title", + "link": "https://example.com/sitelink", + } + ], + } + ], + "document_search_results": [ + { + "document": { + "id": "3f3d47f3-8baf-58eb-8bc2-0171fb1c6e09", + "title": "Document Title", + "chunks": ["Chunk 1", "Chunk 2"], + "metadata": {}, + }, + } + ], + }, + }, + }, + ], + "conversation_id": "a32b4c5d-6e7f-8a9b-0c1d-2e3f4a5b6c7d", + } + } + + +class DocumentSearchResult(BaseModel): + document_id: str = Field( + ..., + description="The document ID", + ) + metadata: Optional[dict] = Field( + None, + description="The metadata of the document", + ) + score: float = Field( + ..., + description="The score of the document", + ) + + +# A generic base model for SSE events +class SSEEventBase(BaseModel): + event: str + data: Any + + +# Model for the search results event +class SearchResultsData(BaseModel): + id: str + object: str + data: AggregateSearchResult + + +class SearchResultsEvent(SSEEventBase): + event: Literal["search_results"] + data: SearchResultsData + + +class DeltaPayload(BaseModel): + value: str + annotations: list[Any] + + +# Model for message events (partial tokens) +class MessageDelta(BaseModel): + type: str + payload: DeltaPayload + + +class Delta(BaseModel): + content: list[MessageDelta] + + +class MessageData(BaseModel): + id: str + object: str + delta: Delta + + +class MessageEvent(SSEEventBase): + event: Literal["message"] + data: MessageData + + +# Update CitationSpan model for SSE events +class CitationSpanData(BaseModel): + start: int = Field( + ..., description="Starting character index of the citation" + ) + end: int = Field(..., description="Ending character index of the citation") + context_start: Optional[int] = Field( + None, description="Starting index of surrounding context" + ) + context_end: Optional[int] = Field( + None, description="Ending index of surrounding context" + ) + + +# Update CitationData model +class CitationData(BaseModel): + id: str = Field( + ..., description="The short ID of the citation (e.g., 'e41ac2d')" + ) + object: str = Field( + "citation", description="The type of object, always 'citation'" + ) + + # New fields from the enhanced Citation model + is_new: Optional[bool] = Field( + None, + description="Whether this is the first occurrence of this citation", + ) + + span: Optional[CitationSpanData] = Field( + None, description="Position of this citation occurrence in the text" + ) + + source_type: Optional[str] = Field( + None, description="Type of source: 'chunk', 'graph', 'web', or 'doc'" + ) + + # Optional payload field, only for first occurrence + payload: Optional[Any] = Field( + None, + description="The complete source object (only included for new citations)", + ) + + # For backward compatibility, maintain the existing fields + class Config: + populate_by_name = True + extra = "ignore" + + +# CitationEvent remains the same, but now using the updated CitationData +class CitationEvent(SSEEventBase): + event: Literal["citation"] + data: CitationData + + +# Model for the final answer event +class FinalAnswerData(BaseModel): + generated_answer: str + citations: list[Citation] # refine if you have a citation model + + +class FinalAnswerEvent(SSEEventBase): + event: Literal["final_answer"] + data: FinalAnswerData + + +# "tool_call" event +class ToolCallData(BaseModel): + tool_call_id: str + name: str + arguments: Any # If JSON arguments, use dict[str, Any], or str if needed + + +class ToolCallEvent(SSEEventBase): + event: Literal["tool_call"] + data: ToolCallData + + +# "tool_result" event +class ToolResultData(BaseModel): + tool_call_id: str + role: Literal["tool", "function"] + content: str + + +class ToolResultEvent(SSEEventBase): + event: Literal["tool_result"] + data: ToolResultData + + +# Optionally, define a fallback model for unrecognized events +class UnknownEvent(SSEEventBase): + pass + + +# 1) Define a new ThinkingEvent type +class ThinkingData(BaseModel): + id: str + object: str + delta: Delta + + +class ThinkingEvent(SSEEventBase): + event: str = "thinking" + data: ThinkingData + + +# Create a union type for all RAG events +RAGEvent = ( + SearchResultsEvent + | MessageEvent + | CitationEvent + | FinalAnswerEvent + | UnknownEvent + | ToolCallEvent + | ToolResultEvent + | ToolResultData + | ToolResultEvent +) + +AgentEvent = ( + ThinkingEvent + | SearchResultsEvent + | MessageEvent + | CitationEvent + | FinalAnswerEvent + | ToolCallEvent + | ToolResultEvent + | UnknownEvent +) + +WrappedCompletionResponse = R2RResults[LLMChatCompletion] +# Create wrapped versions of the responses +WrappedVectorSearchResponse = R2RResults[list[ChunkSearchResult]] +WrappedSearchResponse = R2RResults[AggregateSearchResult] +# FIXME: This is returning DocumentResponse, but should be DocumentSearchResult +WrappedDocumentSearchResponse = R2RResults[list[DocumentResponse]] +WrappedRAGResponse = R2RResults[RAGResponse] +WrappedAgentResponse = R2RResults[AgentResponse] +WrappedLLMChatCompletion = R2RResults[LLMChatCompletion] +WrappedEmbeddingResponse = R2RResults[list[float]] diff --git a/.venv/lib/python3.12/site-packages/shared/utils/__init__.py b/.venv/lib/python3.12/site-packages/shared/utils/__init__.py new file mode 100644 index 00000000..eb037e22 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/shared/utils/__init__.py @@ -0,0 +1,46 @@ +from .base_utils import ( + _decorate_vector_type, + _get_vector_column_str, + decrement_version, + deep_update, + dump_collector, + dump_obj, + format_search_results_for_llm, + generate_default_prompt_id, + generate_default_user_collection_id, + generate_document_id, + generate_entity_document_id, + generate_extraction_id, + generate_id, + generate_user_id, + increment_version, + validate_uuid, + yield_sse_event, +) +from .splitter.text import RecursiveCharacterTextSplitter, TextSplitter + +__all__ = [ + "format_search_results_for_llm", + # ID generation + "generate_id", + "generate_document_id", + "generate_extraction_id", + "generate_default_user_collection_id", + "generate_user_id", + "generate_default_prompt_id", + "generate_entity_document_id", + # Other + "increment_version", + "decrement_version", + "validate_uuid", + "deep_update", + # Text splitter + "RecursiveCharacterTextSplitter", + "TextSplitter", + # Vector utils + "_decorate_vector_type", + "_get_vector_column_str", + "yield_sse_event", + "dump_collector", + "dump_obj", +] diff --git a/.venv/lib/python3.12/site-packages/shared/utils/base_utils.py b/.venv/lib/python3.12/site-packages/shared/utils/base_utils.py new file mode 100644 index 00000000..1864d0b4 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/shared/utils/base_utils.py @@ -0,0 +1,783 @@ +import json +import logging +import math +import uuid +from abc import ABCMeta +from copy import deepcopy +from datetime import datetime +from typing import TYPE_CHECKING, Any, Optional, Tuple, TypeVar +from uuid import NAMESPACE_DNS, UUID, uuid4, uuid5 + +import tiktoken + +from ..abstractions import ( + AggregateSearchResult, + AsyncSyncMeta, + GraphCommunityResult, + GraphEntityResult, + GraphRelationshipResult, +) +from ..abstractions.vector import VectorQuantizationType + +if TYPE_CHECKING: + pass + + +logger = logging.getLogger() + + +def id_to_shorthand(id: str | UUID): + return str(id)[:7] + + +def format_search_results_for_llm( + results: AggregateSearchResult, + collector: Any, # SearchResultsCollector +) -> str: + """ + Instead of resetting 'source_counter' to 1, we: + - For each chunk / graph / web / doc in `results`, + - Find the aggregator index from the collector, + - Print 'Source [X]:' with that aggregator index. + """ + lines = [] + + # We'll build a quick helper to locate aggregator indices for each object: + # Or you can rely on the fact that we've added them to the collector + # in the same order. But let's do a "lookup aggregator index" approach: + + # 1) Chunk search + if results.chunk_search_results: + lines.append("Vector Search Results:") + for c in results.chunk_search_results: + lines.append(f"Source ID [{id_to_shorthand(c.id)}]:") + lines.append(c.text or "") # or c.text[:200] to truncate + + # 2) Graph search + if results.graph_search_results: + lines.append("Graph Search Results:") + for g in results.graph_search_results: + lines.append(f"Source ID [{id_to_shorthand(g.id)}]:") + if isinstance(g.content, GraphCommunityResult): + lines.append(f"Community Name: {g.content.name}") + lines.append(f"ID: {g.content.id}") + lines.append(f"Summary: {g.content.summary}") + # etc. ... + elif isinstance(g.content, GraphEntityResult): + lines.append(f"Entity Name: {g.content.name}") + lines.append(f"Description: {g.content.description}") + elif isinstance(g.content, GraphRelationshipResult): + lines.append( + f"Relationship: {g.content.subject}-{g.content.predicate}-{g.content.object}" + ) + # Add metadata if needed + + # 3) Web search + if results.web_search_results: + lines.append("Web Search Results:") + for w in results.web_search_results: + lines.append(f"Source ID [{id_to_shorthand(w.id)}]:") + lines.append(f"Title: {w.title}") + lines.append(f"Link: {w.link}") + lines.append(f"Snippet: {w.snippet}") + + # 4) Local context docs + if results.document_search_results: + lines.append("Local Context Documents:") + for doc_result in results.document_search_results: + doc_title = doc_result.title or "Untitled Document" + doc_id = doc_result.id + summary = doc_result.summary + + lines.append(f"Full Document ID: {doc_id}") + lines.append(f"Shortened Document ID: {id_to_shorthand(doc_id)}") + lines.append(f"Document Title: {doc_title}") + if summary: + lines.append(f"Summary: {summary}") + + if doc_result.chunks: + # Then each chunk inside: + for chunk in doc_result.chunks: + lines.append( + f"\nChunk ID {id_to_shorthand(chunk['id'])}:\n{chunk['text']}" + ) + + result = "\n".join(lines) + return result + + +def _generate_id_from_label(label) -> UUID: + return uuid5(NAMESPACE_DNS, label) + + +def generate_id(label: Optional[str] = None) -> UUID: + """Generates a unique run id.""" + return _generate_id_from_label( + label if label is not None else str(uuid4()) + ) + + +def generate_document_id(filename: str, user_id: UUID) -> UUID: + """Generates a unique document id from a given filename and user id.""" + safe_filename = filename.replace("/", "_") + return _generate_id_from_label(f"{safe_filename}-{str(user_id)}") + + +def generate_extraction_id( + document_id: UUID, iteration: int = 0, version: str = "0" +) -> UUID: + """Generates a unique extraction id from a given document id and + iteration.""" + return _generate_id_from_label(f"{str(document_id)}-{iteration}-{version}") + + +def generate_default_user_collection_id(user_id: UUID) -> UUID: + """Generates a unique collection id from a given user id.""" + return _generate_id_from_label(str(user_id)) + + +def generate_user_id(email: str) -> UUID: + """Generates a unique user id from a given email.""" + return _generate_id_from_label(email) + + +def generate_default_prompt_id(prompt_name: str) -> UUID: + """Generates a unique prompt id.""" + return _generate_id_from_label(prompt_name) + + +def generate_entity_document_id() -> UUID: + """Generates a unique document id inserting entities into a graph.""" + generation_time = datetime.now().isoformat() + return _generate_id_from_label(f"entity-{generation_time}") + + +def increment_version(version: str) -> str: + prefix = version[:-1] + suffix = int(version[-1]) + return f"{prefix}{suffix + 1}" + + +def decrement_version(version: str) -> str: + prefix = version[:-1] + suffix = int(version[-1]) + return f"{prefix}{max(0, suffix - 1)}" + + +def validate_uuid(uuid_str: str) -> UUID: + return UUID(uuid_str) + + +def update_settings_from_dict(server_settings, settings_dict: dict): + """Updates a settings object with values from a dictionary.""" + settings = deepcopy(server_settings) + for key, value in settings_dict.items(): + if value is not None: + if isinstance(value, dict): + for k, v in value.items(): + if isinstance(getattr(settings, key), dict): + getattr(settings, key)[k] = v + else: + setattr(getattr(settings, key), k, v) + else: + setattr(settings, key, value) + + return settings + + +def _decorate_vector_type( + input_str: str, + quantization_type: VectorQuantizationType = VectorQuantizationType.FP32, +) -> str: + return f"{quantization_type.db_type}{input_str}" + + +def _get_vector_column_str( + dimension: int | float, quantization_type: VectorQuantizationType +) -> str: + """Returns a string representation of a vector column type. + + Explicitly handles the case where the dimension is not a valid number meant + to support embedding models that do not allow for specifying the dimension. + """ + if math.isnan(dimension) or dimension <= 0: + vector_dim = "" # Allows for Postgres to handle any dimension + else: + vector_dim = f"({dimension})" + return _decorate_vector_type(vector_dim, quantization_type) + + +KeyType = TypeVar("KeyType") + + +def deep_update( + mapping: dict[KeyType, Any], *updating_mappings: dict[KeyType, Any] +) -> dict[KeyType, Any]: + """ + Taken from Pydantic v1: + https://github.com/pydantic/pydantic/blob/fd2991fe6a73819b48c906e3c3274e8e47d0f761/pydantic/utils.py#L200 + """ + updated_mapping = mapping.copy() + for updating_mapping in updating_mappings: + for k, v in updating_mapping.items(): + if ( + k in updated_mapping + and isinstance(updated_mapping[k], dict) + and isinstance(v, dict) + ): + updated_mapping[k] = deep_update(updated_mapping[k], v) + else: + updated_mapping[k] = v + return updated_mapping + + +def tokens_count_for_message(message, encoding): + """Return the number of tokens used by a single message.""" + tokens_per_message = 3 + + num_tokens = 0 + num_tokens += tokens_per_message + if message.get("function_call"): + num_tokens += len(encoding.encode(message["function_call"]["name"])) + num_tokens += len( + encoding.encode(message["function_call"]["arguments"]) + ) + elif message.get("tool_calls"): + for tool_call in message["tool_calls"]: + num_tokens += len(encoding.encode(tool_call["function"]["name"])) + num_tokens += len( + encoding.encode(tool_call["function"]["arguments"]) + ) + else: + if "content" in message: + num_tokens += len(encoding.encode(message["content"])) + + return num_tokens + + +def num_tokens_from_messages(messages, model="gpt-4o"): + """Return the number of tokens used by a list of messages for both user and assistant.""" + try: + encoding = tiktoken.encoding_for_model(model) + except KeyError: + logger.warning("Warning: model not found. Using cl100k_base encoding.") + encoding = tiktoken.get_encoding("cl100k_base") + + tokens = 0 + for message_ in messages: + tokens += tokens_count_for_message(message_, encoding) + + tokens += 3 # every reply is primed with assistant + return tokens + + +class SearchResultsCollector: + """ + Collects search results in the form (source_type, result_obj). + Handles both object-oriented and dictionary-based search results. + """ + + def __init__(self): + # We'll store a list of (source_type, result_obj) + self._results_in_order = [] + + @property + def results(self): + """Get the results list""" + return self._results_in_order + + @results.setter + def results(self, value): + """ + Set the results directly, with automatic type detection for 'unknown' items + Handles the format: [('unknown', {...}), ('unknown', {...})] + """ + self._results_in_order = [] + + if isinstance(value, list): + for item in value: + if isinstance(item, tuple) and len(item) == 2: + source_type, result_obj = item + + # Only auto-detect if the source type is "unknown" + if source_type == "unknown": + detected_type = self._detect_result_type(result_obj) + self._results_in_order.append( + (detected_type, result_obj) + ) + else: + self._results_in_order.append( + (source_type, result_obj) + ) + else: + # If not a tuple, detect and add + detected_type = self._detect_result_type(item) + self._results_in_order.append((detected_type, item)) + else: + raise ValueError("Results must be a list") + + def add_aggregate_result(self, agg): + """ + Flatten the chunk_search_results, graph_search_results, web_search_results, + and document_search_results into the collector, including nested chunks. + """ + if hasattr(agg, "chunk_search_results") and agg.chunk_search_results: + for c in agg.chunk_search_results: + self._results_in_order.append(("chunk", c)) + + if hasattr(agg, "graph_search_results") and agg.graph_search_results: + for g in agg.graph_search_results: + self._results_in_order.append(("graph", g)) + + if hasattr(agg, "web_search_results") and agg.web_search_results: + for w in agg.web_search_results: + self._results_in_order.append(("web", w)) + + # Add documents and extract their chunks + if ( + hasattr(agg, "document_search_results") + and agg.document_search_results + ): + for doc in agg.document_search_results: + # Add the document itself + self._results_in_order.append(("doc", doc)) + + # Extract and add chunks from the document + chunks = None + if isinstance(doc, dict): + chunks = doc.get("chunks", []) + elif hasattr(doc, "chunks") and doc.chunks is not None: + chunks = doc.chunks + + if chunks: + for chunk in chunks: + # Ensure each chunk has the minimum required attributes + if isinstance(chunk, dict) and "id" in chunk: + # Add the chunk directly to results for citation lookup + self._results_in_order.append(("chunk", chunk)) + elif hasattr(chunk, "id"): + self._results_in_order.append(("chunk", chunk)) + + def add_result(self, result_obj, source_type=None): + """ + Add a single result object to the collector. + If source_type is not provided, automatically detect the type. + """ + if source_type: + self._results_in_order.append((source_type, result_obj)) + return source_type + + detected_type = self._detect_result_type(result_obj) + self._results_in_order.append((detected_type, result_obj)) + return detected_type + + def _detect_result_type(self, obj): + """ + Detect the type of a result object based on its properties. + Works with both object attributes and dictionary keys. + """ + # Handle dictionary types first (common for web search results) + if isinstance(obj, dict): + # Web search pattern + if all(k in obj for k in ["title", "link"]) and any( + k in obj for k in ["snippet", "description"] + ): + return "web" + + # Check for graph dictionary patterns + if "content" in obj and isinstance(obj["content"], dict): + content = obj["content"] + if all(k in content for k in ["name", "description"]): + return "graph" # Entity + if all( + k in content for k in ["subject", "predicate", "object"] + ): + return "graph" # Relationship + if all(k in content for k in ["name", "summary"]): + return "graph" # Community + + # Chunk pattern + if all(k in obj for k in ["text", "id"]) and any( + k in obj for k in ["score", "metadata"] + ): + return "chunk" + + # Context document pattern + if "document" in obj and "chunks" in obj: + return "doc" + + # Check for explicit type indicator + if "type" in obj: + type_val = str(obj["type"]).lower() + if any(t in type_val for t in ["web", "organic"]): + return "web" + if "graph" in type_val: + return "graph" + if "chunk" in type_val: + return "chunk" + if "document" in type_val: + return "doc" + + # Handle object attributes for OOP-style results + if hasattr(obj, "result_type"): + result_type = str(obj.result_type).lower() + if result_type in ["entity", "relationship", "community"]: + return "graph" + + # Check class name hints + class_name = obj.__class__.__name__ + if "Graph" in class_name: + return "graph" + if "Chunk" in class_name: + return "chunk" + if "Web" in class_name: + return "web" + if "Document" in class_name: + return "doc" + + # Check for object attribute patterns + if hasattr(obj, "content"): + content = obj.content + if hasattr(content, "name") and hasattr(content, "description"): + return "graph" # Entity + if hasattr(content, "subject") and hasattr(content, "predicate"): + return "graph" # Relationship + if hasattr(content, "name") and hasattr(content, "summary"): + return "graph" # Community + + if ( + hasattr(obj, "text") + and hasattr(obj, "id") + and (hasattr(obj, "score") or hasattr(obj, "metadata")) + ): + return "chunk" + + if ( + hasattr(obj, "title") + and hasattr(obj, "link") + and hasattr(obj, "snippet") + ): + return "web" + + if hasattr(obj, "document") and hasattr(obj, "chunks"): + return "doc" + + # Default when type can't be determined + return "unknown" + + def find_by_short_id(self, short_id): + """Find a result by its short ID prefix with better chunk handling""" + if not short_id: + return None + + # First try direct lookup using regular iteration + for _, result_obj in self._results_in_order: + # Check dictionary objects + if isinstance(result_obj, dict) and "id" in result_obj: + result_id = str(result_obj["id"]) + if result_id.startswith(short_id): + return result_obj + + # Check object with id attribute + elif hasattr(result_obj, "id"): + obj_id = getattr(result_obj, "id", None) + if obj_id and str(obj_id).startswith(short_id): + # Convert to dict if possible + if hasattr(result_obj, "as_dict"): + return result_obj.as_dict() + elif hasattr(result_obj, "model_dump"): + return result_obj.model_dump() + elif hasattr(result_obj, "dict"): + return result_obj.dict() + else: + return result_obj + + # If not found, look for chunks inside documents that weren't extracted properly + for source_type, result_obj in self._results_in_order: + if source_type == "doc": + # Try various ways to access chunks + chunks = None + if isinstance(result_obj, dict) and "chunks" in result_obj: + chunks = result_obj["chunks"] + elif ( + hasattr(result_obj, "chunks") + and result_obj.chunks is not None + ): + chunks = result_obj.chunks + + if chunks: + for chunk in chunks: + # Try each chunk + chunk_id = None + if isinstance(chunk, dict) and "id" in chunk: + chunk_id = chunk["id"] + elif hasattr(chunk, "id"): + chunk_id = chunk.id + + if chunk_id and str(chunk_id).startswith(short_id): + return chunk + + return None + + def get_results_by_type(self, type_name): + """Get all results of a specific type""" + return [ + result_obj + for source_type, result_obj in self._results_in_order + if source_type == type_name + ] + + def __repr__(self): + """String representation showing counts by type""" + type_counts = {} + for source_type, _ in self._results_in_order: + type_counts[source_type] = type_counts.get(source_type, 0) + 1 + + return f"SearchResultsCollector with {len(self._results_in_order)} results: {type_counts}" + + def get_all_results(self) -> list[Tuple[str, Any]]: + """ + Return list of (source_type, result_obj, aggregator_index), + in the order appended. + """ + return self._results_in_order + + +def convert_nonserializable_objects(obj): + if hasattr(obj, "model_dump"): + obj = obj.model_dump() + if hasattr(obj, "as_dict"): + obj = obj.as_dict() + if hasattr(obj, "to_dict"): + obj = obj.to_dict() + + if isinstance(obj, dict): + new_obj = {} + for key, value in obj.items(): + # Convert key to string if it is a UUID or not already a string. + new_key = str(key) if not isinstance(key, str) else key + new_obj[new_key] = convert_nonserializable_objects(value) + return new_obj + elif isinstance(obj, list): + return [convert_nonserializable_objects(item) for item in obj] + elif isinstance(obj, tuple): + return tuple(convert_nonserializable_objects(item) for item in obj) + elif isinstance(obj, set): + return {convert_nonserializable_objects(item) for item in obj} + elif isinstance(obj, uuid.UUID): + return str(obj) + elif isinstance(obj, datetime): + return obj.isoformat() # Convert datetime to ISO formatted string + else: + return obj + + +def dump_obj(obj) -> list[dict[str, Any]]: + if hasattr(obj, "model_dump"): + obj = obj.model_dump() + elif hasattr(obj, "dict"): + obj = obj.dict() + elif hasattr(obj, "as_dict"): + obj = obj.as_dict() + elif hasattr(obj, "to_dict"): + obj = obj.to_dict() + obj = convert_nonserializable_objects(obj) + + return obj + + +def dump_collector(collector: SearchResultsCollector) -> list[dict[str, Any]]: + dumped = [] + for source_type, result_obj in collector.get_all_results(): + # Get the dictionary from the result object + if hasattr(result_obj, "model_dump"): + result_dict = result_obj.model_dump() + elif hasattr(result_obj, "dict"): + result_dict = result_obj.dict() + elif hasattr(result_obj, "as_dict"): + result_dict = result_obj.as_dict() + elif hasattr(result_obj, "to_dict"): + result_dict = result_obj.to_dict() + else: + result_dict = ( + result_obj # Fallback if no conversion method is available + ) + + # Use the recursive conversion on the entire dictionary + result_dict = convert_nonserializable_objects(result_dict) + + dumped.append( + { + "source_type": source_type, + "result": result_dict, + } + ) + return dumped + + +def num_tokens(text, model="gpt-4o"): + try: + encoding = tiktoken.encoding_for_model(model) + except KeyError: + encoding = tiktoken.get_encoding("cl100k_base") + + """Return the number of tokens used by a list of messages for both user and assistant.""" + return len(encoding.encode(text, disallowed_special=())) + + +class CombinedMeta(AsyncSyncMeta, ABCMeta): + pass + + +async def yield_sse_event(event_name: str, payload: dict, chunk_size=1024): + """ + Helper that yields a single SSE event in properly chunked lines. + + e.g. event: event_name + data: (partial JSON 1) + data: (partial JSON 2) + ... + [blank line to end event] + """ + + # SSE: first the "event: ..." + yield f"event: {event_name}\n" + + # Convert payload to JSON + content_str = json.dumps(payload, default=str) + + # data + yield f"data: {content_str}\n" + + # blank line signals end of SSE event + yield "\n" + + +class SSEFormatter: + """ + Enhanced formatter for Server-Sent Events (SSE) with citation tracking. + Extends the existing SSEFormatter with improved citation handling. + """ + + @staticmethod + async def yield_citation_event( + citation_data: dict, + ): + """ + Emits a citation event with optimized payload. + + Args: + citation_id: The short ID of the citation (e.g., 'abc1234') + span: (start, end) position tuple for this occurrence + payload: Source object (included only for first occurrence) + is_new: Whether this is the first time we've seen this citation + citation_id_counter: Optional counter for citation occurrences + + Yields: + Formatted SSE event lines + """ + + # Include the full payload only for new citations + if not citation_data.get("is_new") or "payload" not in citation_data: + citation_data["payload"] = None + + # Yield the event + async for line in yield_sse_event("citation", citation_data): + yield line + + @staticmethod + async def yield_final_answer_event( + final_data: dict, + ): + # Yield the event + async for line in yield_sse_event("final_answer", final_data): + yield line + + # Include other existing SSEFormatter methods for compatibility + @staticmethod + async def yield_message_event(text_segment, msg_id=None): + msg_id = msg_id or f"msg_{uuid.uuid4().hex[:8]}" + msg_payload = { + "id": msg_id, + "object": "agent.message.delta", + "delta": { + "content": [ + { + "type": "text", + "payload": { + "value": text_segment, + "annotations": [], + }, + } + ] + }, + } + async for line in yield_sse_event("message", msg_payload): + yield line + + @staticmethod + async def yield_thinking_event(text_segment, thinking_id=None): + thinking_id = thinking_id or f"think_{uuid.uuid4().hex[:8]}" + thinking_data = { + "id": thinking_id, + "object": "agent.thinking.delta", + "delta": { + "content": [ + { + "type": "text", + "payload": { + "value": text_segment, + "annotations": [], + }, + } + ] + }, + } + async for line in yield_sse_event("thinking", thinking_data): + yield line + + @staticmethod + def yield_done_event(): + return "event: done\ndata: [DONE]\n\n" + + @staticmethod + async def yield_error_event(error_message, error_id=None): + error_id = error_id or f"err_{uuid.uuid4().hex[:8]}" + error_payload = { + "id": error_id, + "object": "agent.error", + "error": {"message": error_message, "type": "agent_error"}, + } + async for line in yield_sse_event("error", error_payload): + yield line + + @staticmethod + async def yield_tool_call_event(tool_call_data): + from ..api.models.retrieval.responses import ToolCallEvent + + tc_event = ToolCallEvent(event="tool_call", data=tool_call_data) + async for line in yield_sse_event( + "tool_call", tc_event.dict()["data"] + ): + yield line + + # New helper for emitting search results: + @staticmethod + async def yield_search_results_event(aggregated_results): + payload = { + "id": "search_1", + "object": "rag.search_results", + "data": aggregated_results.as_dict(), + } + async for line in yield_sse_event("search_results", payload): + yield line + + @staticmethod + async def yield_tool_result_event(tool_result_data): + from ..api.models.retrieval.responses import ToolResultEvent + + tr_event = ToolResultEvent(event="tool_result", data=tool_result_data) + async for line in yield_sse_event( + "tool_result", tr_event.dict()["data"] + ): + yield line diff --git a/.venv/lib/python3.12/site-packages/shared/utils/splitter/__init__.py b/.venv/lib/python3.12/site-packages/shared/utils/splitter/__init__.py new file mode 100644 index 00000000..07a9f554 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/shared/utils/splitter/__init__.py @@ -0,0 +1,3 @@ +from .text import RecursiveCharacterTextSplitter + +__all__ = ["RecursiveCharacterTextSplitter"] diff --git a/.venv/lib/python3.12/site-packages/shared/utils/splitter/text.py b/.venv/lib/python3.12/site-packages/shared/utils/splitter/text.py new file mode 100644 index 00000000..92a7c81b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/shared/utils/splitter/text.py @@ -0,0 +1,2000 @@ +# Source - LangChain +# URL: https://github.com/langchain-ai/langchain/blob/6a5b084704afa22ca02f78d0464f35aed75d1ff2/libs/langchain/langchain/text_splitter.py#L851 +"""**Text Splitters** are classes for splitting text. + +**Class hierarchy:** + +.. code-block:: + + BaseDocumentTransformer --> TextSplitter --> <name>TextSplitter # Example: CharacterTextSplitter + RecursiveCharacterTextSplitter --> <name>TextSplitter + +Note: **MarkdownHeaderTextSplitter** and **HTMLHeaderTextSplitter do not derive from TextSplitter. + + +**Main helpers:** + +.. code-block:: + + Document, Tokenizer, Language, LineType, HeaderType +""" # noqa: E501 + +from __future__ import annotations + +import copy +import json +import logging +import pathlib +import re +from abc import ABC, abstractmethod +from dataclasses import dataclass +from enum import Enum +from io import BytesIO, StringIO +from typing import ( + AbstractSet, + Any, + Callable, + Collection, + Iterable, + Literal, + Optional, + Sequence, + Tuple, + Type, + TypedDict, + TypeVar, + cast, +) + +import requests +from pydantic import BaseModel, Field, PrivateAttr +from typing_extensions import NotRequired + +logger = logging.getLogger() + +TS = TypeVar("TS", bound="TextSplitter") + + +class BaseSerialized(TypedDict): + """Base class for serialized objects.""" + + lc: int + id: list[str] + name: NotRequired[str] + graph: NotRequired[dict[str, Any]] + + +class SerializedConstructor(BaseSerialized): + """Serialized constructor.""" + + type: Literal["constructor"] + kwargs: dict[str, Any] + + +class SerializedSecret(BaseSerialized): + """Serialized secret.""" + + type: Literal["secret"] + + +class SerializedNotImplemented(BaseSerialized): + """Serialized not implemented.""" + + type: Literal["not_implemented"] + repr: Optional[str] + + +def try_neq_default(value: Any, key: str, model: BaseModel) -> bool: + """Try to determine if a value is different from the default. + + Args: + value: The value. + key: The key. + model: The model. + + Returns: + Whether the value is different from the default. + """ + try: + return model.__fields__[key].get_default() != value + except Exception: + return True + + +class Serializable(BaseModel, ABC): + """Serializable base class.""" + + @classmethod + def is_lc_serializable(cls) -> bool: + """Is this class serializable?""" + return False + + @classmethod + def get_lc_namespace(cls) -> list[str]: + """Get the namespace of the langchain object. + + For example, if the class is `langchain.llms.openai.OpenAI`, then the + namespace is ["langchain", "llms", "openai"] + """ + return cls.__module__.split(".") + + @property + def lc_secrets(self) -> dict[str, str]: + """A map of constructor argument names to secret ids. + + For example, {"openai_api_key": "OPENAI_API_KEY"} + """ + return {} + + @property + def lc_attributes(self) -> dict: + """List of attribute names that should be included in the serialized + kwargs. + + These attributes must be accepted by the constructor. + """ + return {} + + @classmethod + def lc_id(cls) -> list[str]: + """A unique identifier for this class for serialization purposes. + + The unique identifier is a list of strings that describes the path to + the object. + """ + return [*cls.get_lc_namespace(), cls.__name__] + + class Config: + extra = "ignore" + + def __repr_args__(self) -> Any: + return [ + (k, v) + for k, v in super().__repr_args__() + if (k not in self.__fields__ or try_neq_default(v, k, self)) + ] + + _lc_kwargs: dict[str, Any] = PrivateAttr(default_factory=dict) + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._lc_kwargs = kwargs + + def to_json( + self, + ) -> SerializedConstructor | SerializedNotImplemented: + if not self.is_lc_serializable(): + return self.to_json_not_implemented() + + secrets = dict() + # Get latest values for kwargs if there is an attribute with same name + lc_kwargs = { + k: getattr(self, k, v) + for k, v in self._lc_kwargs.items() + if not (self.__exclude_fields__ or {}).get(k, False) # type: ignore + } + + # Merge the lc_secrets and lc_attributes from every class in the MRO + for cls in [None, *self.__class__.mro()]: + # Once we get to Serializable, we're done + if cls is Serializable: + break + + if cls: + deprecated_attributes = [ + "lc_namespace", + "lc_serializable", + ] + + for attr in deprecated_attributes: + if hasattr(cls, attr): + raise ValueError( + f"Class {self.__class__} has a deprecated " + f"attribute {attr}. Please use the corresponding " + f"classmethod instead." + ) + + # Get a reference to self bound to each class in the MRO + this = cast( + Serializable, self if cls is None else super(cls, self) + ) + + secrets.update(this.lc_secrets) + # Now also add the aliases for the secrets + # This ensures known secret aliases are hidden. + # Note: this does NOT hide any other extra kwargs + # that are not present in the fields. + for key in list(secrets): + value = secrets[key] + if key in this.__fields__: + secrets[this.__fields__[key].alias] = value # type: ignore + lc_kwargs.update(this.lc_attributes) + + # include all secrets, even if not specified in kwargs + # as these secrets may be passed as an environment variable instead + for key in secrets.keys(): + secret_value = getattr(self, key, None) or lc_kwargs.get(key) + if secret_value is not None: + lc_kwargs.update({key: secret_value}) + + return { + "lc": 1, + "type": "constructor", + "id": self.lc_id(), + "kwargs": ( + lc_kwargs + if not secrets + else _replace_secrets(lc_kwargs, secrets) + ), + } + + def to_json_not_implemented(self) -> SerializedNotImplemented: + return to_json_not_implemented(self) + + +def _replace_secrets( + root: dict[Any, Any], secrets_map: dict[str, str] +) -> dict[Any, Any]: + result = root.copy() + for path, secret_id in secrets_map.items(): + [*parts, last] = path.split(".") + current = result + for part in parts: + if part not in current: + break + current[part] = current[part].copy() + current = current[part] + if last in current: + current[last] = { + "lc": 1, + "type": "secret", + "id": [secret_id], + } + return result + + +def to_json_not_implemented(obj: object) -> SerializedNotImplemented: + """Serialize a "not implemented" object. + + Args: + obj: object to serialize + + Returns: + SerializedNotImplemented + """ + _id: list[str] = [] + try: + if hasattr(obj, "__name__"): + _id = [*obj.__module__.split("."), obj.__name__] + elif hasattr(obj, "__class__"): + _id = [ + *obj.__class__.__module__.split("."), + obj.__class__.__name__, + ] + except Exception: + pass + + result: SerializedNotImplemented = { + "lc": 1, + "type": "not_implemented", + "id": _id, + "repr": None, + } + try: + result["repr"] = repr(obj) + except Exception: + pass + return result + + +class SplitterDocument(Serializable): + """Class for storing a piece of text and associated metadata.""" + + page_content: str + """String text.""" + metadata: dict = Field(default_factory=dict) + """Arbitrary metadata about the page content (e.g., source, relationships + to other documents, etc.).""" + type: Literal["Document"] = "Document" + + def __init__(self, page_content: str, **kwargs: Any) -> None: + """Pass page_content in as positional or named arg.""" + super().__init__(page_content=page_content, **kwargs) + + @classmethod + def is_lc_serializable(cls) -> bool: + """Return whether this class is serializable.""" + return True + + @classmethod + def get_lc_namespace(cls) -> list[str]: + """Get the namespace of the langchain object.""" + return ["langchain", "schema", "document"] + + +class BaseDocumentTransformer(ABC): + """Abstract base class for document transformation systems. + + A document transformation system takes a sequence of Documents and returns a + sequence of transformed Documents. + + Example: + .. code-block:: python + + class EmbeddingsRedundantFilter(BaseDocumentTransformer, BaseModel): + embeddings: Embeddings + similarity_fn: Callable = cosine_similarity + similarity_threshold: float = 0.95 + + class Config: + arbitrary_types_allowed = True + + def transform_documents( + self, documents: Sequence[Document], **kwargs: Any + ) -> Sequence[Document]: + stateful_documents = get_stateful_documents(documents) + embedded_documents = _get_embeddings_from_stateful_docs( + self.embeddings, stateful_documents + ) + included_idxs = _filter_similar_embeddings( + embedded_documents, self.similarity_fn, self.similarity_threshold + ) + return [stateful_documents[i] for i in sorted(included_idxs)] + + async def atransform_documents( + self, documents: Sequence[Document], **kwargs: Any + ) -> Sequence[Document]: + raise NotImplementedError + """ # noqa: E501 + + @abstractmethod + def transform_documents( + self, documents: Sequence[SplitterDocument], **kwargs: Any + ) -> Sequence[SplitterDocument]: + """Transform a list of documents. + + Args: + documents: A sequence of Documents to be transformed. + + Returns: + A list of transformed Documents. + """ + + async def atransform_documents( + self, documents: Sequence[SplitterDocument], **kwargs: Any + ) -> Sequence[SplitterDocument]: + """Asynchronously transform a list of documents. + + Args: + documents: A sequence of Documents to be transformed. + + Returns: + A list of transformed Documents. + """ + raise NotImplementedError("This method is not implemented.") + # return await langchain_core.runnables.config.run_in_executor( + # None, self.transform_documents, documents, **kwargs + # ) + + +def _make_spacy_pipe_for_splitting( + pipe: str, *, max_length: int = 1_000_000 +) -> Any: # avoid importing spacy + try: + import spacy + except ImportError: + raise ImportError( + "Spacy is not installed, run `pip install spacy`." + ) from None + if pipe == "sentencizer": + from spacy.lang.en import English + + sentencizer = English() + sentencizer.add_pipe("sentencizer") + else: + sentencizer = spacy.load(pipe, exclude=["ner", "tagger"]) + sentencizer.max_length = max_length + return sentencizer + + +def _split_text_with_regex( + text: str, separator: str, keep_separator: bool +) -> list[str]: + # Now that we have the separator, split the text + if separator: + if keep_separator: + # The parentheses in the pattern keep the delimiters in the result. + _splits = re.split(f"({separator})", text) + splits = [ + _splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2) + ] + if len(_splits) % 2 == 0: + splits += _splits[-1:] + splits = [_splits[0]] + splits + else: + splits = re.split(separator, text) + else: + splits = list(text) + return [s for s in splits if s != ""] + + +class TextSplitter(BaseDocumentTransformer, ABC): + """Interface for splitting text into chunks.""" + + def __init__( + self, + chunk_size: int = 4000, + chunk_overlap: int = 200, + length_function: Callable[[str], int] = len, + keep_separator: bool = False, + add_start_index: bool = False, + strip_whitespace: bool = True, + ) -> None: + """Create a new TextSplitter. + + Args: + chunk_size: Maximum size of chunks to return + chunk_overlap: Overlap in characters between chunks + length_function: Function that measures the length of given chunks + keep_separator: Whether to keep the separator in the chunks + add_start_index: If `True`, includes chunk's start index in + metadata + strip_whitespace: If `True`, strips whitespace from the start and + end of every document + """ + if chunk_overlap > chunk_size: + raise ValueError( + f"Got a larger chunk overlap ({chunk_overlap}) than chunk size " + f"({chunk_size}), should be smaller." + ) + self._chunk_size = chunk_size + self._chunk_overlap = chunk_overlap + self._length_function = length_function + self._keep_separator = keep_separator + self._add_start_index = add_start_index + self._strip_whitespace = strip_whitespace + + @abstractmethod + def split_text(self, text: str) -> list[str]: + """Split text into multiple components.""" + + def create_documents( + self, texts: list[str], metadatas: Optional[list[dict]] = None + ) -> list[SplitterDocument]: + """Create documents from a list of texts.""" + _metadatas = metadatas or [{}] * len(texts) + documents = [] + for i, text in enumerate(texts): + index = 0 + previous_chunk_len = 0 + for chunk in self.split_text(text): + metadata = copy.deepcopy(_metadatas[i]) + if self._add_start_index: + offset = index + previous_chunk_len - self._chunk_overlap + index = text.find(chunk, max(0, offset)) + metadata["start_index"] = index + previous_chunk_len = len(chunk) + new_doc = SplitterDocument( + page_content=chunk, metadata=metadata + ) + documents.append(new_doc) + return documents + + def split_documents( + self, documents: Iterable[SplitterDocument] + ) -> list[SplitterDocument]: + """Split documents.""" + texts, metadatas = [], [] + for doc in documents: + texts.append(doc.page_content) + metadatas.append(doc.metadata) + return self.create_documents(texts, metadatas=metadatas) + + def _join_docs(self, docs: list[str], separator: str) -> Optional[str]: + text = separator.join(docs) + if self._strip_whitespace: + text = text.strip() + if text == "": + return None + else: + return text + + def _merge_splits( + self, splits: Iterable[str], separator: str + ) -> list[str]: + # We now want to combine these smaller pieces into medium size + # chunks to send to the LLM. + separator_len = self._length_function(separator) + + docs = [] + current_doc: list[str] = [] + total = 0 + for d in splits: + _len = self._length_function(d) + if ( + total + _len + (separator_len if len(current_doc) > 0 else 0) + > self._chunk_size + ): + if total > self._chunk_size: + logger.warning( + f"Created a chunk of size {total}, " + f"which is longer than the specified {self._chunk_size}" + ) + if len(current_doc) > 0: + doc = self._join_docs(current_doc, separator) + if doc is not None: + docs.append(doc) + # Keep on popping if: + # - we have a larger chunk than in the chunk overlap + # - or if we still have any chunks and the length is long + while total > self._chunk_overlap or ( + total + + _len + + (separator_len if len(current_doc) > 0 else 0) + > self._chunk_size + and total > 0 + ): + total -= self._length_function(current_doc[0]) + ( + separator_len if len(current_doc) > 1 else 0 + ) + current_doc = current_doc[1:] + current_doc.append(d) + total += _len + (separator_len if len(current_doc) > 1 else 0) + doc = self._join_docs(current_doc, separator) + if doc is not None: + docs.append(doc) + return docs + + @classmethod + def from_huggingface_tokenizer( + cls, tokenizer: Any, **kwargs: Any + ) -> TextSplitter: + """Text splitter that uses HuggingFace tokenizer to count length.""" + try: + from transformers import PreTrainedTokenizerBase + + if not isinstance(tokenizer, PreTrainedTokenizerBase): + raise ValueError( + "Tokenizer received was not an instance of PreTrainedTokenizerBase" + ) + + def _huggingface_tokenizer_length(text: str) -> int: + return len(tokenizer.encode(text)) + + except ImportError: + raise ValueError( + "Could not import transformers python package. " + "Please install it with `pip install transformers`." + ) from None + return cls(length_function=_huggingface_tokenizer_length, **kwargs) + + @classmethod + def from_tiktoken_encoder( + cls: Type[TS], + encoding_name: str = "gpt2", + model: Optional[str] = None, + allowed_special: Literal["all"] | AbstractSet[str] = set(), + disallowed_special: Literal["all"] | Collection[str] = "all", + **kwargs: Any, + ) -> TS: + """Text splitter that uses tiktoken encoder to count length.""" + try: + import tiktoken + except ImportError: + raise ImportError("""Could not import tiktoken python package. + This is needed in order to calculate max_tokens_for_prompt. + Please install it with `pip install tiktoken`.""") from None + + if model is not None: + enc = tiktoken.encoding_for_model(model) + else: + enc = tiktoken.get_encoding(encoding_name) + + def _tiktoken_encoder(text: str) -> int: + return len( + enc.encode( + text, + allowed_special=allowed_special, + disallowed_special=disallowed_special, + ) + ) + + if issubclass(cls, TokenTextSplitter): + extra_kwargs = { + "encoding_name": encoding_name, + "model": model, + "allowed_special": allowed_special, + "disallowed_special": disallowed_special, + } + kwargs = {**kwargs, **extra_kwargs} + + return cls(length_function=_tiktoken_encoder, **kwargs) + + def transform_documents( + self, documents: Sequence[SplitterDocument], **kwargs: Any + ) -> Sequence[SplitterDocument]: + """Transform sequence of documents by splitting them.""" + return self.split_documents(list(documents)) + + +class CharacterTextSplitter(TextSplitter): + """Splitting text that looks at characters.""" + + DEFAULT_SEPARATOR: str = "\n\n" + + def __init__( + self, + separator: str = DEFAULT_SEPARATOR, + is_separator_regex: bool = False, + **kwargs: Any, + ) -> None: + """Create a new TextSplitter.""" + super().__init__(**kwargs) + self._separator = separator + self._is_separator_regex = is_separator_regex + + def split_text(self, text: str) -> list[str]: + """Split incoming text and return chunks.""" + # First we naively split the large input into a bunch of smaller ones. + separator = ( + self._separator + if self._is_separator_regex + else re.escape(self._separator) + ) + splits = _split_text_with_regex(text, separator, self._keep_separator) + _separator = "" if self._keep_separator else self._separator + return self._merge_splits(splits, _separator) + + +class LineType(TypedDict): + """Line type as typed dict.""" + + metadata: dict[str, str] + content: str + + +class HeaderType(TypedDict): + """Header type as typed dict.""" + + level: int + name: str + data: str + + +class MarkdownHeaderTextSplitter: + """Splitting markdown files based on specified headers.""" + + def __init__( + self, + headers_to_split_on: list[Tuple[str, str]], + return_each_line: bool = False, + strip_headers: bool = True, + ): + """Create a new MarkdownHeaderTextSplitter. + + Args: + headers_to_split_on: Headers we want to track + return_each_line: Return each line w/ associated headers + strip_headers: Strip split headers from the content of the chunk + """ + # Output line-by-line or aggregated into chunks w/ common headers + self.return_each_line = return_each_line + # Given the headers we want to split on, + # (e.g., "#, ##, etc") order by length + self.headers_to_split_on = sorted( + headers_to_split_on, key=lambda split: len(split[0]), reverse=True + ) + # Strip headers split headers from the content of the chunk + self.strip_headers = strip_headers + + def aggregate_lines_to_chunks( + self, lines: list[LineType] + ) -> list[SplitterDocument]: + """Combine lines with common metadata into chunks + Args: + lines: Line of text / associated header metadata + """ + aggregated_chunks: list[LineType] = [] + + for line in lines: + if ( + aggregated_chunks + and aggregated_chunks[-1]["metadata"] == line["metadata"] + ): + # If the last line in the aggregated list + # has the same metadata as the current line, + # append the current content to the last lines's content + aggregated_chunks[-1]["content"] += " \n" + line["content"] + elif ( + aggregated_chunks + and aggregated_chunks[-1]["metadata"] != line["metadata"] + # may be issues if other metadata is present + and len(aggregated_chunks[-1]["metadata"]) + < len(line["metadata"]) + and aggregated_chunks[-1]["content"].split("\n")[-1][0] == "#" + and not self.strip_headers + ): + # If the last line in the aggregated list + # has different metadata as the current line, + # and has shallower header level than the current line, + # and the last line is a header, + # and we are not stripping headers, + # append the current content to the last line's content + aggregated_chunks[-1]["content"] += " \n" + line["content"] + # and update the last line's metadata + aggregated_chunks[-1]["metadata"] = line["metadata"] + else: + # Otherwise, append the current line to the aggregated list + aggregated_chunks.append(line) + + return [ + SplitterDocument( + page_content=chunk["content"], metadata=chunk["metadata"] + ) + for chunk in aggregated_chunks + ] + + def split_text(self, text: str) -> list[SplitterDocument]: + """Split markdown file + Args: + text: Markdown file""" + + # Split the input text by newline character ("\n"). + lines = text.split("\n") + # Final output + lines_with_metadata: list[LineType] = [] + # Content and metadata of the chunk currently being processed + current_content: list[str] = [] + current_metadata: dict[str, str] = {} + # Keep track of the nested header structure + # header_stack: list[dict[str, int | str]] = [] + header_stack: list[HeaderType] = [] + initial_metadata: dict[str, str] = {} + + in_code_block = False + opening_fence = "" + + for line in lines: + stripped_line = line.strip() + + if not in_code_block: + # Exclude inline code spans + if ( + stripped_line.startswith("```") + and stripped_line.count("```") == 1 + ): + in_code_block = True + opening_fence = "```" + elif stripped_line.startswith("~~~"): + in_code_block = True + opening_fence = "~~~" + else: + if stripped_line.startswith(opening_fence): + in_code_block = False + opening_fence = "" + + if in_code_block: + current_content.append(stripped_line) + continue + + # Check each line against each of the header types (e.g., #, ##) + for sep, name in self.headers_to_split_on: + # Check if line starts with a header that we intend to split on + if stripped_line.startswith(sep) and ( + # Header with no text OR header is followed by space + # Both are valid conditions that sep is being used a header + len(stripped_line) == len(sep) + or stripped_line[len(sep)] == " " + ): + # Ensure we are tracking the header as metadata + if name is not None: + # Get the current header level + current_header_level = sep.count("#") + + # Pop out headers of lower or same level from the stack + while ( + header_stack + and header_stack[-1]["level"] + >= current_header_level + ): + # We have encountered a new header + # at the same or higher level + popped_header = header_stack.pop() + # Clear the metadata for the + # popped header in initial_metadata + if popped_header["name"] in initial_metadata: + initial_metadata.pop(popped_header["name"]) + + # Push the current header to the stack + header: HeaderType = { + "level": current_header_level, + "name": name, + "data": stripped_line[len(sep) :].strip(), + } + header_stack.append(header) + # Update initial_metadata with the current header + initial_metadata[name] = header["data"] + + # Add the previous line to the lines_with_metadata + # only if current_content is not empty + if current_content: + lines_with_metadata.append( + { + "content": "\n".join(current_content), + "metadata": current_metadata.copy(), + } + ) + current_content.clear() + + if not self.strip_headers: + current_content.append(stripped_line) + + break + else: + if stripped_line: + current_content.append(stripped_line) + elif current_content: + lines_with_metadata.append( + { + "content": "\n".join(current_content), + "metadata": current_metadata.copy(), + } + ) + current_content.clear() + + current_metadata = initial_metadata.copy() + + if current_content: + lines_with_metadata.append( + { + "content": "\n".join(current_content), + "metadata": current_metadata, + } + ) + + # lines_with_metadata has each line with associated header metadata + # aggregate these into chunks based on common metadata + if not self.return_each_line: + return self.aggregate_lines_to_chunks(lines_with_metadata) + else: + return [ + SplitterDocument( + page_content=chunk["content"], metadata=chunk["metadata"] + ) + for chunk in lines_with_metadata + ] + + +class ElementType(TypedDict): + """Element type as typed dict.""" + + url: str + xpath: str + content: str + metadata: dict[str, str] + + +class HTMLHeaderTextSplitter: + """Splitting HTML files based on specified headers. + + Requires lxml package. + """ + + def __init__( + self, + headers_to_split_on: list[Tuple[str, str]], + return_each_element: bool = False, + ): + """Create a new HTMLHeaderTextSplitter. + + Args: + headers_to_split_on: list of tuples of headers we want to track + mapped to (arbitrary) keys for metadata. Allowed header values: + h1, h2, h3, h4, h5, h6 + e.g. [("h1", "Header 1"), ("h2", "Header 2)]. + return_each_element: Return each element w/ associated headers. + """ + # Output element-by-element or aggregated into chunks w/ common headers + self.return_each_element = return_each_element + self.headers_to_split_on = sorted(headers_to_split_on) + + def aggregate_elements_to_chunks( + self, elements: list[ElementType] + ) -> list[SplitterDocument]: + """Combine elements with common metadata into chunks. + + Args: + elements: HTML element content with associated identifying + info and metadata + """ + aggregated_chunks: list[ElementType] = [] + + for element in elements: + if ( + aggregated_chunks + and aggregated_chunks[-1]["metadata"] == element["metadata"] + ): + # If the last element in the aggregated list + # has the same metadata as the current element, + # append the current content to the last element's content + aggregated_chunks[-1]["content"] += " \n" + element["content"] + else: + # Otherwise, append the current element to the aggregated list + aggregated_chunks.append(element) + + return [ + SplitterDocument( + page_content=chunk["content"], metadata=chunk["metadata"] + ) + for chunk in aggregated_chunks + ] + + def split_text_from_url(self, url: str) -> list[SplitterDocument]: + """Split HTML from web URL. + + Args: + url: web URL + """ + r = requests.get(url) + return self.split_text_from_file(BytesIO(r.content)) + + def split_text(self, text: str) -> list[SplitterDocument]: + """Split HTML text string. + + Args: + text: HTML text + """ + return self.split_text_from_file(StringIO(text)) + + def split_text_from_file(self, file: Any) -> list[SplitterDocument]: + """Split HTML file. + + Args: + file: HTML file + """ + try: + from lxml import etree + except ImportError: + raise ImportError( + "Unable to import lxml, run `pip install lxml`." + ) from None + # use lxml library to parse html document and return xml ElementTree + # Explicitly encoding in utf-8 allows non-English + # html files to be processed without garbled characters + parser = etree.HTMLParser(encoding="utf-8") + tree = etree.parse(file, parser) + + # document transformation for "structure-aware" chunking is handled + # with xsl. See comments in html_chunks_with_headers.xslt for more + # detailed information. + xslt_path = ( + pathlib.Path(__file__).parent + / "document_transformers/xsl/html_chunks_with_headers.xslt" + ) + xslt_tree = etree.parse(xslt_path) + transform = etree.XSLT(xslt_tree) + result = transform(tree) + result_dom = etree.fromstring(str(result)) + + # create filter and mapping for header metadata + header_filter = [header[0] for header in self.headers_to_split_on] + header_mapping = dict(self.headers_to_split_on) + + # map xhtml namespace prefix + ns_map = {"h": "http://www.w3.org/1999/xhtml"} + + # build list of elements from DOM + elements = [] + for element in result_dom.findall("*//*", ns_map): + if element.findall("*[@class='headers']") or element.findall( + "*[@class='chunk']" + ): + elements.append( + ElementType( + url=file, + xpath="".join( + [ + node.text + for node in element.findall( + "*[@class='xpath']", ns_map + ) + ] + ), + content="".join( + [ + node.text + for node in element.findall( + "*[@class='chunk']", ns_map + ) + ] + ), + metadata={ + # Add text of specified headers to + # metadata using header mapping. + header_mapping[node.tag]: node.text + for node in filter( + lambda x: x.tag in header_filter, + element.findall( + "*[@class='headers']/*", ns_map + ), + ) + }, + ) + ) + + if not self.return_each_element: + return self.aggregate_elements_to_chunks(elements) + else: + return [ + SplitterDocument( + page_content=chunk["content"], metadata=chunk["metadata"] + ) + for chunk in elements + ] + + +# should be in newer Python versions (3.11+) +# @dataclass(frozen=True, kw_only=True, slots=True) +@dataclass(frozen=True) +class Tokenizer: + """Tokenizer data class.""" + + chunk_overlap: int + """Overlap in tokens between chunks.""" + tokens_per_chunk: int + """Maximum number of tokens per chunk.""" + decode: Callable[[list[int]], str] + """Function to decode a list of token ids to a string.""" + encode: Callable[[str], list[int]] + """Function to encode a string to a list of token ids.""" + + +def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> list[str]: + """Split incoming text and return chunks using tokenizer.""" + splits: list[str] = [] + input_ids = tokenizer.encode(text) + start_idx = 0 + cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids)) + chunk_ids = input_ids[start_idx:cur_idx] + while start_idx < len(input_ids): + splits.append(tokenizer.decode(chunk_ids)) + if cur_idx == len(input_ids): + break + start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap + cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids)) + chunk_ids = input_ids[start_idx:cur_idx] + return splits + + +class TokenTextSplitter(TextSplitter): + """Splitting text to tokens using model tokenizer.""" + + def __init__( + self, + encoding_name: str = "gpt2", + model: Optional[str] = None, + allowed_special: Literal["all"] | AbstractSet[str] = set(), + disallowed_special: Literal["all"] | Collection[str] = "all", + **kwargs: Any, + ) -> None: + """Create a new TextSplitter.""" + super().__init__(**kwargs) + try: + import tiktoken + except ImportError: + raise ImportError( + "Could not import tiktoken python package. " + "This is needed in order to for TokenTextSplitter. " + "Please install it with `pip install tiktoken`." + ) from None + + if model is not None: + enc = tiktoken.encoding_for_model(model) + else: + enc = tiktoken.get_encoding(encoding_name) + self._tokenizer = enc + self._allowed_special = allowed_special + self._disallowed_special = disallowed_special + + def split_text(self, text: str) -> list[str]: + def _encode(_text: str) -> list[int]: + return self._tokenizer.encode( + _text, + allowed_special=self._allowed_special, + disallowed_special=self._disallowed_special, + ) + + tokenizer = Tokenizer( + chunk_overlap=self._chunk_overlap, + tokens_per_chunk=self._chunk_size, + decode=self._tokenizer.decode, + encode=_encode, + ) + + return split_text_on_tokens(text=text, tokenizer=tokenizer) + + +class SentenceTransformersTokenTextSplitter(TextSplitter): + """Splitting text to tokens using sentence model tokenizer.""" + + def __init__( + self, + chunk_overlap: int = 50, + model: str = "sentence-transformers/all-mpnet-base-v2", + tokens_per_chunk: Optional[int] = None, + **kwargs: Any, + ) -> None: + """Create a new TextSplitter.""" + super().__init__(**kwargs, chunk_overlap=chunk_overlap) + + try: + from sentence_transformers import SentenceTransformer + except ImportError: + raise ImportError( + """Could not import sentence_transformer python package. + This is needed in order to for + SentenceTransformersTokenTextSplitter. + Please install it with `pip install sentence-transformers`. + """ + ) from None + + self.model = model + self._model = SentenceTransformer(self.model, trust_remote_code=True) + self.tokenizer = self._model.tokenizer + self._initialize_chunk_configuration(tokens_per_chunk=tokens_per_chunk) + + def _initialize_chunk_configuration( + self, *, tokens_per_chunk: Optional[int] + ) -> None: + self.maximum_tokens_per_chunk = cast(int, self._model.max_seq_length) + + if tokens_per_chunk is None: + self.tokens_per_chunk = self.maximum_tokens_per_chunk + else: + self.tokens_per_chunk = tokens_per_chunk + + if self.tokens_per_chunk > self.maximum_tokens_per_chunk: + raise ValueError( + f"The token limit of the models '{self.model}'" + f" is: {self.maximum_tokens_per_chunk}." + f" Argument tokens_per_chunk={self.tokens_per_chunk}" + f" > maximum token limit." + ) + + def split_text(self, text: str) -> list[str]: + def encode_strip_start_and_stop_token_ids(text: str) -> list[int]: + return self._encode(text)[1:-1] + + tokenizer = Tokenizer( + chunk_overlap=self._chunk_overlap, + tokens_per_chunk=self.tokens_per_chunk, + decode=self.tokenizer.decode, + encode=encode_strip_start_and_stop_token_ids, + ) + + return split_text_on_tokens(text=text, tokenizer=tokenizer) + + def count_tokens(self, *, text: str) -> int: + return len(self._encode(text)) + + _max_length_equal_32_bit_integer: int = 2**32 + + def _encode(self, text: str) -> list[int]: + token_ids_with_start_and_end_token_ids = self.tokenizer.encode( + text, + max_length=self._max_length_equal_32_bit_integer, + truncation="do_not_truncate", + ) + return token_ids_with_start_and_end_token_ids + + +class Language(str, Enum): + """Enum of the programming languages.""" + + CPP = "cpp" + GO = "go" + JAVA = "java" + KOTLIN = "kotlin" + JS = "js" + TS = "ts" + PHP = "php" + PROTO = "proto" + PYTHON = "python" + RST = "rst" + RUBY = "ruby" + RUST = "rust" + SCALA = "scala" + SWIFT = "swift" + MARKDOWN = "markdown" + LATEX = "latex" + HTML = "html" + SOL = "sol" + CSHARP = "csharp" + COBOL = "cobol" + C = "c" + LUA = "lua" + PERL = "perl" + + +class RecursiveCharacterTextSplitter(TextSplitter): + """Splitting text by recursively look at characters. + + Recursively tries to split by different characters to find one that works. + """ + + def __init__( + self, + separators: Optional[list[str]] = None, + keep_separator: bool = True, + is_separator_regex: bool = False, + chunk_size: int = 4000, + chunk_overlap: int = 200, + **kwargs: Any, + ) -> None: + """Create a new TextSplitter.""" + super().__init__( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + keep_separator=keep_separator, + **kwargs, + ) + self._separators = separators or ["\n\n", "\n", " ", ""] + self._is_separator_regex = is_separator_regex + self.chunk_size = chunk_size + self.chunk_overlap = chunk_overlap + + def _split_text(self, text: str, separators: list[str]) -> list[str]: + """Split incoming text and return chunks.""" + final_chunks = [] + # Get appropriate separator to use + separator = separators[-1] + new_separators = [] + for i, _s in enumerate(separators): + _separator = _s if self._is_separator_regex else re.escape(_s) + if _s == "": + separator = _s + break + if re.search(_separator, text): + separator = _s + new_separators = separators[i + 1 :] + break + + _separator = ( + separator if self._is_separator_regex else re.escape(separator) + ) + splits = _split_text_with_regex(text, _separator, self._keep_separator) + + # Now go merging things, recursively splitting longer texts. + _good_splits = [] + _separator = "" if self._keep_separator else separator + for s in splits: + if self._length_function(s) < self._chunk_size: + _good_splits.append(s) + else: + if _good_splits: + merged_text = self._merge_splits(_good_splits, _separator) + final_chunks.extend(merged_text) + _good_splits = [] + if not new_separators: + final_chunks.append(s) + else: + other_info = self._split_text(s, new_separators) + final_chunks.extend(other_info) + if _good_splits: + merged_text = self._merge_splits(_good_splits, _separator) + final_chunks.extend(merged_text) + return final_chunks + + def split_text(self, text: str) -> list[str]: + return self._split_text(text, self._separators) + + @classmethod + def from_language( + cls, language: Language, **kwargs: Any + ) -> RecursiveCharacterTextSplitter: + separators = cls.get_separators_for_language(language) + return cls(separators=separators, is_separator_regex=True, **kwargs) + + @staticmethod + def get_separators_for_language(language: Language) -> list[str]: + if language == Language.CPP: + return [ + # Split along class definitions + "\nclass ", + # Split along function definitions + "\nvoid ", + "\nint ", + "\nfloat ", + "\ndouble ", + # Split along control flow statements + "\nif ", + "\nfor ", + "\nwhile ", + "\nswitch ", + "\ncase ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.GO: + return [ + # Split along function definitions + "\nfunc ", + "\nvar ", + "\nconst ", + "\ntype ", + # Split along control flow statements + "\nif ", + "\nfor ", + "\nswitch ", + "\ncase ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.JAVA: + return [ + # Split along class definitions + "\nclass ", + # Split along method definitions + "\npublic ", + "\nprotected ", + "\nprivate ", + "\nstatic ", + # Split along control flow statements + "\nif ", + "\nfor ", + "\nwhile ", + "\nswitch ", + "\ncase ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.KOTLIN: + return [ + # Split along class definitions + "\nclass ", + # Split along method definitions + "\npublic ", + "\nprotected ", + "\nprivate ", + "\ninternal ", + "\ncompanion ", + "\nfun ", + "\nval ", + "\nvar ", + # Split along control flow statements + "\nif ", + "\nfor ", + "\nwhile ", + "\nwhen ", + "\ncase ", + "\nelse ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.JS: + return [ + # Split along function definitions + "\nfunction ", + "\nconst ", + "\nlet ", + "\nvar ", + "\nclass ", + # Split along control flow statements + "\nif ", + "\nfor ", + "\nwhile ", + "\nswitch ", + "\ncase ", + "\ndefault ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.TS: + return [ + "\nenum ", + "\ninterface ", + "\nnamespace ", + "\ntype ", + # Split along class definitions + "\nclass ", + # Split along function definitions + "\nfunction ", + "\nconst ", + "\nlet ", + "\nvar ", + # Split along control flow statements + "\nif ", + "\nfor ", + "\nwhile ", + "\nswitch ", + "\ncase ", + "\ndefault ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.PHP: + return [ + # Split along function definitions + "\nfunction ", + # Split along class definitions + "\nclass ", + # Split along control flow statements + "\nif ", + "\nforeach ", + "\nwhile ", + "\ndo ", + "\nswitch ", + "\ncase ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.PROTO: + return [ + # Split along message definitions + "\nmessage ", + # Split along service definitions + "\nservice ", + # Split along enum definitions + "\nenum ", + # Split along option definitions + "\noption ", + # Split along import statements + "\nimport ", + # Split along syntax declarations + "\nsyntax ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.PYTHON: + return [ + # First, try to split along class definitions + "\nclass ", + "\ndef ", + "\n\tdef ", + # Now split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.RST: + return [ + # Split along section titles + "\n=+\n", + "\n-+\n", + "\n\\*+\n", + # Split along directive markers + "\n\n.. *\n\n", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.RUBY: + return [ + # Split along method definitions + "\ndef ", + "\nclass ", + # Split along control flow statements + "\nif ", + "\nunless ", + "\nwhile ", + "\nfor ", + "\ndo ", + "\nbegin ", + "\nrescue ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.RUST: + return [ + # Split along function definitions + "\nfn ", + "\nconst ", + "\nlet ", + # Split along control flow statements + "\nif ", + "\nwhile ", + "\nfor ", + "\nloop ", + "\nmatch ", + "\nconst ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.SCALA: + return [ + # Split along class definitions + "\nclass ", + "\nobject ", + # Split along method definitions + "\ndef ", + "\nval ", + "\nvar ", + # Split along control flow statements + "\nif ", + "\nfor ", + "\nwhile ", + "\nmatch ", + "\ncase ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.SWIFT: + return [ + # Split along function definitions + "\nfunc ", + # Split along class definitions + "\nclass ", + "\nstruct ", + "\nenum ", + # Split along control flow statements + "\nif ", + "\nfor ", + "\nwhile ", + "\ndo ", + "\nswitch ", + "\ncase ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.MARKDOWN: + return [ + # First, try to split along Markdown headings + # (starting with level 2) + "\n#{1,6} ", + # Note the alternative syntax for headings (below) + # is not handled here + # Heading level 2 + # --------------- + # End of code block + "```\n", + # Horizontal lines + "\n\\*\\*\\*+\n", + "\n---+\n", + "\n___+\n", + # Note that this splitter doesn't handle + # horizontal lines defined + # by *three or more* of ***, ---, or ___, + # but this is not handled + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.LATEX: + return [ + # First, try to split along Latex sections + "\n\\\\chapter{", + "\n\\\\section{", + "\n\\\\subsection{", + "\n\\\\subsubsection{", + # Now split by environments + "\n\\\\begin{enumerate}", + "\n\\\\begin{itemize}", + "\n\\\\begin{description}", + "\n\\\\begin{list}", + "\n\\\\begin{quote}", + "\n\\\\begin{quotation}", + "\n\\\\begin{verse}", + "\n\\\\begin{verbatim}", + # Now split by math environments + "\n\\\begin{align}", + "$$", + "$", + # Now split by the normal type of lines + " ", + "", + ] + elif language == Language.HTML: + return [ + # First, try to split along HTML tags + "<body", + "<div", + "<p", + "<br", + "<li", + "<h1", + "<h2", + "<h3", + "<h4", + "<h5", + "<h6", + "<span", + "<table", + "<tr", + "<td", + "<th", + "<ul", + "<ol", + "<header", + "<footer", + "<nav", + # Head + "<head", + "<style", + "<script", + "<meta", + "<title", + "", + ] + elif language == Language.CSHARP: + return [ + "\ninterface ", + "\nenum ", + "\nimplements ", + "\ndelegate ", + "\nevent ", + # Split along class definitions + "\nclass ", + "\nabstract ", + # Split along method definitions + "\npublic ", + "\nprotected ", + "\nprivate ", + "\nstatic ", + "\nreturn ", + # Split along control flow statements + "\nif ", + "\ncontinue ", + "\nfor ", + "\nforeach ", + "\nwhile ", + "\nswitch ", + "\nbreak ", + "\ncase ", + "\nelse ", + # Split by exceptions + "\ntry ", + "\nthrow ", + "\nfinally ", + "\ncatch ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.SOL: + return [ + # Split along compiler information definitions + "\npragma ", + "\nusing ", + # Split along contract definitions + "\ncontract ", + "\ninterface ", + "\nlibrary ", + # Split along method definitions + "\nconstructor ", + "\ntype ", + "\nfunction ", + "\nevent ", + "\nmodifier ", + "\nerror ", + "\nstruct ", + "\nenum ", + # Split along control flow statements + "\nif ", + "\nfor ", + "\nwhile ", + "\ndo while ", + "\nassembly ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.COBOL: + return [ + # Split along divisions + "\nIDENTIFICATION DIVISION.", + "\nENVIRONMENT DIVISION.", + "\nDATA DIVISION.", + "\nPROCEDURE DIVISION.", + # Split along sections within DATA DIVISION + "\nWORKING-STORAGE SECTION.", + "\nLINKAGE SECTION.", + "\nFILE SECTION.", + # Split along sections within PROCEDURE DIVISION + "\nINPUT-OUTPUT SECTION.", + # Split along paragraphs and common statements + "\nOPEN ", + "\nCLOSE ", + "\nREAD ", + "\nWRITE ", + "\nIF ", + "\nELSE ", + "\nMOVE ", + "\nPERFORM ", + "\nUNTIL ", + "\nVARYING ", + "\nACCEPT ", + "\nDISPLAY ", + "\nSTOP RUN.", + # Split by the normal type of lines + "\n", + " ", + "", + ] + + else: + raise ValueError( + f"Language {language} is not supported! " + f"Please choose from {list(Language)}" + ) + + +class NLTKTextSplitter(TextSplitter): + """Splitting text using NLTK package.""" + + def __init__( + self, separator: str = "\n\n", language: str = "english", **kwargs: Any + ) -> None: + """Initialize the NLTK splitter.""" + super().__init__(**kwargs) + try: + from nltk.tokenize import sent_tokenize + + self._tokenizer = sent_tokenize + except ImportError: + raise ImportError("""NLTK is not installed, please install it with + `pip install nltk`.""") from None + self._separator = separator + self._language = language + + def split_text(self, text: str) -> list[str]: + """Split incoming text and return chunks.""" + # First we naively split the large input into a bunch of smaller ones. + splits = self._tokenizer(text, language=self._language) + return self._merge_splits(splits, self._separator) + + +class SpacyTextSplitter(TextSplitter): + """Splitting text using Spacy package. + + Per default, Spacy's `en_core_web_sm` model is used and + its default max_length is 1000000 (it is the length of maximum character + this model takes which can be increased for large files). For a faster, + but potentially less accurate splitting, you can use `pipe='sentencizer'`. + """ + + def __init__( + self, + separator: str = "\n\n", + pipe: str = "en_core_web_sm", + max_length: int = 1_000_000, + **kwargs: Any, + ) -> None: + """Initialize the spacy text splitter.""" + super().__init__(**kwargs) + self._tokenizer = _make_spacy_pipe_for_splitting( + pipe, max_length=max_length + ) + self._separator = separator + + def split_text(self, text: str) -> list[str]: + """Split incoming text and return chunks.""" + splits = (s.text for s in self._tokenizer(text).sents) + return self._merge_splits(splits, self._separator) + + +class KonlpyTextSplitter(TextSplitter): + """Splitting text using Konlpy package. + + It is good for splitting Korean text. + """ + + def __init__( + self, + separator: str = "\n\n", + **kwargs: Any, + ) -> None: + """Initialize the Konlpy text splitter.""" + super().__init__(**kwargs) + self._separator = separator + try: + from konlpy.tag import Kkma + except ImportError: + raise ImportError(""" + Konlpy is not installed, please install it with + `pip install konlpy` + """) from None + self.kkma = Kkma() + + def split_text(self, text: str) -> list[str]: + """Split incoming text and return chunks.""" + splits = self.kkma.sentences(text) + return self._merge_splits(splits, self._separator) + + +# For backwards compatibility +class PythonCodeTextSplitter(RecursiveCharacterTextSplitter): + """Attempts to split the text along Python syntax.""" + + def __init__(self, **kwargs: Any) -> None: + """Initialize a PythonCodeTextSplitter.""" + separators = self.get_separators_for_language(Language.PYTHON) + super().__init__(separators=separators, **kwargs) + + +class MarkdownTextSplitter(RecursiveCharacterTextSplitter): + """Attempts to split the text along Markdown-formatted headings.""" + + def __init__(self, **kwargs: Any) -> None: + """Initialize a MarkdownTextSplitter.""" + separators = self.get_separators_for_language(Language.MARKDOWN) + super().__init__(separators=separators, **kwargs) + + +class LatexTextSplitter(RecursiveCharacterTextSplitter): + """Attempts to split the text along Latex-formatted layout elements.""" + + def __init__(self, **kwargs: Any) -> None: + """Initialize a LatexTextSplitter.""" + separators = self.get_separators_for_language(Language.LATEX) + super().__init__(separators=separators, **kwargs) + + +class RecursiveJsonSplitter: + def __init__( + self, max_chunk_size: int = 2000, min_chunk_size: Optional[int] = None + ): + super().__init__() + self.max_chunk_size = max_chunk_size + self.min_chunk_size = ( + min_chunk_size + if min_chunk_size is not None + else max(max_chunk_size - 200, 50) + ) + + @staticmethod + def _json_size(data: dict) -> int: + """Calculate the size of the serialized JSON object.""" + return len(json.dumps(data)) + + @staticmethod + def _set_nested_dict(d: dict, path: list[str], value: Any) -> None: + """Set a value in a nested dictionary based on the given path.""" + for key in path[:-1]: + d = d.setdefault(key, {}) + d[path[-1]] = value + + def _list_to_dict_preprocessing(self, data: Any) -> Any: + if isinstance(data, dict): + # Process each key-value pair in the dictionary + return { + k: self._list_to_dict_preprocessing(v) for k, v in data.items() + } + elif isinstance(data, list): + # Convert the list to a dictionary with index-based keys + return { + str(i): self._list_to_dict_preprocessing(item) + for i, item in enumerate(data) + } + else: + # The item is neither a dict nor a list, return unchanged + return data + + def _json_split( + self, + data: dict[str, Any], + current_path: list[str] | None = None, + chunks: list[dict] | None = None, + ) -> list[dict]: + """Split json into maximum size dictionaries while preserving + structure.""" + if current_path is None: + current_path = [] + if chunks is None: + chunks = [{}] + + if isinstance(data, dict): + for key, value in data.items(): + new_path = current_path + [key] + chunk_size = self._json_size(chunks[-1]) + size = self._json_size({key: value}) + remaining = self.max_chunk_size - chunk_size + + if size < remaining: + # Add item to current chunk + self._set_nested_dict(chunks[-1], new_path, value) + else: + if chunk_size >= self.min_chunk_size: + # Chunk is big enough, start a new chunk + chunks.append({}) + + # Iterate + self._json_split(value, new_path, chunks) + else: + # handle single item + self._set_nested_dict(chunks[-1], current_path, data) + return chunks + + def split_json( + self, + json_data: dict[str, Any], + convert_lists: bool = False, + ) -> list[dict]: + """Splits JSON into a list of JSON chunks.""" + + if convert_lists: + chunks = self._json_split( + self._list_to_dict_preprocessing(json_data) + ) + else: + chunks = self._json_split(json_data) + + # Remove the last chunk if it's empty + if not chunks[-1]: + chunks.pop() + return chunks + + def split_text( + self, json_data: dict[str, Any], convert_lists: bool = False + ) -> list[str]: + """Splits JSON into a list of JSON formatted strings.""" + + chunks = self.split_json( + json_data=json_data, convert_lists=convert_lists + ) + + # Convert to string + return [json.dumps(chunk) for chunk in chunks] + + def create_documents( + self, + texts: list[dict], + convert_lists: bool = False, + metadatas: Optional[list[dict]] = None, + ) -> list[SplitterDocument]: + """Create documents from a list of json objects (dict).""" + _metadatas = metadatas or [{}] * len(texts) + documents = [] + for i, text in enumerate(texts): + for chunk in self.split_text( + json_data=text, convert_lists=convert_lists + ): + metadata = copy.deepcopy(_metadatas[i]) + new_doc = SplitterDocument( + page_content=chunk, metadata=metadata + ) + documents.append(new_doc) + return documents |