aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/shared/abstractions
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/shared/abstractions')
-rw-r--r--.venv/lib/python3.12/site-packages/shared/abstractions/__init__.py146
-rw-r--r--.venv/lib/python3.12/site-packages/shared/abstractions/base.py145
-rw-r--r--.venv/lib/python3.12/site-packages/shared/abstractions/document.py377
-rw-r--r--.venv/lib/python3.12/site-packages/shared/abstractions/embedding.py31
-rw-r--r--.venv/lib/python3.12/site-packages/shared/abstractions/exception.py75
-rw-r--r--.venv/lib/python3.12/site-packages/shared/abstractions/graph.py257
-rw-r--r--.venv/lib/python3.12/site-packages/shared/abstractions/llm.py325
-rw-r--r--.venv/lib/python3.12/site-packages/shared/abstractions/prompt.py39
-rw-r--r--.venv/lib/python3.12/site-packages/shared/abstractions/search.py614
-rw-r--r--.venv/lib/python3.12/site-packages/shared/abstractions/user.py69
-rw-r--r--.venv/lib/python3.12/site-packages/shared/abstractions/vector.py239
11 files changed, 2317 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/shared/abstractions/__init__.py b/.venv/lib/python3.12/site-packages/shared/abstractions/__init__.py
new file mode 100644
index 00000000..da33ddd7
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/shared/abstractions/__init__.py
@@ -0,0 +1,146 @@
+from .base import AsyncSyncMeta, R2RSerializable, syncable
+from .document import (
+ Document,
+ DocumentChunk,
+ DocumentResponse,
+ DocumentType,
+ GraphConstructionStatus,
+ GraphExtractionStatus,
+ IngestionMode,
+ IngestionStatus,
+ RawChunk,
+ UnprocessedChunk,
+)
+from .embedding import EmbeddingPurpose, default_embedding_prefixes
+from .exception import (
+ PDFParsingError,
+ PopplerNotFoundError,
+ R2RDocumentProcessingError,
+ R2RException,
+)
+from .graph import (
+ Community,
+ Entity,
+ GraphCommunitySettings,
+ GraphCreationSettings,
+ GraphEnrichmentSettings,
+ GraphExtraction,
+ Relationship,
+ StoreType,
+)
+from .llm import (
+ GenerationConfig,
+ LLMChatCompletion,
+ LLMChatCompletionChunk,
+ Message,
+ MessageType,
+ RAGCompletion,
+)
+from .prompt import Prompt
+from .search import (
+ AggregateSearchResult,
+ ChunkSearchResult,
+ ChunkSearchSettings,
+ GraphCommunityResult,
+ GraphEntityResult,
+ GraphRelationshipResult,
+ GraphSearchResult,
+ GraphSearchResultType,
+ GraphSearchSettings,
+ HybridSearchSettings,
+ SearchMode,
+ SearchSettings,
+ WebPageSearchResult,
+ select_search_filters,
+)
+from .user import Token, TokenData, User
+from .vector import (
+ IndexArgsHNSW,
+ IndexArgsIVFFlat,
+ IndexMeasure,
+ IndexMethod,
+ StorageResult,
+ Vector,
+ VectorEntry,
+ VectorQuantizationType,
+ VectorTableName,
+ VectorType,
+)
+
+__all__ = [
+ # Base abstractions
+ "R2RSerializable",
+ "AsyncSyncMeta",
+ "syncable",
+ # Completion abstractions
+ "MessageType",
+ # Document abstractions
+ "Document",
+ "DocumentChunk",
+ "DocumentResponse",
+ "IngestionMode",
+ "IngestionStatus",
+ "GraphExtractionStatus",
+ "GraphConstructionStatus",
+ "DocumentType",
+ "RawChunk",
+ "UnprocessedChunk",
+ # Embedding abstractions
+ "EmbeddingPurpose",
+ "default_embedding_prefixes",
+ # Exception abstractions
+ "R2RDocumentProcessingError",
+ "R2RException",
+ "PDFParsingError",
+ "PopplerNotFoundError",
+ # Graph abstractions
+ "Entity",
+ "Community",
+ "Community",
+ "GraphExtraction",
+ "Relationship",
+ "StoreType",
+ # LLM abstractions
+ "GenerationConfig",
+ "LLMChatCompletion",
+ "LLMChatCompletionChunk",
+ "Message",
+ "RAGCompletion",
+ # Prompt abstractions
+ "Prompt",
+ # Search abstractions
+ "AggregateSearchResult",
+ "GraphSearchResult",
+ "WebPageSearchResult",
+ "GraphSearchResultType",
+ "GraphEntityResult",
+ "GraphRelationshipResult",
+ "GraphCommunityResult",
+ "GraphSearchSettings",
+ "ChunkSearchSettings",
+ "ChunkSearchResult",
+ "SearchSettings",
+ "select_search_filters",
+ "HybridSearchSettings",
+ "SearchMode",
+ # graph abstractions
+ "GraphCreationSettings",
+ "GraphEnrichmentSettings",
+ "GraphExtraction",
+ "GraphCommunitySettings",
+ # User abstractions
+ "Token",
+ "TokenData",
+ "User",
+ # Vector abstractions
+ "Vector",
+ "VectorEntry",
+ "VectorType",
+ "IndexMethod",
+ "IndexMeasure",
+ "IndexArgsIVFFlat",
+ "IndexArgsHNSW",
+ "VectorTableName",
+ "VectorQuantizationType",
+ "StorageResult",
+]
diff --git a/.venv/lib/python3.12/site-packages/shared/abstractions/base.py b/.venv/lib/python3.12/site-packages/shared/abstractions/base.py
new file mode 100644
index 00000000..d90ba400
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/shared/abstractions/base.py
@@ -0,0 +1,145 @@
+import asyncio
+import json
+from datetime import datetime
+from enum import Enum
+from typing import Any, Type, TypeVar
+from uuid import UUID
+
+from pydantic import BaseModel
+
+T = TypeVar("T", bound="R2RSerializable")
+
+
+class R2RSerializable(BaseModel):
+ @classmethod
+ def from_dict(cls: Type[T], data: dict[str, Any] | str) -> T:
+ if isinstance(data, str):
+ try:
+ data_dict = json.loads(data)
+ except json.JSONDecodeError as e:
+ raise ValueError(f"Invalid JSON string: {e}") from e
+ else:
+ data_dict = data
+ return cls(**data_dict)
+
+ def as_dict(self) -> dict[str, Any]:
+ data = self.model_dump(exclude_unset=True)
+ return self._serialize_values(data)
+
+ def to_dict(self) -> dict[str, Any]:
+ data = self.model_dump(exclude_unset=True)
+ return self._serialize_values(data)
+
+ def to_json(self) -> str:
+ data = self.to_dict()
+ return json.dumps(data)
+
+ @classmethod
+ def from_json(cls: Type[T], json_str: str) -> T:
+ return cls.model_validate_json(json_str)
+
+ @staticmethod
+ def _serialize_values(data: Any) -> Any:
+ if isinstance(data, dict):
+ return {
+ k: R2RSerializable._serialize_values(v)
+ for k, v in data.items()
+ }
+ elif isinstance(data, list):
+ return [R2RSerializable._serialize_values(v) for v in data]
+ elif isinstance(data, UUID):
+ return str(data)
+ elif isinstance(data, Enum):
+ return data.value
+ elif isinstance(data, datetime):
+ return data.isoformat()
+ else:
+ return data
+
+ class Config:
+ arbitrary_types_allowed = True
+ json_encoders = {
+ UUID: str,
+ bytes: lambda v: v.decode("utf-8", errors="ignore"),
+ }
+
+
+class AsyncSyncMeta(type):
+ _event_loop = None # Class-level shared event loop
+
+ @classmethod
+ def get_event_loop(cls):
+ if cls._event_loop is None or cls._event_loop.is_closed():
+ cls._event_loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(cls._event_loop)
+ return cls._event_loop
+
+ def __new__(cls, name, bases, dct):
+ new_cls = super().__new__(cls, name, bases, dct)
+ for attr_name, attr_value in dct.items():
+ if asyncio.iscoroutinefunction(attr_value) and getattr(
+ attr_value, "_syncable", False
+ ):
+ sync_method_name = attr_name[
+ 1:
+ ] # Remove leading 'a' for sync method
+ async_method = attr_value
+
+ def make_sync_method(async_method):
+ def sync_wrapper(self, *args, **kwargs):
+ loop = cls.get_event_loop()
+ if not loop.is_running():
+ # Setup to run the loop in a background thread if necessary
+ # to prevent blocking the main thread in a synchronous call environment
+ from threading import Thread
+
+ result = None
+ exception = None
+
+ def run():
+ nonlocal result, exception
+ try:
+ asyncio.set_event_loop(loop)
+ result = loop.run_until_complete(
+ async_method(self, *args, **kwargs)
+ )
+ except Exception as e:
+ exception = e
+ finally:
+ generation_config = kwargs.get(
+ "rag_generation_config", None
+ )
+ if (
+ not generation_config
+ or not generation_config.stream
+ ):
+ loop.run_until_complete(
+ loop.shutdown_asyncgens()
+ )
+ loop.close()
+
+ thread = Thread(target=run)
+ thread.start()
+ thread.join()
+ if exception:
+ raise exception
+ return result
+ else:
+ # If there's already a running loop, schedule and execute the coroutine
+ future = asyncio.run_coroutine_threadsafe(
+ async_method(self, *args, **kwargs), loop
+ )
+ return future.result()
+
+ return sync_wrapper
+
+ setattr(
+ new_cls, sync_method_name, make_sync_method(async_method)
+ )
+ return new_cls
+
+
+def syncable(func):
+ """Decorator to mark methods for synchronous wrapper creation."""
+ func._syncable = True
+ return func
diff --git a/.venv/lib/python3.12/site-packages/shared/abstractions/document.py b/.venv/lib/python3.12/site-packages/shared/abstractions/document.py
new file mode 100644
index 00000000..513392f8
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/shared/abstractions/document.py
@@ -0,0 +1,377 @@
+"""Abstractions for documents and their extractions."""
+
+import json
+import logging
+from datetime import datetime
+from enum import Enum
+from typing import Any, Optional
+from uuid import UUID, uuid4
+
+from pydantic import Field
+
+from .base import R2RSerializable
+from .llm import GenerationConfig
+
+logger = logging.getLogger()
+
+
+class DocumentType(str, Enum):
+ """Types of documents that can be stored."""
+
+ # Audio
+ MP3 = "mp3"
+
+ # CSV
+ CSV = "csv"
+
+ # Email
+ EML = "eml"
+ MSG = "msg"
+ P7S = "p7s"
+
+ # EPUB
+ EPUB = "epub"
+
+ # Excel
+ XLS = "xls"
+ XLSX = "xlsx"
+
+ # HTML
+ HTML = "html"
+ HTM = "htm"
+
+ # Image
+ BMP = "bmp"
+ HEIC = "heic"
+ JPEG = "jpeg"
+ PNG = "png"
+ TIFF = "tiff"
+ JPG = "jpg"
+ SVG = "svg"
+
+ # Markdown
+ MD = "md"
+
+ # Org Mode
+ ORG = "org"
+
+ # Open Office
+ ODT = "odt"
+
+ # PDF
+ PDF = "pdf"
+
+ # Plain text
+ TXT = "txt"
+ JSON = "json"
+
+ # PowerPoint
+ PPT = "ppt"
+ PPTX = "pptx"
+
+ # reStructured Text
+ RST = "rst"
+
+ # Rich Text
+ RTF = "rtf"
+
+ # TSV
+ TSV = "tsv"
+
+ # Video/GIF
+ GIF = "gif"
+
+ # Word
+ DOC = "doc"
+ DOCX = "docx"
+
+ # XML
+ XML = "xml"
+
+
+class Document(R2RSerializable):
+ id: UUID = Field(default_factory=uuid4)
+ collection_ids: list[UUID]
+ owner_id: UUID
+ document_type: DocumentType
+ metadata: dict
+
+ class Config:
+ arbitrary_types_allowed = True
+ ignore_extra = False
+ json_encoders = {
+ UUID: str,
+ }
+ populate_by_name = True
+
+
+class IngestionStatus(str, Enum):
+ """Status of document processing."""
+
+ PENDING = "pending"
+ PARSING = "parsing"
+ EXTRACTING = "extracting"
+ CHUNKING = "chunking"
+ EMBEDDING = "embedding"
+ AUGMENTING = "augmenting"
+ STORING = "storing"
+ ENRICHING = "enriching"
+
+ FAILED = "failed"
+ SUCCESS = "success"
+
+ def __str__(self):
+ return self.value
+
+ @classmethod
+ def table_name(cls) -> str:
+ return "documents"
+
+ @classmethod
+ def id_column(cls) -> str:
+ return "document_id"
+
+
+class GraphExtractionStatus(str, Enum):
+ """Status of graph creation per document."""
+
+ PENDING = "pending"
+ PROCESSING = "processing"
+ SUCCESS = "success"
+ ENRICHED = "enriched"
+ FAILED = "failed"
+
+ def __str__(self):
+ return self.value
+
+ @classmethod
+ def table_name(cls) -> str:
+ return "documents"
+
+ @classmethod
+ def id_column(cls) -> str:
+ return "id"
+
+
+class GraphConstructionStatus(str, Enum):
+ """Status of graph enrichment per collection."""
+
+ PENDING = "pending"
+ PROCESSING = "processing"
+ OUTDATED = "outdated"
+ SUCCESS = "success"
+ FAILED = "failed"
+
+ def __str__(self):
+ return self.value
+
+ @classmethod
+ def table_name(cls) -> str:
+ return "collections"
+
+ @classmethod
+ def id_column(cls) -> str:
+ return "id"
+
+
+class DocumentResponse(R2RSerializable):
+ """Base class for document information handling."""
+
+ id: UUID
+ collection_ids: list[UUID]
+ owner_id: UUID
+ document_type: DocumentType
+ metadata: dict
+ title: Optional[str] = None
+ version: str
+ size_in_bytes: Optional[int]
+ ingestion_status: IngestionStatus = IngestionStatus.PENDING
+ extraction_status: GraphExtractionStatus = GraphExtractionStatus.PENDING
+ created_at: Optional[datetime] = None
+ updated_at: Optional[datetime] = None
+ ingestion_attempt_number: Optional[int] = None
+ summary: Optional[str] = None
+ summary_embedding: Optional[list[float]] = None # Add optional embedding
+ total_tokens: Optional[int] = None
+ chunks: Optional[list] = None
+
+ def convert_to_db_entry(self):
+ """Prepare the document info for database entry, extracting certain
+ fields from metadata."""
+ now = datetime.now()
+
+ # Format the embedding properly for Postgres vector type
+ embedding = None
+ if self.summary_embedding is not None:
+ embedding = f"[{','.join(str(x) for x in self.summary_embedding)}]"
+
+ return {
+ "id": self.id,
+ "collection_ids": self.collection_ids,
+ "owner_id": self.owner_id,
+ "document_type": self.document_type,
+ "metadata": json.dumps(self.metadata),
+ "title": self.title or "N/A",
+ "version": self.version,
+ "size_in_bytes": self.size_in_bytes,
+ "ingestion_status": self.ingestion_status.value,
+ "extraction_status": self.extraction_status.value,
+ "created_at": self.created_at or now,
+ "updated_at": self.updated_at or now,
+ "ingestion_attempt_number": self.ingestion_attempt_number or 0,
+ "summary": self.summary,
+ "summary_embedding": embedding,
+ "total_tokens": self.total_tokens or 0, # ensure we pass 0 if None
+ }
+
+ class Config:
+ json_schema_extra = {
+ "example": {
+ "id": "123e4567-e89b-12d3-a456-426614174000",
+ "collection_ids": ["123e4567-e89b-12d3-a456-426614174000"],
+ "owner_id": "123e4567-e89b-12d3-a456-426614174000",
+ "document_type": "pdf",
+ "metadata": {"title": "Sample Document"},
+ "title": "Sample Document",
+ "version": "1.0",
+ "size_in_bytes": 123456,
+ "ingestion_status": "pending",
+ "extraction_status": "pending",
+ "created_at": "2021-01-01T00:00:00",
+ "updated_at": "2021-01-01T00:00:00",
+ "ingestion_attempt_number": 0,
+ "summary": "A summary of the document",
+ "summary_embedding": [0.1, 0.2, 0.3],
+ "total_tokens": 1000,
+ }
+ }
+
+
+class UnprocessedChunk(R2RSerializable):
+ """An extraction from a document."""
+
+ id: Optional[UUID] = None
+ document_id: Optional[UUID] = None
+ collection_ids: list[UUID] = []
+ metadata: dict = {}
+ text: str
+
+
+class UpdateChunk(R2RSerializable):
+ """An extraction from a document."""
+
+ id: UUID
+ metadata: Optional[dict] = None
+ text: str
+
+
+class DocumentChunk(R2RSerializable):
+ """An extraction from a document."""
+
+ id: UUID
+ document_id: UUID
+ collection_ids: list[UUID]
+ owner_id: UUID
+ data: str | bytes
+ metadata: dict
+
+
+class RawChunk(R2RSerializable):
+ text: str
+
+
+class IngestionMode(str, Enum):
+ hi_res = "hi-res"
+ fast = "fast"
+ custom = "custom"
+
+
+class ChunkEnrichmentSettings(R2RSerializable):
+ """Settings for chunk enrichment."""
+
+ enable_chunk_enrichment: bool = Field(
+ default=False,
+ description="Whether to enable chunk enrichment or not",
+ )
+ n_chunks: int = Field(
+ default=2,
+ description="The number of preceding and succeeding chunks to include. Defaults to 2.",
+ )
+ generation_config: Optional[GenerationConfig] = Field(
+ default=None,
+ description="The generation config to use for chunk enrichment",
+ )
+ chunk_enrichment_prompt: Optional[str] = Field(
+ default="chunk_enrichment",
+ description="The prompt to use for chunk enrichment",
+ )
+
+
+class IngestionConfig(R2RSerializable):
+ provider: str = "r2r"
+ excluded_parsers: list[str] = ["mp4"]
+ chunking_strategy: str = "recursive"
+ chunk_enrichment_settings: ChunkEnrichmentSettings = (
+ ChunkEnrichmentSettings()
+ )
+ extra_parsers: dict[str, Any] = {}
+
+ audio_transcription_model: str = ""
+
+ vision_img_prompt_name: str = "vision_img"
+
+ vision_pdf_prompt_name: str = "vision_pdf"
+
+ skip_document_summary: bool = False
+ document_summary_system_prompt: str = "system"
+ document_summary_task_prompt: str = "summary"
+ chunks_for_document_summary: int = 128
+ document_summary_model: str = ""
+
+ @property
+ def supported_providers(self) -> list[str]:
+ return ["r2r", "unstructured_local", "unstructured_api"]
+
+ def validate_config(self) -> None:
+ if self.provider not in self.supported_providers:
+ raise ValueError(f"Provider {self.provider} is not supported.")
+
+ @classmethod
+ def get_default(cls, mode: str) -> "IngestionConfig":
+ """Return default ingestion configuration for a given mode."""
+ if mode == "hi-res":
+ # More thorough parsing, no skipping summaries, possibly larger `chunks_for_document_summary`.
+ return cls(
+ provider="r2r",
+ excluded_parsers=["mp4"],
+ chunk_enrichment_settings=ChunkEnrichmentSettings(), # default
+ extra_parsers={},
+ audio_transcription_model="",
+ vision_img_prompt_name="vision_img",
+ vision_pdf_prompt_name="vision_pdf",
+ skip_document_summary=False,
+ document_summary_system_prompt="system",
+ document_summary_task_prompt="summary",
+ chunks_for_document_summary=256, # larger for hi-res
+ document_summary_model="",
+ )
+
+ elif mode == "fast":
+ # Skip summaries and other enrichment steps for speed.
+ return cls(
+ provider="r2r",
+ excluded_parsers=["mp4"],
+ chunk_enrichment_settings=ChunkEnrichmentSettings(), # default
+ extra_parsers={},
+ audio_transcription_model="",
+ vision_img_prompt_name="vision_img",
+ vision_pdf_prompt_name="vision_pdf",
+ skip_document_summary=True, # skip summaries
+ document_summary_system_prompt="system",
+ document_summary_task_prompt="summary",
+ chunks_for_document_summary=64,
+ document_summary_model="",
+ )
+ else:
+ # For `custom` or any unrecognized mode, return a base config
+ return cls()
diff --git a/.venv/lib/python3.12/site-packages/shared/abstractions/embedding.py b/.venv/lib/python3.12/site-packages/shared/abstractions/embedding.py
new file mode 100644
index 00000000..6e27da28
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/shared/abstractions/embedding.py
@@ -0,0 +1,31 @@
+from enum import Enum, auto
+
+
+class EmbeddingPurpose(str, Enum):
+ INDEX = auto()
+ QUERY = auto()
+ DOCUMENT = auto()
+
+
+default_embedding_prefixes = {
+ "nomic-embed-text-v1.5": {
+ EmbeddingPurpose.INDEX: "",
+ EmbeddingPurpose.QUERY: "search_query: ",
+ EmbeddingPurpose.DOCUMENT: "search_document: ",
+ },
+ "nomic-embed-text": {
+ EmbeddingPurpose.INDEX: "",
+ EmbeddingPurpose.QUERY: "search_query: ",
+ EmbeddingPurpose.DOCUMENT: "search_document: ",
+ },
+ "mixedbread-ai/mxbai-embed-large-v1": {
+ EmbeddingPurpose.INDEX: "",
+ EmbeddingPurpose.QUERY: "Represent this sentence for searching relevant passages: ",
+ EmbeddingPurpose.DOCUMENT: "Represent this sentence for searching relevant passages: ",
+ },
+ "mixedbread-ai/mxbai-embed-large": {
+ EmbeddingPurpose.INDEX: "",
+ EmbeddingPurpose.QUERY: "Represent this sentence for searching relevant passages: ",
+ EmbeddingPurpose.DOCUMENT: "Represent this sentence for searching relevant passages: ",
+ },
+}
diff --git a/.venv/lib/python3.12/site-packages/shared/abstractions/exception.py b/.venv/lib/python3.12/site-packages/shared/abstractions/exception.py
new file mode 100644
index 00000000..3dedfae8
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/shared/abstractions/exception.py
@@ -0,0 +1,75 @@
+import textwrap
+from typing import Any, Optional
+from uuid import UUID
+
+
+class R2RException(Exception):
+ def __init__(
+ self, message: str, status_code: int, detail: Optional[Any] = None
+ ):
+ self.message = message
+ self.status_code = status_code
+ super().__init__(self.message)
+
+ def to_dict(self):
+ return {
+ "message": self.message,
+ "status_code": self.status_code,
+ "detail": self.detail,
+ "error_type": self.__class__.__name__,
+ }
+
+
+class R2RDocumentProcessingError(R2RException):
+ def __init__(
+ self, error_message: str, document_id: UUID, status_code: int = 500
+ ):
+ detail = {
+ "document_id": str(document_id),
+ "error_type": "document_processing_error",
+ }
+ super().__init__(error_message, status_code, detail)
+
+ def to_dict(self):
+ result = super().to_dict()
+ result["document_id"] = self.document_id
+ return result
+
+
+class PDFParsingError(R2RException):
+ """Custom exception for PDF parsing errors."""
+
+ def __init__(
+ self,
+ message: str,
+ original_error: Exception | None = None,
+ status_code: int = 500,
+ ):
+ detail = {
+ "original_error": str(original_error) if original_error else None
+ }
+ super().__init__(message, status_code, detail)
+
+
+class PopplerNotFoundError(PDFParsingError):
+ """Specific error for when Poppler is not installed."""
+
+ def __init__(self):
+ installation_instructions = textwrap.dedent("""
+ PDF processing requires Poppler to be installed. Please install Poppler and ensure it's in your system PATH.
+
+ Installing poppler:
+ - Ubuntu: sudo apt-get install poppler-utils
+ - Archlinux: sudo pacman -S poppler
+ - MacOS: brew install poppler
+ - Windows:
+ 1. Download poppler from @oschwartz10612
+ 2. Move extracted directory to desired location
+ 3. Add bin/ directory to PATH
+ 4. Test by running 'pdftoppm -h' in terminal
+ """)
+ super().__init__(
+ message=installation_instructions,
+ status_code=422,
+ original_error=None,
+ )
diff --git a/.venv/lib/python3.12/site-packages/shared/abstractions/graph.py b/.venv/lib/python3.12/site-packages/shared/abstractions/graph.py
new file mode 100644
index 00000000..3c1cec9e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/shared/abstractions/graph.py
@@ -0,0 +1,257 @@
+import json
+from dataclasses import dataclass
+from datetime import datetime
+from enum import Enum
+from typing import Any, Optional
+from uuid import UUID
+
+from pydantic import Field
+
+from ..abstractions.llm import GenerationConfig
+from .base import R2RSerializable
+
+
+class Entity(R2RSerializable):
+ """An entity extracted from a document."""
+
+ name: str
+ description: Optional[str] = None
+ category: Optional[str] = None
+ metadata: Optional[dict[str, Any]] = None
+
+ id: Optional[UUID] = None
+ parent_id: Optional[UUID] = None # graph_id | document_id
+ description_embedding: Optional[list[float] | str] = None
+ chunk_ids: Optional[list[UUID]] = []
+
+ def __str__(self):
+ return f"{self.name}:{self.category}"
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ if isinstance(self.metadata, str):
+ try:
+ self.metadata = json.loads(self.metadata)
+ except json.JSONDecodeError:
+ self.metadata = self.metadata
+
+
+class Relationship(R2RSerializable):
+ """A relationship between two entities.
+
+ This is a generic relationship, and can be used to represent any type of
+ relationship between any two entities.
+ """
+
+ id: Optional[UUID] = None
+ subject: str
+ predicate: str
+ object: str
+ description: Optional[str] = None
+ subject_id: Optional[UUID] = None
+ object_id: Optional[UUID] = None
+ weight: float | None = 1.0
+ chunk_ids: Optional[list[UUID]] = []
+ parent_id: Optional[UUID] = None
+ description_embedding: Optional[list[float] | str] = None
+ metadata: Optional[dict[str, Any] | str] = None
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ if isinstance(self.metadata, str):
+ try:
+ self.metadata = json.loads(self.metadata)
+ except json.JSONDecodeError:
+ self.metadata = self.metadata
+
+
+@dataclass
+class Community(R2RSerializable):
+ name: str = ""
+ summary: str = ""
+ level: Optional[int] = None
+ findings: list[str] = []
+ id: Optional[int | UUID] = None
+ community_id: Optional[UUID] = None
+ collection_id: Optional[UUID] = None
+ rating: Optional[float] = None
+ rating_explanation: Optional[str] = None
+ description_embedding: Optional[list[float]] = None
+ attributes: dict[str, Any] | None = None
+ created_at: datetime = Field(
+ default_factory=datetime.utcnow,
+ )
+ updated_at: datetime = Field(
+ default_factory=datetime.utcnow,
+ )
+
+ def __init__(self, **kwargs):
+ if isinstance(kwargs.get("attributes", None), str):
+ kwargs["attributes"] = json.loads(kwargs["attributes"])
+
+ if isinstance(kwargs.get("embedding", None), str):
+ kwargs["embedding"] = json.loads(kwargs["embedding"])
+
+ super().__init__(**kwargs)
+
+ @classmethod
+ def from_dict(cls, data: dict[str, Any] | str) -> "Community":
+ parsed_data: dict[str, Any] = (
+ json.loads(data) if isinstance(data, str) else data
+ )
+ if isinstance(parsed_data.get("embedding", None), str):
+ parsed_data["embedding"] = json.loads(parsed_data["embedding"])
+ return cls(**parsed_data)
+
+
+class GraphExtraction(R2RSerializable):
+ """A protocol for a knowledge graph extraction."""
+
+ entities: list[Entity]
+ relationships: list[Relationship]
+
+
+class Graph(R2RSerializable):
+ id: UUID | None = Field()
+ name: str
+ description: Optional[str] = None
+ created_at: datetime = Field(
+ default_factory=datetime.utcnow,
+ )
+ updated_at: datetime = Field(
+ default_factory=datetime.utcnow,
+ )
+ status: str = "pending"
+
+ class Config:
+ populate_by_name = True
+ from_attributes = True
+
+ @classmethod
+ def from_dict(cls, data: dict[str, Any] | str) -> "Graph":
+ """Create a Graph instance from a dictionary."""
+ # Convert string to dict if needed
+ parsed_data: dict[str, Any] = (
+ json.loads(data) if isinstance(data, str) else data
+ )
+ return cls(**parsed_data)
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+
+class StoreType(str, Enum):
+ GRAPHS = "graphs"
+ DOCUMENTS = "documents"
+
+
+class GraphCreationSettings(R2RSerializable):
+ """Settings for knowledge graph creation."""
+
+ graph_extraction_prompt: str = Field(
+ default="graph_extraction",
+ description="The prompt to use for knowledge graph extraction.",
+ )
+
+ graph_entity_description_prompt: str = Field(
+ default="graph_entity_description",
+ description="The prompt to use for entity description generation.",
+ )
+
+ entity_types: list[str] = Field(
+ default=[],
+ description="The types of entities to extract.",
+ )
+
+ relation_types: list[str] = Field(
+ default=[],
+ description="The types of relations to extract.",
+ )
+
+ chunk_merge_count: int = Field(
+ default=2,
+ description="""The number of extractions to merge into a single graph
+ extraction.""",
+ )
+
+ max_knowledge_relationships: int = Field(
+ default=100,
+ description="""The maximum number of knowledge relationships to extract
+ from each chunk.""",
+ )
+
+ max_description_input_length: int = Field(
+ default=65536,
+ description="""The maximum length of the description for a node in the
+ graph.""",
+ )
+
+ generation_config: Optional[GenerationConfig] = Field(
+ default=None,
+ description="Configuration for text generation during graph enrichment.",
+ )
+
+ automatic_deduplication: bool = Field(
+ default=False,
+ description="Whether to automatically deduplicate entities.",
+ )
+
+
+class GraphEnrichmentSettings(R2RSerializable):
+ """Settings for knowledge graph enrichment."""
+
+ force_graph_search_results_enrichment: bool = Field(
+ default=False,
+ description="""Force run the enrichment step even if graph creation is
+ still in progress for some documents.""",
+ )
+
+ graph_communities_prompt: str = Field(
+ default="graph_communities",
+ description="The prompt to use for knowledge graph enrichment.",
+ )
+
+ max_summary_input_length: int = Field(
+ default=65536,
+ description="The maximum length of the summary for a community.",
+ )
+
+ generation_config: Optional[GenerationConfig] = Field(
+ default=None,
+ description="Configuration for text generation during graph enrichment.",
+ )
+
+ leiden_params: dict = Field(
+ default_factory=dict,
+ description="Parameters for the Leiden algorithm.",
+ )
+
+
+class GraphCommunitySettings(R2RSerializable):
+ """Settings for knowledge graph community enrichment."""
+
+ force_graph_search_results_enrichment: bool = Field(
+ default=False,
+ description="""Force run the enrichment step even if graph creation is
+ still in progress for some documents.""",
+ )
+
+ graph_communities: str = Field(
+ default="graph_communities",
+ description="The prompt to use for knowledge graph enrichment.",
+ )
+
+ max_summary_input_length: int = Field(
+ default=65536,
+ description="The maximum length of the summary for a community.",
+ )
+
+ generation_config: Optional[GenerationConfig] = Field(
+ default=None,
+ description="Configuration for text generation during graph enrichment.",
+ )
+
+ leiden_params: dict = Field(
+ default_factory=dict,
+ description="Parameters for the Leiden algorithm.",
+ )
diff --git a/.venv/lib/python3.12/site-packages/shared/abstractions/llm.py b/.venv/lib/python3.12/site-packages/shared/abstractions/llm.py
new file mode 100644
index 00000000..d71e279e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/shared/abstractions/llm.py
@@ -0,0 +1,325 @@
+"""Abstractions for the LLM model."""
+
+import json
+from enum import Enum
+from typing import TYPE_CHECKING, Any, ClassVar, Optional
+
+from openai.types.chat import ChatCompletionChunk
+from pydantic import BaseModel, Field
+
+from .base import R2RSerializable
+
+if TYPE_CHECKING:
+ from .search import AggregateSearchResult
+
+from typing_extensions import Literal
+
+
+class Function(BaseModel):
+ arguments: str
+ """
+ The arguments to call the function with, as generated by the model in JSON
+ format. Note that the model does not always generate valid JSON, and may
+ hallucinate parameters not defined by your function schema. Validate the
+ arguments in your code before calling your function.
+ """
+
+ name: str
+ """The name of the function to call."""
+
+
+class ChatCompletionMessageToolCall(BaseModel):
+ id: str
+ """The ID of the tool call."""
+
+ function: Function
+ """The function that the model called."""
+
+ type: Literal["function"]
+ """The type of the tool. Currently, only `function` is supported."""
+
+
+class FunctionCall(BaseModel):
+ arguments: str
+ """
+ The arguments to call the function with, as generated by the model in JSON
+ format. Note that the model does not always generate valid JSON, and may
+ hallucinate parameters not defined by your function schema. Validate the
+ arguments in your code before calling your function.
+ """
+
+ name: str
+ """The name of the function to call."""
+
+
+class ChatCompletionMessage(BaseModel):
+ content: Optional[str] = None
+ """The contents of the message."""
+
+ refusal: Optional[str] = None
+ """The refusal message generated by the model."""
+
+ role: Literal["assistant"]
+ """The role of the author of this message."""
+
+ # audio: Optional[ChatCompletionAudio] = None
+ """
+ If the audio output modality is requested, this object contains data about the
+ audio response from the model.
+ [Learn more](https://platform.openai.com/docs/guides/audio).
+ """
+
+ function_call: Optional[FunctionCall] = None
+ """Deprecated and replaced by `tool_calls`.
+
+ The name and arguments of a function that should be called, as generated by the
+ model.
+ """
+
+ tool_calls: Optional[list[ChatCompletionMessageToolCall]] = None
+ """The tool calls generated by the model, such as function calls."""
+
+ structured_content: Optional[list[dict]] = None
+
+
+class Choice(BaseModel):
+ finish_reason: Literal[
+ "stop",
+ "length",
+ "tool_calls",
+ "content_filter",
+ "function_call",
+ "max_tokens",
+ ]
+ """The reason the model stopped generating tokens.
+
+ This will be `stop` if the model hit a natural stop point or a provided stop
+ sequence, `length` if the maximum number of tokens specified in the request was
+ reached, `content_filter` if content was omitted due to a flag from our content
+ filters, `tool_calls` if the model called a tool, or `function_call`
+ (deprecated) if the model called a function.
+ """
+
+ index: int
+ """The index of the choice in the list of choices."""
+
+ # logprobs: Optional[ChoiceLogprobs] = None
+ """Log probability information for the choice."""
+
+ message: ChatCompletionMessage
+ """A chat completion message generated by the model."""
+
+
+class LLMChatCompletion(BaseModel):
+ id: str
+ """A unique identifier for the chat completion."""
+
+ choices: list[Choice]
+ """A list of chat completion choices.
+
+ Can be more than one if `n` is greater than 1.
+ """
+
+ created: int
+ """The Unix timestamp (in seconds) of when the chat completion was created."""
+
+ model: str
+ """The model used for the chat completion."""
+
+ object: Literal["chat.completion"]
+ """The object type, which is always `chat.completion`."""
+
+ service_tier: Optional[Literal["scale", "default"]] = None
+ """The service tier used for processing the request."""
+
+ system_fingerprint: Optional[str] = None
+ """This fingerprint represents the backend configuration that the model runs with.
+
+ Can be used in conjunction with the `seed` request parameter to understand when
+ backend changes have been made that might impact determinism.
+ """
+
+ usage: Optional[Any] = None
+ """Usage statistics for the completion request."""
+
+
+LLMChatCompletionChunk = ChatCompletionChunk
+
+
+class RAGCompletion:
+ completion: LLMChatCompletion
+ search_results: "AggregateSearchResult"
+
+ def __init__(
+ self,
+ completion: LLMChatCompletion,
+ search_results: "AggregateSearchResult",
+ ):
+ self.completion = completion
+ self.search_results = search_results
+
+
+class GenerationConfig(R2RSerializable):
+ _defaults: ClassVar[dict] = {
+ "model": None,
+ "temperature": 0.1,
+ "top_p": 1.0,
+ "max_tokens_to_sample": 1024,
+ "stream": False,
+ "functions": None,
+ "tools": None,
+ "add_generation_kwargs": None,
+ "api_base": None,
+ "response_format": None,
+ "extended_thinking": False,
+ "thinking_budget": None,
+ "reasoning_effort": None,
+ }
+
+ model: Optional[str] = Field(
+ default_factory=lambda: GenerationConfig._defaults["model"]
+ )
+ temperature: float = Field(
+ default_factory=lambda: GenerationConfig._defaults["temperature"]
+ )
+ top_p: Optional[float] = Field(
+ default_factory=lambda: GenerationConfig._defaults["top_p"],
+ )
+ max_tokens_to_sample: int = Field(
+ default_factory=lambda: GenerationConfig._defaults[
+ "max_tokens_to_sample"
+ ],
+ )
+ stream: bool = Field(
+ default_factory=lambda: GenerationConfig._defaults["stream"]
+ )
+ functions: Optional[list[dict]] = Field(
+ default_factory=lambda: GenerationConfig._defaults["functions"]
+ )
+ tools: Optional[list[dict]] = Field(
+ default_factory=lambda: GenerationConfig._defaults["tools"]
+ )
+ add_generation_kwargs: Optional[dict] = Field(
+ default_factory=lambda: GenerationConfig._defaults[
+ "add_generation_kwargs"
+ ],
+ )
+ api_base: Optional[str] = Field(
+ default_factory=lambda: GenerationConfig._defaults["api_base"],
+ )
+ response_format: Optional[dict | BaseModel] = None
+ extended_thinking: bool = Field(
+ default=False,
+ description="Flag to enable extended thinking mode (for Anthropic providers)",
+ )
+ thinking_budget: Optional[int] = Field(
+ default=None,
+ description=(
+ "Token budget for internal reasoning when extended thinking mode is enabled. "
+ "Must be less than max_tokens_to_sample."
+ ),
+ )
+ reasoning_effort: Optional[str] = Field(
+ default=None,
+ description=(
+ "Effort level for internal reasoning when extended thinking mode is enabled, `low`, `medium`, or `high`."
+ "Only applicable to OpenAI providers."
+ ),
+ )
+
+ @classmethod
+ def set_default(cls, **kwargs):
+ for key, value in kwargs.items():
+ if key in cls._defaults:
+ cls._defaults[key] = value
+ else:
+ raise AttributeError(
+ f"No default attribute '{key}' in GenerationConfig"
+ )
+
+ def __init__(self, **data):
+ # Handle max_tokens mapping to max_tokens_to_sample
+ if "max_tokens" in data:
+ # Only set max_tokens_to_sample if it's not already provided
+ if "max_tokens_to_sample" not in data:
+ data["max_tokens_to_sample"] = data.pop("max_tokens")
+ else:
+ # If both are provided, max_tokens_to_sample takes precedence
+ data.pop("max_tokens")
+
+ if (
+ "response_format" in data
+ and isinstance(data["response_format"], type)
+ and issubclass(data["response_format"], BaseModel)
+ ):
+ model_class = data["response_format"]
+ data["response_format"] = {
+ "type": "json_schema",
+ "json_schema": {
+ "name": model_class.__name__,
+ "schema": model_class.model_json_schema(),
+ },
+ }
+
+ model = data.pop("model", None)
+ if model is not None:
+ super().__init__(model=model, **data)
+ else:
+ super().__init__(**data)
+
+ def __str__(self):
+ return json.dumps(self.to_dict())
+
+ class Config:
+ populate_by_name = True
+ json_schema_extra = {
+ "example": {
+ "model": "openai/gpt-4o",
+ "temperature": 0.1,
+ "top_p": 1.0,
+ "max_tokens_to_sample": 1024,
+ "stream": False,
+ "functions": None,
+ "tools": None,
+ "add_generation_kwargs": None,
+ "api_base": None,
+ }
+ }
+
+
+class MessageType(Enum):
+ SYSTEM = "system"
+ USER = "user"
+ ASSISTANT = "assistant"
+ FUNCTION = "function"
+ TOOL = "tool"
+
+ def __str__(self):
+ return self.value
+
+
+class Message(R2RSerializable):
+ role: MessageType | str
+ content: Optional[Any] = None
+ name: Optional[str] = None
+ function_call: Optional[dict[str, Any]] = None
+ tool_calls: Optional[list[dict[str, Any]]] = None
+ tool_call_id: Optional[str] = None
+ metadata: Optional[dict[str, Any]] = None
+ structured_content: Optional[list[dict]] = None
+ image_url: Optional[str] = None # For URL-based images
+ image_data: Optional[dict[str, str]] = (
+ None # For base64 {media_type, data}
+ )
+
+ class Config:
+ populate_by_name = True
+ json_schema_extra = {
+ "example": {
+ "role": "user",
+ "content": "This is a test message.",
+ "name": None,
+ "function_call": None,
+ "tool_calls": None,
+ }
+ }
diff --git a/.venv/lib/python3.12/site-packages/shared/abstractions/prompt.py b/.venv/lib/python3.12/site-packages/shared/abstractions/prompt.py
new file mode 100644
index 00000000..85ab5312
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/shared/abstractions/prompt.py
@@ -0,0 +1,39 @@
+"""Abstraction for a prompt that can be formatted with inputs."""
+
+import logging
+from datetime import datetime
+from typing import Any
+from uuid import UUID, uuid4
+
+from pydantic import BaseModel, Field
+
+logger = logging.getLogger()
+
+
+class Prompt(BaseModel):
+ """A prompt that can be formatted with inputs."""
+
+ id: UUID = Field(default_factory=uuid4)
+ name: str
+ template: str
+ input_types: dict[str, str]
+ created_at: datetime = Field(default_factory=datetime.utcnow)
+ updated_at: datetime = Field(default_factory=datetime.utcnow)
+
+ def format_prompt(self, inputs: dict[str, Any]) -> str:
+ self._validate_inputs(inputs)
+ return self.template.format(**inputs)
+
+ def _validate_inputs(self, inputs: dict[str, Any]) -> None:
+ for var, expected_type_name in self.input_types.items():
+ expected_type = self._convert_type(expected_type_name)
+ if var not in inputs:
+ raise ValueError(f"Missing input: {var}")
+ if not isinstance(inputs[var], expected_type):
+ raise TypeError(
+ f"Input '{var}' must be of type {expected_type.__name__}, got {type(inputs[var]).__name__} instead."
+ )
+
+ def _convert_type(self, type_name: str) -> type:
+ type_mapping = {"int": int, "str": str}
+ return type_mapping.get(type_name, str)
diff --git a/.venv/lib/python3.12/site-packages/shared/abstractions/search.py b/.venv/lib/python3.12/site-packages/shared/abstractions/search.py
new file mode 100644
index 00000000..bf0f650e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/shared/abstractions/search.py
@@ -0,0 +1,614 @@
+"""Abstractions for search functionality."""
+
+from copy import copy
+from enum import Enum
+from typing import Any, Optional
+from uuid import NAMESPACE_DNS, UUID, uuid5
+
+from pydantic import Field
+
+from .base import R2RSerializable
+from .document import DocumentResponse
+from .llm import GenerationConfig
+from .vector import IndexMeasure
+
+
+def generate_id_from_label(label) -> UUID:
+ return uuid5(NAMESPACE_DNS, label)
+
+
+class ChunkSearchResult(R2RSerializable):
+ """Result of a search operation."""
+
+ id: UUID
+ document_id: UUID
+ owner_id: Optional[UUID]
+ collection_ids: list[UUID]
+ score: Optional[float] = None
+ text: str
+ metadata: dict[str, Any]
+
+ def __str__(self) -> str:
+ if self.score:
+ return (
+ f"ChunkSearchResult(score={self.score:.3f}, text={self.text})"
+ )
+ else:
+ return f"ChunkSearchResult(text={self.text})"
+
+ def __repr__(self) -> str:
+ return self.__str__()
+
+ def as_dict(self) -> dict:
+ return {
+ "id": self.id,
+ "document_id": self.document_id,
+ "owner_id": self.owner_id,
+ "collection_ids": self.collection_ids,
+ "score": self.score,
+ "text": self.text,
+ "metadata": self.metadata,
+ }
+
+ class Config:
+ populate_by_name = True
+ json_schema_extra = {
+ "example": {
+ "id": "3f3d47f3-8baf-58eb-8bc2-0171fb1c6e09",
+ "document_id": "3e157b3a-8469-51db-90d9-52e7d896b49b",
+ "owner_id": "2acb499e-8428-543b-bd85-0d9098718220",
+ "collection_ids": [],
+ "score": 0.23943702876567796,
+ "text": "Example text from the document",
+ "metadata": {
+ "title": "example_document.pdf",
+ "associated_query": "What is the capital of France?",
+ },
+ }
+ }
+
+
+class GraphSearchResultType(str, Enum):
+ ENTITY = "entity"
+ RELATIONSHIP = "relationship"
+ COMMUNITY = "community"
+
+
+class GraphEntityResult(R2RSerializable):
+ id: Optional[UUID] = None
+ name: str
+ description: str
+ metadata: Optional[dict[str, Any]] = None
+
+ class Config:
+ json_schema_extra = {
+ "example": {
+ "name": "Entity Name",
+ "description": "Entity Description",
+ "metadata": {},
+ }
+ }
+
+
+class GraphRelationshipResult(R2RSerializable):
+ id: Optional[UUID] = None
+ subject: str
+ predicate: str
+ object: str
+ subject_id: Optional[UUID] = None
+ object_id: Optional[UUID] = None
+ metadata: Optional[dict[str, Any]] = None
+ score: Optional[float] = None
+ description: str | None = None
+
+ class Config:
+ json_schema_extra = {
+ "example": {
+ "name": "Relationship Name",
+ "description": "Relationship Description",
+ "metadata": {},
+ }
+ }
+
+ def __str__(self) -> str:
+ return f"GraphRelationshipResult(subject={self.subject}, predicate={self.predicate}, object={self.object})"
+
+
+class GraphCommunityResult(R2RSerializable):
+ id: Optional[UUID] = None
+ name: str
+ summary: str
+ metadata: Optional[dict[str, Any]] = None
+
+ class Config:
+ json_schema_extra = {
+ "example": {
+ "name": "Community Name",
+ "summary": "Community Summary",
+ "rating": 9,
+ "rating_explanation": "Rating Explanation",
+ "metadata": {},
+ }
+ }
+
+ def __str__(self) -> str:
+ return (
+ f"GraphCommunityResult(name={self.name}, summary={self.summary})"
+ )
+
+
+class GraphSearchResult(R2RSerializable):
+ content: GraphEntityResult | GraphRelationshipResult | GraphCommunityResult
+ result_type: Optional[GraphSearchResultType] = None
+ chunk_ids: Optional[list[UUID]] = None
+ metadata: dict[str, Any] = {}
+ score: Optional[float] = None
+ id: UUID
+
+ def __str__(self) -> str:
+ return f"GraphSearchResult(content={self.content}, result_type={self.result_type})"
+
+ class Config:
+ populate_by_name = True
+ json_schema_extra = {
+ "example": {
+ "content": {
+ "id": "3f3d47f3-8baf-58eb-8bc2-0171fb1c6e09",
+ "name": "Entity Name",
+ "description": "Entity Description",
+ "metadata": {},
+ },
+ "result_type": "entity",
+ "chunk_ids": ["c68dc72e-fc23-5452-8f49-d7bd46088a96"],
+ "metadata": {
+ "associated_query": "What is the capital of France?"
+ },
+ }
+ }
+
+
+class WebPageSearchResult(R2RSerializable):
+ title: Optional[str] = None
+ link: Optional[str] = None
+ snippet: Optional[str] = None
+ position: int
+ type: str = "organic"
+ date: Optional[str] = None
+ sitelinks: Optional[list[dict]] = None
+ id: UUID
+
+ class Config:
+ json_schema_extra = {
+ "example": {
+ "title": "Page Title",
+ "link": "https://example.com/page",
+ "snippet": "Page snippet",
+ "position": 1,
+ "date": "2021-01-01",
+ "sitelinks": [
+ {
+ "title": "Sitelink Title",
+ "link": "https://example.com/sitelink",
+ }
+ ],
+ }
+ }
+
+ def __str__(self) -> str:
+ return f"WebPageSearchResult(title={self.title}, link={self.link}, snippet={self.snippet})"
+
+
+class RelatedSearchResult(R2RSerializable):
+ query: str
+ type: str = "related"
+ id: UUID
+
+
+class PeopleAlsoAskResult(R2RSerializable):
+ question: str
+ snippet: str
+ link: str
+ title: str
+ id: UUID
+ type: str = "peopleAlsoAsk"
+
+
+class WebSearchResult(R2RSerializable):
+ organic_results: list[WebPageSearchResult] = []
+ related_searches: list[RelatedSearchResult] = []
+ people_also_ask: list[PeopleAlsoAskResult] = []
+
+ @classmethod
+ def from_serper_results(cls, results: list[dict]) -> "WebSearchResult":
+ organic = []
+ related = []
+ paa = []
+
+ for result in results:
+ if result["type"] == "organic":
+ organic.append(
+ WebPageSearchResult(
+ **result, id=generate_id_from_label(result.get("link"))
+ )
+ )
+ elif result["type"] == "relatedSearches":
+ related.append(
+ RelatedSearchResult(
+ **result,
+ id=generate_id_from_label(result.get("query")),
+ )
+ )
+ elif result["type"] == "peopleAlsoAsk":
+ paa.append(
+ PeopleAlsoAskResult(
+ **result, id=generate_id_from_label(result.get("link"))
+ )
+ )
+
+ return cls(
+ organic_results=organic,
+ related_searches=related,
+ people_also_ask=paa,
+ )
+
+
+class AggregateSearchResult(R2RSerializable):
+ """Result of an aggregate search operation."""
+
+ chunk_search_results: Optional[list[ChunkSearchResult]] = None
+ graph_search_results: Optional[list[GraphSearchResult]] = None
+ web_search_results: Optional[list[WebPageSearchResult]] = None
+ document_search_results: Optional[list[DocumentResponse]] = None
+
+ def __str__(self) -> str:
+ return f"AggregateSearchResult(chunk_search_results={self.chunk_search_results}, graph_search_results={self.graph_search_results}, web_search_results={self.web_search_results}, document_search_results={str(self.document_search_results)})"
+
+ def __repr__(self) -> str:
+ return f"AggregateSearchResult(chunk_search_results={self.chunk_search_results}, graph_search_results={self.graph_search_results}, web_search_results={self.web_search_results}, document_search_results={str(self.document_search_results)})"
+
+ def as_dict(self) -> dict:
+ return {
+ "chunk_search_results": (
+ [result.as_dict() for result in self.chunk_search_results]
+ if self.chunk_search_results
+ else []
+ ),
+ "graph_search_results": (
+ [result.to_dict() for result in self.graph_search_results]
+ if self.graph_search_results
+ else []
+ ),
+ "web_search_results": (
+ [result.to_dict() for result in self.web_search_results]
+ if self.web_search_results
+ else []
+ ),
+ "document_search_results": (
+ [cdr.to_dict() for cdr in self.document_search_results]
+ if self.document_search_results
+ else []
+ ),
+ }
+
+ class Config:
+ populate_by_name = True
+ json_schema_extra = {
+ "example": {
+ "chunk_search_results": [
+ {
+ "id": "3f3d47f3-8baf-58eb-8bc2-0171fb1c6e09",
+ "document_id": "3e157b3a-8469-51db-90d9-52e7d896b49b",
+ "owner_id": "2acb499e-8428-543b-bd85-0d9098718220",
+ "collection_ids": [],
+ "score": 0.23943702876567796,
+ "text": "Example text from the document",
+ "metadata": {
+ "title": "example_document.pdf",
+ "associated_query": "What is the capital of France?",
+ },
+ }
+ ],
+ "graph_search_results": [
+ {
+ "content": {
+ "id": "3f3d47f3-8baf-58eb-8bc2-0171fb1c6e09",
+ "name": "Entity Name",
+ "description": "Entity Description",
+ "metadata": {},
+ },
+ "result_type": "entity",
+ "chunk_ids": ["c68dc72e-fc23-5452-8f49-d7bd46088a96"],
+ "metadata": {
+ "associated_query": "What is the capital of France?"
+ },
+ }
+ ],
+ "web_search_results": [
+ {
+ "title": "Page Title",
+ "link": "https://example.com/page",
+ "snippet": "Page snippet",
+ "position": 1,
+ "date": "2021-01-01",
+ "sitelinks": [
+ {
+ "title": "Sitelink Title",
+ "link": "https://example.com/sitelink",
+ }
+ ],
+ }
+ ],
+ "document_search_results": [
+ {
+ "document": {
+ "id": "3f3d47f3-8baf-58eb-8bc2-0171fb1c6e09",
+ "title": "Document Title",
+ "chunks": ["Chunk 1", "Chunk 2"],
+ "metadata": {},
+ },
+ }
+ ],
+ }
+ }
+
+
+class HybridSearchSettings(R2RSerializable):
+ """Settings for hybrid search combining full-text and semantic search."""
+
+ full_text_weight: float = Field(
+ default=1.0, description="Weight to apply to full text search"
+ )
+ semantic_weight: float = Field(
+ default=5.0, description="Weight to apply to semantic search"
+ )
+ full_text_limit: int = Field(
+ default=200,
+ description="Maximum number of results to return from full text search",
+ )
+ rrf_k: int = Field(
+ default=50, description="K-value for RRF (Rank Reciprocal Fusion)"
+ )
+
+
+class ChunkSearchSettings(R2RSerializable):
+ """Settings specific to chunk/vector search."""
+
+ index_measure: IndexMeasure = Field(
+ default=IndexMeasure.cosine_distance,
+ description="The distance measure to use for indexing",
+ )
+ probes: int = Field(
+ default=10,
+ description="Number of ivfflat index lists to query. Higher increases accuracy but decreases speed.",
+ )
+ ef_search: int = Field(
+ default=40,
+ description="Size of the dynamic candidate list for HNSW index search. Higher increases accuracy but decreases speed.",
+ )
+ enabled: bool = Field(
+ default=True,
+ description="Whether to enable chunk search",
+ )
+
+
+class GraphSearchSettings(R2RSerializable):
+ """Settings specific to knowledge graph search."""
+
+ generation_config: Optional[GenerationConfig] = Field(
+ default=None,
+ description="Configuration for text generation during graph search.",
+ )
+ max_community_description_length: int = Field(
+ default=65536,
+ )
+ max_llm_queries_for_global_search: int = Field(
+ default=250,
+ )
+ limits: dict[str, int] = Field(
+ default={},
+ )
+ enabled: bool = Field(
+ default=True,
+ description="Whether to enable graph search",
+ )
+
+
+class SearchSettings(R2RSerializable):
+ """Main search settings class that combines shared settings with
+ specialized settings for chunks and graph."""
+
+ # Search type flags
+ use_hybrid_search: bool = Field(
+ default=False,
+ description="Whether to perform a hybrid search. This is equivalent to setting `use_semantic_search=True` and `use_fulltext_search=True`, e.g. combining vector and keyword search.",
+ )
+ use_semantic_search: bool = Field(
+ default=True,
+ description="Whether to use semantic search",
+ )
+ use_fulltext_search: bool = Field(
+ default=False,
+ description="Whether to use full-text search",
+ )
+
+ # Common search parameters
+ filters: dict[str, Any] = Field(
+ default_factory=dict,
+ description="""Filters to apply to the search. Allowed operators include `eq`, `neq`, `gt`, `gte`, `lt`, `lte`, `like`, `ilike`, `in`, and `nin`.
+
+ Commonly seen filters include operations include the following:
+
+ `{"document_id": {"$eq": "9fbe403b-..."}}`
+
+ `{"document_id": {"$in": ["9fbe403b-...", "3e157b3a-..."]}}`
+
+ `{"collection_ids": {"$overlap": ["122fdf6a-...", "..."]}}`
+
+ `{"$and": {"$document_id": ..., "collection_ids": ...}}`""",
+ )
+ limit: int = Field(
+ default=10,
+ description="Maximum number of results to return",
+ ge=1,
+ le=1_000,
+ )
+ offset: int = Field(
+ default=0,
+ ge=0,
+ description="Offset to paginate search results",
+ )
+ include_metadatas: bool = Field(
+ default=True,
+ description="Whether to include element metadata in the search results",
+ )
+ include_scores: bool = Field(
+ default=True,
+ description="""Whether to include search score values in the
+ search results""",
+ )
+
+ # Search strategy and settings
+ search_strategy: str = Field(
+ default="vanilla",
+ description="""Search strategy to use
+ (e.g., 'vanilla', 'query_fusion', 'hyde')""",
+ )
+ hybrid_settings: HybridSearchSettings = Field(
+ default_factory=HybridSearchSettings,
+ description="""Settings for hybrid search (only used if
+ `use_semantic_search` and `use_fulltext_search` are both true)""",
+ )
+
+ # Specialized settings
+ chunk_settings: ChunkSearchSettings = Field(
+ default_factory=ChunkSearchSettings,
+ description="Settings specific to chunk/vector search",
+ )
+ graph_settings: GraphSearchSettings = Field(
+ default_factory=GraphSearchSettings,
+ description="Settings specific to knowledge graph search",
+ )
+
+ # For HyDE or multi-query:
+ num_sub_queries: int = Field(
+ default=5,
+ description="Number of sub-queries/hypothetical docs to generate when using hyde or rag_fusion search strategies.",
+ )
+
+ class Config:
+ populate_by_name = True
+ json_encoders = {UUID: str}
+ json_schema_extra = {
+ "example": {
+ "use_semantic_search": True,
+ "use_fulltext_search": False,
+ "use_hybrid_search": False,
+ "filters": {"category": "technology"},
+ "limit": 20,
+ "offset": 0,
+ "search_strategy": "vanilla",
+ "hybrid_settings": {
+ "full_text_weight": 1.0,
+ "semantic_weight": 5.0,
+ "full_text_limit": 200,
+ "rrf_k": 50,
+ },
+ "chunk_settings": {
+ "enabled": True,
+ "index_measure": "cosine_distance",
+ "include_metadata": True,
+ "probes": 10,
+ "ef_search": 40,
+ },
+ "graph_settings": {
+ "enabled": True,
+ "generation_config": GenerationConfig.Config.json_schema_extra,
+ "max_community_description_length": 65536,
+ "max_llm_queries_for_global_search": 250,
+ "limits": {
+ "entity": 20,
+ "relationship": 20,
+ "community": 20,
+ },
+ },
+ }
+ }
+
+ def __init__(self, **data):
+ # Handle legacy search_filters field
+ data["filters"] = {
+ **data.get("filters", {}),
+ **data.get("search_filters", {}),
+ }
+ super().__init__(**data)
+
+ def model_dump(self, *args, **kwargs):
+ return super().model_dump(*args, **kwargs)
+
+ @classmethod
+ def get_default(cls, mode: str) -> "SearchSettings":
+ """Return default search settings for a given mode."""
+ if mode == "basic":
+ # A simpler search that relies primarily on semantic search.
+ return cls(
+ use_semantic_search=True,
+ use_fulltext_search=False,
+ use_hybrid_search=False,
+ search_strategy="vanilla",
+ # Other relevant defaults can be provided here as needed
+ )
+ elif mode == "advanced":
+ # A more powerful, combined search that leverages both semantic and fulltext.
+ return cls(
+ use_semantic_search=True,
+ use_fulltext_search=True,
+ use_hybrid_search=True,
+ search_strategy="hyde",
+ # Other advanced defaults as needed
+ )
+ else:
+ # For 'custom' or unrecognized modes, return a basic empty config.
+ return cls()
+
+
+class SearchMode(str, Enum):
+ """Search modes for the search endpoint."""
+
+ basic = "basic"
+ advanced = "advanced"
+ custom = "custom"
+
+
+def select_search_filters(
+ auth_user: Any,
+ search_settings: SearchSettings,
+) -> dict[str, Any]:
+ filters = copy(search_settings.filters)
+ selected_collections = None
+ if not auth_user.is_superuser:
+ user_collections = set(auth_user.collection_ids)
+ for key in filters.keys():
+ if "collection_ids" in key:
+ selected_collections = set(map(UUID, filters[key]["$overlap"]))
+ break
+
+ if selected_collections:
+ allowed_collections = user_collections.intersection(
+ selected_collections
+ )
+ else:
+ allowed_collections = user_collections
+ # for non-superusers, we filter by user_id and selected & allowed collections
+ collection_filters = {
+ "$or": [
+ {"owner_id": {"$eq": auth_user.id}},
+ {"collection_ids": {"$overlap": list(allowed_collections)}},
+ ] # type: ignore
+ }
+
+ filters.pop("collection_ids", None)
+ if filters != {}:
+ filters = {"$and": [collection_filters, filters]} # type: ignore
+ else:
+ filters = collection_filters
+ return filters
diff --git a/.venv/lib/python3.12/site-packages/shared/abstractions/user.py b/.venv/lib/python3.12/site-packages/shared/abstractions/user.py
new file mode 100644
index 00000000..b04ac50b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/shared/abstractions/user.py
@@ -0,0 +1,69 @@
+from datetime import datetime
+from typing import Optional
+from uuid import UUID
+
+from pydantic import BaseModel, Field
+
+from shared.abstractions import R2RSerializable
+
+from ..utils import generate_default_user_collection_id
+
+
+class Collection(BaseModel):
+ id: UUID
+ name: str
+ description: Optional[str] = None
+ created_at: datetime = Field(
+ default_factory=datetime.utcnow,
+ )
+ updated_at: datetime = Field(
+ default_factory=datetime.utcnow,
+ )
+
+ class Config:
+ populate_by_name = True
+ from_attributes = True
+
+ def __init__(self, **data):
+ super().__init__(**data)
+ if self.id is None:
+ self.id = generate_default_user_collection_id(self.name)
+
+
+class Token(BaseModel):
+ token: str
+ token_type: str
+
+
+class TokenData(BaseModel):
+ email: str
+ token_type: str
+ exp: datetime
+
+
+class User(R2RSerializable):
+ id: UUID
+ email: str
+ is_active: bool = True
+ is_superuser: bool = False
+ created_at: datetime = datetime.now()
+ updated_at: datetime = datetime.now()
+ is_verified: bool = False
+ collection_ids: list[UUID] = []
+ graph_ids: list[UUID] = []
+ document_ids: list[UUID] = []
+
+ # Optional fields (to update or set at creation)
+ limits_overrides: Optional[dict] = None
+ metadata: Optional[dict] = None
+ verification_code_expiry: Optional[datetime] = None
+ name: Optional[str] = None
+ bio: Optional[str] = None
+ profile_picture: Optional[str] = None
+ total_size_in_bytes: Optional[int] = None
+ num_files: Optional[int] = None
+
+ account_type: str = "password"
+ hashed_password: Optional[str] = None
+ google_id: Optional[str] = None
+ github_id: Optional[str] = None
diff --git a/.venv/lib/python3.12/site-packages/shared/abstractions/vector.py b/.venv/lib/python3.12/site-packages/shared/abstractions/vector.py
new file mode 100644
index 00000000..0b88a765
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/shared/abstractions/vector.py
@@ -0,0 +1,239 @@
+"""Abstraction for a vector that can be stored in the system."""
+
+from enum import Enum
+from typing import Any, Optional
+from uuid import UUID
+
+from pydantic import BaseModel, Field
+
+from .base import R2RSerializable
+
+
+class VectorType(str, Enum):
+ FIXED = "FIXED"
+
+
+class IndexMethod(str, Enum):
+ """An enum representing the index methods available.
+
+ This class currently only supports the 'ivfflat' method but may
+ expand in the future.
+
+ Attributes:
+ auto (str): Automatically choose the best available index method.
+ ivfflat (str): The ivfflat index method.
+ hnsw (str): The hnsw index method.
+ """
+
+ auto = "auto"
+ ivfflat = "ivfflat"
+ hnsw = "hnsw"
+
+ def __str__(self) -> str:
+ return self.value
+
+
+class IndexMeasure(str, Enum):
+ """An enum representing the types of distance measures available for
+ indexing.
+
+ Attributes:
+ cosine_distance (str): The cosine distance measure for indexing.
+ l2_distance (str): The Euclidean (L2) distance measure for indexing.
+ max_inner_product (str): The maximum inner product measure for indexing.
+ """
+
+ l2_distance = "l2_distance"
+ max_inner_product = "max_inner_product"
+ cosine_distance = "cosine_distance"
+ l1_distance = "l1_distance"
+ hamming_distance = "hamming_distance"
+ jaccard_distance = "jaccard_distance"
+
+ def __str__(self) -> str:
+ return self.value
+
+ @property
+ def ops(self) -> str:
+ return {
+ IndexMeasure.l2_distance: "_l2_ops",
+ IndexMeasure.max_inner_product: "_ip_ops",
+ IndexMeasure.cosine_distance: "_cosine_ops",
+ IndexMeasure.l1_distance: "_l1_ops",
+ IndexMeasure.hamming_distance: "_hamming_ops",
+ IndexMeasure.jaccard_distance: "_jaccard_ops",
+ }[self]
+
+ @property
+ def pgvector_repr(self) -> str:
+ return {
+ IndexMeasure.l2_distance: "<->",
+ IndexMeasure.max_inner_product: "<#>",
+ IndexMeasure.cosine_distance: "<=>",
+ IndexMeasure.l1_distance: "<+>",
+ IndexMeasure.hamming_distance: "<~>",
+ IndexMeasure.jaccard_distance: "<%>",
+ }[self]
+
+
+class IndexArgsIVFFlat(R2RSerializable):
+ """A class for arguments that can optionally be supplied to the index
+ creation method when building an IVFFlat type index.
+
+ Attributes:
+ nlist (int): The number of IVF centroids that the index should use
+ """
+
+ n_lists: int
+
+
+class IndexArgsHNSW(R2RSerializable):
+ """A class for arguments that can optionally be supplied to the index
+ creation method when building an HNSW type index.
+
+ Ref: https://github.com/pgvector/pgvector#index-options
+
+ Both attributes are Optional in case the user only wants to specify one and
+ leave the other as default
+
+ Attributes:
+ m (int): Maximum number of connections per node per layer (default: 16)
+ ef_construction (int): Size of the dynamic candidate list for
+ constructing the graph (default: 64)
+ """
+
+ m: Optional[int] = 16
+ ef_construction: Optional[int] = 64
+
+
+class VectorTableName(str, Enum):
+ """This enum represents the different tables where we store vectors."""
+
+ CHUNKS = "chunks"
+ ENTITIES_DOCUMENT = "documents_entities"
+ GRAPHS_ENTITIES = "graphs_entities"
+ # TODO: Add support for relationships
+ # TRIPLES = "relationship"
+ COMMUNITIES = "graphs_communities"
+
+ def __str__(self) -> str:
+ return self.value
+
+
+class VectorQuantizationType(str, Enum):
+ """An enum representing the types of quantization available for vectors.
+
+ Attributes:
+ FP32 (str): 32-bit floating point quantization.
+ FP16 (str): 16-bit floating point quantization.
+ INT1 (str): 1-bit integer quantization.
+ SPARSE (str): Sparse vector quantization.
+ """
+
+ FP32 = "FP32"
+ FP16 = "FP16"
+ INT1 = "INT1"
+ SPARSE = "SPARSE"
+
+ def __str__(self) -> str:
+ return self.value
+
+ @property
+ def db_type(self) -> str:
+ db_type_mapping = {
+ "FP32": "vector",
+ "FP16": "halfvec",
+ "INT1": "bit",
+ "SPARSE": "sparsevec",
+ }
+ return db_type_mapping[self.value]
+
+
+class VectorQuantizationSettings(R2RSerializable):
+ quantization_type: VectorQuantizationType = Field(
+ default=VectorQuantizationType.FP32
+ )
+
+
+class Vector(R2RSerializable):
+ """A vector with the option to fix the number of elements."""
+
+ data: list[float]
+ type: VectorType = Field(default=VectorType.FIXED)
+ length: int = Field(default=-1)
+
+ def __init__(self, **data):
+ super().__init__(**data)
+ if (
+ self.type == VectorType.FIXED
+ and self.length > 0
+ and len(self.data) != self.length
+ ):
+ raise ValueError(
+ f"Vector must be exactly {self.length} elements long."
+ )
+
+ def __repr__(self) -> str:
+ return (
+ f"Vector(data={self.data}, type={self.type}, length={self.length})"
+ )
+
+
+class VectorEntry(R2RSerializable):
+ """A vector entry that can be stored directly in supported vector
+ databases."""
+
+ id: UUID
+ document_id: UUID
+ owner_id: UUID
+ collection_ids: list[UUID]
+ vector: Vector
+ text: str
+ metadata: dict[str, Any]
+
+ def __str__(self) -> str:
+ """Return a string representation of the VectorEntry."""
+ return (
+ f"VectorEntry("
+ f"chunk_id={self.id}, "
+ f"document_id={self.document_id}, "
+ f"owner_id={self.owner_id}, "
+ f"collection_ids={self.collection_ids}, "
+ f"vector={self.vector}, "
+ f"text={self.text}, "
+ f"metadata={self.metadata})"
+ )
+
+ def __repr__(self) -> str:
+ """Return an unambiguous string representation of the VectorEntry."""
+ return self.__str__()
+
+
+class StorageResult(R2RSerializable):
+ """A result of a storage operation."""
+
+ success: bool
+ document_id: UUID
+ num_chunks: int = 0
+ error_message: Optional[str] = None
+
+ def __str__(self) -> str:
+ """Return a string representation of the StorageResult."""
+ return f"StorageResult(success={self.success}, error_message={self.error_message})"
+
+ def __repr__(self) -> str:
+ """Return an unambiguous string representation of the StorageResult."""
+ return self.__str__()
+
+
+class IndexConfig(BaseModel):
+ name: Optional[str] = Field(default=None)
+ table_name: Optional[str] = Field(default=VectorTableName.CHUNKS)
+ index_method: Optional[str] = Field(default=IndexMethod.hnsw)
+ index_measure: Optional[str] = Field(default=IndexMeasure.cosine_distance)
+ index_arguments: Optional[IndexArgsIVFFlat | IndexArgsHNSW] = Field(
+ default=None
+ )
+ index_name: Optional[str] = Field(default=None)
+ index_column: Optional[str] = Field(default=None)
+ concurrently: Optional[bool] = Field(default=True)