aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/core/base/providers
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/core/base/providers')
-rw-r--r--.venv/lib/python3.12/site-packages/core/base/providers/__init__.py59
-rw-r--r--.venv/lib/python3.12/site-packages/core/base/providers/auth.py231
-rw-r--r--.venv/lib/python3.12/site-packages/core/base/providers/base.py135
-rw-r--r--.venv/lib/python3.12/site-packages/core/base/providers/crypto.py120
-rw-r--r--.venv/lib/python3.12/site-packages/core/base/providers/database.py197
-rw-r--r--.venv/lib/python3.12/site-packages/core/base/providers/email.py96
-rw-r--r--.venv/lib/python3.12/site-packages/core/base/providers/embedding.py197
-rw-r--r--.venv/lib/python3.12/site-packages/core/base/providers/ingestion.py172
-rw-r--r--.venv/lib/python3.12/site-packages/core/base/providers/llm.py200
-rw-r--r--.venv/lib/python3.12/site-packages/core/base/providers/orchestration.py70
10 files changed, 1477 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/core/base/providers/__init__.py b/.venv/lib/python3.12/site-packages/core/base/providers/__init__.py
new file mode 100644
index 00000000..b902944d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/base/providers/__init__.py
@@ -0,0 +1,59 @@
+from .auth import AuthConfig, AuthProvider
+from .base import AppConfig, Provider, ProviderConfig
+from .crypto import CryptoConfig, CryptoProvider
+from .database import (
+ DatabaseConfig,
+ DatabaseConnectionManager,
+ DatabaseProvider,
+ Handler,
+ LimitSettings,
+ PostgresConfigurationSettings,
+)
+from .email import EmailConfig, EmailProvider
+from .embedding import EmbeddingConfig, EmbeddingProvider
+from .ingestion import (
+ ChunkingStrategy,
+ IngestionConfig,
+ IngestionMode,
+ IngestionProvider,
+)
+from .llm import CompletionConfig, CompletionProvider
+from .orchestration import OrchestrationConfig, OrchestrationProvider, Workflow
+
+__all__ = [
+ # Auth provider
+ "AuthConfig",
+ "AuthProvider",
+ # Base provider classes
+ "AppConfig",
+ "Provider",
+ "ProviderConfig",
+ # Ingestion provider
+ "IngestionMode",
+ "IngestionConfig",
+ "IngestionProvider",
+ "ChunkingStrategy",
+ # Crypto provider
+ "CryptoConfig",
+ "CryptoProvider",
+ # Email provider
+ "EmailConfig",
+ "EmailProvider",
+ # Database providers
+ "DatabaseConnectionManager",
+ "DatabaseConfig",
+ "LimitSettings",
+ "PostgresConfigurationSettings",
+ "DatabaseProvider",
+ "Handler",
+ # Embedding provider
+ "EmbeddingConfig",
+ "EmbeddingProvider",
+ # LLM provider
+ "CompletionConfig",
+ "CompletionProvider",
+ # Orchestration provider
+ "OrchestrationConfig",
+ "OrchestrationProvider",
+ "Workflow",
+]
diff --git a/.venv/lib/python3.12/site-packages/core/base/providers/auth.py b/.venv/lib/python3.12/site-packages/core/base/providers/auth.py
new file mode 100644
index 00000000..352c3331
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/base/providers/auth.py
@@ -0,0 +1,231 @@
+import logging
+from abc import ABC, abstractmethod
+from datetime import datetime
+from typing import TYPE_CHECKING, Optional
+
+from fastapi import Security
+from fastapi.security import (
+ APIKeyHeader,
+ HTTPAuthorizationCredentials,
+ HTTPBearer,
+)
+
+from ..abstractions import R2RException, Token, TokenData
+from ..api.models import User
+from .base import Provider, ProviderConfig
+from .crypto import CryptoProvider
+from .email import EmailProvider
+
+logger = logging.getLogger()
+
+if TYPE_CHECKING:
+ from core.providers.database import PostgresDatabaseProvider
+
+api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
+
+
+class AuthConfig(ProviderConfig):
+ secret_key: Optional[str] = None
+ require_authentication: bool = False
+ require_email_verification: bool = False
+ default_admin_email: str = "admin@example.com"
+ default_admin_password: str = "change_me_immediately"
+ access_token_lifetime_in_minutes: Optional[int] = None
+ refresh_token_lifetime_in_days: Optional[int] = None
+
+ @property
+ def supported_providers(self) -> list[str]:
+ return ["r2r"]
+
+ def validate_config(self) -> None:
+ pass
+
+
+class AuthProvider(Provider, ABC):
+ security = HTTPBearer(auto_error=False)
+ crypto_provider: CryptoProvider
+ email_provider: EmailProvider
+ database_provider: "PostgresDatabaseProvider"
+
+ def __init__(
+ self,
+ config: AuthConfig,
+ crypto_provider: CryptoProvider,
+ database_provider: "PostgresDatabaseProvider",
+ email_provider: EmailProvider,
+ ):
+ if not isinstance(config, AuthConfig):
+ raise ValueError(
+ "AuthProvider must be initialized with an AuthConfig"
+ )
+ self.config = config
+ self.admin_email = config.default_admin_email
+ self.admin_password = config.default_admin_password
+ self.crypto_provider = crypto_provider
+ self.database_provider = database_provider
+ self.email_provider = email_provider
+ super().__init__(config)
+ self.config: AuthConfig = config
+ self.database_provider: "PostgresDatabaseProvider" = database_provider
+
+ async def _get_default_admin_user(self) -> User:
+ return await self.database_provider.users_handler.get_user_by_email(
+ self.admin_email
+ )
+
+ @abstractmethod
+ def create_access_token(self, data: dict) -> str:
+ pass
+
+ @abstractmethod
+ def create_refresh_token(self, data: dict) -> str:
+ pass
+
+ @abstractmethod
+ async def decode_token(self, token: str) -> TokenData:
+ pass
+
+ @abstractmethod
+ async def user(self, token: str) -> User:
+ pass
+
+ @abstractmethod
+ def get_current_active_user(self, current_user: User) -> User:
+ pass
+
+ @abstractmethod
+ async def register(self, email: str, password: str) -> User:
+ pass
+
+ @abstractmethod
+ async def send_verification_email(
+ self, email: str, user: Optional[User] = None
+ ) -> tuple[str, datetime]:
+ pass
+
+ @abstractmethod
+ async def verify_email(
+ self, email: str, verification_code: str
+ ) -> dict[str, str]:
+ pass
+
+ @abstractmethod
+ async def login(self, email: str, password: str) -> dict[str, Token]:
+ pass
+
+ @abstractmethod
+ async def refresh_access_token(
+ self, refresh_token: str
+ ) -> dict[str, Token]:
+ pass
+
+ def auth_wrapper(
+ self,
+ public: bool = False,
+ ):
+ async def _auth_wrapper(
+ auth: Optional[HTTPAuthorizationCredentials] = Security(
+ self.security
+ ),
+ api_key: Optional[str] = Security(api_key_header),
+ ) -> User:
+ # If authentication is not required and no credentials are provided, return the default admin user
+ if (
+ ((not self.config.require_authentication) or public)
+ and auth is None
+ and api_key is None
+ ):
+ return await self._get_default_admin_user()
+ if not auth and not api_key:
+ raise R2RException(
+ message="No credentials provided. Create an account at https://app.sciphi.ai and set your API key using `r2r configure key` OR change your base URL to a custom deployment.",
+ status_code=401,
+ )
+ if auth and api_key:
+ raise R2RException(
+ message="Cannot have both Bearer token and API key",
+ status_code=400,
+ )
+ # 1. Try JWT if `auth` is present (Bearer token)
+ if auth is not None:
+ credentials = auth.credentials
+ try:
+ token_data = await self.decode_token(credentials)
+ user = await self.database_provider.users_handler.get_user_by_email(
+ token_data.email
+ )
+ if user is not None:
+ return user
+ except R2RException:
+ # JWT decoding failed for logical reasons (invalid token)
+ pass
+ except Exception as e:
+ # JWT decoding failed unexpectedly, log and continue
+ logger.debug(f"JWT verification failed: {e}")
+
+ # 2. If JWT failed, try API key from Bearer token
+ # Expected format: key_id.raw_api_key
+ if "." in credentials:
+ key_id, raw_api_key = credentials.split(".", 1)
+ api_key_record = await self.database_provider.users_handler.get_api_key_record(
+ key_id
+ )
+ if api_key_record is not None:
+ hashed_key = api_key_record["hashed_key"]
+ if self.crypto_provider.verify_api_key(
+ raw_api_key, hashed_key
+ ):
+ user = await self.database_provider.users_handler.get_user_by_id(
+ api_key_record["user_id"]
+ )
+ if user is not None and user.is_active:
+ return user
+
+ # 3. If no Bearer token worked, try the X-API-Key header
+ if api_key is not None and "." in api_key:
+ key_id, raw_api_key = api_key.split(".", 1)
+ api_key_record = await self.database_provider.users_handler.get_api_key_record(
+ key_id
+ )
+ if api_key_record is not None:
+ hashed_key = api_key_record["hashed_key"]
+ if self.crypto_provider.verify_api_key(
+ raw_api_key, hashed_key
+ ):
+ user = await self.database_provider.users_handler.get_user_by_id(
+ api_key_record["user_id"]
+ )
+ if user is not None and user.is_active:
+ return user
+
+ # If we reach here, both JWT and API key auth failed
+ raise R2RException(
+ message="Invalid token or API key",
+ status_code=401,
+ )
+
+ return _auth_wrapper
+
+ @abstractmethod
+ async def change_password(
+ self, user: User, current_password: str, new_password: str
+ ) -> dict[str, str]:
+ pass
+
+ @abstractmethod
+ async def request_password_reset(self, email: str) -> dict[str, str]:
+ pass
+
+ @abstractmethod
+ async def confirm_password_reset(
+ self, reset_token: str, new_password: str
+ ) -> dict[str, str]:
+ pass
+
+ @abstractmethod
+ async def logout(self, token: str) -> dict[str, str]:
+ pass
+
+ @abstractmethod
+ async def send_reset_email(self, email: str) -> dict[str, str]:
+ pass
diff --git a/.venv/lib/python3.12/site-packages/core/base/providers/base.py b/.venv/lib/python3.12/site-packages/core/base/providers/base.py
new file mode 100644
index 00000000..3f72a5ea
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/base/providers/base.py
@@ -0,0 +1,135 @@
+from abc import ABC, abstractmethod
+from typing import Any, Optional, Type
+
+from pydantic import BaseModel
+
+
+class InnerConfig(BaseModel, ABC):
+ """A base provider configuration class."""
+
+ extra_fields: dict[str, Any] = {}
+
+ class Config:
+ populate_by_name = True
+ arbitrary_types_allowed = True
+ ignore_extra = True
+
+ @classmethod
+ def create(cls: Type["InnerConfig"], **kwargs: Any) -> "InnerConfig":
+ base_args = cls.model_fields.keys()
+ filtered_kwargs = {
+ k: v if v != "None" else None
+ for k, v in kwargs.items()
+ if k in base_args
+ }
+ instance = cls(**filtered_kwargs) # type: ignore
+ for k, v in kwargs.items():
+ if k not in base_args:
+ instance.extra_fields[k] = v
+ return instance
+
+
+class AppConfig(InnerConfig):
+ project_name: Optional[str] = None
+ default_max_documents_per_user: Optional[int] = 100
+ default_max_chunks_per_user: Optional[int] = 10_000
+ default_max_collections_per_user: Optional[int] = 5
+ default_max_upload_size: int = 2_000_000 # e.g. ~2 MB
+ quality_llm: Optional[str] = None
+ fast_llm: Optional[str] = None
+ vlm: Optional[str] = None
+ audio_lm: Optional[str] = None
+ reasoning_llm: Optional[str] = None
+ planning_llm: Optional[str] = None
+
+ # File extension to max-size mapping
+ # These are examples; adjust sizes as needed.
+ max_upload_size_by_type: dict[str, int] = {
+ # Common text-based formats
+ "txt": 2_000_000,
+ "md": 2_000_000,
+ "tsv": 2_000_000,
+ "csv": 5_000_000,
+ "xml": 2_000_000,
+ "html": 5_000_000,
+ # Office docs
+ "doc": 10_000_000,
+ "docx": 10_000_000,
+ "ppt": 20_000_000,
+ "pptx": 20_000_000,
+ "xls": 10_000_000,
+ "xlsx": 10_000_000,
+ "odt": 5_000_000,
+ # PDFs can expand quite a bit when converted to text
+ "pdf": 30_000_000,
+ # E-mail
+ "eml": 5_000_000,
+ "msg": 5_000_000,
+ "p7s": 5_000_000,
+ # Images
+ "bmp": 5_000_000,
+ "heic": 5_000_000,
+ "jpeg": 5_000_000,
+ "jpg": 5_000_000,
+ "png": 5_000_000,
+ "tiff": 5_000_000,
+ # Others
+ "epub": 10_000_000,
+ "rtf": 5_000_000,
+ "rst": 5_000_000,
+ "org": 5_000_000,
+ }
+
+
+class ProviderConfig(BaseModel, ABC):
+ """A base provider configuration class."""
+
+ app: AppConfig # Add an app_config field
+ extra_fields: dict[str, Any] = {}
+ provider: Optional[str] = None
+
+ class Config:
+ populate_by_name = True
+ arbitrary_types_allowed = True
+ ignore_extra = True
+
+ @abstractmethod
+ def validate_config(self) -> None:
+ pass
+
+ @classmethod
+ def create(cls: Type["ProviderConfig"], **kwargs: Any) -> "ProviderConfig":
+ base_args = cls.model_fields.keys()
+ filtered_kwargs = {
+ k: v if v != "None" else None
+ for k, v in kwargs.items()
+ if k in base_args
+ }
+ instance = cls(**filtered_kwargs) # type: ignore
+ for k, v in kwargs.items():
+ if k not in base_args:
+ instance.extra_fields[k] = v
+ return instance
+
+ @property
+ @abstractmethod
+ def supported_providers(self) -> list[str]:
+ """Define a list of supported providers."""
+ pass
+
+ @classmethod
+ def from_dict(
+ cls: Type["ProviderConfig"], data: dict[str, Any]
+ ) -> "ProviderConfig":
+ """Create a new instance of the config from a dictionary."""
+ return cls.create(**data)
+
+
+class Provider(ABC):
+ """A base provider class to provide a common interface for all
+ providers."""
+
+ def __init__(self, config: ProviderConfig, *args, **kwargs):
+ if config:
+ config.validate_config()
+ self.config = config
diff --git a/.venv/lib/python3.12/site-packages/core/base/providers/crypto.py b/.venv/lib/python3.12/site-packages/core/base/providers/crypto.py
new file mode 100644
index 00000000..bdf794b0
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/base/providers/crypto.py
@@ -0,0 +1,120 @@
+from abc import ABC, abstractmethod
+from datetime import datetime
+from typing import Optional, Tuple
+
+from .base import Provider, ProviderConfig
+
+
+class CryptoConfig(ProviderConfig):
+ provider: Optional[str] = None
+
+ @property
+ def supported_providers(self) -> list[str]:
+ return ["bcrypt", "nacl"]
+
+ def validate_config(self) -> None:
+ if self.provider not in self.supported_providers:
+ raise ValueError(f"Unsupported crypto provider: {self.provider}")
+
+
+class CryptoProvider(Provider, ABC):
+ def __init__(self, config: CryptoConfig):
+ if not isinstance(config, CryptoConfig):
+ raise ValueError(
+ "CryptoProvider must be initialized with a CryptoConfig"
+ )
+ super().__init__(config)
+
+ @abstractmethod
+ def get_password_hash(self, password: str) -> str:
+ """Hash a plaintext password using a secure password hashing algorithm
+ (e.g., Argon2i)."""
+ pass
+
+ @abstractmethod
+ def verify_password(
+ self, plain_password: str, hashed_password: str
+ ) -> bool:
+ """Verify that a plaintext password matches the given hashed
+ password."""
+ pass
+
+ @abstractmethod
+ def generate_verification_code(self, length: int = 32) -> str:
+ """Generate a random code for email verification or reset tokens."""
+ pass
+
+ @abstractmethod
+ def generate_signing_keypair(self) -> Tuple[str, str, str]:
+ """Generate a new Ed25519 signing keypair for request signing.
+
+ Returns:
+ A tuple of (key_id, private_key, public_key).
+ - key_id: A unique identifier for this keypair.
+ - private_key: Base64 encoded Ed25519 private key.
+ - public_key: Base64 encoded Ed25519 public key.
+ """
+ pass
+
+ @abstractmethod
+ def sign_request(self, private_key: str, data: str) -> str:
+ """Sign request data with an Ed25519 private key, returning the
+ signature."""
+ pass
+
+ @abstractmethod
+ def verify_request_signature(
+ self, public_key: str, signature: str, data: str
+ ) -> bool:
+ """Verify a request signature using the corresponding Ed25519 public
+ key."""
+ pass
+
+ @abstractmethod
+ def generate_api_key(self) -> Tuple[str, str]:
+ """Generate a new API key for a user.
+
+ Returns:
+ A tuple (key_id, raw_api_key):
+ - key_id: A unique identifier for the API key.
+ - raw_api_key: The plaintext API key to provide to the user.
+ """
+ pass
+
+ @abstractmethod
+ def hash_api_key(self, raw_api_key: str) -> str:
+ """Hash a raw API key for secure storage in the database.
+
+ Use strong parameters suitable for long-term secrets.
+ """
+ pass
+
+ @abstractmethod
+ def verify_api_key(self, raw_api_key: str, hashed_key: str) -> bool:
+ """Verify that a provided API key matches the stored hashed version."""
+ pass
+
+ @abstractmethod
+ def generate_secure_token(self, data: dict, expiry: datetime) -> str:
+ """Generate a secure, signed token (e.g., JWT) embedding claims.
+
+ Args:
+ data: The claims to include in the token.
+ expiry: A datetime at which the token expires.
+
+ Returns:
+ A JWT string signed with a secret key.
+ """
+ pass
+
+ @abstractmethod
+ def verify_secure_token(self, token: str) -> Optional[dict]:
+ """Verify a secure token (e.g., JWT).
+
+ Args:
+ token: The token string to verify.
+
+ Returns:
+ The token payload if valid, otherwise None.
+ """
+ pass
diff --git a/.venv/lib/python3.12/site-packages/core/base/providers/database.py b/.venv/lib/python3.12/site-packages/core/base/providers/database.py
new file mode 100644
index 00000000..845a8109
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/base/providers/database.py
@@ -0,0 +1,197 @@
+"""Base classes for database providers."""
+
+import logging
+from abc import ABC, abstractmethod
+from typing import Any, Optional, Sequence, cast
+from uuid import UUID
+
+from pydantic import BaseModel
+
+from core.base.abstractions import (
+ GraphCreationSettings,
+ GraphEnrichmentSettings,
+ GraphSearchSettings,
+)
+
+from .base import Provider, ProviderConfig
+
+logger = logging.getLogger()
+
+
+class DatabaseConnectionManager(ABC):
+ @abstractmethod
+ def execute_query(
+ self,
+ query: str,
+ params: Optional[dict[str, Any] | Sequence[Any]] = None,
+ isolation_level: Optional[str] = None,
+ ):
+ pass
+
+ @abstractmethod
+ async def execute_many(self, query, params=None, batch_size=1000):
+ pass
+
+ @abstractmethod
+ def fetch_query(
+ self,
+ query: str,
+ params: Optional[dict[str, Any] | Sequence[Any]] = None,
+ ):
+ pass
+
+ @abstractmethod
+ def fetchrow_query(
+ self,
+ query: str,
+ params: Optional[dict[str, Any] | Sequence[Any]] = None,
+ ):
+ pass
+
+ @abstractmethod
+ async def initialize(self, pool: Any):
+ pass
+
+
+class Handler(ABC):
+ def __init__(
+ self,
+ project_name: str,
+ connection_manager: DatabaseConnectionManager,
+ ):
+ self.project_name = project_name
+ self.connection_manager = connection_manager
+
+ def _get_table_name(self, base_name: str) -> str:
+ return f"{self.project_name}.{base_name}"
+
+ @abstractmethod
+ def create_tables(self):
+ pass
+
+
+class PostgresConfigurationSettings(BaseModel):
+ """Configuration settings with defaults defined by the PGVector docker
+ image.
+
+ These settings are helpful in managing the connections to the database. To
+ tune these settings for a specific deployment, see
+ https://pgtune.leopard.in.ua/
+ """
+
+ checkpoint_completion_target: Optional[float] = 0.9
+ default_statistics_target: Optional[int] = 100
+ effective_io_concurrency: Optional[int] = 1
+ effective_cache_size: Optional[int] = 524288
+ huge_pages: Optional[str] = "try"
+ maintenance_work_mem: Optional[int] = 65536
+ max_connections: Optional[int] = 256
+ max_parallel_workers_per_gather: Optional[int] = 2
+ max_parallel_workers: Optional[int] = 8
+ max_parallel_maintenance_workers: Optional[int] = 2
+ max_wal_size: Optional[int] = 1024
+ max_worker_processes: Optional[int] = 8
+ min_wal_size: Optional[int] = 80
+ shared_buffers: Optional[int] = 16384
+ statement_cache_size: Optional[int] = 100
+ random_page_cost: Optional[float] = 4
+ wal_buffers: Optional[int] = 512
+ work_mem: Optional[int] = 4096
+
+
+class LimitSettings(BaseModel):
+ global_per_min: Optional[int] = None
+ route_per_min: Optional[int] = None
+ monthly_limit: Optional[int] = None
+
+ def merge_with_defaults(
+ self, defaults: "LimitSettings"
+ ) -> "LimitSettings":
+ return LimitSettings(
+ global_per_min=self.global_per_min or defaults.global_per_min,
+ route_per_min=self.route_per_min or defaults.route_per_min,
+ monthly_limit=self.monthly_limit or defaults.monthly_limit,
+ )
+
+
+class DatabaseConfig(ProviderConfig):
+ """A base database configuration class."""
+
+ provider: str = "postgres"
+ user: Optional[str] = None
+ password: Optional[str] = None
+ host: Optional[str] = None
+ port: Optional[int] = None
+ db_name: Optional[str] = None
+ project_name: Optional[str] = None
+ postgres_configuration_settings: Optional[
+ PostgresConfigurationSettings
+ ] = None
+ default_collection_name: str = "Default"
+ default_collection_description: str = "Your default collection."
+ collection_summary_system_prompt: str = "system"
+ collection_summary_prompt: str = "collection_summary"
+ enable_fts: bool = False
+
+ # Graph settings
+ batch_size: Optional[int] = 1
+ graph_search_results_store_path: Optional[str] = None
+ graph_enrichment_settings: GraphEnrichmentSettings = (
+ GraphEnrichmentSettings()
+ )
+ graph_creation_settings: GraphCreationSettings = GraphCreationSettings()
+ graph_search_settings: GraphSearchSettings = GraphSearchSettings()
+
+ # Rate limits
+ limits: LimitSettings = LimitSettings(
+ global_per_min=60, route_per_min=20, monthly_limit=10000
+ )
+ route_limits: dict[str, LimitSettings] = {}
+ user_limits: dict[UUID, LimitSettings] = {}
+
+ def validate_config(self) -> None:
+ if self.provider not in self.supported_providers:
+ raise ValueError(f"Provider '{self.provider}' is not supported.")
+
+ @property
+ def supported_providers(self) -> list[str]:
+ return ["postgres"]
+
+ @classmethod
+ def from_dict(cls, data: dict[str, Any]) -> "DatabaseConfig":
+ instance = cls.create(**data)
+
+ instance = cast(DatabaseConfig, instance)
+
+ limits_data = data.get("limits", {})
+ default_limits = LimitSettings(
+ global_per_min=limits_data.get("global_per_min", 60),
+ route_per_min=limits_data.get("route_per_min", 20),
+ monthly_limit=limits_data.get("monthly_limit", 10000),
+ )
+
+ instance.limits = default_limits
+
+ route_limits_data = limits_data.get("routes", {})
+ for route_str, route_cfg in route_limits_data.items():
+ instance.route_limits[route_str] = LimitSettings(**route_cfg)
+
+ return instance
+
+
+class DatabaseProvider(Provider):
+ connection_manager: DatabaseConnectionManager
+ config: DatabaseConfig
+ project_name: str
+
+ def __init__(self, config: DatabaseConfig):
+ logger.info(f"Initializing DatabaseProvider with config {config}.")
+ super().__init__(config)
+
+ @abstractmethod
+ async def __aenter__(self):
+ pass
+
+ @abstractmethod
+ async def __aexit__(self, exc_type, exc, tb):
+ pass
diff --git a/.venv/lib/python3.12/site-packages/core/base/providers/email.py b/.venv/lib/python3.12/site-packages/core/base/providers/email.py
new file mode 100644
index 00000000..73f88162
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/base/providers/email.py
@@ -0,0 +1,96 @@
+import logging
+import os
+from abc import ABC, abstractmethod
+from typing import Optional
+
+from .base import Provider, ProviderConfig
+
+
+class EmailConfig(ProviderConfig):
+ smtp_server: Optional[str] = None
+ smtp_port: Optional[int] = None
+ smtp_username: Optional[str] = None
+ smtp_password: Optional[str] = None
+ from_email: Optional[str] = None
+ use_tls: Optional[bool] = True
+ sendgrid_api_key: Optional[str] = None
+ mailersend_api_key: Optional[str] = None
+ verify_email_template_id: Optional[str] = None
+ reset_password_template_id: Optional[str] = None
+ password_changed_template_id: Optional[str] = None
+ frontend_url: Optional[str] = None
+ sender_name: Optional[str] = None
+
+ @property
+ def supported_providers(self) -> list[str]:
+ return [
+ "smtp",
+ "console",
+ "sendgrid",
+ "mailersend",
+ ] # Could add more providers like AWS SES, SendGrid etc.
+
+ def validate_config(self) -> None:
+ if (
+ self.provider == "sendgrid"
+ and not self.sendgrid_api_key
+ and not os.getenv("SENDGRID_API_KEY")
+ ):
+ raise ValueError(
+ "SendGrid API key is required when using SendGrid provider"
+ )
+
+ if (
+ self.provider == "mailersend"
+ and not self.mailersend_api_key
+ and not os.getenv("MAILERSEND_API_KEY")
+ ):
+ raise ValueError(
+ "MailerSend API key is required when using MailerSend provider"
+ )
+
+
+logger = logging.getLogger(__name__)
+
+
+class EmailProvider(Provider, ABC):
+ def __init__(self, config: EmailConfig):
+ if not isinstance(config, EmailConfig):
+ raise ValueError(
+ "EmailProvider must be initialized with an EmailConfig"
+ )
+ super().__init__(config)
+ self.config: EmailConfig = config
+
+ @abstractmethod
+ async def send_email(
+ self,
+ to_email: str,
+ subject: str,
+ body: str,
+ html_body: Optional[str] = None,
+ *args,
+ **kwargs,
+ ) -> None:
+ pass
+
+ @abstractmethod
+ async def send_verification_email(
+ self, to_email: str, verification_code: str, *args, **kwargs
+ ) -> None:
+ pass
+
+ @abstractmethod
+ async def send_password_reset_email(
+ self, to_email: str, reset_token: str, *args, **kwargs
+ ) -> None:
+ pass
+
+ @abstractmethod
+ async def send_password_changed_email(
+ self,
+ to_email: str,
+ *args,
+ **kwargs,
+ ) -> None:
+ pass
diff --git a/.venv/lib/python3.12/site-packages/core/base/providers/embedding.py b/.venv/lib/python3.12/site-packages/core/base/providers/embedding.py
new file mode 100644
index 00000000..d1f9f9d6
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/base/providers/embedding.py
@@ -0,0 +1,197 @@
+import asyncio
+import logging
+import random
+import time
+from abc import abstractmethod
+from enum import Enum
+from typing import Any, Optional
+
+from litellm import AuthenticationError
+
+from core.base.abstractions import VectorQuantizationSettings
+
+from ..abstractions import (
+ ChunkSearchResult,
+ EmbeddingPurpose,
+ default_embedding_prefixes,
+)
+from .base import Provider, ProviderConfig
+
+logger = logging.getLogger()
+
+
+class EmbeddingConfig(ProviderConfig):
+ provider: str
+ base_model: str
+ base_dimension: int | float
+ rerank_model: Optional[str] = None
+ rerank_url: Optional[str] = None
+ batch_size: int = 1
+ prefixes: Optional[dict[str, str]] = None
+ add_title_as_prefix: bool = True
+ concurrent_request_limit: int = 256
+ max_retries: int = 3
+ initial_backoff: float = 1
+ max_backoff: float = 64.0
+ quantization_settings: VectorQuantizationSettings = (
+ VectorQuantizationSettings()
+ )
+
+ ## deprecated
+ rerank_dimension: Optional[int] = None
+ rerank_transformer_type: Optional[str] = None
+
+ def validate_config(self) -> None:
+ if self.provider not in self.supported_providers:
+ raise ValueError(f"Provider '{self.provider}' is not supported.")
+
+ @property
+ def supported_providers(self) -> list[str]:
+ return ["litellm", "openai", "ollama"]
+
+
+class EmbeddingProvider(Provider):
+ class Step(Enum):
+ BASE = 1
+ RERANK = 2
+
+ def __init__(self, config: EmbeddingConfig):
+ if not isinstance(config, EmbeddingConfig):
+ raise ValueError(
+ "EmbeddingProvider must be initialized with a `EmbeddingConfig`."
+ )
+ logger.info(f"Initializing EmbeddingProvider with config {config}.")
+
+ super().__init__(config)
+ self.config: EmbeddingConfig = config
+ self.semaphore = asyncio.Semaphore(config.concurrent_request_limit)
+ self.current_requests = 0
+
+ async def _execute_with_backoff_async(self, task: dict[str, Any]):
+ retries = 0
+ backoff = self.config.initial_backoff
+ while retries < self.config.max_retries:
+ try:
+ async with self.semaphore:
+ return await self._execute_task(task)
+ except AuthenticationError:
+ raise
+ except Exception as e:
+ logger.warning(
+ f"Request failed (attempt {retries + 1}): {str(e)}"
+ )
+ retries += 1
+ if retries == self.config.max_retries:
+ raise
+ await asyncio.sleep(random.uniform(0, backoff))
+ backoff = min(backoff * 2, self.config.max_backoff)
+
+ def _execute_with_backoff_sync(self, task: dict[str, Any]):
+ retries = 0
+ backoff = self.config.initial_backoff
+ while retries < self.config.max_retries:
+ try:
+ return self._execute_task_sync(task)
+ except AuthenticationError:
+ raise
+ except Exception as e:
+ logger.warning(
+ f"Request failed (attempt {retries + 1}): {str(e)}"
+ )
+ retries += 1
+ if retries == self.config.max_retries:
+ raise
+ time.sleep(random.uniform(0, backoff))
+ backoff = min(backoff * 2, self.config.max_backoff)
+
+ @abstractmethod
+ async def _execute_task(self, task: dict[str, Any]):
+ pass
+
+ @abstractmethod
+ def _execute_task_sync(self, task: dict[str, Any]):
+ pass
+
+ async def async_get_embedding(
+ self,
+ text: str,
+ stage: Step = Step.BASE,
+ purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
+ ):
+ task = {
+ "text": text,
+ "stage": stage,
+ "purpose": purpose,
+ }
+ return await self._execute_with_backoff_async(task)
+
+ def get_embedding(
+ self,
+ text: str,
+ stage: Step = Step.BASE,
+ purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
+ ):
+ task = {
+ "text": text,
+ "stage": stage,
+ "purpose": purpose,
+ }
+ return self._execute_with_backoff_sync(task)
+
+ async def async_get_embeddings(
+ self,
+ texts: list[str],
+ stage: Step = Step.BASE,
+ purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
+ ):
+ task = {
+ "texts": texts,
+ "stage": stage,
+ "purpose": purpose,
+ }
+ return await self._execute_with_backoff_async(task)
+
+ def get_embeddings(
+ self,
+ texts: list[str],
+ stage: Step = Step.BASE,
+ purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
+ ) -> list[list[float]]:
+ task = {
+ "texts": texts,
+ "stage": stage,
+ "purpose": purpose,
+ }
+ return self._execute_with_backoff_sync(task)
+
+ @abstractmethod
+ def rerank(
+ self,
+ query: str,
+ results: list[ChunkSearchResult],
+ stage: Step = Step.RERANK,
+ limit: int = 10,
+ ):
+ pass
+
+ @abstractmethod
+ async def arerank(
+ self,
+ query: str,
+ results: list[ChunkSearchResult],
+ stage: Step = Step.RERANK,
+ limit: int = 10,
+ ):
+ pass
+
+ def set_prefixes(self, config_prefixes: dict[str, str], base_model: str):
+ self.prefixes = {}
+
+ for t, p in config_prefixes.items():
+ purpose = EmbeddingPurpose(t.lower())
+ self.prefixes[purpose] = p
+
+ if base_model in default_embedding_prefixes:
+ for t, p in default_embedding_prefixes[base_model].items():
+ if t not in self.prefixes:
+ self.prefixes[t] = p
diff --git a/.venv/lib/python3.12/site-packages/core/base/providers/ingestion.py b/.venv/lib/python3.12/site-packages/core/base/providers/ingestion.py
new file mode 100644
index 00000000..70d0d3a0
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/base/providers/ingestion.py
@@ -0,0 +1,172 @@
+import logging
+from abc import ABC
+from enum import Enum
+from typing import TYPE_CHECKING, Any, ClassVar, Optional
+
+from pydantic import Field
+
+from core.base.abstractions import ChunkEnrichmentSettings
+
+from .base import AppConfig, Provider, ProviderConfig
+from .llm import CompletionProvider
+
+logger = logging.getLogger()
+
+if TYPE_CHECKING:
+ from core.providers.database import PostgresDatabaseProvider
+
+
+class ChunkingStrategy(str, Enum):
+ RECURSIVE = "recursive"
+ CHARACTER = "character"
+ BASIC = "basic"
+ BY_TITLE = "by_title"
+
+
+class IngestionMode(str, Enum):
+ hi_res = "hi-res"
+ fast = "fast"
+ custom = "custom"
+
+
+class IngestionConfig(ProviderConfig):
+ _defaults: ClassVar[dict] = {
+ "app": AppConfig(),
+ "provider": "r2r",
+ "excluded_parsers": ["mp4"],
+ "chunking_strategy": "recursive",
+ "chunk_size": 1024,
+ "chunk_enrichment_settings": ChunkEnrichmentSettings(),
+ "extra_parsers": {},
+ "audio_transcription_model": None,
+ "vision_img_prompt_name": "vision_img",
+ "vision_pdf_prompt_name": "vision_pdf",
+ "skip_document_summary": False,
+ "document_summary_system_prompt": "system",
+ "document_summary_task_prompt": "summary",
+ "document_summary_max_length": 100_000,
+ "chunks_for_document_summary": 128,
+ "document_summary_model": None,
+ "parser_overrides": {},
+ "extra_fields": {},
+ "automatic_extraction": False,
+ }
+
+ provider: str = Field(
+ default_factory=lambda: IngestionConfig._defaults["provider"]
+ )
+ excluded_parsers: list[str] = Field(
+ default_factory=lambda: IngestionConfig._defaults["excluded_parsers"]
+ )
+ chunking_strategy: str | ChunkingStrategy = Field(
+ default_factory=lambda: IngestionConfig._defaults["chunking_strategy"]
+ )
+ chunk_size: int = Field(
+ default_factory=lambda: IngestionConfig._defaults["chunk_size"]
+ )
+ chunk_enrichment_settings: ChunkEnrichmentSettings = Field(
+ default_factory=lambda: IngestionConfig._defaults[
+ "chunk_enrichment_settings"
+ ]
+ )
+ extra_parsers: dict[str, Any] = Field(
+ default_factory=lambda: IngestionConfig._defaults["extra_parsers"]
+ )
+ audio_transcription_model: Optional[str] = Field(
+ default_factory=lambda: IngestionConfig._defaults[
+ "audio_transcription_model"
+ ]
+ )
+ vision_img_prompt_name: str = Field(
+ default_factory=lambda: IngestionConfig._defaults[
+ "vision_img_prompt_name"
+ ]
+ )
+ vision_pdf_prompt_name: str = Field(
+ default_factory=lambda: IngestionConfig._defaults[
+ "vision_pdf_prompt_name"
+ ]
+ )
+ skip_document_summary: bool = Field(
+ default_factory=lambda: IngestionConfig._defaults[
+ "skip_document_summary"
+ ]
+ )
+ document_summary_system_prompt: str = Field(
+ default_factory=lambda: IngestionConfig._defaults[
+ "document_summary_system_prompt"
+ ]
+ )
+ document_summary_task_prompt: str = Field(
+ default_factory=lambda: IngestionConfig._defaults[
+ "document_summary_task_prompt"
+ ]
+ )
+ chunks_for_document_summary: int = Field(
+ default_factory=lambda: IngestionConfig._defaults[
+ "chunks_for_document_summary"
+ ]
+ )
+ document_summary_model: Optional[str] = Field(
+ default_factory=lambda: IngestionConfig._defaults[
+ "document_summary_model"
+ ]
+ )
+ parser_overrides: dict[str, str] = Field(
+ default_factory=lambda: IngestionConfig._defaults["parser_overrides"]
+ )
+ automatic_extraction: bool = Field(
+ default_factory=lambda: IngestionConfig._defaults[
+ "automatic_extraction"
+ ]
+ )
+ document_summary_max_length: int = Field(
+ default_factory=lambda: IngestionConfig._defaults[
+ "document_summary_max_length"
+ ]
+ )
+
+ @classmethod
+ def set_default(cls, **kwargs):
+ for key, value in kwargs.items():
+ if key in cls._defaults:
+ cls._defaults[key] = value
+ else:
+ raise AttributeError(
+ f"No default attribute '{key}' in IngestionConfig"
+ )
+
+ @property
+ def supported_providers(self) -> list[str]:
+ return ["r2r", "unstructured_local", "unstructured_api"]
+
+ def validate_config(self) -> None:
+ if self.provider not in self.supported_providers:
+ raise ValueError(f"Provider {self.provider} is not supported.")
+
+ @classmethod
+ def get_default(cls, mode: str, app) -> "IngestionConfig":
+ """Return default ingestion configuration for a given mode."""
+ if mode == "hi-res":
+ return cls(app=app, parser_overrides={"pdf": "zerox"})
+ if mode == "fast":
+ return cls(app=app, skip_document_summary=True)
+ else:
+ return cls(app=app)
+
+
+class IngestionProvider(Provider, ABC):
+ config: IngestionConfig
+ database_provider: "PostgresDatabaseProvider"
+ llm_provider: CompletionProvider
+
+ def __init__(
+ self,
+ config: IngestionConfig,
+ database_provider: "PostgresDatabaseProvider",
+ llm_provider: CompletionProvider,
+ ):
+ super().__init__(config)
+ self.config: IngestionConfig = config
+ self.llm_provider = llm_provider
+ self.database_provider: "PostgresDatabaseProvider" = database_provider
diff --git a/.venv/lib/python3.12/site-packages/core/base/providers/llm.py b/.venv/lib/python3.12/site-packages/core/base/providers/llm.py
new file mode 100644
index 00000000..669dfc4f
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/base/providers/llm.py
@@ -0,0 +1,200 @@
+import asyncio
+import logging
+import random
+import time
+from abc import abstractmethod
+from concurrent.futures import ThreadPoolExecutor
+from typing import Any, AsyncGenerator, Generator, Optional
+
+from litellm import AuthenticationError
+
+from core.base.abstractions import (
+ GenerationConfig,
+ LLMChatCompletion,
+ LLMChatCompletionChunk,
+)
+
+from .base import Provider, ProviderConfig
+
+logger = logging.getLogger()
+
+
+class CompletionConfig(ProviderConfig):
+ provider: Optional[str] = None
+ generation_config: Optional[GenerationConfig] = None
+ concurrent_request_limit: int = 256
+ max_retries: int = 3
+ initial_backoff: float = 1.0
+ max_backoff: float = 64.0
+
+ def validate_config(self) -> None:
+ if not self.provider:
+ raise ValueError("Provider must be set.")
+ if self.provider not in self.supported_providers:
+ raise ValueError(f"Provider '{self.provider}' is not supported.")
+
+ @property
+ def supported_providers(self) -> list[str]:
+ return ["anthropic", "litellm", "openai", "r2r"]
+
+
+class CompletionProvider(Provider):
+ def __init__(self, config: CompletionConfig) -> None:
+ if not isinstance(config, CompletionConfig):
+ raise ValueError(
+ "CompletionProvider must be initialized with a `CompletionConfig`."
+ )
+ logger.info(f"Initializing CompletionProvider with config: {config}")
+ super().__init__(config)
+ self.config: CompletionConfig = config
+ self.semaphore = asyncio.Semaphore(config.concurrent_request_limit)
+ self.thread_pool = ThreadPoolExecutor(
+ max_workers=config.concurrent_request_limit
+ )
+
+ async def _execute_with_backoff_async(self, task: dict[str, Any]):
+ retries = 0
+ backoff = self.config.initial_backoff
+ while retries < self.config.max_retries:
+ try:
+ async with self.semaphore:
+ return await self._execute_task(task)
+ except AuthenticationError:
+ raise
+ except Exception as e:
+ logger.warning(
+ f"Request failed (attempt {retries + 1}): {str(e)}"
+ )
+ retries += 1
+ if retries == self.config.max_retries:
+ raise
+ await asyncio.sleep(random.uniform(0, backoff))
+ backoff = min(backoff * 2, self.config.max_backoff)
+
+ async def _execute_with_backoff_async_stream(
+ self, task: dict[str, Any]
+ ) -> AsyncGenerator[Any, None]:
+ retries = 0
+ backoff = self.config.initial_backoff
+ while retries < self.config.max_retries:
+ try:
+ async with self.semaphore:
+ async for chunk in await self._execute_task(task):
+ yield chunk
+ return # Successful completion of the stream
+ except AuthenticationError:
+ raise
+ except Exception as e:
+ logger.warning(
+ f"Streaming request failed (attempt {retries + 1}): {str(e)}"
+ )
+ retries += 1
+ if retries == self.config.max_retries:
+ raise
+ await asyncio.sleep(random.uniform(0, backoff))
+ backoff = min(backoff * 2, self.config.max_backoff)
+
+ def _execute_with_backoff_sync(self, task: dict[str, Any]):
+ retries = 0
+ backoff = self.config.initial_backoff
+ while retries < self.config.max_retries:
+ try:
+ return self._execute_task_sync(task)
+ except Exception as e:
+ logger.warning(
+ f"Request failed (attempt {retries + 1}): {str(e)}"
+ )
+ retries += 1
+ if retries == self.config.max_retries:
+ raise
+ time.sleep(random.uniform(0, backoff))
+ backoff = min(backoff * 2, self.config.max_backoff)
+
+ def _execute_with_backoff_sync_stream(
+ self, task: dict[str, Any]
+ ) -> Generator[Any, None, None]:
+ retries = 0
+ backoff = self.config.initial_backoff
+ while retries < self.config.max_retries:
+ try:
+ yield from self._execute_task_sync(task)
+ return # Successful completion of the stream
+ except Exception as e:
+ logger.warning(
+ f"Streaming request failed (attempt {retries + 1}): {str(e)}"
+ )
+ retries += 1
+ if retries == self.config.max_retries:
+ raise
+ time.sleep(random.uniform(0, backoff))
+ backoff = min(backoff * 2, self.config.max_backoff)
+
+ @abstractmethod
+ async def _execute_task(self, task: dict[str, Any]):
+ pass
+
+ @abstractmethod
+ def _execute_task_sync(self, task: dict[str, Any]):
+ pass
+
+ async def aget_completion(
+ self,
+ messages: list[dict],
+ generation_config: GenerationConfig,
+ **kwargs,
+ ) -> LLMChatCompletion:
+ task = {
+ "messages": messages,
+ "generation_config": generation_config,
+ "kwargs": kwargs,
+ }
+ response = await self._execute_with_backoff_async(task)
+ return LLMChatCompletion(**response.dict())
+
+ async def aget_completion_stream(
+ self,
+ messages: list[dict],
+ generation_config: GenerationConfig,
+ **kwargs,
+ ) -> AsyncGenerator[LLMChatCompletionChunk, None]:
+ generation_config.stream = True
+ task = {
+ "messages": messages,
+ "generation_config": generation_config,
+ "kwargs": kwargs,
+ }
+ async for chunk in self._execute_with_backoff_async_stream(task):
+ if isinstance(chunk, dict):
+ yield LLMChatCompletionChunk(**chunk)
+ continue
+
+ chunk.choices[0].finish_reason = (
+ chunk.choices[0].finish_reason
+ if chunk.choices[0].finish_reason != ""
+ else None
+ ) # handle error output conventions
+ chunk.choices[0].finish_reason = (
+ chunk.choices[0].finish_reason
+ if chunk.choices[0].finish_reason != "eos"
+ else "stop"
+ ) # hardcode `eos` to `stop` for consistency
+ try:
+ yield LLMChatCompletionChunk(**(chunk.dict()))
+ except Exception as e:
+ logger.error(f"Error parsing chunk: {e}")
+ yield LLMChatCompletionChunk(**(chunk.as_dict()))
+
+ def get_completion_stream(
+ self,
+ messages: list[dict],
+ generation_config: GenerationConfig,
+ **kwargs,
+ ) -> Generator[LLMChatCompletionChunk, None, None]:
+ generation_config.stream = True
+ task = {
+ "messages": messages,
+ "generation_config": generation_config,
+ "kwargs": kwargs,
+ }
+ for chunk in self._execute_with_backoff_sync_stream(task):
+ yield LLMChatCompletionChunk(**chunk.dict())
diff --git a/.venv/lib/python3.12/site-packages/core/base/providers/orchestration.py b/.venv/lib/python3.12/site-packages/core/base/providers/orchestration.py
new file mode 100644
index 00000000..c3105f30
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/base/providers/orchestration.py
@@ -0,0 +1,70 @@
+from abc import abstractmethod
+from enum import Enum
+from typing import Any
+
+from .base import Provider, ProviderConfig
+
+
+class Workflow(Enum):
+ INGESTION = "ingestion"
+ GRAPH = "graph"
+
+
+class OrchestrationConfig(ProviderConfig):
+ provider: str
+ max_runs: int = 2_048
+ graph_search_results_creation_concurrency_limit: int = 32
+ ingestion_concurrency_limit: int = 16
+ graph_search_results_concurrency_limit: int = 8
+
+ def validate_config(self) -> None:
+ if self.provider not in self.supported_providers:
+ raise ValueError(f"Provider {self.provider} is not supported.")
+
+ @property
+ def supported_providers(self) -> list[str]:
+ return ["hatchet", "simple"]
+
+
+class OrchestrationProvider(Provider):
+ def __init__(self, config: OrchestrationConfig):
+ super().__init__(config)
+ self.config = config
+ self.worker = None
+
+ @abstractmethod
+ async def start_worker(self):
+ pass
+
+ @abstractmethod
+ def get_worker(self, name: str, max_runs: int) -> Any:
+ pass
+
+ @abstractmethod
+ def step(self, *args, **kwargs) -> Any:
+ pass
+
+ @abstractmethod
+ def workflow(self, *args, **kwargs) -> Any:
+ pass
+
+ @abstractmethod
+ def failure(self, *args, **kwargs) -> Any:
+ pass
+
+ @abstractmethod
+ def register_workflows(
+ self, workflow: Workflow, service: Any, messages: dict
+ ) -> None:
+ pass
+
+ @abstractmethod
+ async def run_workflow(
+ self,
+ workflow_name: str,
+ parameters: dict,
+ options: dict,
+ *args,
+ **kwargs,
+ ) -> dict[str, str]:
+ pass