diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/core/base/providers | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/core/base/providers')
10 files changed, 1477 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/core/base/providers/__init__.py b/.venv/lib/python3.12/site-packages/core/base/providers/__init__.py new file mode 100644 index 00000000..b902944d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/base/providers/__init__.py @@ -0,0 +1,59 @@ +from .auth import AuthConfig, AuthProvider +from .base import AppConfig, Provider, ProviderConfig +from .crypto import CryptoConfig, CryptoProvider +from .database import ( + DatabaseConfig, + DatabaseConnectionManager, + DatabaseProvider, + Handler, + LimitSettings, + PostgresConfigurationSettings, +) +from .email import EmailConfig, EmailProvider +from .embedding import EmbeddingConfig, EmbeddingProvider +from .ingestion import ( + ChunkingStrategy, + IngestionConfig, + IngestionMode, + IngestionProvider, +) +from .llm import CompletionConfig, CompletionProvider +from .orchestration import OrchestrationConfig, OrchestrationProvider, Workflow + +__all__ = [ + # Auth provider + "AuthConfig", + "AuthProvider", + # Base provider classes + "AppConfig", + "Provider", + "ProviderConfig", + # Ingestion provider + "IngestionMode", + "IngestionConfig", + "IngestionProvider", + "ChunkingStrategy", + # Crypto provider + "CryptoConfig", + "CryptoProvider", + # Email provider + "EmailConfig", + "EmailProvider", + # Database providers + "DatabaseConnectionManager", + "DatabaseConfig", + "LimitSettings", + "PostgresConfigurationSettings", + "DatabaseProvider", + "Handler", + # Embedding provider + "EmbeddingConfig", + "EmbeddingProvider", + # LLM provider + "CompletionConfig", + "CompletionProvider", + # Orchestration provider + "OrchestrationConfig", + "OrchestrationProvider", + "Workflow", +] diff --git a/.venv/lib/python3.12/site-packages/core/base/providers/auth.py b/.venv/lib/python3.12/site-packages/core/base/providers/auth.py new file mode 100644 index 00000000..352c3331 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/base/providers/auth.py @@ -0,0 +1,231 @@ +import logging +from abc import ABC, abstractmethod +from datetime import datetime +from typing import TYPE_CHECKING, Optional + +from fastapi import Security +from fastapi.security import ( + APIKeyHeader, + HTTPAuthorizationCredentials, + HTTPBearer, +) + +from ..abstractions import R2RException, Token, TokenData +from ..api.models import User +from .base import Provider, ProviderConfig +from .crypto import CryptoProvider +from .email import EmailProvider + +logger = logging.getLogger() + +if TYPE_CHECKING: + from core.providers.database import PostgresDatabaseProvider + +api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) + + +class AuthConfig(ProviderConfig): + secret_key: Optional[str] = None + require_authentication: bool = False + require_email_verification: bool = False + default_admin_email: str = "admin@example.com" + default_admin_password: str = "change_me_immediately" + access_token_lifetime_in_minutes: Optional[int] = None + refresh_token_lifetime_in_days: Optional[int] = None + + @property + def supported_providers(self) -> list[str]: + return ["r2r"] + + def validate_config(self) -> None: + pass + + +class AuthProvider(Provider, ABC): + security = HTTPBearer(auto_error=False) + crypto_provider: CryptoProvider + email_provider: EmailProvider + database_provider: "PostgresDatabaseProvider" + + def __init__( + self, + config: AuthConfig, + crypto_provider: CryptoProvider, + database_provider: "PostgresDatabaseProvider", + email_provider: EmailProvider, + ): + if not isinstance(config, AuthConfig): + raise ValueError( + "AuthProvider must be initialized with an AuthConfig" + ) + self.config = config + self.admin_email = config.default_admin_email + self.admin_password = config.default_admin_password + self.crypto_provider = crypto_provider + self.database_provider = database_provider + self.email_provider = email_provider + super().__init__(config) + self.config: AuthConfig = config + self.database_provider: "PostgresDatabaseProvider" = database_provider + + async def _get_default_admin_user(self) -> User: + return await self.database_provider.users_handler.get_user_by_email( + self.admin_email + ) + + @abstractmethod + def create_access_token(self, data: dict) -> str: + pass + + @abstractmethod + def create_refresh_token(self, data: dict) -> str: + pass + + @abstractmethod + async def decode_token(self, token: str) -> TokenData: + pass + + @abstractmethod + async def user(self, token: str) -> User: + pass + + @abstractmethod + def get_current_active_user(self, current_user: User) -> User: + pass + + @abstractmethod + async def register(self, email: str, password: str) -> User: + pass + + @abstractmethod + async def send_verification_email( + self, email: str, user: Optional[User] = None + ) -> tuple[str, datetime]: + pass + + @abstractmethod + async def verify_email( + self, email: str, verification_code: str + ) -> dict[str, str]: + pass + + @abstractmethod + async def login(self, email: str, password: str) -> dict[str, Token]: + pass + + @abstractmethod + async def refresh_access_token( + self, refresh_token: str + ) -> dict[str, Token]: + pass + + def auth_wrapper( + self, + public: bool = False, + ): + async def _auth_wrapper( + auth: Optional[HTTPAuthorizationCredentials] = Security( + self.security + ), + api_key: Optional[str] = Security(api_key_header), + ) -> User: + # If authentication is not required and no credentials are provided, return the default admin user + if ( + ((not self.config.require_authentication) or public) + and auth is None + and api_key is None + ): + return await self._get_default_admin_user() + if not auth and not api_key: + raise R2RException( + message="No credentials provided. Create an account at https://app.sciphi.ai and set your API key using `r2r configure key` OR change your base URL to a custom deployment.", + status_code=401, + ) + if auth and api_key: + raise R2RException( + message="Cannot have both Bearer token and API key", + status_code=400, + ) + # 1. Try JWT if `auth` is present (Bearer token) + if auth is not None: + credentials = auth.credentials + try: + token_data = await self.decode_token(credentials) + user = await self.database_provider.users_handler.get_user_by_email( + token_data.email + ) + if user is not None: + return user + except R2RException: + # JWT decoding failed for logical reasons (invalid token) + pass + except Exception as e: + # JWT decoding failed unexpectedly, log and continue + logger.debug(f"JWT verification failed: {e}") + + # 2. If JWT failed, try API key from Bearer token + # Expected format: key_id.raw_api_key + if "." in credentials: + key_id, raw_api_key = credentials.split(".", 1) + api_key_record = await self.database_provider.users_handler.get_api_key_record( + key_id + ) + if api_key_record is not None: + hashed_key = api_key_record["hashed_key"] + if self.crypto_provider.verify_api_key( + raw_api_key, hashed_key + ): + user = await self.database_provider.users_handler.get_user_by_id( + api_key_record["user_id"] + ) + if user is not None and user.is_active: + return user + + # 3. If no Bearer token worked, try the X-API-Key header + if api_key is not None and "." in api_key: + key_id, raw_api_key = api_key.split(".", 1) + api_key_record = await self.database_provider.users_handler.get_api_key_record( + key_id + ) + if api_key_record is not None: + hashed_key = api_key_record["hashed_key"] + if self.crypto_provider.verify_api_key( + raw_api_key, hashed_key + ): + user = await self.database_provider.users_handler.get_user_by_id( + api_key_record["user_id"] + ) + if user is not None and user.is_active: + return user + + # If we reach here, both JWT and API key auth failed + raise R2RException( + message="Invalid token or API key", + status_code=401, + ) + + return _auth_wrapper + + @abstractmethod + async def change_password( + self, user: User, current_password: str, new_password: str + ) -> dict[str, str]: + pass + + @abstractmethod + async def request_password_reset(self, email: str) -> dict[str, str]: + pass + + @abstractmethod + async def confirm_password_reset( + self, reset_token: str, new_password: str + ) -> dict[str, str]: + pass + + @abstractmethod + async def logout(self, token: str) -> dict[str, str]: + pass + + @abstractmethod + async def send_reset_email(self, email: str) -> dict[str, str]: + pass diff --git a/.venv/lib/python3.12/site-packages/core/base/providers/base.py b/.venv/lib/python3.12/site-packages/core/base/providers/base.py new file mode 100644 index 00000000..3f72a5ea --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/base/providers/base.py @@ -0,0 +1,135 @@ +from abc import ABC, abstractmethod +from typing import Any, Optional, Type + +from pydantic import BaseModel + + +class InnerConfig(BaseModel, ABC): + """A base provider configuration class.""" + + extra_fields: dict[str, Any] = {} + + class Config: + populate_by_name = True + arbitrary_types_allowed = True + ignore_extra = True + + @classmethod + def create(cls: Type["InnerConfig"], **kwargs: Any) -> "InnerConfig": + base_args = cls.model_fields.keys() + filtered_kwargs = { + k: v if v != "None" else None + for k, v in kwargs.items() + if k in base_args + } + instance = cls(**filtered_kwargs) # type: ignore + for k, v in kwargs.items(): + if k not in base_args: + instance.extra_fields[k] = v + return instance + + +class AppConfig(InnerConfig): + project_name: Optional[str] = None + default_max_documents_per_user: Optional[int] = 100 + default_max_chunks_per_user: Optional[int] = 10_000 + default_max_collections_per_user: Optional[int] = 5 + default_max_upload_size: int = 2_000_000 # e.g. ~2 MB + quality_llm: Optional[str] = None + fast_llm: Optional[str] = None + vlm: Optional[str] = None + audio_lm: Optional[str] = None + reasoning_llm: Optional[str] = None + planning_llm: Optional[str] = None + + # File extension to max-size mapping + # These are examples; adjust sizes as needed. + max_upload_size_by_type: dict[str, int] = { + # Common text-based formats + "txt": 2_000_000, + "md": 2_000_000, + "tsv": 2_000_000, + "csv": 5_000_000, + "xml": 2_000_000, + "html": 5_000_000, + # Office docs + "doc": 10_000_000, + "docx": 10_000_000, + "ppt": 20_000_000, + "pptx": 20_000_000, + "xls": 10_000_000, + "xlsx": 10_000_000, + "odt": 5_000_000, + # PDFs can expand quite a bit when converted to text + "pdf": 30_000_000, + # E-mail + "eml": 5_000_000, + "msg": 5_000_000, + "p7s": 5_000_000, + # Images + "bmp": 5_000_000, + "heic": 5_000_000, + "jpeg": 5_000_000, + "jpg": 5_000_000, + "png": 5_000_000, + "tiff": 5_000_000, + # Others + "epub": 10_000_000, + "rtf": 5_000_000, + "rst": 5_000_000, + "org": 5_000_000, + } + + +class ProviderConfig(BaseModel, ABC): + """A base provider configuration class.""" + + app: AppConfig # Add an app_config field + extra_fields: dict[str, Any] = {} + provider: Optional[str] = None + + class Config: + populate_by_name = True + arbitrary_types_allowed = True + ignore_extra = True + + @abstractmethod + def validate_config(self) -> None: + pass + + @classmethod + def create(cls: Type["ProviderConfig"], **kwargs: Any) -> "ProviderConfig": + base_args = cls.model_fields.keys() + filtered_kwargs = { + k: v if v != "None" else None + for k, v in kwargs.items() + if k in base_args + } + instance = cls(**filtered_kwargs) # type: ignore + for k, v in kwargs.items(): + if k not in base_args: + instance.extra_fields[k] = v + return instance + + @property + @abstractmethod + def supported_providers(self) -> list[str]: + """Define a list of supported providers.""" + pass + + @classmethod + def from_dict( + cls: Type["ProviderConfig"], data: dict[str, Any] + ) -> "ProviderConfig": + """Create a new instance of the config from a dictionary.""" + return cls.create(**data) + + +class Provider(ABC): + """A base provider class to provide a common interface for all + providers.""" + + def __init__(self, config: ProviderConfig, *args, **kwargs): + if config: + config.validate_config() + self.config = config diff --git a/.venv/lib/python3.12/site-packages/core/base/providers/crypto.py b/.venv/lib/python3.12/site-packages/core/base/providers/crypto.py new file mode 100644 index 00000000..bdf794b0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/base/providers/crypto.py @@ -0,0 +1,120 @@ +from abc import ABC, abstractmethod +from datetime import datetime +from typing import Optional, Tuple + +from .base import Provider, ProviderConfig + + +class CryptoConfig(ProviderConfig): + provider: Optional[str] = None + + @property + def supported_providers(self) -> list[str]: + return ["bcrypt", "nacl"] + + def validate_config(self) -> None: + if self.provider not in self.supported_providers: + raise ValueError(f"Unsupported crypto provider: {self.provider}") + + +class CryptoProvider(Provider, ABC): + def __init__(self, config: CryptoConfig): + if not isinstance(config, CryptoConfig): + raise ValueError( + "CryptoProvider must be initialized with a CryptoConfig" + ) + super().__init__(config) + + @abstractmethod + def get_password_hash(self, password: str) -> str: + """Hash a plaintext password using a secure password hashing algorithm + (e.g., Argon2i).""" + pass + + @abstractmethod + def verify_password( + self, plain_password: str, hashed_password: str + ) -> bool: + """Verify that a plaintext password matches the given hashed + password.""" + pass + + @abstractmethod + def generate_verification_code(self, length: int = 32) -> str: + """Generate a random code for email verification or reset tokens.""" + pass + + @abstractmethod + def generate_signing_keypair(self) -> Tuple[str, str, str]: + """Generate a new Ed25519 signing keypair for request signing. + + Returns: + A tuple of (key_id, private_key, public_key). + - key_id: A unique identifier for this keypair. + - private_key: Base64 encoded Ed25519 private key. + - public_key: Base64 encoded Ed25519 public key. + """ + pass + + @abstractmethod + def sign_request(self, private_key: str, data: str) -> str: + """Sign request data with an Ed25519 private key, returning the + signature.""" + pass + + @abstractmethod + def verify_request_signature( + self, public_key: str, signature: str, data: str + ) -> bool: + """Verify a request signature using the corresponding Ed25519 public + key.""" + pass + + @abstractmethod + def generate_api_key(self) -> Tuple[str, str]: + """Generate a new API key for a user. + + Returns: + A tuple (key_id, raw_api_key): + - key_id: A unique identifier for the API key. + - raw_api_key: The plaintext API key to provide to the user. + """ + pass + + @abstractmethod + def hash_api_key(self, raw_api_key: str) -> str: + """Hash a raw API key for secure storage in the database. + + Use strong parameters suitable for long-term secrets. + """ + pass + + @abstractmethod + def verify_api_key(self, raw_api_key: str, hashed_key: str) -> bool: + """Verify that a provided API key matches the stored hashed version.""" + pass + + @abstractmethod + def generate_secure_token(self, data: dict, expiry: datetime) -> str: + """Generate a secure, signed token (e.g., JWT) embedding claims. + + Args: + data: The claims to include in the token. + expiry: A datetime at which the token expires. + + Returns: + A JWT string signed with a secret key. + """ + pass + + @abstractmethod + def verify_secure_token(self, token: str) -> Optional[dict]: + """Verify a secure token (e.g., JWT). + + Args: + token: The token string to verify. + + Returns: + The token payload if valid, otherwise None. + """ + pass diff --git a/.venv/lib/python3.12/site-packages/core/base/providers/database.py b/.venv/lib/python3.12/site-packages/core/base/providers/database.py new file mode 100644 index 00000000..845a8109 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/base/providers/database.py @@ -0,0 +1,197 @@ +"""Base classes for database providers.""" + +import logging +from abc import ABC, abstractmethod +from typing import Any, Optional, Sequence, cast +from uuid import UUID + +from pydantic import BaseModel + +from core.base.abstractions import ( + GraphCreationSettings, + GraphEnrichmentSettings, + GraphSearchSettings, +) + +from .base import Provider, ProviderConfig + +logger = logging.getLogger() + + +class DatabaseConnectionManager(ABC): + @abstractmethod + def execute_query( + self, + query: str, + params: Optional[dict[str, Any] | Sequence[Any]] = None, + isolation_level: Optional[str] = None, + ): + pass + + @abstractmethod + async def execute_many(self, query, params=None, batch_size=1000): + pass + + @abstractmethod + def fetch_query( + self, + query: str, + params: Optional[dict[str, Any] | Sequence[Any]] = None, + ): + pass + + @abstractmethod + def fetchrow_query( + self, + query: str, + params: Optional[dict[str, Any] | Sequence[Any]] = None, + ): + pass + + @abstractmethod + async def initialize(self, pool: Any): + pass + + +class Handler(ABC): + def __init__( + self, + project_name: str, + connection_manager: DatabaseConnectionManager, + ): + self.project_name = project_name + self.connection_manager = connection_manager + + def _get_table_name(self, base_name: str) -> str: + return f"{self.project_name}.{base_name}" + + @abstractmethod + def create_tables(self): + pass + + +class PostgresConfigurationSettings(BaseModel): + """Configuration settings with defaults defined by the PGVector docker + image. + + These settings are helpful in managing the connections to the database. To + tune these settings for a specific deployment, see + https://pgtune.leopard.in.ua/ + """ + + checkpoint_completion_target: Optional[float] = 0.9 + default_statistics_target: Optional[int] = 100 + effective_io_concurrency: Optional[int] = 1 + effective_cache_size: Optional[int] = 524288 + huge_pages: Optional[str] = "try" + maintenance_work_mem: Optional[int] = 65536 + max_connections: Optional[int] = 256 + max_parallel_workers_per_gather: Optional[int] = 2 + max_parallel_workers: Optional[int] = 8 + max_parallel_maintenance_workers: Optional[int] = 2 + max_wal_size: Optional[int] = 1024 + max_worker_processes: Optional[int] = 8 + min_wal_size: Optional[int] = 80 + shared_buffers: Optional[int] = 16384 + statement_cache_size: Optional[int] = 100 + random_page_cost: Optional[float] = 4 + wal_buffers: Optional[int] = 512 + work_mem: Optional[int] = 4096 + + +class LimitSettings(BaseModel): + global_per_min: Optional[int] = None + route_per_min: Optional[int] = None + monthly_limit: Optional[int] = None + + def merge_with_defaults( + self, defaults: "LimitSettings" + ) -> "LimitSettings": + return LimitSettings( + global_per_min=self.global_per_min or defaults.global_per_min, + route_per_min=self.route_per_min or defaults.route_per_min, + monthly_limit=self.monthly_limit or defaults.monthly_limit, + ) + + +class DatabaseConfig(ProviderConfig): + """A base database configuration class.""" + + provider: str = "postgres" + user: Optional[str] = None + password: Optional[str] = None + host: Optional[str] = None + port: Optional[int] = None + db_name: Optional[str] = None + project_name: Optional[str] = None + postgres_configuration_settings: Optional[ + PostgresConfigurationSettings + ] = None + default_collection_name: str = "Default" + default_collection_description: str = "Your default collection." + collection_summary_system_prompt: str = "system" + collection_summary_prompt: str = "collection_summary" + enable_fts: bool = False + + # Graph settings + batch_size: Optional[int] = 1 + graph_search_results_store_path: Optional[str] = None + graph_enrichment_settings: GraphEnrichmentSettings = ( + GraphEnrichmentSettings() + ) + graph_creation_settings: GraphCreationSettings = GraphCreationSettings() + graph_search_settings: GraphSearchSettings = GraphSearchSettings() + + # Rate limits + limits: LimitSettings = LimitSettings( + global_per_min=60, route_per_min=20, monthly_limit=10000 + ) + route_limits: dict[str, LimitSettings] = {} + user_limits: dict[UUID, LimitSettings] = {} + + def validate_config(self) -> None: + if self.provider not in self.supported_providers: + raise ValueError(f"Provider '{self.provider}' is not supported.") + + @property + def supported_providers(self) -> list[str]: + return ["postgres"] + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "DatabaseConfig": + instance = cls.create(**data) + + instance = cast(DatabaseConfig, instance) + + limits_data = data.get("limits", {}) + default_limits = LimitSettings( + global_per_min=limits_data.get("global_per_min", 60), + route_per_min=limits_data.get("route_per_min", 20), + monthly_limit=limits_data.get("monthly_limit", 10000), + ) + + instance.limits = default_limits + + route_limits_data = limits_data.get("routes", {}) + for route_str, route_cfg in route_limits_data.items(): + instance.route_limits[route_str] = LimitSettings(**route_cfg) + + return instance + + +class DatabaseProvider(Provider): + connection_manager: DatabaseConnectionManager + config: DatabaseConfig + project_name: str + + def __init__(self, config: DatabaseConfig): + logger.info(f"Initializing DatabaseProvider with config {config}.") + super().__init__(config) + + @abstractmethod + async def __aenter__(self): + pass + + @abstractmethod + async def __aexit__(self, exc_type, exc, tb): + pass diff --git a/.venv/lib/python3.12/site-packages/core/base/providers/email.py b/.venv/lib/python3.12/site-packages/core/base/providers/email.py new file mode 100644 index 00000000..73f88162 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/base/providers/email.py @@ -0,0 +1,96 @@ +import logging +import os +from abc import ABC, abstractmethod +from typing import Optional + +from .base import Provider, ProviderConfig + + +class EmailConfig(ProviderConfig): + smtp_server: Optional[str] = None + smtp_port: Optional[int] = None + smtp_username: Optional[str] = None + smtp_password: Optional[str] = None + from_email: Optional[str] = None + use_tls: Optional[bool] = True + sendgrid_api_key: Optional[str] = None + mailersend_api_key: Optional[str] = None + verify_email_template_id: Optional[str] = None + reset_password_template_id: Optional[str] = None + password_changed_template_id: Optional[str] = None + frontend_url: Optional[str] = None + sender_name: Optional[str] = None + + @property + def supported_providers(self) -> list[str]: + return [ + "smtp", + "console", + "sendgrid", + "mailersend", + ] # Could add more providers like AWS SES, SendGrid etc. + + def validate_config(self) -> None: + if ( + self.provider == "sendgrid" + and not self.sendgrid_api_key + and not os.getenv("SENDGRID_API_KEY") + ): + raise ValueError( + "SendGrid API key is required when using SendGrid provider" + ) + + if ( + self.provider == "mailersend" + and not self.mailersend_api_key + and not os.getenv("MAILERSEND_API_KEY") + ): + raise ValueError( + "MailerSend API key is required when using MailerSend provider" + ) + + +logger = logging.getLogger(__name__) + + +class EmailProvider(Provider, ABC): + def __init__(self, config: EmailConfig): + if not isinstance(config, EmailConfig): + raise ValueError( + "EmailProvider must be initialized with an EmailConfig" + ) + super().__init__(config) + self.config: EmailConfig = config + + @abstractmethod + async def send_email( + self, + to_email: str, + subject: str, + body: str, + html_body: Optional[str] = None, + *args, + **kwargs, + ) -> None: + pass + + @abstractmethod + async def send_verification_email( + self, to_email: str, verification_code: str, *args, **kwargs + ) -> None: + pass + + @abstractmethod + async def send_password_reset_email( + self, to_email: str, reset_token: str, *args, **kwargs + ) -> None: + pass + + @abstractmethod + async def send_password_changed_email( + self, + to_email: str, + *args, + **kwargs, + ) -> None: + pass diff --git a/.venv/lib/python3.12/site-packages/core/base/providers/embedding.py b/.venv/lib/python3.12/site-packages/core/base/providers/embedding.py new file mode 100644 index 00000000..d1f9f9d6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/base/providers/embedding.py @@ -0,0 +1,197 @@ +import asyncio +import logging +import random +import time +from abc import abstractmethod +from enum import Enum +from typing import Any, Optional + +from litellm import AuthenticationError + +from core.base.abstractions import VectorQuantizationSettings + +from ..abstractions import ( + ChunkSearchResult, + EmbeddingPurpose, + default_embedding_prefixes, +) +from .base import Provider, ProviderConfig + +logger = logging.getLogger() + + +class EmbeddingConfig(ProviderConfig): + provider: str + base_model: str + base_dimension: int | float + rerank_model: Optional[str] = None + rerank_url: Optional[str] = None + batch_size: int = 1 + prefixes: Optional[dict[str, str]] = None + add_title_as_prefix: bool = True + concurrent_request_limit: int = 256 + max_retries: int = 3 + initial_backoff: float = 1 + max_backoff: float = 64.0 + quantization_settings: VectorQuantizationSettings = ( + VectorQuantizationSettings() + ) + + ## deprecated + rerank_dimension: Optional[int] = None + rerank_transformer_type: Optional[str] = None + + def validate_config(self) -> None: + if self.provider not in self.supported_providers: + raise ValueError(f"Provider '{self.provider}' is not supported.") + + @property + def supported_providers(self) -> list[str]: + return ["litellm", "openai", "ollama"] + + +class EmbeddingProvider(Provider): + class Step(Enum): + BASE = 1 + RERANK = 2 + + def __init__(self, config: EmbeddingConfig): + if not isinstance(config, EmbeddingConfig): + raise ValueError( + "EmbeddingProvider must be initialized with a `EmbeddingConfig`." + ) + logger.info(f"Initializing EmbeddingProvider with config {config}.") + + super().__init__(config) + self.config: EmbeddingConfig = config + self.semaphore = asyncio.Semaphore(config.concurrent_request_limit) + self.current_requests = 0 + + async def _execute_with_backoff_async(self, task: dict[str, Any]): + retries = 0 + backoff = self.config.initial_backoff + while retries < self.config.max_retries: + try: + async with self.semaphore: + return await self._execute_task(task) + except AuthenticationError: + raise + except Exception as e: + logger.warning( + f"Request failed (attempt {retries + 1}): {str(e)}" + ) + retries += 1 + if retries == self.config.max_retries: + raise + await asyncio.sleep(random.uniform(0, backoff)) + backoff = min(backoff * 2, self.config.max_backoff) + + def _execute_with_backoff_sync(self, task: dict[str, Any]): + retries = 0 + backoff = self.config.initial_backoff + while retries < self.config.max_retries: + try: + return self._execute_task_sync(task) + except AuthenticationError: + raise + except Exception as e: + logger.warning( + f"Request failed (attempt {retries + 1}): {str(e)}" + ) + retries += 1 + if retries == self.config.max_retries: + raise + time.sleep(random.uniform(0, backoff)) + backoff = min(backoff * 2, self.config.max_backoff) + + @abstractmethod + async def _execute_task(self, task: dict[str, Any]): + pass + + @abstractmethod + def _execute_task_sync(self, task: dict[str, Any]): + pass + + async def async_get_embedding( + self, + text: str, + stage: Step = Step.BASE, + purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX, + ): + task = { + "text": text, + "stage": stage, + "purpose": purpose, + } + return await self._execute_with_backoff_async(task) + + def get_embedding( + self, + text: str, + stage: Step = Step.BASE, + purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX, + ): + task = { + "text": text, + "stage": stage, + "purpose": purpose, + } + return self._execute_with_backoff_sync(task) + + async def async_get_embeddings( + self, + texts: list[str], + stage: Step = Step.BASE, + purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX, + ): + task = { + "texts": texts, + "stage": stage, + "purpose": purpose, + } + return await self._execute_with_backoff_async(task) + + def get_embeddings( + self, + texts: list[str], + stage: Step = Step.BASE, + purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX, + ) -> list[list[float]]: + task = { + "texts": texts, + "stage": stage, + "purpose": purpose, + } + return self._execute_with_backoff_sync(task) + + @abstractmethod + def rerank( + self, + query: str, + results: list[ChunkSearchResult], + stage: Step = Step.RERANK, + limit: int = 10, + ): + pass + + @abstractmethod + async def arerank( + self, + query: str, + results: list[ChunkSearchResult], + stage: Step = Step.RERANK, + limit: int = 10, + ): + pass + + def set_prefixes(self, config_prefixes: dict[str, str], base_model: str): + self.prefixes = {} + + for t, p in config_prefixes.items(): + purpose = EmbeddingPurpose(t.lower()) + self.prefixes[purpose] = p + + if base_model in default_embedding_prefixes: + for t, p in default_embedding_prefixes[base_model].items(): + if t not in self.prefixes: + self.prefixes[t] = p diff --git a/.venv/lib/python3.12/site-packages/core/base/providers/ingestion.py b/.venv/lib/python3.12/site-packages/core/base/providers/ingestion.py new file mode 100644 index 00000000..70d0d3a0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/base/providers/ingestion.py @@ -0,0 +1,172 @@ +import logging +from abc import ABC +from enum import Enum +from typing import TYPE_CHECKING, Any, ClassVar, Optional + +from pydantic import Field + +from core.base.abstractions import ChunkEnrichmentSettings + +from .base import AppConfig, Provider, ProviderConfig +from .llm import CompletionProvider + +logger = logging.getLogger() + +if TYPE_CHECKING: + from core.providers.database import PostgresDatabaseProvider + + +class ChunkingStrategy(str, Enum): + RECURSIVE = "recursive" + CHARACTER = "character" + BASIC = "basic" + BY_TITLE = "by_title" + + +class IngestionMode(str, Enum): + hi_res = "hi-res" + fast = "fast" + custom = "custom" + + +class IngestionConfig(ProviderConfig): + _defaults: ClassVar[dict] = { + "app": AppConfig(), + "provider": "r2r", + "excluded_parsers": ["mp4"], + "chunking_strategy": "recursive", + "chunk_size": 1024, + "chunk_enrichment_settings": ChunkEnrichmentSettings(), + "extra_parsers": {}, + "audio_transcription_model": None, + "vision_img_prompt_name": "vision_img", + "vision_pdf_prompt_name": "vision_pdf", + "skip_document_summary": False, + "document_summary_system_prompt": "system", + "document_summary_task_prompt": "summary", + "document_summary_max_length": 100_000, + "chunks_for_document_summary": 128, + "document_summary_model": None, + "parser_overrides": {}, + "extra_fields": {}, + "automatic_extraction": False, + } + + provider: str = Field( + default_factory=lambda: IngestionConfig._defaults["provider"] + ) + excluded_parsers: list[str] = Field( + default_factory=lambda: IngestionConfig._defaults["excluded_parsers"] + ) + chunking_strategy: str | ChunkingStrategy = Field( + default_factory=lambda: IngestionConfig._defaults["chunking_strategy"] + ) + chunk_size: int = Field( + default_factory=lambda: IngestionConfig._defaults["chunk_size"] + ) + chunk_enrichment_settings: ChunkEnrichmentSettings = Field( + default_factory=lambda: IngestionConfig._defaults[ + "chunk_enrichment_settings" + ] + ) + extra_parsers: dict[str, Any] = Field( + default_factory=lambda: IngestionConfig._defaults["extra_parsers"] + ) + audio_transcription_model: Optional[str] = Field( + default_factory=lambda: IngestionConfig._defaults[ + "audio_transcription_model" + ] + ) + vision_img_prompt_name: str = Field( + default_factory=lambda: IngestionConfig._defaults[ + "vision_img_prompt_name" + ] + ) + vision_pdf_prompt_name: str = Field( + default_factory=lambda: IngestionConfig._defaults[ + "vision_pdf_prompt_name" + ] + ) + skip_document_summary: bool = Field( + default_factory=lambda: IngestionConfig._defaults[ + "skip_document_summary" + ] + ) + document_summary_system_prompt: str = Field( + default_factory=lambda: IngestionConfig._defaults[ + "document_summary_system_prompt" + ] + ) + document_summary_task_prompt: str = Field( + default_factory=lambda: IngestionConfig._defaults[ + "document_summary_task_prompt" + ] + ) + chunks_for_document_summary: int = Field( + default_factory=lambda: IngestionConfig._defaults[ + "chunks_for_document_summary" + ] + ) + document_summary_model: Optional[str] = Field( + default_factory=lambda: IngestionConfig._defaults[ + "document_summary_model" + ] + ) + parser_overrides: dict[str, str] = Field( + default_factory=lambda: IngestionConfig._defaults["parser_overrides"] + ) + automatic_extraction: bool = Field( + default_factory=lambda: IngestionConfig._defaults[ + "automatic_extraction" + ] + ) + document_summary_max_length: int = Field( + default_factory=lambda: IngestionConfig._defaults[ + "document_summary_max_length" + ] + ) + + @classmethod + def set_default(cls, **kwargs): + for key, value in kwargs.items(): + if key in cls._defaults: + cls._defaults[key] = value + else: + raise AttributeError( + f"No default attribute '{key}' in IngestionConfig" + ) + + @property + def supported_providers(self) -> list[str]: + return ["r2r", "unstructured_local", "unstructured_api"] + + def validate_config(self) -> None: + if self.provider not in self.supported_providers: + raise ValueError(f"Provider {self.provider} is not supported.") + + @classmethod + def get_default(cls, mode: str, app) -> "IngestionConfig": + """Return default ingestion configuration for a given mode.""" + if mode == "hi-res": + return cls(app=app, parser_overrides={"pdf": "zerox"}) + if mode == "fast": + return cls(app=app, skip_document_summary=True) + else: + return cls(app=app) + + +class IngestionProvider(Provider, ABC): + config: IngestionConfig + database_provider: "PostgresDatabaseProvider" + llm_provider: CompletionProvider + + def __init__( + self, + config: IngestionConfig, + database_provider: "PostgresDatabaseProvider", + llm_provider: CompletionProvider, + ): + super().__init__(config) + self.config: IngestionConfig = config + self.llm_provider = llm_provider + self.database_provider: "PostgresDatabaseProvider" = database_provider diff --git a/.venv/lib/python3.12/site-packages/core/base/providers/llm.py b/.venv/lib/python3.12/site-packages/core/base/providers/llm.py new file mode 100644 index 00000000..669dfc4f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/base/providers/llm.py @@ -0,0 +1,200 @@ +import asyncio +import logging +import random +import time +from abc import abstractmethod +from concurrent.futures import ThreadPoolExecutor +from typing import Any, AsyncGenerator, Generator, Optional + +from litellm import AuthenticationError + +from core.base.abstractions import ( + GenerationConfig, + LLMChatCompletion, + LLMChatCompletionChunk, +) + +from .base import Provider, ProviderConfig + +logger = logging.getLogger() + + +class CompletionConfig(ProviderConfig): + provider: Optional[str] = None + generation_config: Optional[GenerationConfig] = None + concurrent_request_limit: int = 256 + max_retries: int = 3 + initial_backoff: float = 1.0 + max_backoff: float = 64.0 + + def validate_config(self) -> None: + if not self.provider: + raise ValueError("Provider must be set.") + if self.provider not in self.supported_providers: + raise ValueError(f"Provider '{self.provider}' is not supported.") + + @property + def supported_providers(self) -> list[str]: + return ["anthropic", "litellm", "openai", "r2r"] + + +class CompletionProvider(Provider): + def __init__(self, config: CompletionConfig) -> None: + if not isinstance(config, CompletionConfig): + raise ValueError( + "CompletionProvider must be initialized with a `CompletionConfig`." + ) + logger.info(f"Initializing CompletionProvider with config: {config}") + super().__init__(config) + self.config: CompletionConfig = config + self.semaphore = asyncio.Semaphore(config.concurrent_request_limit) + self.thread_pool = ThreadPoolExecutor( + max_workers=config.concurrent_request_limit + ) + + async def _execute_with_backoff_async(self, task: dict[str, Any]): + retries = 0 + backoff = self.config.initial_backoff + while retries < self.config.max_retries: + try: + async with self.semaphore: + return await self._execute_task(task) + except AuthenticationError: + raise + except Exception as e: + logger.warning( + f"Request failed (attempt {retries + 1}): {str(e)}" + ) + retries += 1 + if retries == self.config.max_retries: + raise + await asyncio.sleep(random.uniform(0, backoff)) + backoff = min(backoff * 2, self.config.max_backoff) + + async def _execute_with_backoff_async_stream( + self, task: dict[str, Any] + ) -> AsyncGenerator[Any, None]: + retries = 0 + backoff = self.config.initial_backoff + while retries < self.config.max_retries: + try: + async with self.semaphore: + async for chunk in await self._execute_task(task): + yield chunk + return # Successful completion of the stream + except AuthenticationError: + raise + except Exception as e: + logger.warning( + f"Streaming request failed (attempt {retries + 1}): {str(e)}" + ) + retries += 1 + if retries == self.config.max_retries: + raise + await asyncio.sleep(random.uniform(0, backoff)) + backoff = min(backoff * 2, self.config.max_backoff) + + def _execute_with_backoff_sync(self, task: dict[str, Any]): + retries = 0 + backoff = self.config.initial_backoff + while retries < self.config.max_retries: + try: + return self._execute_task_sync(task) + except Exception as e: + logger.warning( + f"Request failed (attempt {retries + 1}): {str(e)}" + ) + retries += 1 + if retries == self.config.max_retries: + raise + time.sleep(random.uniform(0, backoff)) + backoff = min(backoff * 2, self.config.max_backoff) + + def _execute_with_backoff_sync_stream( + self, task: dict[str, Any] + ) -> Generator[Any, None, None]: + retries = 0 + backoff = self.config.initial_backoff + while retries < self.config.max_retries: + try: + yield from self._execute_task_sync(task) + return # Successful completion of the stream + except Exception as e: + logger.warning( + f"Streaming request failed (attempt {retries + 1}): {str(e)}" + ) + retries += 1 + if retries == self.config.max_retries: + raise + time.sleep(random.uniform(0, backoff)) + backoff = min(backoff * 2, self.config.max_backoff) + + @abstractmethod + async def _execute_task(self, task: dict[str, Any]): + pass + + @abstractmethod + def _execute_task_sync(self, task: dict[str, Any]): + pass + + async def aget_completion( + self, + messages: list[dict], + generation_config: GenerationConfig, + **kwargs, + ) -> LLMChatCompletion: + task = { + "messages": messages, + "generation_config": generation_config, + "kwargs": kwargs, + } + response = await self._execute_with_backoff_async(task) + return LLMChatCompletion(**response.dict()) + + async def aget_completion_stream( + self, + messages: list[dict], + generation_config: GenerationConfig, + **kwargs, + ) -> AsyncGenerator[LLMChatCompletionChunk, None]: + generation_config.stream = True + task = { + "messages": messages, + "generation_config": generation_config, + "kwargs": kwargs, + } + async for chunk in self._execute_with_backoff_async_stream(task): + if isinstance(chunk, dict): + yield LLMChatCompletionChunk(**chunk) + continue + + chunk.choices[0].finish_reason = ( + chunk.choices[0].finish_reason + if chunk.choices[0].finish_reason != "" + else None + ) # handle error output conventions + chunk.choices[0].finish_reason = ( + chunk.choices[0].finish_reason + if chunk.choices[0].finish_reason != "eos" + else "stop" + ) # hardcode `eos` to `stop` for consistency + try: + yield LLMChatCompletionChunk(**(chunk.dict())) + except Exception as e: + logger.error(f"Error parsing chunk: {e}") + yield LLMChatCompletionChunk(**(chunk.as_dict())) + + def get_completion_stream( + self, + messages: list[dict], + generation_config: GenerationConfig, + **kwargs, + ) -> Generator[LLMChatCompletionChunk, None, None]: + generation_config.stream = True + task = { + "messages": messages, + "generation_config": generation_config, + "kwargs": kwargs, + } + for chunk in self._execute_with_backoff_sync_stream(task): + yield LLMChatCompletionChunk(**chunk.dict()) diff --git a/.venv/lib/python3.12/site-packages/core/base/providers/orchestration.py b/.venv/lib/python3.12/site-packages/core/base/providers/orchestration.py new file mode 100644 index 00000000..c3105f30 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/base/providers/orchestration.py @@ -0,0 +1,70 @@ +from abc import abstractmethod +from enum import Enum +from typing import Any + +from .base import Provider, ProviderConfig + + +class Workflow(Enum): + INGESTION = "ingestion" + GRAPH = "graph" + + +class OrchestrationConfig(ProviderConfig): + provider: str + max_runs: int = 2_048 + graph_search_results_creation_concurrency_limit: int = 32 + ingestion_concurrency_limit: int = 16 + graph_search_results_concurrency_limit: int = 8 + + def validate_config(self) -> None: + if self.provider not in self.supported_providers: + raise ValueError(f"Provider {self.provider} is not supported.") + + @property + def supported_providers(self) -> list[str]: + return ["hatchet", "simple"] + + +class OrchestrationProvider(Provider): + def __init__(self, config: OrchestrationConfig): + super().__init__(config) + self.config = config + self.worker = None + + @abstractmethod + async def start_worker(self): + pass + + @abstractmethod + def get_worker(self, name: str, max_runs: int) -> Any: + pass + + @abstractmethod + def step(self, *args, **kwargs) -> Any: + pass + + @abstractmethod + def workflow(self, *args, **kwargs) -> Any: + pass + + @abstractmethod + def failure(self, *args, **kwargs) -> Any: + pass + + @abstractmethod + def register_workflows( + self, workflow: Workflow, service: Any, messages: dict + ) -> None: + pass + + @abstractmethod + async def run_workflow( + self, + workflow_name: str, + parameters: dict, + options: dict, + *args, + **kwargs, + ) -> dict[str, str]: + pass |