diff options
| author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
|---|---|---|
| committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
| commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
| tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_endpoint | |
| parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
| download | gn-ai-master.tar.gz | |
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_endpoint')
5 files changed, 993 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_endpoint/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_endpoint/__init__.py new file mode 100644 index 00000000..fdf8caba --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_endpoint/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_endpoint/_endpoint_helpers.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_endpoint/_endpoint_helpers.py new file mode 100644 index 00000000..5d62a229 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_endpoint/_endpoint_helpers.py @@ -0,0 +1,62 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import re +from typing import Any, Optional + +from azure.ai.ml.constants._endpoint import EndpointConfigurations +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException + + +def validate_endpoint_or_deployment_name(name: Optional[str], is_deployment: bool = False) -> None: + """Validates the name of an endpoint or a deployment + + A valid name of an endpoint or deployment: + + 1. Is between 3 and 32 characters long (inclusive of both ends of the range) + 2. Starts with a letter + 3. Is followed by 0 or more alphanumeric characters (`a-zA-Z0-9`) or hyphens (`-`) + 3. Ends with an alphanumeric character (`a-zA-Z0-9`) + + :param name: Either an endpoint or deployment name + :type name: str + :param is_deployment: Whether the name is a deployment name. Defaults to False + :type is_deployment: bool + """ + if name is None: + return + + type_str = "a deployment" if is_deployment else "an endpoint" + target = ErrorTarget.DEPLOYMENT if is_deployment else ErrorTarget.ENDPOINT + if len(name) < EndpointConfigurations.MIN_NAME_LENGTH or len(name) > EndpointConfigurations.MAX_NAME_LENGTH: + msg = f"The name for {type_str} must be at least 3 and at most 32 characters long (inclusive of both limits)." + raise ValidationException( + message=msg, + target=target, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + if not re.match(EndpointConfigurations.NAME_REGEX_PATTERN, name): + msg = f"""The name for {type_str} must start with an upper- or lowercase letter + and only consist of '-'s and alphanumeric characters.""" + raise ValidationException( + message=msg, + target=target, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + + +def validate_identity_type_defined(identity: Any) -> None: + if identity and not identity.type: + msg = "Identity type not found in provided yaml file." + raise ValidationException( + message=msg, + target=ErrorTarget.ENDPOINT, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.MISSING_FIELD, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_endpoint/batch_endpoint.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_endpoint/batch_endpoint.py new file mode 100644 index 00000000..4883c828 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_endpoint/batch_endpoint.py @@ -0,0 +1,134 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import logging +from os import PathLike +from pathlib import Path +from typing import IO, Any, AnyStr, Dict, Optional, Union + +from azure.ai.ml._restclient.v2023_10_01.models import BatchEndpoint as BatchEndpointData +from azure.ai.ml._restclient.v2023_10_01.models import BatchEndpointProperties as RestBatchEndpoint +from azure.ai.ml._schema._endpoint import BatchEndpointSchema +from azure.ai.ml._utils.utils import camel_to_snake, snake_to_camel +from azure.ai.ml.constants._common import AAD_TOKEN_YAML, BASE_PATH_CONTEXT_KEY, PARAMS_OVERRIDE_KEY +from azure.ai.ml.entities._endpoint._endpoint_helpers import validate_endpoint_or_deployment_name +from azure.ai.ml.entities._util import load_from_dict + +from .endpoint import Endpoint + +module_logger = logging.getLogger(__name__) + + +class BatchEndpoint(Endpoint): + """Batch endpoint entity. + + :param name: Name of the resource. + :type name: str + :param tags: Tag dictionary. Tags can be added, removed, and updated. + :type tags: dict[str, str] + :param properties: The asset property dictionary. + :type properties: dict[str, str] + :param auth_mode: Possible values include: "AMLToken", "Key", "AADToken", defaults to None + :type auth_mode: str + :param description: Description of the inference endpoint, defaults to None + :type description: str + :param location: defaults to None + :type location: str + :param defaults: Traffic rules on how the traffic will be routed across deployments, defaults to {} + :type defaults: Dict[str, str] + :param default_deployment_name: Equivalent to defaults.default_deployment, will be ignored if defaults is present. + :type default_deployment_name: str + :param scoring_uri: URI to use to perform a prediction, readonly. + :type scoring_uri: str + :param openapi_uri: URI to check the open API definition of the endpoint. + :type openapi_uri: str + """ + + def __init__( + self, + *, + name: Optional[str] = None, + tags: Optional[Dict] = None, + properties: Optional[Dict] = None, + auth_mode: str = AAD_TOKEN_YAML, + description: Optional[str] = None, + location: Optional[str] = None, + defaults: Optional[Dict[str, str]] = None, + default_deployment_name: Optional[str] = None, + scoring_uri: Optional[str] = None, + openapi_uri: Optional[str] = None, + **kwargs: Any, + ) -> None: + super(BatchEndpoint, self).__init__( + name=name, + tags=tags, + properties=properties, + auth_mode=auth_mode, + description=description, + location=location, + scoring_uri=scoring_uri, + openapi_uri=openapi_uri, + **kwargs, + ) + + self.defaults = defaults + + if not self.defaults and default_deployment_name: + self.defaults = {} + self.defaults["deployment_name"] = default_deployment_name + + def _to_rest_batch_endpoint(self, location: str) -> BatchEndpointData: + validate_endpoint_or_deployment_name(self.name) + batch_endpoint = RestBatchEndpoint( + description=self.description, + auth_mode=snake_to_camel(self.auth_mode), + properties=self.properties, + defaults=self.defaults, + ) + return BatchEndpointData(location=location, tags=self.tags, properties=batch_endpoint) + + @classmethod + def _from_rest_object(cls, obj: BatchEndpointData) -> "BatchEndpoint": + return BatchEndpoint( + id=obj.id, + name=obj.name, + tags=obj.tags, + properties=obj.properties.properties, + auth_mode=camel_to_snake(obj.properties.auth_mode), + description=obj.properties.description, + location=obj.location, + defaults=obj.properties.defaults, + provisioning_state=obj.properties.provisioning_state, + scoring_uri=obj.properties.scoring_uri, + openapi_uri=obj.properties.swagger_uri, + ) + + def dump( + self, + dest: Optional[Union[str, PathLike, IO[AnyStr]]] = None, + **kwargs: Any, + ) -> Dict[str, Any]: + context = {BASE_PATH_CONTEXT_KEY: Path(".").parent} + return BatchEndpointSchema(context=context).dump(self) # type: ignore + + @classmethod + def _load( + cls, + data: Optional[Dict] = None, + yaml_path: Optional[Union[PathLike, str]] = None, + params_override: Optional[list] = None, + **kwargs: Any, + ) -> "BatchEndpoint": + data = data or {} + params_override = params_override or [] + context = { + BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path.cwd(), + PARAMS_OVERRIDE_KEY: params_override, + } + res: BatchEndpoint = load_from_dict(BatchEndpointSchema, data, context) + return res + + def _to_dict(self) -> Dict: + res: dict = BatchEndpointSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + return res diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_endpoint/endpoint.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_endpoint/endpoint.py new file mode 100644 index 00000000..d878742e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_endpoint/endpoint.py @@ -0,0 +1,145 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import logging +from abc import abstractmethod +from os import PathLike +from typing import IO, Any, AnyStr, Dict, Optional, Union + +from azure.ai.ml.entities._resource import Resource +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException + +module_logger = logging.getLogger(__name__) + + +class Endpoint(Resource): # pylint: disable=too-many-instance-attributes + """Endpoint base class. + + :param auth_mode: The authentication mode, defaults to None + :type auth_mode: str + :param location: The location of the endpoint, defaults to None + :type location: str + :param name: Name of the resource. + :type name: str + :param tags: Tag dictionary. Tags can be added, removed, and updated. + :type tags: typing.Optional[typing.Dict[str, str]] + :param properties: The asset property dictionary. + :type properties: typing.Optional[typing.Dict[str, str]] + :param description: Description of the resource. + :type description: typing.Optional[str] + :keyword traffic: Traffic rules on how the traffic will be routed across deployments, defaults to {} + :paramtype traffic: typing.Optional[typing.Dict[str, int]] + :keyword scoring_uri: str, Endpoint URI, readonly + :paramtype scoring_uri: typing.Optional[str] + :keyword openapi_uri: str, Endpoint Open API URI, readonly + :paramtype openapi_uri: typing.Optional[str] + :keyword provisioning_state: str, provisioning state, readonly + :paramtype provisioning_state: typing.Optional[str] + """ + + def __init__( + self, + auth_mode: Optional[str] = None, + location: Optional[str] = None, + name: Optional[str] = None, + tags: Optional[Dict[str, str]] = None, + properties: Optional[Dict[str, Any]] = None, + description: Optional[str] = None, + **kwargs: Any, + ): + """Endpoint base class. + + Constructor for Endpoint base class. + + :param auth_mode: The authentication mode, defaults to None + :type auth_mode: str + :param location: The location of the endpoint, defaults to None + :type location: str + :param name: Name of the resource. + :type name: str + :param tags: Tag dictionary. Tags can be added, removed, and updated. + :type tags: typing.Optional[typing.Dict[str, str]] + :param properties: The asset property dictionary. + :type properties: typing.Optional[typing.Dict[str, str]] + :param description: Description of the resource. + :type description: typing.Optional[str] + :keyword traffic: Traffic rules on how the traffic will be routed across deployments, defaults to {} + :paramtype traffic: typing.Optional[typing.Dict[str, int]] + :keyword scoring_uri: str, Endpoint URI, readonly + :paramtype scoring_uri: typing.Optional[str] + :keyword openapi_uri: str, Endpoint Open API URI, readonly + :paramtype openapi_uri: typing.Optional[str] + :keyword provisioning_state: str, provisioning state, readonly + :paramtype provisioning_state: typing.Optional[str] + """ + # MFE is case-insensitive for Name. So convert the name into lower case here. + if name: + name = name.lower() + self._scoring_uri: Optional[str] = kwargs.pop("scoring_uri", None) + self._openapi_uri: Optional[str] = kwargs.pop("openapi_uri", None) + self._provisioning_state: Optional[str] = kwargs.pop("provisioning_state", None) + super().__init__(name, description, tags, properties, **kwargs) + self.auth_mode = auth_mode + self.location = location + + @property + def scoring_uri(self) -> Optional[str]: + """URI to use to perform a prediction, readonly. + + :return: The scoring URI + :rtype: typing.Optional[str] + """ + return self._scoring_uri + + @property + def openapi_uri(self) -> Optional[str]: + """URI to check the open api definition of the endpoint. + + :return: The open API URI + :rtype: typing.Optional[str] + """ + return self._openapi_uri + + @property + def provisioning_state(self) -> Optional[str]: + """Endpoint provisioning state, readonly. + + :return: Endpoint provisioning state. + :rtype: typing.Optional[str] + """ + return self._provisioning_state + + @abstractmethod + def dump(self, dest: Optional[Union[str, PathLike, IO[AnyStr]]] = None, **kwargs: Any) -> Dict: + pass + + @classmethod + @abstractmethod + def _from_rest_object(cls, obj: Any) -> Any: + pass + + def _merge_with(self, other: Any) -> None: + if other: + if self.name != other.name: + msg = "The endpoint name: {} and {} are not matched when merging." + raise ValidationException( + message=msg.format(self.name, other.name), + target=ErrorTarget.ENDPOINT, + no_personal_data_message=msg.format("[name1]", "[name2]"), + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + self.description = other.description or self.description + if other.tags: + if self.tags is not None: + self.tags = {**self.tags, **other.tags} + if other.properties: + self.properties = {**self.properties, **other.properties} + self.auth_mode = other.auth_mode or self.auth_mode + if hasattr(other, "traffic"): + self.traffic = other.traffic # pylint: disable=attribute-defined-outside-init + if hasattr(other, "mirror_traffic"): + self.mirror_traffic = other.mirror_traffic # pylint: disable=attribute-defined-outside-init + if hasattr(other, "defaults"): + self.defaults = other.defaults # pylint: disable=attribute-defined-outside-init diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_endpoint/online_endpoint.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_endpoint/online_endpoint.py new file mode 100644 index 00000000..cdd72536 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_endpoint/online_endpoint.py @@ -0,0 +1,647 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=no-member + +import logging +from os import PathLike +from pathlib import Path +from typing import IO, Any, AnyStr, Dict, Optional, Union, cast + +from azure.ai.ml._restclient.v2022_02_01_preview.models import EndpointAuthKeys as RestEndpointAuthKeys +from azure.ai.ml._restclient.v2022_02_01_preview.models import EndpointAuthMode +from azure.ai.ml._restclient.v2022_02_01_preview.models import EndpointAuthToken as RestEndpointAuthToken +from azure.ai.ml._restclient.v2022_02_01_preview.models import OnlineEndpointData +from azure.ai.ml._restclient.v2022_02_01_preview.models import OnlineEndpointDetails as RestOnlineEndpoint +from azure.ai.ml._restclient.v2022_05_01.models import ManagedServiceIdentity as RestManagedServiceIdentityConfiguration +from azure.ai.ml._schema._endpoint import KubernetesOnlineEndpointSchema, ManagedOnlineEndpointSchema +from azure.ai.ml._utils.utils import dict_eq +from azure.ai.ml.constants._common import ( + AAD_TOKEN_YAML, + AML_TOKEN_YAML, + BASE_PATH_CONTEXT_KEY, + KEY, + PARAMS_OVERRIDE_KEY, +) +from azure.ai.ml.constants._endpoint import EndpointYamlFields +from azure.ai.ml.entities._credentials import IdentityConfiguration +from azure.ai.ml.entities._mixins import RestTranslatableMixin +from azure.ai.ml.entities._util import is_compute_in_override, load_from_dict +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException +from azure.core.credentials import AccessToken + +from ._endpoint_helpers import validate_endpoint_or_deployment_name, validate_identity_type_defined +from .endpoint import Endpoint + +module_logger = logging.getLogger(__name__) + + +class OnlineEndpoint(Endpoint): + """Online endpoint entity. + + :keyword name: Name of the resource, defaults to None + :paramtype name: typing.Optional[str] + :keyword tags: Tag dictionary. Tags can be added, removed, and updated. defaults to None + :paramtype tags: typing.Optional[typing.Dict[str, typing.Any]] + :keyword properties: The asset property dictionary, defaults to None + :paramtype properties: typing.Optional[typing.Dict[str, typing.Any]] + :keyword auth_mode: Possible values include: "aml_token", "key", defaults to KEY + :type auth_mode: typing.Optional[str] + :keyword description: Description of the inference endpoint, defaults to None + :paramtype description: typing.Optional[str] + :keyword location: Location of the resource, defaults to None + :paramtype location: typing.Optional[str] + :keyword traffic: Traffic rules on how the traffic will be routed across deployments, defaults to None + :paramtype traffic: typing.Optional[typing.Dict[str, int]] + :keyword mirror_traffic: Duplicated live traffic used to inference a single deployment, defaults to None + :paramtype mirror_traffic: typing.Optional[typing.Dict[str, int]] + :keyword identity: Identity Configuration, defaults to SystemAssigned + :paramtype identity: typing.Optional[IdentityConfiguration] + :keyword scoring_uri: Scoring URI, defaults to None + :paramtype scoring_uri: typing.Optional[str] + :keyword openapi_uri: OpenAPI URI, defaults to None + :paramtype openapi_uri: typing.Optional[str] + :keyword provisioning_state: Provisioning state of an endpoint, defaults to None + :paramtype provisioning_state: typing.Optional[str] + :keyword kind: Kind of the resource, we have two kinds: K8s and Managed online endpoints, defaults to None + :paramtype kind: typing.Optional[str] + """ + + def __init__( + self, + *, + name: Optional[str] = None, + tags: Optional[Dict[str, Any]] = None, + properties: Optional[Dict[str, Any]] = None, + auth_mode: str = KEY, + description: Optional[str] = None, + location: Optional[str] = None, + traffic: Optional[Dict[str, int]] = None, + mirror_traffic: Optional[Dict[str, int]] = None, + identity: Optional[IdentityConfiguration] = None, + scoring_uri: Optional[str] = None, + openapi_uri: Optional[str] = None, + provisioning_state: Optional[str] = None, + kind: Optional[str] = None, + **kwargs: Any, + ): + """Online endpoint entity. + + Constructor for an Online endpoint entity. + + :keyword name: Name of the resource, defaults to None + :paramtype name: typing.Optional[str] + :keyword tags: Tag dictionary. Tags can be added, removed, and updated. defaults to None + :paramtype tags: typing.Optional[typing.Dict[str, typing.Any]] + :keyword properties: The asset property dictionary, defaults to None + :paramtype properties: typing.Optional[typing.Dict[str, typing.Any]] + :keyword auth_mode: Possible values include: "aml_token", "key", defaults to KEY + :type auth_mode: typing.Optional[str] + :keyword description: Description of the inference endpoint, defaults to None + :paramtype description: typing.Optional[str] + :keyword location: Location of the resource, defaults to None + :paramtype location: typing.Optional[str] + :keyword traffic: Traffic rules on how the traffic will be routed across deployments, defaults to None + :paramtype traffic: typing.Optional[typing.Dict[str, int]] + :keyword mirror_traffic: Duplicated live traffic used to inference a single deployment, defaults to None + :paramtype mirror_traffic: typing.Optional[typing.Dict[str, int]] + :keyword identity: Identity Configuration, defaults to SystemAssigned + :paramtype identity: typing.Optional[IdentityConfiguration] + :keyword scoring_uri: Scoring URI, defaults to None + :paramtype scoring_uri: typing.Optional[str] + :keyword openapi_uri: OpenAPI URI, defaults to None + :paramtype openapi_uri: typing.Optional[str] + :keyword provisioning_state: Provisioning state of an endpoint, defaults to None + :paramtype provisioning_state: typing.Optional[str] + :keyword kind: Kind of the resource, we have two kinds: K8s and Managed online endpoints, defaults to None + :type kind: typing.Optional[str] + """ + self._provisioning_state = kwargs.pop("provisioning_state", None) + + super(OnlineEndpoint, self).__init__( + name=name, + properties=properties, + tags=tags, + auth_mode=auth_mode, + description=description, + location=location, + scoring_uri=scoring_uri, + openapi_uri=openapi_uri, + provisioning_state=provisioning_state, + **kwargs, + ) + + self.identity = identity + self.traffic: Dict = dict(traffic) if traffic else {} + self.mirror_traffic: Dict = dict(mirror_traffic) if mirror_traffic else {} + self.kind = kind + + @property + def provisioning_state(self) -> Optional[str]: + """Endpoint provisioning state, readonly. + + :return: Endpoint provisioning state. + :rtype: typing.Optional[str] + """ + return self._provisioning_state + + def _to_rest_online_endpoint(self, location: str) -> OnlineEndpointData: + # pylint: disable=protected-access + identity = ( + self.identity._to_online_endpoint_rest_object() + if self.identity + else RestManagedServiceIdentityConfiguration(type="SystemAssigned") + ) + validate_endpoint_or_deployment_name(self.name) + validate_identity_type_defined(self.identity) + properties = RestOnlineEndpoint( + description=self.description, + auth_mode=OnlineEndpoint._yaml_auth_mode_to_rest_auth_mode(self.auth_mode), + properties=self.properties, + traffic=self.traffic, + mirror_traffic=self.mirror_traffic, + ) + + if hasattr(self, "public_network_access") and self.public_network_access: + properties.public_network_access = self.public_network_access + return OnlineEndpointData( + location=location, + properties=properties, + identity=identity, + tags=self.tags, + ) + + def _to_rest_online_endpoint_traffic_update(self, location: str, no_validation: bool = False) -> OnlineEndpointData: + if not no_validation: + # validate_deployment_name_matches_traffic(self.deployments, self.traffic) + validate_identity_type_defined(self.identity) + # validate_uniqueness_of_deployment_names(self.deployments) + properties = RestOnlineEndpoint( + description=self.description, + auth_mode=OnlineEndpoint._yaml_auth_mode_to_rest_auth_mode(self.auth_mode), + endpoint=self.name, + traffic=self.traffic, + properties=self.properties, + ) + return OnlineEndpointData( + location=location, + properties=properties, + identity=self.identity, + tags=self.tags, + ) + + @classmethod + def _rest_auth_mode_to_yaml_auth_mode(cls, rest_auth_mode: str) -> str: + switcher = { + EndpointAuthMode.AML_TOKEN: AML_TOKEN_YAML, + EndpointAuthMode.AAD_TOKEN: AAD_TOKEN_YAML, + EndpointAuthMode.KEY: KEY, + } + + return switcher.get(rest_auth_mode, rest_auth_mode) + + @classmethod + def _yaml_auth_mode_to_rest_auth_mode(cls, yaml_auth_mode: Optional[str]) -> str: + if yaml_auth_mode is None: + return "" + + yaml_auth_mode = yaml_auth_mode.lower() + + switcher = { + AML_TOKEN_YAML: EndpointAuthMode.AML_TOKEN, + AAD_TOKEN_YAML: EndpointAuthMode.AAD_TOKEN, + KEY: EndpointAuthMode.KEY, + } + + return switcher.get(yaml_auth_mode, yaml_auth_mode) + + @classmethod + def _from_rest_object(cls, obj: OnlineEndpointData) -> "OnlineEndpoint": + auth_mode = cls._rest_auth_mode_to_yaml_auth_mode(obj.properties.auth_mode) + # pylint: disable=protected-access + identity = IdentityConfiguration._from_online_endpoint_rest_object(obj.identity) if obj.identity else None + + endpoint: Any = KubernetesOnlineEndpoint() + + if obj.system_data: + properties_dict = { + "createdBy": obj.system_data.created_by, + "createdAt": obj.system_data.created_at.strftime("%Y-%m-%dT%H:%M:%S.%f%z"), + "lastModifiedAt": obj.system_data.last_modified_at.strftime("%Y-%m-%dT%H:%M:%S.%f%z"), + } + properties_dict.update(obj.properties.properties) + else: + properties_dict = obj.properties.properties + + if obj.properties.compute: + endpoint = KubernetesOnlineEndpoint( + id=obj.id, + name=obj.name, + tags=obj.tags, + properties=properties_dict, + compute=obj.properties.compute, + auth_mode=auth_mode, + description=obj.properties.description, + location=obj.location, + traffic=obj.properties.traffic, + provisioning_state=obj.properties.provisioning_state, + scoring_uri=obj.properties.scoring_uri, + openapi_uri=obj.properties.swagger_uri, + identity=identity, + kind=obj.kind, + ) + else: + endpoint = ManagedOnlineEndpoint( + id=obj.id, + name=obj.name, + tags=obj.tags, + properties=properties_dict, + auth_mode=auth_mode, + description=obj.properties.description, + location=obj.location, + traffic=obj.properties.traffic, + mirror_traffic=obj.properties.mirror_traffic, + provisioning_state=obj.properties.provisioning_state, + scoring_uri=obj.properties.scoring_uri, + openapi_uri=obj.properties.swagger_uri, + identity=identity, + kind=obj.kind, + public_network_access=obj.properties.public_network_access, + ) + + return cast(OnlineEndpoint, endpoint) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, OnlineEndpoint): + return NotImplemented + if not other: + return False + if self.auth_mode is None or other.auth_mode is None: + return False + + if self.name is None and other.name is None: + return ( + self.auth_mode.lower() == other.auth_mode.lower() + and dict_eq(self.tags, other.tags) + and self.description == other.description + and dict_eq(self.traffic, other.traffic) + ) + + if self.name is not None and other.name is not None: + # only compare mutable fields + return ( + self.name.lower() == other.name.lower() + and self.auth_mode.lower() == other.auth_mode.lower() + and dict_eq(self.tags, other.tags) + and self.description == other.description + and dict_eq(self.traffic, other.traffic) + ) + + return False + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) + + @classmethod + def _load( + cls, + data: Optional[Dict] = None, + yaml_path: Optional[Union[PathLike, str]] = None, + params_override: Optional[list] = None, + **kwargs: Any, + ) -> "Endpoint": + data = data or {} + params_override = params_override or [] + context = { + BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path.cwd(), + PARAMS_OVERRIDE_KEY: params_override, + } + + if data.get(EndpointYamlFields.COMPUTE) or is_compute_in_override(params_override): + res_kub: Endpoint = load_from_dict(KubernetesOnlineEndpointSchema, data, context) + return res_kub + + res_managed: Endpoint = load_from_dict(ManagedOnlineEndpointSchema, data, context) + return res_managed + + +class KubernetesOnlineEndpoint(OnlineEndpoint): + """K8s Online endpoint entity. + + :keyword name: Name of the resource, defaults to None + :paramtype name: typing.Optional[str] + :keyword tags: Tag dictionary. Tags can be added, removed, and updated, defaults to None + :paramtype tags: typing.Optional[typing.Dict[str, typing.Any]] + :keyword properties: The asset property dictionary, defaults to None + :paramtype properties: typing.Optional[typing.Dict[str, typing.Any]] + :keyword auth_mode: Possible values include: "aml_token", "key", defaults to KEY + :type auth_mode: typing.Optional[str] + :keyword description: Description of the inference endpoint, defaults to None + :paramtype description: typing.Optional[str] + :keyword location: Location of the resource, defaults to None + :paramtype location: typing.Optional[str] + :keyword traffic: Traffic rules on how the traffic will be routed across deployments, defaults to None + :paramtype traffic: typing.Optional[typing.Dict[str, int]] + :keyword mirror_traffic: Duplicated live traffic used to inference a single deployment, defaults to None + :paramtype mirror_traffic: typing.Optional[typing.Dict[str, int]] + :keyword compute: Compute cluster id, defaults to None + :paramtype compute: typing.Optional[str] + :keyword identity: Identity Configuration, defaults to SystemAssigned + :paramtype identity: typing.Optional[IdentityConfiguration] + :keyword kind: Kind of the resource, we have two kinds: K8s and Managed online endpoints, defaults to None + :paramtype kind: typing.Optional[str] + """ + + def __init__( + self, + *, + name: Optional[str] = None, + tags: Optional[Dict[str, Any]] = None, + properties: Optional[Dict[str, Any]] = None, + auth_mode: str = KEY, + description: Optional[str] = None, + location: Optional[str] = None, + traffic: Optional[Dict[str, int]] = None, + mirror_traffic: Optional[Dict[str, int]] = None, + compute: Optional[str] = None, + identity: Optional[IdentityConfiguration] = None, + kind: Optional[str] = None, + **kwargs: Any, + ): + """K8s Online endpoint entity. + + Constructor for K8s Online endpoint entity. + + :keyword name: Name of the resource, defaults to None + :paramtype name: typing.Optional[str] + :keyword tags: Tag dictionary. Tags can be added, removed, and updated, defaults to None + :paramtype tags: typing.Optional[typing.Dict[str, typing.Any]] + :keyword properties: The asset property dictionary, defaults to None + :paramtype properties: typing.Optional[typing.Dict[str, typing.Any]] + :keyword auth_mode: Possible values include: "aml_token", "key", defaults to KEY + :type auth_mode: typing.Optional[str] + :keyword description: Description of the inference endpoint, defaults to None + :paramtype description: typing.Optional[str] + :keyword location: Location of the resource, defaults to None + :paramtype location: typing.Optional[str] + :keyword traffic: Traffic rules on how the traffic will be routed across deployments, defaults to None + :paramtype traffic: typing.Optional[typing.Dict[str, int]] + :keyword mirror_traffic: Duplicated live traffic used to inference a single deployment, defaults to None + :paramtype mirror_traffic: typing.Optional[typing.Dict[str, int]] + :keyword compute: Compute cluster id, defaults to None + :paramtype compute: typing.Optional[str] + :keyword identity: Identity Configuration, defaults to SystemAssigned + :paramtype identity: typing.Optional[IdentityConfiguration] + :keyword kind: Kind of the resource, we have two kinds: K8s and Managed online endpoints, defaults to None + :type kind: typing.Optional[str] + """ + super(KubernetesOnlineEndpoint, self).__init__( + name=name, + properties=properties, + tags=tags, + auth_mode=auth_mode, + description=description, + location=location, + traffic=traffic, + mirror_traffic=mirror_traffic, + identity=identity, + kind=kind, + **kwargs, + ) + + self.compute = compute + + def dump( + self, + dest: Optional[Union[str, PathLike, IO[AnyStr]]] = None, + **kwargs: Any, + ) -> Dict[str, Any]: + context = {BASE_PATH_CONTEXT_KEY: Path(".").parent} + res: dict = KubernetesOnlineEndpointSchema(context=context).dump(self) + return res + + def _to_rest_online_endpoint(self, location: str) -> OnlineEndpointData: + resource = super()._to_rest_online_endpoint(location) + resource.properties.compute = self.compute + return resource + + def _to_rest_online_endpoint_traffic_update(self, location: str, no_validation: bool = False) -> OnlineEndpointData: + resource = super()._to_rest_online_endpoint_traffic_update(location, no_validation) + resource.properties.compute = self.compute + return resource + + def _merge_with(self, other: "KubernetesOnlineEndpoint") -> None: + if other: + if self.name != other.name: + msg = "The endpoint name: {} and {} are not matched when merging." + raise ValidationException( + message=msg.format(self.name, other.name), + target=ErrorTarget.ONLINE_ENDPOINT, + no_personal_data_message=msg.format("[name1]", "[name2]"), + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + super()._merge_with(other) + self.compute = other.compute or self.compute + + def _to_dict(self) -> Dict: + res: dict = KubernetesOnlineEndpointSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + return res + + +class ManagedOnlineEndpoint(OnlineEndpoint): + """Managed Online endpoint entity. + + :keyword name: Name of the resource, defaults to None + :paramtype name: typing.Optional[str] + :keyword tags: Tag dictionary. Tags can be added, removed, and updated, defaults to None + :paramtype tags: typing.Optional[typing.Dict[str, typing.Any]] + :keyword properties: The asset property dictionary, defaults to None + :paramtype properties: typing.Optional[typing.Dict[str, typing.Any]] + :keyword auth_mode: Possible values include: "aml_token", "key", defaults to KEY + :type auth_mode: str + :keyword description: Description of the inference endpoint, defaults to None + :paramtype description: typing.Optional[str] + :keyword location: Location of the resource, defaults to None + :paramtype location: typing.Optional[str] + :keyword traffic: Traffic rules on how the traffic will be routed across deployments, defaults to None + :paramtype traffic: typing.Optional[typing.Dict[str, int]] + :keyword mirror_traffic: Duplicated live traffic used to inference a single deployment, defaults to None + :paramtype mirror_traffic: typing.Optional[typing.Dict[str, int]] + :keyword identity: Identity Configuration, defaults to SystemAssigned + :paramtype identity: typing.Optional[IdentityConfiguration] + :keyword kind: Kind of the resource, we have two kinds: K8s and Managed online endpoints, defaults to None. + :paramtype kind: typing.Optional[str] + :keyword public_network_access: Whether to allow public endpoint connectivity, defaults to None + Allowed values are: "enabled", "disabled" + :type public_network_access: typing.Optional[str] + """ + + def __init__( + self, + *, + name: Optional[str] = None, + tags: Optional[Dict[str, Any]] = None, + properties: Optional[Dict[str, Any]] = None, + auth_mode: str = KEY, + description: Optional[str] = None, + location: Optional[str] = None, + traffic: Optional[Dict[str, int]] = None, + mirror_traffic: Optional[Dict[str, int]] = None, + identity: Optional[IdentityConfiguration] = None, + kind: Optional[str] = None, + public_network_access: Optional[str] = None, + **kwargs: Any, + ): + """Managed Online endpoint entity. + + Constructor for Managed Online endpoint entity. + + :keyword name: Name of the resource, defaults to None + :paramtype name: typing.Optional[str] + :keyword tags: Tag dictionary. Tags can be added, removed, and updated, defaults to None + :paramtype tags: typing.Optional[typing.Dict[str, typing.Any]] + :keyword properties: The asset property dictionary, defaults to None + :paramtype properties: typing.Optional[typing.Dict[str, typing.Any]] + :keyword auth_mode: Possible values include: "aml_token", "key", defaults to KEY + :type auth_mode: str + :keyword description: Description of the inference endpoint, defaults to None + :paramtype description: typing.Optional[str] + :keyword location: Location of the resource, defaults to None + :paramtype location: typing.Optional[str] + :keyword traffic: Traffic rules on how the traffic will be routed across deployments, defaults to None + :paramtype traffic: typing.Optional[typing.Dict[str, int]] + :keyword mirror_traffic: Duplicated live traffic used to inference a single deployment, defaults to None + :paramtype mirror_traffic: typing.Optional[typing.Dict[str, int]] + :keyword identity: Identity Configuration, defaults to SystemAssigned + :paramtype identity: typing.Optional[IdentityConfiguration] + :keyword kind: Kind of the resource, we have two kinds: K8s and Managed online endpoints, defaults to None. + :type kind: typing.Optional[str] + :keyword public_network_access: Whether to allow public endpoint connectivity, defaults to None + Allowed values are: "enabled", "disabled" + :type public_network_access: typing.Optional[str] + """ + self.public_network_access = public_network_access + + super(ManagedOnlineEndpoint, self).__init__( + name=name, + properties=properties, + tags=tags, + auth_mode=auth_mode, + description=description, + location=location, + traffic=traffic, + mirror_traffic=mirror_traffic, + identity=identity, + kind=kind, + **kwargs, + ) + + def dump( + self, + dest: Optional[Union[str, PathLike, IO[AnyStr]]] = None, + **kwargs: Any, + ) -> Dict[str, Any]: + context = {BASE_PATH_CONTEXT_KEY: Path(".").parent} + res: dict = ManagedOnlineEndpointSchema(context=context).dump(self) + return res + + def _to_dict(self) -> Dict: + res: dict = ManagedOnlineEndpointSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + return res + + +class EndpointAuthKeys(RestTranslatableMixin): + """Keys for endpoint authentication. + + :ivar primary_key: The primary key. + :vartype primary_key: str + :ivar secondary_key: The secondary key. + :vartype secondary_key: str + """ + + def __init__(self, **kwargs: Any): + """Constructor for keys for endpoint authentication. + + :keyword primary_key: The primary key. + :paramtype primary_key: str + :keyword secondary_key: The secondary key. + :paramtype secondary_key: str + """ + self.primary_key = kwargs.get("primary_key", None) + self.secondary_key = kwargs.get("secondary_key", None) + + @classmethod + def _from_rest_object(cls, obj: RestEndpointAuthKeys) -> "EndpointAuthKeys": + return cls(primary_key=obj.primary_key, secondary_key=obj.secondary_key) + + def _to_rest_object(self) -> RestEndpointAuthKeys: + return RestEndpointAuthKeys(primary_key=self.primary_key, secondary_key=self.secondary_key) + + +class EndpointAuthToken(RestTranslatableMixin): + """Endpoint authentication token. + + :ivar access_token: Access token for endpoint authentication. + :vartype access_token: str + :ivar expiry_time_utc: Access token expiry time (UTC). + :vartype expiry_time_utc: float + :ivar refresh_after_time_utc: Refresh access token after time (UTC). + :vartype refresh_after_time_utc: float + :ivar token_type: Access token type. + :vartype token_type: str + """ + + def __init__(self, **kwargs: Any): + """Constuctor for Endpoint authentication token. + + :keyword access_token: Access token for endpoint authentication. + :paramtype access_token: str + :keyword expiry_time_utc: Access token expiry time (UTC). + :paramtype expiry_time_utc: float + :keyword refresh_after_time_utc: Refresh access token after time (UTC). + :paramtype refresh_after_time_utc: float + :keyword token_type: Access token type. + :paramtype token_type: str + """ + self.access_token = kwargs.get("access_token", None) + self.expiry_time_utc = kwargs.get("expiry_time_utc", 0) + self.refresh_after_time_utc = kwargs.get("refresh_after_time_utc", 0) + self.token_type = kwargs.get("token_type", None) + + @classmethod + def _from_rest_object(cls, obj: RestEndpointAuthToken) -> "EndpointAuthToken": + return cls( + access_token=obj.access_token, + expiry_time_utc=obj.expiry_time_utc, + refresh_after_time_utc=obj.refresh_after_time_utc, + token_type=obj.token_type, + ) + + def _to_rest_object(self) -> RestEndpointAuthToken: + return RestEndpointAuthToken( + access_token=self.access_token, + expiry_time_utc=self.expiry_time_utc, + refresh_after_time_utc=self.refresh_after_time_utc, + token_type=self.token_type, + ) + + +class EndpointAadToken: + """Endpoint aad token. + + :ivar access_token: Access token for aad authentication. + :vartype access_token: str + :ivar expiry_time_utc: Access token expiry time (UTC). + :vartype expiry_time_utc: float + """ + + def __init__(self, obj: AccessToken): + """Constructor for Endpoint aad token. + + :param obj: Access token object + :type obj: AccessToken + """ + self.access_token = obj.token + self.expiry_time_utc = obj.expires_on |
