about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/job_service.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/job_service.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/job_service.py')
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/job_service.py424
1 files changed, 424 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/job_service.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/job_service.py
new file mode 100644
index 00000000..a97048fc
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/job_service.py
@@ -0,0 +1,424 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+import logging
+from typing import Any, Dict, Optional, cast
+
+from typing_extensions import Literal
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import AllNodes
+from azure.ai.ml._restclient.v2023_04_01_preview.models import JobService as RestJobService
+from azure.ai.ml.constants._job.job import JobServiceTypeNames
+from azure.ai.ml.entities._mixins import DictMixin, RestTranslatableMixin
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException
+
+module_logger = logging.getLogger(__name__)
+
+
+class JobServiceBase(RestTranslatableMixin, DictMixin):
+    """Base class for job service configuration.
+
+    This class should not be instantiated directly. Instead, use one of its subclasses.
+
+    :keyword endpoint: The endpoint URL.
+    :paramtype endpoint: Optional[str]
+    :keyword type: The endpoint type. Accepted values are "jupyter_lab", "ssh", "tensor_board", and "vs_code".
+    :paramtype type: Optional[Literal["jupyter_lab", "ssh", "tensor_board", "vs_code"]]
+    :keyword port: The port for the endpoint.
+    :paramtype port: Optional[int]
+    :keyword nodes: Indicates whether the service has to run in all nodes.
+    :paramtype nodes: Optional[Literal["all"]]
+    :keyword properties: Additional properties to set on the endpoint.
+    :paramtype properties: Optional[dict[str, str]]
+    :keyword status: The status of the endpoint.
+    :paramtype status: Optional[str]
+    :keyword kwargs: A dictionary of additional configuration parameters.
+    :paramtype kwargs: dict
+    """
+
+    def __init__(  # pylint: disable=unused-argument
+        self,
+        *,
+        endpoint: Optional[str] = None,
+        type: Optional[  # pylint: disable=redefined-builtin
+            Literal["jupyter_lab", "ssh", "tensor_board", "vs_code"]
+        ] = None,
+        nodes: Optional[Literal["all"]] = None,
+        status: Optional[str] = None,
+        port: Optional[int] = None,
+        properties: Optional[Dict[str, str]] = None,
+        **kwargs: Dict,
+    ) -> None:
+        self.endpoint = endpoint
+        self.type: Any = type
+        self.nodes = nodes
+        self.status = status
+        self.port = port
+        self.properties = properties
+        self._validate_nodes()
+        self._validate_type_name()
+
+    def _validate_nodes(self) -> None:
+        if not self.nodes in ["all", None]:
+            msg = f"nodes should be either 'all' or None, but received '{self.nodes}'."
+            raise ValidationException(
+                message=msg,
+                no_personal_data_message=msg,
+                target=ErrorTarget.JOB,
+                error_category=ErrorCategory.USER_ERROR,
+                error_type=ValidationErrorType.INVALID_VALUE,
+            )
+
+    def _validate_type_name(self) -> None:
+        if self.type and not self.type in JobServiceTypeNames.ENTITY_TO_REST:
+            msg = (
+                f"type should be one of " f"{JobServiceTypeNames.NAMES_ALLOWED_FOR_PUBLIC}, but received '{self.type}'."
+            )
+            raise ValidationException(
+                message=msg,
+                no_personal_data_message=msg,
+                target=ErrorTarget.JOB,
+                error_category=ErrorCategory.USER_ERROR,
+                error_type=ValidationErrorType.INVALID_VALUE,
+            )
+
+    def _to_rest_job_service(self, updated_properties: Optional[Dict[str, str]] = None) -> RestJobService:
+        return RestJobService(
+            endpoint=self.endpoint,
+            job_service_type=JobServiceTypeNames.ENTITY_TO_REST.get(self.type, None) if self.type else None,
+            nodes=AllNodes() if self.nodes else None,
+            status=self.status,
+            port=self.port,
+            properties=updated_properties if updated_properties else self.properties,
+        )
+
+    @classmethod
+    def _to_rest_job_services(
+        cls,
+        services: Optional[Dict],
+    ) -> Optional[Dict[str, RestJobService]]:
+        if services is None:
+            return None
+
+        return {name: service._to_rest_object() for name, service in services.items()}
+
+    @classmethod
+    def _from_rest_job_service_object(cls, obj: RestJobService) -> "JobServiceBase":
+        return cls(
+            endpoint=obj.endpoint,
+            type=(
+                JobServiceTypeNames.REST_TO_ENTITY.get(obj.job_service_type, None)  # type: ignore[arg-type]
+                if obj.job_service_type
+                else None
+            ),
+            nodes="all" if obj.nodes else None,
+            status=obj.status,
+            port=obj.port,
+            # ssh_public_keys=_get_property(obj.properties, "sshPublicKeys"),
+            properties=obj.properties,
+        )
+
+    @classmethod
+    def _from_rest_job_services(cls, services: Dict[str, RestJobService]) -> Dict:
+        # """Resolve Dict[str, RestJobService] to Dict[str, Specific JobService]"""
+        if services is None:
+            return None
+
+        result: dict = {}
+        for name, service in services.items():
+            if service.job_service_type == JobServiceTypeNames.RestNames.JUPYTER_LAB:
+                result[name] = JupyterLabJobService._from_rest_object(service)
+            elif service.job_service_type == JobServiceTypeNames.RestNames.SSH:
+                result[name] = SshJobService._from_rest_object(service)
+            elif service.job_service_type == JobServiceTypeNames.RestNames.TENSOR_BOARD:
+                result[name] = TensorBoardJobService._from_rest_object(service)
+            elif service.job_service_type == JobServiceTypeNames.RestNames.VS_CODE:
+                result[name] = VsCodeJobService._from_rest_object(service)
+            else:
+                result[name] = JobService._from_rest_object(service)
+        return result
+
+
+class JobService(JobServiceBase):
+    """Basic job service configuration for backward compatibility.
+
+    This class is not intended to be used directly. Instead, use one of its subclasses specific to your job type.
+
+    :keyword endpoint: The endpoint URL.
+    :paramtype endpoint: Optional[str]
+    :keyword type: The endpoint type. Accepted values are "jupyter_lab", "ssh", "tensor_board", and "vs_code".
+    :paramtype type: Optional[Literal["jupyter_lab", "ssh", "tensor_board", "vs_code"]]
+    :keyword port: The port for the endpoint.
+    :paramtype port: Optional[int]
+    :keyword nodes: Indicates whether the service has to run in all nodes.
+    :paramtype nodes: Optional[Literal["all"]]
+    :keyword properties: Additional properties to set on the endpoint.
+    :paramtype properties: Optional[dict[str, str]]
+    :keyword status: The status of the endpoint.
+    :paramtype status: Optional[str]
+    :keyword kwargs: A dictionary of additional configuration parameters.
+    :paramtype kwargs: dict
+    """
+
+    @classmethod
+    def _from_rest_object(cls, obj: RestJobService) -> "JobService":
+        return cast(JobService, cls._from_rest_job_service_object(obj))
+
+    def _to_rest_object(self) -> RestJobService:
+        return self._to_rest_job_service()
+
+
+class SshJobService(JobServiceBase):
+    """SSH job service configuration.
+
+    :ivar type: Specifies the type of job service. Set automatically to "ssh" for this class.
+    :vartype type: str
+    :keyword endpoint: The endpoint URL.
+    :paramtype endpoint: Optional[str]
+    :keyword port: The port for the endpoint.
+    :paramtype port: Optional[int]
+    :keyword nodes: Indicates whether the service has to run in all nodes.
+    :paramtype nodes: Optional[Literal["all"]]
+    :keyword properties: Additional properties to set on the endpoint.
+    :paramtype properties: Optional[dict[str, str]]
+    :keyword status: The status of the endpoint.
+    :paramtype status: Optional[str]
+    :keyword ssh_public_keys: The SSH Public Key to access the job container.
+    :paramtype ssh_public_keys: Optional[str]
+    :keyword kwargs: A dictionary of additional configuration parameters.
+    :paramtype kwargs: dict
+
+    .. admonition:: Example:
+
+        .. literalinclude:: ../samples/ml_samples_misc.py
+            :start-after: [START ssh_job_service_configuration]
+            :end-before: [END ssh_job_service_configuration]
+            :language: python
+            :dedent: 8
+            :caption: Configuring a SshJobService configuration on a command job.
+    """
+
+    def __init__(
+        self,
+        *,
+        endpoint: Optional[str] = None,
+        nodes: Optional[Literal["all"]] = None,
+        status: Optional[str] = None,
+        port: Optional[int] = None,
+        ssh_public_keys: Optional[str] = None,
+        properties: Optional[Dict[str, str]] = None,
+        **kwargs: Any,
+    ) -> None:
+        super().__init__(
+            endpoint=endpoint,
+            nodes=nodes,
+            status=status,
+            port=port,
+            properties=properties,
+            **kwargs,
+        )
+        self.type = JobServiceTypeNames.EntityNames.SSH
+        self.ssh_public_keys = ssh_public_keys
+
+    @classmethod
+    def _from_rest_object(cls, obj: RestJobService) -> "SshJobService":
+        ssh_job_service = cast(SshJobService, cls._from_rest_job_service_object(obj))
+        ssh_job_service.ssh_public_keys = _get_property(obj.properties, "sshPublicKeys")
+        return ssh_job_service
+
+    def _to_rest_object(self) -> RestJobService:
+        updated_properties = _append_or_update_properties(self.properties, "sshPublicKeys", self.ssh_public_keys)
+        return self._to_rest_job_service(updated_properties)
+
+
+class TensorBoardJobService(JobServiceBase):
+    """TensorBoard job service configuration.
+
+    :ivar type: Specifies the type of job service. Set automatically to "tensor_board" for this class.
+    :vartype type: str
+    :keyword endpoint: The endpoint URL.
+    :paramtype endpoint: Optional[str]
+    :keyword port: The port for the endpoint.
+    :paramtype port: Optional[int]
+    :keyword nodes: Indicates whether the service has to run in all nodes.
+    :paramtype nodes: Optional[Literal["all"]]
+    :keyword properties: Additional properties to set on the endpoint.
+    :paramtype properties: Optional[dict[str, str]]
+    :keyword status: The status of the endpoint.
+    :paramtype status: Optional[str]
+    :keyword log_dir: The directory path for the log file.
+    :paramtype log_dir: Optional[str]
+    :keyword kwargs: A dictionary of additional configuration parameters.
+    :paramtype kwargs: dict
+
+    .. admonition:: Example:
+
+        .. literalinclude:: ../samples/ml_samples_misc.py
+            :start-after: [START ssh_job_service_configuration]
+            :end-before: [END ssh_job_service_configuration]
+            :language: python
+            :dedent: 8
+            :caption: Configuring TensorBoardJobService configuration on a command job.
+    """
+
+    def __init__(
+        self,
+        *,
+        endpoint: Optional[str] = None,
+        nodes: Optional[Literal["all"]] = None,
+        status: Optional[str] = None,
+        port: Optional[int] = None,
+        log_dir: Optional[str] = None,
+        properties: Optional[Dict[str, str]] = None,
+        **kwargs: Any,
+    ) -> None:
+        super().__init__(
+            endpoint=endpoint,
+            nodes=nodes,
+            status=status,
+            port=port,
+            properties=properties,
+            **kwargs,
+        )
+        self.type = JobServiceTypeNames.EntityNames.TENSOR_BOARD
+        self.log_dir = log_dir
+
+    @classmethod
+    def _from_rest_object(cls, obj: RestJobService) -> "TensorBoardJobService":
+        tensorboard_job_Service = cast(TensorBoardJobService, cls._from_rest_job_service_object(obj))
+        tensorboard_job_Service.log_dir = _get_property(obj.properties, "logDir")
+        return tensorboard_job_Service
+
+    def _to_rest_object(self) -> RestJobService:
+        updated_properties = _append_or_update_properties(self.properties, "logDir", self.log_dir)
+        return self._to_rest_job_service(updated_properties)
+
+
+class JupyterLabJobService(JobServiceBase):
+    """JupyterLab job service configuration.
+
+    :ivar type: Specifies the type of job service. Set automatically to "jupyter_lab" for this class.
+    :vartype type: str
+    :keyword endpoint: The endpoint URL.
+    :paramtype endpoint: Optional[str]
+    :keyword port: The port for the endpoint.
+    :paramtype port: Optional[int]
+    :keyword nodes: Indicates whether the service has to run in all nodes.
+    :paramtype nodes: Optional[Literal["all"]]
+    :keyword properties: Additional properties to set on the endpoint.
+    :paramtype properties: Optional[dict[str, str]]
+    :keyword status: The status of the endpoint.
+    :paramtype status: Optional[str]
+    :keyword kwargs: A dictionary of additional configuration parameters.
+    :paramtype kwargs: dict
+
+    .. admonition:: Example:
+
+        .. literalinclude:: ../samples/ml_samples_misc.py
+            :start-after: [START ssh_job_service_configuration]
+            :end-before: [END ssh_job_service_configuration]
+            :language: python
+            :dedent: 8
+            :caption: Configuring JupyterLabJobService configuration on a command job.
+    """
+
+    def __init__(
+        self,
+        *,
+        endpoint: Optional[str] = None,
+        nodes: Optional[Literal["all"]] = None,
+        status: Optional[str] = None,
+        port: Optional[int] = None,
+        properties: Optional[Dict[str, str]] = None,
+        **kwargs: Any,
+    ) -> None:
+        super().__init__(
+            endpoint=endpoint,
+            nodes=nodes,
+            status=status,
+            port=port,
+            properties=properties,
+            **kwargs,
+        )
+        self.type = JobServiceTypeNames.EntityNames.JUPYTER_LAB
+
+    @classmethod
+    def _from_rest_object(cls, obj: RestJobService) -> "JupyterLabJobService":
+        return cast(JupyterLabJobService, cls._from_rest_job_service_object(obj))
+
+    def _to_rest_object(self) -> RestJobService:
+        return self._to_rest_job_service()
+
+
+class VsCodeJobService(JobServiceBase):
+    """VS Code job service configuration.
+
+    :ivar type: Specifies the type of job service. Set automatically to "vs_code" for this class.
+    :vartype type: str
+    :keyword endpoint: The endpoint URL.
+    :paramtype endpoint: Optional[str]
+    :keyword port: The port for the endpoint.
+    :paramtype port: Optional[int]
+    :keyword nodes: Indicates whether the service has to run in all nodes.
+    :paramtype nodes: Optional[Literal["all"]]
+    :keyword properties: Additional properties to set on the endpoint.
+    :paramtype properties: Optional[dict[str, str]]
+    :keyword status: The status of the endpoint.
+    :paramtype status: Optional[str]
+    :keyword kwargs: A dictionary of additional configuration parameters.
+    :paramtype kwargs: dict
+
+    .. admonition:: Example:
+
+        .. literalinclude:: ../samples/ml_samples_misc.py
+            :start-after: [START ssh_job_service_configuration]
+            :end-before: [END ssh_job_service_configuration]
+            :language: python
+            :dedent: 8
+            :caption: Configuring a VsCodeJobService configuration on a command job.
+    """
+
+    def __init__(
+        self,
+        *,
+        endpoint: Optional[str] = None,
+        nodes: Optional[Literal["all"]] = None,
+        status: Optional[str] = None,
+        port: Optional[int] = None,
+        properties: Optional[Dict[str, str]] = None,
+        **kwargs: Any,
+    ) -> None:
+        super().__init__(
+            endpoint=endpoint,
+            nodes=nodes,
+            status=status,
+            port=port,
+            properties=properties,
+            **kwargs,
+        )
+        self.type = JobServiceTypeNames.EntityNames.VS_CODE
+
+    @classmethod
+    def _from_rest_object(cls, obj: RestJobService) -> "VsCodeJobService":
+        return cast(VsCodeJobService, cls._from_rest_job_service_object(obj))
+
+    def _to_rest_object(self) -> RestJobService:
+        return self._to_rest_job_service()
+
+
+def _append_or_update_properties(
+    properties: Optional[Dict[str, str]], key: str, value: Optional[str]
+) -> Dict[str, str]:
+    if value and not properties:
+        properties = {key: value}
+
+    if value and properties:
+        properties.update({key: value})
+    return properties if properties is not None else {}
+
+
+def _get_property(properties: Dict[str, str], key: str) -> Optional[str]:
+    return properties.get(key, None) if properties else None