diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/shared/abstractions')
11 files changed, 2317 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/shared/abstractions/__init__.py b/.venv/lib/python3.12/site-packages/shared/abstractions/__init__.py new file mode 100644 index 00000000..da33ddd7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/shared/abstractions/__init__.py @@ -0,0 +1,146 @@ +from .base import AsyncSyncMeta, R2RSerializable, syncable +from .document import ( + Document, + DocumentChunk, + DocumentResponse, + DocumentType, + GraphConstructionStatus, + GraphExtractionStatus, + IngestionMode, + IngestionStatus, + RawChunk, + UnprocessedChunk, +) +from .embedding import EmbeddingPurpose, default_embedding_prefixes +from .exception import ( + PDFParsingError, + PopplerNotFoundError, + R2RDocumentProcessingError, + R2RException, +) +from .graph import ( + Community, + Entity, + GraphCommunitySettings, + GraphCreationSettings, + GraphEnrichmentSettings, + GraphExtraction, + Relationship, + StoreType, +) +from .llm import ( + GenerationConfig, + LLMChatCompletion, + LLMChatCompletionChunk, + Message, + MessageType, + RAGCompletion, +) +from .prompt import Prompt +from .search import ( + AggregateSearchResult, + ChunkSearchResult, + ChunkSearchSettings, + GraphCommunityResult, + GraphEntityResult, + GraphRelationshipResult, + GraphSearchResult, + GraphSearchResultType, + GraphSearchSettings, + HybridSearchSettings, + SearchMode, + SearchSettings, + WebPageSearchResult, + select_search_filters, +) +from .user import Token, TokenData, User +from .vector import ( + IndexArgsHNSW, + IndexArgsIVFFlat, + IndexMeasure, + IndexMethod, + StorageResult, + Vector, + VectorEntry, + VectorQuantizationType, + VectorTableName, + VectorType, +) + +__all__ = [ + # Base abstractions + "R2RSerializable", + "AsyncSyncMeta", + "syncable", + # Completion abstractions + "MessageType", + # Document abstractions + "Document", + "DocumentChunk", + "DocumentResponse", + "IngestionMode", + "IngestionStatus", + "GraphExtractionStatus", + "GraphConstructionStatus", + "DocumentType", + "RawChunk", + "UnprocessedChunk", + # Embedding abstractions + "EmbeddingPurpose", + "default_embedding_prefixes", + # Exception abstractions + "R2RDocumentProcessingError", + "R2RException", + "PDFParsingError", + "PopplerNotFoundError", + # Graph abstractions + "Entity", + "Community", + "Community", + "GraphExtraction", + "Relationship", + "StoreType", + # LLM abstractions + "GenerationConfig", + "LLMChatCompletion", + "LLMChatCompletionChunk", + "Message", + "RAGCompletion", + # Prompt abstractions + "Prompt", + # Search abstractions + "AggregateSearchResult", + "GraphSearchResult", + "WebPageSearchResult", + "GraphSearchResultType", + "GraphEntityResult", + "GraphRelationshipResult", + "GraphCommunityResult", + "GraphSearchSettings", + "ChunkSearchSettings", + "ChunkSearchResult", + "SearchSettings", + "select_search_filters", + "HybridSearchSettings", + "SearchMode", + # graph abstractions + "GraphCreationSettings", + "GraphEnrichmentSettings", + "GraphExtraction", + "GraphCommunitySettings", + # User abstractions + "Token", + "TokenData", + "User", + # Vector abstractions + "Vector", + "VectorEntry", + "VectorType", + "IndexMethod", + "IndexMeasure", + "IndexArgsIVFFlat", + "IndexArgsHNSW", + "VectorTableName", + "VectorQuantizationType", + "StorageResult", +] diff --git a/.venv/lib/python3.12/site-packages/shared/abstractions/base.py b/.venv/lib/python3.12/site-packages/shared/abstractions/base.py new file mode 100644 index 00000000..d90ba400 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/shared/abstractions/base.py @@ -0,0 +1,145 @@ +import asyncio +import json +from datetime import datetime +from enum import Enum +from typing import Any, Type, TypeVar +from uuid import UUID + +from pydantic import BaseModel + +T = TypeVar("T", bound="R2RSerializable") + + +class R2RSerializable(BaseModel): + @classmethod + def from_dict(cls: Type[T], data: dict[str, Any] | str) -> T: + if isinstance(data, str): + try: + data_dict = json.loads(data) + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON string: {e}") from e + else: + data_dict = data + return cls(**data_dict) + + def as_dict(self) -> dict[str, Any]: + data = self.model_dump(exclude_unset=True) + return self._serialize_values(data) + + def to_dict(self) -> dict[str, Any]: + data = self.model_dump(exclude_unset=True) + return self._serialize_values(data) + + def to_json(self) -> str: + data = self.to_dict() + return json.dumps(data) + + @classmethod + def from_json(cls: Type[T], json_str: str) -> T: + return cls.model_validate_json(json_str) + + @staticmethod + def _serialize_values(data: Any) -> Any: + if isinstance(data, dict): + return { + k: R2RSerializable._serialize_values(v) + for k, v in data.items() + } + elif isinstance(data, list): + return [R2RSerializable._serialize_values(v) for v in data] + elif isinstance(data, UUID): + return str(data) + elif isinstance(data, Enum): + return data.value + elif isinstance(data, datetime): + return data.isoformat() + else: + return data + + class Config: + arbitrary_types_allowed = True + json_encoders = { + UUID: str, + bytes: lambda v: v.decode("utf-8", errors="ignore"), + } + + +class AsyncSyncMeta(type): + _event_loop = None # Class-level shared event loop + + @classmethod + def get_event_loop(cls): + if cls._event_loop is None or cls._event_loop.is_closed(): + cls._event_loop = asyncio.new_event_loop() + asyncio.set_event_loop(cls._event_loop) + return cls._event_loop + + def __new__(cls, name, bases, dct): + new_cls = super().__new__(cls, name, bases, dct) + for attr_name, attr_value in dct.items(): + if asyncio.iscoroutinefunction(attr_value) and getattr( + attr_value, "_syncable", False + ): + sync_method_name = attr_name[ + 1: + ] # Remove leading 'a' for sync method + async_method = attr_value + + def make_sync_method(async_method): + def sync_wrapper(self, *args, **kwargs): + loop = cls.get_event_loop() + if not loop.is_running(): + # Setup to run the loop in a background thread if necessary + # to prevent blocking the main thread in a synchronous call environment + from threading import Thread + + result = None + exception = None + + def run(): + nonlocal result, exception + try: + asyncio.set_event_loop(loop) + result = loop.run_until_complete( + async_method(self, *args, **kwargs) + ) + except Exception as e: + exception = e + finally: + generation_config = kwargs.get( + "rag_generation_config", None + ) + if ( + not generation_config + or not generation_config.stream + ): + loop.run_until_complete( + loop.shutdown_asyncgens() + ) + loop.close() + + thread = Thread(target=run) + thread.start() + thread.join() + if exception: + raise exception + return result + else: + # If there's already a running loop, schedule and execute the coroutine + future = asyncio.run_coroutine_threadsafe( + async_method(self, *args, **kwargs), loop + ) + return future.result() + + return sync_wrapper + + setattr( + new_cls, sync_method_name, make_sync_method(async_method) + ) + return new_cls + + +def syncable(func): + """Decorator to mark methods for synchronous wrapper creation.""" + func._syncable = True + return func diff --git a/.venv/lib/python3.12/site-packages/shared/abstractions/document.py b/.venv/lib/python3.12/site-packages/shared/abstractions/document.py new file mode 100644 index 00000000..513392f8 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/shared/abstractions/document.py @@ -0,0 +1,377 @@ +"""Abstractions for documents and their extractions.""" + +import json +import logging +from datetime import datetime +from enum import Enum +from typing import Any, Optional +from uuid import UUID, uuid4 + +from pydantic import Field + +from .base import R2RSerializable +from .llm import GenerationConfig + +logger = logging.getLogger() + + +class DocumentType(str, Enum): + """Types of documents that can be stored.""" + + # Audio + MP3 = "mp3" + + # CSV + CSV = "csv" + + # Email + EML = "eml" + MSG = "msg" + P7S = "p7s" + + # EPUB + EPUB = "epub" + + # Excel + XLS = "xls" + XLSX = "xlsx" + + # HTML + HTML = "html" + HTM = "htm" + + # Image + BMP = "bmp" + HEIC = "heic" + JPEG = "jpeg" + PNG = "png" + TIFF = "tiff" + JPG = "jpg" + SVG = "svg" + + # Markdown + MD = "md" + + # Org Mode + ORG = "org" + + # Open Office + ODT = "odt" + + # PDF + PDF = "pdf" + + # Plain text + TXT = "txt" + JSON = "json" + + # PowerPoint + PPT = "ppt" + PPTX = "pptx" + + # reStructured Text + RST = "rst" + + # Rich Text + RTF = "rtf" + + # TSV + TSV = "tsv" + + # Video/GIF + GIF = "gif" + + # Word + DOC = "doc" + DOCX = "docx" + + # XML + XML = "xml" + + +class Document(R2RSerializable): + id: UUID = Field(default_factory=uuid4) + collection_ids: list[UUID] + owner_id: UUID + document_type: DocumentType + metadata: dict + + class Config: + arbitrary_types_allowed = True + ignore_extra = False + json_encoders = { + UUID: str, + } + populate_by_name = True + + +class IngestionStatus(str, Enum): + """Status of document processing.""" + + PENDING = "pending" + PARSING = "parsing" + EXTRACTING = "extracting" + CHUNKING = "chunking" + EMBEDDING = "embedding" + AUGMENTING = "augmenting" + STORING = "storing" + ENRICHING = "enriching" + + FAILED = "failed" + SUCCESS = "success" + + def __str__(self): + return self.value + + @classmethod + def table_name(cls) -> str: + return "documents" + + @classmethod + def id_column(cls) -> str: + return "document_id" + + +class GraphExtractionStatus(str, Enum): + """Status of graph creation per document.""" + + PENDING = "pending" + PROCESSING = "processing" + SUCCESS = "success" + ENRICHED = "enriched" + FAILED = "failed" + + def __str__(self): + return self.value + + @classmethod + def table_name(cls) -> str: + return "documents" + + @classmethod + def id_column(cls) -> str: + return "id" + + +class GraphConstructionStatus(str, Enum): + """Status of graph enrichment per collection.""" + + PENDING = "pending" + PROCESSING = "processing" + OUTDATED = "outdated" + SUCCESS = "success" + FAILED = "failed" + + def __str__(self): + return self.value + + @classmethod + def table_name(cls) -> str: + return "collections" + + @classmethod + def id_column(cls) -> str: + return "id" + + +class DocumentResponse(R2RSerializable): + """Base class for document information handling.""" + + id: UUID + collection_ids: list[UUID] + owner_id: UUID + document_type: DocumentType + metadata: dict + title: Optional[str] = None + version: str + size_in_bytes: Optional[int] + ingestion_status: IngestionStatus = IngestionStatus.PENDING + extraction_status: GraphExtractionStatus = GraphExtractionStatus.PENDING + created_at: Optional[datetime] = None + updated_at: Optional[datetime] = None + ingestion_attempt_number: Optional[int] = None + summary: Optional[str] = None + summary_embedding: Optional[list[float]] = None # Add optional embedding + total_tokens: Optional[int] = None + chunks: Optional[list] = None + + def convert_to_db_entry(self): + """Prepare the document info for database entry, extracting certain + fields from metadata.""" + now = datetime.now() + + # Format the embedding properly for Postgres vector type + embedding = None + if self.summary_embedding is not None: + embedding = f"[{','.join(str(x) for x in self.summary_embedding)}]" + + return { + "id": self.id, + "collection_ids": self.collection_ids, + "owner_id": self.owner_id, + "document_type": self.document_type, + "metadata": json.dumps(self.metadata), + "title": self.title or "N/A", + "version": self.version, + "size_in_bytes": self.size_in_bytes, + "ingestion_status": self.ingestion_status.value, + "extraction_status": self.extraction_status.value, + "created_at": self.created_at or now, + "updated_at": self.updated_at or now, + "ingestion_attempt_number": self.ingestion_attempt_number or 0, + "summary": self.summary, + "summary_embedding": embedding, + "total_tokens": self.total_tokens or 0, # ensure we pass 0 if None + } + + class Config: + json_schema_extra = { + "example": { + "id": "123e4567-e89b-12d3-a456-426614174000", + "collection_ids": ["123e4567-e89b-12d3-a456-426614174000"], + "owner_id": "123e4567-e89b-12d3-a456-426614174000", + "document_type": "pdf", + "metadata": {"title": "Sample Document"}, + "title": "Sample Document", + "version": "1.0", + "size_in_bytes": 123456, + "ingestion_status": "pending", + "extraction_status": "pending", + "created_at": "2021-01-01T00:00:00", + "updated_at": "2021-01-01T00:00:00", + "ingestion_attempt_number": 0, + "summary": "A summary of the document", + "summary_embedding": [0.1, 0.2, 0.3], + "total_tokens": 1000, + } + } + + +class UnprocessedChunk(R2RSerializable): + """An extraction from a document.""" + + id: Optional[UUID] = None + document_id: Optional[UUID] = None + collection_ids: list[UUID] = [] + metadata: dict = {} + text: str + + +class UpdateChunk(R2RSerializable): + """An extraction from a document.""" + + id: UUID + metadata: Optional[dict] = None + text: str + + +class DocumentChunk(R2RSerializable): + """An extraction from a document.""" + + id: UUID + document_id: UUID + collection_ids: list[UUID] + owner_id: UUID + data: str | bytes + metadata: dict + + +class RawChunk(R2RSerializable): + text: str + + +class IngestionMode(str, Enum): + hi_res = "hi-res" + fast = "fast" + custom = "custom" + + +class ChunkEnrichmentSettings(R2RSerializable): + """Settings for chunk enrichment.""" + + enable_chunk_enrichment: bool = Field( + default=False, + description="Whether to enable chunk enrichment or not", + ) + n_chunks: int = Field( + default=2, + description="The number of preceding and succeeding chunks to include. Defaults to 2.", + ) + generation_config: Optional[GenerationConfig] = Field( + default=None, + description="The generation config to use for chunk enrichment", + ) + chunk_enrichment_prompt: Optional[str] = Field( + default="chunk_enrichment", + description="The prompt to use for chunk enrichment", + ) + + +class IngestionConfig(R2RSerializable): + provider: str = "r2r" + excluded_parsers: list[str] = ["mp4"] + chunking_strategy: str = "recursive" + chunk_enrichment_settings: ChunkEnrichmentSettings = ( + ChunkEnrichmentSettings() + ) + extra_parsers: dict[str, Any] = {} + + audio_transcription_model: str = "" + + vision_img_prompt_name: str = "vision_img" + + vision_pdf_prompt_name: str = "vision_pdf" + + skip_document_summary: bool = False + document_summary_system_prompt: str = "system" + document_summary_task_prompt: str = "summary" + chunks_for_document_summary: int = 128 + document_summary_model: str = "" + + @property + def supported_providers(self) -> list[str]: + return ["r2r", "unstructured_local", "unstructured_api"] + + def validate_config(self) -> None: + if self.provider not in self.supported_providers: + raise ValueError(f"Provider {self.provider} is not supported.") + + @classmethod + def get_default(cls, mode: str) -> "IngestionConfig": + """Return default ingestion configuration for a given mode.""" + if mode == "hi-res": + # More thorough parsing, no skipping summaries, possibly larger `chunks_for_document_summary`. + return cls( + provider="r2r", + excluded_parsers=["mp4"], + chunk_enrichment_settings=ChunkEnrichmentSettings(), # default + extra_parsers={}, + audio_transcription_model="", + vision_img_prompt_name="vision_img", + vision_pdf_prompt_name="vision_pdf", + skip_document_summary=False, + document_summary_system_prompt="system", + document_summary_task_prompt="summary", + chunks_for_document_summary=256, # larger for hi-res + document_summary_model="", + ) + + elif mode == "fast": + # Skip summaries and other enrichment steps for speed. + return cls( + provider="r2r", + excluded_parsers=["mp4"], + chunk_enrichment_settings=ChunkEnrichmentSettings(), # default + extra_parsers={}, + audio_transcription_model="", + vision_img_prompt_name="vision_img", + vision_pdf_prompt_name="vision_pdf", + skip_document_summary=True, # skip summaries + document_summary_system_prompt="system", + document_summary_task_prompt="summary", + chunks_for_document_summary=64, + document_summary_model="", + ) + else: + # For `custom` or any unrecognized mode, return a base config + return cls() diff --git a/.venv/lib/python3.12/site-packages/shared/abstractions/embedding.py b/.venv/lib/python3.12/site-packages/shared/abstractions/embedding.py new file mode 100644 index 00000000..6e27da28 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/shared/abstractions/embedding.py @@ -0,0 +1,31 @@ +from enum import Enum, auto + + +class EmbeddingPurpose(str, Enum): + INDEX = auto() + QUERY = auto() + DOCUMENT = auto() + + +default_embedding_prefixes = { + "nomic-embed-text-v1.5": { + EmbeddingPurpose.INDEX: "", + EmbeddingPurpose.QUERY: "search_query: ", + EmbeddingPurpose.DOCUMENT: "search_document: ", + }, + "nomic-embed-text": { + EmbeddingPurpose.INDEX: "", + EmbeddingPurpose.QUERY: "search_query: ", + EmbeddingPurpose.DOCUMENT: "search_document: ", + }, + "mixedbread-ai/mxbai-embed-large-v1": { + EmbeddingPurpose.INDEX: "", + EmbeddingPurpose.QUERY: "Represent this sentence for searching relevant passages: ", + EmbeddingPurpose.DOCUMENT: "Represent this sentence for searching relevant passages: ", + }, + "mixedbread-ai/mxbai-embed-large": { + EmbeddingPurpose.INDEX: "", + EmbeddingPurpose.QUERY: "Represent this sentence for searching relevant passages: ", + EmbeddingPurpose.DOCUMENT: "Represent this sentence for searching relevant passages: ", + }, +} diff --git a/.venv/lib/python3.12/site-packages/shared/abstractions/exception.py b/.venv/lib/python3.12/site-packages/shared/abstractions/exception.py new file mode 100644 index 00000000..3dedfae8 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/shared/abstractions/exception.py @@ -0,0 +1,75 @@ +import textwrap +from typing import Any, Optional +from uuid import UUID + + +class R2RException(Exception): + def __init__( + self, message: str, status_code: int, detail: Optional[Any] = None + ): + self.message = message + self.status_code = status_code + super().__init__(self.message) + + def to_dict(self): + return { + "message": self.message, + "status_code": self.status_code, + "detail": self.detail, + "error_type": self.__class__.__name__, + } + + +class R2RDocumentProcessingError(R2RException): + def __init__( + self, error_message: str, document_id: UUID, status_code: int = 500 + ): + detail = { + "document_id": str(document_id), + "error_type": "document_processing_error", + } + super().__init__(error_message, status_code, detail) + + def to_dict(self): + result = super().to_dict() + result["document_id"] = self.document_id + return result + + +class PDFParsingError(R2RException): + """Custom exception for PDF parsing errors.""" + + def __init__( + self, + message: str, + original_error: Exception | None = None, + status_code: int = 500, + ): + detail = { + "original_error": str(original_error) if original_error else None + } + super().__init__(message, status_code, detail) + + +class PopplerNotFoundError(PDFParsingError): + """Specific error for when Poppler is not installed.""" + + def __init__(self): + installation_instructions = textwrap.dedent(""" + PDF processing requires Poppler to be installed. Please install Poppler and ensure it's in your system PATH. + + Installing poppler: + - Ubuntu: sudo apt-get install poppler-utils + - Archlinux: sudo pacman -S poppler + - MacOS: brew install poppler + - Windows: + 1. Download poppler from @oschwartz10612 + 2. Move extracted directory to desired location + 3. Add bin/ directory to PATH + 4. Test by running 'pdftoppm -h' in terminal + """) + super().__init__( + message=installation_instructions, + status_code=422, + original_error=None, + ) diff --git a/.venv/lib/python3.12/site-packages/shared/abstractions/graph.py b/.venv/lib/python3.12/site-packages/shared/abstractions/graph.py new file mode 100644 index 00000000..3c1cec9e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/shared/abstractions/graph.py @@ -0,0 +1,257 @@ +import json +from dataclasses import dataclass +from datetime import datetime +from enum import Enum +from typing import Any, Optional +from uuid import UUID + +from pydantic import Field + +from ..abstractions.llm import GenerationConfig +from .base import R2RSerializable + + +class Entity(R2RSerializable): + """An entity extracted from a document.""" + + name: str + description: Optional[str] = None + category: Optional[str] = None + metadata: Optional[dict[str, Any]] = None + + id: Optional[UUID] = None + parent_id: Optional[UUID] = None # graph_id | document_id + description_embedding: Optional[list[float] | str] = None + chunk_ids: Optional[list[UUID]] = [] + + def __str__(self): + return f"{self.name}:{self.category}" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + if isinstance(self.metadata, str): + try: + self.metadata = json.loads(self.metadata) + except json.JSONDecodeError: + self.metadata = self.metadata + + +class Relationship(R2RSerializable): + """A relationship between two entities. + + This is a generic relationship, and can be used to represent any type of + relationship between any two entities. + """ + + id: Optional[UUID] = None + subject: str + predicate: str + object: str + description: Optional[str] = None + subject_id: Optional[UUID] = None + object_id: Optional[UUID] = None + weight: float | None = 1.0 + chunk_ids: Optional[list[UUID]] = [] + parent_id: Optional[UUID] = None + description_embedding: Optional[list[float] | str] = None + metadata: Optional[dict[str, Any] | str] = None + + def __init__(self, **kwargs): + super().__init__(**kwargs) + if isinstance(self.metadata, str): + try: + self.metadata = json.loads(self.metadata) + except json.JSONDecodeError: + self.metadata = self.metadata + + +@dataclass +class Community(R2RSerializable): + name: str = "" + summary: str = "" + level: Optional[int] = None + findings: list[str] = [] + id: Optional[int | UUID] = None + community_id: Optional[UUID] = None + collection_id: Optional[UUID] = None + rating: Optional[float] = None + rating_explanation: Optional[str] = None + description_embedding: Optional[list[float]] = None + attributes: dict[str, Any] | None = None + created_at: datetime = Field( + default_factory=datetime.utcnow, + ) + updated_at: datetime = Field( + default_factory=datetime.utcnow, + ) + + def __init__(self, **kwargs): + if isinstance(kwargs.get("attributes", None), str): + kwargs["attributes"] = json.loads(kwargs["attributes"]) + + if isinstance(kwargs.get("embedding", None), str): + kwargs["embedding"] = json.loads(kwargs["embedding"]) + + super().__init__(**kwargs) + + @classmethod + def from_dict(cls, data: dict[str, Any] | str) -> "Community": + parsed_data: dict[str, Any] = ( + json.loads(data) if isinstance(data, str) else data + ) + if isinstance(parsed_data.get("embedding", None), str): + parsed_data["embedding"] = json.loads(parsed_data["embedding"]) + return cls(**parsed_data) + + +class GraphExtraction(R2RSerializable): + """A protocol for a knowledge graph extraction.""" + + entities: list[Entity] + relationships: list[Relationship] + + +class Graph(R2RSerializable): + id: UUID | None = Field() + name: str + description: Optional[str] = None + created_at: datetime = Field( + default_factory=datetime.utcnow, + ) + updated_at: datetime = Field( + default_factory=datetime.utcnow, + ) + status: str = "pending" + + class Config: + populate_by_name = True + from_attributes = True + + @classmethod + def from_dict(cls, data: dict[str, Any] | str) -> "Graph": + """Create a Graph instance from a dictionary.""" + # Convert string to dict if needed + parsed_data: dict[str, Any] = ( + json.loads(data) if isinstance(data, str) else data + ) + return cls(**parsed_data) + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + +class StoreType(str, Enum): + GRAPHS = "graphs" + DOCUMENTS = "documents" + + +class GraphCreationSettings(R2RSerializable): + """Settings for knowledge graph creation.""" + + graph_extraction_prompt: str = Field( + default="graph_extraction", + description="The prompt to use for knowledge graph extraction.", + ) + + graph_entity_description_prompt: str = Field( + default="graph_entity_description", + description="The prompt to use for entity description generation.", + ) + + entity_types: list[str] = Field( + default=[], + description="The types of entities to extract.", + ) + + relation_types: list[str] = Field( + default=[], + description="The types of relations to extract.", + ) + + chunk_merge_count: int = Field( + default=2, + description="""The number of extractions to merge into a single graph + extraction.""", + ) + + max_knowledge_relationships: int = Field( + default=100, + description="""The maximum number of knowledge relationships to extract + from each chunk.""", + ) + + max_description_input_length: int = Field( + default=65536, + description="""The maximum length of the description for a node in the + graph.""", + ) + + generation_config: Optional[GenerationConfig] = Field( + default=None, + description="Configuration for text generation during graph enrichment.", + ) + + automatic_deduplication: bool = Field( + default=False, + description="Whether to automatically deduplicate entities.", + ) + + +class GraphEnrichmentSettings(R2RSerializable): + """Settings for knowledge graph enrichment.""" + + force_graph_search_results_enrichment: bool = Field( + default=False, + description="""Force run the enrichment step even if graph creation is + still in progress for some documents.""", + ) + + graph_communities_prompt: str = Field( + default="graph_communities", + description="The prompt to use for knowledge graph enrichment.", + ) + + max_summary_input_length: int = Field( + default=65536, + description="The maximum length of the summary for a community.", + ) + + generation_config: Optional[GenerationConfig] = Field( + default=None, + description="Configuration for text generation during graph enrichment.", + ) + + leiden_params: dict = Field( + default_factory=dict, + description="Parameters for the Leiden algorithm.", + ) + + +class GraphCommunitySettings(R2RSerializable): + """Settings for knowledge graph community enrichment.""" + + force_graph_search_results_enrichment: bool = Field( + default=False, + description="""Force run the enrichment step even if graph creation is + still in progress for some documents.""", + ) + + graph_communities: str = Field( + default="graph_communities", + description="The prompt to use for knowledge graph enrichment.", + ) + + max_summary_input_length: int = Field( + default=65536, + description="The maximum length of the summary for a community.", + ) + + generation_config: Optional[GenerationConfig] = Field( + default=None, + description="Configuration for text generation during graph enrichment.", + ) + + leiden_params: dict = Field( + default_factory=dict, + description="Parameters for the Leiden algorithm.", + ) diff --git a/.venv/lib/python3.12/site-packages/shared/abstractions/llm.py b/.venv/lib/python3.12/site-packages/shared/abstractions/llm.py new file mode 100644 index 00000000..d71e279e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/shared/abstractions/llm.py @@ -0,0 +1,325 @@ +"""Abstractions for the LLM model.""" + +import json +from enum import Enum +from typing import TYPE_CHECKING, Any, ClassVar, Optional + +from openai.types.chat import ChatCompletionChunk +from pydantic import BaseModel, Field + +from .base import R2RSerializable + +if TYPE_CHECKING: + from .search import AggregateSearchResult + +from typing_extensions import Literal + + +class Function(BaseModel): + arguments: str + """ + The arguments to call the function with, as generated by the model in JSON + format. Note that the model does not always generate valid JSON, and may + hallucinate parameters not defined by your function schema. Validate the + arguments in your code before calling your function. + """ + + name: str + """The name of the function to call.""" + + +class ChatCompletionMessageToolCall(BaseModel): + id: str + """The ID of the tool call.""" + + function: Function + """The function that the model called.""" + + type: Literal["function"] + """The type of the tool. Currently, only `function` is supported.""" + + +class FunctionCall(BaseModel): + arguments: str + """ + The arguments to call the function with, as generated by the model in JSON + format. Note that the model does not always generate valid JSON, and may + hallucinate parameters not defined by your function schema. Validate the + arguments in your code before calling your function. + """ + + name: str + """The name of the function to call.""" + + +class ChatCompletionMessage(BaseModel): + content: Optional[str] = None + """The contents of the message.""" + + refusal: Optional[str] = None + """The refusal message generated by the model.""" + + role: Literal["assistant"] + """The role of the author of this message.""" + + # audio: Optional[ChatCompletionAudio] = None + """ + If the audio output modality is requested, this object contains data about the + audio response from the model. + [Learn more](https://platform.openai.com/docs/guides/audio). + """ + + function_call: Optional[FunctionCall] = None + """Deprecated and replaced by `tool_calls`. + + The name and arguments of a function that should be called, as generated by the + model. + """ + + tool_calls: Optional[list[ChatCompletionMessageToolCall]] = None + """The tool calls generated by the model, such as function calls.""" + + structured_content: Optional[list[dict]] = None + + +class Choice(BaseModel): + finish_reason: Literal[ + "stop", + "length", + "tool_calls", + "content_filter", + "function_call", + "max_tokens", + ] + """The reason the model stopped generating tokens. + + This will be `stop` if the model hit a natural stop point or a provided stop + sequence, `length` if the maximum number of tokens specified in the request was + reached, `content_filter` if content was omitted due to a flag from our content + filters, `tool_calls` if the model called a tool, or `function_call` + (deprecated) if the model called a function. + """ + + index: int + """The index of the choice in the list of choices.""" + + # logprobs: Optional[ChoiceLogprobs] = None + """Log probability information for the choice.""" + + message: ChatCompletionMessage + """A chat completion message generated by the model.""" + + +class LLMChatCompletion(BaseModel): + id: str + """A unique identifier for the chat completion.""" + + choices: list[Choice] + """A list of chat completion choices. + + Can be more than one if `n` is greater than 1. + """ + + created: int + """The Unix timestamp (in seconds) of when the chat completion was created.""" + + model: str + """The model used for the chat completion.""" + + object: Literal["chat.completion"] + """The object type, which is always `chat.completion`.""" + + service_tier: Optional[Literal["scale", "default"]] = None + """The service tier used for processing the request.""" + + system_fingerprint: Optional[str] = None + """This fingerprint represents the backend configuration that the model runs with. + + Can be used in conjunction with the `seed` request parameter to understand when + backend changes have been made that might impact determinism. + """ + + usage: Optional[Any] = None + """Usage statistics for the completion request.""" + + +LLMChatCompletionChunk = ChatCompletionChunk + + +class RAGCompletion: + completion: LLMChatCompletion + search_results: "AggregateSearchResult" + + def __init__( + self, + completion: LLMChatCompletion, + search_results: "AggregateSearchResult", + ): + self.completion = completion + self.search_results = search_results + + +class GenerationConfig(R2RSerializable): + _defaults: ClassVar[dict] = { + "model": None, + "temperature": 0.1, + "top_p": 1.0, + "max_tokens_to_sample": 1024, + "stream": False, + "functions": None, + "tools": None, + "add_generation_kwargs": None, + "api_base": None, + "response_format": None, + "extended_thinking": False, + "thinking_budget": None, + "reasoning_effort": None, + } + + model: Optional[str] = Field( + default_factory=lambda: GenerationConfig._defaults["model"] + ) + temperature: float = Field( + default_factory=lambda: GenerationConfig._defaults["temperature"] + ) + top_p: Optional[float] = Field( + default_factory=lambda: GenerationConfig._defaults["top_p"], + ) + max_tokens_to_sample: int = Field( + default_factory=lambda: GenerationConfig._defaults[ + "max_tokens_to_sample" + ], + ) + stream: bool = Field( + default_factory=lambda: GenerationConfig._defaults["stream"] + ) + functions: Optional[list[dict]] = Field( + default_factory=lambda: GenerationConfig._defaults["functions"] + ) + tools: Optional[list[dict]] = Field( + default_factory=lambda: GenerationConfig._defaults["tools"] + ) + add_generation_kwargs: Optional[dict] = Field( + default_factory=lambda: GenerationConfig._defaults[ + "add_generation_kwargs" + ], + ) + api_base: Optional[str] = Field( + default_factory=lambda: GenerationConfig._defaults["api_base"], + ) + response_format: Optional[dict | BaseModel] = None + extended_thinking: bool = Field( + default=False, + description="Flag to enable extended thinking mode (for Anthropic providers)", + ) + thinking_budget: Optional[int] = Field( + default=None, + description=( + "Token budget for internal reasoning when extended thinking mode is enabled. " + "Must be less than max_tokens_to_sample." + ), + ) + reasoning_effort: Optional[str] = Field( + default=None, + description=( + "Effort level for internal reasoning when extended thinking mode is enabled, `low`, `medium`, or `high`." + "Only applicable to OpenAI providers." + ), + ) + + @classmethod + def set_default(cls, **kwargs): + for key, value in kwargs.items(): + if key in cls._defaults: + cls._defaults[key] = value + else: + raise AttributeError( + f"No default attribute '{key}' in GenerationConfig" + ) + + def __init__(self, **data): + # Handle max_tokens mapping to max_tokens_to_sample + if "max_tokens" in data: + # Only set max_tokens_to_sample if it's not already provided + if "max_tokens_to_sample" not in data: + data["max_tokens_to_sample"] = data.pop("max_tokens") + else: + # If both are provided, max_tokens_to_sample takes precedence + data.pop("max_tokens") + + if ( + "response_format" in data + and isinstance(data["response_format"], type) + and issubclass(data["response_format"], BaseModel) + ): + model_class = data["response_format"] + data["response_format"] = { + "type": "json_schema", + "json_schema": { + "name": model_class.__name__, + "schema": model_class.model_json_schema(), + }, + } + + model = data.pop("model", None) + if model is not None: + super().__init__(model=model, **data) + else: + super().__init__(**data) + + def __str__(self): + return json.dumps(self.to_dict()) + + class Config: + populate_by_name = True + json_schema_extra = { + "example": { + "model": "openai/gpt-4o", + "temperature": 0.1, + "top_p": 1.0, + "max_tokens_to_sample": 1024, + "stream": False, + "functions": None, + "tools": None, + "add_generation_kwargs": None, + "api_base": None, + } + } + + +class MessageType(Enum): + SYSTEM = "system" + USER = "user" + ASSISTANT = "assistant" + FUNCTION = "function" + TOOL = "tool" + + def __str__(self): + return self.value + + +class Message(R2RSerializable): + role: MessageType | str + content: Optional[Any] = None + name: Optional[str] = None + function_call: Optional[dict[str, Any]] = None + tool_calls: Optional[list[dict[str, Any]]] = None + tool_call_id: Optional[str] = None + metadata: Optional[dict[str, Any]] = None + structured_content: Optional[list[dict]] = None + image_url: Optional[str] = None # For URL-based images + image_data: Optional[dict[str, str]] = ( + None # For base64 {media_type, data} + ) + + class Config: + populate_by_name = True + json_schema_extra = { + "example": { + "role": "user", + "content": "This is a test message.", + "name": None, + "function_call": None, + "tool_calls": None, + } + } diff --git a/.venv/lib/python3.12/site-packages/shared/abstractions/prompt.py b/.venv/lib/python3.12/site-packages/shared/abstractions/prompt.py new file mode 100644 index 00000000..85ab5312 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/shared/abstractions/prompt.py @@ -0,0 +1,39 @@ +"""Abstraction for a prompt that can be formatted with inputs.""" + +import logging +from datetime import datetime +from typing import Any +from uuid import UUID, uuid4 + +from pydantic import BaseModel, Field + +logger = logging.getLogger() + + +class Prompt(BaseModel): + """A prompt that can be formatted with inputs.""" + + id: UUID = Field(default_factory=uuid4) + name: str + template: str + input_types: dict[str, str] + created_at: datetime = Field(default_factory=datetime.utcnow) + updated_at: datetime = Field(default_factory=datetime.utcnow) + + def format_prompt(self, inputs: dict[str, Any]) -> str: + self._validate_inputs(inputs) + return self.template.format(**inputs) + + def _validate_inputs(self, inputs: dict[str, Any]) -> None: + for var, expected_type_name in self.input_types.items(): + expected_type = self._convert_type(expected_type_name) + if var not in inputs: + raise ValueError(f"Missing input: {var}") + if not isinstance(inputs[var], expected_type): + raise TypeError( + f"Input '{var}' must be of type {expected_type.__name__}, got {type(inputs[var]).__name__} instead." + ) + + def _convert_type(self, type_name: str) -> type: + type_mapping = {"int": int, "str": str} + return type_mapping.get(type_name, str) diff --git a/.venv/lib/python3.12/site-packages/shared/abstractions/search.py b/.venv/lib/python3.12/site-packages/shared/abstractions/search.py new file mode 100644 index 00000000..bf0f650e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/shared/abstractions/search.py @@ -0,0 +1,614 @@ +"""Abstractions for search functionality.""" + +from copy import copy +from enum import Enum +from typing import Any, Optional +from uuid import NAMESPACE_DNS, UUID, uuid5 + +from pydantic import Field + +from .base import R2RSerializable +from .document import DocumentResponse +from .llm import GenerationConfig +from .vector import IndexMeasure + + +def generate_id_from_label(label) -> UUID: + return uuid5(NAMESPACE_DNS, label) + + +class ChunkSearchResult(R2RSerializable): + """Result of a search operation.""" + + id: UUID + document_id: UUID + owner_id: Optional[UUID] + collection_ids: list[UUID] + score: Optional[float] = None + text: str + metadata: dict[str, Any] + + def __str__(self) -> str: + if self.score: + return ( + f"ChunkSearchResult(score={self.score:.3f}, text={self.text})" + ) + else: + return f"ChunkSearchResult(text={self.text})" + + def __repr__(self) -> str: + return self.__str__() + + def as_dict(self) -> dict: + return { + "id": self.id, + "document_id": self.document_id, + "owner_id": self.owner_id, + "collection_ids": self.collection_ids, + "score": self.score, + "text": self.text, + "metadata": self.metadata, + } + + class Config: + populate_by_name = True + json_schema_extra = { + "example": { + "id": "3f3d47f3-8baf-58eb-8bc2-0171fb1c6e09", + "document_id": "3e157b3a-8469-51db-90d9-52e7d896b49b", + "owner_id": "2acb499e-8428-543b-bd85-0d9098718220", + "collection_ids": [], + "score": 0.23943702876567796, + "text": "Example text from the document", + "metadata": { + "title": "example_document.pdf", + "associated_query": "What is the capital of France?", + }, + } + } + + +class GraphSearchResultType(str, Enum): + ENTITY = "entity" + RELATIONSHIP = "relationship" + COMMUNITY = "community" + + +class GraphEntityResult(R2RSerializable): + id: Optional[UUID] = None + name: str + description: str + metadata: Optional[dict[str, Any]] = None + + class Config: + json_schema_extra = { + "example": { + "name": "Entity Name", + "description": "Entity Description", + "metadata": {}, + } + } + + +class GraphRelationshipResult(R2RSerializable): + id: Optional[UUID] = None + subject: str + predicate: str + object: str + subject_id: Optional[UUID] = None + object_id: Optional[UUID] = None + metadata: Optional[dict[str, Any]] = None + score: Optional[float] = None + description: str | None = None + + class Config: + json_schema_extra = { + "example": { + "name": "Relationship Name", + "description": "Relationship Description", + "metadata": {}, + } + } + + def __str__(self) -> str: + return f"GraphRelationshipResult(subject={self.subject}, predicate={self.predicate}, object={self.object})" + + +class GraphCommunityResult(R2RSerializable): + id: Optional[UUID] = None + name: str + summary: str + metadata: Optional[dict[str, Any]] = None + + class Config: + json_schema_extra = { + "example": { + "name": "Community Name", + "summary": "Community Summary", + "rating": 9, + "rating_explanation": "Rating Explanation", + "metadata": {}, + } + } + + def __str__(self) -> str: + return ( + f"GraphCommunityResult(name={self.name}, summary={self.summary})" + ) + + +class GraphSearchResult(R2RSerializable): + content: GraphEntityResult | GraphRelationshipResult | GraphCommunityResult + result_type: Optional[GraphSearchResultType] = None + chunk_ids: Optional[list[UUID]] = None + metadata: dict[str, Any] = {} + score: Optional[float] = None + id: UUID + + def __str__(self) -> str: + return f"GraphSearchResult(content={self.content}, result_type={self.result_type})" + + class Config: + populate_by_name = True + json_schema_extra = { + "example": { + "content": { + "id": "3f3d47f3-8baf-58eb-8bc2-0171fb1c6e09", + "name": "Entity Name", + "description": "Entity Description", + "metadata": {}, + }, + "result_type": "entity", + "chunk_ids": ["c68dc72e-fc23-5452-8f49-d7bd46088a96"], + "metadata": { + "associated_query": "What is the capital of France?" + }, + } + } + + +class WebPageSearchResult(R2RSerializable): + title: Optional[str] = None + link: Optional[str] = None + snippet: Optional[str] = None + position: int + type: str = "organic" + date: Optional[str] = None + sitelinks: Optional[list[dict]] = None + id: UUID + + class Config: + json_schema_extra = { + "example": { + "title": "Page Title", + "link": "https://example.com/page", + "snippet": "Page snippet", + "position": 1, + "date": "2021-01-01", + "sitelinks": [ + { + "title": "Sitelink Title", + "link": "https://example.com/sitelink", + } + ], + } + } + + def __str__(self) -> str: + return f"WebPageSearchResult(title={self.title}, link={self.link}, snippet={self.snippet})" + + +class RelatedSearchResult(R2RSerializable): + query: str + type: str = "related" + id: UUID + + +class PeopleAlsoAskResult(R2RSerializable): + question: str + snippet: str + link: str + title: str + id: UUID + type: str = "peopleAlsoAsk" + + +class WebSearchResult(R2RSerializable): + organic_results: list[WebPageSearchResult] = [] + related_searches: list[RelatedSearchResult] = [] + people_also_ask: list[PeopleAlsoAskResult] = [] + + @classmethod + def from_serper_results(cls, results: list[dict]) -> "WebSearchResult": + organic = [] + related = [] + paa = [] + + for result in results: + if result["type"] == "organic": + organic.append( + WebPageSearchResult( + **result, id=generate_id_from_label(result.get("link")) + ) + ) + elif result["type"] == "relatedSearches": + related.append( + RelatedSearchResult( + **result, + id=generate_id_from_label(result.get("query")), + ) + ) + elif result["type"] == "peopleAlsoAsk": + paa.append( + PeopleAlsoAskResult( + **result, id=generate_id_from_label(result.get("link")) + ) + ) + + return cls( + organic_results=organic, + related_searches=related, + people_also_ask=paa, + ) + + +class AggregateSearchResult(R2RSerializable): + """Result of an aggregate search operation.""" + + chunk_search_results: Optional[list[ChunkSearchResult]] = None + graph_search_results: Optional[list[GraphSearchResult]] = None + web_search_results: Optional[list[WebPageSearchResult]] = None + document_search_results: Optional[list[DocumentResponse]] = None + + def __str__(self) -> str: + return f"AggregateSearchResult(chunk_search_results={self.chunk_search_results}, graph_search_results={self.graph_search_results}, web_search_results={self.web_search_results}, document_search_results={str(self.document_search_results)})" + + def __repr__(self) -> str: + return f"AggregateSearchResult(chunk_search_results={self.chunk_search_results}, graph_search_results={self.graph_search_results}, web_search_results={self.web_search_results}, document_search_results={str(self.document_search_results)})" + + def as_dict(self) -> dict: + return { + "chunk_search_results": ( + [result.as_dict() for result in self.chunk_search_results] + if self.chunk_search_results + else [] + ), + "graph_search_results": ( + [result.to_dict() for result in self.graph_search_results] + if self.graph_search_results + else [] + ), + "web_search_results": ( + [result.to_dict() for result in self.web_search_results] + if self.web_search_results + else [] + ), + "document_search_results": ( + [cdr.to_dict() for cdr in self.document_search_results] + if self.document_search_results + else [] + ), + } + + class Config: + populate_by_name = True + json_schema_extra = { + "example": { + "chunk_search_results": [ + { + "id": "3f3d47f3-8baf-58eb-8bc2-0171fb1c6e09", + "document_id": "3e157b3a-8469-51db-90d9-52e7d896b49b", + "owner_id": "2acb499e-8428-543b-bd85-0d9098718220", + "collection_ids": [], + "score": 0.23943702876567796, + "text": "Example text from the document", + "metadata": { + "title": "example_document.pdf", + "associated_query": "What is the capital of France?", + }, + } + ], + "graph_search_results": [ + { + "content": { + "id": "3f3d47f3-8baf-58eb-8bc2-0171fb1c6e09", + "name": "Entity Name", + "description": "Entity Description", + "metadata": {}, + }, + "result_type": "entity", + "chunk_ids": ["c68dc72e-fc23-5452-8f49-d7bd46088a96"], + "metadata": { + "associated_query": "What is the capital of France?" + }, + } + ], + "web_search_results": [ + { + "title": "Page Title", + "link": "https://example.com/page", + "snippet": "Page snippet", + "position": 1, + "date": "2021-01-01", + "sitelinks": [ + { + "title": "Sitelink Title", + "link": "https://example.com/sitelink", + } + ], + } + ], + "document_search_results": [ + { + "document": { + "id": "3f3d47f3-8baf-58eb-8bc2-0171fb1c6e09", + "title": "Document Title", + "chunks": ["Chunk 1", "Chunk 2"], + "metadata": {}, + }, + } + ], + } + } + + +class HybridSearchSettings(R2RSerializable): + """Settings for hybrid search combining full-text and semantic search.""" + + full_text_weight: float = Field( + default=1.0, description="Weight to apply to full text search" + ) + semantic_weight: float = Field( + default=5.0, description="Weight to apply to semantic search" + ) + full_text_limit: int = Field( + default=200, + description="Maximum number of results to return from full text search", + ) + rrf_k: int = Field( + default=50, description="K-value for RRF (Rank Reciprocal Fusion)" + ) + + +class ChunkSearchSettings(R2RSerializable): + """Settings specific to chunk/vector search.""" + + index_measure: IndexMeasure = Field( + default=IndexMeasure.cosine_distance, + description="The distance measure to use for indexing", + ) + probes: int = Field( + default=10, + description="Number of ivfflat index lists to query. Higher increases accuracy but decreases speed.", + ) + ef_search: int = Field( + default=40, + description="Size of the dynamic candidate list for HNSW index search. Higher increases accuracy but decreases speed.", + ) + enabled: bool = Field( + default=True, + description="Whether to enable chunk search", + ) + + +class GraphSearchSettings(R2RSerializable): + """Settings specific to knowledge graph search.""" + + generation_config: Optional[GenerationConfig] = Field( + default=None, + description="Configuration for text generation during graph search.", + ) + max_community_description_length: int = Field( + default=65536, + ) + max_llm_queries_for_global_search: int = Field( + default=250, + ) + limits: dict[str, int] = Field( + default={}, + ) + enabled: bool = Field( + default=True, + description="Whether to enable graph search", + ) + + +class SearchSettings(R2RSerializable): + """Main search settings class that combines shared settings with + specialized settings for chunks and graph.""" + + # Search type flags + use_hybrid_search: bool = Field( + default=False, + description="Whether to perform a hybrid search. This is equivalent to setting `use_semantic_search=True` and `use_fulltext_search=True`, e.g. combining vector and keyword search.", + ) + use_semantic_search: bool = Field( + default=True, + description="Whether to use semantic search", + ) + use_fulltext_search: bool = Field( + default=False, + description="Whether to use full-text search", + ) + + # Common search parameters + filters: dict[str, Any] = Field( + default_factory=dict, + description="""Filters to apply to the search. Allowed operators include `eq`, `neq`, `gt`, `gte`, `lt`, `lte`, `like`, `ilike`, `in`, and `nin`. + + Commonly seen filters include operations include the following: + + `{"document_id": {"$eq": "9fbe403b-..."}}` + + `{"document_id": {"$in": ["9fbe403b-...", "3e157b3a-..."]}}` + + `{"collection_ids": {"$overlap": ["122fdf6a-...", "..."]}}` + + `{"$and": {"$document_id": ..., "collection_ids": ...}}`""", + ) + limit: int = Field( + default=10, + description="Maximum number of results to return", + ge=1, + le=1_000, + ) + offset: int = Field( + default=0, + ge=0, + description="Offset to paginate search results", + ) + include_metadatas: bool = Field( + default=True, + description="Whether to include element metadata in the search results", + ) + include_scores: bool = Field( + default=True, + description="""Whether to include search score values in the + search results""", + ) + + # Search strategy and settings + search_strategy: str = Field( + default="vanilla", + description="""Search strategy to use + (e.g., 'vanilla', 'query_fusion', 'hyde')""", + ) + hybrid_settings: HybridSearchSettings = Field( + default_factory=HybridSearchSettings, + description="""Settings for hybrid search (only used if + `use_semantic_search` and `use_fulltext_search` are both true)""", + ) + + # Specialized settings + chunk_settings: ChunkSearchSettings = Field( + default_factory=ChunkSearchSettings, + description="Settings specific to chunk/vector search", + ) + graph_settings: GraphSearchSettings = Field( + default_factory=GraphSearchSettings, + description="Settings specific to knowledge graph search", + ) + + # For HyDE or multi-query: + num_sub_queries: int = Field( + default=5, + description="Number of sub-queries/hypothetical docs to generate when using hyde or rag_fusion search strategies.", + ) + + class Config: + populate_by_name = True + json_encoders = {UUID: str} + json_schema_extra = { + "example": { + "use_semantic_search": True, + "use_fulltext_search": False, + "use_hybrid_search": False, + "filters": {"category": "technology"}, + "limit": 20, + "offset": 0, + "search_strategy": "vanilla", + "hybrid_settings": { + "full_text_weight": 1.0, + "semantic_weight": 5.0, + "full_text_limit": 200, + "rrf_k": 50, + }, + "chunk_settings": { + "enabled": True, + "index_measure": "cosine_distance", + "include_metadata": True, + "probes": 10, + "ef_search": 40, + }, + "graph_settings": { + "enabled": True, + "generation_config": GenerationConfig.Config.json_schema_extra, + "max_community_description_length": 65536, + "max_llm_queries_for_global_search": 250, + "limits": { + "entity": 20, + "relationship": 20, + "community": 20, + }, + }, + } + } + + def __init__(self, **data): + # Handle legacy search_filters field + data["filters"] = { + **data.get("filters", {}), + **data.get("search_filters", {}), + } + super().__init__(**data) + + def model_dump(self, *args, **kwargs): + return super().model_dump(*args, **kwargs) + + @classmethod + def get_default(cls, mode: str) -> "SearchSettings": + """Return default search settings for a given mode.""" + if mode == "basic": + # A simpler search that relies primarily on semantic search. + return cls( + use_semantic_search=True, + use_fulltext_search=False, + use_hybrid_search=False, + search_strategy="vanilla", + # Other relevant defaults can be provided here as needed + ) + elif mode == "advanced": + # A more powerful, combined search that leverages both semantic and fulltext. + return cls( + use_semantic_search=True, + use_fulltext_search=True, + use_hybrid_search=True, + search_strategy="hyde", + # Other advanced defaults as needed + ) + else: + # For 'custom' or unrecognized modes, return a basic empty config. + return cls() + + +class SearchMode(str, Enum): + """Search modes for the search endpoint.""" + + basic = "basic" + advanced = "advanced" + custom = "custom" + + +def select_search_filters( + auth_user: Any, + search_settings: SearchSettings, +) -> dict[str, Any]: + filters = copy(search_settings.filters) + selected_collections = None + if not auth_user.is_superuser: + user_collections = set(auth_user.collection_ids) + for key in filters.keys(): + if "collection_ids" in key: + selected_collections = set(map(UUID, filters[key]["$overlap"])) + break + + if selected_collections: + allowed_collections = user_collections.intersection( + selected_collections + ) + else: + allowed_collections = user_collections + # for non-superusers, we filter by user_id and selected & allowed collections + collection_filters = { + "$or": [ + {"owner_id": {"$eq": auth_user.id}}, + {"collection_ids": {"$overlap": list(allowed_collections)}}, + ] # type: ignore + } + + filters.pop("collection_ids", None) + if filters != {}: + filters = {"$and": [collection_filters, filters]} # type: ignore + else: + filters = collection_filters + return filters diff --git a/.venv/lib/python3.12/site-packages/shared/abstractions/user.py b/.venv/lib/python3.12/site-packages/shared/abstractions/user.py new file mode 100644 index 00000000..b04ac50b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/shared/abstractions/user.py @@ -0,0 +1,69 @@ +from datetime import datetime +from typing import Optional +from uuid import UUID + +from pydantic import BaseModel, Field + +from shared.abstractions import R2RSerializable + +from ..utils import generate_default_user_collection_id + + +class Collection(BaseModel): + id: UUID + name: str + description: Optional[str] = None + created_at: datetime = Field( + default_factory=datetime.utcnow, + ) + updated_at: datetime = Field( + default_factory=datetime.utcnow, + ) + + class Config: + populate_by_name = True + from_attributes = True + + def __init__(self, **data): + super().__init__(**data) + if self.id is None: + self.id = generate_default_user_collection_id(self.name) + + +class Token(BaseModel): + token: str + token_type: str + + +class TokenData(BaseModel): + email: str + token_type: str + exp: datetime + + +class User(R2RSerializable): + id: UUID + email: str + is_active: bool = True + is_superuser: bool = False + created_at: datetime = datetime.now() + updated_at: datetime = datetime.now() + is_verified: bool = False + collection_ids: list[UUID] = [] + graph_ids: list[UUID] = [] + document_ids: list[UUID] = [] + + # Optional fields (to update or set at creation) + limits_overrides: Optional[dict] = None + metadata: Optional[dict] = None + verification_code_expiry: Optional[datetime] = None + name: Optional[str] = None + bio: Optional[str] = None + profile_picture: Optional[str] = None + total_size_in_bytes: Optional[int] = None + num_files: Optional[int] = None + + account_type: str = "password" + hashed_password: Optional[str] = None + google_id: Optional[str] = None + github_id: Optional[str] = None diff --git a/.venv/lib/python3.12/site-packages/shared/abstractions/vector.py b/.venv/lib/python3.12/site-packages/shared/abstractions/vector.py new file mode 100644 index 00000000..0b88a765 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/shared/abstractions/vector.py @@ -0,0 +1,239 @@ +"""Abstraction for a vector that can be stored in the system.""" + +from enum import Enum +from typing import Any, Optional +from uuid import UUID + +from pydantic import BaseModel, Field + +from .base import R2RSerializable + + +class VectorType(str, Enum): + FIXED = "FIXED" + + +class IndexMethod(str, Enum): + """An enum representing the index methods available. + + This class currently only supports the 'ivfflat' method but may + expand in the future. + + Attributes: + auto (str): Automatically choose the best available index method. + ivfflat (str): The ivfflat index method. + hnsw (str): The hnsw index method. + """ + + auto = "auto" + ivfflat = "ivfflat" + hnsw = "hnsw" + + def __str__(self) -> str: + return self.value + + +class IndexMeasure(str, Enum): + """An enum representing the types of distance measures available for + indexing. + + Attributes: + cosine_distance (str): The cosine distance measure for indexing. + l2_distance (str): The Euclidean (L2) distance measure for indexing. + max_inner_product (str): The maximum inner product measure for indexing. + """ + + l2_distance = "l2_distance" + max_inner_product = "max_inner_product" + cosine_distance = "cosine_distance" + l1_distance = "l1_distance" + hamming_distance = "hamming_distance" + jaccard_distance = "jaccard_distance" + + def __str__(self) -> str: + return self.value + + @property + def ops(self) -> str: + return { + IndexMeasure.l2_distance: "_l2_ops", + IndexMeasure.max_inner_product: "_ip_ops", + IndexMeasure.cosine_distance: "_cosine_ops", + IndexMeasure.l1_distance: "_l1_ops", + IndexMeasure.hamming_distance: "_hamming_ops", + IndexMeasure.jaccard_distance: "_jaccard_ops", + }[self] + + @property + def pgvector_repr(self) -> str: + return { + IndexMeasure.l2_distance: "<->", + IndexMeasure.max_inner_product: "<#>", + IndexMeasure.cosine_distance: "<=>", + IndexMeasure.l1_distance: "<+>", + IndexMeasure.hamming_distance: "<~>", + IndexMeasure.jaccard_distance: "<%>", + }[self] + + +class IndexArgsIVFFlat(R2RSerializable): + """A class for arguments that can optionally be supplied to the index + creation method when building an IVFFlat type index. + + Attributes: + nlist (int): The number of IVF centroids that the index should use + """ + + n_lists: int + + +class IndexArgsHNSW(R2RSerializable): + """A class for arguments that can optionally be supplied to the index + creation method when building an HNSW type index. + + Ref: https://github.com/pgvector/pgvector#index-options + + Both attributes are Optional in case the user only wants to specify one and + leave the other as default + + Attributes: + m (int): Maximum number of connections per node per layer (default: 16) + ef_construction (int): Size of the dynamic candidate list for + constructing the graph (default: 64) + """ + + m: Optional[int] = 16 + ef_construction: Optional[int] = 64 + + +class VectorTableName(str, Enum): + """This enum represents the different tables where we store vectors.""" + + CHUNKS = "chunks" + ENTITIES_DOCUMENT = "documents_entities" + GRAPHS_ENTITIES = "graphs_entities" + # TODO: Add support for relationships + # TRIPLES = "relationship" + COMMUNITIES = "graphs_communities" + + def __str__(self) -> str: + return self.value + + +class VectorQuantizationType(str, Enum): + """An enum representing the types of quantization available for vectors. + + Attributes: + FP32 (str): 32-bit floating point quantization. + FP16 (str): 16-bit floating point quantization. + INT1 (str): 1-bit integer quantization. + SPARSE (str): Sparse vector quantization. + """ + + FP32 = "FP32" + FP16 = "FP16" + INT1 = "INT1" + SPARSE = "SPARSE" + + def __str__(self) -> str: + return self.value + + @property + def db_type(self) -> str: + db_type_mapping = { + "FP32": "vector", + "FP16": "halfvec", + "INT1": "bit", + "SPARSE": "sparsevec", + } + return db_type_mapping[self.value] + + +class VectorQuantizationSettings(R2RSerializable): + quantization_type: VectorQuantizationType = Field( + default=VectorQuantizationType.FP32 + ) + + +class Vector(R2RSerializable): + """A vector with the option to fix the number of elements.""" + + data: list[float] + type: VectorType = Field(default=VectorType.FIXED) + length: int = Field(default=-1) + + def __init__(self, **data): + super().__init__(**data) + if ( + self.type == VectorType.FIXED + and self.length > 0 + and len(self.data) != self.length + ): + raise ValueError( + f"Vector must be exactly {self.length} elements long." + ) + + def __repr__(self) -> str: + return ( + f"Vector(data={self.data}, type={self.type}, length={self.length})" + ) + + +class VectorEntry(R2RSerializable): + """A vector entry that can be stored directly in supported vector + databases.""" + + id: UUID + document_id: UUID + owner_id: UUID + collection_ids: list[UUID] + vector: Vector + text: str + metadata: dict[str, Any] + + def __str__(self) -> str: + """Return a string representation of the VectorEntry.""" + return ( + f"VectorEntry(" + f"chunk_id={self.id}, " + f"document_id={self.document_id}, " + f"owner_id={self.owner_id}, " + f"collection_ids={self.collection_ids}, " + f"vector={self.vector}, " + f"text={self.text}, " + f"metadata={self.metadata})" + ) + + def __repr__(self) -> str: + """Return an unambiguous string representation of the VectorEntry.""" + return self.__str__() + + +class StorageResult(R2RSerializable): + """A result of a storage operation.""" + + success: bool + document_id: UUID + num_chunks: int = 0 + error_message: Optional[str] = None + + def __str__(self) -> str: + """Return a string representation of the StorageResult.""" + return f"StorageResult(success={self.success}, error_message={self.error_message})" + + def __repr__(self) -> str: + """Return an unambiguous string representation of the StorageResult.""" + return self.__str__() + + +class IndexConfig(BaseModel): + name: Optional[str] = Field(default=None) + table_name: Optional[str] = Field(default=VectorTableName.CHUNKS) + index_method: Optional[str] = Field(default=IndexMethod.hnsw) + index_measure: Optional[str] = Field(default=IndexMeasure.cosine_distance) + index_arguments: Optional[IndexArgsIVFFlat | IndexArgsHNSW] = Field( + default=None + ) + index_name: Optional[str] = Field(default=None) + index_column: Optional[str] = Field(default=None) + concurrently: Optional[bool] = Field(default=True) |