diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/online')
| -rw-r--r-- | .venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/online/__init__.py | 5 | ||||
| -rw-r--r-- | .venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/online/online_endpoint.py | 66 |
2 files changed, 71 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/online/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/online/__init__.py new file mode 100644 index 00000000..fdf8caba --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/online/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/online/online_endpoint.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/online/online_endpoint.py new file mode 100644 index 00000000..84b34636 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/online/online_endpoint.py @@ -0,0 +1,66 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +import logging +from typing import Any + +from marshmallow import ValidationError, fields, post_load, validates + +from azure.ai.ml._schema._endpoint.endpoint import EndpointSchema +from azure.ai.ml._schema.core.fields import ArmStr, StringTransformedEnum +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, AzureMLResourceType, PublicNetworkAccess + +module_logger = logging.getLogger(__name__) + + +class OnlineEndpointSchema(EndpointSchema): + traffic = fields.Dict( + keys=fields.Str(), + values=fields.Int(), + metadata={ + "description": """a dict with key as deployment name and value as traffic percentage. + The values need to sum to 100 """ + }, + ) + kind = fields.Str(dump_only=True) + + mirror_traffic = fields.Dict( + keys=fields.Str(), + values=fields.Int(), + metadata={ + "description": """a dict with key as deployment name and value as traffic percentage. + Only one key will be accepted and value needs to be less than or equal to 50%""" + }, + ) + + @validates("traffic") + def validate_traffic(self, data, **kwargs): + if sum(data.values()) > 100: + raise ValidationError("Traffic rule percentages must sum to less than or equal to 100%") + + +class KubernetesOnlineEndpointSchema(OnlineEndpointSchema): + provisioning_state = fields.Str(metadata={"description": "status of the deployment provisioning operation"}) + compute = ArmStr(azureml_type=AzureMLResourceType.COMPUTE) + + @post_load + def make(self, data: Any, **kwargs: Any) -> Any: + from azure.ai.ml.entities import KubernetesOnlineEndpoint + + return KubernetesOnlineEndpoint(base_path=self.context[BASE_PATH_CONTEXT_KEY], **data) + + +class ManagedOnlineEndpointSchema(OnlineEndpointSchema): + provisioning_state = fields.Str() + public_network_access = StringTransformedEnum( + allowed_values=[PublicNetworkAccess.ENABLED, PublicNetworkAccess.DISABLED] + ) + + @post_load + def make(self, data: Any, **kwargs: Any) -> Any: + from azure.ai.ml.entities import ManagedOnlineEndpoint + + return ManagedOnlineEndpoint(base_path=self.context[BASE_PATH_CONTEXT_KEY], **data) |
