about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_endpoint/batch_endpoint.py
blob: 4883c828b9827127f2c908ea0762ae26c040641c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

import logging
from os import PathLike
from pathlib import Path
from typing import IO, Any, AnyStr, Dict, Optional, Union

from azure.ai.ml._restclient.v2023_10_01.models import BatchEndpoint as BatchEndpointData
from azure.ai.ml._restclient.v2023_10_01.models import BatchEndpointProperties as RestBatchEndpoint
from azure.ai.ml._schema._endpoint import BatchEndpointSchema
from azure.ai.ml._utils.utils import camel_to_snake, snake_to_camel
from azure.ai.ml.constants._common import AAD_TOKEN_YAML, BASE_PATH_CONTEXT_KEY, PARAMS_OVERRIDE_KEY
from azure.ai.ml.entities._endpoint._endpoint_helpers import validate_endpoint_or_deployment_name
from azure.ai.ml.entities._util import load_from_dict

from .endpoint import Endpoint

module_logger = logging.getLogger(__name__)


class BatchEndpoint(Endpoint):
    """Batch endpoint entity.

    :param name: Name of the resource.
    :type name: str
    :param tags: Tag dictionary. Tags can be added, removed, and updated.
    :type tags: dict[str, str]
    :param properties: The asset property dictionary.
    :type properties: dict[str, str]
    :param auth_mode: Possible values include: "AMLToken", "Key", "AADToken", defaults to None
    :type auth_mode: str
    :param description: Description of the inference endpoint, defaults to None
    :type description: str
    :param location: defaults to None
    :type location: str
    :param defaults:  Traffic rules on how the traffic will be routed across deployments, defaults to {}
    :type defaults: Dict[str, str]
    :param default_deployment_name:  Equivalent to defaults.default_deployment, will be ignored if defaults is present.
    :type default_deployment_name: str
    :param scoring_uri: URI to use to perform a prediction, readonly.
    :type scoring_uri: str
    :param openapi_uri: URI to check the open API definition of the endpoint.
    :type openapi_uri: str
    """

    def __init__(
        self,
        *,
        name: Optional[str] = None,
        tags: Optional[Dict] = None,
        properties: Optional[Dict] = None,
        auth_mode: str = AAD_TOKEN_YAML,
        description: Optional[str] = None,
        location: Optional[str] = None,
        defaults: Optional[Dict[str, str]] = None,
        default_deployment_name: Optional[str] = None,
        scoring_uri: Optional[str] = None,
        openapi_uri: Optional[str] = None,
        **kwargs: Any,
    ) -> None:
        super(BatchEndpoint, self).__init__(
            name=name,
            tags=tags,
            properties=properties,
            auth_mode=auth_mode,
            description=description,
            location=location,
            scoring_uri=scoring_uri,
            openapi_uri=openapi_uri,
            **kwargs,
        )

        self.defaults = defaults

        if not self.defaults and default_deployment_name:
            self.defaults = {}
            self.defaults["deployment_name"] = default_deployment_name

    def _to_rest_batch_endpoint(self, location: str) -> BatchEndpointData:
        validate_endpoint_or_deployment_name(self.name)
        batch_endpoint = RestBatchEndpoint(
            description=self.description,
            auth_mode=snake_to_camel(self.auth_mode),
            properties=self.properties,
            defaults=self.defaults,
        )
        return BatchEndpointData(location=location, tags=self.tags, properties=batch_endpoint)

    @classmethod
    def _from_rest_object(cls, obj: BatchEndpointData) -> "BatchEndpoint":
        return BatchEndpoint(
            id=obj.id,
            name=obj.name,
            tags=obj.tags,
            properties=obj.properties.properties,
            auth_mode=camel_to_snake(obj.properties.auth_mode),
            description=obj.properties.description,
            location=obj.location,
            defaults=obj.properties.defaults,
            provisioning_state=obj.properties.provisioning_state,
            scoring_uri=obj.properties.scoring_uri,
            openapi_uri=obj.properties.swagger_uri,
        )

    def dump(
        self,
        dest: Optional[Union[str, PathLike, IO[AnyStr]]] = None,
        **kwargs: Any,
    ) -> Dict[str, Any]:
        context = {BASE_PATH_CONTEXT_KEY: Path(".").parent}
        return BatchEndpointSchema(context=context).dump(self)  # type: ignore

    @classmethod
    def _load(
        cls,
        data: Optional[Dict] = None,
        yaml_path: Optional[Union[PathLike, str]] = None,
        params_override: Optional[list] = None,
        **kwargs: Any,
    ) -> "BatchEndpoint":
        data = data or {}
        params_override = params_override or []
        context = {
            BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path.cwd(),
            PARAMS_OVERRIDE_KEY: params_override,
        }
        res: BatchEndpoint = load_from_dict(BatchEndpointSchema, data, context)
        return res

    def _to_dict(self) -> Dict:
        res: dict = BatchEndpointSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
        return res