diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint')
7 files changed, 187 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/__init__.py new file mode 100644 index 00000000..e9538cbb --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/__init__.py @@ -0,0 +1,15 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) + + +from .batch.batch_endpoint import BatchEndpointSchema +from .online.online_endpoint import KubernetesOnlineEndpointSchema, ManagedOnlineEndpointSchema + +__all__ = [ + "BatchEndpointSchema", + "KubernetesOnlineEndpointSchema", + "ManagedOnlineEndpointSchema", +] diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/batch/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/batch/__init__.py new file mode 100644 index 00000000..fdf8caba --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/batch/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/batch/batch_endpoint.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/batch/batch_endpoint.py new file mode 100644 index 00000000..0bee2493 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/batch/batch_endpoint.py @@ -0,0 +1,27 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +import logging +from typing import Any + +from marshmallow import post_load + +from azure.ai.ml._schema._endpoint.batch.batch_endpoint_defaults import BatchEndpointsDefaultsSchema +from azure.ai.ml._schema._endpoint.endpoint import EndpointSchema +from azure.ai.ml._schema.core.fields import NestedField +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY + +module_logger = logging.getLogger(__name__) + + +class BatchEndpointSchema(EndpointSchema): + defaults = NestedField(BatchEndpointsDefaultsSchema) + + @post_load + def make(self, data: Any, **kwargs: Any) -> Any: + from azure.ai.ml.entities import BatchEndpoint + + return BatchEndpoint(base_path=self.context[BASE_PATH_CONTEXT_KEY], **data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/batch/batch_endpoint_defaults.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/batch/batch_endpoint_defaults.py new file mode 100644 index 00000000..49699bb0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/batch/batch_endpoint_defaults.py @@ -0,0 +1,28 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +import logging +from typing import Any + +from marshmallow import fields, post_load + +from azure.ai.ml._restclient.v2023_10_01.models import BatchEndpointDefaults +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta + +module_logger = logging.getLogger(__name__) + + +class BatchEndpointsDefaultsSchema(metaclass=PatchedSchemaMeta): + deployment_name = fields.Str( + metadata={ + "description": """Name of the deployment that will be default for the endpoint. + This deployment will end up getting 100% traffic when the endpoint scoring URL is invoked.""" + } + ) + + @post_load + def make(self, data: Any, **kwargs: Any) -> Any: + return BatchEndpointDefaults(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/endpoint.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/endpoint.py new file mode 100644 index 00000000..1ff43338 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/endpoint.py @@ -0,0 +1,41 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import logging + +from marshmallow import fields, validate + +from azure.ai.ml._restclient.v2022_02_01_preview.models import EndpointAuthMode +from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum +from azure.ai.ml._schema.core.schema import PathAwareSchema +from azure.ai.ml._schema.identity import IdentitySchema +from azure.ai.ml._utils.utils import camel_to_snake +from azure.ai.ml.constants._endpoint import EndpointConfigurations + +module_logger = logging.getLogger(__name__) + + +class EndpointSchema(PathAwareSchema): + id = fields.Str() + name = fields.Str(required=True, validate=validate.Regexp(EndpointConfigurations.NAME_REGEX_PATTERN)) + description = fields.Str(metadata={"description": "Description of the inference endpoint."}) + tags = fields.Dict() + provisioning_state = fields.Str(metadata={"description": "Provisioning state for the endpoint."}) + properties = fields.Dict() + auth_mode = StringTransformedEnum( + allowed_values=[ + EndpointAuthMode.AML_TOKEN, + EndpointAuthMode.KEY, + EndpointAuthMode.AAD_TOKEN, + ], + casing_transform=camel_to_snake, + metadata={ + "description": """authentication method: no auth, key based or azure ml token based. + aad_token is only valid for batch endpoint.""" + }, + ) + scoring_uri = fields.Str(metadata={"description": "The endpoint uri that can be used for scoring"}) + location = fields.Str() + openapi_uri = fields.Str(metadata={"description": "Endpoint Open API URI."}) + identity = NestedField(IdentitySchema) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/online/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/online/__init__.py new file mode 100644 index 00000000..fdf8caba --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/online/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/online/online_endpoint.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/online/online_endpoint.py new file mode 100644 index 00000000..84b34636 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/online/online_endpoint.py @@ -0,0 +1,66 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +import logging +from typing import Any + +from marshmallow import ValidationError, fields, post_load, validates + +from azure.ai.ml._schema._endpoint.endpoint import EndpointSchema +from azure.ai.ml._schema.core.fields import ArmStr, StringTransformedEnum +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, AzureMLResourceType, PublicNetworkAccess + +module_logger = logging.getLogger(__name__) + + +class OnlineEndpointSchema(EndpointSchema): + traffic = fields.Dict( + keys=fields.Str(), + values=fields.Int(), + metadata={ + "description": """a dict with key as deployment name and value as traffic percentage. + The values need to sum to 100 """ + }, + ) + kind = fields.Str(dump_only=True) + + mirror_traffic = fields.Dict( + keys=fields.Str(), + values=fields.Int(), + metadata={ + "description": """a dict with key as deployment name and value as traffic percentage. + Only one key will be accepted and value needs to be less than or equal to 50%""" + }, + ) + + @validates("traffic") + def validate_traffic(self, data, **kwargs): + if sum(data.values()) > 100: + raise ValidationError("Traffic rule percentages must sum to less than or equal to 100%") + + +class KubernetesOnlineEndpointSchema(OnlineEndpointSchema): + provisioning_state = fields.Str(metadata={"description": "status of the deployment provisioning operation"}) + compute = ArmStr(azureml_type=AzureMLResourceType.COMPUTE) + + @post_load + def make(self, data: Any, **kwargs: Any) -> Any: + from azure.ai.ml.entities import KubernetesOnlineEndpoint + + return KubernetesOnlineEndpoint(base_path=self.context[BASE_PATH_CONTEXT_KEY], **data) + + +class ManagedOnlineEndpointSchema(OnlineEndpointSchema): + provisioning_state = fields.Str() + public_network_access = StringTransformedEnum( + allowed_values=[PublicNetworkAccess.ENABLED, PublicNetworkAccess.DISABLED] + ) + + @post_load + def make(self, data: Any, **kwargs: Any) -> Any: + from azure.ai.ml.entities import ManagedOnlineEndpoint + + return ManagedOnlineEndpoint(base_path=self.context[BASE_PATH_CONTEXT_KEY], **data) |