aboutsummaryrefslogtreecommitdiff
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

# pylint: disable=protected-access

import json
import logging
from typing import Any, Iterable, List, Optional

from marshmallow.exceptions import ValidationError as SchemaValidationError

from azure.ai.ml._exception_helper import log_and_raise_error
from azure.ai.ml._local_endpoints import EndpointStub
from azure.ai.ml._local_endpoints.docker_client import (
    DockerClient,
    get_endpoint_json_from_container,
    get_scoring_uri_from_container,
    get_status_from_container,
)
from azure.ai.ml._utils._endpoint_utils import local_endpoint_polling_wrapper
from azure.ai.ml._utils._http_utils import HttpPipeline
from azure.ai.ml._utils.utils import DockerProxy
from azure.ai.ml.constants._endpoint import EndpointInvokeFields, LocalEndpointConstants
from azure.ai.ml.entities import OnlineEndpoint
from azure.ai.ml.exceptions import InvalidLocalEndpointError, LocalEndpointNotFoundError, ValidationException

docker = DockerProxy()
module_logger = logging.getLogger(__name__)


class _LocalEndpointHelper(object):
    """A helper class to interact with Azure ML endpoints locally.

    Use this helper to manage Azure ML endpoints locally, e.g. create, invoke, show, list, delete.
    """

    def __init__(self, *, requests_pipeline: HttpPipeline):
        self._docker_client = DockerClient()
        self._endpoint_stub = EndpointStub()
        self._requests_pipeline = requests_pipeline

    def create_or_update(self, endpoint: OnlineEndpoint) -> OnlineEndpoint:  # type: ignore
        """Create or update an endpoint locally using Docker.

        :param endpoint: OnlineEndpoint object with information from user yaml.
        :type endpoint: OnlineEndpoint
        """
        try:
            if endpoint is None:
                msg = "The entity provided for local endpoint was null. Please provide valid entity."
                raise InvalidLocalEndpointError(message=msg, no_personal_data_message=msg)

            try:
                self.get(endpoint_name=str(endpoint.name))
                operation_message = "Updating local endpoint"
            except LocalEndpointNotFoundError:
                operation_message = "Creating local endpoint"

            local_endpoint_polling_wrapper(
                func=self._endpoint_stub.create_or_update,
                message=f"{operation_message} ({endpoint.name}) ",
                endpoint=endpoint,
            )
            return self.get(endpoint_name=str(endpoint.name))
        except Exception as ex:  # pylint: disable=W0718
            if isinstance(ex, (ValidationException, SchemaValidationError)):
                log_and_raise_error(ex)
            else:
                raise ex

    def invoke(self, endpoint_name: str, data: dict, deployment_name: Optional[str] = None) -> str:
        """Invoke a local endpoint.

        :param endpoint_name: Name of endpoint to invoke.
        :type endpoint_name: str
        :param data: json data to pass
        :type data: dict
        :param deployment_name: Name of specific deployment to invoke.
        :type deployment_name: (str, optional)
        :return: The text response
        :rtype: str
        """
        # get_scoring_uri will throw user error if there are multiple deployments and no deployment_name is specified
        scoring_uri = self._docker_client.get_scoring_uri(endpoint_name=endpoint_name, deployment_name=deployment_name)
        if scoring_uri:
            headers = {}
            if deployment_name is not None:
                headers[EndpointInvokeFields.MODEL_DEPLOYMENT] = deployment_name
            return str(self._requests_pipeline.post(scoring_uri, json=data, headers=headers).text())
        endpoint_stub = self._endpoint_stub.get(endpoint_name=endpoint_name)
        if endpoint_stub:
            return str(self._endpoint_stub.invoke())
        raise LocalEndpointNotFoundError(endpoint_name=endpoint_name, deployment_name=deployment_name)

    def get(self, endpoint_name: str) -> OnlineEndpoint:
        """Get a local endpoint.

        :param endpoint_name: Name of endpoint.
        :type endpoint_name: str
        :return OnlineEndpoint:
        """
        endpoint = self._endpoint_stub.get(endpoint_name=endpoint_name)
        container = self._docker_client.get_endpoint_container(endpoint_name=endpoint_name, include_stopped=True)
        if endpoint:
            if container:
                return _convert_container_to_endpoint(container=container, endpoint_json=endpoint.dump())
            return endpoint
        if container:
            return _convert_container_to_endpoint(container=container)
        raise LocalEndpointNotFoundError(endpoint_name=endpoint_name)

    def list(self) -> Iterable[OnlineEndpoint]:
        """List all local endpoints.

        :return: An iterable of local endpoints
        :rtype: Iterable[OnlineEndpoint]
        """
        endpoints: List = []
        containers = self._docker_client.list_containers()
        endpoint_stubs = self._endpoint_stub.list()
        # Iterate through all cached endpoint files
        for endpoint_file in endpoint_stubs:
            endpoint_json = json.loads(endpoint_file.read_text())
            container = self._docker_client.get_endpoint_container(
                endpoint_name=endpoint_json.get("name"), include_stopped=True
            )
            # If a deployment is associated with endpoint,
            # override certain endpoint properties with deployment information and remove it from containers list.
            # Otherwise, return endpoint spec.
            if container:
                endpoints.append(_convert_container_to_endpoint(endpoint_json=endpoint_json, container=container))
                containers.remove(container)
            else:
                endpoints.append(
                    OnlineEndpoint._load(
                        data=endpoint_json,
                        params_override=[{"location": LocalEndpointConstants.ENDPOINT_STATE_LOCATION}],
                    )
                )
        # Iterate through any deployments that don't have an explicit local endpoint stub.
        for container in containers:
            endpoints.append(_convert_container_to_endpoint(container=container))
        return endpoints

    def delete(self, name: str) -> None:
        """Delete a local endpoint.

        :param name: Name of endpoint to delete.
        :type name: str
        """
        endpoint_stub = self._endpoint_stub.get(endpoint_name=name)
        if endpoint_stub:
            self._endpoint_stub.delete(endpoint_name=name)
            endpoint_container = self._docker_client.get_endpoint_container(endpoint_name=name)
            if endpoint_container:
                self._docker_client.delete(endpoint_name=name)
        else:
            raise LocalEndpointNotFoundError(endpoint_name=name)


def _convert_container_to_endpoint(
    # Bug Item number: 2885719
    container: "docker.models.containers.Container",  # type: ignore
    endpoint_json: Optional[dict] = None,
) -> OnlineEndpoint:
    """Converts provided Container for local deployment to OnlineEndpoint entity.

    :param container: Container for a local deployment.
    :type container: docker.models.containers.Container
    :param endpoint_json: The endpoint json
    :type endpoint_json: Optional[dict]
    :return: The OnlineEndpoint entity
    :rtype: OnlineEndpoint
    """
    if endpoint_json is None:
        endpoint_json = get_endpoint_json_from_container(container=container)
    provisioning_state = get_status_from_container(container=container)
    if provisioning_state == LocalEndpointConstants.CONTAINER_EXITED:
        return _convert_json_to_endpoint(
            endpoint_json=endpoint_json,
            location=LocalEndpointConstants.ENDPOINT_STATE_LOCATION,
            provisioning_state=LocalEndpointConstants.ENDPOINT_STATE_FAILED,
        )
    scoring_uri = get_scoring_uri_from_container(container=container)
    return _convert_json_to_endpoint(
        endpoint_json=endpoint_json,
        location=LocalEndpointConstants.ENDPOINT_STATE_LOCATION,
        provisioning_state=LocalEndpointConstants.ENDPOINT_STATE_SUCCEEDED,
        scoring_uri=scoring_uri,
    )


def _convert_json_to_endpoint(endpoint_json: Optional[dict], **kwargs: Any) -> OnlineEndpoint:
    """Converts metadata json and kwargs to OnlineEndpoint entity.

    :param endpoint_json: dictionary representation of OnlineEndpoint entity.
    :type endpoint_json: dict
    :return: The OnlineEndpoint entity
    :rtype: OnlineEndpoint
    """
    params_override = []
    for k, v in kwargs.items():
        params_override.append({k: v})
    return OnlineEndpoint._load(data=endpoint_json, params_override=params_override)  # type: ignore