about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/core/main/assembly
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/core/main/assembly
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/core/main/assembly')
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/assembly/__init__.py12
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/assembly/builder.py127
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/assembly/factory.py417
3 files changed, 556 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/core/main/assembly/__init__.py b/.venv/lib/python3.12/site-packages/core/main/assembly/__init__.py
new file mode 100644
index 00000000..3d10f2b6
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/assembly/__init__.py
@@ -0,0 +1,12 @@
+from ..config import R2RConfig
+from .builder import R2RBuilder
+from .factory import R2RProviderFactory
+
+__all__ = [
+    # Builder
+    "R2RBuilder",
+    # Config
+    "R2RConfig",
+    # Factory
+    "R2RProviderFactory",
+]
diff --git a/.venv/lib/python3.12/site-packages/core/main/assembly/builder.py b/.venv/lib/python3.12/site-packages/core/main/assembly/builder.py
new file mode 100644
index 00000000..f72a15c9
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/assembly/builder.py
@@ -0,0 +1,127 @@
+import logging
+from typing import Any, Type
+
+from ..abstractions import R2RProviders, R2RServices
+from ..api.v3.chunks_router import ChunksRouter
+from ..api.v3.collections_router import CollectionsRouter
+from ..api.v3.conversations_router import ConversationsRouter
+from ..api.v3.documents_router import DocumentsRouter
+from ..api.v3.graph_router import GraphRouter
+from ..api.v3.indices_router import IndicesRouter
+from ..api.v3.prompts_router import PromptsRouter
+from ..api.v3.retrieval_router import RetrievalRouter
+from ..api.v3.system_router import SystemRouter
+from ..api.v3.users_router import UsersRouter
+from ..app import R2RApp
+from ..config import R2RConfig
+from ..services.auth_service import AuthService  # noqa: F401
+from ..services.graph_service import GraphService  # noqa: F401
+from ..services.ingestion_service import IngestionService  # noqa: F401
+from ..services.management_service import ManagementService  # noqa: F401
+from ..services.retrieval_service import (  # type: ignore
+    RetrievalService,  # noqa: F401 # type: ignore
+)
+from .factory import R2RProviderFactory
+
+logger = logging.getLogger()
+
+
+class R2RBuilder:
+    _SERVICES = ["auth", "ingestion", "management", "retrieval", "graph"]
+
+    def __init__(self, config: R2RConfig):
+        self.config = config
+
+    async def build(self, *args, **kwargs) -> R2RApp:
+        provider_factory = R2RProviderFactory
+
+        try:
+            providers = await self._create_providers(
+                provider_factory, *args, **kwargs
+            )
+        except Exception as e:
+            logger.error(f"Error {e} while creating R2RProviders.")
+            raise
+
+        service_params = {
+            "config": self.config,
+            "providers": providers,
+        }
+
+        services = self._create_services(service_params)
+
+        routers = {
+            "chunks_router": ChunksRouter(
+                providers=providers,
+                services=services,
+                config=self.config,
+            ).get_router(),
+            "collections_router": CollectionsRouter(
+                providers=providers,
+                services=services,
+                config=self.config,
+            ).get_router(),
+            "conversations_router": ConversationsRouter(
+                providers=providers,
+                services=services,
+                config=self.config,
+            ).get_router(),
+            "documents_router": DocumentsRouter(
+                providers=providers,
+                services=services,
+                config=self.config,
+            ).get_router(),
+            "graph_router": GraphRouter(
+                providers=providers,
+                services=services,
+                config=self.config,
+            ).get_router(),
+            "indices_router": IndicesRouter(
+                providers=providers,
+                services=services,
+                config=self.config,
+            ).get_router(),
+            "prompts_router": PromptsRouter(
+                providers=providers,
+                services=services,
+                config=self.config,
+            ).get_router(),
+            "retrieval_router": RetrievalRouter(
+                providers=providers,
+                services=services,
+                config=self.config,
+            ).get_router(),
+            "system_router": SystemRouter(
+                providers=providers,
+                services=services,
+                config=self.config,
+            ).get_router(),
+            "users_router": UsersRouter(
+                providers=providers,
+                services=services,
+                config=self.config,
+            ).get_router(),
+        }
+
+        return R2RApp(
+            config=self.config,
+            orchestration_provider=providers.orchestration,
+            services=services,
+            **routers,
+        )
+
+    async def _create_providers(
+        self, provider_factory: Type[R2RProviderFactory], *args, **kwargs
+    ) -> R2RProviders:
+        factory = provider_factory(self.config)
+        return await factory.create_providers(*args, **kwargs)
+
+    def _create_services(self, service_params: dict[str, Any]) -> R2RServices:
+        services = R2RBuilder._SERVICES
+        service_instances = {}
+
+        for service_type in services:
+            service_class = globals()[f"{service_type.capitalize()}Service"]
+            service_instances[service_type] = service_class(**service_params)
+
+        return R2RServices(**service_instances)
diff --git a/.venv/lib/python3.12/site-packages/core/main/assembly/factory.py b/.venv/lib/python3.12/site-packages/core/main/assembly/factory.py
new file mode 100644
index 00000000..b982aa18
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/assembly/factory.py
@@ -0,0 +1,417 @@
+import logging
+import math
+import os
+from typing import Any, Optional
+
+from core.base import (
+    AuthConfig,
+    CompletionConfig,
+    CompletionProvider,
+    CryptoConfig,
+    DatabaseConfig,
+    EmailConfig,
+    EmbeddingConfig,
+    EmbeddingProvider,
+    IngestionConfig,
+    OrchestrationConfig,
+)
+from core.providers import (
+    AnthropicCompletionProvider,
+    AsyncSMTPEmailProvider,
+    BcryptCryptoConfig,
+    BCryptCryptoProvider,
+    ClerkAuthProvider,
+    ConsoleMockEmailProvider,
+    HatchetOrchestrationProvider,
+    JwtAuthProvider,
+    LiteLLMCompletionProvider,
+    LiteLLMEmbeddingProvider,
+    MailerSendEmailProvider,
+    NaClCryptoConfig,
+    NaClCryptoProvider,
+    OllamaEmbeddingProvider,
+    OpenAICompletionProvider,
+    OpenAIEmbeddingProvider,
+    PostgresDatabaseProvider,
+    R2RAuthProvider,
+    R2RCompletionProvider,
+    R2RIngestionConfig,
+    R2RIngestionProvider,
+    SendGridEmailProvider,
+    SimpleOrchestrationProvider,
+    SupabaseAuthProvider,
+    UnstructuredIngestionConfig,
+    UnstructuredIngestionProvider,
+)
+
+from ..abstractions import R2RProviders
+from ..config import R2RConfig
+
+logger = logging.getLogger()
+
+
+class R2RProviderFactory:
+    def __init__(self, config: R2RConfig):
+        self.config = config
+
+    @staticmethod
+    async def create_auth_provider(
+        auth_config: AuthConfig,
+        crypto_provider: BCryptCryptoProvider | NaClCryptoProvider,
+        database_provider: PostgresDatabaseProvider,
+        email_provider: (
+            AsyncSMTPEmailProvider
+            | ConsoleMockEmailProvider
+            | SendGridEmailProvider
+            | MailerSendEmailProvider
+        ),
+        *args,
+        **kwargs,
+    ) -> (
+        R2RAuthProvider
+        | SupabaseAuthProvider
+        | JwtAuthProvider
+        | ClerkAuthProvider
+    ):
+        if auth_config.provider == "r2r":
+            r2r_auth = R2RAuthProvider(
+                auth_config, crypto_provider, database_provider, email_provider
+            )
+            await r2r_auth.initialize()
+            return r2r_auth
+        elif auth_config.provider == "supabase":
+            return SupabaseAuthProvider(
+                auth_config, crypto_provider, database_provider, email_provider
+            )
+        elif auth_config.provider == "jwt":
+            return JwtAuthProvider(
+                auth_config, crypto_provider, database_provider, email_provider
+            )
+        elif auth_config.provider == "clerk":
+            return ClerkAuthProvider(
+                auth_config, crypto_provider, database_provider, email_provider
+            )
+        else:
+            raise ValueError(
+                f"Auth provider {auth_config.provider} not supported."
+            )
+
+    @staticmethod
+    def create_crypto_provider(
+        crypto_config: CryptoConfig, *args, **kwargs
+    ) -> BCryptCryptoProvider | NaClCryptoProvider:
+        if crypto_config.provider == "bcrypt":
+            return BCryptCryptoProvider(
+                BcryptCryptoConfig(**crypto_config.model_dump())
+            )
+        if crypto_config.provider == "nacl":
+            return NaClCryptoProvider(
+                NaClCryptoConfig(**crypto_config.model_dump())
+            )
+        else:
+            raise ValueError(
+                f"Crypto provider {crypto_config.provider} not supported."
+            )
+
+    @staticmethod
+    def create_ingestion_provider(
+        ingestion_config: IngestionConfig,
+        database_provider: PostgresDatabaseProvider,
+        llm_provider: (
+            AnthropicCompletionProvider
+            | LiteLLMCompletionProvider
+            | OpenAICompletionProvider
+            | R2RCompletionProvider
+        ),
+        *args,
+        **kwargs,
+    ) -> R2RIngestionProvider | UnstructuredIngestionProvider:
+        config_dict = (
+            ingestion_config.model_dump()
+            if isinstance(ingestion_config, IngestionConfig)
+            else ingestion_config
+        )
+
+        extra_fields = config_dict.pop("extra_fields", {})
+
+        if config_dict["provider"] == "r2r":
+            r2r_ingestion_config = R2RIngestionConfig(
+                **config_dict, **extra_fields
+            )
+            return R2RIngestionProvider(
+                r2r_ingestion_config, database_provider, llm_provider
+            )
+        elif config_dict["provider"] in [
+            "unstructured_local",
+            "unstructured_api",
+        ]:
+            unstructured_ingestion_config = UnstructuredIngestionConfig(
+                **config_dict, **extra_fields
+            )
+
+            return UnstructuredIngestionProvider(
+                unstructured_ingestion_config, database_provider, llm_provider
+            )
+        else:
+            raise ValueError(
+                f"Ingestion provider {ingestion_config.provider} not supported"
+            )
+
+    @staticmethod
+    def create_orchestration_provider(
+        config: OrchestrationConfig, *args, **kwargs
+    ) -> HatchetOrchestrationProvider | SimpleOrchestrationProvider:
+        if config.provider == "hatchet":
+            orchestration_provider = HatchetOrchestrationProvider(config)
+            orchestration_provider.get_worker("r2r-worker")
+            return orchestration_provider
+        elif config.provider == "simple":
+            from core.providers import SimpleOrchestrationProvider
+
+            return SimpleOrchestrationProvider(config)
+        else:
+            raise ValueError(
+                f"Orchestration provider {config.provider} not supported"
+            )
+
+    async def create_database_provider(
+        self,
+        db_config: DatabaseConfig,
+        crypto_provider: BCryptCryptoProvider | NaClCryptoProvider,
+        *args,
+        **kwargs,
+    ) -> PostgresDatabaseProvider:
+        if not self.config.embedding.base_dimension:
+            raise ValueError(
+                "Embedding config must have a base dimension to initialize database."
+            )
+
+        dimension = self.config.embedding.base_dimension
+        quantization_type = (
+            self.config.embedding.quantization_settings.quantization_type
+        )
+        if db_config.provider == "postgres":
+            database_provider = PostgresDatabaseProvider(
+                db_config,
+                dimension,
+                crypto_provider=crypto_provider,
+                quantization_type=quantization_type,
+            )
+            await database_provider.initialize()
+            return database_provider
+        else:
+            raise ValueError(
+                f"Database provider {db_config.provider} not supported"
+            )
+
+    @staticmethod
+    def create_embedding_provider(
+        embedding: EmbeddingConfig, *args, **kwargs
+    ) -> (
+        LiteLLMEmbeddingProvider
+        | OllamaEmbeddingProvider
+        | OpenAIEmbeddingProvider
+    ):
+        embedding_provider: Optional[EmbeddingProvider] = None
+
+        if embedding.provider == "openai":
+            if not os.getenv("OPENAI_API_KEY"):
+                raise ValueError(
+                    "Must set OPENAI_API_KEY in order to initialize OpenAIEmbeddingProvider."
+                )
+            from core.providers import OpenAIEmbeddingProvider
+
+            embedding_provider = OpenAIEmbeddingProvider(embedding)
+
+        elif embedding.provider == "litellm":
+            from core.providers import LiteLLMEmbeddingProvider
+
+            embedding_provider = LiteLLMEmbeddingProvider(embedding)
+
+        elif embedding.provider == "ollama":
+            from core.providers import OllamaEmbeddingProvider
+
+            embedding_provider = OllamaEmbeddingProvider(embedding)
+
+        else:
+            raise ValueError(
+                f"Embedding provider {embedding.provider} not supported"
+            )
+
+        return embedding_provider
+
+    @staticmethod
+    def create_llm_provider(
+        llm_config: CompletionConfig, *args, **kwargs
+    ) -> (
+        AnthropicCompletionProvider
+        | LiteLLMCompletionProvider
+        | OpenAICompletionProvider
+        | R2RCompletionProvider
+    ):
+        llm_provider: Optional[CompletionProvider] = None
+        if llm_config.provider == "anthropic":
+            llm_provider = AnthropicCompletionProvider(llm_config)
+        elif llm_config.provider == "litellm":
+            llm_provider = LiteLLMCompletionProvider(llm_config)
+        elif llm_config.provider == "openai":
+            llm_provider = OpenAICompletionProvider(llm_config)
+        elif llm_config.provider == "r2r":
+            llm_provider = R2RCompletionProvider(llm_config)
+        else:
+            raise ValueError(
+                f"Language model provider {llm_config.provider} not supported"
+            )
+        if not llm_provider:
+            raise ValueError("Language model provider not found")
+        return llm_provider
+
+    @staticmethod
+    async def create_email_provider(
+        email_config: Optional[EmailConfig] = None, *args, **kwargs
+    ) -> (
+        AsyncSMTPEmailProvider
+        | ConsoleMockEmailProvider
+        | SendGridEmailProvider
+        | MailerSendEmailProvider
+    ):
+        """Creates an email provider based on configuration."""
+        if not email_config:
+            raise ValueError(
+                "No email configuration provided for email provider, please add `[email]` to your `r2r.toml`."
+            )
+
+        if email_config.provider == "smtp":
+            return AsyncSMTPEmailProvider(email_config)
+        elif email_config.provider == "console_mock":
+            return ConsoleMockEmailProvider(email_config)
+        elif email_config.provider == "sendgrid":
+            return SendGridEmailProvider(email_config)
+        elif email_config.provider == "mailersend":
+            return MailerSendEmailProvider(email_config)
+        else:
+            raise ValueError(
+                f"Email provider {email_config.provider} not supported."
+            )
+
+    async def create_providers(
+        self,
+        auth_provider_override: Optional[
+            R2RAuthProvider | SupabaseAuthProvider
+        ] = None,
+        crypto_provider_override: Optional[
+            BCryptCryptoProvider | NaClCryptoProvider
+        ] = None,
+        database_provider_override: Optional[PostgresDatabaseProvider] = None,
+        email_provider_override: Optional[
+            AsyncSMTPEmailProvider
+            | ConsoleMockEmailProvider
+            | SendGridEmailProvider
+            | MailerSendEmailProvider
+        ] = None,
+        embedding_provider_override: Optional[
+            LiteLLMEmbeddingProvider
+            | OpenAIEmbeddingProvider
+            | OllamaEmbeddingProvider
+        ] = None,
+        ingestion_provider_override: Optional[
+            R2RIngestionProvider | UnstructuredIngestionProvider
+        ] = None,
+        llm_provider_override: Optional[
+            AnthropicCompletionProvider
+            | OpenAICompletionProvider
+            | LiteLLMCompletionProvider
+            | R2RCompletionProvider
+        ] = None,
+        orchestration_provider_override: Optional[Any] = None,
+        *args,
+        **kwargs,
+    ) -> R2RProviders:
+        if (
+            math.isnan(self.config.embedding.base_dimension)
+            != math.isnan(self.config.completion_embedding.base_dimension)
+        ) or (
+            not math.isnan(self.config.embedding.base_dimension)
+            and not math.isnan(self.config.completion_embedding.base_dimension)
+            and self.config.embedding.base_dimension
+            != self.config.completion_embedding.base_dimension
+        ):
+            raise ValueError(
+                f"Both embedding configurations must use the same dimensions. Got {self.config.embedding.base_dimension} and {self.config.completion_embedding.base_dimension}"
+            )
+
+        embedding_provider = (
+            embedding_provider_override
+            or self.create_embedding_provider(
+                self.config.embedding, *args, **kwargs
+            )
+        )
+
+        completion_embedding_provider = (
+            embedding_provider_override
+            or self.create_embedding_provider(
+                self.config.completion_embedding, *args, **kwargs
+            )
+        )
+
+        llm_provider = llm_provider_override or self.create_llm_provider(
+            self.config.completion, *args, **kwargs
+        )
+
+        crypto_provider = (
+            crypto_provider_override
+            or self.create_crypto_provider(self.config.crypto, *args, **kwargs)
+        )
+
+        database_provider = (
+            database_provider_override
+            or await self.create_database_provider(
+                self.config.database, crypto_provider, *args, **kwargs
+            )
+        )
+
+        ingestion_provider = (
+            ingestion_provider_override
+            or self.create_ingestion_provider(
+                self.config.ingestion,
+                database_provider,
+                llm_provider,
+                *args,
+                **kwargs,
+            )
+        )
+
+        email_provider = (
+            email_provider_override
+            or await self.create_email_provider(
+                self.config.email, crypto_provider, *args, **kwargs
+            )
+        )
+
+        auth_provider = (
+            auth_provider_override
+            or await self.create_auth_provider(
+                self.config.auth,
+                crypto_provider,
+                database_provider,
+                email_provider,
+                *args,
+                **kwargs,
+            )
+        )
+
+        orchestration_provider = (
+            orchestration_provider_override
+            or self.create_orchestration_provider(self.config.orchestration)
+        )
+
+        return R2RProviders(
+            auth=auth_provider,
+            database=database_provider,
+            embedding=embedding_provider,
+            completion_embedding=completion_embedding_provider,
+            ingestion=ingestion_provider,
+            llm=llm_provider,
+            email=email_provider,
+            orchestration=orchestration_provider,
+        )