about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/connections
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/connections')
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/connections/__init__.py37
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/connections/connection_subtypes.py225
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/connections/credentials.py178
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/connections/one_lake_artifacts.py26
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/connections/workspace_connection.py86
5 files changed, 552 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/connections/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/connections/__init__.py
new file mode 100644
index 00000000..fa462cfb
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/connections/__init__.py
@@ -0,0 +1,37 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__)  # type: ignore
+
+from .workspace_connection import WorkspaceConnectionSchema
+from .connection_subtypes import (
+    AzureBlobStoreConnectionSchema,
+    MicrosoftOneLakeConnectionSchema,
+    AzureOpenAIConnectionSchema,
+    AzureAIServicesConnectionSchema,
+    AzureAISearchConnectionSchema,
+    AzureContentSafetyConnectionSchema,
+    AzureSpeechServicesConnectionSchema,
+    APIKeyConnectionSchema,
+    OpenAIConnectionSchema,
+    SerpConnectionSchema,
+    ServerlessConnectionSchema,
+    OneLakeArtifactSchema,
+)
+
+__all__ = [
+    "WorkspaceConnectionSchema",
+    "AzureBlobStoreConnectionSchema",
+    "MicrosoftOneLakeConnectionSchema",
+    "AzureOpenAIConnectionSchema",
+    "AzureAIServicesConnectionSchema",
+    "AzureAISearchConnectionSchema",
+    "AzureContentSafetyConnectionSchema",
+    "AzureSpeechServicesConnectionSchema",
+    "APIKeyConnectionSchema",
+    "OpenAIConnectionSchema",
+    "SerpConnectionSchema",
+    "ServerlessConnectionSchema",
+    "OneLakeArtifactSchema",
+]
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/connections/connection_subtypes.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/connections/connection_subtypes.py
new file mode 100644
index 00000000..d04b3e76
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/connections/connection_subtypes.py
@@ -0,0 +1,225 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from marshmallow import fields, post_load
+from marshmallow.exceptions import ValidationError
+from marshmallow.decorators import pre_load
+
+from azure.ai.ml._restclient.v2024_04_01_preview.models import ConnectionCategory
+from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum, UnionField
+from azure.ai.ml._utils.utils import camel_to_snake
+from azure.ai.ml.constants._common import ConnectionTypes
+from azure.ai.ml._schema.workspace.connections.one_lake_artifacts import OneLakeArtifactSchema
+from azure.ai.ml._schema.workspace.connections.credentials import (
+    SasTokenConfigurationSchema,
+    ServicePrincipalConfigurationSchema,
+    AccountKeyConfigurationSchema,
+    AadCredentialConfigurationSchema,
+)
+from azure.ai.ml.entities import AadCredentialConfiguration
+from .workspace_connection import WorkspaceConnectionSchema
+
+
+class AzureBlobStoreConnectionSchema(WorkspaceConnectionSchema):
+    # type and credentials limited
+    type = StringTransformedEnum(
+        allowed_values=ConnectionCategory.AZURE_BLOB, casing_transform=camel_to_snake, required=True
+    )
+    credentials = UnionField(
+        [
+            NestedField(SasTokenConfigurationSchema),
+            NestedField(AccountKeyConfigurationSchema),
+            NestedField(AadCredentialConfigurationSchema),
+        ],
+        required=False,
+        load_default=AadCredentialConfiguration(),
+    )
+
+    url = fields.Str()
+
+    account_name = fields.Str(required=True, allow_none=False)
+    container_name = fields.Str(required=True, allow_none=False)
+
+    @post_load
+    def make(self, data, **kwargs):
+        from azure.ai.ml.entities import AzureBlobStoreConnection
+
+        return AzureBlobStoreConnection(**data)
+
+
+class MicrosoftOneLakeConnectionSchema(WorkspaceConnectionSchema):
+    type = StringTransformedEnum(
+        allowed_values=ConnectionCategory.AZURE_ONE_LAKE, casing_transform=camel_to_snake, required=True
+    )
+    credentials = UnionField(
+        [NestedField(ServicePrincipalConfigurationSchema), NestedField(AadCredentialConfigurationSchema)],
+        required=False,
+        load_default=AadCredentialConfiguration(),
+    )
+    artifact = NestedField(OneLakeArtifactSchema, required=False, allow_none=True)
+
+    endpoint = fields.Str(required=False)
+    one_lake_workspace_name = fields.Str(required=False)
+
+    @pre_load
+    def check_for_target(self, data, **kwargs):
+        target = data.get("target", None)
+        artifact = data.get("artifact", None)
+        endpoint = data.get("endpoint", None)
+        one_lake_workspace_name = data.get("one_lake_workspace_name", None)
+        # If the user is using a target, then they don't need the artifact and one lake workspace name.
+        # This is distinct from when the user set's the 'endpoint' value, which is also used to construct
+        # the target. If the target is already present, then the loaded connection YAML was probably produced
+        # by dumping an extant connection.
+        if target is None:
+            if artifact is None:
+                raise ValidationError("If target is unset, then artifact must be set")
+            if endpoint is None:
+                raise ValidationError("If target is unset, then endpoint must be set")
+            if one_lake_workspace_name is None:
+                raise ValidationError("If target is unset, then one_lake_workspace_name must be set")
+        return data
+
+    @post_load
+    def make(self, data, **kwargs):
+        from azure.ai.ml.entities import MicrosoftOneLakeConnection
+
+        return MicrosoftOneLakeConnection(**data)
+
+
+class AzureOpenAIConnectionSchema(WorkspaceConnectionSchema):
+    # type and credentials limited
+    type = StringTransformedEnum(
+        allowed_values=ConnectionCategory.AZURE_OPEN_AI, casing_transform=camel_to_snake, required=True
+    )
+    api_key = fields.Str(required=False, allow_none=True)
+    api_version = fields.Str(required=False, allow_none=True)
+
+    azure_endpoint = fields.Str()
+    open_ai_resource_id = fields.Str(required=False, allow_none=True)
+
+    @post_load
+    def make(self, data, **kwargs):
+        from azure.ai.ml.entities import AzureOpenAIConnection
+
+        return AzureOpenAIConnection(**data)
+
+
+class AzureAIServicesConnectionSchema(WorkspaceConnectionSchema):
+    # type and credentials limited
+    type = StringTransformedEnum(
+        allowed_values=ConnectionTypes.AZURE_AI_SERVICES, casing_transform=camel_to_snake, required=True
+    )
+    api_key = fields.Str(required=False, allow_none=True)
+    endpoint = fields.Str()
+    ai_services_resource_id = fields.Str()
+
+    @post_load
+    def make(self, data, **kwargs):
+        from azure.ai.ml.entities import AzureAIServicesConnection
+
+        return AzureAIServicesConnection(**data)
+
+
+class AzureAISearchConnectionSchema(WorkspaceConnectionSchema):
+    # type and credentials limited
+    type = StringTransformedEnum(
+        allowed_values=ConnectionTypes.AZURE_SEARCH, casing_transform=camel_to_snake, required=True
+    )
+    api_key = fields.Str(required=False, allow_none=True)
+    endpoint = fields.Str()
+
+    @post_load
+    def make(self, data, **kwargs):
+        from azure.ai.ml.entities import AzureAISearchConnection
+
+        return AzureAISearchConnection(**data)
+
+
+class AzureContentSafetyConnectionSchema(WorkspaceConnectionSchema):
+    # type and credentials limited
+    type = StringTransformedEnum(
+        allowed_values=ConnectionTypes.AZURE_CONTENT_SAFETY, casing_transform=camel_to_snake, required=True
+    )
+    api_key = fields.Str(required=False, allow_none=True)
+    endpoint = fields.Str()
+
+    @post_load
+    def make(self, data, **kwargs):
+        from azure.ai.ml.entities import AzureContentSafetyConnection
+
+        return AzureContentSafetyConnection(**data)
+
+
+class AzureSpeechServicesConnectionSchema(WorkspaceConnectionSchema):
+    # type and credentials limited
+    type = StringTransformedEnum(
+        allowed_values=ConnectionTypes.AZURE_SPEECH_SERVICES, casing_transform=camel_to_snake, required=True
+    )
+    api_key = fields.Str(required=False, allow_none=True)
+    endpoint = fields.Str()
+
+    @post_load
+    def make(self, data, **kwargs):
+        from azure.ai.ml.entities import AzureSpeechServicesConnection
+
+        return AzureSpeechServicesConnection(**data)
+
+
+class APIKeyConnectionSchema(WorkspaceConnectionSchema):
+    # type and credentials limited
+    type = StringTransformedEnum(
+        allowed_values=ConnectionCategory.API_KEY, casing_transform=camel_to_snake, required=True
+    )
+    api_key = fields.Str(required=True)
+    api_base = fields.Str(required=True)
+
+    @post_load
+    def make(self, data, **kwargs):
+        from azure.ai.ml.entities import APIKeyConnection
+
+        return APIKeyConnection(**data)
+
+
+class OpenAIConnectionSchema(WorkspaceConnectionSchema):
+    # type and credentials limited
+    type = StringTransformedEnum(
+        allowed_values=ConnectionCategory.OPEN_AI, casing_transform=camel_to_snake, required=True
+    )
+    api_key = fields.Str(required=True)
+
+    @post_load
+    def make(self, data, **kwargs):
+        from azure.ai.ml.entities import OpenAIConnection
+
+        return OpenAIConnection(**data)
+
+
+class SerpConnectionSchema(WorkspaceConnectionSchema):
+    # type and credentials limited
+    type = StringTransformedEnum(allowed_values=ConnectionCategory.SERP, casing_transform=camel_to_snake, required=True)
+    api_key = fields.Str(required=True)
+
+    @post_load
+    def make(self, data, **kwargs):
+        from azure.ai.ml.entities import SerpConnection
+
+        return SerpConnection(**data)
+
+
+class ServerlessConnectionSchema(WorkspaceConnectionSchema):
+    # type and credentials limited
+    type = StringTransformedEnum(
+        allowed_values=ConnectionCategory.SERVERLESS, casing_transform=camel_to_snake, required=True
+    )
+    api_key = fields.Str(required=True)
+    endpoint = fields.Str()
+
+    @post_load
+    def make(self, data, **kwargs):
+        from azure.ai.ml.entities import ServerlessConnection
+
+        return ServerlessConnection(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/connections/credentials.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/connections/credentials.py
new file mode 100644
index 00000000..52213c08
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/connections/credentials.py
@@ -0,0 +1,178 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+##### DEV NOTE: For some reason, these schemas correlate to the classes defined in ~azure.ai.ml.entities._credentials.
+# There used to be a credentials.py file in ~azure.ai.ml.entities.workspace.connections,
+# but it was, as far as I could tell, never used. So I removed it and added this comment.
+
+from typing import Dict
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._restclient.v2024_04_01_preview.models import ConnectionAuthType
+from azure.ai.ml._schema.core.fields import StringTransformedEnum
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+from azure.ai.ml._utils.utils import camel_to_snake
+from azure.ai.ml.entities._credentials import (
+    ManagedIdentityConfiguration,
+    PatTokenConfiguration,
+    SasTokenConfiguration,
+    ServicePrincipalConfiguration,
+    UsernamePasswordConfiguration,
+    AccessKeyConfiguration,
+    ApiKeyConfiguration,
+    AccountKeyConfiguration,
+    AadCredentialConfiguration,
+    NoneCredentialConfiguration,
+)
+
+
+class WorkspaceCredentialsSchema(metaclass=PatchedSchemaMeta):
+    type = fields.Str()
+
+
+class PatTokenConfigurationSchema(metaclass=PatchedSchemaMeta):
+    type = StringTransformedEnum(
+        allowed_values=ConnectionAuthType.PAT,
+        casing_transform=camel_to_snake,
+        required=True,
+    )
+    pat = fields.Str()
+
+    @post_load
+    def make(self, data: Dict[str, str], **kwargs) -> PatTokenConfiguration:
+        data.pop("type")
+        return PatTokenConfiguration(**data)
+
+
+class SasTokenConfigurationSchema(metaclass=PatchedSchemaMeta):
+    type = StringTransformedEnum(
+        allowed_values=ConnectionAuthType.SAS,
+        casing_transform=camel_to_snake,
+        required=True,
+    )
+    sas_token = fields.Str()
+
+    @post_load
+    def make(self, data: Dict[str, str], **kwargs) -> SasTokenConfiguration:
+        data.pop("type")
+        return SasTokenConfiguration(**data)
+
+
+class UsernamePasswordConfigurationSchema(metaclass=PatchedSchemaMeta):
+    type = StringTransformedEnum(
+        allowed_values=ConnectionAuthType.USERNAME_PASSWORD,
+        casing_transform=camel_to_snake,
+        required=True,
+    )
+    username = fields.Str()
+    password = fields.Str()
+
+    @post_load
+    def make(self, data: Dict[str, str], **kwargs) -> UsernamePasswordConfiguration:
+        data.pop("type")
+        return UsernamePasswordConfiguration(**data)
+
+
+class ManagedIdentityConfigurationSchema(metaclass=PatchedSchemaMeta):
+    type = StringTransformedEnum(
+        allowed_values=ConnectionAuthType.MANAGED_IDENTITY,
+        casing_transform=camel_to_snake,
+        required=True,
+    )
+    client_id = fields.Str()
+    resource_id = fields.Str()
+
+    @post_load
+    def make(self, data: Dict[str, str], **kwargs) -> ManagedIdentityConfiguration:
+        data.pop("type")
+        return ManagedIdentityConfiguration(**data)
+
+
+class ServicePrincipalConfigurationSchema(metaclass=PatchedSchemaMeta):
+    type = StringTransformedEnum(
+        allowed_values=ConnectionAuthType.SERVICE_PRINCIPAL,
+        casing_transform=camel_to_snake,
+        required=True,
+    )
+
+    client_id = fields.Str()
+    client_secret = fields.Str()
+    tenant_id = fields.Str()
+
+    @post_load
+    def make(self, data: Dict[str, str], **kwargs) -> ServicePrincipalConfiguration:
+        data.pop("type")
+        return ServicePrincipalConfiguration(**data)
+
+
+class AccessKeyConfigurationSchema(metaclass=PatchedSchemaMeta):
+    type = StringTransformedEnum(
+        allowed_values=ConnectionAuthType.ACCESS_KEY,
+        casing_transform=camel_to_snake,
+        required=True,
+    )
+    access_key_id = fields.Str()
+    secret_access_key = fields.Str()
+
+    @post_load
+    def make(self, data: Dict[str, str], **kwargs) -> AccessKeyConfiguration:
+        data.pop("type")
+        return AccessKeyConfiguration(**data)
+
+
+class ApiKeyConfigurationSchema(metaclass=PatchedSchemaMeta):
+    type = StringTransformedEnum(
+        allowed_values=ConnectionAuthType.API_KEY,
+        casing_transform=camel_to_snake,
+        required=True,
+    )
+    key = fields.Str()
+
+    @post_load
+    def make(self, data: Dict[str, str], **kwargs) -> ApiKeyConfiguration:
+        data.pop("type")
+        return ApiKeyConfiguration(**data)
+
+
+class AccountKeyConfigurationSchema(metaclass=PatchedSchemaMeta):
+    type = StringTransformedEnum(
+        allowed_values=ConnectionAuthType.ACCOUNT_KEY,
+        casing_transform=camel_to_snake,
+        required=True,
+    )
+    account_key = fields.Str()
+
+    @post_load
+    def make(self, data: Dict[str, str], **kwargs) -> AccountKeyConfiguration:
+        data.pop("type")
+        return AccountKeyConfiguration(**data)
+
+
+class AadCredentialConfigurationSchema(metaclass=PatchedSchemaMeta):
+    type = StringTransformedEnum(
+        allowed_values=ConnectionAuthType.AAD,
+        casing_transform=camel_to_snake,
+        required=True,
+    )
+
+    @post_load
+    def make(self, data: Dict[str, str], **kwargs) -> AadCredentialConfiguration:
+        data.pop("type")
+        return AadCredentialConfiguration(**data)
+
+
+class NoneCredentialConfigurationSchema(metaclass=PatchedSchemaMeta):
+    type = StringTransformedEnum(
+        allowed_values=ConnectionAuthType.NONE,
+        casing_transform=camel_to_snake,
+        required=True,
+    )
+
+    @post_load
+    def make(self, data: Dict[str, str], **kwargs) -> NoneCredentialConfiguration:
+        data.pop("type")
+        return NoneCredentialConfiguration(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/connections/one_lake_artifacts.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/connections/one_lake_artifacts.py
new file mode 100644
index 00000000..563a9359
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/connections/one_lake_artifacts.py
@@ -0,0 +1,26 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._schema.core.fields import StringTransformedEnum
+from azure.ai.ml._utils.utils import camel_to_snake
+from azure.ai.ml.constants._common import OneLakeArtifactTypes
+
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+
+
+class OneLakeArtifactSchema(metaclass=PatchedSchemaMeta):
+    type = StringTransformedEnum(
+        allowed_values=OneLakeArtifactTypes.ONE_LAKE, casing_transform=camel_to_snake, required=True
+    )
+    name = fields.Str(required=True)
+
+    @post_load
+    def make(self, data, **kwargs):
+        from azure.ai.ml.entities import OneLakeConnectionArtifact
+
+        return OneLakeConnectionArtifact(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/connections/workspace_connection.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/connections/workspace_connection.py
new file mode 100644
index 00000000..20863a5a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/connections/workspace_connection.py
@@ -0,0 +1,86 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._restclient.v2024_04_01_preview.models import ConnectionCategory
+from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum, UnionField
+from azure.ai.ml._schema.core.resource import ResourceSchema
+from azure.ai.ml._schema.job import CreationContextSchema
+from azure.ai.ml._schema.workspace.connections.credentials import (
+    AccountKeyConfigurationSchema,
+    ManagedIdentityConfigurationSchema,
+    PatTokenConfigurationSchema,
+    SasTokenConfigurationSchema,
+    ServicePrincipalConfigurationSchema,
+    UsernamePasswordConfigurationSchema,
+    AccessKeyConfigurationSchema,
+    ApiKeyConfigurationSchema,
+    AadCredentialConfigurationSchema,
+    NoneCredentialConfigurationSchema,
+)
+from azure.ai.ml._utils.utils import camel_to_snake
+from azure.ai.ml.constants._common import ConnectionTypes
+from azure.ai.ml.entities import NoneCredentialConfiguration, AadCredentialConfiguration
+
+
+class WorkspaceConnectionSchema(ResourceSchema):
+    # Inherits name, id, tags, and description fields from ResourceSchema
+    creation_context = NestedField(CreationContextSchema, dump_only=True)
+    type = StringTransformedEnum(
+        allowed_values=[
+            ConnectionCategory.GIT,
+            ConnectionCategory.CONTAINER_REGISTRY,
+            ConnectionCategory.PYTHON_FEED,
+            ConnectionCategory.S3,
+            ConnectionCategory.SNOWFLAKE,
+            ConnectionCategory.AZURE_SQL_DB,
+            ConnectionCategory.AZURE_SYNAPSE_ANALYTICS,
+            ConnectionCategory.AZURE_MY_SQL_DB,
+            ConnectionCategory.AZURE_POSTGRES_DB,
+            ConnectionTypes.CUSTOM,
+            ConnectionTypes.AZURE_DATA_LAKE_GEN_2,
+        ],
+        casing_transform=camel_to_snake,
+        required=True,
+    )
+
+    # Sorta false, some connection types require this field, some don't.
+    # And some rename it... for client familiarity reasons.
+    target = fields.Str(required=False)
+
+    credentials = UnionField(
+        [
+            NestedField(PatTokenConfigurationSchema),
+            NestedField(SasTokenConfigurationSchema),
+            NestedField(UsernamePasswordConfigurationSchema),
+            NestedField(ManagedIdentityConfigurationSchema),
+            NestedField(ServicePrincipalConfigurationSchema),
+            NestedField(AccessKeyConfigurationSchema),
+            NestedField(ApiKeyConfigurationSchema),
+            NestedField(AccountKeyConfigurationSchema),
+            NestedField(AadCredentialConfigurationSchema),
+            NestedField(NoneCredentialConfigurationSchema),
+        ],
+        required=False,
+        load_default=NoneCredentialConfiguration(),
+    )
+
+    is_shared = fields.Bool(load_default=True)
+    metadata = fields.Dict(required=False)
+
+    @post_load
+    def make(self, data, **kwargs):
+        from azure.ai.ml.entities import WorkspaceConnection
+
+        # Most non-subclassed connections default to a none credential if none
+        # is provided. ALDS Gen 2 connections default to AAD with this code.
+        if (
+            data.get("type", None) == ConnectionTypes.AZURE_DATA_LAKE_GEN_2
+            and data.get("credentials", None) == NoneCredentialConfiguration()
+        ):
+            data["credentials"] = AadCredentialConfiguration()
+        return WorkspaceConnection(**data)