diff options
| author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
|---|---|---|
| committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
| commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
| tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/job_service.py | |
| parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
| download | gn-ai-master.tar.gz | |
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/job_service.py')
| -rw-r--r-- | .venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/job_service.py | 424 |
1 files changed, 424 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/job_service.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/job_service.py new file mode 100644 index 00000000..a97048fc --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/job_service.py @@ -0,0 +1,424 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +import logging +from typing import Any, Dict, Optional, cast + +from typing_extensions import Literal + +from azure.ai.ml._restclient.v2023_04_01_preview.models import AllNodes +from azure.ai.ml._restclient.v2023_04_01_preview.models import JobService as RestJobService +from azure.ai.ml.constants._job.job import JobServiceTypeNames +from azure.ai.ml.entities._mixins import DictMixin, RestTranslatableMixin +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException + +module_logger = logging.getLogger(__name__) + + +class JobServiceBase(RestTranslatableMixin, DictMixin): + """Base class for job service configuration. + + This class should not be instantiated directly. Instead, use one of its subclasses. + + :keyword endpoint: The endpoint URL. + :paramtype endpoint: Optional[str] + :keyword type: The endpoint type. Accepted values are "jupyter_lab", "ssh", "tensor_board", and "vs_code". + :paramtype type: Optional[Literal["jupyter_lab", "ssh", "tensor_board", "vs_code"]] + :keyword port: The port for the endpoint. + :paramtype port: Optional[int] + :keyword nodes: Indicates whether the service has to run in all nodes. + :paramtype nodes: Optional[Literal["all"]] + :keyword properties: Additional properties to set on the endpoint. + :paramtype properties: Optional[dict[str, str]] + :keyword status: The status of the endpoint. + :paramtype status: Optional[str] + :keyword kwargs: A dictionary of additional configuration parameters. + :paramtype kwargs: dict + """ + + def __init__( # pylint: disable=unused-argument + self, + *, + endpoint: Optional[str] = None, + type: Optional[ # pylint: disable=redefined-builtin + Literal["jupyter_lab", "ssh", "tensor_board", "vs_code"] + ] = None, + nodes: Optional[Literal["all"]] = None, + status: Optional[str] = None, + port: Optional[int] = None, + properties: Optional[Dict[str, str]] = None, + **kwargs: Dict, + ) -> None: + self.endpoint = endpoint + self.type: Any = type + self.nodes = nodes + self.status = status + self.port = port + self.properties = properties + self._validate_nodes() + self._validate_type_name() + + def _validate_nodes(self) -> None: + if not self.nodes in ["all", None]: + msg = f"nodes should be either 'all' or None, but received '{self.nodes}'." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.JOB, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + + def _validate_type_name(self) -> None: + if self.type and not self.type in JobServiceTypeNames.ENTITY_TO_REST: + msg = ( + f"type should be one of " f"{JobServiceTypeNames.NAMES_ALLOWED_FOR_PUBLIC}, but received '{self.type}'." + ) + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.JOB, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + + def _to_rest_job_service(self, updated_properties: Optional[Dict[str, str]] = None) -> RestJobService: + return RestJobService( + endpoint=self.endpoint, + job_service_type=JobServiceTypeNames.ENTITY_TO_REST.get(self.type, None) if self.type else None, + nodes=AllNodes() if self.nodes else None, + status=self.status, + port=self.port, + properties=updated_properties if updated_properties else self.properties, + ) + + @classmethod + def _to_rest_job_services( + cls, + services: Optional[Dict], + ) -> Optional[Dict[str, RestJobService]]: + if services is None: + return None + + return {name: service._to_rest_object() for name, service in services.items()} + + @classmethod + def _from_rest_job_service_object(cls, obj: RestJobService) -> "JobServiceBase": + return cls( + endpoint=obj.endpoint, + type=( + JobServiceTypeNames.REST_TO_ENTITY.get(obj.job_service_type, None) # type: ignore[arg-type] + if obj.job_service_type + else None + ), + nodes="all" if obj.nodes else None, + status=obj.status, + port=obj.port, + # ssh_public_keys=_get_property(obj.properties, "sshPublicKeys"), + properties=obj.properties, + ) + + @classmethod + def _from_rest_job_services(cls, services: Dict[str, RestJobService]) -> Dict: + # """Resolve Dict[str, RestJobService] to Dict[str, Specific JobService]""" + if services is None: + return None + + result: dict = {} + for name, service in services.items(): + if service.job_service_type == JobServiceTypeNames.RestNames.JUPYTER_LAB: + result[name] = JupyterLabJobService._from_rest_object(service) + elif service.job_service_type == JobServiceTypeNames.RestNames.SSH: + result[name] = SshJobService._from_rest_object(service) + elif service.job_service_type == JobServiceTypeNames.RestNames.TENSOR_BOARD: + result[name] = TensorBoardJobService._from_rest_object(service) + elif service.job_service_type == JobServiceTypeNames.RestNames.VS_CODE: + result[name] = VsCodeJobService._from_rest_object(service) + else: + result[name] = JobService._from_rest_object(service) + return result + + +class JobService(JobServiceBase): + """Basic job service configuration for backward compatibility. + + This class is not intended to be used directly. Instead, use one of its subclasses specific to your job type. + + :keyword endpoint: The endpoint URL. + :paramtype endpoint: Optional[str] + :keyword type: The endpoint type. Accepted values are "jupyter_lab", "ssh", "tensor_board", and "vs_code". + :paramtype type: Optional[Literal["jupyter_lab", "ssh", "tensor_board", "vs_code"]] + :keyword port: The port for the endpoint. + :paramtype port: Optional[int] + :keyword nodes: Indicates whether the service has to run in all nodes. + :paramtype nodes: Optional[Literal["all"]] + :keyword properties: Additional properties to set on the endpoint. + :paramtype properties: Optional[dict[str, str]] + :keyword status: The status of the endpoint. + :paramtype status: Optional[str] + :keyword kwargs: A dictionary of additional configuration parameters. + :paramtype kwargs: dict + """ + + @classmethod + def _from_rest_object(cls, obj: RestJobService) -> "JobService": + return cast(JobService, cls._from_rest_job_service_object(obj)) + + def _to_rest_object(self) -> RestJobService: + return self._to_rest_job_service() + + +class SshJobService(JobServiceBase): + """SSH job service configuration. + + :ivar type: Specifies the type of job service. Set automatically to "ssh" for this class. + :vartype type: str + :keyword endpoint: The endpoint URL. + :paramtype endpoint: Optional[str] + :keyword port: The port for the endpoint. + :paramtype port: Optional[int] + :keyword nodes: Indicates whether the service has to run in all nodes. + :paramtype nodes: Optional[Literal["all"]] + :keyword properties: Additional properties to set on the endpoint. + :paramtype properties: Optional[dict[str, str]] + :keyword status: The status of the endpoint. + :paramtype status: Optional[str] + :keyword ssh_public_keys: The SSH Public Key to access the job container. + :paramtype ssh_public_keys: Optional[str] + :keyword kwargs: A dictionary of additional configuration parameters. + :paramtype kwargs: dict + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START ssh_job_service_configuration] + :end-before: [END ssh_job_service_configuration] + :language: python + :dedent: 8 + :caption: Configuring a SshJobService configuration on a command job. + """ + + def __init__( + self, + *, + endpoint: Optional[str] = None, + nodes: Optional[Literal["all"]] = None, + status: Optional[str] = None, + port: Optional[int] = None, + ssh_public_keys: Optional[str] = None, + properties: Optional[Dict[str, str]] = None, + **kwargs: Any, + ) -> None: + super().__init__( + endpoint=endpoint, + nodes=nodes, + status=status, + port=port, + properties=properties, + **kwargs, + ) + self.type = JobServiceTypeNames.EntityNames.SSH + self.ssh_public_keys = ssh_public_keys + + @classmethod + def _from_rest_object(cls, obj: RestJobService) -> "SshJobService": + ssh_job_service = cast(SshJobService, cls._from_rest_job_service_object(obj)) + ssh_job_service.ssh_public_keys = _get_property(obj.properties, "sshPublicKeys") + return ssh_job_service + + def _to_rest_object(self) -> RestJobService: + updated_properties = _append_or_update_properties(self.properties, "sshPublicKeys", self.ssh_public_keys) + return self._to_rest_job_service(updated_properties) + + +class TensorBoardJobService(JobServiceBase): + """TensorBoard job service configuration. + + :ivar type: Specifies the type of job service. Set automatically to "tensor_board" for this class. + :vartype type: str + :keyword endpoint: The endpoint URL. + :paramtype endpoint: Optional[str] + :keyword port: The port for the endpoint. + :paramtype port: Optional[int] + :keyword nodes: Indicates whether the service has to run in all nodes. + :paramtype nodes: Optional[Literal["all"]] + :keyword properties: Additional properties to set on the endpoint. + :paramtype properties: Optional[dict[str, str]] + :keyword status: The status of the endpoint. + :paramtype status: Optional[str] + :keyword log_dir: The directory path for the log file. + :paramtype log_dir: Optional[str] + :keyword kwargs: A dictionary of additional configuration parameters. + :paramtype kwargs: dict + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START ssh_job_service_configuration] + :end-before: [END ssh_job_service_configuration] + :language: python + :dedent: 8 + :caption: Configuring TensorBoardJobService configuration on a command job. + """ + + def __init__( + self, + *, + endpoint: Optional[str] = None, + nodes: Optional[Literal["all"]] = None, + status: Optional[str] = None, + port: Optional[int] = None, + log_dir: Optional[str] = None, + properties: Optional[Dict[str, str]] = None, + **kwargs: Any, + ) -> None: + super().__init__( + endpoint=endpoint, + nodes=nodes, + status=status, + port=port, + properties=properties, + **kwargs, + ) + self.type = JobServiceTypeNames.EntityNames.TENSOR_BOARD + self.log_dir = log_dir + + @classmethod + def _from_rest_object(cls, obj: RestJobService) -> "TensorBoardJobService": + tensorboard_job_Service = cast(TensorBoardJobService, cls._from_rest_job_service_object(obj)) + tensorboard_job_Service.log_dir = _get_property(obj.properties, "logDir") + return tensorboard_job_Service + + def _to_rest_object(self) -> RestJobService: + updated_properties = _append_or_update_properties(self.properties, "logDir", self.log_dir) + return self._to_rest_job_service(updated_properties) + + +class JupyterLabJobService(JobServiceBase): + """JupyterLab job service configuration. + + :ivar type: Specifies the type of job service. Set automatically to "jupyter_lab" for this class. + :vartype type: str + :keyword endpoint: The endpoint URL. + :paramtype endpoint: Optional[str] + :keyword port: The port for the endpoint. + :paramtype port: Optional[int] + :keyword nodes: Indicates whether the service has to run in all nodes. + :paramtype nodes: Optional[Literal["all"]] + :keyword properties: Additional properties to set on the endpoint. + :paramtype properties: Optional[dict[str, str]] + :keyword status: The status of the endpoint. + :paramtype status: Optional[str] + :keyword kwargs: A dictionary of additional configuration parameters. + :paramtype kwargs: dict + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START ssh_job_service_configuration] + :end-before: [END ssh_job_service_configuration] + :language: python + :dedent: 8 + :caption: Configuring JupyterLabJobService configuration on a command job. + """ + + def __init__( + self, + *, + endpoint: Optional[str] = None, + nodes: Optional[Literal["all"]] = None, + status: Optional[str] = None, + port: Optional[int] = None, + properties: Optional[Dict[str, str]] = None, + **kwargs: Any, + ) -> None: + super().__init__( + endpoint=endpoint, + nodes=nodes, + status=status, + port=port, + properties=properties, + **kwargs, + ) + self.type = JobServiceTypeNames.EntityNames.JUPYTER_LAB + + @classmethod + def _from_rest_object(cls, obj: RestJobService) -> "JupyterLabJobService": + return cast(JupyterLabJobService, cls._from_rest_job_service_object(obj)) + + def _to_rest_object(self) -> RestJobService: + return self._to_rest_job_service() + + +class VsCodeJobService(JobServiceBase): + """VS Code job service configuration. + + :ivar type: Specifies the type of job service. Set automatically to "vs_code" for this class. + :vartype type: str + :keyword endpoint: The endpoint URL. + :paramtype endpoint: Optional[str] + :keyword port: The port for the endpoint. + :paramtype port: Optional[int] + :keyword nodes: Indicates whether the service has to run in all nodes. + :paramtype nodes: Optional[Literal["all"]] + :keyword properties: Additional properties to set on the endpoint. + :paramtype properties: Optional[dict[str, str]] + :keyword status: The status of the endpoint. + :paramtype status: Optional[str] + :keyword kwargs: A dictionary of additional configuration parameters. + :paramtype kwargs: dict + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START ssh_job_service_configuration] + :end-before: [END ssh_job_service_configuration] + :language: python + :dedent: 8 + :caption: Configuring a VsCodeJobService configuration on a command job. + """ + + def __init__( + self, + *, + endpoint: Optional[str] = None, + nodes: Optional[Literal["all"]] = None, + status: Optional[str] = None, + port: Optional[int] = None, + properties: Optional[Dict[str, str]] = None, + **kwargs: Any, + ) -> None: + super().__init__( + endpoint=endpoint, + nodes=nodes, + status=status, + port=port, + properties=properties, + **kwargs, + ) + self.type = JobServiceTypeNames.EntityNames.VS_CODE + + @classmethod + def _from_rest_object(cls, obj: RestJobService) -> "VsCodeJobService": + return cast(VsCodeJobService, cls._from_rest_job_service_object(obj)) + + def _to_rest_object(self) -> RestJobService: + return self._to_rest_job_service() + + +def _append_or_update_properties( + properties: Optional[Dict[str, str]], key: str, value: Optional[str] +) -> Dict[str, str]: + if value and not properties: + properties = {key: value} + + if value and properties: + properties.update({key: value}) + return properties if properties is not None else {} + + +def _get_property(properties: Dict[str, str], key: str) -> Optional[str]: + return properties.get(key, None) if properties else None |
