aboutsummaryrefslogtreecommitdiff
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

# pylint: disable=arguments-renamed

import logging
from typing import Optional

from azure.ai.ml._restclient.v2022_05_01.models import BatchRetrySettings as RestBatchRetrySettings
from azure.ai.ml._restclient.v2022_05_01.models import OnlineRequestSettings as RestOnlineRequestSettings
from azure.ai.ml._restclient.v2022_05_01.models import ProbeSettings as RestProbeSettings
from azure.ai.ml._utils.utils import (
    from_iso_duration_format,
    from_iso_duration_format_ms,
    to_iso_duration_format,
    to_iso_duration_format_ms,
)
from azure.ai.ml.entities._mixins import RestTranslatableMixin

module_logger = logging.getLogger(__name__)


class BatchRetrySettings(RestTranslatableMixin):
    """Retry settings for batch deployment.

    :param max_retries: Number of retries in failure, defaults to 3
    :type max_retries: int
    :param timeout: Timeout in seconds, defaults to 30
    :type timeout: int
    """

    def __init__(self, *, max_retries: Optional[int] = None, timeout: Optional[int] = None):
        self.max_retries = max_retries
        self.timeout = timeout

    def _to_rest_object(self) -> RestBatchRetrySettings:
        return RestBatchRetrySettings(
            max_retries=self.max_retries,
            timeout=to_iso_duration_format(self.timeout),
        )

    @classmethod
    def _from_rest_object(cls, settings: RestBatchRetrySettings) -> Optional["BatchRetrySettings"]:
        return (
            BatchRetrySettings(
                max_retries=settings.max_retries,
                timeout=from_iso_duration_format(settings.timeout),
            )
            if settings
            else None
        )

    def _merge_with(self, other: "BatchRetrySettings") -> None:
        if other:
            self.timeout = other.timeout or self.timeout
            self.max_retries = other.max_retries or self.max_retries


class OnlineRequestSettings(RestTranslatableMixin):
    """Request Settings entity.

    :param request_timeout_ms: defaults to 5000
    :type request_timeout_ms: int
    :param max_concurrent_requests_per_instance: defaults to 1
    :type max_concurrent_requests_per_instance: int
    :param max_queue_wait_ms: defaults to 500
    :type max_queue_wait_ms: int
    """

    def __init__(
        self,
        max_concurrent_requests_per_instance: Optional[int] = None,
        request_timeout_ms: Optional[int] = None,
        max_queue_wait_ms: Optional[int] = None,
    ):
        self.request_timeout_ms = request_timeout_ms
        self.max_concurrent_requests_per_instance = max_concurrent_requests_per_instance
        self.max_queue_wait_ms = max_queue_wait_ms

    def _to_rest_object(self) -> RestOnlineRequestSettings:
        return RestOnlineRequestSettings(
            max_queue_wait=to_iso_duration_format_ms(self.max_queue_wait_ms),
            max_concurrent_requests_per_instance=self.max_concurrent_requests_per_instance,
            request_timeout=to_iso_duration_format_ms(self.request_timeout_ms),
        )

    def _merge_with(self, other: Optional["OnlineRequestSettings"]) -> None:
        if other:
            self.max_concurrent_requests_per_instance = (
                other.max_concurrent_requests_per_instance or self.max_concurrent_requests_per_instance
            )
            self.request_timeout_ms = other.request_timeout_ms or self.request_timeout_ms
            self.max_queue_wait_ms = other.max_queue_wait_ms or self.max_queue_wait_ms

    @classmethod
    def _from_rest_object(cls, settings: RestOnlineRequestSettings) -> Optional["OnlineRequestSettings"]:
        return (
            OnlineRequestSettings(
                request_timeout_ms=from_iso_duration_format_ms(settings.request_timeout),
                max_concurrent_requests_per_instance=settings.max_concurrent_requests_per_instance,
                max_queue_wait_ms=from_iso_duration_format_ms(settings.max_queue_wait),
            )
            if settings
            else None
        )

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, OnlineRequestSettings):
            return NotImplemented
        if not other:
            return False
        # only compare mutable fields
        return (
            self.max_concurrent_requests_per_instance == other.max_concurrent_requests_per_instance
            and self.request_timeout_ms == other.request_timeout_ms
            and self.max_queue_wait_ms == other.max_queue_wait_ms
        )

    def __ne__(self, other: object) -> bool:
        return not self.__eq__(other)


class ProbeSettings(RestTranslatableMixin):
    def __init__(
        self,
        *,
        failure_threshold: Optional[int] = None,
        success_threshold: Optional[int] = None,
        timeout: Optional[int] = None,
        period: Optional[int] = None,
        initial_delay: Optional[int] = None,
    ):
        """Settings on how to probe an endpoint.

        :param failure_threshold: Threshold for probe failures, defaults to 30
        :type failure_threshold: int
        :param success_threshold: Threshold for probe success, defaults to 1
        :type success_threshold: int
        :param timeout: timeout in seconds, defaults to 2
        :type timeout: int
        :param period: How often (in seconds) to perform the probe, defaults to 10
        :type period: int
        :param initial_delay: How long (in seconds) to wait for the first probe, defaults to 10
        :type initial_delay: int
        """

        self.failure_threshold = failure_threshold
        self.success_threshold = success_threshold
        self.timeout = timeout
        self.period = period
        self.initial_delay = initial_delay

    def _to_rest_object(self) -> RestProbeSettings:
        return RestProbeSettings(
            failure_threshold=self.failure_threshold,
            success_threshold=self.success_threshold,
            timeout=to_iso_duration_format(self.timeout),
            period=to_iso_duration_format(self.period),
            initial_delay=to_iso_duration_format(self.initial_delay),
        )

    def _merge_with(self, other: Optional["ProbeSettings"]) -> None:
        if other:
            self.failure_threshold = other.failure_threshold or self.failure_threshold
            self.success_threshold = other.success_threshold or self.success_threshold
            self.timeout = other.timeout or self.timeout
            self.period = other.period or self.period
            self.initial_delay = other.initial_delay or self.initial_delay

    @classmethod
    def _from_rest_object(cls, settings: RestProbeSettings) -> Optional["ProbeSettings"]:
        return (
            ProbeSettings(
                failure_threshold=settings.failure_threshold,
                success_threshold=settings.success_threshold,
                timeout=from_iso_duration_format(settings.timeout),
                period=from_iso_duration_format(settings.period),
                initial_delay=from_iso_duration_format(settings.initial_delay),
            )
            if settings
            else None
        )

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, ProbeSettings):
            return NotImplemented
        if not other:
            return False
        # only compare mutable fields
        return (
            self.failure_threshold == other.failure_threshold
            and self.success_threshold == other.success_threshold
            and self.timeout == other.timeout
            and self.period == other.period
            and self.initial_delay == other.initial_delay
        )

    def __ne__(self, other: object) -> bool:
        return not self.__eq__(other)