aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/shared
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/shared
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to '.venv/lib/python3.12/site-packages/shared')
-rw-r--r--.venv/lib/python3.12/site-packages/shared/__init__.py7
-rw-r--r--.venv/lib/python3.12/site-packages/shared/abstractions/__init__.py146
-rw-r--r--.venv/lib/python3.12/site-packages/shared/abstractions/base.py145
-rw-r--r--.venv/lib/python3.12/site-packages/shared/abstractions/document.py377
-rw-r--r--.venv/lib/python3.12/site-packages/shared/abstractions/embedding.py31
-rw-r--r--.venv/lib/python3.12/site-packages/shared/abstractions/exception.py75
-rw-r--r--.venv/lib/python3.12/site-packages/shared/abstractions/graph.py257
-rw-r--r--.venv/lib/python3.12/site-packages/shared/abstractions/llm.py325
-rw-r--r--.venv/lib/python3.12/site-packages/shared/abstractions/prompt.py39
-rw-r--r--.venv/lib/python3.12/site-packages/shared/abstractions/search.py614
-rw-r--r--.venv/lib/python3.12/site-packages/shared/abstractions/user.py69
-rw-r--r--.venv/lib/python3.12/site-packages/shared/abstractions/vector.py239
-rw-r--r--.venv/lib/python3.12/site-packages/shared/api/models/__init__.py194
-rw-r--r--.venv/lib/python3.12/site-packages/shared/api/models/auth/__init__.py0
-rw-r--r--.venv/lib/python3.12/site-packages/shared/api/models/auth/responses.py13
-rw-r--r--.venv/lib/python3.12/site-packages/shared/api/models/base.py26
-rw-r--r--.venv/lib/python3.12/site-packages/shared/api/models/graph/__init__.py0
-rw-r--r--.venv/lib/python3.12/site-packages/shared/api/models/graph/responses.py31
-rw-r--r--.venv/lib/python3.12/site-packages/shared/api/models/ingestion/__init__.py0
-rw-r--r--.venv/lib/python3.12/site-packages/shared/api/models/ingestion/responses.py72
-rw-r--r--.venv/lib/python3.12/site-packages/shared/api/models/management/__init__.py0
-rw-r--r--.venv/lib/python3.12/site-packages/shared/api/models/management/responses.py168
-rw-r--r--.venv/lib/python3.12/site-packages/shared/api/models/retrieval/__init__.py0
-rw-r--r--.venv/lib/python3.12/site-packages/shared/api/models/retrieval/responses.py604
-rw-r--r--.venv/lib/python3.12/site-packages/shared/utils/__init__.py46
-rw-r--r--.venv/lib/python3.12/site-packages/shared/utils/base_utils.py783
-rw-r--r--.venv/lib/python3.12/site-packages/shared/utils/splitter/__init__.py3
-rw-r--r--.venv/lib/python3.12/site-packages/shared/utils/splitter/text.py2000
28 files changed, 6264 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/shared/__init__.py b/.venv/lib/python3.12/site-packages/shared/__init__.py
new file mode 100644
index 00000000..5abf8a6a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/shared/__init__.py
@@ -0,0 +1,7 @@
+from .abstractions import *
+from .abstractions import __all__ as abstractions_all
+from .api.models import *
+from .api.models import __all__ as api_models_all
+from .utils import *
+
+__all__ = abstractions_all + api_models_all
diff --git a/.venv/lib/python3.12/site-packages/shared/abstractions/__init__.py b/.venv/lib/python3.12/site-packages/shared/abstractions/__init__.py
new file mode 100644
index 00000000..da33ddd7
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/shared/abstractions/__init__.py
@@ -0,0 +1,146 @@
+from .base import AsyncSyncMeta, R2RSerializable, syncable
+from .document import (
+ Document,
+ DocumentChunk,
+ DocumentResponse,
+ DocumentType,
+ GraphConstructionStatus,
+ GraphExtractionStatus,
+ IngestionMode,
+ IngestionStatus,
+ RawChunk,
+ UnprocessedChunk,
+)
+from .embedding import EmbeddingPurpose, default_embedding_prefixes
+from .exception import (
+ PDFParsingError,
+ PopplerNotFoundError,
+ R2RDocumentProcessingError,
+ R2RException,
+)
+from .graph import (
+ Community,
+ Entity,
+ GraphCommunitySettings,
+ GraphCreationSettings,
+ GraphEnrichmentSettings,
+ GraphExtraction,
+ Relationship,
+ StoreType,
+)
+from .llm import (
+ GenerationConfig,
+ LLMChatCompletion,
+ LLMChatCompletionChunk,
+ Message,
+ MessageType,
+ RAGCompletion,
+)
+from .prompt import Prompt
+from .search import (
+ AggregateSearchResult,
+ ChunkSearchResult,
+ ChunkSearchSettings,
+ GraphCommunityResult,
+ GraphEntityResult,
+ GraphRelationshipResult,
+ GraphSearchResult,
+ GraphSearchResultType,
+ GraphSearchSettings,
+ HybridSearchSettings,
+ SearchMode,
+ SearchSettings,
+ WebPageSearchResult,
+ select_search_filters,
+)
+from .user import Token, TokenData, User
+from .vector import (
+ IndexArgsHNSW,
+ IndexArgsIVFFlat,
+ IndexMeasure,
+ IndexMethod,
+ StorageResult,
+ Vector,
+ VectorEntry,
+ VectorQuantizationType,
+ VectorTableName,
+ VectorType,
+)
+
+__all__ = [
+ # Base abstractions
+ "R2RSerializable",
+ "AsyncSyncMeta",
+ "syncable",
+ # Completion abstractions
+ "MessageType",
+ # Document abstractions
+ "Document",
+ "DocumentChunk",
+ "DocumentResponse",
+ "IngestionMode",
+ "IngestionStatus",
+ "GraphExtractionStatus",
+ "GraphConstructionStatus",
+ "DocumentType",
+ "RawChunk",
+ "UnprocessedChunk",
+ # Embedding abstractions
+ "EmbeddingPurpose",
+ "default_embedding_prefixes",
+ # Exception abstractions
+ "R2RDocumentProcessingError",
+ "R2RException",
+ "PDFParsingError",
+ "PopplerNotFoundError",
+ # Graph abstractions
+ "Entity",
+ "Community",
+ "Community",
+ "GraphExtraction",
+ "Relationship",
+ "StoreType",
+ # LLM abstractions
+ "GenerationConfig",
+ "LLMChatCompletion",
+ "LLMChatCompletionChunk",
+ "Message",
+ "RAGCompletion",
+ # Prompt abstractions
+ "Prompt",
+ # Search abstractions
+ "AggregateSearchResult",
+ "GraphSearchResult",
+ "WebPageSearchResult",
+ "GraphSearchResultType",
+ "GraphEntityResult",
+ "GraphRelationshipResult",
+ "GraphCommunityResult",
+ "GraphSearchSettings",
+ "ChunkSearchSettings",
+ "ChunkSearchResult",
+ "SearchSettings",
+ "select_search_filters",
+ "HybridSearchSettings",
+ "SearchMode",
+ # graph abstractions
+ "GraphCreationSettings",
+ "GraphEnrichmentSettings",
+ "GraphExtraction",
+ "GraphCommunitySettings",
+ # User abstractions
+ "Token",
+ "TokenData",
+ "User",
+ # Vector abstractions
+ "Vector",
+ "VectorEntry",
+ "VectorType",
+ "IndexMethod",
+ "IndexMeasure",
+ "IndexArgsIVFFlat",
+ "IndexArgsHNSW",
+ "VectorTableName",
+ "VectorQuantizationType",
+ "StorageResult",
+]
diff --git a/.venv/lib/python3.12/site-packages/shared/abstractions/base.py b/.venv/lib/python3.12/site-packages/shared/abstractions/base.py
new file mode 100644
index 00000000..d90ba400
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/shared/abstractions/base.py
@@ -0,0 +1,145 @@
+import asyncio
+import json
+from datetime import datetime
+from enum import Enum
+from typing import Any, Type, TypeVar
+from uuid import UUID
+
+from pydantic import BaseModel
+
+T = TypeVar("T", bound="R2RSerializable")
+
+
+class R2RSerializable(BaseModel):
+ @classmethod
+ def from_dict(cls: Type[T], data: dict[str, Any] | str) -> T:
+ if isinstance(data, str):
+ try:
+ data_dict = json.loads(data)
+ except json.JSONDecodeError as e:
+ raise ValueError(f"Invalid JSON string: {e}") from e
+ else:
+ data_dict = data
+ return cls(**data_dict)
+
+ def as_dict(self) -> dict[str, Any]:
+ data = self.model_dump(exclude_unset=True)
+ return self._serialize_values(data)
+
+ def to_dict(self) -> dict[str, Any]:
+ data = self.model_dump(exclude_unset=True)
+ return self._serialize_values(data)
+
+ def to_json(self) -> str:
+ data = self.to_dict()
+ return json.dumps(data)
+
+ @classmethod
+ def from_json(cls: Type[T], json_str: str) -> T:
+ return cls.model_validate_json(json_str)
+
+ @staticmethod
+ def _serialize_values(data: Any) -> Any:
+ if isinstance(data, dict):
+ return {
+ k: R2RSerializable._serialize_values(v)
+ for k, v in data.items()
+ }
+ elif isinstance(data, list):
+ return [R2RSerializable._serialize_values(v) for v in data]
+ elif isinstance(data, UUID):
+ return str(data)
+ elif isinstance(data, Enum):
+ return data.value
+ elif isinstance(data, datetime):
+ return data.isoformat()
+ else:
+ return data
+
+ class Config:
+ arbitrary_types_allowed = True
+ json_encoders = {
+ UUID: str,
+ bytes: lambda v: v.decode("utf-8", errors="ignore"),
+ }
+
+
+class AsyncSyncMeta(type):
+ _event_loop = None # Class-level shared event loop
+
+ @classmethod
+ def get_event_loop(cls):
+ if cls._event_loop is None or cls._event_loop.is_closed():
+ cls._event_loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(cls._event_loop)
+ return cls._event_loop
+
+ def __new__(cls, name, bases, dct):
+ new_cls = super().__new__(cls, name, bases, dct)
+ for attr_name, attr_value in dct.items():
+ if asyncio.iscoroutinefunction(attr_value) and getattr(
+ attr_value, "_syncable", False
+ ):
+ sync_method_name = attr_name[
+ 1:
+ ] # Remove leading 'a' for sync method
+ async_method = attr_value
+
+ def make_sync_method(async_method):
+ def sync_wrapper(self, *args, **kwargs):
+ loop = cls.get_event_loop()
+ if not loop.is_running():
+ # Setup to run the loop in a background thread if necessary
+ # to prevent blocking the main thread in a synchronous call environment
+ from threading import Thread
+
+ result = None
+ exception = None
+
+ def run():
+ nonlocal result, exception
+ try:
+ asyncio.set_event_loop(loop)
+ result = loop.run_until_complete(
+ async_method(self, *args, **kwargs)
+ )
+ except Exception as e:
+ exception = e
+ finally:
+ generation_config = kwargs.get(
+ "rag_generation_config", None
+ )
+ if (
+ not generation_config
+ or not generation_config.stream
+ ):
+ loop.run_until_complete(
+ loop.shutdown_asyncgens()
+ )
+ loop.close()
+
+ thread = Thread(target=run)
+ thread.start()
+ thread.join()
+ if exception:
+ raise exception
+ return result
+ else:
+ # If there's already a running loop, schedule and execute the coroutine
+ future = asyncio.run_coroutine_threadsafe(
+ async_method(self, *args, **kwargs), loop
+ )
+ return future.result()
+
+ return sync_wrapper
+
+ setattr(
+ new_cls, sync_method_name, make_sync_method(async_method)
+ )
+ return new_cls
+
+
+def syncable(func):
+ """Decorator to mark methods for synchronous wrapper creation."""
+ func._syncable = True
+ return func
diff --git a/.venv/lib/python3.12/site-packages/shared/abstractions/document.py b/.venv/lib/python3.12/site-packages/shared/abstractions/document.py
new file mode 100644
index 00000000..513392f8
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/shared/abstractions/document.py
@@ -0,0 +1,377 @@
+"""Abstractions for documents and their extractions."""
+
+import json
+import logging
+from datetime import datetime
+from enum import Enum
+from typing import Any, Optional
+from uuid import UUID, uuid4
+
+from pydantic import Field
+
+from .base import R2RSerializable
+from .llm import GenerationConfig
+
+logger = logging.getLogger()
+
+
+class DocumentType(str, Enum):
+ """Types of documents that can be stored."""
+
+ # Audio
+ MP3 = "mp3"
+
+ # CSV
+ CSV = "csv"
+
+ # Email
+ EML = "eml"
+ MSG = "msg"
+ P7S = "p7s"
+
+ # EPUB
+ EPUB = "epub"
+
+ # Excel
+ XLS = "xls"
+ XLSX = "xlsx"
+
+ # HTML
+ HTML = "html"
+ HTM = "htm"
+
+ # Image
+ BMP = "bmp"
+ HEIC = "heic"
+ JPEG = "jpeg"
+ PNG = "png"
+ TIFF = "tiff"
+ JPG = "jpg"
+ SVG = "svg"
+
+ # Markdown
+ MD = "md"
+
+ # Org Mode
+ ORG = "org"
+
+ # Open Office
+ ODT = "odt"
+
+ # PDF
+ PDF = "pdf"
+
+ # Plain text
+ TXT = "txt"
+ JSON = "json"
+
+ # PowerPoint
+ PPT = "ppt"
+ PPTX = "pptx"
+
+ # reStructured Text
+ RST = "rst"
+
+ # Rich Text
+ RTF = "rtf"
+
+ # TSV
+ TSV = "tsv"
+
+ # Video/GIF
+ GIF = "gif"
+
+ # Word
+ DOC = "doc"
+ DOCX = "docx"
+
+ # XML
+ XML = "xml"
+
+
+class Document(R2RSerializable):
+ id: UUID = Field(default_factory=uuid4)
+ collection_ids: list[UUID]
+ owner_id: UUID
+ document_type: DocumentType
+ metadata: dict
+
+ class Config:
+ arbitrary_types_allowed = True
+ ignore_extra = False
+ json_encoders = {
+ UUID: str,
+ }
+ populate_by_name = True
+
+
+class IngestionStatus(str, Enum):
+ """Status of document processing."""
+
+ PENDING = "pending"
+ PARSING = "parsing"
+ EXTRACTING = "extracting"
+ CHUNKING = "chunking"
+ EMBEDDING = "embedding"
+ AUGMENTING = "augmenting"
+ STORING = "storing"
+ ENRICHING = "enriching"
+
+ FAILED = "failed"
+ SUCCESS = "success"
+
+ def __str__(self):
+ return self.value
+
+ @classmethod
+ def table_name(cls) -> str:
+ return "documents"
+
+ @classmethod
+ def id_column(cls) -> str:
+ return "document_id"
+
+
+class GraphExtractionStatus(str, Enum):
+ """Status of graph creation per document."""
+
+ PENDING = "pending"
+ PROCESSING = "processing"
+ SUCCESS = "success"
+ ENRICHED = "enriched"
+ FAILED = "failed"
+
+ def __str__(self):
+ return self.value
+
+ @classmethod
+ def table_name(cls) -> str:
+ return "documents"
+
+ @classmethod
+ def id_column(cls) -> str:
+ return "id"
+
+
+class GraphConstructionStatus(str, Enum):
+ """Status of graph enrichment per collection."""
+
+ PENDING = "pending"
+ PROCESSING = "processing"
+ OUTDATED = "outdated"
+ SUCCESS = "success"
+ FAILED = "failed"
+
+ def __str__(self):
+ return self.value
+
+ @classmethod
+ def table_name(cls) -> str:
+ return "collections"
+
+ @classmethod
+ def id_column(cls) -> str:
+ return "id"
+
+
+class DocumentResponse(R2RSerializable):
+ """Base class for document information handling."""
+
+ id: UUID
+ collection_ids: list[UUID]
+ owner_id: UUID
+ document_type: DocumentType
+ metadata: dict
+ title: Optional[str] = None
+ version: str
+ size_in_bytes: Optional[int]
+ ingestion_status: IngestionStatus = IngestionStatus.PENDING
+ extraction_status: GraphExtractionStatus = GraphExtractionStatus.PENDING
+ created_at: Optional[datetime] = None
+ updated_at: Optional[datetime] = None
+ ingestion_attempt_number: Optional[int] = None
+ summary: Optional[str] = None
+ summary_embedding: Optional[list[float]] = None # Add optional embedding
+ total_tokens: Optional[int] = None
+ chunks: Optional[list] = None
+
+ def convert_to_db_entry(self):
+ """Prepare the document info for database entry, extracting certain
+ fields from metadata."""
+ now = datetime.now()
+
+ # Format the embedding properly for Postgres vector type
+ embedding = None
+ if self.summary_embedding is not None:
+ embedding = f"[{','.join(str(x) for x in self.summary_embedding)}]"
+
+ return {
+ "id": self.id,
+ "collection_ids": self.collection_ids,
+ "owner_id": self.owner_id,
+ "document_type": self.document_type,
+ "metadata": json.dumps(self.metadata),
+ "title": self.title or "N/A",
+ "version": self.version,
+ "size_in_bytes": self.size_in_bytes,
+ "ingestion_status": self.ingestion_status.value,
+ "extraction_status": self.extraction_status.value,
+ "created_at": self.created_at or now,
+ "updated_at": self.updated_at or now,
+ "ingestion_attempt_number": self.ingestion_attempt_number or 0,
+ "summary": self.summary,
+ "summary_embedding": embedding,
+ "total_tokens": self.total_tokens or 0, # ensure we pass 0 if None
+ }
+
+ class Config:
+ json_schema_extra = {
+ "example": {
+ "id": "123e4567-e89b-12d3-a456-426614174000",
+ "collection_ids": ["123e4567-e89b-12d3-a456-426614174000"],
+ "owner_id": "123e4567-e89b-12d3-a456-426614174000",
+ "document_type": "pdf",
+ "metadata": {"title": "Sample Document"},
+ "title": "Sample Document",
+ "version": "1.0",
+ "size_in_bytes": 123456,
+ "ingestion_status": "pending",
+ "extraction_status": "pending",
+ "created_at": "2021-01-01T00:00:00",
+ "updated_at": "2021-01-01T00:00:00",
+ "ingestion_attempt_number": 0,
+ "summary": "A summary of the document",
+ "summary_embedding": [0.1, 0.2, 0.3],
+ "total_tokens": 1000,
+ }
+ }
+
+
+class UnprocessedChunk(R2RSerializable):
+ """An extraction from a document."""
+
+ id: Optional[UUID] = None
+ document_id: Optional[UUID] = None
+ collection_ids: list[UUID] = []
+ metadata: dict = {}
+ text: str
+
+
+class UpdateChunk(R2RSerializable):
+ """An extraction from a document."""
+
+ id: UUID
+ metadata: Optional[dict] = None
+ text: str
+
+
+class DocumentChunk(R2RSerializable):
+ """An extraction from a document."""
+
+ id: UUID
+ document_id: UUID
+ collection_ids: list[UUID]
+ owner_id: UUID
+ data: str | bytes
+ metadata: dict
+
+
+class RawChunk(R2RSerializable):
+ text: str
+
+
+class IngestionMode(str, Enum):
+ hi_res = "hi-res"
+ fast = "fast"
+ custom = "custom"
+
+
+class ChunkEnrichmentSettings(R2RSerializable):
+ """Settings for chunk enrichment."""
+
+ enable_chunk_enrichment: bool = Field(
+ default=False,
+ description="Whether to enable chunk enrichment or not",
+ )
+ n_chunks: int = Field(
+ default=2,
+ description="The number of preceding and succeeding chunks to include. Defaults to 2.",
+ )
+ generation_config: Optional[GenerationConfig] = Field(
+ default=None,
+ description="The generation config to use for chunk enrichment",
+ )
+ chunk_enrichment_prompt: Optional[str] = Field(
+ default="chunk_enrichment",
+ description="The prompt to use for chunk enrichment",
+ )
+
+
+class IngestionConfig(R2RSerializable):
+ provider: str = "r2r"
+ excluded_parsers: list[str] = ["mp4"]
+ chunking_strategy: str = "recursive"
+ chunk_enrichment_settings: ChunkEnrichmentSettings = (
+ ChunkEnrichmentSettings()
+ )
+ extra_parsers: dict[str, Any] = {}
+
+ audio_transcription_model: str = ""
+
+ vision_img_prompt_name: str = "vision_img"
+
+ vision_pdf_prompt_name: str = "vision_pdf"
+
+ skip_document_summary: bool = False
+ document_summary_system_prompt: str = "system"
+ document_summary_task_prompt: str = "summary"
+ chunks_for_document_summary: int = 128
+ document_summary_model: str = ""
+
+ @property
+ def supported_providers(self) -> list[str]:
+ return ["r2r", "unstructured_local", "unstructured_api"]
+
+ def validate_config(self) -> None:
+ if self.provider not in self.supported_providers:
+ raise ValueError(f"Provider {self.provider} is not supported.")
+
+ @classmethod
+ def get_default(cls, mode: str) -> "IngestionConfig":
+ """Return default ingestion configuration for a given mode."""
+ if mode == "hi-res":
+ # More thorough parsing, no skipping summaries, possibly larger `chunks_for_document_summary`.
+ return cls(
+ provider="r2r",
+ excluded_parsers=["mp4"],
+ chunk_enrichment_settings=ChunkEnrichmentSettings(), # default
+ extra_parsers={},
+ audio_transcription_model="",
+ vision_img_prompt_name="vision_img",
+ vision_pdf_prompt_name="vision_pdf",
+ skip_document_summary=False,
+ document_summary_system_prompt="system",
+ document_summary_task_prompt="summary",
+ chunks_for_document_summary=256, # larger for hi-res
+ document_summary_model="",
+ )
+
+ elif mode == "fast":
+ # Skip summaries and other enrichment steps for speed.
+ return cls(
+ provider="r2r",
+ excluded_parsers=["mp4"],
+ chunk_enrichment_settings=ChunkEnrichmentSettings(), # default
+ extra_parsers={},
+ audio_transcription_model="",
+ vision_img_prompt_name="vision_img",
+ vision_pdf_prompt_name="vision_pdf",
+ skip_document_summary=True, # skip summaries
+ document_summary_system_prompt="system",
+ document_summary_task_prompt="summary",
+ chunks_for_document_summary=64,
+ document_summary_model="",
+ )
+ else:
+ # For `custom` or any unrecognized mode, return a base config
+ return cls()
diff --git a/.venv/lib/python3.12/site-packages/shared/abstractions/embedding.py b/.venv/lib/python3.12/site-packages/shared/abstractions/embedding.py
new file mode 100644
index 00000000..6e27da28
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/shared/abstractions/embedding.py
@@ -0,0 +1,31 @@
+from enum import Enum, auto
+
+
+class EmbeddingPurpose(str, Enum):
+ INDEX = auto()
+ QUERY = auto()
+ DOCUMENT = auto()
+
+
+default_embedding_prefixes = {
+ "nomic-embed-text-v1.5": {
+ EmbeddingPurpose.INDEX: "",
+ EmbeddingPurpose.QUERY: "search_query: ",
+ EmbeddingPurpose.DOCUMENT: "search_document: ",
+ },
+ "nomic-embed-text": {
+ EmbeddingPurpose.INDEX: "",
+ EmbeddingPurpose.QUERY: "search_query: ",
+ EmbeddingPurpose.DOCUMENT: "search_document: ",
+ },
+ "mixedbread-ai/mxbai-embed-large-v1": {
+ EmbeddingPurpose.INDEX: "",
+ EmbeddingPurpose.QUERY: "Represent this sentence for searching relevant passages: ",
+ EmbeddingPurpose.DOCUMENT: "Represent this sentence for searching relevant passages: ",
+ },
+ "mixedbread-ai/mxbai-embed-large": {
+ EmbeddingPurpose.INDEX: "",
+ EmbeddingPurpose.QUERY: "Represent this sentence for searching relevant passages: ",
+ EmbeddingPurpose.DOCUMENT: "Represent this sentence for searching relevant passages: ",
+ },
+}
diff --git a/.venv/lib/python3.12/site-packages/shared/abstractions/exception.py b/.venv/lib/python3.12/site-packages/shared/abstractions/exception.py
new file mode 100644
index 00000000..3dedfae8
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/shared/abstractions/exception.py
@@ -0,0 +1,75 @@
+import textwrap
+from typing import Any, Optional
+from uuid import UUID
+
+
+class R2RException(Exception):
+ def __init__(
+ self, message: str, status_code: int, detail: Optional[Any] = None
+ ):
+ self.message = message
+ self.status_code = status_code
+ super().__init__(self.message)
+
+ def to_dict(self):
+ return {
+ "message": self.message,
+ "status_code": self.status_code,
+ "detail": self.detail,
+ "error_type": self.__class__.__name__,
+ }
+
+
+class R2RDocumentProcessingError(R2RException):
+ def __init__(
+ self, error_message: str, document_id: UUID, status_code: int = 500
+ ):
+ detail = {
+ "document_id": str(document_id),
+ "error_type": "document_processing_error",
+ }
+ super().__init__(error_message, status_code, detail)
+
+ def to_dict(self):
+ result = super().to_dict()
+ result["document_id"] = self.document_id
+ return result
+
+
+class PDFParsingError(R2RException):
+ """Custom exception for PDF parsing errors."""
+
+ def __init__(
+ self,
+ message: str,
+ original_error: Exception | None = None,
+ status_code: int = 500,
+ ):
+ detail = {
+ "original_error": str(original_error) if original_error else None
+ }
+ super().__init__(message, status_code, detail)
+
+
+class PopplerNotFoundError(PDFParsingError):
+ """Specific error for when Poppler is not installed."""
+
+ def __init__(self):
+ installation_instructions = textwrap.dedent("""
+ PDF processing requires Poppler to be installed. Please install Poppler and ensure it's in your system PATH.
+
+ Installing poppler:
+ - Ubuntu: sudo apt-get install poppler-utils
+ - Archlinux: sudo pacman -S poppler
+ - MacOS: brew install poppler
+ - Windows:
+ 1. Download poppler from @oschwartz10612
+ 2. Move extracted directory to desired location
+ 3. Add bin/ directory to PATH
+ 4. Test by running 'pdftoppm -h' in terminal
+ """)
+ super().__init__(
+ message=installation_instructions,
+ status_code=422,
+ original_error=None,
+ )
diff --git a/.venv/lib/python3.12/site-packages/shared/abstractions/graph.py b/.venv/lib/python3.12/site-packages/shared/abstractions/graph.py
new file mode 100644
index 00000000..3c1cec9e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/shared/abstractions/graph.py
@@ -0,0 +1,257 @@
+import json
+from dataclasses import dataclass
+from datetime import datetime
+from enum import Enum
+from typing import Any, Optional
+from uuid import UUID
+
+from pydantic import Field
+
+from ..abstractions.llm import GenerationConfig
+from .base import R2RSerializable
+
+
+class Entity(R2RSerializable):
+ """An entity extracted from a document."""
+
+ name: str
+ description: Optional[str] = None
+ category: Optional[str] = None
+ metadata: Optional[dict[str, Any]] = None
+
+ id: Optional[UUID] = None
+ parent_id: Optional[UUID] = None # graph_id | document_id
+ description_embedding: Optional[list[float] | str] = None
+ chunk_ids: Optional[list[UUID]] = []
+
+ def __str__(self):
+ return f"{self.name}:{self.category}"
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ if isinstance(self.metadata, str):
+ try:
+ self.metadata = json.loads(self.metadata)
+ except json.JSONDecodeError:
+ self.metadata = self.metadata
+
+
+class Relationship(R2RSerializable):
+ """A relationship between two entities.
+
+ This is a generic relationship, and can be used to represent any type of
+ relationship between any two entities.
+ """
+
+ id: Optional[UUID] = None
+ subject: str
+ predicate: str
+ object: str
+ description: Optional[str] = None
+ subject_id: Optional[UUID] = None
+ object_id: Optional[UUID] = None
+ weight: float | None = 1.0
+ chunk_ids: Optional[list[UUID]] = []
+ parent_id: Optional[UUID] = None
+ description_embedding: Optional[list[float] | str] = None
+ metadata: Optional[dict[str, Any] | str] = None
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ if isinstance(self.metadata, str):
+ try:
+ self.metadata = json.loads(self.metadata)
+ except json.JSONDecodeError:
+ self.metadata = self.metadata
+
+
+@dataclass
+class Community(R2RSerializable):
+ name: str = ""
+ summary: str = ""
+ level: Optional[int] = None
+ findings: list[str] = []
+ id: Optional[int | UUID] = None
+ community_id: Optional[UUID] = None
+ collection_id: Optional[UUID] = None
+ rating: Optional[float] = None
+ rating_explanation: Optional[str] = None
+ description_embedding: Optional[list[float]] = None
+ attributes: dict[str, Any] | None = None
+ created_at: datetime = Field(
+ default_factory=datetime.utcnow,
+ )
+ updated_at: datetime = Field(
+ default_factory=datetime.utcnow,
+ )
+
+ def __init__(self, **kwargs):
+ if isinstance(kwargs.get("attributes", None), str):
+ kwargs["attributes"] = json.loads(kwargs["attributes"])
+
+ if isinstance(kwargs.get("embedding", None), str):
+ kwargs["embedding"] = json.loads(kwargs["embedding"])
+
+ super().__init__(**kwargs)
+
+ @classmethod
+ def from_dict(cls, data: dict[str, Any] | str) -> "Community":
+ parsed_data: dict[str, Any] = (
+ json.loads(data) if isinstance(data, str) else data
+ )
+ if isinstance(parsed_data.get("embedding", None), str):
+ parsed_data["embedding"] = json.loads(parsed_data["embedding"])
+ return cls(**parsed_data)
+
+
+class GraphExtraction(R2RSerializable):
+ """A protocol for a knowledge graph extraction."""
+
+ entities: list[Entity]
+ relationships: list[Relationship]
+
+
+class Graph(R2RSerializable):
+ id: UUID | None = Field()
+ name: str
+ description: Optional[str] = None
+ created_at: datetime = Field(
+ default_factory=datetime.utcnow,
+ )
+ updated_at: datetime = Field(
+ default_factory=datetime.utcnow,
+ )
+ status: str = "pending"
+
+ class Config:
+ populate_by_name = True
+ from_attributes = True
+
+ @classmethod
+ def from_dict(cls, data: dict[str, Any] | str) -> "Graph":
+ """Create a Graph instance from a dictionary."""
+ # Convert string to dict if needed
+ parsed_data: dict[str, Any] = (
+ json.loads(data) if isinstance(data, str) else data
+ )
+ return cls(**parsed_data)
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+
+class StoreType(str, Enum):
+ GRAPHS = "graphs"
+ DOCUMENTS = "documents"
+
+
+class GraphCreationSettings(R2RSerializable):
+ """Settings for knowledge graph creation."""
+
+ graph_extraction_prompt: str = Field(
+ default="graph_extraction",
+ description="The prompt to use for knowledge graph extraction.",
+ )
+
+ graph_entity_description_prompt: str = Field(
+ default="graph_entity_description",
+ description="The prompt to use for entity description generation.",
+ )
+
+ entity_types: list[str] = Field(
+ default=[],
+ description="The types of entities to extract.",
+ )
+
+ relation_types: list[str] = Field(
+ default=[],
+ description="The types of relations to extract.",
+ )
+
+ chunk_merge_count: int = Field(
+ default=2,
+ description="""The number of extractions to merge into a single graph
+ extraction.""",
+ )
+
+ max_knowledge_relationships: int = Field(
+ default=100,
+ description="""The maximum number of knowledge relationships to extract
+ from each chunk.""",
+ )
+
+ max_description_input_length: int = Field(
+ default=65536,
+ description="""The maximum length of the description for a node in the
+ graph.""",
+ )
+
+ generation_config: Optional[GenerationConfig] = Field(
+ default=None,
+ description="Configuration for text generation during graph enrichment.",
+ )
+
+ automatic_deduplication: bool = Field(
+ default=False,
+ description="Whether to automatically deduplicate entities.",
+ )
+
+
+class GraphEnrichmentSettings(R2RSerializable):
+ """Settings for knowledge graph enrichment."""
+
+ force_graph_search_results_enrichment: bool = Field(
+ default=False,
+ description="""Force run the enrichment step even if graph creation is
+ still in progress for some documents.""",
+ )
+
+ graph_communities_prompt: str = Field(
+ default="graph_communities",
+ description="The prompt to use for knowledge graph enrichment.",
+ )
+
+ max_summary_input_length: int = Field(
+ default=65536,
+ description="The maximum length of the summary for a community.",
+ )
+
+ generation_config: Optional[GenerationConfig] = Field(
+ default=None,
+ description="Configuration for text generation during graph enrichment.",
+ )
+
+ leiden_params: dict = Field(
+ default_factory=dict,
+ description="Parameters for the Leiden algorithm.",
+ )
+
+
+class GraphCommunitySettings(R2RSerializable):
+ """Settings for knowledge graph community enrichment."""
+
+ force_graph_search_results_enrichment: bool = Field(
+ default=False,
+ description="""Force run the enrichment step even if graph creation is
+ still in progress for some documents.""",
+ )
+
+ graph_communities: str = Field(
+ default="graph_communities",
+ description="The prompt to use for knowledge graph enrichment.",
+ )
+
+ max_summary_input_length: int = Field(
+ default=65536,
+ description="The maximum length of the summary for a community.",
+ )
+
+ generation_config: Optional[GenerationConfig] = Field(
+ default=None,
+ description="Configuration for text generation during graph enrichment.",
+ )
+
+ leiden_params: dict = Field(
+ default_factory=dict,
+ description="Parameters for the Leiden algorithm.",
+ )
diff --git a/.venv/lib/python3.12/site-packages/shared/abstractions/llm.py b/.venv/lib/python3.12/site-packages/shared/abstractions/llm.py
new file mode 100644
index 00000000..d71e279e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/shared/abstractions/llm.py
@@ -0,0 +1,325 @@
+"""Abstractions for the LLM model."""
+
+import json
+from enum import Enum
+from typing import TYPE_CHECKING, Any, ClassVar, Optional
+
+from openai.types.chat import ChatCompletionChunk
+from pydantic import BaseModel, Field
+
+from .base import R2RSerializable
+
+if TYPE_CHECKING:
+ from .search import AggregateSearchResult
+
+from typing_extensions import Literal
+
+
+class Function(BaseModel):
+ arguments: str
+ """
+ The arguments to call the function with, as generated by the model in JSON
+ format. Note that the model does not always generate valid JSON, and may
+ hallucinate parameters not defined by your function schema. Validate the
+ arguments in your code before calling your function.
+ """
+
+ name: str
+ """The name of the function to call."""
+
+
+class ChatCompletionMessageToolCall(BaseModel):
+ id: str
+ """The ID of the tool call."""
+
+ function: Function
+ """The function that the model called."""
+
+ type: Literal["function"]
+ """The type of the tool. Currently, only `function` is supported."""
+
+
+class FunctionCall(BaseModel):
+ arguments: str
+ """
+ The arguments to call the function with, as generated by the model in JSON
+ format. Note that the model does not always generate valid JSON, and may
+ hallucinate parameters not defined by your function schema. Validate the
+ arguments in your code before calling your function.
+ """
+
+ name: str
+ """The name of the function to call."""
+
+
+class ChatCompletionMessage(BaseModel):
+ content: Optional[str] = None
+ """The contents of the message."""
+
+ refusal: Optional[str] = None
+ """The refusal message generated by the model."""
+
+ role: Literal["assistant"]
+ """The role of the author of this message."""
+
+ # audio: Optional[ChatCompletionAudio] = None
+ """
+ If the audio output modality is requested, this object contains data about the
+ audio response from the model.
+ [Learn more](https://platform.openai.com/docs/guides/audio).
+ """
+
+ function_call: Optional[FunctionCall] = None
+ """Deprecated and replaced by `tool_calls`.
+
+ The name and arguments of a function that should be called, as generated by the
+ model.
+ """
+
+ tool_calls: Optional[list[ChatCompletionMessageToolCall]] = None
+ """The tool calls generated by the model, such as function calls."""
+
+ structured_content: Optional[list[dict]] = None
+
+
+class Choice(BaseModel):
+ finish_reason: Literal[
+ "stop",
+ "length",
+ "tool_calls",
+ "content_filter",
+ "function_call",
+ "max_tokens",
+ ]
+ """The reason the model stopped generating tokens.
+
+ This will be `stop` if the model hit a natural stop point or a provided stop
+ sequence, `length` if the maximum number of tokens specified in the request was
+ reached, `content_filter` if content was omitted due to a flag from our content
+ filters, `tool_calls` if the model called a tool, or `function_call`
+ (deprecated) if the model called a function.
+ """
+
+ index: int
+ """The index of the choice in the list of choices."""
+
+ # logprobs: Optional[ChoiceLogprobs] = None
+ """Log probability information for the choice."""
+
+ message: ChatCompletionMessage
+ """A chat completion message generated by the model."""
+
+
+class LLMChatCompletion(BaseModel):
+ id: str
+ """A unique identifier for the chat completion."""
+
+ choices: list[Choice]
+ """A list of chat completion choices.
+
+ Can be more than one if `n` is greater than 1.
+ """
+
+ created: int
+ """The Unix timestamp (in seconds) of when the chat completion was created."""
+
+ model: str
+ """The model used for the chat completion."""
+
+ object: Literal["chat.completion"]
+ """The object type, which is always `chat.completion`."""
+
+ service_tier: Optional[Literal["scale", "default"]] = None
+ """The service tier used for processing the request."""
+
+ system_fingerprint: Optional[str] = None
+ """This fingerprint represents the backend configuration that the model runs with.
+
+ Can be used in conjunction with the `seed` request parameter to understand when
+ backend changes have been made that might impact determinism.
+ """
+
+ usage: Optional[Any] = None
+ """Usage statistics for the completion request."""
+
+
+LLMChatCompletionChunk = ChatCompletionChunk
+
+
+class RAGCompletion:
+ completion: LLMChatCompletion
+ search_results: "AggregateSearchResult"
+
+ def __init__(
+ self,
+ completion: LLMChatCompletion,
+ search_results: "AggregateSearchResult",
+ ):
+ self.completion = completion
+ self.search_results = search_results
+
+
+class GenerationConfig(R2RSerializable):
+ _defaults: ClassVar[dict] = {
+ "model": None,
+ "temperature": 0.1,
+ "top_p": 1.0,
+ "max_tokens_to_sample": 1024,
+ "stream": False,
+ "functions": None,
+ "tools": None,
+ "add_generation_kwargs": None,
+ "api_base": None,
+ "response_format": None,
+ "extended_thinking": False,
+ "thinking_budget": None,
+ "reasoning_effort": None,
+ }
+
+ model: Optional[str] = Field(
+ default_factory=lambda: GenerationConfig._defaults["model"]
+ )
+ temperature: float = Field(
+ default_factory=lambda: GenerationConfig._defaults["temperature"]
+ )
+ top_p: Optional[float] = Field(
+ default_factory=lambda: GenerationConfig._defaults["top_p"],
+ )
+ max_tokens_to_sample: int = Field(
+ default_factory=lambda: GenerationConfig._defaults[
+ "max_tokens_to_sample"
+ ],
+ )
+ stream: bool = Field(
+ default_factory=lambda: GenerationConfig._defaults["stream"]
+ )
+ functions: Optional[list[dict]] = Field(
+ default_factory=lambda: GenerationConfig._defaults["functions"]
+ )
+ tools: Optional[list[dict]] = Field(
+ default_factory=lambda: GenerationConfig._defaults["tools"]
+ )
+ add_generation_kwargs: Optional[dict] = Field(
+ default_factory=lambda: GenerationConfig._defaults[
+ "add_generation_kwargs"
+ ],
+ )
+ api_base: Optional[str] = Field(
+ default_factory=lambda: GenerationConfig._defaults["api_base"],
+ )
+ response_format: Optional[dict | BaseModel] = None
+ extended_thinking: bool = Field(
+ default=False,
+ description="Flag to enable extended thinking mode (for Anthropic providers)",
+ )
+ thinking_budget: Optional[int] = Field(
+ default=None,
+ description=(
+ "Token budget for internal reasoning when extended thinking mode is enabled. "
+ "Must be less than max_tokens_to_sample."
+ ),
+ )
+ reasoning_effort: Optional[str] = Field(
+ default=None,
+ description=(
+ "Effort level for internal reasoning when extended thinking mode is enabled, `low`, `medium`, or `high`."
+ "Only applicable to OpenAI providers."
+ ),
+ )
+
+ @classmethod
+ def set_default(cls, **kwargs):
+ for key, value in kwargs.items():
+ if key in cls._defaults:
+ cls._defaults[key] = value
+ else:
+ raise AttributeError(
+ f"No default attribute '{key}' in GenerationConfig"
+ )
+
+ def __init__(self, **data):
+ # Handle max_tokens mapping to max_tokens_to_sample
+ if "max_tokens" in data:
+ # Only set max_tokens_to_sample if it's not already provided
+ if "max_tokens_to_sample" not in data:
+ data["max_tokens_to_sample"] = data.pop("max_tokens")
+ else:
+ # If both are provided, max_tokens_to_sample takes precedence
+ data.pop("max_tokens")
+
+ if (
+ "response_format" in data
+ and isinstance(data["response_format"], type)
+ and issubclass(data["response_format"], BaseModel)
+ ):
+ model_class = data["response_format"]
+ data["response_format"] = {
+ "type": "json_schema",
+ "json_schema": {
+ "name": model_class.__name__,
+ "schema": model_class.model_json_schema(),
+ },
+ }
+
+ model = data.pop("model", None)
+ if model is not None:
+ super().__init__(model=model, **data)
+ else:
+ super().__init__(**data)
+
+ def __str__(self):
+ return json.dumps(self.to_dict())
+
+ class Config:
+ populate_by_name = True
+ json_schema_extra = {
+ "example": {
+ "model": "openai/gpt-4o",
+ "temperature": 0.1,
+ "top_p": 1.0,
+ "max_tokens_to_sample": 1024,
+ "stream": False,
+ "functions": None,
+ "tools": None,
+ "add_generation_kwargs": None,
+ "api_base": None,
+ }
+ }
+
+
+class MessageType(Enum):
+ SYSTEM = "system"
+ USER = "user"
+ ASSISTANT = "assistant"
+ FUNCTION = "function"
+ TOOL = "tool"
+
+ def __str__(self):
+ return self.value
+
+
+class Message(R2RSerializable):
+ role: MessageType | str
+ content: Optional[Any] = None
+ name: Optional[str] = None
+ function_call: Optional[dict[str, Any]] = None
+ tool_calls: Optional[list[dict[str, Any]]] = None
+ tool_call_id: Optional[str] = None
+ metadata: Optional[dict[str, Any]] = None
+ structured_content: Optional[list[dict]] = None
+ image_url: Optional[str] = None # For URL-based images
+ image_data: Optional[dict[str, str]] = (
+ None # For base64 {media_type, data}
+ )
+
+ class Config:
+ populate_by_name = True
+ json_schema_extra = {
+ "example": {
+ "role": "user",
+ "content": "This is a test message.",
+ "name": None,
+ "function_call": None,
+ "tool_calls": None,
+ }
+ }
diff --git a/.venv/lib/python3.12/site-packages/shared/abstractions/prompt.py b/.venv/lib/python3.12/site-packages/shared/abstractions/prompt.py
new file mode 100644
index 00000000..85ab5312
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/shared/abstractions/prompt.py
@@ -0,0 +1,39 @@
+"""Abstraction for a prompt that can be formatted with inputs."""
+
+import logging
+from datetime import datetime
+from typing import Any
+from uuid import UUID, uuid4
+
+from pydantic import BaseModel, Field
+
+logger = logging.getLogger()
+
+
+class Prompt(BaseModel):
+ """A prompt that can be formatted with inputs."""
+
+ id: UUID = Field(default_factory=uuid4)
+ name: str
+ template: str
+ input_types: dict[str, str]
+ created_at: datetime = Field(default_factory=datetime.utcnow)
+ updated_at: datetime = Field(default_factory=datetime.utcnow)
+
+ def format_prompt(self, inputs: dict[str, Any]) -> str:
+ self._validate_inputs(inputs)
+ return self.template.format(**inputs)
+
+ def _validate_inputs(self, inputs: dict[str, Any]) -> None:
+ for var, expected_type_name in self.input_types.items():
+ expected_type = self._convert_type(expected_type_name)
+ if var not in inputs:
+ raise ValueError(f"Missing input: {var}")
+ if not isinstance(inputs[var], expected_type):
+ raise TypeError(
+ f"Input '{var}' must be of type {expected_type.__name__}, got {type(inputs[var]).__name__} instead."
+ )
+
+ def _convert_type(self, type_name: str) -> type:
+ type_mapping = {"int": int, "str": str}
+ return type_mapping.get(type_name, str)
diff --git a/.venv/lib/python3.12/site-packages/shared/abstractions/search.py b/.venv/lib/python3.12/site-packages/shared/abstractions/search.py
new file mode 100644
index 00000000..bf0f650e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/shared/abstractions/search.py
@@ -0,0 +1,614 @@
+"""Abstractions for search functionality."""
+
+from copy import copy
+from enum import Enum
+from typing import Any, Optional
+from uuid import NAMESPACE_DNS, UUID, uuid5
+
+from pydantic import Field
+
+from .base import R2RSerializable
+from .document import DocumentResponse
+from .llm import GenerationConfig
+from .vector import IndexMeasure
+
+
+def generate_id_from_label(label) -> UUID:
+ return uuid5(NAMESPACE_DNS, label)
+
+
+class ChunkSearchResult(R2RSerializable):
+ """Result of a search operation."""
+
+ id: UUID
+ document_id: UUID
+ owner_id: Optional[UUID]
+ collection_ids: list[UUID]
+ score: Optional[float] = None
+ text: str
+ metadata: dict[str, Any]
+
+ def __str__(self) -> str:
+ if self.score:
+ return (
+ f"ChunkSearchResult(score={self.score:.3f}, text={self.text})"
+ )
+ else:
+ return f"ChunkSearchResult(text={self.text})"
+
+ def __repr__(self) -> str:
+ return self.__str__()
+
+ def as_dict(self) -> dict:
+ return {
+ "id": self.id,
+ "document_id": self.document_id,
+ "owner_id": self.owner_id,
+ "collection_ids": self.collection_ids,
+ "score": self.score,
+ "text": self.text,
+ "metadata": self.metadata,
+ }
+
+ class Config:
+ populate_by_name = True
+ json_schema_extra = {
+ "example": {
+ "id": "3f3d47f3-8baf-58eb-8bc2-0171fb1c6e09",
+ "document_id": "3e157b3a-8469-51db-90d9-52e7d896b49b",
+ "owner_id": "2acb499e-8428-543b-bd85-0d9098718220",
+ "collection_ids": [],
+ "score": 0.23943702876567796,
+ "text": "Example text from the document",
+ "metadata": {
+ "title": "example_document.pdf",
+ "associated_query": "What is the capital of France?",
+ },
+ }
+ }
+
+
+class GraphSearchResultType(str, Enum):
+ ENTITY = "entity"
+ RELATIONSHIP = "relationship"
+ COMMUNITY = "community"
+
+
+class GraphEntityResult(R2RSerializable):
+ id: Optional[UUID] = None
+ name: str
+ description: str
+ metadata: Optional[dict[str, Any]] = None
+
+ class Config:
+ json_schema_extra = {
+ "example": {
+ "name": "Entity Name",
+ "description": "Entity Description",
+ "metadata": {},
+ }
+ }
+
+
+class GraphRelationshipResult(R2RSerializable):
+ id: Optional[UUID] = None
+ subject: str
+ predicate: str
+ object: str
+ subject_id: Optional[UUID] = None
+ object_id: Optional[UUID] = None
+ metadata: Optional[dict[str, Any]] = None
+ score: Optional[float] = None
+ description: str | None = None
+
+ class Config:
+ json_schema_extra = {
+ "example": {
+ "name": "Relationship Name",
+ "description": "Relationship Description",
+ "metadata": {},
+ }
+ }
+
+ def __str__(self) -> str:
+ return f"GraphRelationshipResult(subject={self.subject}, predicate={self.predicate}, object={self.object})"
+
+
+class GraphCommunityResult(R2RSerializable):
+ id: Optional[UUID] = None
+ name: str
+ summary: str
+ metadata: Optional[dict[str, Any]] = None
+
+ class Config:
+ json_schema_extra = {
+ "example": {
+ "name": "Community Name",
+ "summary": "Community Summary",
+ "rating": 9,
+ "rating_explanation": "Rating Explanation",
+ "metadata": {},
+ }
+ }
+
+ def __str__(self) -> str:
+ return (
+ f"GraphCommunityResult(name={self.name}, summary={self.summary})"
+ )
+
+
+class GraphSearchResult(R2RSerializable):
+ content: GraphEntityResult | GraphRelationshipResult | GraphCommunityResult
+ result_type: Optional[GraphSearchResultType] = None
+ chunk_ids: Optional[list[UUID]] = None
+ metadata: dict[str, Any] = {}
+ score: Optional[float] = None
+ id: UUID
+
+ def __str__(self) -> str:
+ return f"GraphSearchResult(content={self.content}, result_type={self.result_type})"
+
+ class Config:
+ populate_by_name = True
+ json_schema_extra = {
+ "example": {
+ "content": {
+ "id": "3f3d47f3-8baf-58eb-8bc2-0171fb1c6e09",
+ "name": "Entity Name",
+ "description": "Entity Description",
+ "metadata": {},
+ },
+ "result_type": "entity",
+ "chunk_ids": ["c68dc72e-fc23-5452-8f49-d7bd46088a96"],
+ "metadata": {
+ "associated_query": "What is the capital of France?"
+ },
+ }
+ }
+
+
+class WebPageSearchResult(R2RSerializable):
+ title: Optional[str] = None
+ link: Optional[str] = None
+ snippet: Optional[str] = None
+ position: int
+ type: str = "organic"
+ date: Optional[str] = None
+ sitelinks: Optional[list[dict]] = None
+ id: UUID
+
+ class Config:
+ json_schema_extra = {
+ "example": {
+ "title": "Page Title",
+ "link": "https://example.com/page",
+ "snippet": "Page snippet",
+ "position": 1,
+ "date": "2021-01-01",
+ "sitelinks": [
+ {
+ "title": "Sitelink Title",
+ "link": "https://example.com/sitelink",
+ }
+ ],
+ }
+ }
+
+ def __str__(self) -> str:
+ return f"WebPageSearchResult(title={self.title}, link={self.link}, snippet={self.snippet})"
+
+
+class RelatedSearchResult(R2RSerializable):
+ query: str
+ type: str = "related"
+ id: UUID
+
+
+class PeopleAlsoAskResult(R2RSerializable):
+ question: str
+ snippet: str
+ link: str
+ title: str
+ id: UUID
+ type: str = "peopleAlsoAsk"
+
+
+class WebSearchResult(R2RSerializable):
+ organic_results: list[WebPageSearchResult] = []
+ related_searches: list[RelatedSearchResult] = []
+ people_also_ask: list[PeopleAlsoAskResult] = []
+
+ @classmethod
+ def from_serper_results(cls, results: list[dict]) -> "WebSearchResult":
+ organic = []
+ related = []
+ paa = []
+
+ for result in results:
+ if result["type"] == "organic":
+ organic.append(
+ WebPageSearchResult(
+ **result, id=generate_id_from_label(result.get("link"))
+ )
+ )
+ elif result["type"] == "relatedSearches":
+ related.append(
+ RelatedSearchResult(
+ **result,
+ id=generate_id_from_label(result.get("query")),
+ )
+ )
+ elif result["type"] == "peopleAlsoAsk":
+ paa.append(
+ PeopleAlsoAskResult(
+ **result, id=generate_id_from_label(result.get("link"))
+ )
+ )
+
+ return cls(
+ organic_results=organic,
+ related_searches=related,
+ people_also_ask=paa,
+ )
+
+
+class AggregateSearchResult(R2RSerializable):
+ """Result of an aggregate search operation."""
+
+ chunk_search_results: Optional[list[ChunkSearchResult]] = None
+ graph_search_results: Optional[list[GraphSearchResult]] = None
+ web_search_results: Optional[list[WebPageSearchResult]] = None
+ document_search_results: Optional[list[DocumentResponse]] = None
+
+ def __str__(self) -> str:
+ return f"AggregateSearchResult(chunk_search_results={self.chunk_search_results}, graph_search_results={self.graph_search_results}, web_search_results={self.web_search_results}, document_search_results={str(self.document_search_results)})"
+
+ def __repr__(self) -> str:
+ return f"AggregateSearchResult(chunk_search_results={self.chunk_search_results}, graph_search_results={self.graph_search_results}, web_search_results={self.web_search_results}, document_search_results={str(self.document_search_results)})"
+
+ def as_dict(self) -> dict:
+ return {
+ "chunk_search_results": (
+ [result.as_dict() for result in self.chunk_search_results]
+ if self.chunk_search_results
+ else []
+ ),
+ "graph_search_results": (
+ [result.to_dict() for result in self.graph_search_results]
+ if self.graph_search_results
+ else []
+ ),
+ "web_search_results": (
+ [result.to_dict() for result in self.web_search_results]
+ if self.web_search_results
+ else []
+ ),
+ "document_search_results": (
+ [cdr.to_dict() for cdr in self.document_search_results]
+ if self.document_search_results
+ else []
+ ),
+ }
+
+ class Config:
+ populate_by_name = True
+ json_schema_extra = {
+ "example": {
+ "chunk_search_results": [
+ {
+ "id": "3f3d47f3-8baf-58eb-8bc2-0171fb1c6e09",
+ "document_id": "3e157b3a-8469-51db-90d9-52e7d896b49b",
+ "owner_id": "2acb499e-8428-543b-bd85-0d9098718220",
+ "collection_ids": [],
+ "score": 0.23943702876567796,
+ "text": "Example text from the document",
+ "metadata": {
+ "title": "example_document.pdf",
+ "associated_query": "What is the capital of France?",
+ },
+ }
+ ],
+ "graph_search_results": [
+ {
+ "content": {
+ "id": "3f3d47f3-8baf-58eb-8bc2-0171fb1c6e09",
+ "name": "Entity Name",
+ "description": "Entity Description",
+ "metadata": {},
+ },
+ "result_type": "entity",
+ "chunk_ids": ["c68dc72e-fc23-5452-8f49-d7bd46088a96"],
+ "metadata": {
+ "associated_query": "What is the capital of France?"
+ },
+ }
+ ],
+ "web_search_results": [
+ {
+ "title": "Page Title",
+ "link": "https://example.com/page",
+ "snippet": "Page snippet",
+ "position": 1,
+ "date": "2021-01-01",
+ "sitelinks": [
+ {
+ "title": "Sitelink Title",
+ "link": "https://example.com/sitelink",
+ }
+ ],
+ }
+ ],
+ "document_search_results": [
+ {
+ "document": {
+ "id": "3f3d47f3-8baf-58eb-8bc2-0171fb1c6e09",
+ "title": "Document Title",
+ "chunks": ["Chunk 1", "Chunk 2"],
+ "metadata": {},
+ },
+ }
+ ],
+ }
+ }
+
+
+class HybridSearchSettings(R2RSerializable):
+ """Settings for hybrid search combining full-text and semantic search."""
+
+ full_text_weight: float = Field(
+ default=1.0, description="Weight to apply to full text search"
+ )
+ semantic_weight: float = Field(
+ default=5.0, description="Weight to apply to semantic search"
+ )
+ full_text_limit: int = Field(
+ default=200,
+ description="Maximum number of results to return from full text search",
+ )
+ rrf_k: int = Field(
+ default=50, description="K-value for RRF (Rank Reciprocal Fusion)"
+ )
+
+
+class ChunkSearchSettings(R2RSerializable):
+ """Settings specific to chunk/vector search."""
+
+ index_measure: IndexMeasure = Field(
+ default=IndexMeasure.cosine_distance,
+ description="The distance measure to use for indexing",
+ )
+ probes: int = Field(
+ default=10,
+ description="Number of ivfflat index lists to query. Higher increases accuracy but decreases speed.",
+ )
+ ef_search: int = Field(
+ default=40,
+ description="Size of the dynamic candidate list for HNSW index search. Higher increases accuracy but decreases speed.",
+ )
+ enabled: bool = Field(
+ default=True,
+ description="Whether to enable chunk search",
+ )
+
+
+class GraphSearchSettings(R2RSerializable):
+ """Settings specific to knowledge graph search."""
+
+ generation_config: Optional[GenerationConfig] = Field(
+ default=None,
+ description="Configuration for text generation during graph search.",
+ )
+ max_community_description_length: int = Field(
+ default=65536,
+ )
+ max_llm_queries_for_global_search: int = Field(
+ default=250,
+ )
+ limits: dict[str, int] = Field(
+ default={},
+ )
+ enabled: bool = Field(
+ default=True,
+ description="Whether to enable graph search",
+ )
+
+
+class SearchSettings(R2RSerializable):
+ """Main search settings class that combines shared settings with
+ specialized settings for chunks and graph."""
+
+ # Search type flags
+ use_hybrid_search: bool = Field(
+ default=False,
+ description="Whether to perform a hybrid search. This is equivalent to setting `use_semantic_search=True` and `use_fulltext_search=True`, e.g. combining vector and keyword search.",
+ )
+ use_semantic_search: bool = Field(
+ default=True,
+ description="Whether to use semantic search",
+ )
+ use_fulltext_search: bool = Field(
+ default=False,
+ description="Whether to use full-text search",
+ )
+
+ # Common search parameters
+ filters: dict[str, Any] = Field(
+ default_factory=dict,
+ description="""Filters to apply to the search. Allowed operators include `eq`, `neq`, `gt`, `gte`, `lt`, `lte`, `like`, `ilike`, `in`, and `nin`.
+
+ Commonly seen filters include operations include the following:
+
+ `{"document_id": {"$eq": "9fbe403b-..."}}`
+
+ `{"document_id": {"$in": ["9fbe403b-...", "3e157b3a-..."]}}`
+
+ `{"collection_ids": {"$overlap": ["122fdf6a-...", "..."]}}`
+
+ `{"$and": {"$document_id": ..., "collection_ids": ...}}`""",
+ )
+ limit: int = Field(
+ default=10,
+ description="Maximum number of results to return",
+ ge=1,
+ le=1_000,
+ )
+ offset: int = Field(
+ default=0,
+ ge=0,
+ description="Offset to paginate search results",
+ )
+ include_metadatas: bool = Field(
+ default=True,
+ description="Whether to include element metadata in the search results",
+ )
+ include_scores: bool = Field(
+ default=True,
+ description="""Whether to include search score values in the
+ search results""",
+ )
+
+ # Search strategy and settings
+ search_strategy: str = Field(
+ default="vanilla",
+ description="""Search strategy to use
+ (e.g., 'vanilla', 'query_fusion', 'hyde')""",
+ )
+ hybrid_settings: HybridSearchSettings = Field(
+ default_factory=HybridSearchSettings,
+ description="""Settings for hybrid search (only used if
+ `use_semantic_search` and `use_fulltext_search` are both true)""",
+ )
+
+ # Specialized settings
+ chunk_settings: ChunkSearchSettings = Field(
+ default_factory=ChunkSearchSettings,
+ description="Settings specific to chunk/vector search",
+ )
+ graph_settings: GraphSearchSettings = Field(
+ default_factory=GraphSearchSettings,
+ description="Settings specific to knowledge graph search",
+ )
+
+ # For HyDE or multi-query:
+ num_sub_queries: int = Field(
+ default=5,
+ description="Number of sub-queries/hypothetical docs to generate when using hyde or rag_fusion search strategies.",
+ )
+
+ class Config:
+ populate_by_name = True
+ json_encoders = {UUID: str}
+ json_schema_extra = {
+ "example": {
+ "use_semantic_search": True,
+ "use_fulltext_search": False,
+ "use_hybrid_search": False,
+ "filters": {"category": "technology"},
+ "limit": 20,
+ "offset": 0,
+ "search_strategy": "vanilla",
+ "hybrid_settings": {
+ "full_text_weight": 1.0,
+ "semantic_weight": 5.0,
+ "full_text_limit": 200,
+ "rrf_k": 50,
+ },
+ "chunk_settings": {
+ "enabled": True,
+ "index_measure": "cosine_distance",
+ "include_metadata": True,
+ "probes": 10,
+ "ef_search": 40,
+ },
+ "graph_settings": {
+ "enabled": True,
+ "generation_config": GenerationConfig.Config.json_schema_extra,
+ "max_community_description_length": 65536,
+ "max_llm_queries_for_global_search": 250,
+ "limits": {
+ "entity": 20,
+ "relationship": 20,
+ "community": 20,
+ },
+ },
+ }
+ }
+
+ def __init__(self, **data):
+ # Handle legacy search_filters field
+ data["filters"] = {
+ **data.get("filters", {}),
+ **data.get("search_filters", {}),
+ }
+ super().__init__(**data)
+
+ def model_dump(self, *args, **kwargs):
+ return super().model_dump(*args, **kwargs)
+
+ @classmethod
+ def get_default(cls, mode: str) -> "SearchSettings":
+ """Return default search settings for a given mode."""
+ if mode == "basic":
+ # A simpler search that relies primarily on semantic search.
+ return cls(
+ use_semantic_search=True,
+ use_fulltext_search=False,
+ use_hybrid_search=False,
+ search_strategy="vanilla",
+ # Other relevant defaults can be provided here as needed
+ )
+ elif mode == "advanced":
+ # A more powerful, combined search that leverages both semantic and fulltext.
+ return cls(
+ use_semantic_search=True,
+ use_fulltext_search=True,
+ use_hybrid_search=True,
+ search_strategy="hyde",
+ # Other advanced defaults as needed
+ )
+ else:
+ # For 'custom' or unrecognized modes, return a basic empty config.
+ return cls()
+
+
+class SearchMode(str, Enum):
+ """Search modes for the search endpoint."""
+
+ basic = "basic"
+ advanced = "advanced"
+ custom = "custom"
+
+
+def select_search_filters(
+ auth_user: Any,
+ search_settings: SearchSettings,
+) -> dict[str, Any]:
+ filters = copy(search_settings.filters)
+ selected_collections = None
+ if not auth_user.is_superuser:
+ user_collections = set(auth_user.collection_ids)
+ for key in filters.keys():
+ if "collection_ids" in key:
+ selected_collections = set(map(UUID, filters[key]["$overlap"]))
+ break
+
+ if selected_collections:
+ allowed_collections = user_collections.intersection(
+ selected_collections
+ )
+ else:
+ allowed_collections = user_collections
+ # for non-superusers, we filter by user_id and selected & allowed collections
+ collection_filters = {
+ "$or": [
+ {"owner_id": {"$eq": auth_user.id}},
+ {"collection_ids": {"$overlap": list(allowed_collections)}},
+ ] # type: ignore
+ }
+
+ filters.pop("collection_ids", None)
+ if filters != {}:
+ filters = {"$and": [collection_filters, filters]} # type: ignore
+ else:
+ filters = collection_filters
+ return filters
diff --git a/.venv/lib/python3.12/site-packages/shared/abstractions/user.py b/.venv/lib/python3.12/site-packages/shared/abstractions/user.py
new file mode 100644
index 00000000..b04ac50b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/shared/abstractions/user.py
@@ -0,0 +1,69 @@
+from datetime import datetime
+from typing import Optional
+from uuid import UUID
+
+from pydantic import BaseModel, Field
+
+from shared.abstractions import R2RSerializable
+
+from ..utils import generate_default_user_collection_id
+
+
+class Collection(BaseModel):
+ id: UUID
+ name: str
+ description: Optional[str] = None
+ created_at: datetime = Field(
+ default_factory=datetime.utcnow,
+ )
+ updated_at: datetime = Field(
+ default_factory=datetime.utcnow,
+ )
+
+ class Config:
+ populate_by_name = True
+ from_attributes = True
+
+ def __init__(self, **data):
+ super().__init__(**data)
+ if self.id is None:
+ self.id = generate_default_user_collection_id(self.name)
+
+
+class Token(BaseModel):
+ token: str
+ token_type: str
+
+
+class TokenData(BaseModel):
+ email: str
+ token_type: str
+ exp: datetime
+
+
+class User(R2RSerializable):
+ id: UUID
+ email: str
+ is_active: bool = True
+ is_superuser: bool = False
+ created_at: datetime = datetime.now()
+ updated_at: datetime = datetime.now()
+ is_verified: bool = False
+ collection_ids: list[UUID] = []
+ graph_ids: list[UUID] = []
+ document_ids: list[UUID] = []
+
+ # Optional fields (to update or set at creation)
+ limits_overrides: Optional[dict] = None
+ metadata: Optional[dict] = None
+ verification_code_expiry: Optional[datetime] = None
+ name: Optional[str] = None
+ bio: Optional[str] = None
+ profile_picture: Optional[str] = None
+ total_size_in_bytes: Optional[int] = None
+ num_files: Optional[int] = None
+
+ account_type: str = "password"
+ hashed_password: Optional[str] = None
+ google_id: Optional[str] = None
+ github_id: Optional[str] = None
diff --git a/.venv/lib/python3.12/site-packages/shared/abstractions/vector.py b/.venv/lib/python3.12/site-packages/shared/abstractions/vector.py
new file mode 100644
index 00000000..0b88a765
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/shared/abstractions/vector.py
@@ -0,0 +1,239 @@
+"""Abstraction for a vector that can be stored in the system."""
+
+from enum import Enum
+from typing import Any, Optional
+from uuid import UUID
+
+from pydantic import BaseModel, Field
+
+from .base import R2RSerializable
+
+
+class VectorType(str, Enum):
+ FIXED = "FIXED"
+
+
+class IndexMethod(str, Enum):
+ """An enum representing the index methods available.
+
+ This class currently only supports the 'ivfflat' method but may
+ expand in the future.
+
+ Attributes:
+ auto (str): Automatically choose the best available index method.
+ ivfflat (str): The ivfflat index method.
+ hnsw (str): The hnsw index method.
+ """
+
+ auto = "auto"
+ ivfflat = "ivfflat"
+ hnsw = "hnsw"
+
+ def __str__(self) -> str:
+ return self.value
+
+
+class IndexMeasure(str, Enum):
+ """An enum representing the types of distance measures available for
+ indexing.
+
+ Attributes:
+ cosine_distance (str): The cosine distance measure for indexing.
+ l2_distance (str): The Euclidean (L2) distance measure for indexing.
+ max_inner_product (str): The maximum inner product measure for indexing.
+ """
+
+ l2_distance = "l2_distance"
+ max_inner_product = "max_inner_product"
+ cosine_distance = "cosine_distance"
+ l1_distance = "l1_distance"
+ hamming_distance = "hamming_distance"
+ jaccard_distance = "jaccard_distance"
+
+ def __str__(self) -> str:
+ return self.value
+
+ @property
+ def ops(self) -> str:
+ return {
+ IndexMeasure.l2_distance: "_l2_ops",
+ IndexMeasure.max_inner_product: "_ip_ops",
+ IndexMeasure.cosine_distance: "_cosine_ops",
+ IndexMeasure.l1_distance: "_l1_ops",
+ IndexMeasure.hamming_distance: "_hamming_ops",
+ IndexMeasure.jaccard_distance: "_jaccard_ops",
+ }[self]
+
+ @property
+ def pgvector_repr(self) -> str:
+ return {
+ IndexMeasure.l2_distance: "<->",
+ IndexMeasure.max_inner_product: "<#>",
+ IndexMeasure.cosine_distance: "<=>",
+ IndexMeasure.l1_distance: "<+>",
+ IndexMeasure.hamming_distance: "<~>",
+ IndexMeasure.jaccard_distance: "<%>",
+ }[self]
+
+
+class IndexArgsIVFFlat(R2RSerializable):
+ """A class for arguments that can optionally be supplied to the index
+ creation method when building an IVFFlat type index.
+
+ Attributes:
+ nlist (int): The number of IVF centroids that the index should use
+ """
+
+ n_lists: int
+
+
+class IndexArgsHNSW(R2RSerializable):
+ """A class for arguments that can optionally be supplied to the index
+ creation method when building an HNSW type index.
+
+ Ref: https://github.com/pgvector/pgvector#index-options
+
+ Both attributes are Optional in case the user only wants to specify one and
+ leave the other as default
+
+ Attributes:
+ m (int): Maximum number of connections per node per layer (default: 16)
+ ef_construction (int): Size of the dynamic candidate list for
+ constructing the graph (default: 64)
+ """
+
+ m: Optional[int] = 16
+ ef_construction: Optional[int] = 64
+
+
+class VectorTableName(str, Enum):
+ """This enum represents the different tables where we store vectors."""
+
+ CHUNKS = "chunks"
+ ENTITIES_DOCUMENT = "documents_entities"
+ GRAPHS_ENTITIES = "graphs_entities"
+ # TODO: Add support for relationships
+ # TRIPLES = "relationship"
+ COMMUNITIES = "graphs_communities"
+
+ def __str__(self) -> str:
+ return self.value
+
+
+class VectorQuantizationType(str, Enum):
+ """An enum representing the types of quantization available for vectors.
+
+ Attributes:
+ FP32 (str): 32-bit floating point quantization.
+ FP16 (str): 16-bit floating point quantization.
+ INT1 (str): 1-bit integer quantization.
+ SPARSE (str): Sparse vector quantization.
+ """
+
+ FP32 = "FP32"
+ FP16 = "FP16"
+ INT1 = "INT1"
+ SPARSE = "SPARSE"
+
+ def __str__(self) -> str:
+ return self.value
+
+ @property
+ def db_type(self) -> str:
+ db_type_mapping = {
+ "FP32": "vector",
+ "FP16": "halfvec",
+ "INT1": "bit",
+ "SPARSE": "sparsevec",
+ }
+ return db_type_mapping[self.value]
+
+
+class VectorQuantizationSettings(R2RSerializable):
+ quantization_type: VectorQuantizationType = Field(
+ default=VectorQuantizationType.FP32
+ )
+
+
+class Vector(R2RSerializable):
+ """A vector with the option to fix the number of elements."""
+
+ data: list[float]
+ type: VectorType = Field(default=VectorType.FIXED)
+ length: int = Field(default=-1)
+
+ def __init__(self, **data):
+ super().__init__(**data)
+ if (
+ self.type == VectorType.FIXED
+ and self.length > 0
+ and len(self.data) != self.length
+ ):
+ raise ValueError(
+ f"Vector must be exactly {self.length} elements long."
+ )
+
+ def __repr__(self) -> str:
+ return (
+ f"Vector(data={self.data}, type={self.type}, length={self.length})"
+ )
+
+
+class VectorEntry(R2RSerializable):
+ """A vector entry that can be stored directly in supported vector
+ databases."""
+
+ id: UUID
+ document_id: UUID
+ owner_id: UUID
+ collection_ids: list[UUID]
+ vector: Vector
+ text: str
+ metadata: dict[str, Any]
+
+ def __str__(self) -> str:
+ """Return a string representation of the VectorEntry."""
+ return (
+ f"VectorEntry("
+ f"chunk_id={self.id}, "
+ f"document_id={self.document_id}, "
+ f"owner_id={self.owner_id}, "
+ f"collection_ids={self.collection_ids}, "
+ f"vector={self.vector}, "
+ f"text={self.text}, "
+ f"metadata={self.metadata})"
+ )
+
+ def __repr__(self) -> str:
+ """Return an unambiguous string representation of the VectorEntry."""
+ return self.__str__()
+
+
+class StorageResult(R2RSerializable):
+ """A result of a storage operation."""
+
+ success: bool
+ document_id: UUID
+ num_chunks: int = 0
+ error_message: Optional[str] = None
+
+ def __str__(self) -> str:
+ """Return a string representation of the StorageResult."""
+ return f"StorageResult(success={self.success}, error_message={self.error_message})"
+
+ def __repr__(self) -> str:
+ """Return an unambiguous string representation of the StorageResult."""
+ return self.__str__()
+
+
+class IndexConfig(BaseModel):
+ name: Optional[str] = Field(default=None)
+ table_name: Optional[str] = Field(default=VectorTableName.CHUNKS)
+ index_method: Optional[str] = Field(default=IndexMethod.hnsw)
+ index_measure: Optional[str] = Field(default=IndexMeasure.cosine_distance)
+ index_arguments: Optional[IndexArgsIVFFlat | IndexArgsHNSW] = Field(
+ default=None
+ )
+ index_name: Optional[str] = Field(default=None)
+ index_column: Optional[str] = Field(default=None)
+ concurrently: Optional[bool] = Field(default=True)
diff --git a/.venv/lib/python3.12/site-packages/shared/api/models/__init__.py b/.venv/lib/python3.12/site-packages/shared/api/models/__init__.py
new file mode 100644
index 00000000..2d39dab1
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/shared/api/models/__init__.py
@@ -0,0 +1,194 @@
+from shared.api.models.auth.responses import (
+ TokenResponse,
+ WrappedTokenResponse,
+)
+from shared.api.models.base import (
+ GenericBooleanResponse,
+ GenericMessageResponse,
+ PaginatedR2RResult,
+ R2RResults,
+ WrappedBooleanResponse,
+ WrappedGenericMessageResponse,
+)
+from shared.api.models.graph.responses import (
+ GraphResponse,
+ WrappedCommunitiesResponse,
+ WrappedCommunityResponse,
+ WrappedEntitiesResponse,
+ WrappedEntityResponse,
+ WrappedGraphResponse,
+ WrappedGraphsResponse,
+ WrappedRelationshipResponse,
+ WrappedRelationshipsResponse,
+)
+from shared.api.models.ingestion.responses import (
+ IngestionResponse,
+ WrappedIngestionResponse,
+ WrappedMetadataUpdateResponse,
+ WrappedUpdateResponse,
+ WrappedVectorIndexResponse,
+ WrappedVectorIndicesResponse,
+)
+from shared.api.models.management.responses import (
+ ChunkResponse,
+ CollectionResponse,
+ ConversationResponse,
+ MessageResponse,
+ PromptResponse,
+ ServerStats,
+ SettingsResponse,
+ WrappedAPIKeyResponse,
+ WrappedAPIKeysResponse,
+ WrappedChunkResponse,
+ WrappedChunksResponse,
+ WrappedCollectionResponse,
+ WrappedCollectionsResponse,
+ WrappedConversationMessagesResponse,
+ WrappedConversationResponse,
+ WrappedConversationsResponse,
+ WrappedDocumentResponse,
+ WrappedDocumentsResponse,
+ WrappedLimitsResponse,
+ WrappedLoginResponse,
+ WrappedMessageResponse,
+ WrappedPromptResponse,
+ WrappedPromptsResponse,
+ WrappedServerStatsResponse,
+ WrappedSettingsResponse,
+ WrappedUserResponse,
+ WrappedUsersResponse,
+)
+from shared.api.models.retrieval.responses import (
+ AgentEvent,
+ AgentResponse,
+ AggregateSearchResult,
+ Citation,
+ CitationData,
+ CitationEvent,
+ Delta,
+ DeltaPayload,
+ FinalAnswerData,
+ FinalAnswerEvent,
+ MessageData,
+ MessageDelta,
+ MessageEvent,
+ RAGEvent,
+ RAGResponse,
+ SearchResultsData,
+ SearchResultsEvent,
+ SSEEventBase,
+ ThinkingData,
+ ThinkingEvent,
+ ToolCallData,
+ ToolCallEvent,
+ ToolResultData,
+ ToolResultEvent,
+ UnknownEvent,
+ WrappedAgentResponse,
+ WrappedDocumentSearchResponse,
+ WrappedEmbeddingResponse,
+ WrappedLLMChatCompletion,
+ WrappedRAGResponse,
+ WrappedSearchResponse,
+ WrappedVectorSearchResponse,
+)
+
+__all__ = [
+ # Generic Responses
+ "SSEEventBase",
+ "SearchResultsData",
+ "SearchResultsEvent",
+ "MessageDelta",
+ "MessageData",
+ "MessageEvent",
+ "DeltaPayload",
+ "Delta",
+ "CitationData",
+ "CitationEvent",
+ "FinalAnswerData",
+ "FinalAnswerEvent",
+ "ToolCallData",
+ "ToolCallEvent",
+ "ToolResultData",
+ "ToolResultEvent",
+ "ThinkingData",
+ "ThinkingEvent",
+ "AgentEvent",
+ "RAGEvent",
+ "UnknownEvent",
+ # Auth Responses
+ "GenericMessageResponse",
+ "TokenResponse",
+ "WrappedTokenResponse",
+ "WrappedGenericMessageResponse",
+ # Ingestion Responses
+ "IngestionResponse",
+ "WrappedIngestionResponse",
+ "WrappedUpdateResponse",
+ "WrappedVectorIndexResponse",
+ "WrappedVectorIndicesResponse",
+ "WrappedMetadataUpdateResponse",
+ "GraphResponse",
+ "WrappedGraphResponse",
+ "WrappedGraphsResponse",
+ "WrappedEntityResponse",
+ "WrappedEntitiesResponse",
+ "WrappedRelationshipResponse",
+ "WrappedRelationshipsResponse",
+ "WrappedCommunityResponse",
+ "WrappedCommunitiesResponse",
+ # Management Responses
+ "PromptResponse",
+ "ServerStats",
+ "SettingsResponse",
+ "ChunkResponse",
+ "CollectionResponse",
+ "ConversationResponse",
+ "MessageResponse",
+ "WrappedServerStatsResponse",
+ "WrappedSettingsResponse",
+ # Document Responses
+ "WrappedDocumentResponse",
+ "WrappedDocumentsResponse",
+ # Collection Responses
+ "WrappedCollectionResponse",
+ "WrappedCollectionsResponse",
+ # Prompt Responses
+ "WrappedPromptResponse",
+ "WrappedPromptsResponse",
+ # Chunk Responses
+ "WrappedChunkResponse",
+ "WrappedChunksResponse",
+ # Conversation Responses
+ "WrappedConversationMessagesResponse",
+ "WrappedConversationResponse",
+ "WrappedConversationsResponse",
+ # User Responses
+ "WrappedUserResponse",
+ "WrappedAPIKeyResponse",
+ "WrappedLimitsResponse",
+ "WrappedAPIKeysResponse",
+ "WrappedLoginResponse",
+ "WrappedUsersResponse",
+ "WrappedMessageResponse",
+ # Base Responses
+ "PaginatedR2RResult",
+ "R2RResults",
+ "GenericBooleanResponse",
+ "GenericMessageResponse",
+ "WrappedBooleanResponse",
+ "WrappedGenericMessageResponse",
+ # TODO: Clean up the following responses
+ # Retrieval Responses
+ "RAGResponse",
+ "Citation",
+ "WrappedRAGResponse",
+ "AgentResponse",
+ "AggregateSearchResult",
+ "WrappedSearchResponse",
+ "WrappedDocumentSearchResponse",
+ "WrappedVectorSearchResponse",
+ "WrappedAgentResponse",
+ "WrappedLLMChatCompletion",
+ "WrappedEmbeddingResponse",
+]
diff --git a/.venv/lib/python3.12/site-packages/shared/api/models/auth/__init__.py b/.venv/lib/python3.12/site-packages/shared/api/models/auth/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/shared/api/models/auth/__init__.py
diff --git a/.venv/lib/python3.12/site-packages/shared/api/models/auth/responses.py b/.venv/lib/python3.12/site-packages/shared/api/models/auth/responses.py
new file mode 100644
index 00000000..2d448945
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/shared/api/models/auth/responses.py
@@ -0,0 +1,13 @@
+from pydantic import BaseModel
+
+from shared.abstractions import Token
+from shared.api.models.base import R2RResults
+
+
+class TokenResponse(BaseModel):
+ access_token: Token
+ refresh_token: Token
+
+
+# Create wrapped versions of each response
+WrappedTokenResponse = R2RResults[TokenResponse]
diff --git a/.venv/lib/python3.12/site-packages/shared/api/models/base.py b/.venv/lib/python3.12/site-packages/shared/api/models/base.py
new file mode 100644
index 00000000..e0493d0b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/shared/api/models/base.py
@@ -0,0 +1,26 @@
+from typing import Generic, TypeVar
+
+from pydantic import BaseModel
+
+T = TypeVar("T")
+
+
+class R2RResults(BaseModel, Generic[T]):
+ results: T
+
+
+class PaginatedR2RResult(BaseModel, Generic[T]):
+ results: T
+ total_entries: int
+
+
+class GenericBooleanResponse(BaseModel):
+ success: bool
+
+
+class GenericMessageResponse(BaseModel):
+ message: str
+
+
+WrappedBooleanResponse = R2RResults[GenericBooleanResponse]
+WrappedGenericMessageResponse = R2RResults[GenericMessageResponse]
diff --git a/.venv/lib/python3.12/site-packages/shared/api/models/graph/__init__.py b/.venv/lib/python3.12/site-packages/shared/api/models/graph/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/shared/api/models/graph/__init__.py
diff --git a/.venv/lib/python3.12/site-packages/shared/api/models/graph/responses.py b/.venv/lib/python3.12/site-packages/shared/api/models/graph/responses.py
new file mode 100644
index 00000000..a2272833
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/shared/api/models/graph/responses.py
@@ -0,0 +1,31 @@
+from datetime import datetime
+from typing import Optional
+from uuid import UUID
+
+from pydantic import BaseModel
+
+from shared.abstractions.graph import Community, Entity, Relationship
+from shared.api.models.base import PaginatedR2RResult, R2RResults
+
+WrappedEntityResponse = R2RResults[Entity]
+WrappedEntitiesResponse = PaginatedR2RResult[list[Entity]]
+WrappedRelationshipResponse = R2RResults[Relationship]
+WrappedRelationshipsResponse = PaginatedR2RResult[list[Relationship]]
+WrappedCommunityResponse = R2RResults[Community]
+WrappedCommunitiesResponse = PaginatedR2RResult[list[Community]]
+
+
+class GraphResponse(BaseModel):
+ id: UUID
+ collection_id: UUID
+ name: str
+ description: Optional[str]
+ status: str
+ created_at: datetime
+ updated_at: datetime
+ document_ids: list[UUID]
+
+
+# Graph Responses
+WrappedGraphResponse = R2RResults[GraphResponse]
+WrappedGraphsResponse = PaginatedR2RResult[list[GraphResponse]]
diff --git a/.venv/lib/python3.12/site-packages/shared/api/models/ingestion/__init__.py b/.venv/lib/python3.12/site-packages/shared/api/models/ingestion/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/shared/api/models/ingestion/__init__.py
diff --git a/.venv/lib/python3.12/site-packages/shared/api/models/ingestion/responses.py b/.venv/lib/python3.12/site-packages/shared/api/models/ingestion/responses.py
new file mode 100644
index 00000000..091e48e7
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/shared/api/models/ingestion/responses.py
@@ -0,0 +1,72 @@
+from typing import Any, Optional, TypeVar
+from uuid import UUID
+
+from pydantic import BaseModel, Field
+
+from shared.api.models.base import PaginatedR2RResult, R2RResults
+
+T = TypeVar("T")
+
+
+class IngestionResponse(BaseModel):
+ message: str = Field(
+ ...,
+ description="A message describing the result of the ingestion request.",
+ )
+ task_id: Optional[UUID] = Field(
+ None,
+ description="The task ID of the ingestion request.",
+ )
+ document_id: UUID = Field(
+ ...,
+ description="The ID of the document that was ingested.",
+ )
+
+ class Config:
+ json_schema_extra = {
+ "example": {
+ "message": "Ingestion task queued successfully.",
+ "task_id": "c68dc72e-fc23-5452-8f49-d7bd46088a96",
+ "document_id": "9fbe403b-c11c-5aae-8ade-ef22980c3ad1",
+ }
+ }
+
+
+class UpdateResponse(BaseModel):
+ message: str = Field(
+ ...,
+ description="A message describing the result of the ingestion request.",
+ )
+ task_id: Optional[UUID] = Field(
+ None,
+ description="The task ID of the ingestion request.",
+ )
+ document_ids: list[UUID] = Field(
+ ...,
+ description="The ID of the document that was ingested.",
+ )
+
+ class Config:
+ json_schema_extra = {
+ "example": {
+ "message": "Update task queued successfully.",
+ "task_id": "c68dc72e-fc23-5452-8f49-d7bd46088a96",
+ "document_ids": ["9fbe403b-c11c-5aae-8ade-ef22980c3ad1"],
+ }
+ }
+
+
+class VectorIndexResponse(BaseModel):
+ index: dict[str, Any]
+
+
+class VectorIndicesResponse(BaseModel):
+ indices: list[VectorIndexResponse]
+
+
+WrappedIngestionResponse = R2RResults[IngestionResponse]
+WrappedMetadataUpdateResponse = R2RResults[IngestionResponse]
+WrappedUpdateResponse = R2RResults[UpdateResponse]
+
+WrappedVectorIndexResponse = R2RResults[VectorIndexResponse]
+WrappedVectorIndicesResponse = PaginatedR2RResult[VectorIndicesResponse]
diff --git a/.venv/lib/python3.12/site-packages/shared/api/models/management/__init__.py b/.venv/lib/python3.12/site-packages/shared/api/models/management/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/shared/api/models/management/__init__.py
diff --git a/.venv/lib/python3.12/site-packages/shared/api/models/management/responses.py b/.venv/lib/python3.12/site-packages/shared/api/models/management/responses.py
new file mode 100644
index 00000000..5e8b67a2
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/shared/api/models/management/responses.py
@@ -0,0 +1,168 @@
+from datetime import datetime
+from typing import Any, Optional
+from uuid import UUID
+
+from pydantic import BaseModel
+
+from shared.abstractions.document import DocumentResponse
+from shared.abstractions.llm import Message
+from shared.abstractions.user import Token, User
+from shared.api.models.base import PaginatedR2RResult, R2RResults
+
+
+class PromptResponse(BaseModel):
+ id: UUID
+ name: str
+ template: str
+ created_at: datetime
+ updated_at: datetime
+ input_types: dict[str, str]
+
+
+class ServerStats(BaseModel):
+ start_time: datetime
+ uptime_seconds: float
+ cpu_usage: float
+ memory_usage: float
+
+
+class SettingsResponse(BaseModel):
+ config: dict[str, Any]
+ prompts: dict[str, Any]
+ r2r_project_name: str
+ # r2r_version: str
+
+
+class ChunkResponse(BaseModel):
+ id: UUID
+ document_id: UUID
+ owner_id: UUID
+ collection_ids: list[UUID]
+ text: str
+ metadata: dict[str, Any]
+ vector: Optional[list[float]] = None
+
+
+class CollectionResponse(BaseModel):
+ id: UUID
+ owner_id: Optional[UUID]
+ name: str
+ description: Optional[str]
+ graph_cluster_status: str
+ graph_sync_status: str
+ created_at: datetime
+ updated_at: datetime
+ user_count: int
+ document_count: int
+
+
+class ConversationResponse(BaseModel):
+ id: UUID
+ created_at: datetime
+ user_id: Optional[UUID] = None
+ name: Optional[str] = None
+
+
+class MessageResponse(BaseModel):
+ id: UUID
+ message: Message
+ metadata: dict[str, Any] = {}
+
+
+class ApiKey(BaseModel):
+ public_key: str
+ api_key: str
+ key_id: str
+ name: Optional[str] = None
+
+
+class ApiKeyNoPriv(BaseModel):
+ public_key: str
+ key_id: str
+ name: Optional[str] = None
+ updated_at: datetime
+ description: Optional[str] = None
+
+
+class LoginResponse(BaseModel):
+ access_token: Token
+ refresh_token: Token
+
+
+class UsageLimit(BaseModel):
+ used: int
+ limit: int
+ remaining: int
+
+
+class StorageTypeLimit(BaseModel):
+ limit: int
+ used: int
+ remaining: int
+
+
+class StorageLimits(BaseModel):
+ chunks: StorageTypeLimit
+ documents: StorageTypeLimit
+ collections: StorageTypeLimit
+
+
+class RouteUsage(BaseModel):
+ route_per_min: UsageLimit
+ monthly_limit: UsageLimit
+
+
+class Usage(BaseModel):
+ global_per_min: UsageLimit
+ monthly_limit: UsageLimit
+ routes: dict[str, RouteUsage]
+
+
+class SystemDefaults(BaseModel):
+ global_per_min: int
+ route_per_min: Optional[int]
+ monthly_limit: int
+
+
+class LimitsResponse(BaseModel):
+ storage_limits: StorageLimits
+ system_defaults: SystemDefaults
+ user_overrides: dict
+ effective_limits: SystemDefaults
+ usage: Usage
+
+
+# Chunk Responses
+WrappedChunkResponse = R2RResults[ChunkResponse]
+WrappedChunksResponse = PaginatedR2RResult[list[ChunkResponse]]
+
+# Collection Responses
+WrappedCollectionResponse = R2RResults[CollectionResponse]
+WrappedCollectionsResponse = PaginatedR2RResult[list[CollectionResponse]]
+
+# Conversation Responses
+WrappedConversationMessagesResponse = R2RResults[list[MessageResponse]]
+WrappedConversationResponse = R2RResults[ConversationResponse]
+WrappedConversationsResponse = PaginatedR2RResult[list[ConversationResponse]]
+WrappedMessageResponse = R2RResults[MessageResponse]
+WrappedMessagesResponse = PaginatedR2RResult[list[MessageResponse]]
+
+# Document Responses
+WrappedDocumentResponse = R2RResults[DocumentResponse]
+WrappedDocumentsResponse = PaginatedR2RResult[list[DocumentResponse]]
+
+# Prompt Responses
+WrappedPromptResponse = R2RResults[PromptResponse]
+WrappedPromptsResponse = PaginatedR2RResult[list[PromptResponse]]
+
+# System Responses
+WrappedSettingsResponse = R2RResults[SettingsResponse]
+WrappedServerStatsResponse = R2RResults[ServerStats]
+
+# User Responses
+WrappedUserResponse = R2RResults[User]
+WrappedUsersResponse = PaginatedR2RResult[list[User]]
+WrappedAPIKeyResponse = R2RResults[ApiKey]
+WrappedAPIKeysResponse = PaginatedR2RResult[list[ApiKeyNoPriv]]
+WrappedLoginResponse = R2RResults[LoginResponse]
+WrappedLimitsResponse = R2RResults[LimitsResponse]
diff --git a/.venv/lib/python3.12/site-packages/shared/api/models/retrieval/__init__.py b/.venv/lib/python3.12/site-packages/shared/api/models/retrieval/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/shared/api/models/retrieval/__init__.py
diff --git a/.venv/lib/python3.12/site-packages/shared/api/models/retrieval/responses.py b/.venv/lib/python3.12/site-packages/shared/api/models/retrieval/responses.py
new file mode 100644
index 00000000..f695ebfb
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/shared/api/models/retrieval/responses.py
@@ -0,0 +1,604 @@
+from typing import Any, Literal, Optional
+
+from pydantic import BaseModel, Field
+
+from shared.abstractions import (
+ AggregateSearchResult,
+ ChunkSearchResult,
+ GraphSearchResult,
+ LLMChatCompletion,
+ Message,
+ WebPageSearchResult,
+)
+from shared.api.models.base import R2RResults
+from shared.api.models.management.responses import DocumentResponse
+
+from ....abstractions import R2RSerializable
+
+
+class CitationSpan(R2RSerializable):
+ """Represents a single occurrence of a citation in text."""
+
+ start_index: int = Field(
+ ..., description="Starting character index of the citation"
+ )
+ end_index: int = Field(
+ ..., description="Ending character index of the citation"
+ )
+ context_start: int = Field(
+ ..., description="Starting index of the surrounding context"
+ )
+ context_end: int = Field(
+ ..., description="Ending index of the surrounding context"
+ )
+
+
+class Citation(R2RSerializable):
+ """
+ Represents a citation reference in the RAG response.
+
+ The first time a citation appears, it includes the full payload.
+ Subsequent appearances only include the citation ID and span information.
+ """
+
+ # Basic identification
+ id: str = Field(
+ ..., description="The short ID of the citation (e.g., 'e41ac2d')"
+ )
+ object: str = Field(
+ "citation", description="The type of object, always 'citation'"
+ )
+
+ # Optimize payload delivery
+ is_new: bool = Field(
+ True,
+ description="Whether this is the first occurrence of this citation",
+ )
+
+ # Position information
+ span: Optional[CitationSpan] = Field(
+ None, description="Position of this citation occurrence in the text"
+ )
+
+ # Source information - only included for first occurrence
+ source_type: Optional[str] = Field(
+ None, description="Type of source: 'chunk', 'graph', 'web', or 'doc'"
+ )
+
+ # Full payload - only included for first occurrence
+ payload: (
+ ChunkSearchResult
+ | GraphSearchResult
+ | WebPageSearchResult
+ | DocumentResponse
+ | dict[str, Any]
+ | None
+ ) = Field(
+ None,
+ description="The complete source object (only included for new citations)",
+ )
+
+ class Config:
+ extra = "ignore"
+ json_schema_extra = {
+ "example": {
+ "id": "e41ac2d",
+ "object": "citation",
+ "is_new": True,
+ "span": {
+ "start_index": 120,
+ "end_index": 129,
+ "context_start": 80,
+ "context_end": 180,
+ },
+ "source_type": "chunk",
+ "payload": {
+ "id": "e41ac2d1-full-id",
+ "text": "The study found significant improvements...",
+ "metadata": {"title": "Research Paper"},
+ },
+ }
+ }
+
+
+# class Citation(R2RSerializable):
+# """Represents a single citation reference in the RAG response.
+
+# Combines both bracket metadata (start/end offsets, snippet range) and the
+# mapped source fields (id, doc ID, chunk text, etc.).
+# """
+
+# # Bracket references
+# id: str = Field(..., description="The ID of the citation object")
+# object: str = Field(
+# ...,
+# description="The type of object, e.g. `citation`",
+# )
+# payload: (
+# ChunkSearchResult
+# | GraphSearchResult
+# | WebPageSearchResult
+# | DocumentResponse
+# | None
+# ) = Field(
+# ..., description="The object payload and it's corresponding type"
+# )
+
+# class Config:
+# extra = "ignore" # This tells Pydantic to ignore extra fields
+# json_schema_extra = {
+# "example": {
+# "id": "cit.abcd123",
+# "object": "citation",
+# "payload": "ChunkSearchResult(...)",
+# }
+# }
+
+
+class RAGResponse(R2RSerializable):
+ generated_answer: str = Field(
+ ..., description="The generated completion from the RAG process"
+ )
+ search_results: AggregateSearchResult = Field(
+ ..., description="The search results used for the RAG process"
+ )
+ citations: Optional[list[Citation]] = Field(
+ None,
+ description="Structured citation metadata, if you do citation extraction.",
+ )
+ metadata: dict = Field(
+ default_factory=dict,
+ description="Additional data returned by the LLM provider",
+ )
+ completion: str = Field(
+ ...,
+ description="The generated completion from the RAG process",
+ # deprecated=True,
+ )
+
+ class Config:
+ json_schema_extra = {
+ "example": {
+ "generated_answer": "The capital of France is Paris.",
+ "search_results": {
+ "chunk_search_results": [
+ {
+ "index": 1,
+ "start_index": 25,
+ "end_index": 28,
+ "uri": "https://example.com/doc1",
+ "title": "example_document_1.pdf",
+ "license": "CC-BY-4.0",
+ }
+ ],
+ "graph_search_results": [
+ {
+ "content": {
+ "id": "3f3d47f3-8baf-58eb-8bc2-0171fb1c6e09",
+ "name": "Entity Name",
+ "description": "Entity Description",
+ "metadata": {},
+ },
+ "result_type": "entity",
+ "chunk_ids": [
+ "c68dc72e-fc23-5452-8f49-d7bd46088a96"
+ ],
+ "metadata": {
+ "associated_query": "What is the capital of France?"
+ },
+ }
+ ],
+ "web_search_results": [
+ {
+ "title": "Page Title",
+ "link": "https://example.com/page",
+ "snippet": "Page snippet",
+ "position": 1,
+ "date": "2021-01-01",
+ "sitelinks": [
+ {
+ "title": "Sitelink Title",
+ "link": "https://example.com/sitelink",
+ }
+ ],
+ }
+ ],
+ "document_search_results": [
+ {
+ "document": {
+ "id": "3f3d47f3-8baf-58eb-8bc2-0171fb1c6e09",
+ "title": "Document Title",
+ "chunks": ["Chunk 1", "Chunk 2"],
+ "metadata": {},
+ },
+ }
+ ],
+ },
+ "citations": [
+ {
+ "index": 1,
+ "rawIndex": 9,
+ "startIndex": 393,
+ "endIndex": 396,
+ "snippetStartIndex": 320,
+ "snippetEndIndex": 418,
+ "sourceType": "chunk",
+ "id": "e760bb76-1c6e-52eb-910d-0ce5b567011b",
+ "document_id": "e43864f5-a36f-548e-aacd-6f8d48b30c7f",
+ "owner_id": "2acb499e-8428-543b-bd85-0d9098718220",
+ "collection_ids": [
+ "122fdf6a-e116-546b-a8f6-e4cb2e2c0a09"
+ ],
+ "score": 0.64,
+ "text": "Document Title: DeepSeek_R1.pdf\n\nText: could achieve an accuracy of ...",
+ "metadata": {
+ "title": "DeepSeek_R1.pdf",
+ "license": "CC-BY-4.0",
+ "chunk_order": 68,
+ "document_type": "pdf",
+ },
+ }
+ ],
+ "metadata": {
+ "id": "chatcmpl-example123",
+ "choices": [
+ {
+ "finish_reason": "stop",
+ "index": 0,
+ "message": {"role": "assistant"},
+ }
+ ],
+ },
+ "completion": "TO BE DEPRECATED",
+ }
+ }
+
+
+class AgentResponse(R2RSerializable):
+ messages: list[Message] = Field(..., description="Agent response messages")
+ conversation_id: str = Field(
+ ..., description="The conversation ID for the RAG agent response"
+ )
+
+ class Config:
+ json_schema_extra = {
+ "example": {
+ "messages": [
+ {
+ "role": "assistant",
+ "content": """Aristotle (384–322 BC) was an Ancient
+ Greek philosopher and polymath whose contributions
+ have had a profound impact on various fields of
+ knowledge.
+ Here are some key points about his life and work:
+ \n\n1. **Early Life**: Aristotle was born in 384 BC in
+ Stagira, Chalcidice, which is near modern-day
+ Thessaloniki, Greece. His father, Nicomachus, was the
+ personal physician to King Amyntas of Macedon, which
+ exposed Aristotle to medical and biological knowledge
+ from a young age [C].\n\n2. **Education and Career**:
+ After the death of his parents, Aristotle was sent to
+ Athens to study at Plato's Academy, where he remained
+ for about 20 years. After Plato's death, Aristotle
+ left Athens and eventually became the tutor of
+ Alexander the Great [C].
+ \n\n3. **Philosophical Contributions**: Aristotle
+ founded the Lyceum in Athens, where he established the
+ Peripatetic school of philosophy. His works cover a
+ wide range of subjects, including metaphysics, ethics,
+ politics, logic, biology, and aesthetics. His writings
+ laid the groundwork for many modern scientific and
+ philosophical inquiries [A].\n\n4. **Legacy**:
+ Aristotle's influence extends beyond philosophy to the
+ natural sciences, linguistics, economics, and
+ psychology. His method of systematic observation and
+ analysis has been foundational to the development of
+ modern science [A].\n\nAristotle's comprehensive
+ approach to knowledge and his systematic methodology
+ have earned him a lasting legacy as one of the
+ greatest philosophers of all time.\n\nSources:
+ \n- [A] Aristotle's broad range of writings and
+ influence on modern science.\n- [C] Details about
+ Aristotle's early life and education.""",
+ "name": None,
+ "function_call": None,
+ "tool_calls": None,
+ "metadata": {
+ "citations": [
+ {
+ "index": 1,
+ "rawIndex": 9,
+ "startIndex": 393,
+ "endIndex": 396,
+ "snippetStartIndex": 320,
+ "snippetEndIndex": 418,
+ "sourceType": "chunk",
+ "id": "e760bb76-1c6e-52eb-910d-0ce5b567011b",
+ "document_id": """
+ e43864f5-a36f-548e-aacd-6f8d48b30c7f
+ """,
+ "owner_id": """
+ 2acb499e-8428-543b-bd85-0d9098718220
+ """,
+ "collection_ids": [
+ "122fdf6a-e116-546b-a8f6-e4cb2e2c0a09"
+ ],
+ "score": 0.64,
+ "text": """
+ Document Title: DeepSeek_R1.pdf
+ \n\nText: could achieve an accuracy of ...
+ """,
+ "metadata": {
+ "title": "DeepSeek_R1.pdf",
+ "license": "CC-BY-4.0",
+ "chunk_order": 68,
+ "document_type": "pdf",
+ },
+ }
+ ],
+ "aggregated_search_results": {
+ "chunk_search_results": [
+ {
+ "id": "3f3d47f3-8baf-58eb-8bc2-0171fb1c6e09",
+ "document_id": "3e157b3a-8469-51db-90d9-52e7d896b49b",
+ "owner_id": "2acb499e-8428-543b-bd85-0d9098718220",
+ "collection_ids": [],
+ "score": 0.23943702876567796,
+ "text": "Example text from the document",
+ "metadata": {
+ "title": "example_document.pdf",
+ "associated_query": "What is the capital of France?",
+ },
+ }
+ ],
+ "graph_search_results": [
+ {
+ "content": {
+ "id": "3f3d47f3-8baf-58eb-8bc2-0171fb1c6e09",
+ "name": "Entity Name",
+ "description": "Entity Description",
+ "metadata": {},
+ },
+ "result_type": "entity",
+ "chunk_ids": [
+ "c68dc72e-fc23-5452-8f49-d7bd46088a96"
+ ],
+ "metadata": {
+ "associated_query": "What is the capital of France?"
+ },
+ }
+ ],
+ "web_search_results": [
+ {
+ "title": "Page Title",
+ "link": "https://example.com/page",
+ "snippet": "Page snippet",
+ "position": 1,
+ "date": "2021-01-01",
+ "sitelinks": [
+ {
+ "title": "Sitelink Title",
+ "link": "https://example.com/sitelink",
+ }
+ ],
+ }
+ ],
+ "document_search_results": [
+ {
+ "document": {
+ "id": "3f3d47f3-8baf-58eb-8bc2-0171fb1c6e09",
+ "title": "Document Title",
+ "chunks": ["Chunk 1", "Chunk 2"],
+ "metadata": {},
+ },
+ }
+ ],
+ },
+ },
+ },
+ ],
+ "conversation_id": "a32b4c5d-6e7f-8a9b-0c1d-2e3f4a5b6c7d",
+ }
+ }
+
+
+class DocumentSearchResult(BaseModel):
+ document_id: str = Field(
+ ...,
+ description="The document ID",
+ )
+ metadata: Optional[dict] = Field(
+ None,
+ description="The metadata of the document",
+ )
+ score: float = Field(
+ ...,
+ description="The score of the document",
+ )
+
+
+# A generic base model for SSE events
+class SSEEventBase(BaseModel):
+ event: str
+ data: Any
+
+
+# Model for the search results event
+class SearchResultsData(BaseModel):
+ id: str
+ object: str
+ data: AggregateSearchResult
+
+
+class SearchResultsEvent(SSEEventBase):
+ event: Literal["search_results"]
+ data: SearchResultsData
+
+
+class DeltaPayload(BaseModel):
+ value: str
+ annotations: list[Any]
+
+
+# Model for message events (partial tokens)
+class MessageDelta(BaseModel):
+ type: str
+ payload: DeltaPayload
+
+
+class Delta(BaseModel):
+ content: list[MessageDelta]
+
+
+class MessageData(BaseModel):
+ id: str
+ object: str
+ delta: Delta
+
+
+class MessageEvent(SSEEventBase):
+ event: Literal["message"]
+ data: MessageData
+
+
+# Update CitationSpan model for SSE events
+class CitationSpanData(BaseModel):
+ start: int = Field(
+ ..., description="Starting character index of the citation"
+ )
+ end: int = Field(..., description="Ending character index of the citation")
+ context_start: Optional[int] = Field(
+ None, description="Starting index of surrounding context"
+ )
+ context_end: Optional[int] = Field(
+ None, description="Ending index of surrounding context"
+ )
+
+
+# Update CitationData model
+class CitationData(BaseModel):
+ id: str = Field(
+ ..., description="The short ID of the citation (e.g., 'e41ac2d')"
+ )
+ object: str = Field(
+ "citation", description="The type of object, always 'citation'"
+ )
+
+ # New fields from the enhanced Citation model
+ is_new: Optional[bool] = Field(
+ None,
+ description="Whether this is the first occurrence of this citation",
+ )
+
+ span: Optional[CitationSpanData] = Field(
+ None, description="Position of this citation occurrence in the text"
+ )
+
+ source_type: Optional[str] = Field(
+ None, description="Type of source: 'chunk', 'graph', 'web', or 'doc'"
+ )
+
+ # Optional payload field, only for first occurrence
+ payload: Optional[Any] = Field(
+ None,
+ description="The complete source object (only included for new citations)",
+ )
+
+ # For backward compatibility, maintain the existing fields
+ class Config:
+ populate_by_name = True
+ extra = "ignore"
+
+
+# CitationEvent remains the same, but now using the updated CitationData
+class CitationEvent(SSEEventBase):
+ event: Literal["citation"]
+ data: CitationData
+
+
+# Model for the final answer event
+class FinalAnswerData(BaseModel):
+ generated_answer: str
+ citations: list[Citation] # refine if you have a citation model
+
+
+class FinalAnswerEvent(SSEEventBase):
+ event: Literal["final_answer"]
+ data: FinalAnswerData
+
+
+# "tool_call" event
+class ToolCallData(BaseModel):
+ tool_call_id: str
+ name: str
+ arguments: Any # If JSON arguments, use dict[str, Any], or str if needed
+
+
+class ToolCallEvent(SSEEventBase):
+ event: Literal["tool_call"]
+ data: ToolCallData
+
+
+# "tool_result" event
+class ToolResultData(BaseModel):
+ tool_call_id: str
+ role: Literal["tool", "function"]
+ content: str
+
+
+class ToolResultEvent(SSEEventBase):
+ event: Literal["tool_result"]
+ data: ToolResultData
+
+
+# Optionally, define a fallback model for unrecognized events
+class UnknownEvent(SSEEventBase):
+ pass
+
+
+# 1) Define a new ThinkingEvent type
+class ThinkingData(BaseModel):
+ id: str
+ object: str
+ delta: Delta
+
+
+class ThinkingEvent(SSEEventBase):
+ event: str = "thinking"
+ data: ThinkingData
+
+
+# Create a union type for all RAG events
+RAGEvent = (
+ SearchResultsEvent
+ | MessageEvent
+ | CitationEvent
+ | FinalAnswerEvent
+ | UnknownEvent
+ | ToolCallEvent
+ | ToolResultEvent
+ | ToolResultData
+ | ToolResultEvent
+)
+
+AgentEvent = (
+ ThinkingEvent
+ | SearchResultsEvent
+ | MessageEvent
+ | CitationEvent
+ | FinalAnswerEvent
+ | ToolCallEvent
+ | ToolResultEvent
+ | UnknownEvent
+)
+
+WrappedCompletionResponse = R2RResults[LLMChatCompletion]
+# Create wrapped versions of the responses
+WrappedVectorSearchResponse = R2RResults[list[ChunkSearchResult]]
+WrappedSearchResponse = R2RResults[AggregateSearchResult]
+# FIXME: This is returning DocumentResponse, but should be DocumentSearchResult
+WrappedDocumentSearchResponse = R2RResults[list[DocumentResponse]]
+WrappedRAGResponse = R2RResults[RAGResponse]
+WrappedAgentResponse = R2RResults[AgentResponse]
+WrappedLLMChatCompletion = R2RResults[LLMChatCompletion]
+WrappedEmbeddingResponse = R2RResults[list[float]]
diff --git a/.venv/lib/python3.12/site-packages/shared/utils/__init__.py b/.venv/lib/python3.12/site-packages/shared/utils/__init__.py
new file mode 100644
index 00000000..eb037e22
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/shared/utils/__init__.py
@@ -0,0 +1,46 @@
+from .base_utils import (
+ _decorate_vector_type,
+ _get_vector_column_str,
+ decrement_version,
+ deep_update,
+ dump_collector,
+ dump_obj,
+ format_search_results_for_llm,
+ generate_default_prompt_id,
+ generate_default_user_collection_id,
+ generate_document_id,
+ generate_entity_document_id,
+ generate_extraction_id,
+ generate_id,
+ generate_user_id,
+ increment_version,
+ validate_uuid,
+ yield_sse_event,
+)
+from .splitter.text import RecursiveCharacterTextSplitter, TextSplitter
+
+__all__ = [
+ "format_search_results_for_llm",
+ # ID generation
+ "generate_id",
+ "generate_document_id",
+ "generate_extraction_id",
+ "generate_default_user_collection_id",
+ "generate_user_id",
+ "generate_default_prompt_id",
+ "generate_entity_document_id",
+ # Other
+ "increment_version",
+ "decrement_version",
+ "validate_uuid",
+ "deep_update",
+ # Text splitter
+ "RecursiveCharacterTextSplitter",
+ "TextSplitter",
+ # Vector utils
+ "_decorate_vector_type",
+ "_get_vector_column_str",
+ "yield_sse_event",
+ "dump_collector",
+ "dump_obj",
+]
diff --git a/.venv/lib/python3.12/site-packages/shared/utils/base_utils.py b/.venv/lib/python3.12/site-packages/shared/utils/base_utils.py
new file mode 100644
index 00000000..1864d0b4
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/shared/utils/base_utils.py
@@ -0,0 +1,783 @@
+import json
+import logging
+import math
+import uuid
+from abc import ABCMeta
+from copy import deepcopy
+from datetime import datetime
+from typing import TYPE_CHECKING, Any, Optional, Tuple, TypeVar
+from uuid import NAMESPACE_DNS, UUID, uuid4, uuid5
+
+import tiktoken
+
+from ..abstractions import (
+ AggregateSearchResult,
+ AsyncSyncMeta,
+ GraphCommunityResult,
+ GraphEntityResult,
+ GraphRelationshipResult,
+)
+from ..abstractions.vector import VectorQuantizationType
+
+if TYPE_CHECKING:
+ pass
+
+
+logger = logging.getLogger()
+
+
+def id_to_shorthand(id: str | UUID):
+ return str(id)[:7]
+
+
+def format_search_results_for_llm(
+ results: AggregateSearchResult,
+ collector: Any, # SearchResultsCollector
+) -> str:
+ """
+ Instead of resetting 'source_counter' to 1, we:
+ - For each chunk / graph / web / doc in `results`,
+ - Find the aggregator index from the collector,
+ - Print 'Source [X]:' with that aggregator index.
+ """
+ lines = []
+
+ # We'll build a quick helper to locate aggregator indices for each object:
+ # Or you can rely on the fact that we've added them to the collector
+ # in the same order. But let's do a "lookup aggregator index" approach:
+
+ # 1) Chunk search
+ if results.chunk_search_results:
+ lines.append("Vector Search Results:")
+ for c in results.chunk_search_results:
+ lines.append(f"Source ID [{id_to_shorthand(c.id)}]:")
+ lines.append(c.text or "") # or c.text[:200] to truncate
+
+ # 2) Graph search
+ if results.graph_search_results:
+ lines.append("Graph Search Results:")
+ for g in results.graph_search_results:
+ lines.append(f"Source ID [{id_to_shorthand(g.id)}]:")
+ if isinstance(g.content, GraphCommunityResult):
+ lines.append(f"Community Name: {g.content.name}")
+ lines.append(f"ID: {g.content.id}")
+ lines.append(f"Summary: {g.content.summary}")
+ # etc. ...
+ elif isinstance(g.content, GraphEntityResult):
+ lines.append(f"Entity Name: {g.content.name}")
+ lines.append(f"Description: {g.content.description}")
+ elif isinstance(g.content, GraphRelationshipResult):
+ lines.append(
+ f"Relationship: {g.content.subject}-{g.content.predicate}-{g.content.object}"
+ )
+ # Add metadata if needed
+
+ # 3) Web search
+ if results.web_search_results:
+ lines.append("Web Search Results:")
+ for w in results.web_search_results:
+ lines.append(f"Source ID [{id_to_shorthand(w.id)}]:")
+ lines.append(f"Title: {w.title}")
+ lines.append(f"Link: {w.link}")
+ lines.append(f"Snippet: {w.snippet}")
+
+ # 4) Local context docs
+ if results.document_search_results:
+ lines.append("Local Context Documents:")
+ for doc_result in results.document_search_results:
+ doc_title = doc_result.title or "Untitled Document"
+ doc_id = doc_result.id
+ summary = doc_result.summary
+
+ lines.append(f"Full Document ID: {doc_id}")
+ lines.append(f"Shortened Document ID: {id_to_shorthand(doc_id)}")
+ lines.append(f"Document Title: {doc_title}")
+ if summary:
+ lines.append(f"Summary: {summary}")
+
+ if doc_result.chunks:
+ # Then each chunk inside:
+ for chunk in doc_result.chunks:
+ lines.append(
+ f"\nChunk ID {id_to_shorthand(chunk['id'])}:\n{chunk['text']}"
+ )
+
+ result = "\n".join(lines)
+ return result
+
+
+def _generate_id_from_label(label) -> UUID:
+ return uuid5(NAMESPACE_DNS, label)
+
+
+def generate_id(label: Optional[str] = None) -> UUID:
+ """Generates a unique run id."""
+ return _generate_id_from_label(
+ label if label is not None else str(uuid4())
+ )
+
+
+def generate_document_id(filename: str, user_id: UUID) -> UUID:
+ """Generates a unique document id from a given filename and user id."""
+ safe_filename = filename.replace("/", "_")
+ return _generate_id_from_label(f"{safe_filename}-{str(user_id)}")
+
+
+def generate_extraction_id(
+ document_id: UUID, iteration: int = 0, version: str = "0"
+) -> UUID:
+ """Generates a unique extraction id from a given document id and
+ iteration."""
+ return _generate_id_from_label(f"{str(document_id)}-{iteration}-{version}")
+
+
+def generate_default_user_collection_id(user_id: UUID) -> UUID:
+ """Generates a unique collection id from a given user id."""
+ return _generate_id_from_label(str(user_id))
+
+
+def generate_user_id(email: str) -> UUID:
+ """Generates a unique user id from a given email."""
+ return _generate_id_from_label(email)
+
+
+def generate_default_prompt_id(prompt_name: str) -> UUID:
+ """Generates a unique prompt id."""
+ return _generate_id_from_label(prompt_name)
+
+
+def generate_entity_document_id() -> UUID:
+ """Generates a unique document id inserting entities into a graph."""
+ generation_time = datetime.now().isoformat()
+ return _generate_id_from_label(f"entity-{generation_time}")
+
+
+def increment_version(version: str) -> str:
+ prefix = version[:-1]
+ suffix = int(version[-1])
+ return f"{prefix}{suffix + 1}"
+
+
+def decrement_version(version: str) -> str:
+ prefix = version[:-1]
+ suffix = int(version[-1])
+ return f"{prefix}{max(0, suffix - 1)}"
+
+
+def validate_uuid(uuid_str: str) -> UUID:
+ return UUID(uuid_str)
+
+
+def update_settings_from_dict(server_settings, settings_dict: dict):
+ """Updates a settings object with values from a dictionary."""
+ settings = deepcopy(server_settings)
+ for key, value in settings_dict.items():
+ if value is not None:
+ if isinstance(value, dict):
+ for k, v in value.items():
+ if isinstance(getattr(settings, key), dict):
+ getattr(settings, key)[k] = v
+ else:
+ setattr(getattr(settings, key), k, v)
+ else:
+ setattr(settings, key, value)
+
+ return settings
+
+
+def _decorate_vector_type(
+ input_str: str,
+ quantization_type: VectorQuantizationType = VectorQuantizationType.FP32,
+) -> str:
+ return f"{quantization_type.db_type}{input_str}"
+
+
+def _get_vector_column_str(
+ dimension: int | float, quantization_type: VectorQuantizationType
+) -> str:
+ """Returns a string representation of a vector column type.
+
+ Explicitly handles the case where the dimension is not a valid number meant
+ to support embedding models that do not allow for specifying the dimension.
+ """
+ if math.isnan(dimension) or dimension <= 0:
+ vector_dim = "" # Allows for Postgres to handle any dimension
+ else:
+ vector_dim = f"({dimension})"
+ return _decorate_vector_type(vector_dim, quantization_type)
+
+
+KeyType = TypeVar("KeyType")
+
+
+def deep_update(
+ mapping: dict[KeyType, Any], *updating_mappings: dict[KeyType, Any]
+) -> dict[KeyType, Any]:
+ """
+ Taken from Pydantic v1:
+ https://github.com/pydantic/pydantic/blob/fd2991fe6a73819b48c906e3c3274e8e47d0f761/pydantic/utils.py#L200
+ """
+ updated_mapping = mapping.copy()
+ for updating_mapping in updating_mappings:
+ for k, v in updating_mapping.items():
+ if (
+ k in updated_mapping
+ and isinstance(updated_mapping[k], dict)
+ and isinstance(v, dict)
+ ):
+ updated_mapping[k] = deep_update(updated_mapping[k], v)
+ else:
+ updated_mapping[k] = v
+ return updated_mapping
+
+
+def tokens_count_for_message(message, encoding):
+ """Return the number of tokens used by a single message."""
+ tokens_per_message = 3
+
+ num_tokens = 0
+ num_tokens += tokens_per_message
+ if message.get("function_call"):
+ num_tokens += len(encoding.encode(message["function_call"]["name"]))
+ num_tokens += len(
+ encoding.encode(message["function_call"]["arguments"])
+ )
+ elif message.get("tool_calls"):
+ for tool_call in message["tool_calls"]:
+ num_tokens += len(encoding.encode(tool_call["function"]["name"]))
+ num_tokens += len(
+ encoding.encode(tool_call["function"]["arguments"])
+ )
+ else:
+ if "content" in message:
+ num_tokens += len(encoding.encode(message["content"]))
+
+ return num_tokens
+
+
+def num_tokens_from_messages(messages, model="gpt-4o"):
+ """Return the number of tokens used by a list of messages for both user and assistant."""
+ try:
+ encoding = tiktoken.encoding_for_model(model)
+ except KeyError:
+ logger.warning("Warning: model not found. Using cl100k_base encoding.")
+ encoding = tiktoken.get_encoding("cl100k_base")
+
+ tokens = 0
+ for message_ in messages:
+ tokens += tokens_count_for_message(message_, encoding)
+
+ tokens += 3 # every reply is primed with assistant
+ return tokens
+
+
+class SearchResultsCollector:
+ """
+ Collects search results in the form (source_type, result_obj).
+ Handles both object-oriented and dictionary-based search results.
+ """
+
+ def __init__(self):
+ # We'll store a list of (source_type, result_obj)
+ self._results_in_order = []
+
+ @property
+ def results(self):
+ """Get the results list"""
+ return self._results_in_order
+
+ @results.setter
+ def results(self, value):
+ """
+ Set the results directly, with automatic type detection for 'unknown' items
+ Handles the format: [('unknown', {...}), ('unknown', {...})]
+ """
+ self._results_in_order = []
+
+ if isinstance(value, list):
+ for item in value:
+ if isinstance(item, tuple) and len(item) == 2:
+ source_type, result_obj = item
+
+ # Only auto-detect if the source type is "unknown"
+ if source_type == "unknown":
+ detected_type = self._detect_result_type(result_obj)
+ self._results_in_order.append(
+ (detected_type, result_obj)
+ )
+ else:
+ self._results_in_order.append(
+ (source_type, result_obj)
+ )
+ else:
+ # If not a tuple, detect and add
+ detected_type = self._detect_result_type(item)
+ self._results_in_order.append((detected_type, item))
+ else:
+ raise ValueError("Results must be a list")
+
+ def add_aggregate_result(self, agg):
+ """
+ Flatten the chunk_search_results, graph_search_results, web_search_results,
+ and document_search_results into the collector, including nested chunks.
+ """
+ if hasattr(agg, "chunk_search_results") and agg.chunk_search_results:
+ for c in agg.chunk_search_results:
+ self._results_in_order.append(("chunk", c))
+
+ if hasattr(agg, "graph_search_results") and agg.graph_search_results:
+ for g in agg.graph_search_results:
+ self._results_in_order.append(("graph", g))
+
+ if hasattr(agg, "web_search_results") and agg.web_search_results:
+ for w in agg.web_search_results:
+ self._results_in_order.append(("web", w))
+
+ # Add documents and extract their chunks
+ if (
+ hasattr(agg, "document_search_results")
+ and agg.document_search_results
+ ):
+ for doc in agg.document_search_results:
+ # Add the document itself
+ self._results_in_order.append(("doc", doc))
+
+ # Extract and add chunks from the document
+ chunks = None
+ if isinstance(doc, dict):
+ chunks = doc.get("chunks", [])
+ elif hasattr(doc, "chunks") and doc.chunks is not None:
+ chunks = doc.chunks
+
+ if chunks:
+ for chunk in chunks:
+ # Ensure each chunk has the minimum required attributes
+ if isinstance(chunk, dict) and "id" in chunk:
+ # Add the chunk directly to results for citation lookup
+ self._results_in_order.append(("chunk", chunk))
+ elif hasattr(chunk, "id"):
+ self._results_in_order.append(("chunk", chunk))
+
+ def add_result(self, result_obj, source_type=None):
+ """
+ Add a single result object to the collector.
+ If source_type is not provided, automatically detect the type.
+ """
+ if source_type:
+ self._results_in_order.append((source_type, result_obj))
+ return source_type
+
+ detected_type = self._detect_result_type(result_obj)
+ self._results_in_order.append((detected_type, result_obj))
+ return detected_type
+
+ def _detect_result_type(self, obj):
+ """
+ Detect the type of a result object based on its properties.
+ Works with both object attributes and dictionary keys.
+ """
+ # Handle dictionary types first (common for web search results)
+ if isinstance(obj, dict):
+ # Web search pattern
+ if all(k in obj for k in ["title", "link"]) and any(
+ k in obj for k in ["snippet", "description"]
+ ):
+ return "web"
+
+ # Check for graph dictionary patterns
+ if "content" in obj and isinstance(obj["content"], dict):
+ content = obj["content"]
+ if all(k in content for k in ["name", "description"]):
+ return "graph" # Entity
+ if all(
+ k in content for k in ["subject", "predicate", "object"]
+ ):
+ return "graph" # Relationship
+ if all(k in content for k in ["name", "summary"]):
+ return "graph" # Community
+
+ # Chunk pattern
+ if all(k in obj for k in ["text", "id"]) and any(
+ k in obj for k in ["score", "metadata"]
+ ):
+ return "chunk"
+
+ # Context document pattern
+ if "document" in obj and "chunks" in obj:
+ return "doc"
+
+ # Check for explicit type indicator
+ if "type" in obj:
+ type_val = str(obj["type"]).lower()
+ if any(t in type_val for t in ["web", "organic"]):
+ return "web"
+ if "graph" in type_val:
+ return "graph"
+ if "chunk" in type_val:
+ return "chunk"
+ if "document" in type_val:
+ return "doc"
+
+ # Handle object attributes for OOP-style results
+ if hasattr(obj, "result_type"):
+ result_type = str(obj.result_type).lower()
+ if result_type in ["entity", "relationship", "community"]:
+ return "graph"
+
+ # Check class name hints
+ class_name = obj.__class__.__name__
+ if "Graph" in class_name:
+ return "graph"
+ if "Chunk" in class_name:
+ return "chunk"
+ if "Web" in class_name:
+ return "web"
+ if "Document" in class_name:
+ return "doc"
+
+ # Check for object attribute patterns
+ if hasattr(obj, "content"):
+ content = obj.content
+ if hasattr(content, "name") and hasattr(content, "description"):
+ return "graph" # Entity
+ if hasattr(content, "subject") and hasattr(content, "predicate"):
+ return "graph" # Relationship
+ if hasattr(content, "name") and hasattr(content, "summary"):
+ return "graph" # Community
+
+ if (
+ hasattr(obj, "text")
+ and hasattr(obj, "id")
+ and (hasattr(obj, "score") or hasattr(obj, "metadata"))
+ ):
+ return "chunk"
+
+ if (
+ hasattr(obj, "title")
+ and hasattr(obj, "link")
+ and hasattr(obj, "snippet")
+ ):
+ return "web"
+
+ if hasattr(obj, "document") and hasattr(obj, "chunks"):
+ return "doc"
+
+ # Default when type can't be determined
+ return "unknown"
+
+ def find_by_short_id(self, short_id):
+ """Find a result by its short ID prefix with better chunk handling"""
+ if not short_id:
+ return None
+
+ # First try direct lookup using regular iteration
+ for _, result_obj in self._results_in_order:
+ # Check dictionary objects
+ if isinstance(result_obj, dict) and "id" in result_obj:
+ result_id = str(result_obj["id"])
+ if result_id.startswith(short_id):
+ return result_obj
+
+ # Check object with id attribute
+ elif hasattr(result_obj, "id"):
+ obj_id = getattr(result_obj, "id", None)
+ if obj_id and str(obj_id).startswith(short_id):
+ # Convert to dict if possible
+ if hasattr(result_obj, "as_dict"):
+ return result_obj.as_dict()
+ elif hasattr(result_obj, "model_dump"):
+ return result_obj.model_dump()
+ elif hasattr(result_obj, "dict"):
+ return result_obj.dict()
+ else:
+ return result_obj
+
+ # If not found, look for chunks inside documents that weren't extracted properly
+ for source_type, result_obj in self._results_in_order:
+ if source_type == "doc":
+ # Try various ways to access chunks
+ chunks = None
+ if isinstance(result_obj, dict) and "chunks" in result_obj:
+ chunks = result_obj["chunks"]
+ elif (
+ hasattr(result_obj, "chunks")
+ and result_obj.chunks is not None
+ ):
+ chunks = result_obj.chunks
+
+ if chunks:
+ for chunk in chunks:
+ # Try each chunk
+ chunk_id = None
+ if isinstance(chunk, dict) and "id" in chunk:
+ chunk_id = chunk["id"]
+ elif hasattr(chunk, "id"):
+ chunk_id = chunk.id
+
+ if chunk_id and str(chunk_id).startswith(short_id):
+ return chunk
+
+ return None
+
+ def get_results_by_type(self, type_name):
+ """Get all results of a specific type"""
+ return [
+ result_obj
+ for source_type, result_obj in self._results_in_order
+ if source_type == type_name
+ ]
+
+ def __repr__(self):
+ """String representation showing counts by type"""
+ type_counts = {}
+ for source_type, _ in self._results_in_order:
+ type_counts[source_type] = type_counts.get(source_type, 0) + 1
+
+ return f"SearchResultsCollector with {len(self._results_in_order)} results: {type_counts}"
+
+ def get_all_results(self) -> list[Tuple[str, Any]]:
+ """
+ Return list of (source_type, result_obj, aggregator_index),
+ in the order appended.
+ """
+ return self._results_in_order
+
+
+def convert_nonserializable_objects(obj):
+ if hasattr(obj, "model_dump"):
+ obj = obj.model_dump()
+ if hasattr(obj, "as_dict"):
+ obj = obj.as_dict()
+ if hasattr(obj, "to_dict"):
+ obj = obj.to_dict()
+
+ if isinstance(obj, dict):
+ new_obj = {}
+ for key, value in obj.items():
+ # Convert key to string if it is a UUID or not already a string.
+ new_key = str(key) if not isinstance(key, str) else key
+ new_obj[new_key] = convert_nonserializable_objects(value)
+ return new_obj
+ elif isinstance(obj, list):
+ return [convert_nonserializable_objects(item) for item in obj]
+ elif isinstance(obj, tuple):
+ return tuple(convert_nonserializable_objects(item) for item in obj)
+ elif isinstance(obj, set):
+ return {convert_nonserializable_objects(item) for item in obj}
+ elif isinstance(obj, uuid.UUID):
+ return str(obj)
+ elif isinstance(obj, datetime):
+ return obj.isoformat() # Convert datetime to ISO formatted string
+ else:
+ return obj
+
+
+def dump_obj(obj) -> list[dict[str, Any]]:
+ if hasattr(obj, "model_dump"):
+ obj = obj.model_dump()
+ elif hasattr(obj, "dict"):
+ obj = obj.dict()
+ elif hasattr(obj, "as_dict"):
+ obj = obj.as_dict()
+ elif hasattr(obj, "to_dict"):
+ obj = obj.to_dict()
+ obj = convert_nonserializable_objects(obj)
+
+ return obj
+
+
+def dump_collector(collector: SearchResultsCollector) -> list[dict[str, Any]]:
+ dumped = []
+ for source_type, result_obj in collector.get_all_results():
+ # Get the dictionary from the result object
+ if hasattr(result_obj, "model_dump"):
+ result_dict = result_obj.model_dump()
+ elif hasattr(result_obj, "dict"):
+ result_dict = result_obj.dict()
+ elif hasattr(result_obj, "as_dict"):
+ result_dict = result_obj.as_dict()
+ elif hasattr(result_obj, "to_dict"):
+ result_dict = result_obj.to_dict()
+ else:
+ result_dict = (
+ result_obj # Fallback if no conversion method is available
+ )
+
+ # Use the recursive conversion on the entire dictionary
+ result_dict = convert_nonserializable_objects(result_dict)
+
+ dumped.append(
+ {
+ "source_type": source_type,
+ "result": result_dict,
+ }
+ )
+ return dumped
+
+
+def num_tokens(text, model="gpt-4o"):
+ try:
+ encoding = tiktoken.encoding_for_model(model)
+ except KeyError:
+ encoding = tiktoken.get_encoding("cl100k_base")
+
+ """Return the number of tokens used by a list of messages for both user and assistant."""
+ return len(encoding.encode(text, disallowed_special=()))
+
+
+class CombinedMeta(AsyncSyncMeta, ABCMeta):
+ pass
+
+
+async def yield_sse_event(event_name: str, payload: dict, chunk_size=1024):
+ """
+ Helper that yields a single SSE event in properly chunked lines.
+
+ e.g. event: event_name
+ data: (partial JSON 1)
+ data: (partial JSON 2)
+ ...
+ [blank line to end event]
+ """
+
+ # SSE: first the "event: ..."
+ yield f"event: {event_name}\n"
+
+ # Convert payload to JSON
+ content_str = json.dumps(payload, default=str)
+
+ # data
+ yield f"data: {content_str}\n"
+
+ # blank line signals end of SSE event
+ yield "\n"
+
+
+class SSEFormatter:
+ """
+ Enhanced formatter for Server-Sent Events (SSE) with citation tracking.
+ Extends the existing SSEFormatter with improved citation handling.
+ """
+
+ @staticmethod
+ async def yield_citation_event(
+ citation_data: dict,
+ ):
+ """
+ Emits a citation event with optimized payload.
+
+ Args:
+ citation_id: The short ID of the citation (e.g., 'abc1234')
+ span: (start, end) position tuple for this occurrence
+ payload: Source object (included only for first occurrence)
+ is_new: Whether this is the first time we've seen this citation
+ citation_id_counter: Optional counter for citation occurrences
+
+ Yields:
+ Formatted SSE event lines
+ """
+
+ # Include the full payload only for new citations
+ if not citation_data.get("is_new") or "payload" not in citation_data:
+ citation_data["payload"] = None
+
+ # Yield the event
+ async for line in yield_sse_event("citation", citation_data):
+ yield line
+
+ @staticmethod
+ async def yield_final_answer_event(
+ final_data: dict,
+ ):
+ # Yield the event
+ async for line in yield_sse_event("final_answer", final_data):
+ yield line
+
+ # Include other existing SSEFormatter methods for compatibility
+ @staticmethod
+ async def yield_message_event(text_segment, msg_id=None):
+ msg_id = msg_id or f"msg_{uuid.uuid4().hex[:8]}"
+ msg_payload = {
+ "id": msg_id,
+ "object": "agent.message.delta",
+ "delta": {
+ "content": [
+ {
+ "type": "text",
+ "payload": {
+ "value": text_segment,
+ "annotations": [],
+ },
+ }
+ ]
+ },
+ }
+ async for line in yield_sse_event("message", msg_payload):
+ yield line
+
+ @staticmethod
+ async def yield_thinking_event(text_segment, thinking_id=None):
+ thinking_id = thinking_id or f"think_{uuid.uuid4().hex[:8]}"
+ thinking_data = {
+ "id": thinking_id,
+ "object": "agent.thinking.delta",
+ "delta": {
+ "content": [
+ {
+ "type": "text",
+ "payload": {
+ "value": text_segment,
+ "annotations": [],
+ },
+ }
+ ]
+ },
+ }
+ async for line in yield_sse_event("thinking", thinking_data):
+ yield line
+
+ @staticmethod
+ def yield_done_event():
+ return "event: done\ndata: [DONE]\n\n"
+
+ @staticmethod
+ async def yield_error_event(error_message, error_id=None):
+ error_id = error_id or f"err_{uuid.uuid4().hex[:8]}"
+ error_payload = {
+ "id": error_id,
+ "object": "agent.error",
+ "error": {"message": error_message, "type": "agent_error"},
+ }
+ async for line in yield_sse_event("error", error_payload):
+ yield line
+
+ @staticmethod
+ async def yield_tool_call_event(tool_call_data):
+ from ..api.models.retrieval.responses import ToolCallEvent
+
+ tc_event = ToolCallEvent(event="tool_call", data=tool_call_data)
+ async for line in yield_sse_event(
+ "tool_call", tc_event.dict()["data"]
+ ):
+ yield line
+
+ # New helper for emitting search results:
+ @staticmethod
+ async def yield_search_results_event(aggregated_results):
+ payload = {
+ "id": "search_1",
+ "object": "rag.search_results",
+ "data": aggregated_results.as_dict(),
+ }
+ async for line in yield_sse_event("search_results", payload):
+ yield line
+
+ @staticmethod
+ async def yield_tool_result_event(tool_result_data):
+ from ..api.models.retrieval.responses import ToolResultEvent
+
+ tr_event = ToolResultEvent(event="tool_result", data=tool_result_data)
+ async for line in yield_sse_event(
+ "tool_result", tr_event.dict()["data"]
+ ):
+ yield line
diff --git a/.venv/lib/python3.12/site-packages/shared/utils/splitter/__init__.py b/.venv/lib/python3.12/site-packages/shared/utils/splitter/__init__.py
new file mode 100644
index 00000000..07a9f554
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/shared/utils/splitter/__init__.py
@@ -0,0 +1,3 @@
+from .text import RecursiveCharacterTextSplitter
+
+__all__ = ["RecursiveCharacterTextSplitter"]
diff --git a/.venv/lib/python3.12/site-packages/shared/utils/splitter/text.py b/.venv/lib/python3.12/site-packages/shared/utils/splitter/text.py
new file mode 100644
index 00000000..92a7c81b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/shared/utils/splitter/text.py
@@ -0,0 +1,2000 @@
+# Source - LangChain
+# URL: https://github.com/langchain-ai/langchain/blob/6a5b084704afa22ca02f78d0464f35aed75d1ff2/libs/langchain/langchain/text_splitter.py#L851
+"""**Text Splitters** are classes for splitting text.
+
+**Class hierarchy:**
+
+.. code-block::
+
+ BaseDocumentTransformer --> TextSplitter --> <name>TextSplitter # Example: CharacterTextSplitter
+ RecursiveCharacterTextSplitter --> <name>TextSplitter
+
+Note: **MarkdownHeaderTextSplitter** and **HTMLHeaderTextSplitter do not derive from TextSplitter.
+
+
+**Main helpers:**
+
+.. code-block::
+
+ Document, Tokenizer, Language, LineType, HeaderType
+""" # noqa: E501
+
+from __future__ import annotations
+
+import copy
+import json
+import logging
+import pathlib
+import re
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+from enum import Enum
+from io import BytesIO, StringIO
+from typing import (
+ AbstractSet,
+ Any,
+ Callable,
+ Collection,
+ Iterable,
+ Literal,
+ Optional,
+ Sequence,
+ Tuple,
+ Type,
+ TypedDict,
+ TypeVar,
+ cast,
+)
+
+import requests
+from pydantic import BaseModel, Field, PrivateAttr
+from typing_extensions import NotRequired
+
+logger = logging.getLogger()
+
+TS = TypeVar("TS", bound="TextSplitter")
+
+
+class BaseSerialized(TypedDict):
+ """Base class for serialized objects."""
+
+ lc: int
+ id: list[str]
+ name: NotRequired[str]
+ graph: NotRequired[dict[str, Any]]
+
+
+class SerializedConstructor(BaseSerialized):
+ """Serialized constructor."""
+
+ type: Literal["constructor"]
+ kwargs: dict[str, Any]
+
+
+class SerializedSecret(BaseSerialized):
+ """Serialized secret."""
+
+ type: Literal["secret"]
+
+
+class SerializedNotImplemented(BaseSerialized):
+ """Serialized not implemented."""
+
+ type: Literal["not_implemented"]
+ repr: Optional[str]
+
+
+def try_neq_default(value: Any, key: str, model: BaseModel) -> bool:
+ """Try to determine if a value is different from the default.
+
+ Args:
+ value: The value.
+ key: The key.
+ model: The model.
+
+ Returns:
+ Whether the value is different from the default.
+ """
+ try:
+ return model.__fields__[key].get_default() != value
+ except Exception:
+ return True
+
+
+class Serializable(BaseModel, ABC):
+ """Serializable base class."""
+
+ @classmethod
+ def is_lc_serializable(cls) -> bool:
+ """Is this class serializable?"""
+ return False
+
+ @classmethod
+ def get_lc_namespace(cls) -> list[str]:
+ """Get the namespace of the langchain object.
+
+ For example, if the class is `langchain.llms.openai.OpenAI`, then the
+ namespace is ["langchain", "llms", "openai"]
+ """
+ return cls.__module__.split(".")
+
+ @property
+ def lc_secrets(self) -> dict[str, str]:
+ """A map of constructor argument names to secret ids.
+
+ For example, {"openai_api_key": "OPENAI_API_KEY"}
+ """
+ return {}
+
+ @property
+ def lc_attributes(self) -> dict:
+ """List of attribute names that should be included in the serialized
+ kwargs.
+
+ These attributes must be accepted by the constructor.
+ """
+ return {}
+
+ @classmethod
+ def lc_id(cls) -> list[str]:
+ """A unique identifier for this class for serialization purposes.
+
+ The unique identifier is a list of strings that describes the path to
+ the object.
+ """
+ return [*cls.get_lc_namespace(), cls.__name__]
+
+ class Config:
+ extra = "ignore"
+
+ def __repr_args__(self) -> Any:
+ return [
+ (k, v)
+ for k, v in super().__repr_args__()
+ if (k not in self.__fields__ or try_neq_default(v, k, self))
+ ]
+
+ _lc_kwargs: dict[str, Any] = PrivateAttr(default_factory=dict)
+
+ def __init__(self, **kwargs: Any) -> None:
+ super().__init__(**kwargs)
+ self._lc_kwargs = kwargs
+
+ def to_json(
+ self,
+ ) -> SerializedConstructor | SerializedNotImplemented:
+ if not self.is_lc_serializable():
+ return self.to_json_not_implemented()
+
+ secrets = dict()
+ # Get latest values for kwargs if there is an attribute with same name
+ lc_kwargs = {
+ k: getattr(self, k, v)
+ for k, v in self._lc_kwargs.items()
+ if not (self.__exclude_fields__ or {}).get(k, False) # type: ignore
+ }
+
+ # Merge the lc_secrets and lc_attributes from every class in the MRO
+ for cls in [None, *self.__class__.mro()]:
+ # Once we get to Serializable, we're done
+ if cls is Serializable:
+ break
+
+ if cls:
+ deprecated_attributes = [
+ "lc_namespace",
+ "lc_serializable",
+ ]
+
+ for attr in deprecated_attributes:
+ if hasattr(cls, attr):
+ raise ValueError(
+ f"Class {self.__class__} has a deprecated "
+ f"attribute {attr}. Please use the corresponding "
+ f"classmethod instead."
+ )
+
+ # Get a reference to self bound to each class in the MRO
+ this = cast(
+ Serializable, self if cls is None else super(cls, self)
+ )
+
+ secrets.update(this.lc_secrets)
+ # Now also add the aliases for the secrets
+ # This ensures known secret aliases are hidden.
+ # Note: this does NOT hide any other extra kwargs
+ # that are not present in the fields.
+ for key in list(secrets):
+ value = secrets[key]
+ if key in this.__fields__:
+ secrets[this.__fields__[key].alias] = value # type: ignore
+ lc_kwargs.update(this.lc_attributes)
+
+ # include all secrets, even if not specified in kwargs
+ # as these secrets may be passed as an environment variable instead
+ for key in secrets.keys():
+ secret_value = getattr(self, key, None) or lc_kwargs.get(key)
+ if secret_value is not None:
+ lc_kwargs.update({key: secret_value})
+
+ return {
+ "lc": 1,
+ "type": "constructor",
+ "id": self.lc_id(),
+ "kwargs": (
+ lc_kwargs
+ if not secrets
+ else _replace_secrets(lc_kwargs, secrets)
+ ),
+ }
+
+ def to_json_not_implemented(self) -> SerializedNotImplemented:
+ return to_json_not_implemented(self)
+
+
+def _replace_secrets(
+ root: dict[Any, Any], secrets_map: dict[str, str]
+) -> dict[Any, Any]:
+ result = root.copy()
+ for path, secret_id in secrets_map.items():
+ [*parts, last] = path.split(".")
+ current = result
+ for part in parts:
+ if part not in current:
+ break
+ current[part] = current[part].copy()
+ current = current[part]
+ if last in current:
+ current[last] = {
+ "lc": 1,
+ "type": "secret",
+ "id": [secret_id],
+ }
+ return result
+
+
+def to_json_not_implemented(obj: object) -> SerializedNotImplemented:
+ """Serialize a "not implemented" object.
+
+ Args:
+ obj: object to serialize
+
+ Returns:
+ SerializedNotImplemented
+ """
+ _id: list[str] = []
+ try:
+ if hasattr(obj, "__name__"):
+ _id = [*obj.__module__.split("."), obj.__name__]
+ elif hasattr(obj, "__class__"):
+ _id = [
+ *obj.__class__.__module__.split("."),
+ obj.__class__.__name__,
+ ]
+ except Exception:
+ pass
+
+ result: SerializedNotImplemented = {
+ "lc": 1,
+ "type": "not_implemented",
+ "id": _id,
+ "repr": None,
+ }
+ try:
+ result["repr"] = repr(obj)
+ except Exception:
+ pass
+ return result
+
+
+class SplitterDocument(Serializable):
+ """Class for storing a piece of text and associated metadata."""
+
+ page_content: str
+ """String text."""
+ metadata: dict = Field(default_factory=dict)
+ """Arbitrary metadata about the page content (e.g., source, relationships
+ to other documents, etc.)."""
+ type: Literal["Document"] = "Document"
+
+ def __init__(self, page_content: str, **kwargs: Any) -> None:
+ """Pass page_content in as positional or named arg."""
+ super().__init__(page_content=page_content, **kwargs)
+
+ @classmethod
+ def is_lc_serializable(cls) -> bool:
+ """Return whether this class is serializable."""
+ return True
+
+ @classmethod
+ def get_lc_namespace(cls) -> list[str]:
+ """Get the namespace of the langchain object."""
+ return ["langchain", "schema", "document"]
+
+
+class BaseDocumentTransformer(ABC):
+ """Abstract base class for document transformation systems.
+
+ A document transformation system takes a sequence of Documents and returns a
+ sequence of transformed Documents.
+
+ Example:
+ .. code-block:: python
+
+ class EmbeddingsRedundantFilter(BaseDocumentTransformer, BaseModel):
+ embeddings: Embeddings
+ similarity_fn: Callable = cosine_similarity
+ similarity_threshold: float = 0.95
+
+ class Config:
+ arbitrary_types_allowed = True
+
+ def transform_documents(
+ self, documents: Sequence[Document], **kwargs: Any
+ ) -> Sequence[Document]:
+ stateful_documents = get_stateful_documents(documents)
+ embedded_documents = _get_embeddings_from_stateful_docs(
+ self.embeddings, stateful_documents
+ )
+ included_idxs = _filter_similar_embeddings(
+ embedded_documents, self.similarity_fn, self.similarity_threshold
+ )
+ return [stateful_documents[i] for i in sorted(included_idxs)]
+
+ async def atransform_documents(
+ self, documents: Sequence[Document], **kwargs: Any
+ ) -> Sequence[Document]:
+ raise NotImplementedError
+ """ # noqa: E501
+
+ @abstractmethod
+ def transform_documents(
+ self, documents: Sequence[SplitterDocument], **kwargs: Any
+ ) -> Sequence[SplitterDocument]:
+ """Transform a list of documents.
+
+ Args:
+ documents: A sequence of Documents to be transformed.
+
+ Returns:
+ A list of transformed Documents.
+ """
+
+ async def atransform_documents(
+ self, documents: Sequence[SplitterDocument], **kwargs: Any
+ ) -> Sequence[SplitterDocument]:
+ """Asynchronously transform a list of documents.
+
+ Args:
+ documents: A sequence of Documents to be transformed.
+
+ Returns:
+ A list of transformed Documents.
+ """
+ raise NotImplementedError("This method is not implemented.")
+ # return await langchain_core.runnables.config.run_in_executor(
+ # None, self.transform_documents, documents, **kwargs
+ # )
+
+
+def _make_spacy_pipe_for_splitting(
+ pipe: str, *, max_length: int = 1_000_000
+) -> Any: # avoid importing spacy
+ try:
+ import spacy
+ except ImportError:
+ raise ImportError(
+ "Spacy is not installed, run `pip install spacy`."
+ ) from None
+ if pipe == "sentencizer":
+ from spacy.lang.en import English
+
+ sentencizer = English()
+ sentencizer.add_pipe("sentencizer")
+ else:
+ sentencizer = spacy.load(pipe, exclude=["ner", "tagger"])
+ sentencizer.max_length = max_length
+ return sentencizer
+
+
+def _split_text_with_regex(
+ text: str, separator: str, keep_separator: bool
+) -> list[str]:
+ # Now that we have the separator, split the text
+ if separator:
+ if keep_separator:
+ # The parentheses in the pattern keep the delimiters in the result.
+ _splits = re.split(f"({separator})", text)
+ splits = [
+ _splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)
+ ]
+ if len(_splits) % 2 == 0:
+ splits += _splits[-1:]
+ splits = [_splits[0]] + splits
+ else:
+ splits = re.split(separator, text)
+ else:
+ splits = list(text)
+ return [s for s in splits if s != ""]
+
+
+class TextSplitter(BaseDocumentTransformer, ABC):
+ """Interface for splitting text into chunks."""
+
+ def __init__(
+ self,
+ chunk_size: int = 4000,
+ chunk_overlap: int = 200,
+ length_function: Callable[[str], int] = len,
+ keep_separator: bool = False,
+ add_start_index: bool = False,
+ strip_whitespace: bool = True,
+ ) -> None:
+ """Create a new TextSplitter.
+
+ Args:
+ chunk_size: Maximum size of chunks to return
+ chunk_overlap: Overlap in characters between chunks
+ length_function: Function that measures the length of given chunks
+ keep_separator: Whether to keep the separator in the chunks
+ add_start_index: If `True`, includes chunk's start index in
+ metadata
+ strip_whitespace: If `True`, strips whitespace from the start and
+ end of every document
+ """
+ if chunk_overlap > chunk_size:
+ raise ValueError(
+ f"Got a larger chunk overlap ({chunk_overlap}) than chunk size "
+ f"({chunk_size}), should be smaller."
+ )
+ self._chunk_size = chunk_size
+ self._chunk_overlap = chunk_overlap
+ self._length_function = length_function
+ self._keep_separator = keep_separator
+ self._add_start_index = add_start_index
+ self._strip_whitespace = strip_whitespace
+
+ @abstractmethod
+ def split_text(self, text: str) -> list[str]:
+ """Split text into multiple components."""
+
+ def create_documents(
+ self, texts: list[str], metadatas: Optional[list[dict]] = None
+ ) -> list[SplitterDocument]:
+ """Create documents from a list of texts."""
+ _metadatas = metadatas or [{}] * len(texts)
+ documents = []
+ for i, text in enumerate(texts):
+ index = 0
+ previous_chunk_len = 0
+ for chunk in self.split_text(text):
+ metadata = copy.deepcopy(_metadatas[i])
+ if self._add_start_index:
+ offset = index + previous_chunk_len - self._chunk_overlap
+ index = text.find(chunk, max(0, offset))
+ metadata["start_index"] = index
+ previous_chunk_len = len(chunk)
+ new_doc = SplitterDocument(
+ page_content=chunk, metadata=metadata
+ )
+ documents.append(new_doc)
+ return documents
+
+ def split_documents(
+ self, documents: Iterable[SplitterDocument]
+ ) -> list[SplitterDocument]:
+ """Split documents."""
+ texts, metadatas = [], []
+ for doc in documents:
+ texts.append(doc.page_content)
+ metadatas.append(doc.metadata)
+ return self.create_documents(texts, metadatas=metadatas)
+
+ def _join_docs(self, docs: list[str], separator: str) -> Optional[str]:
+ text = separator.join(docs)
+ if self._strip_whitespace:
+ text = text.strip()
+ if text == "":
+ return None
+ else:
+ return text
+
+ def _merge_splits(
+ self, splits: Iterable[str], separator: str
+ ) -> list[str]:
+ # We now want to combine these smaller pieces into medium size
+ # chunks to send to the LLM.
+ separator_len = self._length_function(separator)
+
+ docs = []
+ current_doc: list[str] = []
+ total = 0
+ for d in splits:
+ _len = self._length_function(d)
+ if (
+ total + _len + (separator_len if len(current_doc) > 0 else 0)
+ > self._chunk_size
+ ):
+ if total > self._chunk_size:
+ logger.warning(
+ f"Created a chunk of size {total}, "
+ f"which is longer than the specified {self._chunk_size}"
+ )
+ if len(current_doc) > 0:
+ doc = self._join_docs(current_doc, separator)
+ if doc is not None:
+ docs.append(doc)
+ # Keep on popping if:
+ # - we have a larger chunk than in the chunk overlap
+ # - or if we still have any chunks and the length is long
+ while total > self._chunk_overlap or (
+ total
+ + _len
+ + (separator_len if len(current_doc) > 0 else 0)
+ > self._chunk_size
+ and total > 0
+ ):
+ total -= self._length_function(current_doc[0]) + (
+ separator_len if len(current_doc) > 1 else 0
+ )
+ current_doc = current_doc[1:]
+ current_doc.append(d)
+ total += _len + (separator_len if len(current_doc) > 1 else 0)
+ doc = self._join_docs(current_doc, separator)
+ if doc is not None:
+ docs.append(doc)
+ return docs
+
+ @classmethod
+ def from_huggingface_tokenizer(
+ cls, tokenizer: Any, **kwargs: Any
+ ) -> TextSplitter:
+ """Text splitter that uses HuggingFace tokenizer to count length."""
+ try:
+ from transformers import PreTrainedTokenizerBase
+
+ if not isinstance(tokenizer, PreTrainedTokenizerBase):
+ raise ValueError(
+ "Tokenizer received was not an instance of PreTrainedTokenizerBase"
+ )
+
+ def _huggingface_tokenizer_length(text: str) -> int:
+ return len(tokenizer.encode(text))
+
+ except ImportError:
+ raise ValueError(
+ "Could not import transformers python package. "
+ "Please install it with `pip install transformers`."
+ ) from None
+ return cls(length_function=_huggingface_tokenizer_length, **kwargs)
+
+ @classmethod
+ def from_tiktoken_encoder(
+ cls: Type[TS],
+ encoding_name: str = "gpt2",
+ model: Optional[str] = None,
+ allowed_special: Literal["all"] | AbstractSet[str] = set(),
+ disallowed_special: Literal["all"] | Collection[str] = "all",
+ **kwargs: Any,
+ ) -> TS:
+ """Text splitter that uses tiktoken encoder to count length."""
+ try:
+ import tiktoken
+ except ImportError:
+ raise ImportError("""Could not import tiktoken python package.
+ This is needed in order to calculate max_tokens_for_prompt.
+ Please install it with `pip install tiktoken`.""") from None
+
+ if model is not None:
+ enc = tiktoken.encoding_for_model(model)
+ else:
+ enc = tiktoken.get_encoding(encoding_name)
+
+ def _tiktoken_encoder(text: str) -> int:
+ return len(
+ enc.encode(
+ text,
+ allowed_special=allowed_special,
+ disallowed_special=disallowed_special,
+ )
+ )
+
+ if issubclass(cls, TokenTextSplitter):
+ extra_kwargs = {
+ "encoding_name": encoding_name,
+ "model": model,
+ "allowed_special": allowed_special,
+ "disallowed_special": disallowed_special,
+ }
+ kwargs = {**kwargs, **extra_kwargs}
+
+ return cls(length_function=_tiktoken_encoder, **kwargs)
+
+ def transform_documents(
+ self, documents: Sequence[SplitterDocument], **kwargs: Any
+ ) -> Sequence[SplitterDocument]:
+ """Transform sequence of documents by splitting them."""
+ return self.split_documents(list(documents))
+
+
+class CharacterTextSplitter(TextSplitter):
+ """Splitting text that looks at characters."""
+
+ DEFAULT_SEPARATOR: str = "\n\n"
+
+ def __init__(
+ self,
+ separator: str = DEFAULT_SEPARATOR,
+ is_separator_regex: bool = False,
+ **kwargs: Any,
+ ) -> None:
+ """Create a new TextSplitter."""
+ super().__init__(**kwargs)
+ self._separator = separator
+ self._is_separator_regex = is_separator_regex
+
+ def split_text(self, text: str) -> list[str]:
+ """Split incoming text and return chunks."""
+ # First we naively split the large input into a bunch of smaller ones.
+ separator = (
+ self._separator
+ if self._is_separator_regex
+ else re.escape(self._separator)
+ )
+ splits = _split_text_with_regex(text, separator, self._keep_separator)
+ _separator = "" if self._keep_separator else self._separator
+ return self._merge_splits(splits, _separator)
+
+
+class LineType(TypedDict):
+ """Line type as typed dict."""
+
+ metadata: dict[str, str]
+ content: str
+
+
+class HeaderType(TypedDict):
+ """Header type as typed dict."""
+
+ level: int
+ name: str
+ data: str
+
+
+class MarkdownHeaderTextSplitter:
+ """Splitting markdown files based on specified headers."""
+
+ def __init__(
+ self,
+ headers_to_split_on: list[Tuple[str, str]],
+ return_each_line: bool = False,
+ strip_headers: bool = True,
+ ):
+ """Create a new MarkdownHeaderTextSplitter.
+
+ Args:
+ headers_to_split_on: Headers we want to track
+ return_each_line: Return each line w/ associated headers
+ strip_headers: Strip split headers from the content of the chunk
+ """
+ # Output line-by-line or aggregated into chunks w/ common headers
+ self.return_each_line = return_each_line
+ # Given the headers we want to split on,
+ # (e.g., "#, ##, etc") order by length
+ self.headers_to_split_on = sorted(
+ headers_to_split_on, key=lambda split: len(split[0]), reverse=True
+ )
+ # Strip headers split headers from the content of the chunk
+ self.strip_headers = strip_headers
+
+ def aggregate_lines_to_chunks(
+ self, lines: list[LineType]
+ ) -> list[SplitterDocument]:
+ """Combine lines with common metadata into chunks
+ Args:
+ lines: Line of text / associated header metadata
+ """
+ aggregated_chunks: list[LineType] = []
+
+ for line in lines:
+ if (
+ aggregated_chunks
+ and aggregated_chunks[-1]["metadata"] == line["metadata"]
+ ):
+ # If the last line in the aggregated list
+ # has the same metadata as the current line,
+ # append the current content to the last lines's content
+ aggregated_chunks[-1]["content"] += " \n" + line["content"]
+ elif (
+ aggregated_chunks
+ and aggregated_chunks[-1]["metadata"] != line["metadata"]
+ # may be issues if other metadata is present
+ and len(aggregated_chunks[-1]["metadata"])
+ < len(line["metadata"])
+ and aggregated_chunks[-1]["content"].split("\n")[-1][0] == "#"
+ and not self.strip_headers
+ ):
+ # If the last line in the aggregated list
+ # has different metadata as the current line,
+ # and has shallower header level than the current line,
+ # and the last line is a header,
+ # and we are not stripping headers,
+ # append the current content to the last line's content
+ aggregated_chunks[-1]["content"] += " \n" + line["content"]
+ # and update the last line's metadata
+ aggregated_chunks[-1]["metadata"] = line["metadata"]
+ else:
+ # Otherwise, append the current line to the aggregated list
+ aggregated_chunks.append(line)
+
+ return [
+ SplitterDocument(
+ page_content=chunk["content"], metadata=chunk["metadata"]
+ )
+ for chunk in aggregated_chunks
+ ]
+
+ def split_text(self, text: str) -> list[SplitterDocument]:
+ """Split markdown file
+ Args:
+ text: Markdown file"""
+
+ # Split the input text by newline character ("\n").
+ lines = text.split("\n")
+ # Final output
+ lines_with_metadata: list[LineType] = []
+ # Content and metadata of the chunk currently being processed
+ current_content: list[str] = []
+ current_metadata: dict[str, str] = {}
+ # Keep track of the nested header structure
+ # header_stack: list[dict[str, int | str]] = []
+ header_stack: list[HeaderType] = []
+ initial_metadata: dict[str, str] = {}
+
+ in_code_block = False
+ opening_fence = ""
+
+ for line in lines:
+ stripped_line = line.strip()
+
+ if not in_code_block:
+ # Exclude inline code spans
+ if (
+ stripped_line.startswith("```")
+ and stripped_line.count("```") == 1
+ ):
+ in_code_block = True
+ opening_fence = "```"
+ elif stripped_line.startswith("~~~"):
+ in_code_block = True
+ opening_fence = "~~~"
+ else:
+ if stripped_line.startswith(opening_fence):
+ in_code_block = False
+ opening_fence = ""
+
+ if in_code_block:
+ current_content.append(stripped_line)
+ continue
+
+ # Check each line against each of the header types (e.g., #, ##)
+ for sep, name in self.headers_to_split_on:
+ # Check if line starts with a header that we intend to split on
+ if stripped_line.startswith(sep) and (
+ # Header with no text OR header is followed by space
+ # Both are valid conditions that sep is being used a header
+ len(stripped_line) == len(sep)
+ or stripped_line[len(sep)] == " "
+ ):
+ # Ensure we are tracking the header as metadata
+ if name is not None:
+ # Get the current header level
+ current_header_level = sep.count("#")
+
+ # Pop out headers of lower or same level from the stack
+ while (
+ header_stack
+ and header_stack[-1]["level"]
+ >= current_header_level
+ ):
+ # We have encountered a new header
+ # at the same or higher level
+ popped_header = header_stack.pop()
+ # Clear the metadata for the
+ # popped header in initial_metadata
+ if popped_header["name"] in initial_metadata:
+ initial_metadata.pop(popped_header["name"])
+
+ # Push the current header to the stack
+ header: HeaderType = {
+ "level": current_header_level,
+ "name": name,
+ "data": stripped_line[len(sep) :].strip(),
+ }
+ header_stack.append(header)
+ # Update initial_metadata with the current header
+ initial_metadata[name] = header["data"]
+
+ # Add the previous line to the lines_with_metadata
+ # only if current_content is not empty
+ if current_content:
+ lines_with_metadata.append(
+ {
+ "content": "\n".join(current_content),
+ "metadata": current_metadata.copy(),
+ }
+ )
+ current_content.clear()
+
+ if not self.strip_headers:
+ current_content.append(stripped_line)
+
+ break
+ else:
+ if stripped_line:
+ current_content.append(stripped_line)
+ elif current_content:
+ lines_with_metadata.append(
+ {
+ "content": "\n".join(current_content),
+ "metadata": current_metadata.copy(),
+ }
+ )
+ current_content.clear()
+
+ current_metadata = initial_metadata.copy()
+
+ if current_content:
+ lines_with_metadata.append(
+ {
+ "content": "\n".join(current_content),
+ "metadata": current_metadata,
+ }
+ )
+
+ # lines_with_metadata has each line with associated header metadata
+ # aggregate these into chunks based on common metadata
+ if not self.return_each_line:
+ return self.aggregate_lines_to_chunks(lines_with_metadata)
+ else:
+ return [
+ SplitterDocument(
+ page_content=chunk["content"], metadata=chunk["metadata"]
+ )
+ for chunk in lines_with_metadata
+ ]
+
+
+class ElementType(TypedDict):
+ """Element type as typed dict."""
+
+ url: str
+ xpath: str
+ content: str
+ metadata: dict[str, str]
+
+
+class HTMLHeaderTextSplitter:
+ """Splitting HTML files based on specified headers.
+
+ Requires lxml package.
+ """
+
+ def __init__(
+ self,
+ headers_to_split_on: list[Tuple[str, str]],
+ return_each_element: bool = False,
+ ):
+ """Create a new HTMLHeaderTextSplitter.
+
+ Args:
+ headers_to_split_on: list of tuples of headers we want to track
+ mapped to (arbitrary) keys for metadata. Allowed header values:
+ h1, h2, h3, h4, h5, h6
+ e.g. [("h1", "Header 1"), ("h2", "Header 2)].
+ return_each_element: Return each element w/ associated headers.
+ """
+ # Output element-by-element or aggregated into chunks w/ common headers
+ self.return_each_element = return_each_element
+ self.headers_to_split_on = sorted(headers_to_split_on)
+
+ def aggregate_elements_to_chunks(
+ self, elements: list[ElementType]
+ ) -> list[SplitterDocument]:
+ """Combine elements with common metadata into chunks.
+
+ Args:
+ elements: HTML element content with associated identifying
+ info and metadata
+ """
+ aggregated_chunks: list[ElementType] = []
+
+ for element in elements:
+ if (
+ aggregated_chunks
+ and aggregated_chunks[-1]["metadata"] == element["metadata"]
+ ):
+ # If the last element in the aggregated list
+ # has the same metadata as the current element,
+ # append the current content to the last element's content
+ aggregated_chunks[-1]["content"] += " \n" + element["content"]
+ else:
+ # Otherwise, append the current element to the aggregated list
+ aggregated_chunks.append(element)
+
+ return [
+ SplitterDocument(
+ page_content=chunk["content"], metadata=chunk["metadata"]
+ )
+ for chunk in aggregated_chunks
+ ]
+
+ def split_text_from_url(self, url: str) -> list[SplitterDocument]:
+ """Split HTML from web URL.
+
+ Args:
+ url: web URL
+ """
+ r = requests.get(url)
+ return self.split_text_from_file(BytesIO(r.content))
+
+ def split_text(self, text: str) -> list[SplitterDocument]:
+ """Split HTML text string.
+
+ Args:
+ text: HTML text
+ """
+ return self.split_text_from_file(StringIO(text))
+
+ def split_text_from_file(self, file: Any) -> list[SplitterDocument]:
+ """Split HTML file.
+
+ Args:
+ file: HTML file
+ """
+ try:
+ from lxml import etree
+ except ImportError:
+ raise ImportError(
+ "Unable to import lxml, run `pip install lxml`."
+ ) from None
+ # use lxml library to parse html document and return xml ElementTree
+ # Explicitly encoding in utf-8 allows non-English
+ # html files to be processed without garbled characters
+ parser = etree.HTMLParser(encoding="utf-8")
+ tree = etree.parse(file, parser)
+
+ # document transformation for "structure-aware" chunking is handled
+ # with xsl. See comments in html_chunks_with_headers.xslt for more
+ # detailed information.
+ xslt_path = (
+ pathlib.Path(__file__).parent
+ / "document_transformers/xsl/html_chunks_with_headers.xslt"
+ )
+ xslt_tree = etree.parse(xslt_path)
+ transform = etree.XSLT(xslt_tree)
+ result = transform(tree)
+ result_dom = etree.fromstring(str(result))
+
+ # create filter and mapping for header metadata
+ header_filter = [header[0] for header in self.headers_to_split_on]
+ header_mapping = dict(self.headers_to_split_on)
+
+ # map xhtml namespace prefix
+ ns_map = {"h": "http://www.w3.org/1999/xhtml"}
+
+ # build list of elements from DOM
+ elements = []
+ for element in result_dom.findall("*//*", ns_map):
+ if element.findall("*[@class='headers']") or element.findall(
+ "*[@class='chunk']"
+ ):
+ elements.append(
+ ElementType(
+ url=file,
+ xpath="".join(
+ [
+ node.text
+ for node in element.findall(
+ "*[@class='xpath']", ns_map
+ )
+ ]
+ ),
+ content="".join(
+ [
+ node.text
+ for node in element.findall(
+ "*[@class='chunk']", ns_map
+ )
+ ]
+ ),
+ metadata={
+ # Add text of specified headers to
+ # metadata using header mapping.
+ header_mapping[node.tag]: node.text
+ for node in filter(
+ lambda x: x.tag in header_filter,
+ element.findall(
+ "*[@class='headers']/*", ns_map
+ ),
+ )
+ },
+ )
+ )
+
+ if not self.return_each_element:
+ return self.aggregate_elements_to_chunks(elements)
+ else:
+ return [
+ SplitterDocument(
+ page_content=chunk["content"], metadata=chunk["metadata"]
+ )
+ for chunk in elements
+ ]
+
+
+# should be in newer Python versions (3.11+)
+# @dataclass(frozen=True, kw_only=True, slots=True)
+@dataclass(frozen=True)
+class Tokenizer:
+ """Tokenizer data class."""
+
+ chunk_overlap: int
+ """Overlap in tokens between chunks."""
+ tokens_per_chunk: int
+ """Maximum number of tokens per chunk."""
+ decode: Callable[[list[int]], str]
+ """Function to decode a list of token ids to a string."""
+ encode: Callable[[str], list[int]]
+ """Function to encode a string to a list of token ids."""
+
+
+def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> list[str]:
+ """Split incoming text and return chunks using tokenizer."""
+ splits: list[str] = []
+ input_ids = tokenizer.encode(text)
+ start_idx = 0
+ cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
+ chunk_ids = input_ids[start_idx:cur_idx]
+ while start_idx < len(input_ids):
+ splits.append(tokenizer.decode(chunk_ids))
+ if cur_idx == len(input_ids):
+ break
+ start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap
+ cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
+ chunk_ids = input_ids[start_idx:cur_idx]
+ return splits
+
+
+class TokenTextSplitter(TextSplitter):
+ """Splitting text to tokens using model tokenizer."""
+
+ def __init__(
+ self,
+ encoding_name: str = "gpt2",
+ model: Optional[str] = None,
+ allowed_special: Literal["all"] | AbstractSet[str] = set(),
+ disallowed_special: Literal["all"] | Collection[str] = "all",
+ **kwargs: Any,
+ ) -> None:
+ """Create a new TextSplitter."""
+ super().__init__(**kwargs)
+ try:
+ import tiktoken
+ except ImportError:
+ raise ImportError(
+ "Could not import tiktoken python package. "
+ "This is needed in order to for TokenTextSplitter. "
+ "Please install it with `pip install tiktoken`."
+ ) from None
+
+ if model is not None:
+ enc = tiktoken.encoding_for_model(model)
+ else:
+ enc = tiktoken.get_encoding(encoding_name)
+ self._tokenizer = enc
+ self._allowed_special = allowed_special
+ self._disallowed_special = disallowed_special
+
+ def split_text(self, text: str) -> list[str]:
+ def _encode(_text: str) -> list[int]:
+ return self._tokenizer.encode(
+ _text,
+ allowed_special=self._allowed_special,
+ disallowed_special=self._disallowed_special,
+ )
+
+ tokenizer = Tokenizer(
+ chunk_overlap=self._chunk_overlap,
+ tokens_per_chunk=self._chunk_size,
+ decode=self._tokenizer.decode,
+ encode=_encode,
+ )
+
+ return split_text_on_tokens(text=text, tokenizer=tokenizer)
+
+
+class SentenceTransformersTokenTextSplitter(TextSplitter):
+ """Splitting text to tokens using sentence model tokenizer."""
+
+ def __init__(
+ self,
+ chunk_overlap: int = 50,
+ model: str = "sentence-transformers/all-mpnet-base-v2",
+ tokens_per_chunk: Optional[int] = None,
+ **kwargs: Any,
+ ) -> None:
+ """Create a new TextSplitter."""
+ super().__init__(**kwargs, chunk_overlap=chunk_overlap)
+
+ try:
+ from sentence_transformers import SentenceTransformer
+ except ImportError:
+ raise ImportError(
+ """Could not import sentence_transformer python package.
+ This is needed in order to for
+ SentenceTransformersTokenTextSplitter.
+ Please install it with `pip install sentence-transformers`.
+ """
+ ) from None
+
+ self.model = model
+ self._model = SentenceTransformer(self.model, trust_remote_code=True)
+ self.tokenizer = self._model.tokenizer
+ self._initialize_chunk_configuration(tokens_per_chunk=tokens_per_chunk)
+
+ def _initialize_chunk_configuration(
+ self, *, tokens_per_chunk: Optional[int]
+ ) -> None:
+ self.maximum_tokens_per_chunk = cast(int, self._model.max_seq_length)
+
+ if tokens_per_chunk is None:
+ self.tokens_per_chunk = self.maximum_tokens_per_chunk
+ else:
+ self.tokens_per_chunk = tokens_per_chunk
+
+ if self.tokens_per_chunk > self.maximum_tokens_per_chunk:
+ raise ValueError(
+ f"The token limit of the models '{self.model}'"
+ f" is: {self.maximum_tokens_per_chunk}."
+ f" Argument tokens_per_chunk={self.tokens_per_chunk}"
+ f" > maximum token limit."
+ )
+
+ def split_text(self, text: str) -> list[str]:
+ def encode_strip_start_and_stop_token_ids(text: str) -> list[int]:
+ return self._encode(text)[1:-1]
+
+ tokenizer = Tokenizer(
+ chunk_overlap=self._chunk_overlap,
+ tokens_per_chunk=self.tokens_per_chunk,
+ decode=self.tokenizer.decode,
+ encode=encode_strip_start_and_stop_token_ids,
+ )
+
+ return split_text_on_tokens(text=text, tokenizer=tokenizer)
+
+ def count_tokens(self, *, text: str) -> int:
+ return len(self._encode(text))
+
+ _max_length_equal_32_bit_integer: int = 2**32
+
+ def _encode(self, text: str) -> list[int]:
+ token_ids_with_start_and_end_token_ids = self.tokenizer.encode(
+ text,
+ max_length=self._max_length_equal_32_bit_integer,
+ truncation="do_not_truncate",
+ )
+ return token_ids_with_start_and_end_token_ids
+
+
+class Language(str, Enum):
+ """Enum of the programming languages."""
+
+ CPP = "cpp"
+ GO = "go"
+ JAVA = "java"
+ KOTLIN = "kotlin"
+ JS = "js"
+ TS = "ts"
+ PHP = "php"
+ PROTO = "proto"
+ PYTHON = "python"
+ RST = "rst"
+ RUBY = "ruby"
+ RUST = "rust"
+ SCALA = "scala"
+ SWIFT = "swift"
+ MARKDOWN = "markdown"
+ LATEX = "latex"
+ HTML = "html"
+ SOL = "sol"
+ CSHARP = "csharp"
+ COBOL = "cobol"
+ C = "c"
+ LUA = "lua"
+ PERL = "perl"
+
+
+class RecursiveCharacterTextSplitter(TextSplitter):
+ """Splitting text by recursively look at characters.
+
+ Recursively tries to split by different characters to find one that works.
+ """
+
+ def __init__(
+ self,
+ separators: Optional[list[str]] = None,
+ keep_separator: bool = True,
+ is_separator_regex: bool = False,
+ chunk_size: int = 4000,
+ chunk_overlap: int = 200,
+ **kwargs: Any,
+ ) -> None:
+ """Create a new TextSplitter."""
+ super().__init__(
+ chunk_size=chunk_size,
+ chunk_overlap=chunk_overlap,
+ keep_separator=keep_separator,
+ **kwargs,
+ )
+ self._separators = separators or ["\n\n", "\n", " ", ""]
+ self._is_separator_regex = is_separator_regex
+ self.chunk_size = chunk_size
+ self.chunk_overlap = chunk_overlap
+
+ def _split_text(self, text: str, separators: list[str]) -> list[str]:
+ """Split incoming text and return chunks."""
+ final_chunks = []
+ # Get appropriate separator to use
+ separator = separators[-1]
+ new_separators = []
+ for i, _s in enumerate(separators):
+ _separator = _s if self._is_separator_regex else re.escape(_s)
+ if _s == "":
+ separator = _s
+ break
+ if re.search(_separator, text):
+ separator = _s
+ new_separators = separators[i + 1 :]
+ break
+
+ _separator = (
+ separator if self._is_separator_regex else re.escape(separator)
+ )
+ splits = _split_text_with_regex(text, _separator, self._keep_separator)
+
+ # Now go merging things, recursively splitting longer texts.
+ _good_splits = []
+ _separator = "" if self._keep_separator else separator
+ for s in splits:
+ if self._length_function(s) < self._chunk_size:
+ _good_splits.append(s)
+ else:
+ if _good_splits:
+ merged_text = self._merge_splits(_good_splits, _separator)
+ final_chunks.extend(merged_text)
+ _good_splits = []
+ if not new_separators:
+ final_chunks.append(s)
+ else:
+ other_info = self._split_text(s, new_separators)
+ final_chunks.extend(other_info)
+ if _good_splits:
+ merged_text = self._merge_splits(_good_splits, _separator)
+ final_chunks.extend(merged_text)
+ return final_chunks
+
+ def split_text(self, text: str) -> list[str]:
+ return self._split_text(text, self._separators)
+
+ @classmethod
+ def from_language(
+ cls, language: Language, **kwargs: Any
+ ) -> RecursiveCharacterTextSplitter:
+ separators = cls.get_separators_for_language(language)
+ return cls(separators=separators, is_separator_regex=True, **kwargs)
+
+ @staticmethod
+ def get_separators_for_language(language: Language) -> list[str]:
+ if language == Language.CPP:
+ return [
+ # Split along class definitions
+ "\nclass ",
+ # Split along function definitions
+ "\nvoid ",
+ "\nint ",
+ "\nfloat ",
+ "\ndouble ",
+ # Split along control flow statements
+ "\nif ",
+ "\nfor ",
+ "\nwhile ",
+ "\nswitch ",
+ "\ncase ",
+ # Split by the normal type of lines
+ "\n\n",
+ "\n",
+ " ",
+ "",
+ ]
+ elif language == Language.GO:
+ return [
+ # Split along function definitions
+ "\nfunc ",
+ "\nvar ",
+ "\nconst ",
+ "\ntype ",
+ # Split along control flow statements
+ "\nif ",
+ "\nfor ",
+ "\nswitch ",
+ "\ncase ",
+ # Split by the normal type of lines
+ "\n\n",
+ "\n",
+ " ",
+ "",
+ ]
+ elif language == Language.JAVA:
+ return [
+ # Split along class definitions
+ "\nclass ",
+ # Split along method definitions
+ "\npublic ",
+ "\nprotected ",
+ "\nprivate ",
+ "\nstatic ",
+ # Split along control flow statements
+ "\nif ",
+ "\nfor ",
+ "\nwhile ",
+ "\nswitch ",
+ "\ncase ",
+ # Split by the normal type of lines
+ "\n\n",
+ "\n",
+ " ",
+ "",
+ ]
+ elif language == Language.KOTLIN:
+ return [
+ # Split along class definitions
+ "\nclass ",
+ # Split along method definitions
+ "\npublic ",
+ "\nprotected ",
+ "\nprivate ",
+ "\ninternal ",
+ "\ncompanion ",
+ "\nfun ",
+ "\nval ",
+ "\nvar ",
+ # Split along control flow statements
+ "\nif ",
+ "\nfor ",
+ "\nwhile ",
+ "\nwhen ",
+ "\ncase ",
+ "\nelse ",
+ # Split by the normal type of lines
+ "\n\n",
+ "\n",
+ " ",
+ "",
+ ]
+ elif language == Language.JS:
+ return [
+ # Split along function definitions
+ "\nfunction ",
+ "\nconst ",
+ "\nlet ",
+ "\nvar ",
+ "\nclass ",
+ # Split along control flow statements
+ "\nif ",
+ "\nfor ",
+ "\nwhile ",
+ "\nswitch ",
+ "\ncase ",
+ "\ndefault ",
+ # Split by the normal type of lines
+ "\n\n",
+ "\n",
+ " ",
+ "",
+ ]
+ elif language == Language.TS:
+ return [
+ "\nenum ",
+ "\ninterface ",
+ "\nnamespace ",
+ "\ntype ",
+ # Split along class definitions
+ "\nclass ",
+ # Split along function definitions
+ "\nfunction ",
+ "\nconst ",
+ "\nlet ",
+ "\nvar ",
+ # Split along control flow statements
+ "\nif ",
+ "\nfor ",
+ "\nwhile ",
+ "\nswitch ",
+ "\ncase ",
+ "\ndefault ",
+ # Split by the normal type of lines
+ "\n\n",
+ "\n",
+ " ",
+ "",
+ ]
+ elif language == Language.PHP:
+ return [
+ # Split along function definitions
+ "\nfunction ",
+ # Split along class definitions
+ "\nclass ",
+ # Split along control flow statements
+ "\nif ",
+ "\nforeach ",
+ "\nwhile ",
+ "\ndo ",
+ "\nswitch ",
+ "\ncase ",
+ # Split by the normal type of lines
+ "\n\n",
+ "\n",
+ " ",
+ "",
+ ]
+ elif language == Language.PROTO:
+ return [
+ # Split along message definitions
+ "\nmessage ",
+ # Split along service definitions
+ "\nservice ",
+ # Split along enum definitions
+ "\nenum ",
+ # Split along option definitions
+ "\noption ",
+ # Split along import statements
+ "\nimport ",
+ # Split along syntax declarations
+ "\nsyntax ",
+ # Split by the normal type of lines
+ "\n\n",
+ "\n",
+ " ",
+ "",
+ ]
+ elif language == Language.PYTHON:
+ return [
+ # First, try to split along class definitions
+ "\nclass ",
+ "\ndef ",
+ "\n\tdef ",
+ # Now split by the normal type of lines
+ "\n\n",
+ "\n",
+ " ",
+ "",
+ ]
+ elif language == Language.RST:
+ return [
+ # Split along section titles
+ "\n=+\n",
+ "\n-+\n",
+ "\n\\*+\n",
+ # Split along directive markers
+ "\n\n.. *\n\n",
+ # Split by the normal type of lines
+ "\n\n",
+ "\n",
+ " ",
+ "",
+ ]
+ elif language == Language.RUBY:
+ return [
+ # Split along method definitions
+ "\ndef ",
+ "\nclass ",
+ # Split along control flow statements
+ "\nif ",
+ "\nunless ",
+ "\nwhile ",
+ "\nfor ",
+ "\ndo ",
+ "\nbegin ",
+ "\nrescue ",
+ # Split by the normal type of lines
+ "\n\n",
+ "\n",
+ " ",
+ "",
+ ]
+ elif language == Language.RUST:
+ return [
+ # Split along function definitions
+ "\nfn ",
+ "\nconst ",
+ "\nlet ",
+ # Split along control flow statements
+ "\nif ",
+ "\nwhile ",
+ "\nfor ",
+ "\nloop ",
+ "\nmatch ",
+ "\nconst ",
+ # Split by the normal type of lines
+ "\n\n",
+ "\n",
+ " ",
+ "",
+ ]
+ elif language == Language.SCALA:
+ return [
+ # Split along class definitions
+ "\nclass ",
+ "\nobject ",
+ # Split along method definitions
+ "\ndef ",
+ "\nval ",
+ "\nvar ",
+ # Split along control flow statements
+ "\nif ",
+ "\nfor ",
+ "\nwhile ",
+ "\nmatch ",
+ "\ncase ",
+ # Split by the normal type of lines
+ "\n\n",
+ "\n",
+ " ",
+ "",
+ ]
+ elif language == Language.SWIFT:
+ return [
+ # Split along function definitions
+ "\nfunc ",
+ # Split along class definitions
+ "\nclass ",
+ "\nstruct ",
+ "\nenum ",
+ # Split along control flow statements
+ "\nif ",
+ "\nfor ",
+ "\nwhile ",
+ "\ndo ",
+ "\nswitch ",
+ "\ncase ",
+ # Split by the normal type of lines
+ "\n\n",
+ "\n",
+ " ",
+ "",
+ ]
+ elif language == Language.MARKDOWN:
+ return [
+ # First, try to split along Markdown headings
+ # (starting with level 2)
+ "\n#{1,6} ",
+ # Note the alternative syntax for headings (below)
+ # is not handled here
+ # Heading level 2
+ # ---------------
+ # End of code block
+ "```\n",
+ # Horizontal lines
+ "\n\\*\\*\\*+\n",
+ "\n---+\n",
+ "\n___+\n",
+ # Note that this splitter doesn't handle
+ # horizontal lines defined
+ # by *three or more* of ***, ---, or ___,
+ # but this is not handled
+ "\n\n",
+ "\n",
+ " ",
+ "",
+ ]
+ elif language == Language.LATEX:
+ return [
+ # First, try to split along Latex sections
+ "\n\\\\chapter{",
+ "\n\\\\section{",
+ "\n\\\\subsection{",
+ "\n\\\\subsubsection{",
+ # Now split by environments
+ "\n\\\\begin{enumerate}",
+ "\n\\\\begin{itemize}",
+ "\n\\\\begin{description}",
+ "\n\\\\begin{list}",
+ "\n\\\\begin{quote}",
+ "\n\\\\begin{quotation}",
+ "\n\\\\begin{verse}",
+ "\n\\\\begin{verbatim}",
+ # Now split by math environments
+ "\n\\\begin{align}",
+ "$$",
+ "$",
+ # Now split by the normal type of lines
+ " ",
+ "",
+ ]
+ elif language == Language.HTML:
+ return [
+ # First, try to split along HTML tags
+ "<body",
+ "<div",
+ "<p",
+ "<br",
+ "<li",
+ "<h1",
+ "<h2",
+ "<h3",
+ "<h4",
+ "<h5",
+ "<h6",
+ "<span",
+ "<table",
+ "<tr",
+ "<td",
+ "<th",
+ "<ul",
+ "<ol",
+ "<header",
+ "<footer",
+ "<nav",
+ # Head
+ "<head",
+ "<style",
+ "<script",
+ "<meta",
+ "<title",
+ "",
+ ]
+ elif language == Language.CSHARP:
+ return [
+ "\ninterface ",
+ "\nenum ",
+ "\nimplements ",
+ "\ndelegate ",
+ "\nevent ",
+ # Split along class definitions
+ "\nclass ",
+ "\nabstract ",
+ # Split along method definitions
+ "\npublic ",
+ "\nprotected ",
+ "\nprivate ",
+ "\nstatic ",
+ "\nreturn ",
+ # Split along control flow statements
+ "\nif ",
+ "\ncontinue ",
+ "\nfor ",
+ "\nforeach ",
+ "\nwhile ",
+ "\nswitch ",
+ "\nbreak ",
+ "\ncase ",
+ "\nelse ",
+ # Split by exceptions
+ "\ntry ",
+ "\nthrow ",
+ "\nfinally ",
+ "\ncatch ",
+ # Split by the normal type of lines
+ "\n\n",
+ "\n",
+ " ",
+ "",
+ ]
+ elif language == Language.SOL:
+ return [
+ # Split along compiler information definitions
+ "\npragma ",
+ "\nusing ",
+ # Split along contract definitions
+ "\ncontract ",
+ "\ninterface ",
+ "\nlibrary ",
+ # Split along method definitions
+ "\nconstructor ",
+ "\ntype ",
+ "\nfunction ",
+ "\nevent ",
+ "\nmodifier ",
+ "\nerror ",
+ "\nstruct ",
+ "\nenum ",
+ # Split along control flow statements
+ "\nif ",
+ "\nfor ",
+ "\nwhile ",
+ "\ndo while ",
+ "\nassembly ",
+ # Split by the normal type of lines
+ "\n\n",
+ "\n",
+ " ",
+ "",
+ ]
+ elif language == Language.COBOL:
+ return [
+ # Split along divisions
+ "\nIDENTIFICATION DIVISION.",
+ "\nENVIRONMENT DIVISION.",
+ "\nDATA DIVISION.",
+ "\nPROCEDURE DIVISION.",
+ # Split along sections within DATA DIVISION
+ "\nWORKING-STORAGE SECTION.",
+ "\nLINKAGE SECTION.",
+ "\nFILE SECTION.",
+ # Split along sections within PROCEDURE DIVISION
+ "\nINPUT-OUTPUT SECTION.",
+ # Split along paragraphs and common statements
+ "\nOPEN ",
+ "\nCLOSE ",
+ "\nREAD ",
+ "\nWRITE ",
+ "\nIF ",
+ "\nELSE ",
+ "\nMOVE ",
+ "\nPERFORM ",
+ "\nUNTIL ",
+ "\nVARYING ",
+ "\nACCEPT ",
+ "\nDISPLAY ",
+ "\nSTOP RUN.",
+ # Split by the normal type of lines
+ "\n",
+ " ",
+ "",
+ ]
+
+ else:
+ raise ValueError(
+ f"Language {language} is not supported! "
+ f"Please choose from {list(Language)}"
+ )
+
+
+class NLTKTextSplitter(TextSplitter):
+ """Splitting text using NLTK package."""
+
+ def __init__(
+ self, separator: str = "\n\n", language: str = "english", **kwargs: Any
+ ) -> None:
+ """Initialize the NLTK splitter."""
+ super().__init__(**kwargs)
+ try:
+ from nltk.tokenize import sent_tokenize
+
+ self._tokenizer = sent_tokenize
+ except ImportError:
+ raise ImportError("""NLTK is not installed, please install it with
+ `pip install nltk`.""") from None
+ self._separator = separator
+ self._language = language
+
+ def split_text(self, text: str) -> list[str]:
+ """Split incoming text and return chunks."""
+ # First we naively split the large input into a bunch of smaller ones.
+ splits = self._tokenizer(text, language=self._language)
+ return self._merge_splits(splits, self._separator)
+
+
+class SpacyTextSplitter(TextSplitter):
+ """Splitting text using Spacy package.
+
+ Per default, Spacy's `en_core_web_sm` model is used and
+ its default max_length is 1000000 (it is the length of maximum character
+ this model takes which can be increased for large files). For a faster,
+ but potentially less accurate splitting, you can use `pipe='sentencizer'`.
+ """
+
+ def __init__(
+ self,
+ separator: str = "\n\n",
+ pipe: str = "en_core_web_sm",
+ max_length: int = 1_000_000,
+ **kwargs: Any,
+ ) -> None:
+ """Initialize the spacy text splitter."""
+ super().__init__(**kwargs)
+ self._tokenizer = _make_spacy_pipe_for_splitting(
+ pipe, max_length=max_length
+ )
+ self._separator = separator
+
+ def split_text(self, text: str) -> list[str]:
+ """Split incoming text and return chunks."""
+ splits = (s.text for s in self._tokenizer(text).sents)
+ return self._merge_splits(splits, self._separator)
+
+
+class KonlpyTextSplitter(TextSplitter):
+ """Splitting text using Konlpy package.
+
+ It is good for splitting Korean text.
+ """
+
+ def __init__(
+ self,
+ separator: str = "\n\n",
+ **kwargs: Any,
+ ) -> None:
+ """Initialize the Konlpy text splitter."""
+ super().__init__(**kwargs)
+ self._separator = separator
+ try:
+ from konlpy.tag import Kkma
+ except ImportError:
+ raise ImportError("""
+ Konlpy is not installed, please install it with
+ `pip install konlpy`
+ """) from None
+ self.kkma = Kkma()
+
+ def split_text(self, text: str) -> list[str]:
+ """Split incoming text and return chunks."""
+ splits = self.kkma.sentences(text)
+ return self._merge_splits(splits, self._separator)
+
+
+# For backwards compatibility
+class PythonCodeTextSplitter(RecursiveCharacterTextSplitter):
+ """Attempts to split the text along Python syntax."""
+
+ def __init__(self, **kwargs: Any) -> None:
+ """Initialize a PythonCodeTextSplitter."""
+ separators = self.get_separators_for_language(Language.PYTHON)
+ super().__init__(separators=separators, **kwargs)
+
+
+class MarkdownTextSplitter(RecursiveCharacterTextSplitter):
+ """Attempts to split the text along Markdown-formatted headings."""
+
+ def __init__(self, **kwargs: Any) -> None:
+ """Initialize a MarkdownTextSplitter."""
+ separators = self.get_separators_for_language(Language.MARKDOWN)
+ super().__init__(separators=separators, **kwargs)
+
+
+class LatexTextSplitter(RecursiveCharacterTextSplitter):
+ """Attempts to split the text along Latex-formatted layout elements."""
+
+ def __init__(self, **kwargs: Any) -> None:
+ """Initialize a LatexTextSplitter."""
+ separators = self.get_separators_for_language(Language.LATEX)
+ super().__init__(separators=separators, **kwargs)
+
+
+class RecursiveJsonSplitter:
+ def __init__(
+ self, max_chunk_size: int = 2000, min_chunk_size: Optional[int] = None
+ ):
+ super().__init__()
+ self.max_chunk_size = max_chunk_size
+ self.min_chunk_size = (
+ min_chunk_size
+ if min_chunk_size is not None
+ else max(max_chunk_size - 200, 50)
+ )
+
+ @staticmethod
+ def _json_size(data: dict) -> int:
+ """Calculate the size of the serialized JSON object."""
+ return len(json.dumps(data))
+
+ @staticmethod
+ def _set_nested_dict(d: dict, path: list[str], value: Any) -> None:
+ """Set a value in a nested dictionary based on the given path."""
+ for key in path[:-1]:
+ d = d.setdefault(key, {})
+ d[path[-1]] = value
+
+ def _list_to_dict_preprocessing(self, data: Any) -> Any:
+ if isinstance(data, dict):
+ # Process each key-value pair in the dictionary
+ return {
+ k: self._list_to_dict_preprocessing(v) for k, v in data.items()
+ }
+ elif isinstance(data, list):
+ # Convert the list to a dictionary with index-based keys
+ return {
+ str(i): self._list_to_dict_preprocessing(item)
+ for i, item in enumerate(data)
+ }
+ else:
+ # The item is neither a dict nor a list, return unchanged
+ return data
+
+ def _json_split(
+ self,
+ data: dict[str, Any],
+ current_path: list[str] | None = None,
+ chunks: list[dict] | None = None,
+ ) -> list[dict]:
+ """Split json into maximum size dictionaries while preserving
+ structure."""
+ if current_path is None:
+ current_path = []
+ if chunks is None:
+ chunks = [{}]
+
+ if isinstance(data, dict):
+ for key, value in data.items():
+ new_path = current_path + [key]
+ chunk_size = self._json_size(chunks[-1])
+ size = self._json_size({key: value})
+ remaining = self.max_chunk_size - chunk_size
+
+ if size < remaining:
+ # Add item to current chunk
+ self._set_nested_dict(chunks[-1], new_path, value)
+ else:
+ if chunk_size >= self.min_chunk_size:
+ # Chunk is big enough, start a new chunk
+ chunks.append({})
+
+ # Iterate
+ self._json_split(value, new_path, chunks)
+ else:
+ # handle single item
+ self._set_nested_dict(chunks[-1], current_path, data)
+ return chunks
+
+ def split_json(
+ self,
+ json_data: dict[str, Any],
+ convert_lists: bool = False,
+ ) -> list[dict]:
+ """Splits JSON into a list of JSON chunks."""
+
+ if convert_lists:
+ chunks = self._json_split(
+ self._list_to_dict_preprocessing(json_data)
+ )
+ else:
+ chunks = self._json_split(json_data)
+
+ # Remove the last chunk if it's empty
+ if not chunks[-1]:
+ chunks.pop()
+ return chunks
+
+ def split_text(
+ self, json_data: dict[str, Any], convert_lists: bool = False
+ ) -> list[str]:
+ """Splits JSON into a list of JSON formatted strings."""
+
+ chunks = self.split_json(
+ json_data=json_data, convert_lists=convert_lists
+ )
+
+ # Convert to string
+ return [json.dumps(chunk) for chunk in chunks]
+
+ def create_documents(
+ self,
+ texts: list[dict],
+ convert_lists: bool = False,
+ metadatas: Optional[list[dict]] = None,
+ ) -> list[SplitterDocument]:
+ """Create documents from a list of json objects (dict)."""
+ _metadatas = metadatas or [{}] * len(texts)
+ documents = []
+ for i, text in enumerate(texts):
+ for chunk in self.split_text(
+ json_data=text, convert_lists=convert_lists
+ ):
+ metadata = copy.deepcopy(_metadatas[i])
+ new_doc = SplitterDocument(
+ page_content=chunk, metadata=metadata
+ )
+ documents.append(new_doc)
+ return documents