diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/azure/ai/inference | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/inference')
35 files changed, 14450 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/inference/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/inference/__init__.py new file mode 100644 index 00000000..b7537d16 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/inference/__init__.py @@ -0,0 +1,36 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) Python Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +# pylint: disable=wrong-import-position + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ._patch import * # pylint: disable=unused-wildcard-import + +from ._client import ChatCompletionsClient # type: ignore +from ._client import EmbeddingsClient # type: ignore +from ._client import ImageEmbeddingsClient # type: ignore +from ._version import VERSION + +__version__ = VERSION + +try: + from ._patch import __all__ as _patch_all + from ._patch import * +except ImportError: + _patch_all = [] +from ._patch import patch_sdk as _patch_sdk + +__all__ = [ + "ChatCompletionsClient", + "EmbeddingsClient", + "ImageEmbeddingsClient", +] +__all__.extend([p for p in _patch_all if p not in __all__]) # pyright: ignore + +_patch_sdk() diff --git a/.venv/lib/python3.12/site-packages/azure/ai/inference/_client.py b/.venv/lib/python3.12/site-packages/azure/ai/inference/_client.py new file mode 100644 index 00000000..0cde08ff --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/inference/_client.py @@ -0,0 +1,265 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) Python Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- + +from copy import deepcopy +from typing import Any, TYPE_CHECKING, Union +from typing_extensions import Self + +from azure.core import PipelineClient +from azure.core.credentials import AzureKeyCredential +from azure.core.pipeline import policies +from azure.core.rest import HttpRequest, HttpResponse + +from ._configuration import ( + ChatCompletionsClientConfiguration, + EmbeddingsClientConfiguration, + ImageEmbeddingsClientConfiguration, +) +from ._operations import ( + ChatCompletionsClientOperationsMixin, + EmbeddingsClientOperationsMixin, + ImageEmbeddingsClientOperationsMixin, +) +from ._serialization import Deserializer, Serializer + +if TYPE_CHECKING: + from azure.core.credentials import TokenCredential + + +class ChatCompletionsClient(ChatCompletionsClientOperationsMixin): + """ChatCompletionsClient. + + :param endpoint: Service host. Required. + :type endpoint: str + :param credential: Credential used to authenticate requests to the service. Is either a key + credential type or a token credential type. Required. + :type credential: ~azure.core.credentials.AzureKeyCredential or + ~azure.core.credentials.AzureKeyCredential or ~azure.core.credentials.TokenCredential + :keyword api_version: The API version to use for this operation. Default value is + "2024-05-01-preview". Note that overriding this default value may result in unsupported + behavior. + :paramtype api_version: str + """ + + def __init__(self, endpoint: str, credential: Union[AzureKeyCredential, "TokenCredential"], **kwargs: Any) -> None: + _endpoint = "{endpoint}" + self._config = ChatCompletionsClientConfiguration(endpoint=endpoint, credential=credential, **kwargs) + _policies = kwargs.pop("policies", None) + if _policies is None: + _policies = [ + policies.RequestIdPolicy(**kwargs), + self._config.headers_policy, + self._config.user_agent_policy, + self._config.proxy_policy, + policies.ContentDecodePolicy(**kwargs), + self._config.redirect_policy, + self._config.retry_policy, + self._config.authentication_policy, + self._config.custom_hook_policy, + self._config.logging_policy, + policies.DistributedTracingPolicy(**kwargs), + policies.SensitiveHeaderCleanupPolicy(**kwargs) if self._config.redirect_policy else None, + self._config.http_logging_policy, + ] + self._client: PipelineClient = PipelineClient(base_url=_endpoint, policies=_policies, **kwargs) + + self._serialize = Serializer() + self._deserialize = Deserializer() + self._serialize.client_side_validation = False + + def send_request(self, request: HttpRequest, *, stream: bool = False, **kwargs: Any) -> HttpResponse: + """Runs the network request through the client's chained policies. + + >>> from azure.core.rest import HttpRequest + >>> request = HttpRequest("GET", "https://www.example.org/") + <HttpRequest [GET], url: 'https://www.example.org/'> + >>> response = client.send_request(request) + <HttpResponse: 200 OK> + + For more information on this code flow, see https://aka.ms/azsdk/dpcodegen/python/send_request + + :param request: The network request you want to make. Required. + :type request: ~azure.core.rest.HttpRequest + :keyword bool stream: Whether the response payload will be streamed. Defaults to False. + :return: The response of your network call. Does not do error handling on your response. + :rtype: ~azure.core.rest.HttpResponse + """ + + request_copy = deepcopy(request) + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + + request_copy.url = self._client.format_url(request_copy.url, **path_format_arguments) + return self._client.send_request(request_copy, stream=stream, **kwargs) # type: ignore + + def close(self) -> None: + self._client.close() + + def __enter__(self) -> Self: + self._client.__enter__() + return self + + def __exit__(self, *exc_details: Any) -> None: + self._client.__exit__(*exc_details) + + +class EmbeddingsClient(EmbeddingsClientOperationsMixin): + """EmbeddingsClient. + + :param endpoint: Service host. Required. + :type endpoint: str + :param credential: Credential used to authenticate requests to the service. Is either a key + credential type or a token credential type. Required. + :type credential: ~azure.core.credentials.AzureKeyCredential or + ~azure.core.credentials.AzureKeyCredential or ~azure.core.credentials.TokenCredential + :keyword api_version: The API version to use for this operation. Default value is + "2024-05-01-preview". Note that overriding this default value may result in unsupported + behavior. + :paramtype api_version: str + """ + + def __init__(self, endpoint: str, credential: Union[AzureKeyCredential, "TokenCredential"], **kwargs: Any) -> None: + _endpoint = "{endpoint}" + self._config = EmbeddingsClientConfiguration(endpoint=endpoint, credential=credential, **kwargs) + _policies = kwargs.pop("policies", None) + if _policies is None: + _policies = [ + policies.RequestIdPolicy(**kwargs), + self._config.headers_policy, + self._config.user_agent_policy, + self._config.proxy_policy, + policies.ContentDecodePolicy(**kwargs), + self._config.redirect_policy, + self._config.retry_policy, + self._config.authentication_policy, + self._config.custom_hook_policy, + self._config.logging_policy, + policies.DistributedTracingPolicy(**kwargs), + policies.SensitiveHeaderCleanupPolicy(**kwargs) if self._config.redirect_policy else None, + self._config.http_logging_policy, + ] + self._client: PipelineClient = PipelineClient(base_url=_endpoint, policies=_policies, **kwargs) + + self._serialize = Serializer() + self._deserialize = Deserializer() + self._serialize.client_side_validation = False + + def send_request(self, request: HttpRequest, *, stream: bool = False, **kwargs: Any) -> HttpResponse: + """Runs the network request through the client's chained policies. + + >>> from azure.core.rest import HttpRequest + >>> request = HttpRequest("GET", "https://www.example.org/") + <HttpRequest [GET], url: 'https://www.example.org/'> + >>> response = client.send_request(request) + <HttpResponse: 200 OK> + + For more information on this code flow, see https://aka.ms/azsdk/dpcodegen/python/send_request + + :param request: The network request you want to make. Required. + :type request: ~azure.core.rest.HttpRequest + :keyword bool stream: Whether the response payload will be streamed. Defaults to False. + :return: The response of your network call. Does not do error handling on your response. + :rtype: ~azure.core.rest.HttpResponse + """ + + request_copy = deepcopy(request) + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + + request_copy.url = self._client.format_url(request_copy.url, **path_format_arguments) + return self._client.send_request(request_copy, stream=stream, **kwargs) # type: ignore + + def close(self) -> None: + self._client.close() + + def __enter__(self) -> Self: + self._client.__enter__() + return self + + def __exit__(self, *exc_details: Any) -> None: + self._client.__exit__(*exc_details) + + +class ImageEmbeddingsClient(ImageEmbeddingsClientOperationsMixin): + """ImageEmbeddingsClient. + + :param endpoint: Service host. Required. + :type endpoint: str + :param credential: Credential used to authenticate requests to the service. Is either a key + credential type or a token credential type. Required. + :type credential: ~azure.core.credentials.AzureKeyCredential or + ~azure.core.credentials.AzureKeyCredential or ~azure.core.credentials.TokenCredential + :keyword api_version: The API version to use for this operation. Default value is + "2024-05-01-preview". Note that overriding this default value may result in unsupported + behavior. + :paramtype api_version: str + """ + + def __init__(self, endpoint: str, credential: Union[AzureKeyCredential, "TokenCredential"], **kwargs: Any) -> None: + _endpoint = "{endpoint}" + self._config = ImageEmbeddingsClientConfiguration(endpoint=endpoint, credential=credential, **kwargs) + _policies = kwargs.pop("policies", None) + if _policies is None: + _policies = [ + policies.RequestIdPolicy(**kwargs), + self._config.headers_policy, + self._config.user_agent_policy, + self._config.proxy_policy, + policies.ContentDecodePolicy(**kwargs), + self._config.redirect_policy, + self._config.retry_policy, + self._config.authentication_policy, + self._config.custom_hook_policy, + self._config.logging_policy, + policies.DistributedTracingPolicy(**kwargs), + policies.SensitiveHeaderCleanupPolicy(**kwargs) if self._config.redirect_policy else None, + self._config.http_logging_policy, + ] + self._client: PipelineClient = PipelineClient(base_url=_endpoint, policies=_policies, **kwargs) + + self._serialize = Serializer() + self._deserialize = Deserializer() + self._serialize.client_side_validation = False + + def send_request(self, request: HttpRequest, *, stream: bool = False, **kwargs: Any) -> HttpResponse: + """Runs the network request through the client's chained policies. + + >>> from azure.core.rest import HttpRequest + >>> request = HttpRequest("GET", "https://www.example.org/") + <HttpRequest [GET], url: 'https://www.example.org/'> + >>> response = client.send_request(request) + <HttpResponse: 200 OK> + + For more information on this code flow, see https://aka.ms/azsdk/dpcodegen/python/send_request + + :param request: The network request you want to make. Required. + :type request: ~azure.core.rest.HttpRequest + :keyword bool stream: Whether the response payload will be streamed. Defaults to False. + :return: The response of your network call. Does not do error handling on your response. + :rtype: ~azure.core.rest.HttpResponse + """ + + request_copy = deepcopy(request) + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + + request_copy.url = self._client.format_url(request_copy.url, **path_format_arguments) + return self._client.send_request(request_copy, stream=stream, **kwargs) # type: ignore + + def close(self) -> None: + self._client.close() + + def __enter__(self) -> Self: + self._client.__enter__() + return self + + def __exit__(self, *exc_details: Any) -> None: + self._client.__exit__(*exc_details) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/inference/_configuration.py b/.venv/lib/python3.12/site-packages/azure/ai/inference/_configuration.py new file mode 100644 index 00000000..894ec657 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/inference/_configuration.py @@ -0,0 +1,188 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) Python Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- + +from typing import Any, TYPE_CHECKING, Union + +from azure.core.credentials import AzureKeyCredential +from azure.core.pipeline import policies + +from ._version import VERSION + +if TYPE_CHECKING: + from azure.core.credentials import TokenCredential + + +class ChatCompletionsClientConfiguration: # pylint: disable=too-many-instance-attributes + """Configuration for ChatCompletionsClient. + + Note that all parameters used to create this instance are saved as instance + attributes. + + :param endpoint: Service host. Required. + :type endpoint: str + :param credential: Credential used to authenticate requests to the service. Is either a key + credential type or a token credential type. Required. + :type credential: ~azure.core.credentials.AzureKeyCredential or + ~azure.core.credentials.AzureKeyCredential or ~azure.core.credentials.TokenCredential + :keyword api_version: The API version to use for this operation. Default value is + "2024-05-01-preview". Note that overriding this default value may result in unsupported + behavior. + :paramtype api_version: str + """ + + def __init__(self, endpoint: str, credential: Union[AzureKeyCredential, "TokenCredential"], **kwargs: Any) -> None: + api_version: str = kwargs.pop("api_version", "2024-05-01-preview") + + if endpoint is None: + raise ValueError("Parameter 'endpoint' must not be None.") + if credential is None: + raise ValueError("Parameter 'credential' must not be None.") + + self.endpoint = endpoint + self.credential = credential + self.api_version = api_version + self.credential_scopes = kwargs.pop("credential_scopes", ["https://ml.azure.com/.default"]) + kwargs.setdefault("sdk_moniker", "ai-inference/{}".format(VERSION)) + self.polling_interval = kwargs.get("polling_interval", 30) + self._configure(**kwargs) + + def _infer_policy(self, **kwargs): + if isinstance(self.credential, AzureKeyCredential): + return policies.AzureKeyCredentialPolicy(self.credential, "Authorization", prefix="Bearer", **kwargs) + if isinstance(self.credential, AzureKeyCredential): + return policies.AzureKeyCredentialPolicy(self.credential, "api-key", **kwargs) + if hasattr(self.credential, "get_token"): + return policies.BearerTokenCredentialPolicy(self.credential, *self.credential_scopes, **kwargs) + raise TypeError(f"Unsupported credential: {self.credential}") + + def _configure(self, **kwargs: Any) -> None: + self.user_agent_policy = kwargs.get("user_agent_policy") or policies.UserAgentPolicy(**kwargs) + self.headers_policy = kwargs.get("headers_policy") or policies.HeadersPolicy(**kwargs) + self.proxy_policy = kwargs.get("proxy_policy") or policies.ProxyPolicy(**kwargs) + self.logging_policy = kwargs.get("logging_policy") or policies.NetworkTraceLoggingPolicy(**kwargs) + self.http_logging_policy = kwargs.get("http_logging_policy") or policies.HttpLoggingPolicy(**kwargs) + self.custom_hook_policy = kwargs.get("custom_hook_policy") or policies.CustomHookPolicy(**kwargs) + self.redirect_policy = kwargs.get("redirect_policy") or policies.RedirectPolicy(**kwargs) + self.retry_policy = kwargs.get("retry_policy") or policies.RetryPolicy(**kwargs) + self.authentication_policy = kwargs.get("authentication_policy") + if self.credential and not self.authentication_policy: + self.authentication_policy = self._infer_policy(**kwargs) + + +class EmbeddingsClientConfiguration: # pylint: disable=too-many-instance-attributes + """Configuration for EmbeddingsClient. + + Note that all parameters used to create this instance are saved as instance + attributes. + + :param endpoint: Service host. Required. + :type endpoint: str + :param credential: Credential used to authenticate requests to the service. Is either a key + credential type or a token credential type. Required. + :type credential: ~azure.core.credentials.AzureKeyCredential or + ~azure.core.credentials.AzureKeyCredential or ~azure.core.credentials.TokenCredential + :keyword api_version: The API version to use for this operation. Default value is + "2024-05-01-preview". Note that overriding this default value may result in unsupported + behavior. + :paramtype api_version: str + """ + + def __init__(self, endpoint: str, credential: Union[AzureKeyCredential, "TokenCredential"], **kwargs: Any) -> None: + api_version: str = kwargs.pop("api_version", "2024-05-01-preview") + + if endpoint is None: + raise ValueError("Parameter 'endpoint' must not be None.") + if credential is None: + raise ValueError("Parameter 'credential' must not be None.") + + self.endpoint = endpoint + self.credential = credential + self.api_version = api_version + self.credential_scopes = kwargs.pop("credential_scopes", ["https://ml.azure.com/.default"]) + kwargs.setdefault("sdk_moniker", "ai-inference/{}".format(VERSION)) + self.polling_interval = kwargs.get("polling_interval", 30) + self._configure(**kwargs) + + def _infer_policy(self, **kwargs): + if isinstance(self.credential, AzureKeyCredential): + return policies.AzureKeyCredentialPolicy(self.credential, "Authorization", prefix="Bearer", **kwargs) + if isinstance(self.credential, AzureKeyCredential): + return policies.AzureKeyCredentialPolicy(self.credential, "api-key", **kwargs) + if hasattr(self.credential, "get_token"): + return policies.BearerTokenCredentialPolicy(self.credential, *self.credential_scopes, **kwargs) + raise TypeError(f"Unsupported credential: {self.credential}") + + def _configure(self, **kwargs: Any) -> None: + self.user_agent_policy = kwargs.get("user_agent_policy") or policies.UserAgentPolicy(**kwargs) + self.headers_policy = kwargs.get("headers_policy") or policies.HeadersPolicy(**kwargs) + self.proxy_policy = kwargs.get("proxy_policy") or policies.ProxyPolicy(**kwargs) + self.logging_policy = kwargs.get("logging_policy") or policies.NetworkTraceLoggingPolicy(**kwargs) + self.http_logging_policy = kwargs.get("http_logging_policy") or policies.HttpLoggingPolicy(**kwargs) + self.custom_hook_policy = kwargs.get("custom_hook_policy") or policies.CustomHookPolicy(**kwargs) + self.redirect_policy = kwargs.get("redirect_policy") or policies.RedirectPolicy(**kwargs) + self.retry_policy = kwargs.get("retry_policy") or policies.RetryPolicy(**kwargs) + self.authentication_policy = kwargs.get("authentication_policy") + if self.credential and not self.authentication_policy: + self.authentication_policy = self._infer_policy(**kwargs) + + +class ImageEmbeddingsClientConfiguration: # pylint: disable=too-many-instance-attributes + """Configuration for ImageEmbeddingsClient. + + Note that all parameters used to create this instance are saved as instance + attributes. + + :param endpoint: Service host. Required. + :type endpoint: str + :param credential: Credential used to authenticate requests to the service. Is either a key + credential type or a token credential type. Required. + :type credential: ~azure.core.credentials.AzureKeyCredential or + ~azure.core.credentials.AzureKeyCredential or ~azure.core.credentials.TokenCredential + :keyword api_version: The API version to use for this operation. Default value is + "2024-05-01-preview". Note that overriding this default value may result in unsupported + behavior. + :paramtype api_version: str + """ + + def __init__(self, endpoint: str, credential: Union[AzureKeyCredential, "TokenCredential"], **kwargs: Any) -> None: + api_version: str = kwargs.pop("api_version", "2024-05-01-preview") + + if endpoint is None: + raise ValueError("Parameter 'endpoint' must not be None.") + if credential is None: + raise ValueError("Parameter 'credential' must not be None.") + + self.endpoint = endpoint + self.credential = credential + self.api_version = api_version + self.credential_scopes = kwargs.pop("credential_scopes", ["https://ml.azure.com/.default"]) + kwargs.setdefault("sdk_moniker", "ai-inference/{}".format(VERSION)) + self.polling_interval = kwargs.get("polling_interval", 30) + self._configure(**kwargs) + + def _infer_policy(self, **kwargs): + if isinstance(self.credential, AzureKeyCredential): + return policies.AzureKeyCredentialPolicy(self.credential, "Authorization", prefix="Bearer", **kwargs) + if isinstance(self.credential, AzureKeyCredential): + return policies.AzureKeyCredentialPolicy(self.credential, "api-key", **kwargs) + if hasattr(self.credential, "get_token"): + return policies.BearerTokenCredentialPolicy(self.credential, *self.credential_scopes, **kwargs) + raise TypeError(f"Unsupported credential: {self.credential}") + + def _configure(self, **kwargs: Any) -> None: + self.user_agent_policy = kwargs.get("user_agent_policy") or policies.UserAgentPolicy(**kwargs) + self.headers_policy = kwargs.get("headers_policy") or policies.HeadersPolicy(**kwargs) + self.proxy_policy = kwargs.get("proxy_policy") or policies.ProxyPolicy(**kwargs) + self.logging_policy = kwargs.get("logging_policy") or policies.NetworkTraceLoggingPolicy(**kwargs) + self.http_logging_policy = kwargs.get("http_logging_policy") or policies.HttpLoggingPolicy(**kwargs) + self.custom_hook_policy = kwargs.get("custom_hook_policy") or policies.CustomHookPolicy(**kwargs) + self.redirect_policy = kwargs.get("redirect_policy") or policies.RedirectPolicy(**kwargs) + self.retry_policy = kwargs.get("retry_policy") or policies.RetryPolicy(**kwargs) + self.authentication_policy = kwargs.get("authentication_policy") + if self.credential and not self.authentication_policy: + self.authentication_policy = self._infer_policy(**kwargs) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/inference/_model_base.py b/.venv/lib/python3.12/site-packages/azure/ai/inference/_model_base.py new file mode 100644 index 00000000..359ecebe --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/inference/_model_base.py @@ -0,0 +1,1235 @@ +# pylint: disable=too-many-lines,arguments-differ,signature-differs,no-member +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +# pylint: disable=protected-access, broad-except + +import copy +import calendar +import decimal +import functools +import sys +import logging +import base64 +import re +import typing +import enum +import email.utils +from datetime import datetime, date, time, timedelta, timezone +from json import JSONEncoder +import xml.etree.ElementTree as ET +from typing_extensions import Self +import isodate +from azure.core.exceptions import DeserializationError +from azure.core import CaseInsensitiveEnumMeta +from azure.core.pipeline import PipelineResponse +from azure.core.serialization import _Null + +if sys.version_info >= (3, 9): + from collections.abc import MutableMapping +else: + from typing import MutableMapping + +_LOGGER = logging.getLogger(__name__) + +__all__ = ["SdkJSONEncoder", "Model", "rest_field", "rest_discriminator"] + +TZ_UTC = timezone.utc +_T = typing.TypeVar("_T") + + +def _timedelta_as_isostr(td: timedelta) -> str: + """Converts a datetime.timedelta object into an ISO 8601 formatted string, e.g. 'P4DT12H30M05S' + + Function adapted from the Tin Can Python project: https://github.com/RusticiSoftware/TinCanPython + + :param timedelta td: The timedelta to convert + :rtype: str + :return: ISO8601 version of this timedelta + """ + + # Split seconds to larger units + seconds = td.total_seconds() + minutes, seconds = divmod(seconds, 60) + hours, minutes = divmod(minutes, 60) + days, hours = divmod(hours, 24) + + days, hours, minutes = list(map(int, (days, hours, minutes))) + seconds = round(seconds, 6) + + # Build date + date_str = "" + if days: + date_str = "%sD" % days + + if hours or minutes or seconds: + # Build time + time_str = "T" + + # Hours + bigger_exists = date_str or hours + if bigger_exists: + time_str += "{:02}H".format(hours) + + # Minutes + bigger_exists = bigger_exists or minutes + if bigger_exists: + time_str += "{:02}M".format(minutes) + + # Seconds + try: + if seconds.is_integer(): + seconds_string = "{:02}".format(int(seconds)) + else: + # 9 chars long w/ leading 0, 6 digits after decimal + seconds_string = "%09.6f" % seconds + # Remove trailing zeros + seconds_string = seconds_string.rstrip("0") + except AttributeError: # int.is_integer() raises + seconds_string = "{:02}".format(seconds) + + time_str += "{}S".format(seconds_string) + else: + time_str = "" + + return "P" + date_str + time_str + + +def _serialize_bytes(o, format: typing.Optional[str] = None) -> str: + encoded = base64.b64encode(o).decode() + if format == "base64url": + return encoded.strip("=").replace("+", "-").replace("/", "_") + return encoded + + +def _serialize_datetime(o, format: typing.Optional[str] = None): + if hasattr(o, "year") and hasattr(o, "hour"): + if format == "rfc7231": + return email.utils.format_datetime(o, usegmt=True) + if format == "unix-timestamp": + return int(calendar.timegm(o.utctimetuple())) + + # astimezone() fails for naive times in Python 2.7, so make make sure o is aware (tzinfo is set) + if not o.tzinfo: + iso_formatted = o.replace(tzinfo=TZ_UTC).isoformat() + else: + iso_formatted = o.astimezone(TZ_UTC).isoformat() + # Replace the trailing "+00:00" UTC offset with "Z" (RFC 3339: https://www.ietf.org/rfc/rfc3339.txt) + return iso_formatted.replace("+00:00", "Z") + # Next try datetime.date or datetime.time + return o.isoformat() + + +def _is_readonly(p): + try: + return p._visibility == ["read"] + except AttributeError: + return False + + +class SdkJSONEncoder(JSONEncoder): + """A JSON encoder that's capable of serializing datetime objects and bytes.""" + + def __init__(self, *args, exclude_readonly: bool = False, format: typing.Optional[str] = None, **kwargs): + super().__init__(*args, **kwargs) + self.exclude_readonly = exclude_readonly + self.format = format + + def default(self, o): # pylint: disable=too-many-return-statements + if _is_model(o): + if self.exclude_readonly: + readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] + return {k: v for k, v in o.items() if k not in readonly_props} + return dict(o.items()) + try: + return super(SdkJSONEncoder, self).default(o) + except TypeError: + if isinstance(o, _Null): + return None + if isinstance(o, decimal.Decimal): + return float(o) + if isinstance(o, (bytes, bytearray)): + return _serialize_bytes(o, self.format) + try: + # First try datetime.datetime + return _serialize_datetime(o, self.format) + except AttributeError: + pass + # Last, try datetime.timedelta + try: + return _timedelta_as_isostr(o) + except AttributeError: + # This will be raised when it hits value.total_seconds in the method above + pass + return super(SdkJSONEncoder, self).default(o) + + +_VALID_DATE = re.compile(r"\d{4}[-]\d{2}[-]\d{2}T\d{2}:\d{2}:\d{2}" + r"\.?\d*Z?[-+]?[\d{2}]?:?[\d{2}]?") +_VALID_RFC7231 = re.compile( + r"(Mon|Tue|Wed|Thu|Fri|Sat|Sun),\s\d{2}\s" + r"(Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)\s\d{4}\s\d{2}:\d{2}:\d{2}\sGMT" +) + + +def _deserialize_datetime(attr: typing.Union[str, datetime]) -> datetime: + """Deserialize ISO-8601 formatted string into Datetime object. + + :param str attr: response string to be deserialized. + :rtype: ~datetime.datetime + :returns: The datetime object from that input + """ + if isinstance(attr, datetime): + # i'm already deserialized + return attr + attr = attr.upper() + match = _VALID_DATE.match(attr) + if not match: + raise ValueError("Invalid datetime string: " + attr) + + check_decimal = attr.split(".") + if len(check_decimal) > 1: + decimal_str = "" + for digit in check_decimal[1]: + if digit.isdigit(): + decimal_str += digit + else: + break + if len(decimal_str) > 6: + attr = attr.replace(decimal_str, decimal_str[0:6]) + + date_obj = isodate.parse_datetime(attr) + test_utc = date_obj.utctimetuple() + if test_utc.tm_year > 9999 or test_utc.tm_year < 1: + raise OverflowError("Hit max or min date") + return date_obj + + +def _deserialize_datetime_rfc7231(attr: typing.Union[str, datetime]) -> datetime: + """Deserialize RFC7231 formatted string into Datetime object. + + :param str attr: response string to be deserialized. + :rtype: ~datetime.datetime + :returns: The datetime object from that input + """ + if isinstance(attr, datetime): + # i'm already deserialized + return attr + match = _VALID_RFC7231.match(attr) + if not match: + raise ValueError("Invalid datetime string: " + attr) + + return email.utils.parsedate_to_datetime(attr) + + +def _deserialize_datetime_unix_timestamp(attr: typing.Union[float, datetime]) -> datetime: + """Deserialize unix timestamp into Datetime object. + + :param str attr: response string to be deserialized. + :rtype: ~datetime.datetime + :returns: The datetime object from that input + """ + if isinstance(attr, datetime): + # i'm already deserialized + return attr + return datetime.fromtimestamp(attr, TZ_UTC) + + +def _deserialize_date(attr: typing.Union[str, date]) -> date: + """Deserialize ISO-8601 formatted string into Date object. + :param str attr: response string to be deserialized. + :rtype: date + :returns: The date object from that input + """ + # This must NOT use defaultmonth/defaultday. Using None ensure this raises an exception. + if isinstance(attr, date): + return attr + return isodate.parse_date(attr, defaultmonth=None, defaultday=None) # type: ignore + + +def _deserialize_time(attr: typing.Union[str, time]) -> time: + """Deserialize ISO-8601 formatted string into time object. + + :param str attr: response string to be deserialized. + :rtype: datetime.time + :returns: The time object from that input + """ + if isinstance(attr, time): + return attr + return isodate.parse_time(attr) + + +def _deserialize_bytes(attr): + if isinstance(attr, (bytes, bytearray)): + return attr + return bytes(base64.b64decode(attr)) + + +def _deserialize_bytes_base64(attr): + if isinstance(attr, (bytes, bytearray)): + return attr + padding = "=" * (3 - (len(attr) + 3) % 4) # type: ignore + attr = attr + padding # type: ignore + encoded = attr.replace("-", "+").replace("_", "/") + return bytes(base64.b64decode(encoded)) + + +def _deserialize_duration(attr): + if isinstance(attr, timedelta): + return attr + return isodate.parse_duration(attr) + + +def _deserialize_decimal(attr): + if isinstance(attr, decimal.Decimal): + return attr + return decimal.Decimal(str(attr)) + + +def _deserialize_int_as_str(attr): + if isinstance(attr, int): + return attr + return int(attr) + + +_DESERIALIZE_MAPPING = { + datetime: _deserialize_datetime, + date: _deserialize_date, + time: _deserialize_time, + bytes: _deserialize_bytes, + bytearray: _deserialize_bytes, + timedelta: _deserialize_duration, + typing.Any: lambda x: x, + decimal.Decimal: _deserialize_decimal, +} + +_DESERIALIZE_MAPPING_WITHFORMAT = { + "rfc3339": _deserialize_datetime, + "rfc7231": _deserialize_datetime_rfc7231, + "unix-timestamp": _deserialize_datetime_unix_timestamp, + "base64": _deserialize_bytes, + "base64url": _deserialize_bytes_base64, +} + + +def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] = None): + if annotation is int and rf and rf._format == "str": + return _deserialize_int_as_str + if rf and rf._format: + return _DESERIALIZE_MAPPING_WITHFORMAT.get(rf._format) + return _DESERIALIZE_MAPPING.get(annotation) # pyright: ignore + + +def _get_type_alias_type(module_name: str, alias_name: str): + types = { + k: v + for k, v in sys.modules[module_name].__dict__.items() + if isinstance(v, typing._GenericAlias) # type: ignore + } + if alias_name not in types: + return alias_name + return types[alias_name] + + +def _get_model(module_name: str, model_name: str): + models = {k: v for k, v in sys.modules[module_name].__dict__.items() if isinstance(v, type)} + module_end = module_name.rsplit(".", 1)[0] + models.update({k: v for k, v in sys.modules[module_end].__dict__.items() if isinstance(v, type)}) + if isinstance(model_name, str): + model_name = model_name.split(".")[-1] + if model_name not in models: + return model_name + return models[model_name] + + +_UNSET = object() + + +class _MyMutableMapping(MutableMapping[str, typing.Any]): # pylint: disable=unsubscriptable-object + def __init__(self, data: typing.Dict[str, typing.Any]) -> None: + self._data = data + + def __contains__(self, key: typing.Any) -> bool: + return key in self._data + + def __getitem__(self, key: str) -> typing.Any: + return self._data.__getitem__(key) + + def __setitem__(self, key: str, value: typing.Any) -> None: + self._data.__setitem__(key, value) + + def __delitem__(self, key: str) -> None: + self._data.__delitem__(key) + + def __iter__(self) -> typing.Iterator[typing.Any]: + return self._data.__iter__() + + def __len__(self) -> int: + return self._data.__len__() + + def __ne__(self, other: typing.Any) -> bool: + return not self.__eq__(other) + + def keys(self) -> typing.KeysView[str]: + """ + :returns: a set-like object providing a view on D's keys + :rtype: ~typing.KeysView + """ + return self._data.keys() + + def values(self) -> typing.ValuesView[typing.Any]: + """ + :returns: an object providing a view on D's values + :rtype: ~typing.ValuesView + """ + return self._data.values() + + def items(self) -> typing.ItemsView[str, typing.Any]: + """ + :returns: set-like object providing a view on D's items + :rtype: ~typing.ItemsView + """ + return self._data.items() + + def get(self, key: str, default: typing.Any = None) -> typing.Any: + """ + Get the value for key if key is in the dictionary, else default. + :param str key: The key to look up. + :param any default: The value to return if key is not in the dictionary. Defaults to None + :returns: D[k] if k in D, else d. + :rtype: any + """ + try: + return self[key] + except KeyError: + return default + + @typing.overload + def pop(self, key: str) -> typing.Any: ... + + @typing.overload + def pop(self, key: str, default: _T) -> _T: ... + + @typing.overload + def pop(self, key: str, default: typing.Any) -> typing.Any: ... + + def pop(self, key: str, default: typing.Any = _UNSET) -> typing.Any: + """ + Removes specified key and return the corresponding value. + :param str key: The key to pop. + :param any default: The value to return if key is not in the dictionary + :returns: The value corresponding to the key. + :rtype: any + :raises KeyError: If key is not found and default is not given. + """ + if default is _UNSET: + return self._data.pop(key) + return self._data.pop(key, default) + + def popitem(self) -> typing.Tuple[str, typing.Any]: + """ + Removes and returns some (key, value) pair + :returns: The (key, value) pair. + :rtype: tuple + :raises KeyError: if D is empty. + """ + return self._data.popitem() + + def clear(self) -> None: + """ + Remove all items from D. + """ + self._data.clear() + + def update(self, *args: typing.Any, **kwargs: typing.Any) -> None: + """ + Updates D from mapping/iterable E and F. + :param any args: Either a mapping object or an iterable of key-value pairs. + """ + self._data.update(*args, **kwargs) + + @typing.overload + def setdefault(self, key: str, default: None = None) -> None: ... + + @typing.overload + def setdefault(self, key: str, default: typing.Any) -> typing.Any: ... + + def setdefault(self, key: str, default: typing.Any = _UNSET) -> typing.Any: + """ + Same as calling D.get(k, d), and setting D[k]=d if k not found + :param str key: The key to look up. + :param any default: The value to set if key is not in the dictionary + :returns: D[k] if k in D, else d. + :rtype: any + """ + if default is _UNSET: + return self._data.setdefault(key) + return self._data.setdefault(key, default) + + def __eq__(self, other: typing.Any) -> bool: + try: + other_model = self.__class__(other) + except Exception: + return False + return self._data == other_model._data + + def __repr__(self) -> str: + return str(self._data) + + +def _is_model(obj: typing.Any) -> bool: + return getattr(obj, "_is_model", False) + + +def _serialize(o, format: typing.Optional[str] = None): # pylint: disable=too-many-return-statements + if isinstance(o, list): + return [_serialize(x, format) for x in o] + if isinstance(o, dict): + return {k: _serialize(v, format) for k, v in o.items()} + if isinstance(o, set): + return {_serialize(x, format) for x in o} + if isinstance(o, tuple): + return tuple(_serialize(x, format) for x in o) + if isinstance(o, (bytes, bytearray)): + return _serialize_bytes(o, format) + if isinstance(o, decimal.Decimal): + return float(o) + if isinstance(o, enum.Enum): + return o.value + if isinstance(o, int): + if format == "str": + return str(o) + return o + try: + # First try datetime.datetime + return _serialize_datetime(o, format) + except AttributeError: + pass + # Last, try datetime.timedelta + try: + return _timedelta_as_isostr(o) + except AttributeError: + # This will be raised when it hits value.total_seconds in the method above + pass + return o + + +def _get_rest_field( + attr_to_rest_field: typing.Dict[str, "_RestField"], rest_name: str +) -> typing.Optional["_RestField"]: + try: + return next(rf for rf in attr_to_rest_field.values() if rf._rest_name == rest_name) + except StopIteration: + return None + + +def _create_value(rf: typing.Optional["_RestField"], value: typing.Any) -> typing.Any: + if not rf: + return _serialize(value, None) + if rf._is_multipart_file_input: + return value + if rf._is_model: + return _deserialize(rf._type, value) + if isinstance(value, ET.Element): + value = _deserialize(rf._type, value) + return _serialize(value, rf._format) + + +class Model(_MyMutableMapping): + _is_model = True + # label whether current class's _attr_to_rest_field has been calculated + # could not see _attr_to_rest_field directly because subclass inherits it from parent class + _calculated: typing.Set[str] = set() + + def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: + class_name = self.__class__.__name__ + if len(args) > 1: + raise TypeError(f"{class_name}.__init__() takes 2 positional arguments but {len(args) + 1} were given") + dict_to_pass = { + rest_field._rest_name: rest_field._default + for rest_field in self._attr_to_rest_field.values() + if rest_field._default is not _UNSET + } + if args: # pylint: disable=too-many-nested-blocks + if isinstance(args[0], ET.Element): + existed_attr_keys = [] + model_meta = getattr(self, "_xml", {}) + + for rf in self._attr_to_rest_field.values(): + prop_meta = getattr(rf, "_xml", {}) + xml_name = prop_meta.get("name", rf._rest_name) + xml_ns = prop_meta.get("ns", model_meta.get("ns", None)) + if xml_ns: + xml_name = "{" + xml_ns + "}" + xml_name + + # attribute + if prop_meta.get("attribute", False) and args[0].get(xml_name) is not None: + existed_attr_keys.append(xml_name) + dict_to_pass[rf._rest_name] = _deserialize(rf._type, args[0].get(xml_name)) + continue + + # unwrapped element is array + if prop_meta.get("unwrapped", False): + # unwrapped array could either use prop items meta/prop meta + if prop_meta.get("itemsName"): + xml_name = prop_meta.get("itemsName") + xml_ns = prop_meta.get("itemNs") + if xml_ns: + xml_name = "{" + xml_ns + "}" + xml_name + items = args[0].findall(xml_name) # pyright: ignore + if len(items) > 0: + existed_attr_keys.append(xml_name) + dict_to_pass[rf._rest_name] = _deserialize(rf._type, items) + continue + + # text element is primitive type + if prop_meta.get("text", False): + if args[0].text is not None: + dict_to_pass[rf._rest_name] = _deserialize(rf._type, args[0].text) + continue + + # wrapped element could be normal property or array, it should only have one element + item = args[0].find(xml_name) + if item is not None: + existed_attr_keys.append(xml_name) + dict_to_pass[rf._rest_name] = _deserialize(rf._type, item) + + # rest thing is additional properties + for e in args[0]: + if e.tag not in existed_attr_keys: + dict_to_pass[e.tag] = _convert_element(e) + else: + dict_to_pass.update( + {k: _create_value(_get_rest_field(self._attr_to_rest_field, k), v) for k, v in args[0].items()} + ) + else: + non_attr_kwargs = [k for k in kwargs if k not in self._attr_to_rest_field] + if non_attr_kwargs: + # actual type errors only throw the first wrong keyword arg they see, so following that. + raise TypeError(f"{class_name}.__init__() got an unexpected keyword argument '{non_attr_kwargs[0]}'") + dict_to_pass.update( + { + self._attr_to_rest_field[k]._rest_name: _create_value(self._attr_to_rest_field[k], v) + for k, v in kwargs.items() + if v is not None + } + ) + super().__init__(dict_to_pass) + + def copy(self) -> "Model": + return Model(self.__dict__) + + def __new__(cls, *args: typing.Any, **kwargs: typing.Any) -> Self: + if f"{cls.__module__}.{cls.__qualname__}" not in cls._calculated: + # we know the last nine classes in mro are going to be 'Model', '_MyMutableMapping', 'MutableMapping', + # 'Mapping', 'Collection', 'Sized', 'Iterable', 'Container' and 'object' + mros = cls.__mro__[:-9][::-1] # ignore parents, and reverse the mro order + attr_to_rest_field: typing.Dict[str, _RestField] = { # map attribute name to rest_field property + k: v for mro_class in mros for k, v in mro_class.__dict__.items() if k[0] != "_" and hasattr(v, "_type") + } + annotations = { + k: v + for mro_class in mros + if hasattr(mro_class, "__annotations__") + for k, v in mro_class.__annotations__.items() + } + for attr, rf in attr_to_rest_field.items(): + rf._module = cls.__module__ + if not rf._type: + rf._type = rf._get_deserialize_callable_from_annotation(annotations.get(attr, None)) + if not rf._rest_name_input: + rf._rest_name_input = attr + cls._attr_to_rest_field: typing.Dict[str, _RestField] = dict(attr_to_rest_field.items()) + cls._calculated.add(f"{cls.__module__}.{cls.__qualname__}") + + return super().__new__(cls) # pylint: disable=no-value-for-parameter + + def __init_subclass__(cls, discriminator: typing.Optional[str] = None) -> None: + for base in cls.__bases__: + if hasattr(base, "__mapping__"): + base.__mapping__[discriminator or cls.__name__] = cls # type: ignore + + @classmethod + def _get_discriminator(cls, exist_discriminators) -> typing.Optional["_RestField"]: + for v in cls.__dict__.values(): + if isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators: + return v + return None + + @classmethod + def _deserialize(cls, data, exist_discriminators): + if not hasattr(cls, "__mapping__"): + return cls(data) + discriminator = cls._get_discriminator(exist_discriminators) + if discriminator is None: + return cls(data) + exist_discriminators.append(discriminator._rest_name) + if isinstance(data, ET.Element): + model_meta = getattr(cls, "_xml", {}) + prop_meta = getattr(discriminator, "_xml", {}) + xml_name = prop_meta.get("name", discriminator._rest_name) + xml_ns = prop_meta.get("ns", model_meta.get("ns", None)) + if xml_ns: + xml_name = "{" + xml_ns + "}" + xml_name + + if data.get(xml_name) is not None: + discriminator_value = data.get(xml_name) + else: + discriminator_value = data.find(xml_name).text # pyright: ignore + else: + discriminator_value = data.get(discriminator._rest_name) + mapped_cls = cls.__mapping__.get(discriminator_value, cls) # pyright: ignore + return mapped_cls._deserialize(data, exist_discriminators) + + def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.Any]: + """Return a dict that can be turned into json using json.dump. + + :keyword bool exclude_readonly: Whether to remove the readonly properties. + :returns: A dict JSON compatible object + :rtype: dict + """ + + result = {} + readonly_props = [] + if exclude_readonly: + readonly_props = [p._rest_name for p in self._attr_to_rest_field.values() if _is_readonly(p)] + for k, v in self.items(): + if exclude_readonly and k in readonly_props: # pyright: ignore + continue + is_multipart_file_input = False + try: + is_multipart_file_input = next( + rf for rf in self._attr_to_rest_field.values() if rf._rest_name == k + )._is_multipart_file_input + except StopIteration: + pass + result[k] = v if is_multipart_file_input else Model._as_dict_value(v, exclude_readonly=exclude_readonly) + return result + + @staticmethod + def _as_dict_value(v: typing.Any, exclude_readonly: bool = False) -> typing.Any: + if v is None or isinstance(v, _Null): + return None + if isinstance(v, (list, tuple, set)): + return type(v)(Model._as_dict_value(x, exclude_readonly=exclude_readonly) for x in v) + if isinstance(v, dict): + return {dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly) for dk, dv in v.items()} + return v.as_dict(exclude_readonly=exclude_readonly) if hasattr(v, "as_dict") else v + + +def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj): + if _is_model(obj): + return obj + return _deserialize(model_deserializer, obj) + + +def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Callable], obj): + if obj is None: + return obj + return _deserialize_with_callable(if_obj_deserializer, obj) + + +def _deserialize_with_union(deserializers, obj): + for deserializer in deserializers: + try: + return _deserialize(deserializer, obj) + except DeserializationError: + pass + raise DeserializationError() + + +def _deserialize_dict( + value_deserializer: typing.Optional[typing.Callable], + module: typing.Optional[str], + obj: typing.Dict[typing.Any, typing.Any], +): + if obj is None: + return obj + if isinstance(obj, ET.Element): + obj = {child.tag: child for child in obj} + return {k: _deserialize(value_deserializer, v, module) for k, v in obj.items()} + + +def _deserialize_multiple_sequence( + entry_deserializers: typing.List[typing.Optional[typing.Callable]], + module: typing.Optional[str], + obj, +): + if obj is None: + return obj + return type(obj)(_deserialize(deserializer, entry, module) for entry, deserializer in zip(obj, entry_deserializers)) + + +def _deserialize_sequence( + deserializer: typing.Optional[typing.Callable], + module: typing.Optional[str], + obj, +): + if obj is None: + return obj + if isinstance(obj, ET.Element): + obj = list(obj) + return type(obj)(_deserialize(deserializer, entry, module) for entry in obj) + + +def _sorted_annotations(types: typing.List[typing.Any]) -> typing.List[typing.Any]: + return sorted( + types, + key=lambda x: hasattr(x, "__name__") and x.__name__.lower() in ("str", "float", "int", "bool"), + ) + + +def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-return-statements, too-many-branches + annotation: typing.Any, + module: typing.Optional[str], + rf: typing.Optional["_RestField"] = None, +) -> typing.Optional[typing.Callable[[typing.Any], typing.Any]]: + if not annotation: + return None + + # is it a type alias? + if isinstance(annotation, str): + if module is not None: + annotation = _get_type_alias_type(module, annotation) + + # is it a forward ref / in quotes? + if isinstance(annotation, (str, typing.ForwardRef)): + try: + model_name = annotation.__forward_arg__ # type: ignore + except AttributeError: + model_name = annotation + if module is not None: + annotation = _get_model(module, model_name) # type: ignore + + try: + if module and _is_model(annotation): + if rf: + rf._is_model = True + + return functools.partial(_deserialize_model, annotation) # pyright: ignore + except Exception: + pass + + # is it a literal? + try: + if annotation.__origin__ is typing.Literal: # pyright: ignore + return None + except AttributeError: + pass + + # is it optional? + try: + if any(a for a in annotation.__args__ if a == type(None)): # pyright: ignore + if len(annotation.__args__) <= 2: # pyright: ignore + if_obj_deserializer = _get_deserialize_callable_from_annotation( + next(a for a in annotation.__args__ if a != type(None)), module, rf # pyright: ignore + ) + + return functools.partial(_deserialize_with_optional, if_obj_deserializer) + # the type is Optional[Union[...]], we need to remove the None type from the Union + annotation_copy = copy.copy(annotation) + annotation_copy.__args__ = [a for a in annotation_copy.__args__ if a != type(None)] # pyright: ignore + return _get_deserialize_callable_from_annotation(annotation_copy, module, rf) + except AttributeError: + pass + + # is it union? + if getattr(annotation, "__origin__", None) is typing.Union: + # initial ordering is we make `string` the last deserialization option, because it is often them most generic + deserializers = [ + _get_deserialize_callable_from_annotation(arg, module, rf) + for arg in _sorted_annotations(annotation.__args__) # pyright: ignore + ] + + return functools.partial(_deserialize_with_union, deserializers) + + try: + if annotation._name == "Dict": # pyright: ignore + value_deserializer = _get_deserialize_callable_from_annotation( + annotation.__args__[1], module, rf # pyright: ignore + ) + + return functools.partial( + _deserialize_dict, + value_deserializer, + module, + ) + except (AttributeError, IndexError): + pass + try: + if annotation._name in ["List", "Set", "Tuple", "Sequence"]: # pyright: ignore + if len(annotation.__args__) > 1: # pyright: ignore + entry_deserializers = [ + _get_deserialize_callable_from_annotation(dt, module, rf) + for dt in annotation.__args__ # pyright: ignore + ] + return functools.partial(_deserialize_multiple_sequence, entry_deserializers, module) + deserializer = _get_deserialize_callable_from_annotation( + annotation.__args__[0], module, rf # pyright: ignore + ) + + return functools.partial(_deserialize_sequence, deserializer, module) + except (TypeError, IndexError, AttributeError, SyntaxError): + pass + + def _deserialize_default( + deserializer, + obj, + ): + if obj is None: + return obj + try: + return _deserialize_with_callable(deserializer, obj) + except Exception: + pass + return obj + + if get_deserializer(annotation, rf): + return functools.partial(_deserialize_default, get_deserializer(annotation, rf)) + + return functools.partial(_deserialize_default, annotation) + + +def _deserialize_with_callable( + deserializer: typing.Optional[typing.Callable[[typing.Any], typing.Any]], + value: typing.Any, +): # pylint: disable=too-many-return-statements + try: + if value is None or isinstance(value, _Null): + return None + if isinstance(value, ET.Element): + if deserializer is str: + return value.text or "" + if deserializer is int: + return int(value.text) if value.text else None + if deserializer is float: + return float(value.text) if value.text else None + if deserializer is bool: + return value.text == "true" if value.text else None + if deserializer is None: + return value + if deserializer in [int, float, bool]: + return deserializer(value) + if isinstance(deserializer, CaseInsensitiveEnumMeta): + try: + return deserializer(value) + except ValueError: + # for unknown value, return raw value + return value + if isinstance(deserializer, type) and issubclass(deserializer, Model): + return deserializer._deserialize(value, []) + return typing.cast(typing.Callable[[typing.Any], typing.Any], deserializer)(value) + except Exception as e: + raise DeserializationError() from e + + +def _deserialize( + deserializer: typing.Any, + value: typing.Any, + module: typing.Optional[str] = None, + rf: typing.Optional["_RestField"] = None, + format: typing.Optional[str] = None, +) -> typing.Any: + if isinstance(value, PipelineResponse): + value = value.http_response.json() + if rf is None and format: + rf = _RestField(format=format) + if not isinstance(deserializer, functools.partial): + deserializer = _get_deserialize_callable_from_annotation(deserializer, module, rf) + return _deserialize_with_callable(deserializer, value) + + +def _failsafe_deserialize( + deserializer: typing.Any, + value: typing.Any, + module: typing.Optional[str] = None, + rf: typing.Optional["_RestField"] = None, + format: typing.Optional[str] = None, +) -> typing.Any: + try: + return _deserialize(deserializer, value, module, rf, format) + except DeserializationError: + _LOGGER.warning( + "Ran into a deserialization error. Ignoring since this is failsafe deserialization", exc_info=True + ) + return None + + +def _failsafe_deserialize_xml( + deserializer: typing.Any, + value: typing.Any, +) -> typing.Any: + try: + return _deserialize_xml(deserializer, value) + except DeserializationError: + _LOGGER.warning( + "Ran into a deserialization error. Ignoring since this is failsafe deserialization", exc_info=True + ) + return None + + +class _RestField: + def __init__( + self, + *, + name: typing.Optional[str] = None, + type: typing.Optional[typing.Callable] = None, # pylint: disable=redefined-builtin + is_discriminator: bool = False, + visibility: typing.Optional[typing.List[str]] = None, + default: typing.Any = _UNSET, + format: typing.Optional[str] = None, + is_multipart_file_input: bool = False, + xml: typing.Optional[typing.Dict[str, typing.Any]] = None, + ): + self._type = type + self._rest_name_input = name + self._module: typing.Optional[str] = None + self._is_discriminator = is_discriminator + self._visibility = visibility + self._is_model = False + self._default = default + self._format = format + self._is_multipart_file_input = is_multipart_file_input + self._xml = xml if xml is not None else {} + + @property + def _class_type(self) -> typing.Any: + return getattr(self._type, "args", [None])[0] + + @property + def _rest_name(self) -> str: + if self._rest_name_input is None: + raise ValueError("Rest name was never set") + return self._rest_name_input + + def __get__(self, obj: Model, type=None): # pylint: disable=redefined-builtin + # by this point, type and rest_name will have a value bc we default + # them in __new__ of the Model class + item = obj.get(self._rest_name) + if item is None: + return item + if self._is_model: + return item + return _deserialize(self._type, _serialize(item, self._format), rf=self) + + def __set__(self, obj: Model, value) -> None: + if value is None: + # we want to wipe out entries if users set attr to None + try: + obj.__delitem__(self._rest_name) + except KeyError: + pass + return + if self._is_model: + if not _is_model(value): + value = _deserialize(self._type, value) + obj.__setitem__(self._rest_name, value) + return + obj.__setitem__(self._rest_name, _serialize(value, self._format)) + + def _get_deserialize_callable_from_annotation( + self, annotation: typing.Any + ) -> typing.Optional[typing.Callable[[typing.Any], typing.Any]]: + return _get_deserialize_callable_from_annotation(annotation, self._module, self) + + +def rest_field( + *, + name: typing.Optional[str] = None, + type: typing.Optional[typing.Callable] = None, # pylint: disable=redefined-builtin + visibility: typing.Optional[typing.List[str]] = None, + default: typing.Any = _UNSET, + format: typing.Optional[str] = None, + is_multipart_file_input: bool = False, + xml: typing.Optional[typing.Dict[str, typing.Any]] = None, +) -> typing.Any: + return _RestField( + name=name, + type=type, + visibility=visibility, + default=default, + format=format, + is_multipart_file_input=is_multipart_file_input, + xml=xml, + ) + + +def rest_discriminator( + *, + name: typing.Optional[str] = None, + type: typing.Optional[typing.Callable] = None, # pylint: disable=redefined-builtin + visibility: typing.Optional[typing.List[str]] = None, + xml: typing.Optional[typing.Dict[str, typing.Any]] = None, +) -> typing.Any: + return _RestField(name=name, type=type, is_discriminator=True, visibility=visibility, xml=xml) + + +def serialize_xml(model: Model, exclude_readonly: bool = False) -> str: + """Serialize a model to XML. + + :param Model model: The model to serialize. + :param bool exclude_readonly: Whether to exclude readonly properties. + :returns: The XML representation of the model. + :rtype: str + """ + return ET.tostring(_get_element(model, exclude_readonly), encoding="unicode") # type: ignore + + +def _get_element( + o: typing.Any, + exclude_readonly: bool = False, + parent_meta: typing.Optional[typing.Dict[str, typing.Any]] = None, + wrapped_element: typing.Optional[ET.Element] = None, +) -> typing.Union[ET.Element, typing.List[ET.Element]]: + if _is_model(o): + model_meta = getattr(o, "_xml", {}) + + # if prop is a model, then use the prop element directly, else generate a wrapper of model + if wrapped_element is None: + wrapped_element = _create_xml_element( + model_meta.get("name", o.__class__.__name__), + model_meta.get("prefix"), + model_meta.get("ns"), + ) + + readonly_props = [] + if exclude_readonly: + readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] + + for k, v in o.items(): + # do not serialize readonly properties + if exclude_readonly and k in readonly_props: + continue + + prop_rest_field = _get_rest_field(o._attr_to_rest_field, k) + if prop_rest_field: + prop_meta = getattr(prop_rest_field, "_xml").copy() + # use the wire name as xml name if no specific name is set + if prop_meta.get("name") is None: + prop_meta["name"] = k + else: + # additional properties will not have rest field, use the wire name as xml name + prop_meta = {"name": k} + + # if no ns for prop, use model's + if prop_meta.get("ns") is None and model_meta.get("ns"): + prop_meta["ns"] = model_meta.get("ns") + prop_meta["prefix"] = model_meta.get("prefix") + + if prop_meta.get("unwrapped", False): + # unwrapped could only set on array + wrapped_element.extend(_get_element(v, exclude_readonly, prop_meta)) + elif prop_meta.get("text", False): + # text could only set on primitive type + wrapped_element.text = _get_primitive_type_value(v) + elif prop_meta.get("attribute", False): + xml_name = prop_meta.get("name", k) + if prop_meta.get("ns"): + ET.register_namespace(prop_meta.get("prefix"), prop_meta.get("ns")) # pyright: ignore + xml_name = "{" + prop_meta.get("ns") + "}" + xml_name # pyright: ignore + # attribute should be primitive type + wrapped_element.set(xml_name, _get_primitive_type_value(v)) + else: + # other wrapped prop element + wrapped_element.append(_get_wrapped_element(v, exclude_readonly, prop_meta)) + return wrapped_element + if isinstance(o, list): + return [_get_element(x, exclude_readonly, parent_meta) for x in o] # type: ignore + if isinstance(o, dict): + result = [] + for k, v in o.items(): + result.append( + _get_wrapped_element( + v, + exclude_readonly, + { + "name": k, + "ns": parent_meta.get("ns") if parent_meta else None, + "prefix": parent_meta.get("prefix") if parent_meta else None, + }, + ) + ) + return result + + # primitive case need to create element based on parent_meta + if parent_meta: + return _get_wrapped_element( + o, + exclude_readonly, + { + "name": parent_meta.get("itemsName", parent_meta.get("name")), + "prefix": parent_meta.get("itemsPrefix", parent_meta.get("prefix")), + "ns": parent_meta.get("itemsNs", parent_meta.get("ns")), + }, + ) + + raise ValueError("Could not serialize value into xml: " + o) + + +def _get_wrapped_element( + v: typing.Any, + exclude_readonly: bool, + meta: typing.Optional[typing.Dict[str, typing.Any]], +) -> ET.Element: + wrapped_element = _create_xml_element( + meta.get("name") if meta else None, meta.get("prefix") if meta else None, meta.get("ns") if meta else None + ) + if isinstance(v, (dict, list)): + wrapped_element.extend(_get_element(v, exclude_readonly, meta)) + elif _is_model(v): + _get_element(v, exclude_readonly, meta, wrapped_element) + else: + wrapped_element.text = _get_primitive_type_value(v) + return wrapped_element + + +def _get_primitive_type_value(v) -> str: + if v is True: + return "true" + if v is False: + return "false" + if isinstance(v, _Null): + return "" + return str(v) + + +def _create_xml_element(tag, prefix=None, ns=None): + if prefix and ns: + ET.register_namespace(prefix, ns) + if ns: + return ET.Element("{" + ns + "}" + tag) + return ET.Element(tag) + + +def _deserialize_xml( + deserializer: typing.Any, + value: str, +) -> typing.Any: + element = ET.fromstring(value) # nosec + return _deserialize(deserializer, element) + + +def _convert_element(e: ET.Element): + # dict case + if len(e.attrib) > 0 or len({child.tag for child in e}) > 1: + dict_result: typing.Dict[str, typing.Any] = {} + for child in e: + if dict_result.get(child.tag) is not None: + if isinstance(dict_result[child.tag], list): + dict_result[child.tag].append(_convert_element(child)) + else: + dict_result[child.tag] = [dict_result[child.tag], _convert_element(child)] + else: + dict_result[child.tag] = _convert_element(child) + dict_result.update(e.attrib) + return dict_result + # array case + if len(e) > 0: + array_result: typing.List[typing.Any] = [] + for child in e: + array_result.append(_convert_element(child)) + return array_result + # primitive case + return e.text diff --git a/.venv/lib/python3.12/site-packages/azure/ai/inference/_operations/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/inference/_operations/__init__.py new file mode 100644 index 00000000..ab870887 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/inference/_operations/__init__.py @@ -0,0 +1,29 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) Python Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +# pylint: disable=wrong-import-position + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ._patch import * # pylint: disable=unused-wildcard-import + +from ._operations import ChatCompletionsClientOperationsMixin # type: ignore +from ._operations import EmbeddingsClientOperationsMixin # type: ignore +from ._operations import ImageEmbeddingsClientOperationsMixin # type: ignore + +from ._patch import __all__ as _patch_all +from ._patch import * +from ._patch import patch_sdk as _patch_sdk + +__all__ = [ + "ChatCompletionsClientOperationsMixin", + "EmbeddingsClientOperationsMixin", + "ImageEmbeddingsClientOperationsMixin", +] +__all__.extend([p for p in _patch_all if p not in __all__]) # pyright: ignore +_patch_sdk() diff --git a/.venv/lib/python3.12/site-packages/azure/ai/inference/_operations/_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/inference/_operations/_operations.py new file mode 100644 index 00000000..78e5ee35 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/inference/_operations/_operations.py @@ -0,0 +1,912 @@ +# pylint: disable=too-many-locals +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) Python Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +from io import IOBase +import json +import sys +from typing import Any, Callable, Dict, IO, List, Optional, TypeVar, Union, overload + +from azure.core.exceptions import ( + ClientAuthenticationError, + HttpResponseError, + ResourceExistsError, + ResourceNotFoundError, + ResourceNotModifiedError, + StreamClosedError, + StreamConsumedError, + map_error, +) +from azure.core.pipeline import PipelineResponse +from azure.core.rest import HttpRequest, HttpResponse +from azure.core.tracing.decorator import distributed_trace +from azure.core.utils import case_insensitive_dict + +from .. import models as _models +from .._model_base import SdkJSONEncoder, _deserialize +from .._serialization import Serializer +from .._vendor import ChatCompletionsClientMixinABC, EmbeddingsClientMixinABC, ImageEmbeddingsClientMixinABC + +if sys.version_info >= (3, 9): + from collections.abc import MutableMapping +else: + from typing import MutableMapping # type: ignore +JSON = MutableMapping[str, Any] # pylint: disable=unsubscriptable-object +_Unset: Any = object() +T = TypeVar("T") +ClsType = Optional[Callable[[PipelineResponse[HttpRequest, HttpResponse], T, Dict[str, Any]], Any]] + +_SERIALIZER = Serializer() +_SERIALIZER.client_side_validation = False + + +def build_chat_completions_complete_request( + *, extra_params: Optional[Union[str, _models._enums.ExtraParameters]] = None, **kwargs: Any +) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-01-preview")) + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/chat/completions" + + # Construct parameters + _params["api-version"] = _SERIALIZER.query("api_version", api_version, "str") + + # Construct headers + if extra_params is not None: + _headers["extra-parameters"] = _SERIALIZER.header("extra_params", extra_params, "str") + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="POST", url=_url, params=_params, headers=_headers, **kwargs) + + +def build_chat_completions_get_model_info_request(**kwargs: Any) -> HttpRequest: # pylint: disable=name-too-long + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) + + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-01-preview")) + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/info" + + # Construct parameters + _params["api-version"] = _SERIALIZER.query("api_version", api_version, "str") + + # Construct headers + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="GET", url=_url, params=_params, headers=_headers, **kwargs) + + +def build_embeddings_embed_request( + *, extra_params: Optional[Union[str, _models._enums.ExtraParameters]] = None, **kwargs: Any +) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-01-preview")) + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/embeddings" + + # Construct parameters + _params["api-version"] = _SERIALIZER.query("api_version", api_version, "str") + + # Construct headers + if extra_params is not None: + _headers["extra-parameters"] = _SERIALIZER.header("extra_params", extra_params, "str") + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="POST", url=_url, params=_params, headers=_headers, **kwargs) + + +def build_embeddings_get_model_info_request(**kwargs: Any) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) + + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-01-preview")) + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/info" + + # Construct parameters + _params["api-version"] = _SERIALIZER.query("api_version", api_version, "str") + + # Construct headers + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="GET", url=_url, params=_params, headers=_headers, **kwargs) + + +def build_image_embeddings_embed_request( + *, extra_params: Optional[Union[str, _models._enums.ExtraParameters]] = None, **kwargs: Any +) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-01-preview")) + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/images/embeddings" + + # Construct parameters + _params["api-version"] = _SERIALIZER.query("api_version", api_version, "str") + + # Construct headers + if extra_params is not None: + _headers["extra-parameters"] = _SERIALIZER.header("extra_params", extra_params, "str") + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="POST", url=_url, params=_params, headers=_headers, **kwargs) + + +def build_image_embeddings_get_model_info_request(**kwargs: Any) -> HttpRequest: # pylint: disable=name-too-long + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) + + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-01-preview")) + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/info" + + # Construct parameters + _params["api-version"] = _SERIALIZER.query("api_version", api_version, "str") + + # Construct headers + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="GET", url=_url, params=_params, headers=_headers, **kwargs) + + +class ChatCompletionsClientOperationsMixin(ChatCompletionsClientMixinABC): + + @overload + def _complete( + self, + *, + messages: List[_models._models.ChatRequestMessage], + extra_params: Optional[Union[str, _models._enums.ExtraParameters]] = None, + content_type: str = "application/json", + frequency_penalty: Optional[float] = None, + stream_parameter: Optional[bool] = None, + presence_penalty: Optional[float] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + max_tokens: Optional[int] = None, + response_format: Optional[_models._models.ChatCompletionsResponseFormat] = None, + stop: Optional[List[str]] = None, + tools: Optional[List[_models.ChatCompletionsToolDefinition]] = None, + tool_choice: Optional[ + Union[str, _models.ChatCompletionsToolChoicePreset, _models.ChatCompletionsNamedToolChoice] + ] = None, + seed: Optional[int] = None, + model: Optional[str] = None, + **kwargs: Any + ) -> _models.ChatCompletions: ... + @overload + def _complete( + self, + body: JSON, + *, + extra_params: Optional[Union[str, _models._enums.ExtraParameters]] = None, + content_type: str = "application/json", + **kwargs: Any + ) -> _models.ChatCompletions: ... + @overload + def _complete( + self, + body: IO[bytes], + *, + extra_params: Optional[Union[str, _models._enums.ExtraParameters]] = None, + content_type: str = "application/json", + **kwargs: Any + ) -> _models.ChatCompletions: ... + + @distributed_trace + def _complete( + self, + body: Union[JSON, IO[bytes]] = _Unset, + *, + messages: List[_models._models.ChatRequestMessage] = _Unset, + extra_params: Optional[Union[str, _models._enums.ExtraParameters]] = None, + frequency_penalty: Optional[float] = None, + stream_parameter: Optional[bool] = None, + presence_penalty: Optional[float] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + max_tokens: Optional[int] = None, + response_format: Optional[_models._models.ChatCompletionsResponseFormat] = None, + stop: Optional[List[str]] = None, + tools: Optional[List[_models.ChatCompletionsToolDefinition]] = None, + tool_choice: Optional[ + Union[str, _models.ChatCompletionsToolChoicePreset, _models.ChatCompletionsNamedToolChoice] + ] = None, + seed: Optional[int] = None, + model: Optional[str] = None, + **kwargs: Any + ) -> _models.ChatCompletions: + """Gets chat completions for the provided chat messages. + Completions support a wide variety of tasks and generate text that continues from or + "completes" + provided prompt data. The method makes a REST API call to the ``/chat/completions`` route + on the given endpoint. + + :param body: Is either a JSON type or a IO[bytes] type. Required. + :type body: JSON or IO[bytes] + :keyword messages: The collection of context messages associated with this chat completions + request. + Typical usage begins with a chat message for the System role that provides instructions for + the behavior of the assistant, followed by alternating messages between the User and + Assistant roles. Required. + :paramtype messages: list[~azure.ai.inference.models._models.ChatRequestMessage] + :keyword extra_params: Controls what happens if extra parameters, undefined by the REST API, + are passed in the JSON request payload. + This sets the HTTP request header ``extra-parameters``. Known values are: "error", "drop", and + "pass-through". Default value is None. + :paramtype extra_params: str or ~azure.ai.inference.models.ExtraParameters + :keyword frequency_penalty: A value that influences the probability of generated tokens + appearing based on their cumulative + frequency in generated text. + Positive values will make tokens less likely to appear as their frequency increases and + decrease the likelihood of the model repeating the same statements verbatim. + Supported range is [-2, 2]. Default value is None. + :paramtype frequency_penalty: float + :keyword stream_parameter: A value indicating whether chat completions should be streamed for + this request. Default value is None. + :paramtype stream_parameter: bool + :keyword presence_penalty: A value that influences the probability of generated tokens + appearing based on their existing + presence in generated text. + Positive values will make tokens less likely to appear when they already exist and increase + the + model's likelihood to output new topics. + Supported range is [-2, 2]. Default value is None. + :paramtype presence_penalty: float + :keyword temperature: The sampling temperature to use that controls the apparent creativity of + generated completions. + Higher values will make output more random while lower values will make results more focused + and deterministic. + It is not recommended to modify temperature and top_p for the same completions request as the + interaction of these two settings is difficult to predict. + Supported range is [0, 1]. Default value is None. + :paramtype temperature: float + :keyword top_p: An alternative to sampling with temperature called nucleus sampling. This value + causes the + model to consider the results of tokens with the provided probability mass. As an example, a + value of 0.15 will cause only the tokens comprising the top 15% of probability mass to be + considered. + It is not recommended to modify temperature and top_p for the same completions request as the + interaction of these two settings is difficult to predict. + Supported range is [0, 1]. Default value is None. + :paramtype top_p: float + :keyword max_tokens: The maximum number of tokens to generate. Default value is None. + :paramtype max_tokens: int + :keyword response_format: An object specifying the format that the model must output. + + Setting to ``{ "type": "json_schema", "json_schema": {...} }`` enables Structured Outputs + which ensures the model will match your supplied JSON schema. + + Setting to ``{ "type": "json_object" }`` enables JSON mode, which ensures the message the + model generates is valid JSON. + + **Important:** when using JSON mode, you **must** also instruct the model to produce JSON + yourself via a system or user message. Without this, the model may generate an unending stream + of whitespace until the generation reaches the token limit, resulting in a long-running and + seemingly "stuck" request. Also note that the message content may be partially cut off if + ``finish_reason="length"``\\ , which indicates the generation exceeded ``max_tokens`` or the + conversation exceeded the max context length. Default value is None. + :paramtype response_format: ~azure.ai.inference.models._models.ChatCompletionsResponseFormat + :keyword stop: A collection of textual sequences that will end completions generation. Default + value is None. + :paramtype stop: list[str] + :keyword tools: A list of tools the model may request to call. Currently, only functions are + supported as a tool. The model + may response with a function call request and provide the input arguments in JSON format for + that function. Default value is None. + :paramtype tools: list[~azure.ai.inference.models.ChatCompletionsToolDefinition] + :keyword tool_choice: If specified, the model will configure which of the provided tools it can + use for the chat completions response. Is either a Union[str, + "_models.ChatCompletionsToolChoicePreset"] type or a ChatCompletionsNamedToolChoice type. + Default value is None. + :paramtype tool_choice: str or ~azure.ai.inference.models.ChatCompletionsToolChoicePreset or + ~azure.ai.inference.models.ChatCompletionsNamedToolChoice + :keyword seed: If specified, the system will make a best effort to sample deterministically + such that repeated requests with the + same seed and parameters should return the same result. Determinism is not guaranteed. Default + value is None. + :paramtype seed: int + :keyword model: ID of the specific AI model to use, if more than one model is available on the + endpoint. Default value is None. + :paramtype model: str + :return: ChatCompletions. The ChatCompletions is compatible with MutableMapping + :rtype: ~azure.ai.inference.models.ChatCompletions + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[_models.ChatCompletions] = kwargs.pop("cls", None) + + if body is _Unset: + if messages is _Unset: + raise TypeError("missing required argument: messages") + body = { + "frequency_penalty": frequency_penalty, + "max_tokens": max_tokens, + "messages": messages, + "model": model, + "presence_penalty": presence_penalty, + "response_format": response_format, + "seed": seed, + "stop": stop, + "stream": stream_parameter, + "temperature": temperature, + "tool_choice": tool_choice, + "tools": tools, + "top_p": top_p, + } + body = {k: v for k, v in body.items() if v is not None} + content_type = content_type or "application/json" + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _content = json.dumps(body, cls=SdkJSONEncoder, exclude_readonly=True) # type: ignore + + _request = build_chat_completions_complete_request( + extra_params=extra_params, + content_type=content_type, + api_version=self._config.api_version, + content=_content, + headers=_headers, + params=_params, + ) + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + _request.url = self._client.format_url(_request.url, **path_format_arguments) + + _stream = kwargs.pop("stream", False) + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + if _stream: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if _stream: + deserialized = response.iter_bytes() + else: + deserialized = _deserialize(_models.ChatCompletions, response.json()) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @distributed_trace + def _get_model_info(self, **kwargs: Any) -> _models.ModelInfo: + """Returns information about the AI model. + The method makes a REST API call to the ``/info`` route on the given endpoint. + This method will only work when using Serverless API or Managed Compute endpoint. + It will not work for GitHub Models endpoint or Azure OpenAI endpoint. + + :return: ModelInfo. The ModelInfo is compatible with MutableMapping + :rtype: ~azure.ai.inference.models.ModelInfo + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[_models.ModelInfo] = kwargs.pop("cls", None) + + _request = build_chat_completions_get_model_info_request( + api_version=self._config.api_version, + headers=_headers, + params=_params, + ) + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + _request.url = self._client.format_url(_request.url, **path_format_arguments) + + _stream = kwargs.pop("stream", False) + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + if _stream: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if _stream: + deserialized = response.iter_bytes() + else: + deserialized = _deserialize(_models.ModelInfo, response.json()) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + +class EmbeddingsClientOperationsMixin(EmbeddingsClientMixinABC): + + @overload + def _embed( + self, + *, + input: List[str], + extra_params: Optional[Union[str, _models._enums.ExtraParameters]] = None, + content_type: str = "application/json", + dimensions: Optional[int] = None, + encoding_format: Optional[Union[str, _models.EmbeddingEncodingFormat]] = None, + input_type: Optional[Union[str, _models.EmbeddingInputType]] = None, + model: Optional[str] = None, + **kwargs: Any + ) -> _models.EmbeddingsResult: ... + @overload + def _embed( + self, + body: JSON, + *, + extra_params: Optional[Union[str, _models._enums.ExtraParameters]] = None, + content_type: str = "application/json", + **kwargs: Any + ) -> _models.EmbeddingsResult: ... + @overload + def _embed( + self, + body: IO[bytes], + *, + extra_params: Optional[Union[str, _models._enums.ExtraParameters]] = None, + content_type: str = "application/json", + **kwargs: Any + ) -> _models.EmbeddingsResult: ... + + @distributed_trace + def _embed( + self, + body: Union[JSON, IO[bytes]] = _Unset, + *, + input: List[str] = _Unset, + extra_params: Optional[Union[str, _models._enums.ExtraParameters]] = None, + dimensions: Optional[int] = None, + encoding_format: Optional[Union[str, _models.EmbeddingEncodingFormat]] = None, + input_type: Optional[Union[str, _models.EmbeddingInputType]] = None, + model: Optional[str] = None, + **kwargs: Any + ) -> _models.EmbeddingsResult: + """Return the embedding vectors for given text prompts. + The method makes a REST API call to the ``/embeddings`` route on the given endpoint. + + :param body: Is either a JSON type or a IO[bytes] type. Required. + :type body: JSON or IO[bytes] + :keyword input: Input text to embed, encoded as a string or array of tokens. + To embed multiple inputs in a single request, pass an array + of strings or array of token arrays. Required. + :paramtype input: list[str] + :keyword extra_params: Controls what happens if extra parameters, undefined by the REST API, + are passed in the JSON request payload. + This sets the HTTP request header ``extra-parameters``. Known values are: "error", "drop", and + "pass-through". Default value is None. + :paramtype extra_params: str or ~azure.ai.inference.models.ExtraParameters + :keyword dimensions: Optional. The number of dimensions the resulting output embeddings should + have. + Passing null causes the model to use its default value. + Returns a 422 error if the model doesn't support the value or parameter. Default value is + None. + :paramtype dimensions: int + :keyword encoding_format: Optional. The desired format for the returned embeddings. Known + values are: "base64", "binary", "float", "int8", "ubinary", and "uint8". Default value is None. + :paramtype encoding_format: str or ~azure.ai.inference.models.EmbeddingEncodingFormat + :keyword input_type: Optional. The type of the input. + Returns a 422 error if the model doesn't support the value or parameter. Known values are: + "text", "query", and "document". Default value is None. + :paramtype input_type: str or ~azure.ai.inference.models.EmbeddingInputType + :keyword model: ID of the specific AI model to use, if more than one model is available on the + endpoint. Default value is None. + :paramtype model: str + :return: EmbeddingsResult. The EmbeddingsResult is compatible with MutableMapping + :rtype: ~azure.ai.inference.models.EmbeddingsResult + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[_models.EmbeddingsResult] = kwargs.pop("cls", None) + + if body is _Unset: + if input is _Unset: + raise TypeError("missing required argument: input") + body = { + "dimensions": dimensions, + "encoding_format": encoding_format, + "input": input, + "input_type": input_type, + "model": model, + } + body = {k: v for k, v in body.items() if v is not None} + content_type = content_type or "application/json" + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _content = json.dumps(body, cls=SdkJSONEncoder, exclude_readonly=True) # type: ignore + + _request = build_embeddings_embed_request( + extra_params=extra_params, + content_type=content_type, + api_version=self._config.api_version, + content=_content, + headers=_headers, + params=_params, + ) + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + _request.url = self._client.format_url(_request.url, **path_format_arguments) + + _stream = kwargs.pop("stream", False) + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + if _stream: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if _stream: + deserialized = response.iter_bytes() + else: + deserialized = _deserialize(_models.EmbeddingsResult, response.json()) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @distributed_trace + def _get_model_info(self, **kwargs: Any) -> _models.ModelInfo: + """Returns information about the AI model. + The method makes a REST API call to the ``/info`` route on the given endpoint. + This method will only work when using Serverless API or Managed Compute endpoint. + It will not work for GitHub Models endpoint or Azure OpenAI endpoint. + + :return: ModelInfo. The ModelInfo is compatible with MutableMapping + :rtype: ~azure.ai.inference.models.ModelInfo + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[_models.ModelInfo] = kwargs.pop("cls", None) + + _request = build_embeddings_get_model_info_request( + api_version=self._config.api_version, + headers=_headers, + params=_params, + ) + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + _request.url = self._client.format_url(_request.url, **path_format_arguments) + + _stream = kwargs.pop("stream", False) + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + if _stream: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if _stream: + deserialized = response.iter_bytes() + else: + deserialized = _deserialize(_models.ModelInfo, response.json()) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + +class ImageEmbeddingsClientOperationsMixin(ImageEmbeddingsClientMixinABC): + + @overload + def _embed( + self, + *, + input: List[_models.ImageEmbeddingInput], + extra_params: Optional[Union[str, _models._enums.ExtraParameters]] = None, + content_type: str = "application/json", + dimensions: Optional[int] = None, + encoding_format: Optional[Union[str, _models.EmbeddingEncodingFormat]] = None, + input_type: Optional[Union[str, _models.EmbeddingInputType]] = None, + model: Optional[str] = None, + **kwargs: Any + ) -> _models.EmbeddingsResult: ... + @overload + def _embed( + self, + body: JSON, + *, + extra_params: Optional[Union[str, _models._enums.ExtraParameters]] = None, + content_type: str = "application/json", + **kwargs: Any + ) -> _models.EmbeddingsResult: ... + @overload + def _embed( + self, + body: IO[bytes], + *, + extra_params: Optional[Union[str, _models._enums.ExtraParameters]] = None, + content_type: str = "application/json", + **kwargs: Any + ) -> _models.EmbeddingsResult: ... + + @distributed_trace + def _embed( + self, + body: Union[JSON, IO[bytes]] = _Unset, + *, + input: List[_models.ImageEmbeddingInput] = _Unset, + extra_params: Optional[Union[str, _models._enums.ExtraParameters]] = None, + dimensions: Optional[int] = None, + encoding_format: Optional[Union[str, _models.EmbeddingEncodingFormat]] = None, + input_type: Optional[Union[str, _models.EmbeddingInputType]] = None, + model: Optional[str] = None, + **kwargs: Any + ) -> _models.EmbeddingsResult: + """Return the embedding vectors for given images. + The method makes a REST API call to the ``/images/embeddings`` route on the given endpoint. + + :param body: Is either a JSON type or a IO[bytes] type. Required. + :type body: JSON or IO[bytes] + :keyword input: Input image to embed. To embed multiple inputs in a single request, pass an + array. + The input must not exceed the max input tokens for the model. Required. + :paramtype input: list[~azure.ai.inference.models.ImageEmbeddingInput] + :keyword extra_params: Controls what happens if extra parameters, undefined by the REST API, + are passed in the JSON request payload. + This sets the HTTP request header ``extra-parameters``. Known values are: "error", "drop", and + "pass-through". Default value is None. + :paramtype extra_params: str or ~azure.ai.inference.models.ExtraParameters + :keyword dimensions: Optional. The number of dimensions the resulting output embeddings should + have. + Passing null causes the model to use its default value. + Returns a 422 error if the model doesn't support the value or parameter. Default value is + None. + :paramtype dimensions: int + :keyword encoding_format: Optional. The number of dimensions the resulting output embeddings + should have. + Passing null causes the model to use its default value. + Returns a 422 error if the model doesn't support the value or parameter. Known values are: + "base64", "binary", "float", "int8", "ubinary", and "uint8". Default value is None. + :paramtype encoding_format: str or ~azure.ai.inference.models.EmbeddingEncodingFormat + :keyword input_type: Optional. The type of the input. + Returns a 422 error if the model doesn't support the value or parameter. Known values are: + "text", "query", and "document". Default value is None. + :paramtype input_type: str or ~azure.ai.inference.models.EmbeddingInputType + :keyword model: ID of the specific AI model to use, if more than one model is available on the + endpoint. Default value is None. + :paramtype model: str + :return: EmbeddingsResult. The EmbeddingsResult is compatible with MutableMapping + :rtype: ~azure.ai.inference.models.EmbeddingsResult + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[_models.EmbeddingsResult] = kwargs.pop("cls", None) + + if body is _Unset: + if input is _Unset: + raise TypeError("missing required argument: input") + body = { + "dimensions": dimensions, + "encoding_format": encoding_format, + "input": input, + "input_type": input_type, + "model": model, + } + body = {k: v for k, v in body.items() if v is not None} + content_type = content_type or "application/json" + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _content = json.dumps(body, cls=SdkJSONEncoder, exclude_readonly=True) # type: ignore + + _request = build_image_embeddings_embed_request( + extra_params=extra_params, + content_type=content_type, + api_version=self._config.api_version, + content=_content, + headers=_headers, + params=_params, + ) + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + _request.url = self._client.format_url(_request.url, **path_format_arguments) + + _stream = kwargs.pop("stream", False) + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + if _stream: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if _stream: + deserialized = response.iter_bytes() + else: + deserialized = _deserialize(_models.EmbeddingsResult, response.json()) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @distributed_trace + def _get_model_info(self, **kwargs: Any) -> _models.ModelInfo: + """Returns information about the AI model. + The method makes a REST API call to the ``/info`` route on the given endpoint. + This method will only work when using Serverless API or Managed Compute endpoint. + It will not work for GitHub Models endpoint or Azure OpenAI endpoint. + + :return: ModelInfo. The ModelInfo is compatible with MutableMapping + :rtype: ~azure.ai.inference.models.ModelInfo + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[_models.ModelInfo] = kwargs.pop("cls", None) + + _request = build_image_embeddings_get_model_info_request( + api_version=self._config.api_version, + headers=_headers, + params=_params, + ) + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + _request.url = self._client.format_url(_request.url, **path_format_arguments) + + _stream = kwargs.pop("stream", False) + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + if _stream: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if _stream: + deserialized = response.iter_bytes() + else: + deserialized = _deserialize(_models.ModelInfo, response.json()) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore diff --git a/.venv/lib/python3.12/site-packages/azure/ai/inference/_operations/_patch.py b/.venv/lib/python3.12/site-packages/azure/ai/inference/_operations/_patch.py new file mode 100644 index 00000000..f7dd3251 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/inference/_operations/_patch.py @@ -0,0 +1,20 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +"""Customize generated code here. + +Follow our quickstart for examples: https://aka.ms/azsdk/python/dpcodegen/python/customize +""" +from typing import List + +__all__: List[str] = [] # Add all objects you want publicly available to users at this package level + + +def patch_sdk(): + """Do not remove from this file. + + `patch_sdk` is a last resort escape hatch that allows you to do customizations + you can't accomplish using the techniques described in + https://aka.ms/azsdk/python/dpcodegen/python/customize + """ diff --git a/.venv/lib/python3.12/site-packages/azure/ai/inference/_patch.py b/.venv/lib/python3.12/site-packages/azure/ai/inference/_patch.py new file mode 100644 index 00000000..da95cf93 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/inference/_patch.py @@ -0,0 +1,1387 @@ +# pylint: disable=too-many-lines +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +"""Customize generated code here. + +Follow our quickstart for examples: https://aka.ms/azsdk/python/dpcodegen/python/customize + +Why do we patch auto-generated code? Below is a summary of the changes made in all _patch files (not just this one): +1. Add support for input argument `model_extras` (all clients) +2. Add support for function load_client +3. Add support for setting sticky chat completions/embeddings input arguments in the client constructor +4. Add support for get_model_info, while caching the result (all clients) +5. Add support for chat completion streaming (ChatCompletionsClient client only) +6. Add support for friendly print of result objects (__str__ method) (all clients) +7. Add support for load() method in ImageUrl class (see /models/_patch.py) +8. Add support for sending two auth headers for api-key auth (all clients) +9. Simplify how chat completions "response_format" is set. Define "response_format" as a flat Union of strings and + JsonSchemaFormat object, instead of using auto-generated base/derived classes named + ChatCompletionsResponseFormatXxxInternal. +10. Allow UserMessage("my message") in addition to UserMessage(content="my message"). Same applies to +AssistantMessage, SystemMessage, DeveloperMessage and ToolMessage. + +""" +import json +import logging +import sys + +from io import IOBase +from typing import Any, Dict, Union, IO, List, Literal, Optional, overload, Type, TYPE_CHECKING, Iterable + +from azure.core.pipeline import PipelineResponse +from azure.core.credentials import AzureKeyCredential +from azure.core.tracing.decorator import distributed_trace +from azure.core.utils import case_insensitive_dict +from azure.core.exceptions import ( + ClientAuthenticationError, + HttpResponseError, + map_error, + ResourceExistsError, + ResourceNotFoundError, + ResourceNotModifiedError, +) +from . import models as _models +from ._model_base import SdkJSONEncoder, _deserialize +from ._serialization import Serializer +from ._operations._operations import ( + build_chat_completions_complete_request, + build_embeddings_embed_request, + build_image_embeddings_embed_request, +) +from ._client import ChatCompletionsClient as ChatCompletionsClientGenerated +from ._client import EmbeddingsClient as EmbeddingsClientGenerated +from ._client import ImageEmbeddingsClient as ImageEmbeddingsClientGenerated + +if sys.version_info >= (3, 9): + from collections.abc import MutableMapping +else: + from typing import MutableMapping # type: ignore # pylint: disable=ungrouped-imports + +if TYPE_CHECKING: + # pylint: disable=unused-import,ungrouped-imports + from azure.core.credentials import TokenCredential + +JSON = MutableMapping[str, Any] # pylint: disable=unsubscriptable-object +_Unset: Any = object() + +_SERIALIZER = Serializer() +_SERIALIZER.client_side_validation = False + +_LOGGER = logging.getLogger(__name__) + + +def _get_internal_response_format( + response_format: Optional[Union[Literal["text", "json_object"], _models.JsonSchemaFormat]] +) -> Optional[_models._models.ChatCompletionsResponseFormat]: + """ + Internal helper method to convert between the public response format type that's supported in the `complete` method, + and the internal response format type that's used in the generated code. + + :param response_format: Response format. Required. + :type response_format: Optional[Union[Literal["text", "json_object"], _models.JsonSchemaFormat]] + :return: Internal response format. + :rtype: ~azure.ai.inference._models._models.ChatCompletionsResponseFormat + """ + if response_format is not None: + + # To make mypy tool happy, start by declaring the type as the base class + internal_response_format: _models._models.ChatCompletionsResponseFormat + + if isinstance(response_format, str) and response_format == "text": + internal_response_format = ( + _models._models.ChatCompletionsResponseFormatText() # pylint: disable=protected-access + ) + elif isinstance(response_format, str) and response_format == "json_object": + internal_response_format = ( + _models._models.ChatCompletionsResponseFormatJsonObject() # pylint: disable=protected-access + ) + elif isinstance(response_format, _models.JsonSchemaFormat): + internal_response_format = ( + _models._models.ChatCompletionsResponseFormatJsonSchema( # pylint: disable=protected-access + json_schema=response_format + ) + ) + else: + raise ValueError(f"Unsupported `response_format` {response_format}") + + return internal_response_format + + return None + + +def load_client( + endpoint: str, credential: Union[AzureKeyCredential, "TokenCredential"], **kwargs: Any +) -> Union["ChatCompletionsClient", "EmbeddingsClient", "ImageEmbeddingsClient"]: + """ + Load a client from a given endpoint URL. The method makes a REST API call to the `/info` route + on the given endpoint, to determine the model type and therefore which client to instantiate. + Keyword arguments are passed to the appropriate client's constructor, so if you need to set things like + `api_version`, `logging_enable`, `user_agent`, etc., you can do so here. + This method will only work when using Serverless API or Managed Compute endpoint. + It will not work for GitHub Models endpoint or Azure OpenAI endpoint. + Keyword arguments are passed through to the client constructor (you can set keywords such as + `api_version`, `user_agent`, `logging_enable` etc. on the client constructor). + + :param endpoint: Service endpoint URL for AI model inference. Required. + :type endpoint: str + :param credential: Credential used to authenticate requests to the service. Is either a + AzureKeyCredential type or a TokenCredential type. Required. + :type credential: ~azure.core.credentials.AzureKeyCredential or + ~azure.core.credentials.TokenCredential + :return: The appropriate synchronous client associated with the given endpoint + :rtype: ~azure.ai.inference.ChatCompletionsClient or ~azure.ai.inference.EmbeddingsClient + or ~azure.ai.inference.ImageEmbeddingsClient + :raises ~azure.core.exceptions.HttpResponseError: + """ + + with ChatCompletionsClient( + endpoint, credential, **kwargs + ) as client: # Pick any of the clients, it does not matter. + try: + model_info = client.get_model_info() # type: ignore + except ResourceNotFoundError as error: + error.message = ( + "`load_client` function does not work on this endpoint (`/info` route not supported). " + "Please construct one of the clients (e.g. `ChatCompletionsClient`) directly." + ) + raise error + + _LOGGER.info("model_info=%s", model_info) + if not model_info.model_type: + raise ValueError( + "The AI model information is missing a value for `model type`. Cannot create an appropriate client." + ) + + # TODO: Remove "completions", "chat-comletions" and "embedding" once Mistral Large and Cohere fixes their model type + if model_info.model_type in ( + _models.ModelType.CHAT_COMPLETION, + "chat_completions", + "chat", + "completion", + "chat-completion", + "chat-completions", + "chat completion", + "chat completions", + ): + chat_completion_client = ChatCompletionsClient(endpoint, credential, **kwargs) + chat_completion_client._model_info = ( # pylint: disable=protected-access,attribute-defined-outside-init + model_info + ) + return chat_completion_client + + if model_info.model_type in ( + _models.ModelType.EMBEDDINGS, + "embedding", + "text_embedding", + "text-embeddings", + "text embedding", + "text embeddings", + ): + embedding_client = EmbeddingsClient(endpoint, credential, **kwargs) + embedding_client._model_info = model_info # pylint: disable=protected-access,attribute-defined-outside-init + return embedding_client + + if model_info.model_type in ( + _models.ModelType.IMAGE_EMBEDDINGS, + "image_embedding", + "image-embeddings", + "image-embedding", + "image embedding", + "image embeddings", + ): + image_embedding_client = ImageEmbeddingsClient(endpoint, credential, **kwargs) + image_embedding_client._model_info = ( # pylint: disable=protected-access,attribute-defined-outside-init + model_info + ) + return image_embedding_client + + raise ValueError(f"No client available to support AI model type `{model_info.model_type}`") + + +class ChatCompletionsClient(ChatCompletionsClientGenerated): # pylint: disable=too-many-instance-attributes + """ChatCompletionsClient. + + :param endpoint: Service endpoint URL for AI model inference. Required. + :type endpoint: str + :param credential: Credential used to authenticate requests to the service. Is either a + AzureKeyCredential type or a TokenCredential type. Required. + :type credential: ~azure.core.credentials.AzureKeyCredential or + ~azure.core.credentials.TokenCredential + :keyword frequency_penalty: A value that influences the probability of generated tokens + appearing based on their cumulative frequency in generated text. + Positive values will make tokens less likely to appear as their frequency increases and + decrease the likelihood of the model repeating the same statements verbatim. + Supported range is [-2, 2]. + Default value is None. + :paramtype frequency_penalty: float + :keyword presence_penalty: A value that influences the probability of generated tokens + appearing based on their existing + presence in generated text. + Positive values will make tokens less likely to appear when they already exist and increase + the model's likelihood to output new topics. + Supported range is [-2, 2]. + Default value is None. + :paramtype presence_penalty: float + :keyword temperature: The sampling temperature to use that controls the apparent creativity of + generated completions. + Higher values will make output more random while lower values will make results more focused + and deterministic. + It is not recommended to modify temperature and top_p for the same completions request as the + interaction of these two settings is difficult to predict. + Supported range is [0, 1]. + Default value is None. + :paramtype temperature: float + :keyword top_p: An alternative to sampling with temperature called nucleus sampling. This value + causes the + model to consider the results of tokens with the provided probability mass. As an example, a + value of 0.15 will cause only the tokens comprising the top 15% of probability mass to be + considered. + It is not recommended to modify temperature and top_p for the same completions request as the + interaction of these two settings is difficult to predict. + Supported range is [0, 1]. + Default value is None. + :paramtype top_p: float + :keyword max_tokens: The maximum number of tokens to generate. Default value is None. + :paramtype max_tokens: int + :keyword response_format: The format that the AI model must output. AI chat completions models typically output + unformatted text by default. This is equivalent to setting "text" as the response_format. + To output JSON format, without adhering to any schema, set to "json_object". + To output JSON format adhering to a provided schema, set this to an object of the class + ~azure.ai.inference.models.JsonSchemaFormat. Default value is None. + :paramtype response_format: Union[Literal['text', 'json_object'], ~azure.ai.inference.models.JsonSchemaFormat] + :keyword stop: A collection of textual sequences that will end completions generation. Default + value is None. + :paramtype stop: list[str] + :keyword tools: The available tool definitions that the chat completions request can use, + including caller-defined functions. Default value is None. + :paramtype tools: list[~azure.ai.inference.models.ChatCompletionsToolDefinition] + :keyword tool_choice: If specified, the model will configure which of the provided tools it can + use for the chat completions response. Is either a Union[str, + "_models.ChatCompletionsToolChoicePreset"] type or a ChatCompletionsNamedToolChoice type. + Default value is None. + :paramtype tool_choice: str or ~azure.ai.inference.models.ChatCompletionsToolChoicePreset or + ~azure.ai.inference.models.ChatCompletionsNamedToolChoice + :keyword seed: If specified, the system will make a best effort to sample deterministically + such that repeated requests with the + same seed and parameters should return the same result. Determinism is not guaranteed. + Default value is None. + :paramtype seed: int + :keyword model: ID of the specific AI model to use, if more than one model is available on the + endpoint. Default value is None. + :paramtype model: str + :keyword model_extras: Additional, model-specific parameters that are not in the + standard request payload. They will be added as-is to the root of the JSON in the request body. + How the service handles these extra parameters depends on the value of the + ``extra-parameters`` request header. Default value is None. + :paramtype model_extras: dict[str, Any] + :keyword api_version: The API version to use for this operation. Default value is + "2024-05-01-preview". Note that overriding this default value may result in unsupported + behavior. + :paramtype api_version: str + """ + + def __init__( + self, + endpoint: str, + credential: Union[AzureKeyCredential, "TokenCredential"], + *, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + max_tokens: Optional[int] = None, + response_format: Optional[Union[Literal["text", "json_object"], _models.JsonSchemaFormat]] = None, + stop: Optional[List[str]] = None, + tools: Optional[List[_models.ChatCompletionsToolDefinition]] = None, + tool_choice: Optional[ + Union[str, _models.ChatCompletionsToolChoicePreset, _models.ChatCompletionsNamedToolChoice] + ] = None, + seed: Optional[int] = None, + model: Optional[str] = None, + model_extras: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> None: + + self._model_info: Optional[_models.ModelInfo] = None + + # Store default chat completions settings, to be applied in all future service calls + # unless overridden by arguments in the `complete` method. + self._frequency_penalty = frequency_penalty + self._presence_penalty = presence_penalty + self._temperature = temperature + self._top_p = top_p + self._max_tokens = max_tokens + self._internal_response_format = _get_internal_response_format(response_format) + self._stop = stop + self._tools = tools + self._tool_choice = tool_choice + self._seed = seed + self._model = model + self._model_extras = model_extras + + # For Key auth, we need to send these two auth HTTP request headers simultaneously: + # 1. "Authorization: Bearer <key>" + # 2. "api-key: <key>" + # This is because Serverless API, Managed Compute and GitHub endpoints support the first header, + # and Azure OpenAI and the new Unified Inference endpoints support the second header. + # The first header will be taken care of by auto-generated code. + # The second one is added here. + if isinstance(credential, AzureKeyCredential): + headers = kwargs.pop("headers", {}) + if "api-key" not in headers: + headers["api-key"] = credential.key + kwargs["headers"] = headers + + super().__init__(endpoint, credential, **kwargs) + + @overload + def complete( + self, + *, + messages: Union[List[_models.ChatRequestMessage], List[Dict[str, Any]]], + stream: Literal[False] = False, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + max_tokens: Optional[int] = None, + response_format: Optional[Union[Literal["text", "json_object"], _models.JsonSchemaFormat]] = None, + stop: Optional[List[str]] = None, + tools: Optional[List[_models.ChatCompletionsToolDefinition]] = None, + tool_choice: Optional[ + Union[str, _models.ChatCompletionsToolChoicePreset, _models.ChatCompletionsNamedToolChoice] + ] = None, + seed: Optional[int] = None, + model: Optional[str] = None, + model_extras: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> _models.ChatCompletions: ... + + @overload + def complete( + self, + *, + messages: Union[List[_models.ChatRequestMessage], List[Dict[str, Any]]], + stream: Literal[True], + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + max_tokens: Optional[int] = None, + response_format: Optional[Union[Literal["text", "json_object"], _models.JsonSchemaFormat]] = None, + stop: Optional[List[str]] = None, + tools: Optional[List[_models.ChatCompletionsToolDefinition]] = None, + tool_choice: Optional[ + Union[str, _models.ChatCompletionsToolChoicePreset, _models.ChatCompletionsNamedToolChoice] + ] = None, + seed: Optional[int] = None, + model: Optional[str] = None, + model_extras: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> Iterable[_models.StreamingChatCompletionsUpdate]: ... + + @overload + def complete( + self, + *, + messages: Union[List[_models.ChatRequestMessage], List[Dict[str, Any]]], + stream: Optional[bool] = None, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + max_tokens: Optional[int] = None, + response_format: Optional[Union[Literal["text", "json_object"], _models.JsonSchemaFormat]] = None, + stop: Optional[List[str]] = None, + tools: Optional[List[_models.ChatCompletionsToolDefinition]] = None, + tool_choice: Optional[ + Union[str, _models.ChatCompletionsToolChoicePreset, _models.ChatCompletionsNamedToolChoice] + ] = None, + seed: Optional[int] = None, + model: Optional[str] = None, + model_extras: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> Union[Iterable[_models.StreamingChatCompletionsUpdate], _models.ChatCompletions]: + # pylint: disable=line-too-long + """Gets chat completions for the provided chat messages. + Completions support a wide variety of tasks and generate text that continues from or + "completes" provided prompt data. The method makes a REST API call to the `/chat/completions` route + on the given endpoint. + When using this method with `stream=True`, the response is streamed + back to the client. Iterate over the resulting StreamingChatCompletions + object to get content updates as they arrive. By default, the response is a ChatCompletions object + (non-streaming). + + :keyword messages: The collection of context messages associated with this chat completions + request. + Typical usage begins with a chat message for the System role that provides instructions for + the behavior of the assistant, followed by alternating messages between the User and + Assistant roles. Required. + :paramtype messages: list[~azure.ai.inference.models.ChatRequestMessage] or list[dict[str, Any]] + :keyword stream: A value indicating whether chat completions should be streamed for this request. + Default value is False. If streaming is enabled, the response will be a StreamingChatCompletions. + Otherwise the response will be a ChatCompletions. + :paramtype stream: bool + :keyword frequency_penalty: A value that influences the probability of generated tokens + appearing based on their cumulative frequency in generated text. + Positive values will make tokens less likely to appear as their frequency increases and + decrease the likelihood of the model repeating the same statements verbatim. + Supported range is [-2, 2]. + Default value is None. + :paramtype frequency_penalty: float + :keyword presence_penalty: A value that influences the probability of generated tokens + appearing based on their existing + presence in generated text. + Positive values will make tokens less likely to appear when they already exist and increase + the model's likelihood to output new topics. + Supported range is [-2, 2]. + Default value is None. + :paramtype presence_penalty: float + :keyword temperature: The sampling temperature to use that controls the apparent creativity of + generated completions. + Higher values will make output more random while lower values will make results more focused + and deterministic. + It is not recommended to modify temperature and top_p for the same completions request as the + interaction of these two settings is difficult to predict. + Supported range is [0, 1]. + Default value is None. + :paramtype temperature: float + :keyword top_p: An alternative to sampling with temperature called nucleus sampling. This value + causes the + model to consider the results of tokens with the provided probability mass. As an example, a + value of 0.15 will cause only the tokens comprising the top 15% of probability mass to be + considered. + It is not recommended to modify temperature and top_p for the same completions request as the + interaction of these two settings is difficult to predict. + Supported range is [0, 1]. + Default value is None. + :paramtype top_p: float + :keyword max_tokens: The maximum number of tokens to generate. Default value is None. + :paramtype max_tokens: int + :keyword response_format: The format that the AI model must output. AI chat completions models typically output + unformatted text by default. This is equivalent to setting "text" as the response_format. + To output JSON format, without adhering to any schema, set to "json_object". + To output JSON format adhering to a provided schema, set this to an object of the class + ~azure.ai.inference.models.JsonSchemaFormat. Default value is None. + :paramtype response_format: Union[Literal['text', 'json_object'], ~azure.ai.inference.models.JsonSchemaFormat] + :keyword stop: A collection of textual sequences that will end completions generation. Default + value is None. + :paramtype stop: list[str] + :keyword tools: The available tool definitions that the chat completions request can use, + including caller-defined functions. Default value is None. + :paramtype tools: list[~azure.ai.inference.models.ChatCompletionsToolDefinition] + :keyword tool_choice: If specified, the model will configure which of the provided tools it can + use for the chat completions response. Is either a Union[str, + "_models.ChatCompletionsToolChoicePreset"] type or a ChatCompletionsNamedToolChoice type. + Default value is None. + :paramtype tool_choice: str or ~azure.ai.inference.models.ChatCompletionsToolChoicePreset or + ~azure.ai.inference.models.ChatCompletionsNamedToolChoice + :keyword seed: If specified, the system will make a best effort to sample deterministically + such that repeated requests with the + same seed and parameters should return the same result. Determinism is not guaranteed. + Default value is None. + :paramtype seed: int + :keyword model: ID of the specific AI model to use, if more than one model is available on the + endpoint. Default value is None. + :paramtype model: str + :keyword model_extras: Additional, model-specific parameters that are not in the + standard request payload. They will be added as-is to the root of the JSON in the request body. + How the service handles these extra parameters depends on the value of the + ``extra-parameters`` request header. Default value is None. + :paramtype model_extras: dict[str, Any] + :return: ChatCompletions for non-streaming, or Iterable[StreamingChatCompletionsUpdate] for streaming. + :rtype: ~azure.ai.inference.models.ChatCompletions or ~azure.ai.inference.models.StreamingChatCompletions + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def complete( + self, + body: JSON, + *, + content_type: str = "application/json", + **kwargs: Any, + ) -> Union[Iterable[_models.StreamingChatCompletionsUpdate], _models.ChatCompletions]: + # pylint: disable=line-too-long + """Gets chat completions for the provided chat messages. + Completions support a wide variety of tasks and generate text that continues from or + "completes" provided prompt data. + + :param body: An object of type MutableMapping[str, Any], such as a dictionary, that + specifies the full request payload. Required. + :type body: JSON + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: ChatCompletions for non-streaming, or Iterable[StreamingChatCompletionsUpdate] for streaming. + :rtype: ~azure.ai.inference.models.ChatCompletions or ~azure.ai.inference.models.StreamingChatCompletions + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def complete( + self, + body: IO[bytes], + *, + content_type: str = "application/json", + **kwargs: Any, + ) -> Union[Iterable[_models.StreamingChatCompletionsUpdate], _models.ChatCompletions]: + # pylint: disable=line-too-long + # pylint: disable=too-many-locals + """Gets chat completions for the provided chat messages. + Completions support a wide variety of tasks and generate text that continues from or + "completes" provided prompt data. + + :param body: Specifies the full request payload. Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: ChatCompletions for non-streaming, or Iterable[StreamingChatCompletionsUpdate] for streaming. + :rtype: ~azure.ai.inference.models.ChatCompletions or ~azure.ai.inference.models.StreamingChatCompletions + :raises ~azure.core.exceptions.HttpResponseError: + """ + + # pylint:disable=client-method-missing-tracing-decorator + def complete( + self, + body: Union[JSON, IO[bytes]] = _Unset, + *, + messages: Union[List[_models.ChatRequestMessage], List[Dict[str, Any]]] = _Unset, + stream: Optional[bool] = None, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + max_tokens: Optional[int] = None, + response_format: Optional[Union[Literal["text", "json_object"], _models.JsonSchemaFormat]] = None, + stop: Optional[List[str]] = None, + tools: Optional[List[_models.ChatCompletionsToolDefinition]] = None, + tool_choice: Optional[ + Union[str, _models.ChatCompletionsToolChoicePreset, _models.ChatCompletionsNamedToolChoice] + ] = None, + seed: Optional[int] = None, + model: Optional[str] = None, + model_extras: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> Union[Iterable[_models.StreamingChatCompletionsUpdate], _models.ChatCompletions]: + # pylint: disable=line-too-long + # pylint: disable=too-many-locals + """Gets chat completions for the provided chat messages. + Completions support a wide variety of tasks and generate text that continues from or + "completes" provided prompt data. When using this method with `stream=True`, the response is streamed + back to the client. Iterate over the resulting :class:`~azure.ai.inference.models.StreamingChatCompletions` + object to get content updates as they arrive. + + :param body: Is either a MutableMapping[str, Any] type (like a dictionary) or a IO[bytes] type + that specifies the full request payload. Required. + :type body: JSON or IO[bytes] + :keyword messages: The collection of context messages associated with this chat completions + request. + Typical usage begins with a chat message for the System role that provides instructions for + the behavior of the assistant, followed by alternating messages between the User and + Assistant roles. Required. + :paramtype messages: list[~azure.ai.inference.models.ChatRequestMessage] or list[dict[str, Any]] + :keyword stream: A value indicating whether chat completions should be streamed for this request. + Default value is False. If streaming is enabled, the response will be a StreamingChatCompletions. + Otherwise the response will be a ChatCompletions. + :paramtype stream: bool + :keyword frequency_penalty: A value that influences the probability of generated tokens + appearing based on their cumulative frequency in generated text. + Positive values will make tokens less likely to appear as their frequency increases and + decrease the likelihood of the model repeating the same statements verbatim. + Supported range is [-2, 2]. + Default value is None. + :paramtype frequency_penalty: float + :keyword presence_penalty: A value that influences the probability of generated tokens + appearing based on their existing + presence in generated text. + Positive values will make tokens less likely to appear when they already exist and increase + the model's likelihood to output new topics. + Supported range is [-2, 2]. + Default value is None. + :paramtype presence_penalty: float + :keyword temperature: The sampling temperature to use that controls the apparent creativity of + generated completions. + Higher values will make output more random while lower values will make results more focused + and deterministic. + It is not recommended to modify temperature and top_p for the same completions request as the + interaction of these two settings is difficult to predict. + Supported range is [0, 1]. + Default value is None. + :paramtype temperature: float + :keyword top_p: An alternative to sampling with temperature called nucleus sampling. This value + causes the + model to consider the results of tokens with the provided probability mass. As an example, a + value of 0.15 will cause only the tokens comprising the top 15% of probability mass to be + considered. + It is not recommended to modify temperature and top_p for the same completions request as the + interaction of these two settings is difficult to predict. + Supported range is [0, 1]. + Default value is None. + :paramtype top_p: float + :keyword max_tokens: The maximum number of tokens to generate. Default value is None. + :paramtype max_tokens: int + :keyword response_format: The format that the AI model must output. AI chat completions models typically output + unformatted text by default. This is equivalent to setting "text" as the response_format. + To output JSON format, without adhering to any schema, set to "json_object". + To output JSON format adhering to a provided schema, set this to an object of the class + ~azure.ai.inference.models.JsonSchemaFormat. Default value is None. + :paramtype response_format: Union[Literal['text', 'json_object'], ~azure.ai.inference.models.JsonSchemaFormat] + :keyword stop: A collection of textual sequences that will end completions generation. Default + value is None. + :paramtype stop: list[str] + :keyword tools: The available tool definitions that the chat completions request can use, + including caller-defined functions. Default value is None. + :paramtype tools: list[~azure.ai.inference.models.ChatCompletionsToolDefinition] + :keyword tool_choice: If specified, the model will configure which of the provided tools it can + use for the chat completions response. Is either a Union[str, + "_models.ChatCompletionsToolChoicePreset"] type or a ChatCompletionsNamedToolChoice type. + Default value is None. + :paramtype tool_choice: str or ~azure.ai.inference.models.ChatCompletionsToolChoicePreset or + ~azure.ai.inference.models.ChatCompletionsNamedToolChoice + :keyword seed: If specified, the system will make a best effort to sample deterministically + such that repeated requests with the + same seed and parameters should return the same result. Determinism is not guaranteed. + Default value is None. + :paramtype seed: int + :keyword model: ID of the specific AI model to use, if more than one model is available on the + endpoint. Default value is None. + :paramtype model: str + :keyword model_extras: Additional, model-specific parameters that are not in the + standard request payload. They will be added as-is to the root of the JSON in the request body. + How the service handles these extra parameters depends on the value of the + ``extra-parameters`` request header. Default value is None. + :paramtype model_extras: dict[str, Any] + :return: ChatCompletions for non-streaming, or Iterable[StreamingChatCompletionsUpdate] for streaming. + :rtype: ~azure.ai.inference.models.ChatCompletions or ~azure.ai.inference.models.StreamingChatCompletions + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + _extra_parameters: Union[_models._enums.ExtraParameters, None] = None + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + + internal_response_format = _get_internal_response_format(response_format) + + if body is _Unset: + if messages is _Unset: + raise TypeError("missing required argument: messages") + body = { + "messages": messages, + "stream": stream, + "frequency_penalty": frequency_penalty if frequency_penalty is not None else self._frequency_penalty, + "max_tokens": max_tokens if max_tokens is not None else self._max_tokens, + "model": model if model is not None else self._model, + "presence_penalty": presence_penalty if presence_penalty is not None else self._presence_penalty, + "response_format": ( + internal_response_format if internal_response_format is not None else self._internal_response_format + ), + "seed": seed if seed is not None else self._seed, + "stop": stop if stop is not None else self._stop, + "temperature": temperature if temperature is not None else self._temperature, + "tool_choice": tool_choice if tool_choice is not None else self._tool_choice, + "tools": tools if tools is not None else self._tools, + "top_p": top_p if top_p is not None else self._top_p, + } + if model_extras is not None and bool(model_extras): + body.update(model_extras) + _extra_parameters = _models._enums.ExtraParameters.PASS_THROUGH # pylint: disable=protected-access + elif self._model_extras is not None and bool(self._model_extras): + body.update(self._model_extras) + _extra_parameters = _models._enums.ExtraParameters.PASS_THROUGH # pylint: disable=protected-access + body = {k: v for k, v in body.items() if v is not None} + elif isinstance(body, dict) and "stream" in body and isinstance(body["stream"], bool): + stream = body["stream"] + content_type = content_type or "application/json" + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _content = json.dumps(body, cls=SdkJSONEncoder, exclude_readonly=True) # type: ignore + + _request = build_chat_completions_complete_request( + extra_params=_extra_parameters, + content_type=content_type, + api_version=self._config.api_version, + content=_content, + headers=_headers, + params=_params, + ) + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + _request.url = self._client.format_url(_request.url, **path_format_arguments) + + _stream = stream or False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + if _stream: + response.read() # Load the body in memory and close the socket + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if _stream: + return _models.StreamingChatCompletions(response) + + return _deserialize(_models._patch.ChatCompletions, response.json()) # pylint: disable=protected-access + + @distributed_trace + def get_model_info(self, **kwargs: Any) -> _models.ModelInfo: + # pylint: disable=line-too-long + """Returns information about the AI model. + The method makes a REST API call to the ``/info`` route on the given endpoint. + This method will only work when using Serverless API or Managed Compute endpoint. + It will not work for GitHub Models endpoint or Azure OpenAI endpoint. + + :return: ModelInfo. The ModelInfo is compatible with MutableMapping + :rtype: ~azure.ai.inference.models.ModelInfo + :raises ~azure.core.exceptions.HttpResponseError: + """ + if not self._model_info: + try: + self._model_info = self._get_model_info(**kwargs) # pylint: disable=attribute-defined-outside-init + except ResourceNotFoundError as error: + error.message = "Model information is not available on this endpoint (`/info` route not supported)." + raise error + + return self._model_info + + def __str__(self) -> str: + # pylint: disable=client-method-name-no-double-underscore + return super().__str__() + f"\n{self._model_info}" if self._model_info else super().__str__() + + +class EmbeddingsClient(EmbeddingsClientGenerated): + """EmbeddingsClient. + + :param endpoint: Service endpoint URL for AI model inference. Required. + :type endpoint: str + :param credential: Credential used to authenticate requests to the service. Is either a + AzureKeyCredential type or a TokenCredential type. Required. + :type credential: ~azure.core.credentials.AzureKeyCredential or + ~azure.core.credentials.TokenCredential + :keyword dimensions: Optional. The number of dimensions the resulting output embeddings should + have. Default value is None. + :paramtype dimensions: int + :keyword encoding_format: Optional. The desired format for the returned embeddings. + Known values are: + "base64", "binary", "float", "int8", "ubinary", and "uint8". Default value is None. + :paramtype encoding_format: str or ~azure.ai.inference.models.EmbeddingEncodingFormat + :keyword input_type: Optional. The type of the input. Known values are: + "text", "query", and "document". Default value is None. + :paramtype input_type: str or ~azure.ai.inference.models.EmbeddingInputType + :keyword model: ID of the specific AI model to use, if more than one model is available on the + endpoint. Default value is None. + :paramtype model: str + :keyword model_extras: Additional, model-specific parameters that are not in the + standard request payload. They will be added as-is to the root of the JSON in the request body. + How the service handles these extra parameters depends on the value of the + ``extra-parameters`` request header. Default value is None. + :paramtype model_extras: dict[str, Any] + :keyword api_version: The API version to use for this operation. Default value is + "2024-05-01-preview". Note that overriding this default value may result in unsupported + behavior. + :paramtype api_version: str + """ + + def __init__( + self, + endpoint: str, + credential: Union[AzureKeyCredential, "TokenCredential"], + *, + dimensions: Optional[int] = None, + encoding_format: Optional[Union[str, _models.EmbeddingEncodingFormat]] = None, + input_type: Optional[Union[str, _models.EmbeddingInputType]] = None, + model: Optional[str] = None, + model_extras: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> None: + + self._model_info: Optional[_models.ModelInfo] = None + + # Store default embeddings settings, to be applied in all future service calls + # unless overridden by arguments in the `embed` method. + self._dimensions = dimensions + self._encoding_format = encoding_format + self._input_type = input_type + self._model = model + self._model_extras = model_extras + + # For Key auth, we need to send these two auth HTTP request headers simultaneously: + # 1. "Authorization: Bearer <key>" + # 2. "api-key: <key>" + # This is because Serverless API, Managed Compute and GitHub endpoints support the first header, + # and Azure OpenAI and the new Unified Inference endpoints support the second header. + # The first header will be taken care of by auto-generated code. + # The second one is added here. + if isinstance(credential, AzureKeyCredential): + headers = kwargs.pop("headers", {}) + if "api-key" not in headers: + headers["api-key"] = credential.key + kwargs["headers"] = headers + + super().__init__(endpoint, credential, **kwargs) + + @overload + def embed( + self, + *, + input: List[str], + dimensions: Optional[int] = None, + encoding_format: Optional[Union[str, _models.EmbeddingEncodingFormat]] = None, + input_type: Optional[Union[str, _models.EmbeddingInputType]] = None, + model: Optional[str] = None, + model_extras: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> _models.EmbeddingsResult: + """Return the embedding vectors for given text prompts. + The method makes a REST API call to the `/embeddings` route on the given endpoint. + + :keyword input: Input text to embed, encoded as a string or array of tokens. + To embed multiple inputs in a single request, pass an array + of strings or array of token arrays. Required. + :paramtype input: list[str] + :keyword dimensions: Optional. The number of dimensions the resulting output embeddings should + have. Default value is None. + :paramtype dimensions: int + :keyword encoding_format: Optional. The desired format for the returned embeddings. + Known values are: + "base64", "binary", "float", "int8", "ubinary", and "uint8". Default value is None. + :paramtype encoding_format: str or ~azure.ai.inference.models.EmbeddingEncodingFormat + :keyword input_type: Optional. The type of the input. Known values are: + "text", "query", and "document". Default value is None. + :paramtype input_type: str or ~azure.ai.inference.models.EmbeddingInputType + :keyword model: ID of the specific AI model to use, if more than one model is available on the + endpoint. Default value is None. + :paramtype model: str + :keyword model_extras: Additional, model-specific parameters that are not in the + standard request payload. They will be added as-is to the root of the JSON in the request body. + How the service handles these extra parameters depends on the value of the + ``extra-parameters`` request header. Default value is None. + :paramtype model_extras: dict[str, Any] + :return: EmbeddingsResult. The EmbeddingsResult is compatible with MutableMapping + :rtype: ~azure.ai.inference.models.EmbeddingsResult + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def embed( + self, + body: JSON, + *, + content_type: str = "application/json", + **kwargs: Any, + ) -> _models.EmbeddingsResult: + """Return the embedding vectors for given text prompts. + The method makes a REST API call to the `/embeddings` route on the given endpoint. + + :param body: An object of type MutableMapping[str, Any], such as a dictionary, that + specifies the full request payload. Required. + :type body: JSON + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: EmbeddingsResult. The EmbeddingsResult is compatible with MutableMapping + :rtype: ~azure.ai.inference.models.EmbeddingsResult + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def embed( + self, + body: IO[bytes], + *, + content_type: str = "application/json", + **kwargs: Any, + ) -> _models.EmbeddingsResult: + """Return the embedding vectors for given text prompts. + The method makes a REST API call to the `/embeddings` route on the given endpoint. + + :param body: Specifies the full request payload. Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: EmbeddingsResult. The EmbeddingsResult is compatible with MutableMapping + :rtype: ~azure.ai.inference.models.EmbeddingsResult + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + def embed( + self, + body: Union[JSON, IO[bytes]] = _Unset, + *, + input: List[str] = _Unset, + dimensions: Optional[int] = None, + encoding_format: Optional[Union[str, _models.EmbeddingEncodingFormat]] = None, + input_type: Optional[Union[str, _models.EmbeddingInputType]] = None, + model: Optional[str] = None, + model_extras: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> _models.EmbeddingsResult: + # pylint: disable=line-too-long + """Return the embedding vectors for given text prompts. + The method makes a REST API call to the `/embeddings` route on the given endpoint. + + :param body: Is either a MutableMapping[str, Any] type (like a dictionary) or a IO[bytes] type + that specifies the full request payload. Required. + :type body: JSON or IO[bytes] + :keyword input: Input text to embed, encoded as a string or array of tokens. + To embed multiple inputs in a single request, pass an array + of strings or array of token arrays. Required. + :paramtype input: list[str] + :keyword dimensions: Optional. The number of dimensions the resulting output embeddings should + have. Default value is None. + :paramtype dimensions: int + :keyword encoding_format: Optional. The desired format for the returned embeddings. + Known values are: + "base64", "binary", "float", "int8", "ubinary", and "uint8". Default value is None. + :paramtype encoding_format: str or ~azure.ai.inference.models.EmbeddingEncodingFormat + :keyword input_type: Optional. The type of the input. Known values are: + "text", "query", and "document". Default value is None. + :paramtype input_type: str or ~azure.ai.inference.models.EmbeddingInputType + :keyword model: ID of the specific AI model to use, if more than one model is available on the + endpoint. Default value is None. + :paramtype model: str + :keyword model_extras: Additional, model-specific parameters that are not in the + standard request payload. They will be added as-is to the root of the JSON in the request body. + How the service handles these extra parameters depends on the value of the + ``extra-parameters`` request header. Default value is None. + :paramtype model_extras: dict[str, Any] + :return: EmbeddingsResult. The EmbeddingsResult is compatible with MutableMapping + :rtype: ~azure.ai.inference.models.EmbeddingsResult + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping[int, Type[HttpResponseError]] = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + _extra_parameters: Union[_models._enums.ExtraParameters, None] = None + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + + if body is _Unset: + if input is _Unset: + raise TypeError("missing required argument: input") + body = { + "input": input, + "dimensions": dimensions if dimensions is not None else self._dimensions, + "encoding_format": encoding_format if encoding_format is not None else self._encoding_format, + "input_type": input_type if input_type is not None else self._input_type, + "model": model if model is not None else self._model, + } + if model_extras is not None and bool(model_extras): + body.update(model_extras) + _extra_parameters = _models._enums.ExtraParameters.PASS_THROUGH # pylint: disable=protected-access + elif self._model_extras is not None and bool(self._model_extras): + body.update(self._model_extras) + _extra_parameters = _models._enums.ExtraParameters.PASS_THROUGH # pylint: disable=protected-access + body = {k: v for k, v in body.items() if v is not None} + content_type = content_type or "application/json" + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _content = json.dumps(body, cls=SdkJSONEncoder, exclude_readonly=True) # type: ignore + + _request = build_embeddings_embed_request( + extra_params=_extra_parameters, + content_type=content_type, + api_version=self._config.api_version, + content=_content, + headers=_headers, + params=_params, + ) + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + _request.url = self._client.format_url(_request.url, **path_format_arguments) + + _stream = kwargs.pop("stream", False) + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + if _stream: + response.read() # Load the body in memory and close the socket + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if _stream: + deserialized = response.iter_bytes() + else: + deserialized = _deserialize( + _models._patch.EmbeddingsResult, response.json() # pylint: disable=protected-access + ) + + return deserialized # type: ignore + + @distributed_trace + def get_model_info(self, **kwargs: Any) -> _models.ModelInfo: + # pylint: disable=line-too-long + """Returns information about the AI model. + The method makes a REST API call to the ``/info`` route on the given endpoint. + This method will only work when using Serverless API or Managed Compute endpoint. + It will not work for GitHub Models endpoint or Azure OpenAI endpoint. + + :return: ModelInfo. The ModelInfo is compatible with MutableMapping + :rtype: ~azure.ai.inference.models.ModelInfo + :raises ~azure.core.exceptions.HttpResponseError: + """ + if not self._model_info: + try: + self._model_info = self._get_model_info(**kwargs) # pylint: disable=attribute-defined-outside-init + except ResourceNotFoundError as error: + error.message = "Model information is not available on this endpoint (`/info` route not supported)." + raise error + + return self._model_info + + def __str__(self) -> str: + # pylint: disable=client-method-name-no-double-underscore + return super().__str__() + f"\n{self._model_info}" if self._model_info else super().__str__() + + +class ImageEmbeddingsClient(ImageEmbeddingsClientGenerated): + """ImageEmbeddingsClient. + + :param endpoint: Service endpoint URL for AI model inference. Required. + :type endpoint: str + :param credential: Credential used to authenticate requests to the service. Is either a + AzureKeyCredential type or a TokenCredential type. Required. + :type credential: ~azure.core.credentials.AzureKeyCredential or + ~azure.core.credentials.TokenCredential + :keyword dimensions: Optional. The number of dimensions the resulting output embeddings should + have. Default value is None. + :paramtype dimensions: int + :keyword encoding_format: Optional. The desired format for the returned embeddings. + Known values are: + "base64", "binary", "float", "int8", "ubinary", and "uint8". Default value is None. + :paramtype encoding_format: str or ~azure.ai.inference.models.EmbeddingEncodingFormat + :keyword input_type: Optional. The type of the input. Known values are: + "text", "query", and "document". Default value is None. + :paramtype input_type: str or ~azure.ai.inference.models.EmbeddingInputType + :keyword model: ID of the specific AI model to use, if more than one model is available on the + endpoint. Default value is None. + :paramtype model: str + :keyword model_extras: Additional, model-specific parameters that are not in the + standard request payload. They will be added as-is to the root of the JSON in the request body. + How the service handles these extra parameters depends on the value of the + ``extra-parameters`` request header. Default value is None. + :paramtype model_extras: dict[str, Any] + :keyword api_version: The API version to use for this operation. Default value is + "2024-05-01-preview". Note that overriding this default value may result in unsupported + behavior. + :paramtype api_version: str + """ + + def __init__( + self, + endpoint: str, + credential: Union[AzureKeyCredential, "TokenCredential"], + *, + dimensions: Optional[int] = None, + encoding_format: Optional[Union[str, _models.EmbeddingEncodingFormat]] = None, + input_type: Optional[Union[str, _models.EmbeddingInputType]] = None, + model: Optional[str] = None, + model_extras: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> None: + + self._model_info: Optional[_models.ModelInfo] = None + + # Store default embeddings settings, to be applied in all future service calls + # unless overridden by arguments in the `embed` method. + self._dimensions = dimensions + self._encoding_format = encoding_format + self._input_type = input_type + self._model = model + self._model_extras = model_extras + + # For Key auth, we need to send these two auth HTTP request headers simultaneously: + # 1. "Authorization: Bearer <key>" + # 2. "api-key: <key>" + # This is because Serverless API, Managed Compute and GitHub endpoints support the first header, + # and Azure OpenAI and the new Unified Inference endpoints support the second header. + # The first header will be taken care of by auto-generated code. + # The second one is added here. + if isinstance(credential, AzureKeyCredential): + headers = kwargs.pop("headers", {}) + if "api-key" not in headers: + headers["api-key"] = credential.key + kwargs["headers"] = headers + + super().__init__(endpoint, credential, **kwargs) + + @overload + def embed( + self, + *, + input: List[_models.ImageEmbeddingInput], + dimensions: Optional[int] = None, + encoding_format: Optional[Union[str, _models.EmbeddingEncodingFormat]] = None, + input_type: Optional[Union[str, _models.EmbeddingInputType]] = None, + model: Optional[str] = None, + model_extras: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> _models.EmbeddingsResult: + """Return the embedding vectors for given images. + The method makes a REST API call to the `/images/embeddings` route on the given endpoint. + + :keyword input: Input image to embed. To embed multiple inputs in a single request, pass an + array. + The input must not exceed the max input tokens for the model. Required. + :paramtype input: list[~azure.ai.inference.models.ImageEmbeddingInput] + :keyword dimensions: Optional. The number of dimensions the resulting output embeddings should + have. Default value is None. + :paramtype dimensions: int + :keyword encoding_format: Optional. The desired format for the returned embeddings. + Known values are: + "base64", "binary", "float", "int8", "ubinary", and "uint8". Default value is None. + :paramtype encoding_format: str or ~azure.ai.inference.models.EmbeddingEncodingFormat + :keyword input_type: Optional. The type of the input. Known values are: + "text", "query", and "document". Default value is None. + :paramtype input_type: str or ~azure.ai.inference.models.EmbeddingInputType + :keyword model: ID of the specific AI model to use, if more than one model is available on the + endpoint. Default value is None. + :paramtype model: str + :keyword model_extras: Additional, model-specific parameters that are not in the + standard request payload. They will be added as-is to the root of the JSON in the request body. + How the service handles these extra parameters depends on the value of the + ``extra-parameters`` request header. Default value is None. + :paramtype model_extras: dict[str, Any] + :return: EmbeddingsResult. The EmbeddingsResult is compatible with MutableMapping + :rtype: ~azure.ai.inference.models.EmbeddingsResult + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def embed( + self, + body: JSON, + *, + content_type: str = "application/json", + **kwargs: Any, + ) -> _models.EmbeddingsResult: + """Return the embedding vectors for given images. + The method makes a REST API call to the `/images/embeddings` route on the given endpoint. + + :param body: An object of type MutableMapping[str, Any], such as a dictionary, that + specifies the full request payload. Required. + :type body: JSON + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: EmbeddingsResult. The EmbeddingsResult is compatible with MutableMapping + :rtype: ~azure.ai.inference.models.EmbeddingsResult + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def embed( + self, + body: IO[bytes], + *, + content_type: str = "application/json", + **kwargs: Any, + ) -> _models.EmbeddingsResult: + """Return the embedding vectors for given images. + The method makes a REST API call to the `/images/embeddings` route on the given endpoint. + + :param body: Specifies the full request payload. Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: EmbeddingsResult. The EmbeddingsResult is compatible with MutableMapping + :rtype: ~azure.ai.inference.models.EmbeddingsResult + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + def embed( + self, + body: Union[JSON, IO[bytes]] = _Unset, + *, + input: List[_models.ImageEmbeddingInput] = _Unset, + dimensions: Optional[int] = None, + encoding_format: Optional[Union[str, _models.EmbeddingEncodingFormat]] = None, + input_type: Optional[Union[str, _models.EmbeddingInputType]] = None, + model: Optional[str] = None, + model_extras: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> _models.EmbeddingsResult: + # pylint: disable=line-too-long + """Return the embedding vectors for given images. + The method makes a REST API call to the `/images/embeddings` route on the given endpoint. + + :param body: Is either a MutableMapping[str, Any] type (like a dictionary) or a IO[bytes] type + that specifies the full request payload. Required. + :type body: JSON or IO[bytes] + :keyword input: Input image to embed. To embed multiple inputs in a single request, pass an + array. + The input must not exceed the max input tokens for the model. Required. + :paramtype input: list[~azure.ai.inference.models.ImageEmbeddingInput] + :keyword dimensions: Optional. The number of dimensions the resulting output embeddings should + have. Default value is None. + :paramtype dimensions: int + :keyword encoding_format: Optional. The desired format for the returned embeddings. + Known values are: + "base64", "binary", "float", "int8", "ubinary", and "uint8". Default value is None. + :paramtype encoding_format: str or ~azure.ai.inference.models.EmbeddingEncodingFormat + :keyword input_type: Optional. The type of the input. Known values are: + "text", "query", and "document". Default value is None. + :paramtype input_type: str or ~azure.ai.inference.models.EmbeddingInputType + :keyword model: ID of the specific AI model to use, if more than one model is available on the + endpoint. Default value is None. + :paramtype model: str + :keyword model_extras: Additional, model-specific parameters that are not in the + standard request payload. They will be added as-is to the root of the JSON in the request body. + How the service handles these extra parameters depends on the value of the + ``extra-parameters`` request header. Default value is None. + :paramtype model_extras: dict[str, Any] + :return: EmbeddingsResult. The EmbeddingsResult is compatible with MutableMapping + :rtype: ~azure.ai.inference.models.EmbeddingsResult + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping[int, Type[HttpResponseError]] = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + _extra_parameters: Union[_models._enums.ExtraParameters, None] = None + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + + if body is _Unset: + if input is _Unset: + raise TypeError("missing required argument: input") + body = { + "input": input, + "dimensions": dimensions if dimensions is not None else self._dimensions, + "encoding_format": encoding_format if encoding_format is not None else self._encoding_format, + "input_type": input_type if input_type is not None else self._input_type, + "model": model if model is not None else self._model, + } + if model_extras is not None and bool(model_extras): + body.update(model_extras) + _extra_parameters = _models._enums.ExtraParameters.PASS_THROUGH # pylint: disable=protected-access + elif self._model_extras is not None and bool(self._model_extras): + body.update(self._model_extras) + _extra_parameters = _models._enums.ExtraParameters.PASS_THROUGH # pylint: disable=protected-access + body = {k: v for k, v in body.items() if v is not None} + content_type = content_type or "application/json" + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _content = json.dumps(body, cls=SdkJSONEncoder, exclude_readonly=True) # type: ignore + + _request = build_image_embeddings_embed_request( + extra_params=_extra_parameters, + content_type=content_type, + api_version=self._config.api_version, + content=_content, + headers=_headers, + params=_params, + ) + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + _request.url = self._client.format_url(_request.url, **path_format_arguments) + + _stream = kwargs.pop("stream", False) + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + if _stream: + response.read() # Load the body in memory and close the socket + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if _stream: + deserialized = response.iter_bytes() + else: + deserialized = _deserialize( + _models._patch.EmbeddingsResult, response.json() # pylint: disable=protected-access + ) + + return deserialized # type: ignore + + @distributed_trace + def get_model_info(self, **kwargs: Any) -> _models.ModelInfo: + # pylint: disable=line-too-long + """Returns information about the AI model. + The method makes a REST API call to the ``/info`` route on the given endpoint. + This method will only work when using Serverless API or Managed Compute endpoint. + It will not work for GitHub Models endpoint or Azure OpenAI endpoint. + + :return: ModelInfo. The ModelInfo is compatible with MutableMapping + :rtype: ~azure.ai.inference.models.ModelInfo + :raises ~azure.core.exceptions.HttpResponseError: + """ + if not self._model_info: + try: + self._model_info = self._get_model_info(**kwargs) # pylint: disable=attribute-defined-outside-init + except ResourceNotFoundError as error: + error.message = "Model information is not available on this endpoint (`/info` route not supported)." + raise error + + return self._model_info + + def __str__(self) -> str: + # pylint: disable=client-method-name-no-double-underscore + return super().__str__() + f"\n{self._model_info}" if self._model_info else super().__str__() + + +__all__: List[str] = [ + "load_client", + "ChatCompletionsClient", + "EmbeddingsClient", + "ImageEmbeddingsClient", +] # Add all objects you want publicly available to users at this package level + + +def patch_sdk(): + """Do not remove from this file. + + `patch_sdk` is a last resort escape hatch that allows you to do customizations + you can't accomplish using the techniques described in + https://aka.ms/azsdk/python/dpcodegen/python/customize + """ diff --git a/.venv/lib/python3.12/site-packages/azure/ai/inference/_serialization.py b/.venv/lib/python3.12/site-packages/azure/ai/inference/_serialization.py new file mode 100644 index 00000000..a066e16a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/inference/_serialization.py @@ -0,0 +1,2050 @@ +# pylint: disable=too-many-lines +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- + +# pyright: reportUnnecessaryTypeIgnoreComment=false + +from base64 import b64decode, b64encode +import calendar +import datetime +import decimal +import email +from enum import Enum +import json +import logging +import re +import sys +import codecs +from typing import ( + Dict, + Any, + cast, + Optional, + Union, + AnyStr, + IO, + Mapping, + Callable, + MutableMapping, + List, +) + +try: + from urllib import quote # type: ignore +except ImportError: + from urllib.parse import quote +import xml.etree.ElementTree as ET + +import isodate # type: ignore +from typing_extensions import Self + +from azure.core.exceptions import DeserializationError, SerializationError +from azure.core.serialization import NULL as CoreNull + +_BOM = codecs.BOM_UTF8.decode(encoding="utf-8") + +JSON = MutableMapping[str, Any] + + +class RawDeserializer: + + # Accept "text" because we're open minded people... + JSON_REGEXP = re.compile(r"^(application|text)/([a-z+.]+\+)?json$") + + # Name used in context + CONTEXT_NAME = "deserialized_data" + + @classmethod + def deserialize_from_text(cls, data: Optional[Union[AnyStr, IO]], content_type: Optional[str] = None) -> Any: + """Decode data according to content-type. + + Accept a stream of data as well, but will be load at once in memory for now. + + If no content-type, will return the string version (not bytes, not stream) + + :param data: Input, could be bytes or stream (will be decoded with UTF8) or text + :type data: str or bytes or IO + :param str content_type: The content type. + :return: The deserialized data. + :rtype: object + """ + if hasattr(data, "read"): + # Assume a stream + data = cast(IO, data).read() + + if isinstance(data, bytes): + data_as_str = data.decode(encoding="utf-8-sig") + else: + # Explain to mypy the correct type. + data_as_str = cast(str, data) + + # Remove Byte Order Mark if present in string + data_as_str = data_as_str.lstrip(_BOM) + + if content_type is None: + return data + + if cls.JSON_REGEXP.match(content_type): + try: + return json.loads(data_as_str) + except ValueError as err: + raise DeserializationError("JSON is invalid: {}".format(err), err) from err + elif "xml" in (content_type or []): + try: + + try: + if isinstance(data, unicode): # type: ignore + # If I'm Python 2.7 and unicode XML will scream if I try a "fromstring" on unicode string + data_as_str = data_as_str.encode(encoding="utf-8") # type: ignore + except NameError: + pass + + return ET.fromstring(data_as_str) # nosec + except ET.ParseError as err: + # It might be because the server has an issue, and returned JSON with + # content-type XML.... + # So let's try a JSON load, and if it's still broken + # let's flow the initial exception + def _json_attemp(data): + try: + return True, json.loads(data) + except ValueError: + return False, None # Don't care about this one + + success, json_result = _json_attemp(data) + if success: + return json_result + # If i'm here, it's not JSON, it's not XML, let's scream + # and raise the last context in this block (the XML exception) + # The function hack is because Py2.7 messes up with exception + # context otherwise. + _LOGGER.critical("Wasn't XML not JSON, failing") + raise DeserializationError("XML is invalid") from err + elif content_type.startswith("text/"): + return data_as_str + raise DeserializationError("Cannot deserialize content-type: {}".format(content_type)) + + @classmethod + def deserialize_from_http_generics(cls, body_bytes: Optional[Union[AnyStr, IO]], headers: Mapping) -> Any: + """Deserialize from HTTP response. + + Use bytes and headers to NOT use any requests/aiohttp or whatever + specific implementation. + Headers will tested for "content-type" + + :param bytes body_bytes: The body of the response. + :param dict headers: The headers of the response. + :returns: The deserialized data. + :rtype: object + """ + # Try to use content-type from headers if available + content_type = None + if "content-type" in headers: + content_type = headers["content-type"].split(";")[0].strip().lower() + # Ouch, this server did not declare what it sent... + # Let's guess it's JSON... + # Also, since Autorest was considering that an empty body was a valid JSON, + # need that test as well.... + else: + content_type = "application/json" + + if body_bytes: + return cls.deserialize_from_text(body_bytes, content_type) + return None + + +_LOGGER = logging.getLogger(__name__) + +try: + _long_type = long # type: ignore +except NameError: + _long_type = int + +TZ_UTC = datetime.timezone.utc + +_FLATTEN = re.compile(r"(?<!\\)\.") + + +def attribute_transformer(key, attr_desc, value): # pylint: disable=unused-argument + """A key transformer that returns the Python attribute. + + :param str key: The attribute name + :param dict attr_desc: The attribute metadata + :param object value: The value + :returns: A key using attribute name + :rtype: str + """ + return (key, value) + + +def full_restapi_key_transformer(key, attr_desc, value): # pylint: disable=unused-argument + """A key transformer that returns the full RestAPI key path. + + :param str key: The attribute name + :param dict attr_desc: The attribute metadata + :param object value: The value + :returns: A list of keys using RestAPI syntax. + :rtype: list + """ + keys = _FLATTEN.split(attr_desc["key"]) + return ([_decode_attribute_map_key(k) for k in keys], value) + + +def last_restapi_key_transformer(key, attr_desc, value): + """A key transformer that returns the last RestAPI key. + + :param str key: The attribute name + :param dict attr_desc: The attribute metadata + :param object value: The value + :returns: The last RestAPI key. + :rtype: str + """ + key, value = full_restapi_key_transformer(key, attr_desc, value) + return (key[-1], value) + + +def _create_xml_node(tag, prefix=None, ns=None): + """Create a XML node. + + :param str tag: The tag name + :param str prefix: The prefix + :param str ns: The namespace + :return: The XML node + :rtype: xml.etree.ElementTree.Element + """ + if prefix and ns: + ET.register_namespace(prefix, ns) + if ns: + return ET.Element("{" + ns + "}" + tag) + return ET.Element(tag) + + +class Model: + """Mixin for all client request body/response body models to support + serialization and deserialization. + """ + + _subtype_map: Dict[str, Dict[str, Any]] = {} + _attribute_map: Dict[str, Dict[str, Any]] = {} + _validation: Dict[str, Dict[str, Any]] = {} + + def __init__(self, **kwargs: Any) -> None: + self.additional_properties: Optional[Dict[str, Any]] = {} + for k in kwargs: # pylint: disable=consider-using-dict-items + if k not in self._attribute_map: + _LOGGER.warning("%s is not a known attribute of class %s and will be ignored", k, self.__class__) + elif k in self._validation and self._validation[k].get("readonly", False): + _LOGGER.warning("Readonly attribute %s will be ignored in class %s", k, self.__class__) + else: + setattr(self, k, kwargs[k]) + + def __eq__(self, other: Any) -> bool: + """Compare objects by comparing all attributes. + + :param object other: The object to compare + :returns: True if objects are equal + :rtype: bool + """ + if isinstance(other, self.__class__): + return self.__dict__ == other.__dict__ + return False + + def __ne__(self, other: Any) -> bool: + """Compare objects by comparing all attributes. + + :param object other: The object to compare + :returns: True if objects are not equal + :rtype: bool + """ + return not self.__eq__(other) + + def __str__(self) -> str: + return str(self.__dict__) + + @classmethod + def enable_additional_properties_sending(cls) -> None: + cls._attribute_map["additional_properties"] = {"key": "", "type": "{object}"} + + @classmethod + def is_xml_model(cls) -> bool: + try: + cls._xml_map # type: ignore + except AttributeError: + return False + return True + + @classmethod + def _create_xml_node(cls): + """Create XML node. + + :returns: The XML node + :rtype: xml.etree.ElementTree.Element + """ + try: + xml_map = cls._xml_map # type: ignore + except AttributeError: + xml_map = {} + + return _create_xml_node(xml_map.get("name", cls.__name__), xml_map.get("prefix", None), xml_map.get("ns", None)) + + def serialize(self, keep_readonly: bool = False, **kwargs: Any) -> JSON: + """Return the JSON that would be sent to server from this model. + + This is an alias to `as_dict(full_restapi_key_transformer, keep_readonly=False)`. + + If you want XML serialization, you can pass the kwargs is_xml=True. + + :param bool keep_readonly: If you want to serialize the readonly attributes + :returns: A dict JSON compatible object + :rtype: dict + """ + serializer = Serializer(self._infer_class_models()) + return serializer._serialize( # type: ignore # pylint: disable=protected-access + self, keep_readonly=keep_readonly, **kwargs + ) + + def as_dict( + self, + keep_readonly: bool = True, + key_transformer: Callable[[str, Dict[str, Any], Any], Any] = attribute_transformer, + **kwargs: Any + ) -> JSON: + """Return a dict that can be serialized using json.dump. + + Advanced usage might optionally use a callback as parameter: + + .. code::python + + def my_key_transformer(key, attr_desc, value): + return key + + Key is the attribute name used in Python. Attr_desc + is a dict of metadata. Currently contains 'type' with the + msrest type and 'key' with the RestAPI encoded key. + Value is the current value in this object. + + The string returned will be used to serialize the key. + If the return type is a list, this is considered hierarchical + result dict. + + See the three examples in this file: + + - attribute_transformer + - full_restapi_key_transformer + - last_restapi_key_transformer + + If you want XML serialization, you can pass the kwargs is_xml=True. + + :param bool keep_readonly: If you want to serialize the readonly attributes + :param function key_transformer: A key transformer function. + :returns: A dict JSON compatible object + :rtype: dict + """ + serializer = Serializer(self._infer_class_models()) + return serializer._serialize( # type: ignore # pylint: disable=protected-access + self, key_transformer=key_transformer, keep_readonly=keep_readonly, **kwargs + ) + + @classmethod + def _infer_class_models(cls): + try: + str_models = cls.__module__.rsplit(".", 1)[0] + models = sys.modules[str_models] + client_models = {k: v for k, v in models.__dict__.items() if isinstance(v, type)} + if cls.__name__ not in client_models: + raise ValueError("Not Autorest generated code") + except Exception: # pylint: disable=broad-exception-caught + # Assume it's not Autorest generated (tests?). Add ourselves as dependencies. + client_models = {cls.__name__: cls} + return client_models + + @classmethod + def deserialize(cls, data: Any, content_type: Optional[str] = None) -> Self: + """Parse a str using the RestAPI syntax and return a model. + + :param str data: A str using RestAPI structure. JSON by default. + :param str content_type: JSON by default, set application/xml if XML. + :returns: An instance of this model + :raises DeserializationError: if something went wrong + :rtype: Self + """ + deserializer = Deserializer(cls._infer_class_models()) + return deserializer(cls.__name__, data, content_type=content_type) # type: ignore + + @classmethod + def from_dict( + cls, + data: Any, + key_extractors: Optional[Callable[[str, Dict[str, Any], Any], Any]] = None, + content_type: Optional[str] = None, + ) -> Self: + """Parse a dict using given key extractor return a model. + + By default consider key + extractors (rest_key_case_insensitive_extractor, attribute_key_case_insensitive_extractor + and last_rest_key_case_insensitive_extractor) + + :param dict data: A dict using RestAPI structure + :param function key_extractors: A key extractor function. + :param str content_type: JSON by default, set application/xml if XML. + :returns: An instance of this model + :raises: DeserializationError if something went wrong + :rtype: Self + """ + deserializer = Deserializer(cls._infer_class_models()) + deserializer.key_extractors = ( # type: ignore + [ # type: ignore + attribute_key_case_insensitive_extractor, + rest_key_case_insensitive_extractor, + last_rest_key_case_insensitive_extractor, + ] + if key_extractors is None + else key_extractors + ) + return deserializer(cls.__name__, data, content_type=content_type) # type: ignore + + @classmethod + def _flatten_subtype(cls, key, objects): + if "_subtype_map" not in cls.__dict__: + return {} + result = dict(cls._subtype_map[key]) + for valuetype in cls._subtype_map[key].values(): + result.update(objects[valuetype]._flatten_subtype(key, objects)) # pylint: disable=protected-access + return result + + @classmethod + def _classify(cls, response, objects): + """Check the class _subtype_map for any child classes. + We want to ignore any inherited _subtype_maps. + + :param dict response: The initial data + :param dict objects: The class objects + :returns: The class to be used + :rtype: class + """ + for subtype_key in cls.__dict__.get("_subtype_map", {}).keys(): + subtype_value = None + + if not isinstance(response, ET.Element): + rest_api_response_key = cls._get_rest_key_parts(subtype_key)[-1] + subtype_value = response.get(rest_api_response_key, None) or response.get(subtype_key, None) + else: + subtype_value = xml_key_extractor(subtype_key, cls._attribute_map[subtype_key], response) + if subtype_value: + # Try to match base class. Can be class name only + # (bug to fix in Autorest to support x-ms-discriminator-name) + if cls.__name__ == subtype_value: + return cls + flatten_mapping_type = cls._flatten_subtype(subtype_key, objects) + try: + return objects[flatten_mapping_type[subtype_value]] # type: ignore + except KeyError: + _LOGGER.warning( + "Subtype value %s has no mapping, use base class %s.", + subtype_value, + cls.__name__, + ) + break + else: + _LOGGER.warning("Discriminator %s is absent or null, use base class %s.", subtype_key, cls.__name__) + break + return cls + + @classmethod + def _get_rest_key_parts(cls, attr_key): + """Get the RestAPI key of this attr, split it and decode part + :param str attr_key: Attribute key must be in attribute_map. + :returns: A list of RestAPI part + :rtype: list + """ + rest_split_key = _FLATTEN.split(cls._attribute_map[attr_key]["key"]) + return [_decode_attribute_map_key(key_part) for key_part in rest_split_key] + + +def _decode_attribute_map_key(key): + """This decode a key in an _attribute_map to the actual key we want to look at + inside the received data. + + :param str key: A key string from the generated code + :returns: The decoded key + :rtype: str + """ + return key.replace("\\.", ".") + + +class Serializer: # pylint: disable=too-many-public-methods + """Request object model serializer.""" + + basic_types = {str: "str", int: "int", bool: "bool", float: "float"} + + _xml_basic_types_serializers = {"bool": lambda x: str(x).lower()} + days = {0: "Mon", 1: "Tue", 2: "Wed", 3: "Thu", 4: "Fri", 5: "Sat", 6: "Sun"} + months = { + 1: "Jan", + 2: "Feb", + 3: "Mar", + 4: "Apr", + 5: "May", + 6: "Jun", + 7: "Jul", + 8: "Aug", + 9: "Sep", + 10: "Oct", + 11: "Nov", + 12: "Dec", + } + validation = { + "min_length": lambda x, y: len(x) < y, + "max_length": lambda x, y: len(x) > y, + "minimum": lambda x, y: x < y, + "maximum": lambda x, y: x > y, + "minimum_ex": lambda x, y: x <= y, + "maximum_ex": lambda x, y: x >= y, + "min_items": lambda x, y: len(x) < y, + "max_items": lambda x, y: len(x) > y, + "pattern": lambda x, y: not re.match(y, x, re.UNICODE), + "unique": lambda x, y: len(x) != len(set(x)), + "multiple": lambda x, y: x % y != 0, + } + + def __init__(self, classes: Optional[Mapping[str, type]] = None) -> None: + self.serialize_type = { + "iso-8601": Serializer.serialize_iso, + "rfc-1123": Serializer.serialize_rfc, + "unix-time": Serializer.serialize_unix, + "duration": Serializer.serialize_duration, + "date": Serializer.serialize_date, + "time": Serializer.serialize_time, + "decimal": Serializer.serialize_decimal, + "long": Serializer.serialize_long, + "bytearray": Serializer.serialize_bytearray, + "base64": Serializer.serialize_base64, + "object": self.serialize_object, + "[]": self.serialize_iter, + "{}": self.serialize_dict, + } + self.dependencies: Dict[str, type] = dict(classes) if classes else {} + self.key_transformer = full_restapi_key_transformer + self.client_side_validation = True + + def _serialize( # pylint: disable=too-many-nested-blocks, too-many-branches, too-many-statements, too-many-locals + self, target_obj, data_type=None, **kwargs + ): + """Serialize data into a string according to type. + + :param object target_obj: The data to be serialized. + :param str data_type: The type to be serialized from. + :rtype: str, dict + :raises SerializationError: if serialization fails. + :returns: The serialized data. + """ + key_transformer = kwargs.get("key_transformer", self.key_transformer) + keep_readonly = kwargs.get("keep_readonly", False) + if target_obj is None: + return None + + attr_name = None + class_name = target_obj.__class__.__name__ + + if data_type: + return self.serialize_data(target_obj, data_type, **kwargs) + + if not hasattr(target_obj, "_attribute_map"): + data_type = type(target_obj).__name__ + if data_type in self.basic_types.values(): + return self.serialize_data(target_obj, data_type, **kwargs) + + # Force "is_xml" kwargs if we detect a XML model + try: + is_xml_model_serialization = kwargs["is_xml"] + except KeyError: + is_xml_model_serialization = kwargs.setdefault("is_xml", target_obj.is_xml_model()) + + serialized = {} + if is_xml_model_serialization: + serialized = target_obj._create_xml_node() # pylint: disable=protected-access + try: + attributes = target_obj._attribute_map # pylint: disable=protected-access + for attr, attr_desc in attributes.items(): + attr_name = attr + if not keep_readonly and target_obj._validation.get( # pylint: disable=protected-access + attr_name, {} + ).get("readonly", False): + continue + + if attr_name == "additional_properties" and attr_desc["key"] == "": + if target_obj.additional_properties is not None: + serialized.update(target_obj.additional_properties) + continue + try: + + orig_attr = getattr(target_obj, attr) + if is_xml_model_serialization: + pass # Don't provide "transformer" for XML for now. Keep "orig_attr" + else: # JSON + keys, orig_attr = key_transformer(attr, attr_desc.copy(), orig_attr) + keys = keys if isinstance(keys, list) else [keys] + + kwargs["serialization_ctxt"] = attr_desc + new_attr = self.serialize_data(orig_attr, attr_desc["type"], **kwargs) + + if is_xml_model_serialization: + xml_desc = attr_desc.get("xml", {}) + xml_name = xml_desc.get("name", attr_desc["key"]) + xml_prefix = xml_desc.get("prefix", None) + xml_ns = xml_desc.get("ns", None) + if xml_desc.get("attr", False): + if xml_ns: + ET.register_namespace(xml_prefix, xml_ns) + xml_name = "{{{}}}{}".format(xml_ns, xml_name) + serialized.set(xml_name, new_attr) # type: ignore + continue + if xml_desc.get("text", False): + serialized.text = new_attr # type: ignore + continue + if isinstance(new_attr, list): + serialized.extend(new_attr) # type: ignore + elif isinstance(new_attr, ET.Element): + # If the down XML has no XML/Name, + # we MUST replace the tag with the local tag. But keeping the namespaces. + if "name" not in getattr(orig_attr, "_xml_map", {}): + splitted_tag = new_attr.tag.split("}") + if len(splitted_tag) == 2: # Namespace + new_attr.tag = "}".join([splitted_tag[0], xml_name]) + else: + new_attr.tag = xml_name + serialized.append(new_attr) # type: ignore + else: # That's a basic type + # Integrate namespace if necessary + local_node = _create_xml_node(xml_name, xml_prefix, xml_ns) + local_node.text = str(new_attr) + serialized.append(local_node) # type: ignore + else: # JSON + for k in reversed(keys): # type: ignore + new_attr = {k: new_attr} + + _new_attr = new_attr + _serialized = serialized + for k in keys: # type: ignore + if k not in _serialized: + _serialized.update(_new_attr) # type: ignore + _new_attr = _new_attr[k] # type: ignore + _serialized = _serialized[k] + except ValueError as err: + if isinstance(err, SerializationError): + raise + + except (AttributeError, KeyError, TypeError) as err: + msg = "Attribute {} in object {} cannot be serialized.\n{}".format(attr_name, class_name, str(target_obj)) + raise SerializationError(msg) from err + return serialized + + def body(self, data, data_type, **kwargs): + """Serialize data intended for a request body. + + :param object data: The data to be serialized. + :param str data_type: The type to be serialized from. + :rtype: dict + :raises SerializationError: if serialization fails. + :raises ValueError: if data is None + :returns: The serialized request body + """ + + # Just in case this is a dict + internal_data_type_str = data_type.strip("[]{}") + internal_data_type = self.dependencies.get(internal_data_type_str, None) + try: + is_xml_model_serialization = kwargs["is_xml"] + except KeyError: + if internal_data_type and issubclass(internal_data_type, Model): + is_xml_model_serialization = kwargs.setdefault("is_xml", internal_data_type.is_xml_model()) + else: + is_xml_model_serialization = False + if internal_data_type and not isinstance(internal_data_type, Enum): + try: + deserializer = Deserializer(self.dependencies) + # Since it's on serialization, it's almost sure that format is not JSON REST + # We're not able to deal with additional properties for now. + deserializer.additional_properties_detection = False + if is_xml_model_serialization: + deserializer.key_extractors = [ # type: ignore + attribute_key_case_insensitive_extractor, + ] + else: + deserializer.key_extractors = [ + rest_key_case_insensitive_extractor, + attribute_key_case_insensitive_extractor, + last_rest_key_case_insensitive_extractor, + ] + data = deserializer._deserialize(data_type, data) # pylint: disable=protected-access + except DeserializationError as err: + raise SerializationError("Unable to build a model: " + str(err)) from err + + return self._serialize(data, data_type, **kwargs) + + def url(self, name, data, data_type, **kwargs): + """Serialize data intended for a URL path. + + :param str name: The name of the URL path parameter. + :param object data: The data to be serialized. + :param str data_type: The type to be serialized from. + :rtype: str + :returns: The serialized URL path + :raises TypeError: if serialization fails. + :raises ValueError: if data is None + """ + try: + output = self.serialize_data(data, data_type, **kwargs) + if data_type == "bool": + output = json.dumps(output) + + if kwargs.get("skip_quote") is True: + output = str(output) + output = output.replace("{", quote("{")).replace("}", quote("}")) + else: + output = quote(str(output), safe="") + except SerializationError as exc: + raise TypeError("{} must be type {}.".format(name, data_type)) from exc + return output + + def query(self, name, data, data_type, **kwargs): + """Serialize data intended for a URL query. + + :param str name: The name of the query parameter. + :param object data: The data to be serialized. + :param str data_type: The type to be serialized from. + :rtype: str, list + :raises TypeError: if serialization fails. + :raises ValueError: if data is None + :returns: The serialized query parameter + """ + try: + # Treat the list aside, since we don't want to encode the div separator + if data_type.startswith("["): + internal_data_type = data_type[1:-1] + do_quote = not kwargs.get("skip_quote", False) + return self.serialize_iter(data, internal_data_type, do_quote=do_quote, **kwargs) + + # Not a list, regular serialization + output = self.serialize_data(data, data_type, **kwargs) + if data_type == "bool": + output = json.dumps(output) + if kwargs.get("skip_quote") is True: + output = str(output) + else: + output = quote(str(output), safe="") + except SerializationError as exc: + raise TypeError("{} must be type {}.".format(name, data_type)) from exc + return str(output) + + def header(self, name, data, data_type, **kwargs): + """Serialize data intended for a request header. + + :param str name: The name of the header. + :param object data: The data to be serialized. + :param str data_type: The type to be serialized from. + :rtype: str + :raises TypeError: if serialization fails. + :raises ValueError: if data is None + :returns: The serialized header + """ + try: + if data_type in ["[str]"]: + data = ["" if d is None else d for d in data] + + output = self.serialize_data(data, data_type, **kwargs) + if data_type == "bool": + output = json.dumps(output) + except SerializationError as exc: + raise TypeError("{} must be type {}.".format(name, data_type)) from exc + return str(output) + + def serialize_data(self, data, data_type, **kwargs): + """Serialize generic data according to supplied data type. + + :param object data: The data to be serialized. + :param str data_type: The type to be serialized from. + :raises AttributeError: if required data is None. + :raises ValueError: if data is None + :raises SerializationError: if serialization fails. + :returns: The serialized data. + :rtype: str, int, float, bool, dict, list + """ + if data is None: + raise ValueError("No value for given attribute") + + try: + if data is CoreNull: + return None + if data_type in self.basic_types.values(): + return self.serialize_basic(data, data_type, **kwargs) + + if data_type in self.serialize_type: + return self.serialize_type[data_type](data, **kwargs) + + # If dependencies is empty, try with current data class + # It has to be a subclass of Enum anyway + enum_type = self.dependencies.get(data_type, data.__class__) + if issubclass(enum_type, Enum): + return Serializer.serialize_enum(data, enum_obj=enum_type) + + iter_type = data_type[0] + data_type[-1] + if iter_type in self.serialize_type: + return self.serialize_type[iter_type](data, data_type[1:-1], **kwargs) + + except (ValueError, TypeError) as err: + msg = "Unable to serialize value: {!r} as type: {!r}." + raise SerializationError(msg.format(data, data_type)) from err + return self._serialize(data, **kwargs) + + @classmethod + def _get_custom_serializers(cls, data_type, **kwargs): # pylint: disable=inconsistent-return-statements + custom_serializer = kwargs.get("basic_types_serializers", {}).get(data_type) + if custom_serializer: + return custom_serializer + if kwargs.get("is_xml", False): + return cls._xml_basic_types_serializers.get(data_type) + + @classmethod + def serialize_basic(cls, data, data_type, **kwargs): + """Serialize basic builting data type. + Serializes objects to str, int, float or bool. + + Possible kwargs: + - basic_types_serializers dict[str, callable] : If set, use the callable as serializer + - is_xml bool : If set, use xml_basic_types_serializers + + :param obj data: Object to be serialized. + :param str data_type: Type of object in the iterable. + :rtype: str, int, float, bool + :return: serialized object + """ + custom_serializer = cls._get_custom_serializers(data_type, **kwargs) + if custom_serializer: + return custom_serializer(data) + if data_type == "str": + return cls.serialize_unicode(data) + return eval(data_type)(data) # nosec # pylint: disable=eval-used + + @classmethod + def serialize_unicode(cls, data): + """Special handling for serializing unicode strings in Py2. + Encode to UTF-8 if unicode, otherwise handle as a str. + + :param str data: Object to be serialized. + :rtype: str + :return: serialized object + """ + try: # If I received an enum, return its value + return data.value + except AttributeError: + pass + + try: + if isinstance(data, unicode): # type: ignore + # Don't change it, JSON and XML ElementTree are totally able + # to serialize correctly u'' strings + return data + except NameError: + return str(data) + return str(data) + + def serialize_iter(self, data, iter_type, div=None, **kwargs): + """Serialize iterable. + + Supported kwargs: + - serialization_ctxt dict : The current entry of _attribute_map, or same format. + serialization_ctxt['type'] should be same as data_type. + - is_xml bool : If set, serialize as XML + + :param list data: Object to be serialized. + :param str iter_type: Type of object in the iterable. + :param str div: If set, this str will be used to combine the elements + in the iterable into a combined string. Default is 'None'. + Defaults to False. + :rtype: list, str + :return: serialized iterable + """ + if isinstance(data, str): + raise SerializationError("Refuse str type as a valid iter type.") + + serialization_ctxt = kwargs.get("serialization_ctxt", {}) + is_xml = kwargs.get("is_xml", False) + + serialized = [] + for d in data: + try: + serialized.append(self.serialize_data(d, iter_type, **kwargs)) + except ValueError as err: + if isinstance(err, SerializationError): + raise + serialized.append(None) + + if kwargs.get("do_quote", False): + serialized = ["" if s is None else quote(str(s), safe="") for s in serialized] + + if div: + serialized = ["" if s is None else str(s) for s in serialized] + serialized = div.join(serialized) + + if "xml" in serialization_ctxt or is_xml: + # XML serialization is more complicated + xml_desc = serialization_ctxt.get("xml", {}) + xml_name = xml_desc.get("name") + if not xml_name: + xml_name = serialization_ctxt["key"] + + # Create a wrap node if necessary (use the fact that Element and list have "append") + is_wrapped = xml_desc.get("wrapped", False) + node_name = xml_desc.get("itemsName", xml_name) + if is_wrapped: + final_result = _create_xml_node(xml_name, xml_desc.get("prefix", None), xml_desc.get("ns", None)) + else: + final_result = [] + # All list elements to "local_node" + for el in serialized: + if isinstance(el, ET.Element): + el_node = el + else: + el_node = _create_xml_node(node_name, xml_desc.get("prefix", None), xml_desc.get("ns", None)) + if el is not None: # Otherwise it writes "None" :-p + el_node.text = str(el) + final_result.append(el_node) + return final_result + return serialized + + def serialize_dict(self, attr, dict_type, **kwargs): + """Serialize a dictionary of objects. + + :param dict attr: Object to be serialized. + :param str dict_type: Type of object in the dictionary. + :rtype: dict + :return: serialized dictionary + """ + serialization_ctxt = kwargs.get("serialization_ctxt", {}) + serialized = {} + for key, value in attr.items(): + try: + serialized[self.serialize_unicode(key)] = self.serialize_data(value, dict_type, **kwargs) + except ValueError as err: + if isinstance(err, SerializationError): + raise + serialized[self.serialize_unicode(key)] = None + + if "xml" in serialization_ctxt: + # XML serialization is more complicated + xml_desc = serialization_ctxt["xml"] + xml_name = xml_desc["name"] + + final_result = _create_xml_node(xml_name, xml_desc.get("prefix", None), xml_desc.get("ns", None)) + for key, value in serialized.items(): + ET.SubElement(final_result, key).text = value + return final_result + + return serialized + + def serialize_object(self, attr, **kwargs): # pylint: disable=too-many-return-statements + """Serialize a generic object. + This will be handled as a dictionary. If object passed in is not + a basic type (str, int, float, dict, list) it will simply be + cast to str. + + :param dict attr: Object to be serialized. + :rtype: dict or str + :return: serialized object + """ + if attr is None: + return None + if isinstance(attr, ET.Element): + return attr + obj_type = type(attr) + if obj_type in self.basic_types: + return self.serialize_basic(attr, self.basic_types[obj_type], **kwargs) + if obj_type is _long_type: + return self.serialize_long(attr) + if obj_type is str: + return self.serialize_unicode(attr) + if obj_type is datetime.datetime: + return self.serialize_iso(attr) + if obj_type is datetime.date: + return self.serialize_date(attr) + if obj_type is datetime.time: + return self.serialize_time(attr) + if obj_type is datetime.timedelta: + return self.serialize_duration(attr) + if obj_type is decimal.Decimal: + return self.serialize_decimal(attr) + + # If it's a model or I know this dependency, serialize as a Model + if obj_type in self.dependencies.values() or isinstance(attr, Model): + return self._serialize(attr) + + if obj_type == dict: + serialized = {} + for key, value in attr.items(): + try: + serialized[self.serialize_unicode(key)] = self.serialize_object(value, **kwargs) + except ValueError: + serialized[self.serialize_unicode(key)] = None + return serialized + + if obj_type == list: + serialized = [] + for obj in attr: + try: + serialized.append(self.serialize_object(obj, **kwargs)) + except ValueError: + pass + return serialized + return str(attr) + + @staticmethod + def serialize_enum(attr, enum_obj=None): + try: + result = attr.value + except AttributeError: + result = attr + try: + enum_obj(result) # type: ignore + return result + except ValueError as exc: + for enum_value in enum_obj: # type: ignore + if enum_value.value.lower() == str(attr).lower(): + return enum_value.value + error = "{!r} is not valid value for enum {!r}" + raise SerializationError(error.format(attr, enum_obj)) from exc + + @staticmethod + def serialize_bytearray(attr, **kwargs): # pylint: disable=unused-argument + """Serialize bytearray into base-64 string. + + :param str attr: Object to be serialized. + :rtype: str + :return: serialized base64 + """ + return b64encode(attr).decode() + + @staticmethod + def serialize_base64(attr, **kwargs): # pylint: disable=unused-argument + """Serialize str into base-64 string. + + :param str attr: Object to be serialized. + :rtype: str + :return: serialized base64 + """ + encoded = b64encode(attr).decode("ascii") + return encoded.strip("=").replace("+", "-").replace("/", "_") + + @staticmethod + def serialize_decimal(attr, **kwargs): # pylint: disable=unused-argument + """Serialize Decimal object to float. + + :param decimal attr: Object to be serialized. + :rtype: float + :return: serialized decimal + """ + return float(attr) + + @staticmethod + def serialize_long(attr, **kwargs): # pylint: disable=unused-argument + """Serialize long (Py2) or int (Py3). + + :param int attr: Object to be serialized. + :rtype: int/long + :return: serialized long + """ + return _long_type(attr) + + @staticmethod + def serialize_date(attr, **kwargs): # pylint: disable=unused-argument + """Serialize Date object into ISO-8601 formatted string. + + :param Date attr: Object to be serialized. + :rtype: str + :return: serialized date + """ + if isinstance(attr, str): + attr = isodate.parse_date(attr) + t = "{:04}-{:02}-{:02}".format(attr.year, attr.month, attr.day) + return t + + @staticmethod + def serialize_time(attr, **kwargs): # pylint: disable=unused-argument + """Serialize Time object into ISO-8601 formatted string. + + :param datetime.time attr: Object to be serialized. + :rtype: str + :return: serialized time + """ + if isinstance(attr, str): + attr = isodate.parse_time(attr) + t = "{:02}:{:02}:{:02}".format(attr.hour, attr.minute, attr.second) + if attr.microsecond: + t += ".{:02}".format(attr.microsecond) + return t + + @staticmethod + def serialize_duration(attr, **kwargs): # pylint: disable=unused-argument + """Serialize TimeDelta object into ISO-8601 formatted string. + + :param TimeDelta attr: Object to be serialized. + :rtype: str + :return: serialized duration + """ + if isinstance(attr, str): + attr = isodate.parse_duration(attr) + return isodate.duration_isoformat(attr) + + @staticmethod + def serialize_rfc(attr, **kwargs): # pylint: disable=unused-argument + """Serialize Datetime object into RFC-1123 formatted string. + + :param Datetime attr: Object to be serialized. + :rtype: str + :raises TypeError: if format invalid. + :return: serialized rfc + """ + try: + if not attr.tzinfo: + _LOGGER.warning("Datetime with no tzinfo will be considered UTC.") + utc = attr.utctimetuple() + except AttributeError as exc: + raise TypeError("RFC1123 object must be valid Datetime object.") from exc + + return "{}, {:02} {} {:04} {:02}:{:02}:{:02} GMT".format( + Serializer.days[utc.tm_wday], + utc.tm_mday, + Serializer.months[utc.tm_mon], + utc.tm_year, + utc.tm_hour, + utc.tm_min, + utc.tm_sec, + ) + + @staticmethod + def serialize_iso(attr, **kwargs): # pylint: disable=unused-argument + """Serialize Datetime object into ISO-8601 formatted string. + + :param Datetime attr: Object to be serialized. + :rtype: str + :raises SerializationError: if format invalid. + :return: serialized iso + """ + if isinstance(attr, str): + attr = isodate.parse_datetime(attr) + try: + if not attr.tzinfo: + _LOGGER.warning("Datetime with no tzinfo will be considered UTC.") + utc = attr.utctimetuple() + if utc.tm_year > 9999 or utc.tm_year < 1: + raise OverflowError("Hit max or min date") + + microseconds = str(attr.microsecond).rjust(6, "0").rstrip("0").ljust(3, "0") + if microseconds: + microseconds = "." + microseconds + date = "{:04}-{:02}-{:02}T{:02}:{:02}:{:02}".format( + utc.tm_year, utc.tm_mon, utc.tm_mday, utc.tm_hour, utc.tm_min, utc.tm_sec + ) + return date + microseconds + "Z" + except (ValueError, OverflowError) as err: + msg = "Unable to serialize datetime object." + raise SerializationError(msg) from err + except AttributeError as err: + msg = "ISO-8601 object must be valid Datetime object." + raise TypeError(msg) from err + + @staticmethod + def serialize_unix(attr, **kwargs): # pylint: disable=unused-argument + """Serialize Datetime object into IntTime format. + This is represented as seconds. + + :param Datetime attr: Object to be serialized. + :rtype: int + :raises SerializationError: if format invalid + :return: serialied unix + """ + if isinstance(attr, int): + return attr + try: + if not attr.tzinfo: + _LOGGER.warning("Datetime with no tzinfo will be considered UTC.") + return int(calendar.timegm(attr.utctimetuple())) + except AttributeError as exc: + raise TypeError("Unix time object must be valid Datetime object.") from exc + + +def rest_key_extractor(attr, attr_desc, data): # pylint: disable=unused-argument + key = attr_desc["key"] + working_data = data + + while "." in key: + # Need the cast, as for some reasons "split" is typed as list[str | Any] + dict_keys = cast(List[str], _FLATTEN.split(key)) + if len(dict_keys) == 1: + key = _decode_attribute_map_key(dict_keys[0]) + break + working_key = _decode_attribute_map_key(dict_keys[0]) + working_data = working_data.get(working_key, data) + if working_data is None: + # If at any point while following flatten JSON path see None, it means + # that all properties under are None as well + return None + key = ".".join(dict_keys[1:]) + + return working_data.get(key) + + +def rest_key_case_insensitive_extractor( # pylint: disable=unused-argument, inconsistent-return-statements + attr, attr_desc, data +): + key = attr_desc["key"] + working_data = data + + while "." in key: + dict_keys = _FLATTEN.split(key) + if len(dict_keys) == 1: + key = _decode_attribute_map_key(dict_keys[0]) + break + working_key = _decode_attribute_map_key(dict_keys[0]) + working_data = attribute_key_case_insensitive_extractor(working_key, None, working_data) + if working_data is None: + # If at any point while following flatten JSON path see None, it means + # that all properties under are None as well + return None + key = ".".join(dict_keys[1:]) + + if working_data: + return attribute_key_case_insensitive_extractor(key, None, working_data) + + +def last_rest_key_extractor(attr, attr_desc, data): # pylint: disable=unused-argument + """Extract the attribute in "data" based on the last part of the JSON path key. + + :param str attr: The attribute to extract + :param dict attr_desc: The attribute description + :param dict data: The data to extract from + :rtype: object + :returns: The extracted attribute + """ + key = attr_desc["key"] + dict_keys = _FLATTEN.split(key) + return attribute_key_extractor(dict_keys[-1], None, data) + + +def last_rest_key_case_insensitive_extractor(attr, attr_desc, data): # pylint: disable=unused-argument + """Extract the attribute in "data" based on the last part of the JSON path key. + + This is the case insensitive version of "last_rest_key_extractor" + :param str attr: The attribute to extract + :param dict attr_desc: The attribute description + :param dict data: The data to extract from + :rtype: object + :returns: The extracted attribute + """ + key = attr_desc["key"] + dict_keys = _FLATTEN.split(key) + return attribute_key_case_insensitive_extractor(dict_keys[-1], None, data) + + +def attribute_key_extractor(attr, _, data): + return data.get(attr) + + +def attribute_key_case_insensitive_extractor(attr, _, data): + found_key = None + lower_attr = attr.lower() + for key in data: + if lower_attr == key.lower(): + found_key = key + break + + return data.get(found_key) + + +def _extract_name_from_internal_type(internal_type): + """Given an internal type XML description, extract correct XML name with namespace. + + :param dict internal_type: An model type + :rtype: tuple + :returns: A tuple XML name + namespace dict + """ + internal_type_xml_map = getattr(internal_type, "_xml_map", {}) + xml_name = internal_type_xml_map.get("name", internal_type.__name__) + xml_ns = internal_type_xml_map.get("ns", None) + if xml_ns: + xml_name = "{{{}}}{}".format(xml_ns, xml_name) + return xml_name + + +def xml_key_extractor(attr, attr_desc, data): # pylint: disable=unused-argument,too-many-return-statements + if isinstance(data, dict): + return None + + # Test if this model is XML ready first + if not isinstance(data, ET.Element): + return None + + xml_desc = attr_desc.get("xml", {}) + xml_name = xml_desc.get("name", attr_desc["key"]) + + # Look for a children + is_iter_type = attr_desc["type"].startswith("[") + is_wrapped = xml_desc.get("wrapped", False) + internal_type = attr_desc.get("internalType", None) + internal_type_xml_map = getattr(internal_type, "_xml_map", {}) + + # Integrate namespace if necessary + xml_ns = xml_desc.get("ns", internal_type_xml_map.get("ns", None)) + if xml_ns: + xml_name = "{{{}}}{}".format(xml_ns, xml_name) + + # If it's an attribute, that's simple + if xml_desc.get("attr", False): + return data.get(xml_name) + + # If it's x-ms-text, that's simple too + if xml_desc.get("text", False): + return data.text + + # Scenario where I take the local name: + # - Wrapped node + # - Internal type is an enum (considered basic types) + # - Internal type has no XML/Name node + if is_wrapped or (internal_type and (issubclass(internal_type, Enum) or "name" not in internal_type_xml_map)): + children = data.findall(xml_name) + # If internal type has a local name and it's not a list, I use that name + elif not is_iter_type and internal_type and "name" in internal_type_xml_map: + xml_name = _extract_name_from_internal_type(internal_type) + children = data.findall(xml_name) + # That's an array + else: + if internal_type: # Complex type, ignore itemsName and use the complex type name + items_name = _extract_name_from_internal_type(internal_type) + else: + items_name = xml_desc.get("itemsName", xml_name) + children = data.findall(items_name) + + if len(children) == 0: + if is_iter_type: + if is_wrapped: + return None # is_wrapped no node, we want None + return [] # not wrapped, assume empty list + return None # Assume it's not there, maybe an optional node. + + # If is_iter_type and not wrapped, return all found children + if is_iter_type: + if not is_wrapped: + return children + # Iter and wrapped, should have found one node only (the wrap one) + if len(children) != 1: + raise DeserializationError( + "Tried to deserialize an array not wrapped, and found several nodes '{}'. Maybe you should declare this array as wrapped?".format( # pylint: disable=line-too-long + xml_name + ) + ) + return list(children[0]) # Might be empty list and that's ok. + + # Here it's not a itertype, we should have found one element only or empty + if len(children) > 1: + raise DeserializationError("Find several XML '{}' where it was not expected".format(xml_name)) + return children[0] + + +class Deserializer: + """Response object model deserializer. + + :param dict classes: Class type dictionary for deserializing complex types. + :ivar list key_extractors: Ordered list of extractors to be used by this deserializer. + """ + + basic_types = {str: "str", int: "int", bool: "bool", float: "float"} + + valid_date = re.compile(r"\d{4}[-]\d{2}[-]\d{2}T\d{2}:\d{2}:\d{2}\.?\d*Z?[-+]?[\d{2}]?:?[\d{2}]?") + + def __init__(self, classes: Optional[Mapping[str, type]] = None) -> None: + self.deserialize_type = { + "iso-8601": Deserializer.deserialize_iso, + "rfc-1123": Deserializer.deserialize_rfc, + "unix-time": Deserializer.deserialize_unix, + "duration": Deserializer.deserialize_duration, + "date": Deserializer.deserialize_date, + "time": Deserializer.deserialize_time, + "decimal": Deserializer.deserialize_decimal, + "long": Deserializer.deserialize_long, + "bytearray": Deserializer.deserialize_bytearray, + "base64": Deserializer.deserialize_base64, + "object": self.deserialize_object, + "[]": self.deserialize_iter, + "{}": self.deserialize_dict, + } + self.deserialize_expected_types = { + "duration": (isodate.Duration, datetime.timedelta), + "iso-8601": (datetime.datetime), + } + self.dependencies: Dict[str, type] = dict(classes) if classes else {} + self.key_extractors = [rest_key_extractor, xml_key_extractor] + # Additional properties only works if the "rest_key_extractor" is used to + # extract the keys. Making it to work whatever the key extractor is too much + # complicated, with no real scenario for now. + # So adding a flag to disable additional properties detection. This flag should be + # used if your expect the deserialization to NOT come from a JSON REST syntax. + # Otherwise, result are unexpected + self.additional_properties_detection = True + + def __call__(self, target_obj, response_data, content_type=None): + """Call the deserializer to process a REST response. + + :param str target_obj: Target data type to deserialize to. + :param requests.Response response_data: REST response object. + :param str content_type: Swagger "produces" if available. + :raises DeserializationError: if deserialization fails. + :return: Deserialized object. + :rtype: object + """ + data = self._unpack_content(response_data, content_type) + return self._deserialize(target_obj, data) + + def _deserialize(self, target_obj, data): # pylint: disable=inconsistent-return-statements + """Call the deserializer on a model. + + Data needs to be already deserialized as JSON or XML ElementTree + + :param str target_obj: Target data type to deserialize to. + :param object data: Object to deserialize. + :raises DeserializationError: if deserialization fails. + :return: Deserialized object. + :rtype: object + """ + # This is already a model, go recursive just in case + if hasattr(data, "_attribute_map"): + constants = [name for name, config in getattr(data, "_validation", {}).items() if config.get("constant")] + try: + for attr, mapconfig in data._attribute_map.items(): # pylint: disable=protected-access + if attr in constants: + continue + value = getattr(data, attr) + if value is None: + continue + local_type = mapconfig["type"] + internal_data_type = local_type.strip("[]{}") + if internal_data_type not in self.dependencies or isinstance(internal_data_type, Enum): + continue + setattr(data, attr, self._deserialize(local_type, value)) + return data + except AttributeError: + return + + response, class_name = self._classify_target(target_obj, data) + + if isinstance(response, str): + return self.deserialize_data(data, response) + if isinstance(response, type) and issubclass(response, Enum): + return self.deserialize_enum(data, response) + + if data is None or data is CoreNull: + return data + try: + attributes = response._attribute_map # type: ignore # pylint: disable=protected-access + d_attrs = {} + for attr, attr_desc in attributes.items(): + # Check empty string. If it's not empty, someone has a real "additionalProperties"... + if attr == "additional_properties" and attr_desc["key"] == "": + continue + raw_value = None + # Enhance attr_desc with some dynamic data + attr_desc = attr_desc.copy() # Do a copy, do not change the real one + internal_data_type = attr_desc["type"].strip("[]{}") + if internal_data_type in self.dependencies: + attr_desc["internalType"] = self.dependencies[internal_data_type] + + for key_extractor in self.key_extractors: + found_value = key_extractor(attr, attr_desc, data) + if found_value is not None: + if raw_value is not None and raw_value != found_value: + msg = ( + "Ignoring extracted value '%s' from %s for key '%s'" + " (duplicate extraction, follow extractors order)" + ) + _LOGGER.warning(msg, found_value, key_extractor, attr) + continue + raw_value = found_value + + value = self.deserialize_data(raw_value, attr_desc["type"]) + d_attrs[attr] = value + except (AttributeError, TypeError, KeyError) as err: + msg = "Unable to deserialize to object: " + class_name # type: ignore + raise DeserializationError(msg) from err + additional_properties = self._build_additional_properties(attributes, data) + return self._instantiate_model(response, d_attrs, additional_properties) + + def _build_additional_properties(self, attribute_map, data): + if not self.additional_properties_detection: + return None + if "additional_properties" in attribute_map and attribute_map.get("additional_properties", {}).get("key") != "": + # Check empty string. If it's not empty, someone has a real "additionalProperties" + return None + if isinstance(data, ET.Element): + data = {el.tag: el.text for el in data} + + known_keys = { + _decode_attribute_map_key(_FLATTEN.split(desc["key"])[0]) + for desc in attribute_map.values() + if desc["key"] != "" + } + present_keys = set(data.keys()) + missing_keys = present_keys - known_keys + return {key: data[key] for key in missing_keys} + + def _classify_target(self, target, data): + """Check to see whether the deserialization target object can + be classified into a subclass. + Once classification has been determined, initialize object. + + :param str target: The target object type to deserialize to. + :param str/dict data: The response data to deserialize. + :return: The classified target object and its class name. + :rtype: tuple + """ + if target is None: + return None, None + + if isinstance(target, str): + try: + target = self.dependencies[target] + except KeyError: + return target, target + + try: + target = target._classify(data, self.dependencies) # type: ignore # pylint: disable=protected-access + except AttributeError: + pass # Target is not a Model, no classify + return target, target.__class__.__name__ # type: ignore + + def failsafe_deserialize(self, target_obj, data, content_type=None): + """Ignores any errors encountered in deserialization, + and falls back to not deserializing the object. Recommended + for use in error deserialization, as we want to return the + HttpResponseError to users, and not have them deal with + a deserialization error. + + :param str target_obj: The target object type to deserialize to. + :param str/dict data: The response data to deserialize. + :param str content_type: Swagger "produces" if available. + :return: Deserialized object. + :rtype: object + """ + try: + return self(target_obj, data, content_type=content_type) + except: # pylint: disable=bare-except + _LOGGER.debug( + "Ran into a deserialization error. Ignoring since this is failsafe deserialization", exc_info=True + ) + return None + + @staticmethod + def _unpack_content(raw_data, content_type=None): + """Extract the correct structure for deserialization. + + If raw_data is a PipelineResponse, try to extract the result of RawDeserializer. + if we can't, raise. Your Pipeline should have a RawDeserializer. + + If not a pipeline response and raw_data is bytes or string, use content-type + to decode it. If no content-type, try JSON. + + If raw_data is something else, bypass all logic and return it directly. + + :param obj raw_data: Data to be processed. + :param str content_type: How to parse if raw_data is a string/bytes. + :raises JSONDecodeError: If JSON is requested and parsing is impossible. + :raises UnicodeDecodeError: If bytes is not UTF8 + :rtype: object + :return: Unpacked content. + """ + # Assume this is enough to detect a Pipeline Response without importing it + context = getattr(raw_data, "context", {}) + if context: + if RawDeserializer.CONTEXT_NAME in context: + return context[RawDeserializer.CONTEXT_NAME] + raise ValueError("This pipeline didn't have the RawDeserializer policy; can't deserialize") + + # Assume this is enough to recognize universal_http.ClientResponse without importing it + if hasattr(raw_data, "body"): + return RawDeserializer.deserialize_from_http_generics(raw_data.text(), raw_data.headers) + + # Assume this enough to recognize requests.Response without importing it. + if hasattr(raw_data, "_content_consumed"): + return RawDeserializer.deserialize_from_http_generics(raw_data.text, raw_data.headers) + + if isinstance(raw_data, (str, bytes)) or hasattr(raw_data, "read"): + return RawDeserializer.deserialize_from_text(raw_data, content_type) # type: ignore + return raw_data + + def _instantiate_model(self, response, attrs, additional_properties=None): + """Instantiate a response model passing in deserialized args. + + :param Response response: The response model class. + :param dict attrs: The deserialized response attributes. + :param dict additional_properties: Additional properties to be set. + :rtype: Response + :return: The instantiated response model. + """ + if callable(response): + subtype = getattr(response, "_subtype_map", {}) + try: + readonly = [ + k + for k, v in response._validation.items() # pylint: disable=protected-access # type: ignore + if v.get("readonly") + ] + const = [ + k + for k, v in response._validation.items() # pylint: disable=protected-access # type: ignore + if v.get("constant") + ] + kwargs = {k: v for k, v in attrs.items() if k not in subtype and k not in readonly + const} + response_obj = response(**kwargs) + for attr in readonly: + setattr(response_obj, attr, attrs.get(attr)) + if additional_properties: + response_obj.additional_properties = additional_properties # type: ignore + return response_obj + except TypeError as err: + msg = "Unable to deserialize {} into model {}. ".format(kwargs, response) # type: ignore + raise DeserializationError(msg + str(err)) from err + else: + try: + for attr, value in attrs.items(): + setattr(response, attr, value) + return response + except Exception as exp: + msg = "Unable to populate response model. " + msg += "Type: {}, Error: {}".format(type(response), exp) + raise DeserializationError(msg) from exp + + def deserialize_data(self, data, data_type): # pylint: disable=too-many-return-statements + """Process data for deserialization according to data type. + + :param str data: The response string to be deserialized. + :param str data_type: The type to deserialize to. + :raises DeserializationError: if deserialization fails. + :return: Deserialized object. + :rtype: object + """ + if data is None: + return data + + try: + if not data_type: + return data + if data_type in self.basic_types.values(): + return self.deserialize_basic(data, data_type) + if data_type in self.deserialize_type: + if isinstance(data, self.deserialize_expected_types.get(data_type, tuple())): + return data + + is_a_text_parsing_type = lambda x: x not in [ # pylint: disable=unnecessary-lambda-assignment + "object", + "[]", + r"{}", + ] + if isinstance(data, ET.Element) and is_a_text_parsing_type(data_type) and not data.text: + return None + data_val = self.deserialize_type[data_type](data) + return data_val + + iter_type = data_type[0] + data_type[-1] + if iter_type in self.deserialize_type: + return self.deserialize_type[iter_type](data, data_type[1:-1]) + + obj_type = self.dependencies[data_type] + if issubclass(obj_type, Enum): + if isinstance(data, ET.Element): + data = data.text + return self.deserialize_enum(data, obj_type) + + except (ValueError, TypeError, AttributeError) as err: + msg = "Unable to deserialize response data." + msg += " Data: {}, {}".format(data, data_type) + raise DeserializationError(msg) from err + return self._deserialize(obj_type, data) + + def deserialize_iter(self, attr, iter_type): + """Deserialize an iterable. + + :param list attr: Iterable to be deserialized. + :param str iter_type: The type of object in the iterable. + :return: Deserialized iterable. + :rtype: list + """ + if attr is None: + return None + if isinstance(attr, ET.Element): # If I receive an element here, get the children + attr = list(attr) + if not isinstance(attr, (list, set)): + raise DeserializationError("Cannot deserialize as [{}] an object of type {}".format(iter_type, type(attr))) + return [self.deserialize_data(a, iter_type) for a in attr] + + def deserialize_dict(self, attr, dict_type): + """Deserialize a dictionary. + + :param dict/list attr: Dictionary to be deserialized. Also accepts + a list of key, value pairs. + :param str dict_type: The object type of the items in the dictionary. + :return: Deserialized dictionary. + :rtype: dict + """ + if isinstance(attr, list): + return {x["key"]: self.deserialize_data(x["value"], dict_type) for x in attr} + + if isinstance(attr, ET.Element): + # Transform <Key>value</Key> into {"Key": "value"} + attr = {el.tag: el.text for el in attr} + return {k: self.deserialize_data(v, dict_type) for k, v in attr.items()} + + def deserialize_object(self, attr, **kwargs): # pylint: disable=too-many-return-statements + """Deserialize a generic object. + This will be handled as a dictionary. + + :param dict attr: Dictionary to be deserialized. + :return: Deserialized object. + :rtype: dict + :raises TypeError: if non-builtin datatype encountered. + """ + if attr is None: + return None + if isinstance(attr, ET.Element): + # Do no recurse on XML, just return the tree as-is + return attr + if isinstance(attr, str): + return self.deserialize_basic(attr, "str") + obj_type = type(attr) + if obj_type in self.basic_types: + return self.deserialize_basic(attr, self.basic_types[obj_type]) + if obj_type is _long_type: + return self.deserialize_long(attr) + + if obj_type == dict: + deserialized = {} + for key, value in attr.items(): + try: + deserialized[key] = self.deserialize_object(value, **kwargs) + except ValueError: + deserialized[key] = None + return deserialized + + if obj_type == list: + deserialized = [] + for obj in attr: + try: + deserialized.append(self.deserialize_object(obj, **kwargs)) + except ValueError: + pass + return deserialized + + error = "Cannot deserialize generic object with type: " + raise TypeError(error + str(obj_type)) + + def deserialize_basic(self, attr, data_type): # pylint: disable=too-many-return-statements + """Deserialize basic builtin data type from string. + Will attempt to convert to str, int, float and bool. + This function will also accept '1', '0', 'true' and 'false' as + valid bool values. + + :param str attr: response string to be deserialized. + :param str data_type: deserialization data type. + :return: Deserialized basic type. + :rtype: str, int, float or bool + :raises TypeError: if string format is not valid. + """ + # If we're here, data is supposed to be a basic type. + # If it's still an XML node, take the text + if isinstance(attr, ET.Element): + attr = attr.text + if not attr: + if data_type == "str": + # None or '', node <a/> is empty string. + return "" + # None or '', node <a/> with a strong type is None. + # Don't try to model "empty bool" or "empty int" + return None + + if data_type == "bool": + if attr in [True, False, 1, 0]: + return bool(attr) + if isinstance(attr, str): + if attr.lower() in ["true", "1"]: + return True + if attr.lower() in ["false", "0"]: + return False + raise TypeError("Invalid boolean value: {}".format(attr)) + + if data_type == "str": + return self.deserialize_unicode(attr) + return eval(data_type)(attr) # nosec # pylint: disable=eval-used + + @staticmethod + def deserialize_unicode(data): + """Preserve unicode objects in Python 2, otherwise return data + as a string. + + :param str data: response string to be deserialized. + :return: Deserialized string. + :rtype: str or unicode + """ + # We might be here because we have an enum modeled as string, + # and we try to deserialize a partial dict with enum inside + if isinstance(data, Enum): + return data + + # Consider this is real string + try: + if isinstance(data, unicode): # type: ignore + return data + except NameError: + return str(data) + return str(data) + + @staticmethod + def deserialize_enum(data, enum_obj): + """Deserialize string into enum object. + + If the string is not a valid enum value it will be returned as-is + and a warning will be logged. + + :param str data: Response string to be deserialized. If this value is + None or invalid it will be returned as-is. + :param Enum enum_obj: Enum object to deserialize to. + :return: Deserialized enum object. + :rtype: Enum + """ + if isinstance(data, enum_obj) or data is None: + return data + if isinstance(data, Enum): + data = data.value + if isinstance(data, int): + # Workaround. We might consider remove it in the future. + try: + return list(enum_obj.__members__.values())[data] + except IndexError as exc: + error = "{!r} is not a valid index for enum {!r}" + raise DeserializationError(error.format(data, enum_obj)) from exc + try: + return enum_obj(str(data)) + except ValueError: + for enum_value in enum_obj: + if enum_value.value.lower() == str(data).lower(): + return enum_value + # We don't fail anymore for unknown value, we deserialize as a string + _LOGGER.warning("Deserializer is not able to find %s as valid enum in %s", data, enum_obj) + return Deserializer.deserialize_unicode(data) + + @staticmethod + def deserialize_bytearray(attr): + """Deserialize string into bytearray. + + :param str attr: response string to be deserialized. + :return: Deserialized bytearray + :rtype: bytearray + :raises TypeError: if string format invalid. + """ + if isinstance(attr, ET.Element): + attr = attr.text + return bytearray(b64decode(attr)) # type: ignore + + @staticmethod + def deserialize_base64(attr): + """Deserialize base64 encoded string into string. + + :param str attr: response string to be deserialized. + :return: Deserialized base64 string + :rtype: bytearray + :raises TypeError: if string format invalid. + """ + if isinstance(attr, ET.Element): + attr = attr.text + padding = "=" * (3 - (len(attr) + 3) % 4) # type: ignore + attr = attr + padding # type: ignore + encoded = attr.replace("-", "+").replace("_", "/") + return b64decode(encoded) + + @staticmethod + def deserialize_decimal(attr): + """Deserialize string into Decimal object. + + :param str attr: response string to be deserialized. + :return: Deserialized decimal + :raises DeserializationError: if string format invalid. + :rtype: decimal + """ + if isinstance(attr, ET.Element): + attr = attr.text + try: + return decimal.Decimal(str(attr)) # type: ignore + except decimal.DecimalException as err: + msg = "Invalid decimal {}".format(attr) + raise DeserializationError(msg) from err + + @staticmethod + def deserialize_long(attr): + """Deserialize string into long (Py2) or int (Py3). + + :param str attr: response string to be deserialized. + :return: Deserialized int + :rtype: long or int + :raises ValueError: if string format invalid. + """ + if isinstance(attr, ET.Element): + attr = attr.text + return _long_type(attr) # type: ignore + + @staticmethod + def deserialize_duration(attr): + """Deserialize ISO-8601 formatted string into TimeDelta object. + + :param str attr: response string to be deserialized. + :return: Deserialized duration + :rtype: TimeDelta + :raises DeserializationError: if string format invalid. + """ + if isinstance(attr, ET.Element): + attr = attr.text + try: + duration = isodate.parse_duration(attr) + except (ValueError, OverflowError, AttributeError) as err: + msg = "Cannot deserialize duration object." + raise DeserializationError(msg) from err + return duration + + @staticmethod + def deserialize_date(attr): + """Deserialize ISO-8601 formatted string into Date object. + + :param str attr: response string to be deserialized. + :return: Deserialized date + :rtype: Date + :raises DeserializationError: if string format invalid. + """ + if isinstance(attr, ET.Element): + attr = attr.text + if re.search(r"[^\W\d_]", attr, re.I + re.U): # type: ignore + raise DeserializationError("Date must have only digits and -. Received: %s" % attr) + # This must NOT use defaultmonth/defaultday. Using None ensure this raises an exception. + return isodate.parse_date(attr, defaultmonth=0, defaultday=0) + + @staticmethod + def deserialize_time(attr): + """Deserialize ISO-8601 formatted string into time object. + + :param str attr: response string to be deserialized. + :return: Deserialized time + :rtype: datetime.time + :raises DeserializationError: if string format invalid. + """ + if isinstance(attr, ET.Element): + attr = attr.text + if re.search(r"[^\W\d_]", attr, re.I + re.U): # type: ignore + raise DeserializationError("Date must have only digits and -. Received: %s" % attr) + return isodate.parse_time(attr) + + @staticmethod + def deserialize_rfc(attr): + """Deserialize RFC-1123 formatted string into Datetime object. + + :param str attr: response string to be deserialized. + :return: Deserialized RFC datetime + :rtype: Datetime + :raises DeserializationError: if string format invalid. + """ + if isinstance(attr, ET.Element): + attr = attr.text + try: + parsed_date = email.utils.parsedate_tz(attr) # type: ignore + date_obj = datetime.datetime( + *parsed_date[:6], tzinfo=datetime.timezone(datetime.timedelta(minutes=(parsed_date[9] or 0) / 60)) + ) + if not date_obj.tzinfo: + date_obj = date_obj.astimezone(tz=TZ_UTC) + except ValueError as err: + msg = "Cannot deserialize to rfc datetime object." + raise DeserializationError(msg) from err + return date_obj + + @staticmethod + def deserialize_iso(attr): + """Deserialize ISO-8601 formatted string into Datetime object. + + :param str attr: response string to be deserialized. + :return: Deserialized ISO datetime + :rtype: Datetime + :raises DeserializationError: if string format invalid. + """ + if isinstance(attr, ET.Element): + attr = attr.text + try: + attr = attr.upper() # type: ignore + match = Deserializer.valid_date.match(attr) + if not match: + raise ValueError("Invalid datetime string: " + attr) + + check_decimal = attr.split(".") + if len(check_decimal) > 1: + decimal_str = "" + for digit in check_decimal[1]: + if digit.isdigit(): + decimal_str += digit + else: + break + if len(decimal_str) > 6: + attr = attr.replace(decimal_str, decimal_str[0:6]) + + date_obj = isodate.parse_datetime(attr) + test_utc = date_obj.utctimetuple() + if test_utc.tm_year > 9999 or test_utc.tm_year < 1: + raise OverflowError("Hit max or min date") + except (ValueError, OverflowError, AttributeError) as err: + msg = "Cannot deserialize datetime object." + raise DeserializationError(msg) from err + return date_obj + + @staticmethod + def deserialize_unix(attr): + """Serialize Datetime object into IntTime format. + This is represented as seconds. + + :param int attr: Object to be serialized. + :return: Deserialized datetime + :rtype: Datetime + :raises DeserializationError: if format invalid + """ + if isinstance(attr, ET.Element): + attr = int(attr.text) # type: ignore + try: + attr = int(attr) + date_obj = datetime.datetime.fromtimestamp(attr, TZ_UTC) + except ValueError as err: + msg = "Cannot deserialize to unix datetime object." + raise DeserializationError(msg) from err + return date_obj diff --git a/.venv/lib/python3.12/site-packages/azure/ai/inference/_vendor.py b/.venv/lib/python3.12/site-packages/azure/ai/inference/_vendor.py new file mode 100644 index 00000000..147e96be --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/inference/_vendor.py @@ -0,0 +1,47 @@ +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) Python Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- + +from abc import ABC +from typing import TYPE_CHECKING + +from ._configuration import ( + ChatCompletionsClientConfiguration, + EmbeddingsClientConfiguration, + ImageEmbeddingsClientConfiguration, +) + +if TYPE_CHECKING: + from azure.core import PipelineClient + + from ._serialization import Deserializer, Serializer + + +class ChatCompletionsClientMixinABC(ABC): + """DO NOT use this class. It is for internal typing use only.""" + + _client: "PipelineClient" + _config: ChatCompletionsClientConfiguration + _serialize: "Serializer" + _deserialize: "Deserializer" + + +class EmbeddingsClientMixinABC(ABC): + """DO NOT use this class. It is for internal typing use only.""" + + _client: "PipelineClient" + _config: EmbeddingsClientConfiguration + _serialize: "Serializer" + _deserialize: "Deserializer" + + +class ImageEmbeddingsClientMixinABC(ABC): + """DO NOT use this class. It is for internal typing use only.""" + + _client: "PipelineClient" + _config: ImageEmbeddingsClientConfiguration + _serialize: "Serializer" + _deserialize: "Deserializer" diff --git a/.venv/lib/python3.12/site-packages/azure/ai/inference/_version.py b/.venv/lib/python3.12/site-packages/azure/ai/inference/_version.py new file mode 100644 index 00000000..b1c2836b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/inference/_version.py @@ -0,0 +1,9 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) Python Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- + +VERSION = "1.0.0b9" diff --git a/.venv/lib/python3.12/site-packages/azure/ai/inference/aio/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/inference/aio/__init__.py new file mode 100644 index 00000000..668f989a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/inference/aio/__init__.py @@ -0,0 +1,33 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) Python Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +# pylint: disable=wrong-import-position + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ._patch import * # pylint: disable=unused-wildcard-import + +from ._client import ChatCompletionsClient # type: ignore +from ._client import EmbeddingsClient # type: ignore +from ._client import ImageEmbeddingsClient # type: ignore + +try: + from ._patch import __all__ as _patch_all + from ._patch import * +except ImportError: + _patch_all = [] +from ._patch import patch_sdk as _patch_sdk + +__all__ = [ + "ChatCompletionsClient", + "EmbeddingsClient", + "ImageEmbeddingsClient", +] +__all__.extend([p for p in _patch_all if p not in __all__]) # pyright: ignore + +_patch_sdk() diff --git a/.venv/lib/python3.12/site-packages/azure/ai/inference/aio/_client.py b/.venv/lib/python3.12/site-packages/azure/ai/inference/aio/_client.py new file mode 100644 index 00000000..88e6773b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/inference/aio/_client.py @@ -0,0 +1,280 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) Python Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- + +from copy import deepcopy +from typing import Any, Awaitable, TYPE_CHECKING, Union +from typing_extensions import Self + +from azure.core import AsyncPipelineClient +from azure.core.credentials import AzureKeyCredential +from azure.core.pipeline import policies +from azure.core.rest import AsyncHttpResponse, HttpRequest + +from .._serialization import Deserializer, Serializer +from ._configuration import ( + ChatCompletionsClientConfiguration, + EmbeddingsClientConfiguration, + ImageEmbeddingsClientConfiguration, +) +from ._operations import ( + ChatCompletionsClientOperationsMixin, + EmbeddingsClientOperationsMixin, + ImageEmbeddingsClientOperationsMixin, +) + +if TYPE_CHECKING: + from azure.core.credentials_async import AsyncTokenCredential + + +class ChatCompletionsClient(ChatCompletionsClientOperationsMixin): + """ChatCompletionsClient. + + :param endpoint: Service host. Required. + :type endpoint: str + :param credential: Credential used to authenticate requests to the service. Is either a key + credential type or a token credential type. Required. + :type credential: ~azure.core.credentials.AzureKeyCredential or + ~azure.core.credentials.AzureKeyCredential or + ~azure.core.credentials_async.AsyncTokenCredential + :keyword api_version: The API version to use for this operation. Default value is + "2024-05-01-preview". Note that overriding this default value may result in unsupported + behavior. + :paramtype api_version: str + """ + + def __init__( + self, endpoint: str, credential: Union[AzureKeyCredential, "AsyncTokenCredential"], **kwargs: Any + ) -> None: + _endpoint = "{endpoint}" + self._config = ChatCompletionsClientConfiguration(endpoint=endpoint, credential=credential, **kwargs) + _policies = kwargs.pop("policies", None) + if _policies is None: + _policies = [ + policies.RequestIdPolicy(**kwargs), + self._config.headers_policy, + self._config.user_agent_policy, + self._config.proxy_policy, + policies.ContentDecodePolicy(**kwargs), + self._config.redirect_policy, + self._config.retry_policy, + self._config.authentication_policy, + self._config.custom_hook_policy, + self._config.logging_policy, + policies.DistributedTracingPolicy(**kwargs), + policies.SensitiveHeaderCleanupPolicy(**kwargs) if self._config.redirect_policy else None, + self._config.http_logging_policy, + ] + self._client: AsyncPipelineClient = AsyncPipelineClient(base_url=_endpoint, policies=_policies, **kwargs) + + self._serialize = Serializer() + self._deserialize = Deserializer() + self._serialize.client_side_validation = False + + def send_request( + self, request: HttpRequest, *, stream: bool = False, **kwargs: Any + ) -> Awaitable[AsyncHttpResponse]: + """Runs the network request through the client's chained policies. + + >>> from azure.core.rest import HttpRequest + >>> request = HttpRequest("GET", "https://www.example.org/") + <HttpRequest [GET], url: 'https://www.example.org/'> + >>> response = await client.send_request(request) + <AsyncHttpResponse: 200 OK> + + For more information on this code flow, see https://aka.ms/azsdk/dpcodegen/python/send_request + + :param request: The network request you want to make. Required. + :type request: ~azure.core.rest.HttpRequest + :keyword bool stream: Whether the response payload will be streamed. Defaults to False. + :return: The response of your network call. Does not do error handling on your response. + :rtype: ~azure.core.rest.AsyncHttpResponse + """ + + request_copy = deepcopy(request) + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + + request_copy.url = self._client.format_url(request_copy.url, **path_format_arguments) + return self._client.send_request(request_copy, stream=stream, **kwargs) # type: ignore + + async def close(self) -> None: + await self._client.close() + + async def __aenter__(self) -> Self: + await self._client.__aenter__() + return self + + async def __aexit__(self, *exc_details: Any) -> None: + await self._client.__aexit__(*exc_details) + + +class EmbeddingsClient(EmbeddingsClientOperationsMixin): + """EmbeddingsClient. + + :param endpoint: Service host. Required. + :type endpoint: str + :param credential: Credential used to authenticate requests to the service. Is either a key + credential type or a token credential type. Required. + :type credential: ~azure.core.credentials.AzureKeyCredential or + ~azure.core.credentials.AzureKeyCredential or + ~azure.core.credentials_async.AsyncTokenCredential + :keyword api_version: The API version to use for this operation. Default value is + "2024-05-01-preview". Note that overriding this default value may result in unsupported + behavior. + :paramtype api_version: str + """ + + def __init__( + self, endpoint: str, credential: Union[AzureKeyCredential, "AsyncTokenCredential"], **kwargs: Any + ) -> None: + _endpoint = "{endpoint}" + self._config = EmbeddingsClientConfiguration(endpoint=endpoint, credential=credential, **kwargs) + _policies = kwargs.pop("policies", None) + if _policies is None: + _policies = [ + policies.RequestIdPolicy(**kwargs), + self._config.headers_policy, + self._config.user_agent_policy, + self._config.proxy_policy, + policies.ContentDecodePolicy(**kwargs), + self._config.redirect_policy, + self._config.retry_policy, + self._config.authentication_policy, + self._config.custom_hook_policy, + self._config.logging_policy, + policies.DistributedTracingPolicy(**kwargs), + policies.SensitiveHeaderCleanupPolicy(**kwargs) if self._config.redirect_policy else None, + self._config.http_logging_policy, + ] + self._client: AsyncPipelineClient = AsyncPipelineClient(base_url=_endpoint, policies=_policies, **kwargs) + + self._serialize = Serializer() + self._deserialize = Deserializer() + self._serialize.client_side_validation = False + + def send_request( + self, request: HttpRequest, *, stream: bool = False, **kwargs: Any + ) -> Awaitable[AsyncHttpResponse]: + """Runs the network request through the client's chained policies. + + >>> from azure.core.rest import HttpRequest + >>> request = HttpRequest("GET", "https://www.example.org/") + <HttpRequest [GET], url: 'https://www.example.org/'> + >>> response = await client.send_request(request) + <AsyncHttpResponse: 200 OK> + + For more information on this code flow, see https://aka.ms/azsdk/dpcodegen/python/send_request + + :param request: The network request you want to make. Required. + :type request: ~azure.core.rest.HttpRequest + :keyword bool stream: Whether the response payload will be streamed. Defaults to False. + :return: The response of your network call. Does not do error handling on your response. + :rtype: ~azure.core.rest.AsyncHttpResponse + """ + + request_copy = deepcopy(request) + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + + request_copy.url = self._client.format_url(request_copy.url, **path_format_arguments) + return self._client.send_request(request_copy, stream=stream, **kwargs) # type: ignore + + async def close(self) -> None: + await self._client.close() + + async def __aenter__(self) -> Self: + await self._client.__aenter__() + return self + + async def __aexit__(self, *exc_details: Any) -> None: + await self._client.__aexit__(*exc_details) + + +class ImageEmbeddingsClient(ImageEmbeddingsClientOperationsMixin): + """ImageEmbeddingsClient. + + :param endpoint: Service host. Required. + :type endpoint: str + :param credential: Credential used to authenticate requests to the service. Is either a key + credential type or a token credential type. Required. + :type credential: ~azure.core.credentials.AzureKeyCredential or + ~azure.core.credentials.AzureKeyCredential or + ~azure.core.credentials_async.AsyncTokenCredential + :keyword api_version: The API version to use for this operation. Default value is + "2024-05-01-preview". Note that overriding this default value may result in unsupported + behavior. + :paramtype api_version: str + """ + + def __init__( + self, endpoint: str, credential: Union[AzureKeyCredential, "AsyncTokenCredential"], **kwargs: Any + ) -> None: + _endpoint = "{endpoint}" + self._config = ImageEmbeddingsClientConfiguration(endpoint=endpoint, credential=credential, **kwargs) + _policies = kwargs.pop("policies", None) + if _policies is None: + _policies = [ + policies.RequestIdPolicy(**kwargs), + self._config.headers_policy, + self._config.user_agent_policy, + self._config.proxy_policy, + policies.ContentDecodePolicy(**kwargs), + self._config.redirect_policy, + self._config.retry_policy, + self._config.authentication_policy, + self._config.custom_hook_policy, + self._config.logging_policy, + policies.DistributedTracingPolicy(**kwargs), + policies.SensitiveHeaderCleanupPolicy(**kwargs) if self._config.redirect_policy else None, + self._config.http_logging_policy, + ] + self._client: AsyncPipelineClient = AsyncPipelineClient(base_url=_endpoint, policies=_policies, **kwargs) + + self._serialize = Serializer() + self._deserialize = Deserializer() + self._serialize.client_side_validation = False + + def send_request( + self, request: HttpRequest, *, stream: bool = False, **kwargs: Any + ) -> Awaitable[AsyncHttpResponse]: + """Runs the network request through the client's chained policies. + + >>> from azure.core.rest import HttpRequest + >>> request = HttpRequest("GET", "https://www.example.org/") + <HttpRequest [GET], url: 'https://www.example.org/'> + >>> response = await client.send_request(request) + <AsyncHttpResponse: 200 OK> + + For more information on this code flow, see https://aka.ms/azsdk/dpcodegen/python/send_request + + :param request: The network request you want to make. Required. + :type request: ~azure.core.rest.HttpRequest + :keyword bool stream: Whether the response payload will be streamed. Defaults to False. + :return: The response of your network call. Does not do error handling on your response. + :rtype: ~azure.core.rest.AsyncHttpResponse + """ + + request_copy = deepcopy(request) + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + + request_copy.url = self._client.format_url(request_copy.url, **path_format_arguments) + return self._client.send_request(request_copy, stream=stream, **kwargs) # type: ignore + + async def close(self) -> None: + await self._client.close() + + async def __aenter__(self) -> Self: + await self._client.__aenter__() + return self + + async def __aexit__(self, *exc_details: Any) -> None: + await self._client.__aexit__(*exc_details) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/inference/aio/_configuration.py b/.venv/lib/python3.12/site-packages/azure/ai/inference/aio/_configuration.py new file mode 100644 index 00000000..f60e1125 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/inference/aio/_configuration.py @@ -0,0 +1,197 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) Python Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- + +from typing import Any, TYPE_CHECKING, Union + +from azure.core.credentials import AzureKeyCredential +from azure.core.pipeline import policies + +from .._version import VERSION + +if TYPE_CHECKING: + from azure.core.credentials_async import AsyncTokenCredential + + +class ChatCompletionsClientConfiguration: # pylint: disable=too-many-instance-attributes + """Configuration for ChatCompletionsClient. + + Note that all parameters used to create this instance are saved as instance + attributes. + + :param endpoint: Service host. Required. + :type endpoint: str + :param credential: Credential used to authenticate requests to the service. Is either a key + credential type or a token credential type. Required. + :type credential: ~azure.core.credentials.AzureKeyCredential or + ~azure.core.credentials.AzureKeyCredential or + ~azure.core.credentials_async.AsyncTokenCredential + :keyword api_version: The API version to use for this operation. Default value is + "2024-05-01-preview". Note that overriding this default value may result in unsupported + behavior. + :paramtype api_version: str + """ + + def __init__( + self, endpoint: str, credential: Union[AzureKeyCredential, "AsyncTokenCredential"], **kwargs: Any + ) -> None: + api_version: str = kwargs.pop("api_version", "2024-05-01-preview") + + if endpoint is None: + raise ValueError("Parameter 'endpoint' must not be None.") + if credential is None: + raise ValueError("Parameter 'credential' must not be None.") + + self.endpoint = endpoint + self.credential = credential + self.api_version = api_version + self.credential_scopes = kwargs.pop("credential_scopes", ["https://ml.azure.com/.default"]) + kwargs.setdefault("sdk_moniker", "ai-inference/{}".format(VERSION)) + self.polling_interval = kwargs.get("polling_interval", 30) + self._configure(**kwargs) + + def _infer_policy(self, **kwargs): + if isinstance(self.credential, AzureKeyCredential): + return policies.AzureKeyCredentialPolicy(self.credential, "Authorization", prefix="Bearer", **kwargs) + if isinstance(self.credential, AzureKeyCredential): + return policies.AzureKeyCredentialPolicy(self.credential, "api-key", **kwargs) + if hasattr(self.credential, "get_token"): + return policies.AsyncBearerTokenCredentialPolicy(self.credential, *self.credential_scopes, **kwargs) + raise TypeError(f"Unsupported credential: {self.credential}") + + def _configure(self, **kwargs: Any) -> None: + self.user_agent_policy = kwargs.get("user_agent_policy") or policies.UserAgentPolicy(**kwargs) + self.headers_policy = kwargs.get("headers_policy") or policies.HeadersPolicy(**kwargs) + self.proxy_policy = kwargs.get("proxy_policy") or policies.ProxyPolicy(**kwargs) + self.logging_policy = kwargs.get("logging_policy") or policies.NetworkTraceLoggingPolicy(**kwargs) + self.http_logging_policy = kwargs.get("http_logging_policy") or policies.HttpLoggingPolicy(**kwargs) + self.custom_hook_policy = kwargs.get("custom_hook_policy") or policies.CustomHookPolicy(**kwargs) + self.redirect_policy = kwargs.get("redirect_policy") or policies.AsyncRedirectPolicy(**kwargs) + self.retry_policy = kwargs.get("retry_policy") or policies.AsyncRetryPolicy(**kwargs) + self.authentication_policy = kwargs.get("authentication_policy") + if self.credential and not self.authentication_policy: + self.authentication_policy = self._infer_policy(**kwargs) + + +class EmbeddingsClientConfiguration: # pylint: disable=too-many-instance-attributes + """Configuration for EmbeddingsClient. + + Note that all parameters used to create this instance are saved as instance + attributes. + + :param endpoint: Service host. Required. + :type endpoint: str + :param credential: Credential used to authenticate requests to the service. Is either a key + credential type or a token credential type. Required. + :type credential: ~azure.core.credentials.AzureKeyCredential or + ~azure.core.credentials.AzureKeyCredential or + ~azure.core.credentials_async.AsyncTokenCredential + :keyword api_version: The API version to use for this operation. Default value is + "2024-05-01-preview". Note that overriding this default value may result in unsupported + behavior. + :paramtype api_version: str + """ + + def __init__( + self, endpoint: str, credential: Union[AzureKeyCredential, "AsyncTokenCredential"], **kwargs: Any + ) -> None: + api_version: str = kwargs.pop("api_version", "2024-05-01-preview") + + if endpoint is None: + raise ValueError("Parameter 'endpoint' must not be None.") + if credential is None: + raise ValueError("Parameter 'credential' must not be None.") + + self.endpoint = endpoint + self.credential = credential + self.api_version = api_version + self.credential_scopes = kwargs.pop("credential_scopes", ["https://ml.azure.com/.default"]) + kwargs.setdefault("sdk_moniker", "ai-inference/{}".format(VERSION)) + self.polling_interval = kwargs.get("polling_interval", 30) + self._configure(**kwargs) + + def _infer_policy(self, **kwargs): + if isinstance(self.credential, AzureKeyCredential): + return policies.AzureKeyCredentialPolicy(self.credential, "Authorization", prefix="Bearer", **kwargs) + if isinstance(self.credential, AzureKeyCredential): + return policies.AzureKeyCredentialPolicy(self.credential, "api-key", **kwargs) + if hasattr(self.credential, "get_token"): + return policies.AsyncBearerTokenCredentialPolicy(self.credential, *self.credential_scopes, **kwargs) + raise TypeError(f"Unsupported credential: {self.credential}") + + def _configure(self, **kwargs: Any) -> None: + self.user_agent_policy = kwargs.get("user_agent_policy") or policies.UserAgentPolicy(**kwargs) + self.headers_policy = kwargs.get("headers_policy") or policies.HeadersPolicy(**kwargs) + self.proxy_policy = kwargs.get("proxy_policy") or policies.ProxyPolicy(**kwargs) + self.logging_policy = kwargs.get("logging_policy") or policies.NetworkTraceLoggingPolicy(**kwargs) + self.http_logging_policy = kwargs.get("http_logging_policy") or policies.HttpLoggingPolicy(**kwargs) + self.custom_hook_policy = kwargs.get("custom_hook_policy") or policies.CustomHookPolicy(**kwargs) + self.redirect_policy = kwargs.get("redirect_policy") or policies.AsyncRedirectPolicy(**kwargs) + self.retry_policy = kwargs.get("retry_policy") or policies.AsyncRetryPolicy(**kwargs) + self.authentication_policy = kwargs.get("authentication_policy") + if self.credential and not self.authentication_policy: + self.authentication_policy = self._infer_policy(**kwargs) + + +class ImageEmbeddingsClientConfiguration: # pylint: disable=too-many-instance-attributes + """Configuration for ImageEmbeddingsClient. + + Note that all parameters used to create this instance are saved as instance + attributes. + + :param endpoint: Service host. Required. + :type endpoint: str + :param credential: Credential used to authenticate requests to the service. Is either a key + credential type or a token credential type. Required. + :type credential: ~azure.core.credentials.AzureKeyCredential or + ~azure.core.credentials.AzureKeyCredential or + ~azure.core.credentials_async.AsyncTokenCredential + :keyword api_version: The API version to use for this operation. Default value is + "2024-05-01-preview". Note that overriding this default value may result in unsupported + behavior. + :paramtype api_version: str + """ + + def __init__( + self, endpoint: str, credential: Union[AzureKeyCredential, "AsyncTokenCredential"], **kwargs: Any + ) -> None: + api_version: str = kwargs.pop("api_version", "2024-05-01-preview") + + if endpoint is None: + raise ValueError("Parameter 'endpoint' must not be None.") + if credential is None: + raise ValueError("Parameter 'credential' must not be None.") + + self.endpoint = endpoint + self.credential = credential + self.api_version = api_version + self.credential_scopes = kwargs.pop("credential_scopes", ["https://ml.azure.com/.default"]) + kwargs.setdefault("sdk_moniker", "ai-inference/{}".format(VERSION)) + self.polling_interval = kwargs.get("polling_interval", 30) + self._configure(**kwargs) + + def _infer_policy(self, **kwargs): + if isinstance(self.credential, AzureKeyCredential): + return policies.AzureKeyCredentialPolicy(self.credential, "Authorization", prefix="Bearer", **kwargs) + if isinstance(self.credential, AzureKeyCredential): + return policies.AzureKeyCredentialPolicy(self.credential, "api-key", **kwargs) + if hasattr(self.credential, "get_token"): + return policies.AsyncBearerTokenCredentialPolicy(self.credential, *self.credential_scopes, **kwargs) + raise TypeError(f"Unsupported credential: {self.credential}") + + def _configure(self, **kwargs: Any) -> None: + self.user_agent_policy = kwargs.get("user_agent_policy") or policies.UserAgentPolicy(**kwargs) + self.headers_policy = kwargs.get("headers_policy") or policies.HeadersPolicy(**kwargs) + self.proxy_policy = kwargs.get("proxy_policy") or policies.ProxyPolicy(**kwargs) + self.logging_policy = kwargs.get("logging_policy") or policies.NetworkTraceLoggingPolicy(**kwargs) + self.http_logging_policy = kwargs.get("http_logging_policy") or policies.HttpLoggingPolicy(**kwargs) + self.custom_hook_policy = kwargs.get("custom_hook_policy") or policies.CustomHookPolicy(**kwargs) + self.redirect_policy = kwargs.get("redirect_policy") or policies.AsyncRedirectPolicy(**kwargs) + self.retry_policy = kwargs.get("retry_policy") or policies.AsyncRetryPolicy(**kwargs) + self.authentication_policy = kwargs.get("authentication_policy") + if self.credential and not self.authentication_policy: + self.authentication_policy = self._infer_policy(**kwargs) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/inference/aio/_operations/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/inference/aio/_operations/__init__.py new file mode 100644 index 00000000..ab870887 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/inference/aio/_operations/__init__.py @@ -0,0 +1,29 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) Python Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +# pylint: disable=wrong-import-position + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ._patch import * # pylint: disable=unused-wildcard-import + +from ._operations import ChatCompletionsClientOperationsMixin # type: ignore +from ._operations import EmbeddingsClientOperationsMixin # type: ignore +from ._operations import ImageEmbeddingsClientOperationsMixin # type: ignore + +from ._patch import __all__ as _patch_all +from ._patch import * +from ._patch import patch_sdk as _patch_sdk + +__all__ = [ + "ChatCompletionsClientOperationsMixin", + "EmbeddingsClientOperationsMixin", + "ImageEmbeddingsClientOperationsMixin", +] +__all__.extend([p for p in _patch_all if p not in __all__]) # pyright: ignore +_patch_sdk() diff --git a/.venv/lib/python3.12/site-packages/azure/ai/inference/aio/_operations/_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/inference/aio/_operations/_operations.py new file mode 100644 index 00000000..62ec772f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/inference/aio/_operations/_operations.py @@ -0,0 +1,781 @@ +# pylint: disable=too-many-locals +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) Python Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +from io import IOBase +import json +import sys +from typing import Any, Callable, Dict, IO, List, Optional, TypeVar, Union, overload + +from azure.core.exceptions import ( + ClientAuthenticationError, + HttpResponseError, + ResourceExistsError, + ResourceNotFoundError, + ResourceNotModifiedError, + StreamClosedError, + StreamConsumedError, + map_error, +) +from azure.core.pipeline import PipelineResponse +from azure.core.rest import AsyncHttpResponse, HttpRequest +from azure.core.tracing.decorator_async import distributed_trace_async +from azure.core.utils import case_insensitive_dict + +from ... import models as _models +from ..._model_base import SdkJSONEncoder, _deserialize +from ..._operations._operations import ( + build_chat_completions_complete_request, + build_chat_completions_get_model_info_request, + build_embeddings_embed_request, + build_embeddings_get_model_info_request, + build_image_embeddings_embed_request, + build_image_embeddings_get_model_info_request, +) +from .._vendor import ChatCompletionsClientMixinABC, EmbeddingsClientMixinABC, ImageEmbeddingsClientMixinABC + +if sys.version_info >= (3, 9): + from collections.abc import MutableMapping +else: + from typing import MutableMapping # type: ignore +JSON = MutableMapping[str, Any] # pylint: disable=unsubscriptable-object +_Unset: Any = object() +T = TypeVar("T") +ClsType = Optional[Callable[[PipelineResponse[HttpRequest, AsyncHttpResponse], T, Dict[str, Any]], Any]] + + +class ChatCompletionsClientOperationsMixin(ChatCompletionsClientMixinABC): + + @overload + async def _complete( + self, + *, + messages: List[_models._models.ChatRequestMessage], + extra_params: Optional[Union[str, _models._enums.ExtraParameters]] = None, + content_type: str = "application/json", + frequency_penalty: Optional[float] = None, + stream_parameter: Optional[bool] = None, + presence_penalty: Optional[float] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + max_tokens: Optional[int] = None, + response_format: Optional[_models._models.ChatCompletionsResponseFormat] = None, + stop: Optional[List[str]] = None, + tools: Optional[List[_models.ChatCompletionsToolDefinition]] = None, + tool_choice: Optional[ + Union[str, _models.ChatCompletionsToolChoicePreset, _models.ChatCompletionsNamedToolChoice] + ] = None, + seed: Optional[int] = None, + model: Optional[str] = None, + **kwargs: Any + ) -> _models.ChatCompletions: ... + @overload + async def _complete( + self, + body: JSON, + *, + extra_params: Optional[Union[str, _models._enums.ExtraParameters]] = None, + content_type: str = "application/json", + **kwargs: Any + ) -> _models.ChatCompletions: ... + @overload + async def _complete( + self, + body: IO[bytes], + *, + extra_params: Optional[Union[str, _models._enums.ExtraParameters]] = None, + content_type: str = "application/json", + **kwargs: Any + ) -> _models.ChatCompletions: ... + + @distributed_trace_async + async def _complete( + self, + body: Union[JSON, IO[bytes]] = _Unset, + *, + messages: List[_models._models.ChatRequestMessage] = _Unset, + extra_params: Optional[Union[str, _models._enums.ExtraParameters]] = None, + frequency_penalty: Optional[float] = None, + stream_parameter: Optional[bool] = None, + presence_penalty: Optional[float] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + max_tokens: Optional[int] = None, + response_format: Optional[_models._models.ChatCompletionsResponseFormat] = None, + stop: Optional[List[str]] = None, + tools: Optional[List[_models.ChatCompletionsToolDefinition]] = None, + tool_choice: Optional[ + Union[str, _models.ChatCompletionsToolChoicePreset, _models.ChatCompletionsNamedToolChoice] + ] = None, + seed: Optional[int] = None, + model: Optional[str] = None, + **kwargs: Any + ) -> _models.ChatCompletions: + """Gets chat completions for the provided chat messages. + Completions support a wide variety of tasks and generate text that continues from or + "completes" + provided prompt data. The method makes a REST API call to the ``/chat/completions`` route + on the given endpoint. + + :param body: Is either a JSON type or a IO[bytes] type. Required. + :type body: JSON or IO[bytes] + :keyword messages: The collection of context messages associated with this chat completions + request. + Typical usage begins with a chat message for the System role that provides instructions for + the behavior of the assistant, followed by alternating messages between the User and + Assistant roles. Required. + :paramtype messages: list[~azure.ai.inference.models._models.ChatRequestMessage] + :keyword extra_params: Controls what happens if extra parameters, undefined by the REST API, + are passed in the JSON request payload. + This sets the HTTP request header ``extra-parameters``. Known values are: "error", "drop", and + "pass-through". Default value is None. + :paramtype extra_params: str or ~azure.ai.inference.models.ExtraParameters + :keyword frequency_penalty: A value that influences the probability of generated tokens + appearing based on their cumulative + frequency in generated text. + Positive values will make tokens less likely to appear as their frequency increases and + decrease the likelihood of the model repeating the same statements verbatim. + Supported range is [-2, 2]. Default value is None. + :paramtype frequency_penalty: float + :keyword stream_parameter: A value indicating whether chat completions should be streamed for + this request. Default value is None. + :paramtype stream_parameter: bool + :keyword presence_penalty: A value that influences the probability of generated tokens + appearing based on their existing + presence in generated text. + Positive values will make tokens less likely to appear when they already exist and increase + the + model's likelihood to output new topics. + Supported range is [-2, 2]. Default value is None. + :paramtype presence_penalty: float + :keyword temperature: The sampling temperature to use that controls the apparent creativity of + generated completions. + Higher values will make output more random while lower values will make results more focused + and deterministic. + It is not recommended to modify temperature and top_p for the same completions request as the + interaction of these two settings is difficult to predict. + Supported range is [0, 1]. Default value is None. + :paramtype temperature: float + :keyword top_p: An alternative to sampling with temperature called nucleus sampling. This value + causes the + model to consider the results of tokens with the provided probability mass. As an example, a + value of 0.15 will cause only the tokens comprising the top 15% of probability mass to be + considered. + It is not recommended to modify temperature and top_p for the same completions request as the + interaction of these two settings is difficult to predict. + Supported range is [0, 1]. Default value is None. + :paramtype top_p: float + :keyword max_tokens: The maximum number of tokens to generate. Default value is None. + :paramtype max_tokens: int + :keyword response_format: An object specifying the format that the model must output. + + Setting to ``{ "type": "json_schema", "json_schema": {...} }`` enables Structured Outputs + which ensures the model will match your supplied JSON schema. + + Setting to ``{ "type": "json_object" }`` enables JSON mode, which ensures the message the + model generates is valid JSON. + + **Important:** when using JSON mode, you **must** also instruct the model to produce JSON + yourself via a system or user message. Without this, the model may generate an unending stream + of whitespace until the generation reaches the token limit, resulting in a long-running and + seemingly "stuck" request. Also note that the message content may be partially cut off if + ``finish_reason="length"``\\ , which indicates the generation exceeded ``max_tokens`` or the + conversation exceeded the max context length. Default value is None. + :paramtype response_format: ~azure.ai.inference.models._models.ChatCompletionsResponseFormat + :keyword stop: A collection of textual sequences that will end completions generation. Default + value is None. + :paramtype stop: list[str] + :keyword tools: A list of tools the model may request to call. Currently, only functions are + supported as a tool. The model + may response with a function call request and provide the input arguments in JSON format for + that function. Default value is None. + :paramtype tools: list[~azure.ai.inference.models.ChatCompletionsToolDefinition] + :keyword tool_choice: If specified, the model will configure which of the provided tools it can + use for the chat completions response. Is either a Union[str, + "_models.ChatCompletionsToolChoicePreset"] type or a ChatCompletionsNamedToolChoice type. + Default value is None. + :paramtype tool_choice: str or ~azure.ai.inference.models.ChatCompletionsToolChoicePreset or + ~azure.ai.inference.models.ChatCompletionsNamedToolChoice + :keyword seed: If specified, the system will make a best effort to sample deterministically + such that repeated requests with the + same seed and parameters should return the same result. Determinism is not guaranteed. Default + value is None. + :paramtype seed: int + :keyword model: ID of the specific AI model to use, if more than one model is available on the + endpoint. Default value is None. + :paramtype model: str + :return: ChatCompletions. The ChatCompletions is compatible with MutableMapping + :rtype: ~azure.ai.inference.models.ChatCompletions + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[_models.ChatCompletions] = kwargs.pop("cls", None) + + if body is _Unset: + if messages is _Unset: + raise TypeError("missing required argument: messages") + body = { + "frequency_penalty": frequency_penalty, + "max_tokens": max_tokens, + "messages": messages, + "model": model, + "presence_penalty": presence_penalty, + "response_format": response_format, + "seed": seed, + "stop": stop, + "stream": stream_parameter, + "temperature": temperature, + "tool_choice": tool_choice, + "tools": tools, + "top_p": top_p, + } + body = {k: v for k, v in body.items() if v is not None} + content_type = content_type or "application/json" + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _content = json.dumps(body, cls=SdkJSONEncoder, exclude_readonly=True) # type: ignore + + _request = build_chat_completions_complete_request( + extra_params=extra_params, + content_type=content_type, + api_version=self._config.api_version, + content=_content, + headers=_headers, + params=_params, + ) + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + _request.url = self._client.format_url(_request.url, **path_format_arguments) + + _stream = kwargs.pop("stream", False) + pipeline_response: PipelineResponse = await self._client._pipeline.run( # type: ignore # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + if _stream: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if _stream: + deserialized = response.iter_bytes() + else: + deserialized = _deserialize(_models.ChatCompletions, response.json()) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @distributed_trace_async + async def _get_model_info(self, **kwargs: Any) -> _models.ModelInfo: + """Returns information about the AI model. + The method makes a REST API call to the ``/info`` route on the given endpoint. + This method will only work when using Serverless API or Managed Compute endpoint. + It will not work for GitHub Models endpoint or Azure OpenAI endpoint. + + :return: ModelInfo. The ModelInfo is compatible with MutableMapping + :rtype: ~azure.ai.inference.models.ModelInfo + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[_models.ModelInfo] = kwargs.pop("cls", None) + + _request = build_chat_completions_get_model_info_request( + api_version=self._config.api_version, + headers=_headers, + params=_params, + ) + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + _request.url = self._client.format_url(_request.url, **path_format_arguments) + + _stream = kwargs.pop("stream", False) + pipeline_response: PipelineResponse = await self._client._pipeline.run( # type: ignore # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + if _stream: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if _stream: + deserialized = response.iter_bytes() + else: + deserialized = _deserialize(_models.ModelInfo, response.json()) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + +class EmbeddingsClientOperationsMixin(EmbeddingsClientMixinABC): + + @overload + async def _embed( + self, + *, + input: List[str], + extra_params: Optional[Union[str, _models._enums.ExtraParameters]] = None, + content_type: str = "application/json", + dimensions: Optional[int] = None, + encoding_format: Optional[Union[str, _models.EmbeddingEncodingFormat]] = None, + input_type: Optional[Union[str, _models.EmbeddingInputType]] = None, + model: Optional[str] = None, + **kwargs: Any + ) -> _models.EmbeddingsResult: ... + @overload + async def _embed( + self, + body: JSON, + *, + extra_params: Optional[Union[str, _models._enums.ExtraParameters]] = None, + content_type: str = "application/json", + **kwargs: Any + ) -> _models.EmbeddingsResult: ... + @overload + async def _embed( + self, + body: IO[bytes], + *, + extra_params: Optional[Union[str, _models._enums.ExtraParameters]] = None, + content_type: str = "application/json", + **kwargs: Any + ) -> _models.EmbeddingsResult: ... + + @distributed_trace_async + async def _embed( + self, + body: Union[JSON, IO[bytes]] = _Unset, + *, + input: List[str] = _Unset, + extra_params: Optional[Union[str, _models._enums.ExtraParameters]] = None, + dimensions: Optional[int] = None, + encoding_format: Optional[Union[str, _models.EmbeddingEncodingFormat]] = None, + input_type: Optional[Union[str, _models.EmbeddingInputType]] = None, + model: Optional[str] = None, + **kwargs: Any + ) -> _models.EmbeddingsResult: + """Return the embedding vectors for given text prompts. + The method makes a REST API call to the ``/embeddings`` route on the given endpoint. + + :param body: Is either a JSON type or a IO[bytes] type. Required. + :type body: JSON or IO[bytes] + :keyword input: Input text to embed, encoded as a string or array of tokens. + To embed multiple inputs in a single request, pass an array + of strings or array of token arrays. Required. + :paramtype input: list[str] + :keyword extra_params: Controls what happens if extra parameters, undefined by the REST API, + are passed in the JSON request payload. + This sets the HTTP request header ``extra-parameters``. Known values are: "error", "drop", and + "pass-through". Default value is None. + :paramtype extra_params: str or ~azure.ai.inference.models.ExtraParameters + :keyword dimensions: Optional. The number of dimensions the resulting output embeddings should + have. + Passing null causes the model to use its default value. + Returns a 422 error if the model doesn't support the value or parameter. Default value is + None. + :paramtype dimensions: int + :keyword encoding_format: Optional. The desired format for the returned embeddings. Known + values are: "base64", "binary", "float", "int8", "ubinary", and "uint8". Default value is None. + :paramtype encoding_format: str or ~azure.ai.inference.models.EmbeddingEncodingFormat + :keyword input_type: Optional. The type of the input. + Returns a 422 error if the model doesn't support the value or parameter. Known values are: + "text", "query", and "document". Default value is None. + :paramtype input_type: str or ~azure.ai.inference.models.EmbeddingInputType + :keyword model: ID of the specific AI model to use, if more than one model is available on the + endpoint. Default value is None. + :paramtype model: str + :return: EmbeddingsResult. The EmbeddingsResult is compatible with MutableMapping + :rtype: ~azure.ai.inference.models.EmbeddingsResult + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[_models.EmbeddingsResult] = kwargs.pop("cls", None) + + if body is _Unset: + if input is _Unset: + raise TypeError("missing required argument: input") + body = { + "dimensions": dimensions, + "encoding_format": encoding_format, + "input": input, + "input_type": input_type, + "model": model, + } + body = {k: v for k, v in body.items() if v is not None} + content_type = content_type or "application/json" + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _content = json.dumps(body, cls=SdkJSONEncoder, exclude_readonly=True) # type: ignore + + _request = build_embeddings_embed_request( + extra_params=extra_params, + content_type=content_type, + api_version=self._config.api_version, + content=_content, + headers=_headers, + params=_params, + ) + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + _request.url = self._client.format_url(_request.url, **path_format_arguments) + + _stream = kwargs.pop("stream", False) + pipeline_response: PipelineResponse = await self._client._pipeline.run( # type: ignore # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + if _stream: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if _stream: + deserialized = response.iter_bytes() + else: + deserialized = _deserialize(_models.EmbeddingsResult, response.json()) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @distributed_trace_async + async def _get_model_info(self, **kwargs: Any) -> _models.ModelInfo: + """Returns information about the AI model. + The method makes a REST API call to the ``/info`` route on the given endpoint. + This method will only work when using Serverless API or Managed Compute endpoint. + It will not work for GitHub Models endpoint or Azure OpenAI endpoint. + + :return: ModelInfo. The ModelInfo is compatible with MutableMapping + :rtype: ~azure.ai.inference.models.ModelInfo + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[_models.ModelInfo] = kwargs.pop("cls", None) + + _request = build_embeddings_get_model_info_request( + api_version=self._config.api_version, + headers=_headers, + params=_params, + ) + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + _request.url = self._client.format_url(_request.url, **path_format_arguments) + + _stream = kwargs.pop("stream", False) + pipeline_response: PipelineResponse = await self._client._pipeline.run( # type: ignore # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + if _stream: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if _stream: + deserialized = response.iter_bytes() + else: + deserialized = _deserialize(_models.ModelInfo, response.json()) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + +class ImageEmbeddingsClientOperationsMixin(ImageEmbeddingsClientMixinABC): + + @overload + async def _embed( + self, + *, + input: List[_models.ImageEmbeddingInput], + extra_params: Optional[Union[str, _models._enums.ExtraParameters]] = None, + content_type: str = "application/json", + dimensions: Optional[int] = None, + encoding_format: Optional[Union[str, _models.EmbeddingEncodingFormat]] = None, + input_type: Optional[Union[str, _models.EmbeddingInputType]] = None, + model: Optional[str] = None, + **kwargs: Any + ) -> _models.EmbeddingsResult: ... + @overload + async def _embed( + self, + body: JSON, + *, + extra_params: Optional[Union[str, _models._enums.ExtraParameters]] = None, + content_type: str = "application/json", + **kwargs: Any + ) -> _models.EmbeddingsResult: ... + @overload + async def _embed( + self, + body: IO[bytes], + *, + extra_params: Optional[Union[str, _models._enums.ExtraParameters]] = None, + content_type: str = "application/json", + **kwargs: Any + ) -> _models.EmbeddingsResult: ... + + @distributed_trace_async + async def _embed( + self, + body: Union[JSON, IO[bytes]] = _Unset, + *, + input: List[_models.ImageEmbeddingInput] = _Unset, + extra_params: Optional[Union[str, _models._enums.ExtraParameters]] = None, + dimensions: Optional[int] = None, + encoding_format: Optional[Union[str, _models.EmbeddingEncodingFormat]] = None, + input_type: Optional[Union[str, _models.EmbeddingInputType]] = None, + model: Optional[str] = None, + **kwargs: Any + ) -> _models.EmbeddingsResult: + """Return the embedding vectors for given images. + The method makes a REST API call to the ``/images/embeddings`` route on the given endpoint. + + :param body: Is either a JSON type or a IO[bytes] type. Required. + :type body: JSON or IO[bytes] + :keyword input: Input image to embed. To embed multiple inputs in a single request, pass an + array. + The input must not exceed the max input tokens for the model. Required. + :paramtype input: list[~azure.ai.inference.models.ImageEmbeddingInput] + :keyword extra_params: Controls what happens if extra parameters, undefined by the REST API, + are passed in the JSON request payload. + This sets the HTTP request header ``extra-parameters``. Known values are: "error", "drop", and + "pass-through". Default value is None. + :paramtype extra_params: str or ~azure.ai.inference.models.ExtraParameters + :keyword dimensions: Optional. The number of dimensions the resulting output embeddings should + have. + Passing null causes the model to use its default value. + Returns a 422 error if the model doesn't support the value or parameter. Default value is + None. + :paramtype dimensions: int + :keyword encoding_format: Optional. The number of dimensions the resulting output embeddings + should have. + Passing null causes the model to use its default value. + Returns a 422 error if the model doesn't support the value or parameter. Known values are: + "base64", "binary", "float", "int8", "ubinary", and "uint8". Default value is None. + :paramtype encoding_format: str or ~azure.ai.inference.models.EmbeddingEncodingFormat + :keyword input_type: Optional. The type of the input. + Returns a 422 error if the model doesn't support the value or parameter. Known values are: + "text", "query", and "document". Default value is None. + :paramtype input_type: str or ~azure.ai.inference.models.EmbeddingInputType + :keyword model: ID of the specific AI model to use, if more than one model is available on the + endpoint. Default value is None. + :paramtype model: str + :return: EmbeddingsResult. The EmbeddingsResult is compatible with MutableMapping + :rtype: ~azure.ai.inference.models.EmbeddingsResult + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[_models.EmbeddingsResult] = kwargs.pop("cls", None) + + if body is _Unset: + if input is _Unset: + raise TypeError("missing required argument: input") + body = { + "dimensions": dimensions, + "encoding_format": encoding_format, + "input": input, + "input_type": input_type, + "model": model, + } + body = {k: v for k, v in body.items() if v is not None} + content_type = content_type or "application/json" + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _content = json.dumps(body, cls=SdkJSONEncoder, exclude_readonly=True) # type: ignore + + _request = build_image_embeddings_embed_request( + extra_params=extra_params, + content_type=content_type, + api_version=self._config.api_version, + content=_content, + headers=_headers, + params=_params, + ) + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + _request.url = self._client.format_url(_request.url, **path_format_arguments) + + _stream = kwargs.pop("stream", False) + pipeline_response: PipelineResponse = await self._client._pipeline.run( # type: ignore # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + if _stream: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if _stream: + deserialized = response.iter_bytes() + else: + deserialized = _deserialize(_models.EmbeddingsResult, response.json()) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @distributed_trace_async + async def _get_model_info(self, **kwargs: Any) -> _models.ModelInfo: + """Returns information about the AI model. + The method makes a REST API call to the ``/info`` route on the given endpoint. + This method will only work when using Serverless API or Managed Compute endpoint. + It will not work for GitHub Models endpoint or Azure OpenAI endpoint. + + :return: ModelInfo. The ModelInfo is compatible with MutableMapping + :rtype: ~azure.ai.inference.models.ModelInfo + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[_models.ModelInfo] = kwargs.pop("cls", None) + + _request = build_image_embeddings_get_model_info_request( + api_version=self._config.api_version, + headers=_headers, + params=_params, + ) + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + _request.url = self._client.format_url(_request.url, **path_format_arguments) + + _stream = kwargs.pop("stream", False) + pipeline_response: PipelineResponse = await self._client._pipeline.run( # type: ignore # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + if _stream: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if _stream: + deserialized = response.iter_bytes() + else: + deserialized = _deserialize(_models.ModelInfo, response.json()) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore diff --git a/.venv/lib/python3.12/site-packages/azure/ai/inference/aio/_operations/_patch.py b/.venv/lib/python3.12/site-packages/azure/ai/inference/aio/_operations/_patch.py new file mode 100644 index 00000000..f7dd3251 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/inference/aio/_operations/_patch.py @@ -0,0 +1,20 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +"""Customize generated code here. + +Follow our quickstart for examples: https://aka.ms/azsdk/python/dpcodegen/python/customize +""" +from typing import List + +__all__: List[str] = [] # Add all objects you want publicly available to users at this package level + + +def patch_sdk(): + """Do not remove from this file. + + `patch_sdk` is a last resort escape hatch that allows you to do customizations + you can't accomplish using the techniques described in + https://aka.ms/azsdk/python/dpcodegen/python/customize + """ diff --git a/.venv/lib/python3.12/site-packages/azure/ai/inference/aio/_patch.py b/.venv/lib/python3.12/site-packages/azure/ai/inference/aio/_patch.py new file mode 100644 index 00000000..2f987380 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/inference/aio/_patch.py @@ -0,0 +1,1331 @@ +# pylint: disable=too-many-lines +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +"""Customize generated code here. + +Follow our quickstart for examples: https://aka.ms/azsdk/python/dpcodegen/python/customize +""" +import json +import logging +import sys + +from io import IOBase +from typing import Any, Dict, Union, IO, List, Literal, Optional, overload, Type, TYPE_CHECKING, AsyncIterable + +from azure.core.pipeline import PipelineResponse +from azure.core.credentials import AzureKeyCredential +from azure.core.tracing.decorator_async import distributed_trace_async +from azure.core.utils import case_insensitive_dict +from azure.core.exceptions import ( + ClientAuthenticationError, + HttpResponseError, + map_error, + ResourceExistsError, + ResourceNotFoundError, + ResourceNotModifiedError, +) +from .. import models as _models +from .._model_base import SdkJSONEncoder, _deserialize +from ._client import ChatCompletionsClient as ChatCompletionsClientGenerated +from ._client import EmbeddingsClient as EmbeddingsClientGenerated +from ._client import ImageEmbeddingsClient as ImageEmbeddingsClientGenerated +from .._operations._operations import ( + build_chat_completions_complete_request, + build_embeddings_embed_request, + build_image_embeddings_embed_request, +) +from .._patch import _get_internal_response_format + +if TYPE_CHECKING: + # pylint: disable=unused-import,ungrouped-imports + from azure.core.credentials_async import AsyncTokenCredential + +if sys.version_info >= (3, 9): + from collections.abc import MutableMapping +else: + from typing import MutableMapping # type: ignore # pylint: disable=ungrouped-imports + +JSON = MutableMapping[str, Any] # pylint: disable=unsubscriptable-object +_Unset: Any = object() +_LOGGER = logging.getLogger(__name__) + + +async def load_client( + endpoint: str, credential: Union[AzureKeyCredential, "AsyncTokenCredential"], **kwargs: Any +) -> Union["ChatCompletionsClient", "EmbeddingsClient", "ImageEmbeddingsClient"]: + """ + Load a client from a given endpoint URL. The method makes a REST API call to the `/info` route + on the given endpoint, to determine the model type and therefore which client to instantiate. + This method will only work when using Serverless API or Managed Compute endpoint. + It will not work for GitHub Models endpoint or Azure OpenAI endpoint. + Keyword arguments are passed through to the client constructor (you can set keywords such as + `api_version`, `user_agent`, `logging_enable` etc. on the client constructor). + + :param endpoint: Service endpoint URL for AI model inference. Required. + :type endpoint: str + :param credential: Credential used to authenticate requests to the service. Is either a + AzureKeyCredential type or a AsyncTokenCredential type. Required. + :type credential: ~azure.core.credentials.AzureKeyCredential or + ~azure.core.credentials_async.AsyncTokenCredential + :return: The appropriate asynchronous client associated with the given endpoint + :rtype: ~azure.ai.inference.aio.ChatCompletionsClient or ~azure.ai.inference.aio.EmbeddingsClient + or ~azure.ai.inference.aio.ImageEmbeddingsClient + :raises ~azure.core.exceptions.HttpResponseError: + """ + + async with ChatCompletionsClient( + endpoint, credential, **kwargs + ) as client: # Pick any of the clients, it does not matter. + try: + model_info = await client.get_model_info() # type: ignore + except ResourceNotFoundError as error: + error.message = ( + "`load_client` function does not work on this endpoint (`/info` route not supported). " + "Please construct one of the clients (e.g. `ChatCompletionsClient`) directly." + ) + raise error + + _LOGGER.info("model_info=%s", model_info) + if not model_info.model_type: + raise ValueError( + "The AI model information is missing a value for `model type`. Cannot create an appropriate client." + ) + + # TODO: Remove "completions", "chat-comletions" and "embedding" once Mistral Large and Cohere fixes their model type + if model_info.model_type in ( + _models.ModelType.CHAT_COMPLETION, + "chat_completions", + "chat", + "completion", + "chat-completion", + "chat-completions", + "chat completion", + "chat completions", + ): + chat_completion_client = ChatCompletionsClient(endpoint, credential, **kwargs) + chat_completion_client._model_info = ( # pylint: disable=protected-access,attribute-defined-outside-init + model_info + ) + return chat_completion_client + + if model_info.model_type in ( + _models.ModelType.EMBEDDINGS, + "embedding", + "text_embedding", + "text-embeddings", + "text embedding", + "text embeddings", + ): + embedding_client = EmbeddingsClient(endpoint, credential, **kwargs) + embedding_client._model_info = model_info # pylint: disable=protected-access,attribute-defined-outside-init + return embedding_client + + if model_info.model_type in ( + _models.ModelType.IMAGE_EMBEDDINGS, + "image_embedding", + "image-embeddings", + "image-embedding", + "image embedding", + "image embeddings", + ): + image_embedding_client = ImageEmbeddingsClient(endpoint, credential, **kwargs) + image_embedding_client._model_info = ( # pylint: disable=protected-access,attribute-defined-outside-init + model_info + ) + return image_embedding_client + + raise ValueError(f"No client available to support AI model type `{model_info.model_type}`") + + +class ChatCompletionsClient(ChatCompletionsClientGenerated): # pylint: disable=too-many-instance-attributes + """ChatCompletionsClient. + + :param endpoint: Service endpoint URL for AI model inference. Required. + :type endpoint: str + :param credential: Credential used to authenticate requests to the service. Is either a + AzureKeyCredential type or a AsyncTokenCredential type. Required. + :type credential: ~azure.core.credentials.AzureKeyCredential or + ~azure.core.credentials_async.AsyncTokenCredential + :keyword frequency_penalty: A value that influences the probability of generated tokens + appearing based on their cumulative frequency in generated text. + Positive values will make tokens less likely to appear as their frequency increases and + decrease the likelihood of the model repeating the same statements verbatim. + Supported range is [-2, 2]. + Default value is None. + :paramtype frequency_penalty: float + :keyword presence_penalty: A value that influences the probability of generated tokens + appearing based on their existing + presence in generated text. + Positive values will make tokens less likely to appear when they already exist and increase + the model's likelihood to output new topics. + Supported range is [-2, 2]. + Default value is None. + :paramtype presence_penalty: float + :keyword temperature: The sampling temperature to use that controls the apparent creativity of + generated completions. + Higher values will make output more random while lower values will make results more focused + and deterministic. + It is not recommended to modify temperature and top_p for the same completions request as the + interaction of these two settings is difficult to predict. + Supported range is [0, 1]. + Default value is None. + :paramtype temperature: float + :keyword top_p: An alternative to sampling with temperature called nucleus sampling. This value + causes the + model to consider the results of tokens with the provided probability mass. As an example, a + value of 0.15 will cause only the tokens comprising the top 15% of probability mass to be + considered. + It is not recommended to modify temperature and top_p for the same completions request as the + interaction of these two settings is difficult to predict. + Supported range is [0, 1]. + Default value is None. + :paramtype top_p: float + :keyword max_tokens: The maximum number of tokens to generate. Default value is None. + :paramtype max_tokens: int + :keyword response_format: The format that the AI model must output. AI chat completions models typically output + unformatted text by default. This is equivalent to setting "text" as the response_format. + To output JSON format, without adhering to any schema, set to "json_object". + To output JSON format adhering to a provided schema, set this to an object of the class + ~azure.ai.inference.models.JsonSchemaFormat. Default value is None. + :paramtype response_format: Union[Literal['text', 'json_object'], ~azure.ai.inference.models.JsonSchemaFormat] + :keyword stop: A collection of textual sequences that will end completions generation. Default + value is None. + :paramtype stop: list[str] + :keyword tools: The available tool definitions that the chat completions request can use, + including caller-defined functions. Default value is None. + :paramtype tools: list[~azure.ai.inference.models.ChatCompletionsToolDefinition] + :keyword tool_choice: If specified, the model will configure which of the provided tools it can + use for the chat completions response. Is either a Union[str, + "_models.ChatCompletionsToolChoicePreset"] type or a ChatCompletionsNamedToolChoice type. + Default value is None. + :paramtype tool_choice: str or ~azure.ai.inference.models.ChatCompletionsToolChoicePreset or + ~azure.ai.inference.models.ChatCompletionsNamedToolChoice + :keyword seed: If specified, the system will make a best effort to sample deterministically + such that repeated requests with the + same seed and parameters should return the same result. Determinism is not guaranteed. + Default value is None. + :paramtype seed: int + :keyword model: ID of the specific AI model to use, if more than one model is available on the + endpoint. Default value is None. + :paramtype model: str + :keyword model_extras: Additional, model-specific parameters that are not in the + standard request payload. They will be added as-is to the root of the JSON in the request body. + How the service handles these extra parameters depends on the value of the + ``extra-parameters`` request header. Default value is None. + :paramtype model_extras: dict[str, Any] + :keyword api_version: The API version to use for this operation. Default value is + "2024-05-01-preview". Note that overriding this default value may result in unsupported + behavior. + :paramtype api_version: str + """ + + def __init__( + self, + endpoint: str, + credential: Union[AzureKeyCredential, "AsyncTokenCredential"], + *, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + max_tokens: Optional[int] = None, + response_format: Optional[Union[Literal["text", "json_object"], _models.JsonSchemaFormat]] = None, + stop: Optional[List[str]] = None, + tools: Optional[List[_models.ChatCompletionsToolDefinition]] = None, + tool_choice: Optional[ + Union[str, _models.ChatCompletionsToolChoicePreset, _models.ChatCompletionsNamedToolChoice] + ] = None, + seed: Optional[int] = None, + model: Optional[str] = None, + model_extras: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> None: + + self._model_info: Optional[_models.ModelInfo] = None + + # Store default chat completions settings, to be applied in all future service calls + # unless overridden by arguments in the `complete` method. + self._frequency_penalty = frequency_penalty + self._presence_penalty = presence_penalty + self._temperature = temperature + self._top_p = top_p + self._max_tokens = max_tokens + self._internal_response_format = _get_internal_response_format(response_format) + self._stop = stop + self._tools = tools + self._tool_choice = tool_choice + self._seed = seed + self._model = model + self._model_extras = model_extras + + # For Key auth, we need to send these two auth HTTP request headers simultaneously: + # 1. "Authorization: Bearer <key>" + # 2. "api-key: <key>" + # This is because Serverless API, Managed Compute and GitHub endpoints support the first header, + # and Azure OpenAI and the new Unified Inference endpoints support the second header. + # The first header will be taken care of by auto-generated code. + # The second one is added here. + if isinstance(credential, AzureKeyCredential): + headers = kwargs.pop("headers", {}) + if "api-key" not in headers: + headers["api-key"] = credential.key + kwargs["headers"] = headers + + super().__init__(endpoint, credential, **kwargs) + + @overload + async def complete( + self, + *, + messages: List[_models.ChatRequestMessage], + stream: Literal[False] = False, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + max_tokens: Optional[int] = None, + response_format: Optional[Union[Literal["text", "json_object"], _models.JsonSchemaFormat]] = None, + stop: Optional[List[str]] = None, + tools: Optional[List[_models.ChatCompletionsToolDefinition]] = None, + tool_choice: Optional[ + Union[str, _models.ChatCompletionsToolChoicePreset, _models.ChatCompletionsNamedToolChoice] + ] = None, + seed: Optional[int] = None, + model: Optional[str] = None, + model_extras: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> _models.ChatCompletions: ... + + @overload + async def complete( + self, + *, + messages: List[_models.ChatRequestMessage], + stream: Literal[True], + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + max_tokens: Optional[int] = None, + response_format: Optional[Union[Literal["text", "json_object"], _models.JsonSchemaFormat]] = None, + stop: Optional[List[str]] = None, + tools: Optional[List[_models.ChatCompletionsToolDefinition]] = None, + tool_choice: Optional[ + Union[str, _models.ChatCompletionsToolChoicePreset, _models.ChatCompletionsNamedToolChoice] + ] = None, + seed: Optional[int] = None, + model: Optional[str] = None, + model_extras: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> AsyncIterable[_models.StreamingChatCompletionsUpdate]: ... + + @overload + async def complete( + self, + *, + messages: List[_models.ChatRequestMessage], + stream: Optional[bool] = None, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + max_tokens: Optional[int] = None, + response_format: Optional[Union[Literal["text", "json_object"], _models.JsonSchemaFormat]] = None, + stop: Optional[List[str]] = None, + tools: Optional[List[_models.ChatCompletionsToolDefinition]] = None, + tool_choice: Optional[ + Union[str, _models.ChatCompletionsToolChoicePreset, _models.ChatCompletionsNamedToolChoice] + ] = None, + seed: Optional[int] = None, + model: Optional[str] = None, + model_extras: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> Union[AsyncIterable[_models.StreamingChatCompletionsUpdate], _models.ChatCompletions]: + # pylint: disable=line-too-long + """Gets chat completions for the provided chat messages. + Completions support a wide variety of tasks and generate text that continues from or + "completes" provided prompt data. The method makes a REST API call to the `/chat/completions` route + on the given endpoint. + When using this method with `stream=True`, the response is streamed + back to the client. Iterate over the resulting StreamingChatCompletions + object to get content updates as they arrive. By default, the response is a ChatCompletions object + (non-streaming). + + :keyword messages: The collection of context messages associated with this chat completions + request. + Typical usage begins with a chat message for the System role that provides instructions for + the behavior of the assistant, followed by alternating messages between the User and + Assistant roles. Required. + :paramtype messages: list[~azure.ai.inference.models.ChatRequestMessage] + :keyword stream: A value indicating whether chat completions should be streamed for this request. + Default value is False. If streaming is enabled, the response will be a StreamingChatCompletions. + Otherwise the response will be a ChatCompletions. + :paramtype stream: bool + :keyword frequency_penalty: A value that influences the probability of generated tokens + appearing based on their cumulative frequency in generated text. + Positive values will make tokens less likely to appear as their frequency increases and + decrease the likelihood of the model repeating the same statements verbatim. + Supported range is [-2, 2]. + Default value is None. + :paramtype frequency_penalty: float + :keyword presence_penalty: A value that influences the probability of generated tokens + appearing based on their existing + presence in generated text. + Positive values will make tokens less likely to appear when they already exist and increase + the model's likelihood to output new topics. + Supported range is [-2, 2]. + Default value is None. + :paramtype presence_penalty: float + :keyword temperature: The sampling temperature to use that controls the apparent creativity of + generated completions. + Higher values will make output more random while lower values will make results more focused + and deterministic. + It is not recommended to modify temperature and top_p for the same completions request as the + interaction of these two settings is difficult to predict. + Supported range is [0, 1]. + Default value is None. + :paramtype temperature: float + :keyword top_p: An alternative to sampling with temperature called nucleus sampling. This value + causes the + model to consider the results of tokens with the provided probability mass. As an example, a + value of 0.15 will cause only the tokens comprising the top 15% of probability mass to be + considered. + It is not recommended to modify temperature and top_p for the same completions request as the + interaction of these two settings is difficult to predict. + Supported range is [0, 1]. + Default value is None. + :paramtype top_p: float + :keyword max_tokens: The maximum number of tokens to generate. Default value is None. + :paramtype max_tokens: int + :keyword response_format: The format that the AI model must output. AI chat completions models typically output + unformatted text by default. This is equivalent to setting "text" as the response_format. + To output JSON format, without adhering to any schema, set to "json_object". + To output JSON format adhering to a provided schema, set this to an object of the class + ~azure.ai.inference.models.JsonSchemaFormat. Default value is None. + :paramtype response_format: Union[Literal['text', 'json_object'], ~azure.ai.inference.models.JsonSchemaFormat] + :keyword stop: A collection of textual sequences that will end completions generation. Default + value is None. + :paramtype stop: list[str] + :keyword tools: The available tool definitions that the chat completions request can use, + including caller-defined functions. Default value is None. + :paramtype tools: list[~azure.ai.inference.models.ChatCompletionsToolDefinition] + :keyword tool_choice: If specified, the model will configure which of the provided tools it can + use for the chat completions response. Is either a Union[str, + "_models.ChatCompletionsToolChoicePreset"] type or a ChatCompletionsNamedToolChoice type. + Default value is None. + :paramtype tool_choice: str or ~azure.ai.inference.models.ChatCompletionsToolChoicePreset or + ~azure.ai.inference.models.ChatCompletionsNamedToolChoice + :keyword seed: If specified, the system will make a best effort to sample deterministically + such that repeated requests with the + same seed and parameters should return the same result. Determinism is not guaranteed. + Default value is None. + :paramtype seed: int + :keyword model: ID of the specific AI model to use, if more than one model is available on the + endpoint. Default value is None. + :paramtype model: str + :keyword model_extras: Additional, model-specific parameters that are not in the + standard request payload. They will be added as-is to the root of the JSON in the request body. + How the service handles these extra parameters depends on the value of the + ``extra-parameters`` request header. Default value is None. + :paramtype model_extras: dict[str, Any] + :return: ChatCompletions for non-streaming, or AsyncIterable[StreamingChatCompletionsUpdate] for streaming. + :rtype: ~azure.ai.inference.models.ChatCompletions or ~azure.ai.inference.models.AsyncStreamingChatCompletions + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def complete( + self, + body: JSON, + *, + content_type: str = "application/json", + **kwargs: Any, + ) -> Union[AsyncIterable[_models.StreamingChatCompletionsUpdate], _models.ChatCompletions]: + # pylint: disable=line-too-long + """Gets chat completions for the provided chat messages. + Completions support a wide variety of tasks and generate text that continues from or + "completes" provided prompt data. + + :param body: An object of type MutableMapping[str, Any], such as a dictionary, that + specifies the full request payload. Required. + :type body: JSON + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: ChatCompletions for non-streaming, or AsyncIterable[StreamingChatCompletionsUpdate] for streaming. + :rtype: ~azure.ai.inference.models.ChatCompletions or ~azure.ai.inference.models.AsyncStreamingChatCompletions + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def complete( + self, + body: IO[bytes], + *, + content_type: str = "application/json", + **kwargs: Any, + ) -> Union[AsyncIterable[_models.StreamingChatCompletionsUpdate], _models.ChatCompletions]: + # pylint: disable=line-too-long + """Gets chat completions for the provided chat messages. + Completions support a wide variety of tasks and generate text that continues from or + "completes" provided prompt data. + + :param body: Specifies the full request payload. Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: ChatCompletions for non-streaming, or AsyncIterable[StreamingChatCompletionsUpdate] for streaming. + :rtype: ~azure.ai.inference.models.ChatCompletions or ~azure.ai.inference.models.AsyncStreamingChatCompletions + :raises ~azure.core.exceptions.HttpResponseError: + """ + + # pylint:disable=client-method-missing-tracing-decorator-async + async def complete( + self, + body: Union[JSON, IO[bytes]] = _Unset, + *, + messages: List[_models.ChatRequestMessage] = _Unset, + stream: Optional[bool] = None, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + max_tokens: Optional[int] = None, + response_format: Optional[Union[Literal["text", "json_object"], _models.JsonSchemaFormat]] = None, + stop: Optional[List[str]] = None, + tools: Optional[List[_models.ChatCompletionsToolDefinition]] = None, + tool_choice: Optional[ + Union[str, _models.ChatCompletionsToolChoicePreset, _models.ChatCompletionsNamedToolChoice] + ] = None, + seed: Optional[int] = None, + model: Optional[str] = None, + model_extras: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> Union[AsyncIterable[_models.StreamingChatCompletionsUpdate], _models.ChatCompletions]: + # pylint: disable=line-too-long + # pylint: disable=too-many-locals + """Gets chat completions for the provided chat messages. + Completions support a wide variety of tasks and generate text that continues from or + "completes" provided prompt data. When using this method with `stream=True`, the response is streamed + back to the client. Iterate over the resulting :class:`~azure.ai.inference.models.StreamingChatCompletions` + object to get content updates as they arrive. + + :param body: Is either a MutableMapping[str, Any] type (like a dictionary) or a IO[bytes] type + that specifies the full request payload. Required. + :type body: JSON or IO[bytes] + :keyword messages: The collection of context messages associated with this chat completions + request. + Typical usage begins with a chat message for the System role that provides instructions for + the behavior of the assistant, followed by alternating messages between the User and + Assistant roles. Required. + :paramtype messages: list[~azure.ai.inference.models.ChatRequestMessage] + :keyword stream: A value indicating whether chat completions should be streamed for this request. + Default value is False. If streaming is enabled, the response will be a StreamingChatCompletions. + Otherwise the response will be a ChatCompletions. + :paramtype stream: bool + :keyword frequency_penalty: A value that influences the probability of generated tokens + appearing based on their cumulative frequency in generated text. + Positive values will make tokens less likely to appear as their frequency increases and + decrease the likelihood of the model repeating the same statements verbatim. + Supported range is [-2, 2]. + Default value is None. + :paramtype frequency_penalty: float + :keyword presence_penalty: A value that influences the probability of generated tokens + appearing based on their existing + presence in generated text. + Positive values will make tokens less likely to appear when they already exist and increase + the model's likelihood to output new topics. + Supported range is [-2, 2]. + Default value is None. + :paramtype presence_penalty: float + :keyword temperature: The sampling temperature to use that controls the apparent creativity of + generated completions. + Higher values will make output more random while lower values will make results more focused + and deterministic. + It is not recommended to modify temperature and top_p for the same completions request as the + interaction of these two settings is difficult to predict. + Supported range is [0, 1]. + Default value is None. + :paramtype temperature: float + :keyword top_p: An alternative to sampling with temperature called nucleus sampling. This value + causes the + model to consider the results of tokens with the provided probability mass. As an example, a + value of 0.15 will cause only the tokens comprising the top 15% of probability mass to be + considered. + It is not recommended to modify temperature and top_p for the same completions request as the + interaction of these two settings is difficult to predict. + Supported range is [0, 1]. + Default value is None. + :paramtype top_p: float + :keyword max_tokens: The maximum number of tokens to generate. Default value is None. + :paramtype max_tokens: int + :keyword response_format: The format that the AI model must output. AI chat completions models typically output + unformatted text by default. This is equivalent to setting "text" as the response_format. + To output JSON format, without adhering to any schema, set to "json_object". + To output JSON format adhering to a provided schema, set this to an object of the class + ~azure.ai.inference.models.JsonSchemaFormat. Default value is None. + :paramtype response_format: Union[Literal['text', 'json_object'], ~azure.ai.inference.models.JsonSchemaFormat] + :keyword stop: A collection of textual sequences that will end completions generation. Default + value is None. + :paramtype stop: list[str] + :keyword tools: The available tool definitions that the chat completions request can use, + including caller-defined functions. Default value is None. + :paramtype tools: list[~azure.ai.inference.models.ChatCompletionsToolDefinition] + :keyword tool_choice: If specified, the model will configure which of the provided tools it can + use for the chat completions response. Is either a Union[str, + "_models.ChatCompletionsToolChoicePreset"] type or a ChatCompletionsNamedToolChoice type. + Default value is None. + :paramtype tool_choice: str or ~azure.ai.inference.models.ChatCompletionsToolChoicePreset or + ~azure.ai.inference.models.ChatCompletionsNamedToolChoice + :keyword seed: If specified, the system will make a best effort to sample deterministically + such that repeated requests with the + same seed and parameters should return the same result. Determinism is not guaranteed. + Default value is None. + :paramtype seed: int + :keyword model: ID of the specific AI model to use, if more than one model is available on the + endpoint. Default value is None. + :paramtype model: str + :keyword model_extras: Additional, model-specific parameters that are not in the + standard request payload. They will be added as-is to the root of the JSON in the request body. + How the service handles these extra parameters depends on the value of the + ``extra-parameters`` request header. Default value is None. + :paramtype model_extras: dict[str, Any] + :return: ChatCompletions for non-streaming, or AsyncIterable[StreamingChatCompletionsUpdate] for streaming. + :rtype: ~azure.ai.inference.models.ChatCompletions or ~azure.ai.inference.models.AsyncStreamingChatCompletions + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + _extra_parameters: Union[_models._enums.ExtraParameters, None] = None + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + + internal_response_format = _get_internal_response_format(response_format) + + if body is _Unset: + if messages is _Unset: + raise TypeError("missing required argument: messages") + body = { + "messages": messages, + "stream": stream, + "frequency_penalty": frequency_penalty if frequency_penalty is not None else self._frequency_penalty, + "max_tokens": max_tokens if max_tokens is not None else self._max_tokens, + "model": model if model is not None else self._model, + "presence_penalty": presence_penalty if presence_penalty is not None else self._presence_penalty, + "response_format": ( + internal_response_format if internal_response_format is not None else self._internal_response_format + ), + "seed": seed if seed is not None else self._seed, + "stop": stop if stop is not None else self._stop, + "temperature": temperature if temperature is not None else self._temperature, + "tool_choice": tool_choice if tool_choice is not None else self._tool_choice, + "tools": tools if tools is not None else self._tools, + "top_p": top_p if top_p is not None else self._top_p, + } + if model_extras is not None and bool(model_extras): + body.update(model_extras) + _extra_parameters = _models._enums.ExtraParameters.PASS_THROUGH # pylint: disable=protected-access + elif self._model_extras is not None and bool(self._model_extras): + body.update(self._model_extras) + _extra_parameters = _models._enums.ExtraParameters.PASS_THROUGH # pylint: disable=protected-access + body = {k: v for k, v in body.items() if v is not None} + elif isinstance(body, dict) and "stream" in body and isinstance(body["stream"], bool): + stream = body["stream"] + content_type = content_type or "application/json" + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _content = json.dumps(body, cls=SdkJSONEncoder, exclude_readonly=True) # type: ignore + + _request = build_chat_completions_complete_request( + extra_params=_extra_parameters, + content_type=content_type, + api_version=self._config.api_version, + content=_content, + headers=_headers, + params=_params, + ) + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + _request.url = self._client.format_url(_request.url, **path_format_arguments) + + _stream = stream or False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # type: ignore # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + if _stream: + await response.read() # Load the body in memory and close the socket + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if _stream: + return _models.AsyncStreamingChatCompletions(response) + + return _deserialize(_models._patch.ChatCompletions, response.json()) # pylint: disable=protected-access + + @distributed_trace_async + async def get_model_info(self, **kwargs: Any) -> _models.ModelInfo: + # pylint: disable=line-too-long + """Returns information about the AI model. + The method makes a REST API call to the ``/info`` route on the given endpoint. + This method will only work when using Serverless API or Managed Compute endpoint. + It will not work for GitHub Models endpoint or Azure OpenAI endpoint. + + :return: ModelInfo. The ModelInfo is compatible with MutableMapping + :rtype: ~azure.ai.inference.models.ModelInfo + :raises ~azure.core.exceptions.HttpResponseError: + """ + if not self._model_info: + try: + self._model_info = await self._get_model_info( + **kwargs + ) # pylint: disable=attribute-defined-outside-init + except ResourceNotFoundError as error: + error.message = "Model information is not available on this endpoint (`/info` route not supported)." + raise error + + return self._model_info + + def __str__(self) -> str: + # pylint: disable=client-method-name-no-double-underscore + return super().__str__() + f"\n{self._model_info}" if self._model_info else super().__str__() + + +class EmbeddingsClient(EmbeddingsClientGenerated): + """EmbeddingsClient. + + :param endpoint: Service endpoint URL for AI model inference. Required. + :type endpoint: str + :param credential: Credential used to authenticate requests to the service. Is either a + AzureKeyCredential type or a AsyncTokenCredential type. Required. + :type credential: ~azure.core.credentials.AzureKeyCredential or + ~azure.core.credentials_async.AsyncTokenCredential + :keyword dimensions: Optional. The number of dimensions the resulting output embeddings should + have. Default value is None. + :paramtype dimensions: int + :keyword encoding_format: Optional. The desired format for the returned embeddings. + Known values are: + "base64", "binary", "float", "int8", "ubinary", and "uint8". Default value is None. + :paramtype encoding_format: str or ~azure.ai.inference.models.EmbeddingEncodingFormat + :keyword input_type: Optional. The type of the input. Known values are: + "text", "query", and "document". Default value is None. + :paramtype input_type: str or ~azure.ai.inference.models.EmbeddingInputType + :keyword model: ID of the specific AI model to use, if more than one model is available on the + endpoint. Default value is None. + :paramtype model: str + :keyword model_extras: Additional, model-specific parameters that are not in the + standard request payload. They will be added as-is to the root of the JSON in the request body. + How the service handles these extra parameters depends on the value of the + ``extra-parameters`` request header. Default value is None. + :paramtype model_extras: dict[str, Any] + :keyword api_version: The API version to use for this operation. Default value is + "2024-05-01-preview". Note that overriding this default value may result in unsupported + behavior. + :paramtype api_version: str + """ + + def __init__( + self, + endpoint: str, + credential: Union[AzureKeyCredential, "AsyncTokenCredential"], + *, + dimensions: Optional[int] = None, + encoding_format: Optional[Union[str, _models.EmbeddingEncodingFormat]] = None, + input_type: Optional[Union[str, _models.EmbeddingInputType]] = None, + model: Optional[str] = None, + model_extras: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> None: + + self._model_info: Optional[_models.ModelInfo] = None + + # Store default embeddings settings, to be applied in all future service calls + # unless overridden by arguments in the `embed` method. + self._dimensions = dimensions + self._encoding_format = encoding_format + self._input_type = input_type + self._model = model + self._model_extras = model_extras + + # For Key auth, we need to send these two auth HTTP request headers simultaneously: + # 1. "Authorization: Bearer <key>" + # 2. "api-key: <key>" + # This is because Serverless API, Managed Compute and GitHub endpoints support the first header, + # and Azure OpenAI and the new Unified Inference endpoints support the second header. + # The first header will be taken care of by auto-generated code. + # The second one is added here. + if isinstance(credential, AzureKeyCredential): + headers = kwargs.pop("headers", {}) + if "api-key" not in headers: + headers["api-key"] = credential.key + kwargs["headers"] = headers + + super().__init__(endpoint, credential, **kwargs) + + @overload + async def embed( + self, + *, + input: List[str], + dimensions: Optional[int] = None, + encoding_format: Optional[Union[str, _models.EmbeddingEncodingFormat]] = None, + input_type: Optional[Union[str, _models.EmbeddingInputType]] = None, + model: Optional[str] = None, + model_extras: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> _models.EmbeddingsResult: + """Return the embedding vectors for given text prompts. + The method makes a REST API call to the `/embeddings` route on the given endpoint. + + :keyword input: Input text to embed, encoded as a string or array of tokens. + To embed multiple inputs in a single request, pass an array + of strings or array of token arrays. Required. + :paramtype input: list[str] + :keyword dimensions: Optional. The number of dimensions the resulting output embeddings should + have. Default value is None. + :paramtype dimensions: int + :keyword encoding_format: Optional. The desired format for the returned embeddings. + Known values are: + "base64", "binary", "float", "int8", "ubinary", and "uint8". Default value is None. + :paramtype encoding_format: str or ~azure.ai.inference.models.EmbeddingEncodingFormat + :keyword input_type: Optional. The type of the input. Known values are: + "text", "query", and "document". Default value is None. + :paramtype input_type: str or ~azure.ai.inference.models.EmbeddingInputType + :keyword model: ID of the specific AI model to use, if more than one model is available on the + endpoint. Default value is None. + :paramtype model: str + :keyword model_extras: Additional, model-specific parameters that are not in the + standard request payload. They will be added as-is to the root of the JSON in the request body. + How the service handles these extra parameters depends on the value of the + ``extra-parameters`` request header. Default value is None. + :paramtype model_extras: dict[str, Any] + :return: EmbeddingsResult. The EmbeddingsResult is compatible with MutableMapping + :rtype: ~azure.ai.inference.models.EmbeddingsResult + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def embed( + self, + body: JSON, + *, + content_type: str = "application/json", + **kwargs: Any, + ) -> _models.EmbeddingsResult: + """Return the embedding vectors for given text prompts. + The method makes a REST API call to the `/embeddings` route on the given endpoint. + + :param body: An object of type MutableMapping[str, Any], such as a dictionary, that + specifies the full request payload. Required. + :type body: JSON + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: EmbeddingsResult. The EmbeddingsResult is compatible with MutableMapping + :rtype: ~azure.ai.inference.models.EmbeddingsResult + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def embed( + self, + body: IO[bytes], + *, + content_type: str = "application/json", + **kwargs: Any, + ) -> _models.EmbeddingsResult: + """Return the embedding vectors for given text prompts. + The method makes a REST API call to the `/embeddings` route on the given endpoint. + + :param body: Specifies the full request payload. Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: EmbeddingsResult. The EmbeddingsResult is compatible with MutableMapping + :rtype: ~azure.ai.inference.models.EmbeddingsResult + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + async def embed( + self, + body: Union[JSON, IO[bytes]] = _Unset, + *, + input: List[str] = _Unset, + dimensions: Optional[int] = None, + encoding_format: Optional[Union[str, _models.EmbeddingEncodingFormat]] = None, + input_type: Optional[Union[str, _models.EmbeddingInputType]] = None, + model: Optional[str] = None, + model_extras: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> _models.EmbeddingsResult: + # pylint: disable=line-too-long + """Return the embedding vectors for given text prompts. + The method makes a REST API call to the `/embeddings` route on the given endpoint. + + :param body: Is either a MutableMapping[str, Any] type (like a dictionary) or a IO[bytes] type + that specifies the full request payload. Required. + :type body: JSON or IO[bytes] + :keyword input: Input text to embed, encoded as a string or array of tokens. + To embed multiple inputs in a single request, pass an array + of strings or array of token arrays. Required. + :paramtype input: list[str] + :keyword dimensions: Optional. The number of dimensions the resulting output embeddings should + have. Default value is None. + :paramtype dimensions: int + :keyword encoding_format: Optional. The desired format for the returned embeddings. + Known values are: + "base64", "binary", "float", "int8", "ubinary", and "uint8". Default value is None. + :paramtype encoding_format: str or ~azure.ai.inference.models.EmbeddingEncodingFormat + :keyword input_type: Optional. The type of the input. Known values are: + "text", "query", and "document". Default value is None. + :paramtype input_type: str or ~azure.ai.inference.models.EmbeddingInputType + :keyword model: ID of the specific AI model to use, if more than one model is available on the + endpoint. Default value is None. + :paramtype model: str + :keyword model_extras: Additional, model-specific parameters that are not in the + standard request payload. They will be added as-is to the root of the JSON in the request body. + How the service handles these extra parameters depends on the value of the + ``extra-parameters`` request header. Default value is None. + :paramtype model_extras: dict[str, Any] + :return: EmbeddingsResult. The EmbeddingsResult is compatible with MutableMapping + :rtype: ~azure.ai.inference.models.EmbeddingsResult + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping[int, Type[HttpResponseError]] = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + _extra_parameters: Union[_models._enums.ExtraParameters, None] = None + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + + if body is _Unset: + if input is _Unset: + raise TypeError("missing required argument: input") + body = { + "input": input, + "dimensions": dimensions if dimensions is not None else self._dimensions, + "encoding_format": encoding_format if encoding_format is not None else self._encoding_format, + "input_type": input_type if input_type is not None else self._input_type, + "model": model if model is not None else self._model, + } + if model_extras is not None and bool(model_extras): + body.update(model_extras) + _extra_parameters = _models._enums.ExtraParameters.PASS_THROUGH # pylint: disable=protected-access + elif self._model_extras is not None and bool(self._model_extras): + body.update(self._model_extras) + _extra_parameters = _models._enums.ExtraParameters.PASS_THROUGH # pylint: disable=protected-access + body = {k: v for k, v in body.items() if v is not None} + content_type = content_type or "application/json" + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _content = json.dumps(body, cls=SdkJSONEncoder, exclude_readonly=True) # type: ignore + + _request = build_embeddings_embed_request( + extra_params=_extra_parameters, + content_type=content_type, + api_version=self._config.api_version, + content=_content, + headers=_headers, + params=_params, + ) + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + _request.url = self._client.format_url(_request.url, **path_format_arguments) + + _stream = kwargs.pop("stream", False) + pipeline_response: PipelineResponse = await self._client._pipeline.run( # type: ignore # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + if _stream: + await response.read() # Load the body in memory and close the socket + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if _stream: + deserialized = response.iter_bytes() + else: + deserialized = _deserialize( + _models._patch.EmbeddingsResult, response.json() # pylint: disable=protected-access + ) + + return deserialized # type: ignore + + @distributed_trace_async + async def get_model_info(self, **kwargs: Any) -> _models.ModelInfo: + # pylint: disable=line-too-long + """Returns information about the AI model. + The method makes a REST API call to the ``/info`` route on the given endpoint. + This method will only work when using Serverless API or Managed Compute endpoint. + It will not work for GitHub Models endpoint or Azure OpenAI endpoint. + + :return: ModelInfo. The ModelInfo is compatible with MutableMapping + :rtype: ~azure.ai.inference.models.ModelInfo + :raises ~azure.core.exceptions.HttpResponseError: + """ + if not self._model_info: + try: + self._model_info = await self._get_model_info( + **kwargs + ) # pylint: disable=attribute-defined-outside-init + except ResourceNotFoundError as error: + error.message = "Model information is not available on this endpoint (`/info` route not supported)." + raise error + + return self._model_info + + def __str__(self) -> str: + # pylint: disable=client-method-name-no-double-underscore + return super().__str__() + f"\n{self._model_info}" if self._model_info else super().__str__() + + +class ImageEmbeddingsClient(ImageEmbeddingsClientGenerated): + """ImageEmbeddingsClient. + + :param endpoint: Service endpoint URL for AI model inference. Required. + :type endpoint: str + :param credential: Credential used to authenticate requests to the service. Is either a + AzureKeyCredential type or a AsyncTokenCredential type. Required. + :type credential: ~azure.core.credentials.AzureKeyCredential or + ~azure.core.credentials_async.AsyncTokenCredential + :keyword dimensions: Optional. The number of dimensions the resulting output embeddings should + have. Default value is None. + :paramtype dimensions: int + :keyword encoding_format: Optional. The desired format for the returned embeddings. + Known values are: + "base64", "binary", "float", "int8", "ubinary", and "uint8". Default value is None. + :paramtype encoding_format: str or ~azure.ai.inference.models.EmbeddingEncodingFormat + :keyword input_type: Optional. The type of the input. Known values are: + "text", "query", and "document". Default value is None. + :paramtype input_type: str or ~azure.ai.inference.models.EmbeddingInputType + :keyword model: ID of the specific AI model to use, if more than one model is available on the + endpoint. Default value is None. + :paramtype model: str + :keyword model_extras: Additional, model-specific parameters that are not in the + standard request payload. They will be added as-is to the root of the JSON in the request body. + How the service handles these extra parameters depends on the value of the + ``extra-parameters`` request header. Default value is None. + :paramtype model_extras: dict[str, Any] + :keyword api_version: The API version to use for this operation. Default value is + "2024-05-01-preview". Note that overriding this default value may result in unsupported + behavior. + :paramtype api_version: str + """ + + def __init__( + self, + endpoint: str, + credential: Union[AzureKeyCredential, "AsyncTokenCredential"], + *, + dimensions: Optional[int] = None, + encoding_format: Optional[Union[str, _models.EmbeddingEncodingFormat]] = None, + input_type: Optional[Union[str, _models.EmbeddingInputType]] = None, + model: Optional[str] = None, + model_extras: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> None: + + self._model_info: Optional[_models.ModelInfo] = None + + # Store default embeddings settings, to be applied in all future service calls + # unless overridden by arguments in the `embed` method. + self._dimensions = dimensions + self._encoding_format = encoding_format + self._input_type = input_type + self._model = model + self._model_extras = model_extras + + # For Key auth, we need to send these two auth HTTP request headers simultaneously: + # 1. "Authorization: Bearer <key>" + # 2. "api-key: <key>" + # This is because Serverless API, Managed Compute and GitHub endpoints support the first header, + # and Azure OpenAI and the new Unified Inference endpoints support the second header. + # The first header will be taken care of by auto-generated code. + # The second one is added here. + if isinstance(credential, AzureKeyCredential): + headers = kwargs.pop("headers", {}) + if "api-key" not in headers: + headers["api-key"] = credential.key + kwargs["headers"] = headers + + super().__init__(endpoint, credential, **kwargs) + + @overload + async def embed( + self, + *, + input: List[_models.ImageEmbeddingInput], + dimensions: Optional[int] = None, + encoding_format: Optional[Union[str, _models.EmbeddingEncodingFormat]] = None, + input_type: Optional[Union[str, _models.EmbeddingInputType]] = None, + model: Optional[str] = None, + model_extras: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> _models.EmbeddingsResult: + """Return the embedding vectors for given images. + The method makes a REST API call to the `/images/embeddings` route on the given endpoint. + + :keyword input: Input image to embed. To embed multiple inputs in a single request, pass an + array. + The input must not exceed the max input tokens for the model. Required. + :paramtype input: list[~azure.ai.inference.models.ImageEmbeddingInput] + :keyword dimensions: Optional. The number of dimensions the resulting output embeddings should + have. Default value is None. + :paramtype dimensions: int + :keyword encoding_format: Optional. The desired format for the returned embeddings. + Known values are: + "base64", "binary", "float", "int8", "ubinary", and "uint8". Default value is None. + :paramtype encoding_format: str or ~azure.ai.inference.models.EmbeddingEncodingFormat + :keyword input_type: Optional. Known values are: + "text", "query", and "document". Default value is None. + :paramtype input_type: str or ~azure.ai.inference.models.EmbeddingInputType + :keyword model: ID of the specific AI model to use, if more than one model is available on the + endpoint. Default value is None. + :paramtype model: str + :keyword model_extras: Additional, model-specific parameters that are not in the + standard request payload. They will be added as-is to the root of the JSON in the request body. + How the service handles these extra parameters depends on the value of the + ``extra-parameters`` request header. Default value is None. + :paramtype model_extras: dict[str, Any] + :return: EmbeddingsResult. The EmbeddingsResult is compatible with MutableMapping + :rtype: ~azure.ai.inference.models.EmbeddingsResult + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def embed( + self, + body: JSON, + *, + content_type: str = "application/json", + **kwargs: Any, + ) -> _models.EmbeddingsResult: + """Return the embedding vectors for given images. + The method makes a REST API call to the `/images/embeddings` route on the given endpoint. + + :param body: An object of type MutableMapping[str, Any], such as a dictionary, that + specifies the full request payload. Required. + :type body: JSON + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: EmbeddingsResult. The EmbeddingsResult is compatible with MutableMapping + :rtype: ~azure.ai.inference.models.EmbeddingsResult + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def embed( + self, + body: IO[bytes], + *, + content_type: str = "application/json", + **kwargs: Any, + ) -> _models.EmbeddingsResult: + """Return the embedding vectors for given images. + The method makes a REST API call to the `/images/embeddings` route on the given endpoint. + + :param body: Specifies the full request payload. Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: EmbeddingsResult. The EmbeddingsResult is compatible with MutableMapping + :rtype: ~azure.ai.inference.models.EmbeddingsResult + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + async def embed( + self, + body: Union[JSON, IO[bytes]] = _Unset, + *, + input: List[_models.ImageEmbeddingInput] = _Unset, + dimensions: Optional[int] = None, + encoding_format: Optional[Union[str, _models.EmbeddingEncodingFormat]] = None, + input_type: Optional[Union[str, _models.EmbeddingInputType]] = None, + model: Optional[str] = None, + model_extras: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> _models.EmbeddingsResult: + # pylint: disable=line-too-long + """Return the embedding vectors for given images. + The method makes a REST API call to the `/images/embeddings` route on the given endpoint. + + :param body: Is either a MutableMapping[str, Any] type (like a dictionary) or a IO[bytes] type + that specifies the full request payload. Required. + :type body: JSON or IO[bytes] + :keyword input: Input image to embed. To embed multiple inputs in a single request, pass an + array. + The input must not exceed the max input tokens for the model. Required. + :paramtype input: list[~azure.ai.inference.models.ImageEmbeddingInput] + :keyword dimensions: Optional. The number of dimensions the resulting output embeddings should + have. Default value is None. + :paramtype dimensions: int + :keyword encoding_format: Optional. The desired format for the returned embeddings. + Known values are: + "base64", "binary", "float", "int8", "ubinary", and "uint8". Default value is None. + :paramtype encoding_format: str or ~azure.ai.inference.models.EmbeddingEncodingFormat + :keyword input_type: Optional. The type of the input. Known values are: + "text", "query", and "document". Default value is None. + :paramtype input_type: str or ~azure.ai.inference.models.EmbeddingInputType + :keyword model: ID of the specific AI model to use, if more than one model is available on the + endpoint. Default value is None. + :paramtype model: str + :keyword model_extras: Additional, model-specific parameters that are not in the + standard request payload. They will be added as-is to the root of the JSON in the request body. + How the service handles these extra parameters depends on the value of the + ``extra-parameters`` request header. Default value is None. + :paramtype model_extras: dict[str, Any] + :return: EmbeddingsResult. The EmbeddingsResult is compatible with MutableMapping + :rtype: ~azure.ai.inference.models.EmbeddingsResult + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping[int, Type[HttpResponseError]] = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + _extra_parameters: Union[_models._enums.ExtraParameters, None] = None + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + + if body is _Unset: + if input is _Unset: + raise TypeError("missing required argument: input") + body = { + "input": input, + "dimensions": dimensions if dimensions is not None else self._dimensions, + "encoding_format": encoding_format if encoding_format is not None else self._encoding_format, + "input_type": input_type if input_type is not None else self._input_type, + "model": model if model is not None else self._model, + } + if model_extras is not None and bool(model_extras): + body.update(model_extras) + _extra_parameters = _models._enums.ExtraParameters.PASS_THROUGH # pylint: disable=protected-access + elif self._model_extras is not None and bool(self._model_extras): + body.update(self._model_extras) + _extra_parameters = _models._enums.ExtraParameters.PASS_THROUGH # pylint: disable=protected-access + body = {k: v for k, v in body.items() if v is not None} + content_type = content_type or "application/json" + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _content = json.dumps(body, cls=SdkJSONEncoder, exclude_readonly=True) # type: ignore + + _request = build_image_embeddings_embed_request( + extra_params=_extra_parameters, + content_type=content_type, + api_version=self._config.api_version, + content=_content, + headers=_headers, + params=_params, + ) + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + _request.url = self._client.format_url(_request.url, **path_format_arguments) + + _stream = kwargs.pop("stream", False) + pipeline_response: PipelineResponse = await self._client._pipeline.run( # type: ignore # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + if _stream: + await response.read() # Load the body in memory and close the socket + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if _stream: + deserialized = response.iter_bytes() + else: + deserialized = _deserialize( + _models._patch.EmbeddingsResult, response.json() # pylint: disable=protected-access + ) + + return deserialized # type: ignore + + @distributed_trace_async + async def get_model_info(self, **kwargs: Any) -> _models.ModelInfo: + # pylint: disable=line-too-long + """Returns information about the AI model. + The method makes a REST API call to the ``/info`` route on the given endpoint. + This method will only work when using Serverless API or Managed Compute endpoint. + It will not work for GitHub Models endpoint or Azure OpenAI endpoint. + + :return: ModelInfo. The ModelInfo is compatible with MutableMapping + :rtype: ~azure.ai.inference.models.ModelInfo + :raises ~azure.core.exceptions.HttpResponseError: + """ + if not self._model_info: + try: + self._model_info = await self._get_model_info( + **kwargs + ) # pylint: disable=attribute-defined-outside-init + except ResourceNotFoundError as error: + error.message = "Model information is not available on this endpoint (`/info` route not supported)." + raise error + + return self._model_info + + def __str__(self) -> str: + # pylint: disable=client-method-name-no-double-underscore + return super().__str__() + f"\n{self._model_info}" if self._model_info else super().__str__() + + +__all__: List[str] = [ + "load_client", + "ChatCompletionsClient", + "EmbeddingsClient", + "ImageEmbeddingsClient", +] # Add all objects you want publicly available to users at this package level + + +def patch_sdk(): + """Do not remove from this file. + + `patch_sdk` is a last resort escape hatch that allows you to do customizations + you can't accomplish using the techniques described in + https://aka.ms/azsdk/python/dpcodegen/python/customize + """ diff --git a/.venv/lib/python3.12/site-packages/azure/ai/inference/aio/_vendor.py b/.venv/lib/python3.12/site-packages/azure/ai/inference/aio/_vendor.py new file mode 100644 index 00000000..b430582c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/inference/aio/_vendor.py @@ -0,0 +1,47 @@ +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) Python Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- + +from abc import ABC +from typing import TYPE_CHECKING + +from ._configuration import ( + ChatCompletionsClientConfiguration, + EmbeddingsClientConfiguration, + ImageEmbeddingsClientConfiguration, +) + +if TYPE_CHECKING: + from azure.core import AsyncPipelineClient + + from .._serialization import Deserializer, Serializer + + +class ChatCompletionsClientMixinABC(ABC): + """DO NOT use this class. It is for internal typing use only.""" + + _client: "AsyncPipelineClient" + _config: ChatCompletionsClientConfiguration + _serialize: "Serializer" + _deserialize: "Deserializer" + + +class EmbeddingsClientMixinABC(ABC): + """DO NOT use this class. It is for internal typing use only.""" + + _client: "AsyncPipelineClient" + _config: EmbeddingsClientConfiguration + _serialize: "Serializer" + _deserialize: "Deserializer" + + +class ImageEmbeddingsClientMixinABC(ABC): + """DO NOT use this class. It is for internal typing use only.""" + + _client: "AsyncPipelineClient" + _config: ImageEmbeddingsClientConfiguration + _serialize: "Serializer" + _deserialize: "Deserializer" diff --git a/.venv/lib/python3.12/site-packages/azure/ai/inference/models/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/inference/models/__init__.py new file mode 100644 index 00000000..66e62570 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/inference/models/__init__.py @@ -0,0 +1,96 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) Python Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +# pylint: disable=wrong-import-position + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ._patch import * # pylint: disable=unused-wildcard-import + + +from ._models import ( # type: ignore + AudioContentItem, + ChatChoice, + ChatCompletions, + ChatCompletionsNamedToolChoice, + ChatCompletionsNamedToolChoiceFunction, + ChatCompletionsToolCall, + ChatCompletionsToolDefinition, + ChatResponseMessage, + CompletionsUsage, + ContentItem, + EmbeddingItem, + EmbeddingsResult, + EmbeddingsUsage, + FunctionCall, + FunctionDefinition, + ImageContentItem, + ImageEmbeddingInput, + ImageUrl, + InputAudio, + JsonSchemaFormat, + ModelInfo, + StreamingChatChoiceUpdate, + StreamingChatCompletionsUpdate, + StreamingChatResponseMessageUpdate, + StreamingChatResponseToolCallUpdate, + TextContentItem, +) + +from ._enums import ( # type: ignore + AudioContentFormat, + ChatCompletionsToolChoicePreset, + ChatRole, + CompletionsFinishReason, + EmbeddingEncodingFormat, + EmbeddingInputType, + ImageDetailLevel, + ModelType, +) +from ._patch import __all__ as _patch_all +from ._patch import * +from ._patch import patch_sdk as _patch_sdk + +__all__ = [ + "AudioContentItem", + "ChatChoice", + "ChatCompletions", + "ChatCompletionsNamedToolChoice", + "ChatCompletionsNamedToolChoiceFunction", + "ChatCompletionsToolCall", + "ChatCompletionsToolDefinition", + "ChatResponseMessage", + "CompletionsUsage", + "ContentItem", + "EmbeddingItem", + "EmbeddingsResult", + "EmbeddingsUsage", + "FunctionCall", + "FunctionDefinition", + "ImageContentItem", + "ImageEmbeddingInput", + "ImageUrl", + "InputAudio", + "JsonSchemaFormat", + "ModelInfo", + "StreamingChatChoiceUpdate", + "StreamingChatCompletionsUpdate", + "StreamingChatResponseMessageUpdate", + "StreamingChatResponseToolCallUpdate", + "TextContentItem", + "AudioContentFormat", + "ChatCompletionsToolChoicePreset", + "ChatRole", + "CompletionsFinishReason", + "EmbeddingEncodingFormat", + "EmbeddingInputType", + "ImageDetailLevel", + "ModelType", +] +__all__.extend([p for p in _patch_all if p not in __all__]) # pyright: ignore +_patch_sdk() diff --git a/.venv/lib/python3.12/site-packages/azure/ai/inference/models/_enums.py b/.venv/lib/python3.12/site-packages/azure/ai/inference/models/_enums.py new file mode 100644 index 00000000..6214f668 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/inference/models/_enums.py @@ -0,0 +1,146 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) Python Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- + +from enum import Enum +from azure.core import CaseInsensitiveEnumMeta + + +class AudioContentFormat(str, Enum, metaclass=CaseInsensitiveEnumMeta): + """A representation of the possible audio formats for audio.""" + + WAV = "wav" + """Specifies audio in WAV format.""" + MP3 = "mp3" + """Specifies audio in MP3 format.""" + + +class ChatCompletionsToolChoicePreset(str, Enum, metaclass=CaseInsensitiveEnumMeta): + """Represents a generic policy for how a chat completions tool may be selected.""" + + AUTO = "auto" + """Specifies that the model may either use any of the tools provided in this chat completions + request or + instead return a standard chat completions response as if no tools were provided.""" + NONE = "none" + """Specifies that the model should not respond with a tool call and should instead provide a + standard chat + completions response. Response content may still be influenced by the provided tool + definitions.""" + REQUIRED = "required" + """Specifies that the model should respond with a call to one or more tools.""" + + +class ChatRole(str, Enum, metaclass=CaseInsensitiveEnumMeta): + """A description of the intended purpose of a message within a chat completions interaction.""" + + SYSTEM = "system" + """The role that instructs or sets the behavior of the assistant.""" + USER = "user" + """The role that provides input for chat completions.""" + ASSISTANT = "assistant" + """The role that provides responses to system-instructed, user-prompted input.""" + TOOL = "tool" + """The role that represents extension tool activity within a chat completions operation.""" + DEVELOPER = "developer" + """The role that instructs or sets the behavior of the assistant. Some AI models support this role + instead of the 'system' role.""" + + +class CompletionsFinishReason(str, Enum, metaclass=CaseInsensitiveEnumMeta): + """Representation of the manner in which a completions response concluded.""" + + STOPPED = "stop" + """Completions ended normally and reached its end of token generation.""" + TOKEN_LIMIT_REACHED = "length" + """Completions exhausted available token limits before generation could complete.""" + CONTENT_FILTERED = "content_filter" + """Completions generated a response that was identified as potentially sensitive per content + moderation policies.""" + TOOL_CALLS = "tool_calls" + """Completion ended with the model calling a provided tool for output.""" + + +class EmbeddingEncodingFormat(str, Enum, metaclass=CaseInsensitiveEnumMeta): + """The format of the embeddings result. + Returns a 422 error if the model doesn't support the value or parameter. + """ + + BASE64 = "base64" + """Base64""" + BINARY = "binary" + """Binary""" + FLOAT = "float" + """Floating point""" + INT8 = "int8" + """Signed 8-bit integer""" + UBINARY = "ubinary" + """ubinary""" + UINT8 = "uint8" + """Unsigned 8-bit integer""" + + +class EmbeddingInputType(str, Enum, metaclass=CaseInsensitiveEnumMeta): + """Represents the input types used for embedding search.""" + + TEXT = "text" + """Indicates the input is a general text input.""" + QUERY = "query" + """Indicates the input represents a search query to find the most relevant documents in your + vector database.""" + DOCUMENT = "document" + """Indicates the input represents a document that is stored in a vector database.""" + + +class ExtraParameters(str, Enum, metaclass=CaseInsensitiveEnumMeta): + """Controls what happens if extra parameters, undefined by the REST API, are passed in the JSON + request payload. + """ + + ERROR = "error" + """The service will error if it detected extra parameters in the request payload. This is the + service default.""" + DROP = "drop" + """The service will ignore (drop) extra parameters in the request payload. It will only pass the + known parameters to the back-end AI model.""" + PASS_THROUGH = "pass-through" + """The service will pass extra parameters to the back-end AI model.""" + + +class ImageDetailLevel(str, Enum, metaclass=CaseInsensitiveEnumMeta): + """A representation of the possible image detail levels for image-based chat completions message + content. + """ + + AUTO = "auto" + """Specifies that the model should determine which detail level to apply using heuristics like + image size.""" + LOW = "low" + """Specifies that image evaluation should be constrained to the 'low-res' model that may be faster + and consume fewer + tokens but may also be less accurate for highly detailed images.""" + HIGH = "high" + """Specifies that image evaluation should enable the 'high-res' model that may be more accurate + for highly detailed + images but may also be slower and consume more tokens.""" + + +class ModelType(str, Enum, metaclass=CaseInsensitiveEnumMeta): + """The type of AI model.""" + + EMBEDDINGS = "embeddings" + """A model capable of generating embeddings from a text""" + IMAGE_GENERATION = "image_generation" + """A model capable of generating images from an image and text description""" + TEXT_GENERATION = "text_generation" + """A text generation model""" + IMAGE_EMBEDDINGS = "image_embeddings" + """A model capable of generating embeddings from an image""" + AUDIO_GENERATION = "audio_generation" + """A text-to-audio generative model""" + CHAT_COMPLETION = "chat_completion" + """A model capable of taking chat-formatted messages and generate responses""" diff --git a/.venv/lib/python3.12/site-packages/azure/ai/inference/models/_models.py b/.venv/lib/python3.12/site-packages/azure/ai/inference/models/_models.py new file mode 100644 index 00000000..53934528 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/inference/models/_models.py @@ -0,0 +1,1458 @@ +# pylint: disable=too-many-lines +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) Python Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +# pylint: disable=useless-super-delegation + +import datetime +from typing import Any, Dict, List, Literal, Mapping, Optional, TYPE_CHECKING, Union, overload + +from .. import _model_base +from .._model_base import rest_discriminator, rest_field +from ._enums import ChatRole + +if TYPE_CHECKING: + from .. import models as _models + + +class ContentItem(_model_base.Model): + """An abstract representation of a structured content item within a chat message. + + You probably want to use the sub-classes and not this class directly. Known sub-classes are: + ImageContentItem, AudioContentItem, TextContentItem + + :ivar type: The discriminated object type. Required. Default value is None. + :vartype type: str + """ + + __mapping__: Dict[str, _model_base.Model] = {} + type: str = rest_discriminator(name="type") + """The discriminated object type. Required. Default value is None.""" + + @overload + def __init__( + self, + *, + type: str, + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + +class AudioContentItem(ContentItem, discriminator="input_audio"): + """A structured chat content item containing an audio content. + + :ivar type: The discriminated object type: always 'input_audio' for this type. Required. + Default value is "input_audio". + :vartype type: str + :ivar input_audio: The details of the input audio. Required. + :vartype input_audio: ~azure.ai.inference.models.InputAudio + """ + + type: Literal["input_audio"] = rest_discriminator(name="type") # type: ignore + """The discriminated object type: always 'input_audio' for this type. Required. Default value is + \"input_audio\".""" + input_audio: "_models.InputAudio" = rest_field() + """The details of the input audio. Required.""" + + @overload + def __init__( + self, + *, + input_audio: "_models.InputAudio", + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, type="input_audio", **kwargs) + + +class ChatChoice(_model_base.Model): + """The representation of a single prompt completion as part of an overall chat completions + request. + Generally, ``n`` choices are generated per provided prompt with a default value of 1. + Token limits and other settings may limit the number of choices generated. + + + :ivar index: The ordered index associated with this chat completions choice. Required. + :vartype index: int + :ivar finish_reason: The reason that this chat completions choice completed its generated. + Required. Known values are: "stop", "length", "content_filter", and "tool_calls". + :vartype finish_reason: str or ~azure.ai.inference.models.CompletionsFinishReason + :ivar message: The chat message for a given chat completions prompt. Required. + :vartype message: ~azure.ai.inference.models.ChatResponseMessage + """ + + index: int = rest_field() + """The ordered index associated with this chat completions choice. Required.""" + finish_reason: Union[str, "_models.CompletionsFinishReason"] = rest_field() + """The reason that this chat completions choice completed its generated. Required. Known values + are: \"stop\", \"length\", \"content_filter\", and \"tool_calls\".""" + message: "_models.ChatResponseMessage" = rest_field() + """The chat message for a given chat completions prompt. Required.""" + + @overload + def __init__( + self, + *, + index: int, + finish_reason: Union[str, "_models.CompletionsFinishReason"], + message: "_models.ChatResponseMessage", + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + +class ChatCompletions(_model_base.Model): + """Representation of the response data from a chat completions request. + Completions support a wide variety of tasks and generate text that continues from or + "completes" + provided prompt data. + + + :ivar id: A unique identifier associated with this chat completions response. Required. + :vartype id: str + :ivar created: The first timestamp associated with generation activity for this completions + response, + represented as seconds since the beginning of the Unix epoch of 00:00 on 1 Jan 1970. Required. + :vartype created: ~datetime.datetime + :ivar model: The model used for the chat completion. Required. + :vartype model: str + :ivar choices: The collection of completions choices associated with this completions response. + Generally, ``n`` choices are generated per provided prompt with a default value of 1. + Token limits and other settings may limit the number of choices generated. Required. + :vartype choices: list[~azure.ai.inference.models.ChatChoice] + :ivar usage: Usage information for tokens processed and generated as part of this completions + operation. Required. + :vartype usage: ~azure.ai.inference.models.CompletionsUsage + """ + + id: str = rest_field() + """A unique identifier associated with this chat completions response. Required.""" + created: datetime.datetime = rest_field(format="unix-timestamp") + """The first timestamp associated with generation activity for this completions response, + represented as seconds since the beginning of the Unix epoch of 00:00 on 1 Jan 1970. Required.""" + model: str = rest_field() + """The model used for the chat completion. Required.""" + choices: List["_models.ChatChoice"] = rest_field() + """The collection of completions choices associated with this completions response. + Generally, ``n`` choices are generated per provided prompt with a default value of 1. + Token limits and other settings may limit the number of choices generated. Required.""" + usage: "_models.CompletionsUsage" = rest_field() + """Usage information for tokens processed and generated as part of this completions operation. + Required.""" + + @overload + def __init__( + self, + *, + id: str, # pylint: disable=redefined-builtin + created: datetime.datetime, + model: str, + choices: List["_models.ChatChoice"], + usage: "_models.CompletionsUsage", + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + +class ChatCompletionsNamedToolChoice(_model_base.Model): + """A tool selection of a specific, named function tool that will limit chat completions to using + the named function. + + :ivar type: The type of the tool. Currently, only ``function`` is supported. Required. Default + value is "function". + :vartype type: str + :ivar function: The function that should be called. Required. + :vartype function: ~azure.ai.inference.models.ChatCompletionsNamedToolChoiceFunction + """ + + type: Literal["function"] = rest_field() + """The type of the tool. Currently, only ``function`` is supported. Required. Default value is + \"function\".""" + function: "_models.ChatCompletionsNamedToolChoiceFunction" = rest_field() + """The function that should be called. Required.""" + + @overload + def __init__( + self, + *, + function: "_models.ChatCompletionsNamedToolChoiceFunction", + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.type: Literal["function"] = "function" + + +class ChatCompletionsNamedToolChoiceFunction(_model_base.Model): + """A tool selection of a specific, named function tool that will limit chat completions to using + the named function. + + :ivar name: The name of the function that should be called. Required. + :vartype name: str + """ + + name: str = rest_field() + """The name of the function that should be called. Required.""" + + @overload + def __init__( + self, + *, + name: str, + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + +class ChatCompletionsResponseFormat(_model_base.Model): + """Represents the format that the model must output. Use this to enable JSON mode instead of the + default text mode. + Note that to enable JSON mode, some AI models may also require you to instruct the model to + produce JSON + via a system or user message. + + You probably want to use the sub-classes and not this class directly. Known sub-classes are: + ChatCompletionsResponseFormatJsonObject, ChatCompletionsResponseFormatJsonSchema, + ChatCompletionsResponseFormatText + + :ivar type: The response format type to use for chat completions. Required. Default value is + None. + :vartype type: str + """ + + __mapping__: Dict[str, _model_base.Model] = {} + type: str = rest_discriminator(name="type") + """The response format type to use for chat completions. Required. Default value is None.""" + + @overload + def __init__( + self, + *, + type: str, + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + +class ChatCompletionsResponseFormatJsonObject(ChatCompletionsResponseFormat, discriminator="json_object"): + """A response format for Chat Completions that restricts responses to emitting valid JSON objects. + Note that to enable JSON mode, some AI models may also require you to instruct the model to + produce JSON + via a system or user message. + + :ivar type: Response format type: always 'json_object' for this object. Required. Default value + is "json_object". + :vartype type: str + """ + + type: Literal["json_object"] = rest_discriminator(name="type") # type: ignore + """Response format type: always 'json_object' for this object. Required. Default value is + \"json_object\".""" + + @overload + def __init__( + self, + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, type="json_object", **kwargs) + + +class ChatCompletionsResponseFormatJsonSchema(ChatCompletionsResponseFormat, discriminator="json_schema"): + """A response format for Chat Completions that restricts responses to emitting valid JSON objects, + with a + JSON schema specified by the caller. + + :ivar type: The type of response format being defined: ``json_schema``. Required. Default value + is "json_schema". + :vartype type: str + :ivar json_schema: The definition of the required JSON schema in the response, and associated + metadata. Required. + :vartype json_schema: ~azure.ai.inference.models.JsonSchemaFormat + """ + + type: Literal["json_schema"] = rest_discriminator(name="type") # type: ignore + """The type of response format being defined: ``json_schema``. Required. Default value is + \"json_schema\".""" + json_schema: "_models.JsonSchemaFormat" = rest_field() + """The definition of the required JSON schema in the response, and associated metadata. Required.""" + + @overload + def __init__( + self, + *, + json_schema: "_models.JsonSchemaFormat", + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, type="json_schema", **kwargs) + + +class ChatCompletionsResponseFormatText(ChatCompletionsResponseFormat, discriminator="text"): + """A response format for Chat Completions that emits text responses. This is the default response + format. + + :ivar type: Response format type: always 'text' for this object. Required. Default value is + "text". + :vartype type: str + """ + + type: Literal["text"] = rest_discriminator(name="type") # type: ignore + """Response format type: always 'text' for this object. Required. Default value is \"text\".""" + + @overload + def __init__( + self, + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, type="text", **kwargs) + + +class ChatCompletionsToolCall(_model_base.Model): + """A function tool call requested by the AI model. + + :ivar id: The ID of the tool call. Required. + :vartype id: str + :ivar type: The type of tool call. Currently, only ``function`` is supported. Required. Default + value is "function". + :vartype type: str + :ivar function: The details of the function call requested by the AI model. Required. + :vartype function: ~azure.ai.inference.models.FunctionCall + """ + + id: str = rest_field() + """The ID of the tool call. Required.""" + type: Literal["function"] = rest_field() + """The type of tool call. Currently, only ``function`` is supported. Required. Default value is + \"function\".""" + function: "_models.FunctionCall" = rest_field() + """The details of the function call requested by the AI model. Required.""" + + @overload + def __init__( + self, + *, + id: str, # pylint: disable=redefined-builtin + function: "_models.FunctionCall", + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.type: Literal["function"] = "function" + + +class ChatCompletionsToolDefinition(_model_base.Model): + """The definition of a chat completions tool that can call a function. + + :ivar type: The type of the tool. Currently, only ``function`` is supported. Required. Default + value is "function". + :vartype type: str + :ivar function: The function definition details for the function tool. Required. + :vartype function: ~azure.ai.inference.models.FunctionDefinition + """ + + type: Literal["function"] = rest_field() + """The type of the tool. Currently, only ``function`` is supported. Required. Default value is + \"function\".""" + function: "_models.FunctionDefinition" = rest_field() + """The function definition details for the function tool. Required.""" + + @overload + def __init__( + self, + *, + function: "_models.FunctionDefinition", + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.type: Literal["function"] = "function" + + +class ChatRequestMessage(_model_base.Model): + """An abstract representation of a chat message as provided in a request. + + You probably want to use the sub-classes and not this class directly. Known sub-classes are: + ChatRequestAssistantMessage, ChatRequestDeveloperMessage, ChatRequestSystemMessage, + ChatRequestToolMessage, ChatRequestUserMessage + + :ivar role: The chat role associated with this message. Required. Known values are: "system", + "user", "assistant", "tool", and "developer". + :vartype role: str or ~azure.ai.inference.models.ChatRole + """ + + __mapping__: Dict[str, _model_base.Model] = {} + role: str = rest_discriminator(name="role") + """The chat role associated with this message. Required. Known values are: \"system\", \"user\", + \"assistant\", \"tool\", and \"developer\".""" + + @overload + def __init__( + self, + *, + role: str, + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + +class ChatRequestAssistantMessage(ChatRequestMessage, discriminator="assistant"): + """A request chat message representing response or action from the assistant. + + :ivar role: The chat role associated with this message, which is always 'assistant' for + assistant messages. Required. The role that provides responses to system-instructed, + user-prompted input. + :vartype role: str or ~azure.ai.inference.models.ASSISTANT + :ivar content: The content of the message. + :vartype content: str + :ivar tool_calls: The tool calls that must be resolved and have their outputs appended to + subsequent input messages for the chat + completions request to resolve as configured. + :vartype tool_calls: list[~azure.ai.inference.models.ChatCompletionsToolCall] + """ + + role: Literal[ChatRole.ASSISTANT] = rest_discriminator(name="role") # type: ignore + """The chat role associated with this message, which is always 'assistant' for assistant messages. + Required. The role that provides responses to system-instructed, user-prompted input.""" + content: Optional[str] = rest_field() + """The content of the message.""" + tool_calls: Optional[List["_models.ChatCompletionsToolCall"]] = rest_field() + """The tool calls that must be resolved and have their outputs appended to subsequent input + messages for the chat + completions request to resolve as configured.""" + + @overload + def __init__( + self, + *, + content: Optional[str] = None, + tool_calls: Optional[List["_models.ChatCompletionsToolCall"]] = None, + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, role=ChatRole.ASSISTANT, **kwargs) + + +class ChatRequestDeveloperMessage(ChatRequestMessage, discriminator="developer"): + """A request chat message containing system instructions that influence how the model will + generate a chat completions + response. Some AI models support a developer message instead of a system message. + + :ivar role: The chat role associated with this message, which is always 'developer' for + developer messages. Required. The role that instructs or sets the behavior of the assistant. + Some AI models support this role instead of the 'system' role. + :vartype role: str or ~azure.ai.inference.models.DEVELOPER + :ivar content: The contents of the developer message. Required. + :vartype content: str + """ + + role: Literal[ChatRole.DEVELOPER] = rest_discriminator(name="role") # type: ignore + """The chat role associated with this message, which is always 'developer' for developer messages. + Required. The role that instructs or sets the behavior of the assistant. Some AI models support + this role instead of the 'system' role.""" + content: str = rest_field() + """The contents of the developer message. Required.""" + + @overload + def __init__( + self, + *, + content: str, + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, role=ChatRole.DEVELOPER, **kwargs) + + +class ChatRequestSystemMessage(ChatRequestMessage, discriminator="system"): + """A request chat message containing system instructions that influence how the model will + generate a chat completions + response. + + :ivar role: The chat role associated with this message, which is always 'system' for system + messages. Required. The role that instructs or sets the behavior of the assistant. + :vartype role: str or ~azure.ai.inference.models.SYSTEM + :ivar content: The contents of the system message. Required. + :vartype content: str + """ + + role: Literal[ChatRole.SYSTEM] = rest_discriminator(name="role") # type: ignore + """The chat role associated with this message, which is always 'system' for system messages. + Required. The role that instructs or sets the behavior of the assistant.""" + content: str = rest_field() + """The contents of the system message. Required.""" + + @overload + def __init__( + self, + *, + content: str, + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, role=ChatRole.SYSTEM, **kwargs) + + +class ChatRequestToolMessage(ChatRequestMessage, discriminator="tool"): + """A request chat message representing requested output from a configured tool. + + :ivar role: The chat role associated with this message, which is always 'tool' for tool + messages. Required. The role that represents extension tool activity within a chat completions + operation. + :vartype role: str or ~azure.ai.inference.models.TOOL + :ivar content: The content of the message. + :vartype content: str + :ivar tool_call_id: The ID of the tool call resolved by the provided content. Required. + :vartype tool_call_id: str + """ + + role: Literal[ChatRole.TOOL] = rest_discriminator(name="role") # type: ignore + """The chat role associated with this message, which is always 'tool' for tool messages. Required. + The role that represents extension tool activity within a chat completions operation.""" + content: Optional[str] = rest_field() + """The content of the message.""" + tool_call_id: str = rest_field() + """The ID of the tool call resolved by the provided content. Required.""" + + @overload + def __init__( + self, + *, + tool_call_id: str, + content: Optional[str] = None, + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, role=ChatRole.TOOL, **kwargs) + + +class ChatRequestUserMessage(ChatRequestMessage, discriminator="user"): + """A request chat message representing user input to the assistant. + + :ivar role: The chat role associated with this message, which is always 'user' for user + messages. Required. The role that provides input for chat completions. + :vartype role: str or ~azure.ai.inference.models.USER + :ivar content: The contents of the user message, with available input types varying by selected + model. Required. Is either a str type or a [ContentItem] type. + :vartype content: str or list[~azure.ai.inference.models.ContentItem] + """ + + role: Literal[ChatRole.USER] = rest_discriminator(name="role") # type: ignore + """The chat role associated with this message, which is always 'user' for user messages. Required. + The role that provides input for chat completions.""" + content: Union["str", List["_models.ContentItem"]] = rest_field() + """The contents of the user message, with available input types varying by selected model. + Required. Is either a str type or a [ContentItem] type.""" + + @overload + def __init__( + self, + *, + content: Union[str, List["_models.ContentItem"]], + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, role=ChatRole.USER, **kwargs) + + +class ChatResponseMessage(_model_base.Model): + """A representation of a chat message as received in a response. + + + :ivar role: The chat role associated with the message. Required. Known values are: "system", + "user", "assistant", "tool", and "developer". + :vartype role: str or ~azure.ai.inference.models.ChatRole + :ivar content: The content of the message. Required. + :vartype content: str + :ivar tool_calls: The tool calls that must be resolved and have their outputs appended to + subsequent input messages for the chat + completions request to resolve as configured. + :vartype tool_calls: list[~azure.ai.inference.models.ChatCompletionsToolCall] + """ + + role: Union[str, "_models.ChatRole"] = rest_field() + """The chat role associated with the message. Required. Known values are: \"system\", \"user\", + \"assistant\", \"tool\", and \"developer\".""" + content: str = rest_field() + """The content of the message. Required.""" + tool_calls: Optional[List["_models.ChatCompletionsToolCall"]] = rest_field() + """The tool calls that must be resolved and have their outputs appended to subsequent input + messages for the chat + completions request to resolve as configured.""" + + @overload + def __init__( + self, + *, + role: Union[str, "_models.ChatRole"], + content: str, + tool_calls: Optional[List["_models.ChatCompletionsToolCall"]] = None, + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + +class CompletionsUsage(_model_base.Model): + """Representation of the token counts processed for a completions request. + Counts consider all tokens across prompts, choices, choice alternates, best_of generations, and + other consumers. + + + :ivar completion_tokens: The number of tokens generated across all completions emissions. + Required. + :vartype completion_tokens: int + :ivar prompt_tokens: The number of tokens in the provided prompts for the completions request. + Required. + :vartype prompt_tokens: int + :ivar total_tokens: The total number of tokens processed for the completions request and + response. Required. + :vartype total_tokens: int + """ + + completion_tokens: int = rest_field() + """The number of tokens generated across all completions emissions. Required.""" + prompt_tokens: int = rest_field() + """The number of tokens in the provided prompts for the completions request. Required.""" + total_tokens: int = rest_field() + """The total number of tokens processed for the completions request and response. Required.""" + + @overload + def __init__( + self, + *, + completion_tokens: int, + prompt_tokens: int, + total_tokens: int, + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + +class EmbeddingItem(_model_base.Model): + """Representation of a single embeddings relatedness comparison. + + + :ivar embedding: List of embedding values for the input prompt. These represent a measurement + of the + vector-based relatedness of the provided input. Or a base64 encoded string of the embedding + vector. Required. Is either a str type or a [float] type. + :vartype embedding: str or list[float] + :ivar index: Index of the prompt to which the EmbeddingItem corresponds. Required. + :vartype index: int + """ + + embedding: Union["str", List[float]] = rest_field() + """List of embedding values for the input prompt. These represent a measurement of the + vector-based relatedness of the provided input. Or a base64 encoded string of the embedding + vector. Required. Is either a str type or a [float] type.""" + index: int = rest_field() + """Index of the prompt to which the EmbeddingItem corresponds. Required.""" + + @overload + def __init__( + self, + *, + embedding: Union[str, List[float]], + index: int, + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + +class EmbeddingsResult(_model_base.Model): + """Representation of the response data from an embeddings request. + Embeddings measure the relatedness of text strings and are commonly used for search, + clustering, + recommendations, and other similar scenarios. + + + :ivar id: Unique identifier for the embeddings result. Required. + :vartype id: str + :ivar data: Embedding values for the prompts submitted in the request. Required. + :vartype data: list[~azure.ai.inference.models.EmbeddingItem] + :ivar usage: Usage counts for tokens input using the embeddings API. Required. + :vartype usage: ~azure.ai.inference.models.EmbeddingsUsage + :ivar model: The model ID used to generate this result. Required. + :vartype model: str + """ + + id: str = rest_field() + """Unique identifier for the embeddings result. Required.""" + data: List["_models.EmbeddingItem"] = rest_field() + """Embedding values for the prompts submitted in the request. Required.""" + usage: "_models.EmbeddingsUsage" = rest_field() + """Usage counts for tokens input using the embeddings API. Required.""" + model: str = rest_field() + """The model ID used to generate this result. Required.""" + + @overload + def __init__( + self, + *, + id: str, # pylint: disable=redefined-builtin + data: List["_models.EmbeddingItem"], + usage: "_models.EmbeddingsUsage", + model: str, + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + +class EmbeddingsUsage(_model_base.Model): + """Measurement of the amount of tokens used in this request and response. + + + :ivar prompt_tokens: Number of tokens in the request. Required. + :vartype prompt_tokens: int + :ivar total_tokens: Total number of tokens transacted in this request/response. Should equal + the + number of tokens in the request. Required. + :vartype total_tokens: int + """ + + prompt_tokens: int = rest_field() + """Number of tokens in the request. Required.""" + total_tokens: int = rest_field() + """Total number of tokens transacted in this request/response. Should equal the + number of tokens in the request. Required.""" + + @overload + def __init__( + self, + *, + prompt_tokens: int, + total_tokens: int, + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + +class FunctionCall(_model_base.Model): + """The name and arguments of a function that should be called, as generated by the model. + + + :ivar name: The name of the function to call. Required. + :vartype name: str + :ivar arguments: The arguments to call the function with, as generated by the model in JSON + format. + Note that the model does not always generate valid JSON, and may hallucinate parameters + not defined by your function schema. Validate the arguments in your code before calling + your function. Required. + :vartype arguments: str + """ + + name: str = rest_field() + """The name of the function to call. Required.""" + arguments: str = rest_field() + """The arguments to call the function with, as generated by the model in JSON format. + Note that the model does not always generate valid JSON, and may hallucinate parameters + not defined by your function schema. Validate the arguments in your code before calling + your function. Required.""" + + @overload + def __init__( + self, + *, + name: str, + arguments: str, + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + +class FunctionDefinition(_model_base.Model): + """The definition of a caller-specified function that chat completions may invoke in response to + matching user input. + + :ivar name: The name of the function to be called. Required. + :vartype name: str + :ivar description: A description of what the function does. The model will use this description + when selecting the function and + interpreting its parameters. + :vartype description: str + :ivar parameters: The parameters the function accepts, described as a JSON Schema object. + :vartype parameters: any + """ + + name: str = rest_field() + """The name of the function to be called. Required.""" + description: Optional[str] = rest_field() + """A description of what the function does. The model will use this description when selecting the + function and + interpreting its parameters.""" + parameters: Optional[Any] = rest_field() + """The parameters the function accepts, described as a JSON Schema object.""" + + @overload + def __init__( + self, + *, + name: str, + description: Optional[str] = None, + parameters: Optional[Any] = None, + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + +class ImageContentItem(ContentItem, discriminator="image_url"): + """A structured chat content item containing an image reference. + + :ivar type: The discriminated object type: always 'image_url' for this type. Required. Default + value is "image_url". + :vartype type: str + :ivar image_url: An internet location, which must be accessible to the model,from which the + image may be retrieved. Required. + :vartype image_url: ~azure.ai.inference.models.ImageUrl + """ + + type: Literal["image_url"] = rest_discriminator(name="type") # type: ignore + """The discriminated object type: always 'image_url' for this type. Required. Default value is + \"image_url\".""" + image_url: "_models.ImageUrl" = rest_field() + """An internet location, which must be accessible to the model,from which the image may be + retrieved. Required.""" + + @overload + def __init__( + self, + *, + image_url: "_models.ImageUrl", + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, type="image_url", **kwargs) + + +class ImageEmbeddingInput(_model_base.Model): + """Represents an image with optional text. + + :ivar image: The input image encoded in base64 string as a data URL. Example: + ``data:image/{format};base64,{data}``. Required. + :vartype image: str + :ivar text: Optional. The text input to feed into the model (like DINO, CLIP). + Returns a 422 error if the model doesn't support the value or parameter. + :vartype text: str + """ + + image: str = rest_field() + """The input image encoded in base64 string as a data URL. Example: + ``data:image/{format};base64,{data}``. Required.""" + text: Optional[str] = rest_field() + """Optional. The text input to feed into the model (like DINO, CLIP). + Returns a 422 error if the model doesn't support the value or parameter.""" + + @overload + def __init__( + self, + *, + image: str, + text: Optional[str] = None, + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + +class ImageUrl(_model_base.Model): + """An internet location from which the model may retrieve an image. + + :ivar url: The URL of the image. Required. + :vartype url: str + :ivar detail: The evaluation quality setting to use, which controls relative prioritization of + speed, token consumption, and + accuracy. Known values are: "auto", "low", and "high". + :vartype detail: str or ~azure.ai.inference.models.ImageDetailLevel + """ + + url: str = rest_field() + """The URL of the image. Required.""" + detail: Optional[Union[str, "_models.ImageDetailLevel"]] = rest_field() + """The evaluation quality setting to use, which controls relative prioritization of speed, token + consumption, and + accuracy. Known values are: \"auto\", \"low\", and \"high\".""" + + @overload + def __init__( + self, + *, + url: str, + detail: Optional[Union[str, "_models.ImageDetailLevel"]] = None, + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + +class InputAudio(_model_base.Model): + """The details of an audio chat message content part. + + :ivar data: Base64 encoded audio data. Required. + :vartype data: str + :ivar format: The audio format of the audio content. Required. Known values are: "wav" and + "mp3". + :vartype format: str or ~azure.ai.inference.models.AudioContentFormat + """ + + data: str = rest_field() + """Base64 encoded audio data. Required.""" + format: Union[str, "_models.AudioContentFormat"] = rest_field() + """The audio format of the audio content. Required. Known values are: \"wav\" and \"mp3\".""" + + @overload + def __init__( + self, + *, + data: str, + format: Union[str, "_models.AudioContentFormat"], + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + +class JsonSchemaFormat(_model_base.Model): + """Defines the response format for chat completions as JSON with a given schema. + The AI model will need to adhere to this schema when generating completions. + + :ivar name: A name that labels this JSON schema. Must be a-z, A-Z, 0-9, or contain underscores + and dashes, with a maximum length of 64. Required. + :vartype name: str + :ivar schema: The definition of the JSON schema. See + https://json-schema.org/overview/what-is-jsonschema. + Note that AI models usually only support a subset of the keywords defined by JSON schema. + Consult your AI model documentation to determine what is supported. Required. + :vartype schema: dict[str, any] + :ivar description: A description of the response format, used by the AI model to determine how + to generate responses in this format. + :vartype description: str + :ivar strict: If set to true, the service will error out if the provided JSON schema contains + keywords + not supported by the AI model. An example of such keyword may be ``maxLength`` for JSON type + ``string``. + If false, and the provided JSON schema contains keywords not supported by the AI model, + the AI model will not error out. Instead it will ignore the unsupported keywords. + :vartype strict: bool + """ + + name: str = rest_field() + """A name that labels this JSON schema. Must be a-z, A-Z, 0-9, or contain underscores and dashes, + with a maximum length of 64. Required.""" + schema: Dict[str, Any] = rest_field() + """The definition of the JSON schema. See https://json-schema.org/overview/what-is-jsonschema. + Note that AI models usually only support a subset of the keywords defined by JSON schema. + Consult your AI model documentation to determine what is supported. Required.""" + description: Optional[str] = rest_field() + """A description of the response format, used by the AI model to determine how to generate + responses in this format.""" + strict: Optional[bool] = rest_field() + """If set to true, the service will error out if the provided JSON schema contains keywords + not supported by the AI model. An example of such keyword may be ``maxLength`` for JSON type + ``string``. + If false, and the provided JSON schema contains keywords not supported by the AI model, + the AI model will not error out. Instead it will ignore the unsupported keywords.""" + + @overload + def __init__( + self, + *, + name: str, + schema: Dict[str, Any], + description: Optional[str] = None, + strict: Optional[bool] = None, + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + +class ModelInfo(_model_base.Model): + """Represents some basic information about the AI model. + + + :ivar model_name: The name of the AI model. For example: ``Phi21``. Required. + :vartype model_name: str + :ivar model_type: The type of the AI model. A Unique identifier for the profile. Required. + Known values are: "embeddings", "image_generation", "text_generation", "image_embeddings", + "audio_generation", and "chat_completion". + :vartype model_type: str or ~azure.ai.inference.models.ModelType + :ivar model_provider_name: The model provider name. For example: ``Microsoft Research``. + Required. + :vartype model_provider_name: str + """ + + model_name: str = rest_field() + """The name of the AI model. For example: ``Phi21``. Required.""" + model_type: Union[str, "_models.ModelType"] = rest_field() + """The type of the AI model. A Unique identifier for the profile. Required. Known values are: + \"embeddings\", \"image_generation\", \"text_generation\", \"image_embeddings\", + \"audio_generation\", and \"chat_completion\".""" + model_provider_name: str = rest_field() + """The model provider name. For example: ``Microsoft Research``. Required.""" + + @overload + def __init__( + self, + *, + model_name: str, + model_type: Union[str, "_models.ModelType"], + model_provider_name: str, + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + +class StreamingChatChoiceUpdate(_model_base.Model): + """Represents an update to a single prompt completion when the service is streaming updates + using Server Sent Events (SSE). + Generally, ``n`` choices are generated per provided prompt with a default value of 1. + Token limits and other settings may limit the number of choices generated. + + + :ivar index: The ordered index associated with this chat completions choice. Required. + :vartype index: int + :ivar finish_reason: The reason that this chat completions choice completed its generated. + Required. Known values are: "stop", "length", "content_filter", and "tool_calls". + :vartype finish_reason: str or ~azure.ai.inference.models.CompletionsFinishReason + :ivar delta: An update to the chat message for a given chat completions prompt. Required. + :vartype delta: ~azure.ai.inference.models.StreamingChatResponseMessageUpdate + """ + + index: int = rest_field() + """The ordered index associated with this chat completions choice. Required.""" + finish_reason: Union[str, "_models.CompletionsFinishReason"] = rest_field() + """The reason that this chat completions choice completed its generated. Required. Known values + are: \"stop\", \"length\", \"content_filter\", and \"tool_calls\".""" + delta: "_models.StreamingChatResponseMessageUpdate" = rest_field() + """An update to the chat message for a given chat completions prompt. Required.""" + + @overload + def __init__( + self, + *, + index: int, + finish_reason: Union[str, "_models.CompletionsFinishReason"], + delta: "_models.StreamingChatResponseMessageUpdate", + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + +class StreamingChatCompletionsUpdate(_model_base.Model): + """Represents a response update to a chat completions request, when the service is streaming + updates + using Server Sent Events (SSE). + Completions support a wide variety of tasks and generate text that continues from or + "completes" + provided prompt data. + + + :ivar id: A unique identifier associated with this chat completions response. Required. + :vartype id: str + :ivar created: The first timestamp associated with generation activity for this completions + response, + represented as seconds since the beginning of the Unix epoch of 00:00 on 1 Jan 1970. Required. + :vartype created: ~datetime.datetime + :ivar model: The model used for the chat completion. Required. + :vartype model: str + :ivar choices: An update to the collection of completion choices associated with this + completions response. + Generally, ``n`` choices are generated per provided prompt with a default value of 1. + Token limits and other settings may limit the number of choices generated. Required. + :vartype choices: list[~azure.ai.inference.models.StreamingChatChoiceUpdate] + :ivar usage: Usage information for tokens processed and generated as part of this completions + operation. + :vartype usage: ~azure.ai.inference.models.CompletionsUsage + """ + + id: str = rest_field() + """A unique identifier associated with this chat completions response. Required.""" + created: datetime.datetime = rest_field(format="unix-timestamp") + """The first timestamp associated with generation activity for this completions response, + represented as seconds since the beginning of the Unix epoch of 00:00 on 1 Jan 1970. Required.""" + model: str = rest_field() + """The model used for the chat completion. Required.""" + choices: List["_models.StreamingChatChoiceUpdate"] = rest_field() + """An update to the collection of completion choices associated with this completions response. + Generally, ``n`` choices are generated per provided prompt with a default value of 1. + Token limits and other settings may limit the number of choices generated. Required.""" + usage: Optional["_models.CompletionsUsage"] = rest_field() + """Usage information for tokens processed and generated as part of this completions operation.""" + + @overload + def __init__( + self, + *, + id: str, # pylint: disable=redefined-builtin + created: datetime.datetime, + model: str, + choices: List["_models.StreamingChatChoiceUpdate"], + usage: Optional["_models.CompletionsUsage"] = None, + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + +class StreamingChatResponseMessageUpdate(_model_base.Model): + """A representation of a chat message update as received in a streaming response. + + :ivar role: The chat role associated with the message. If present, should always be + 'assistant'. Known values are: "system", "user", "assistant", "tool", and "developer". + :vartype role: str or ~azure.ai.inference.models.ChatRole + :ivar content: The content of the message. + :vartype content: str + :ivar tool_calls: The tool calls that must be resolved and have their outputs appended to + subsequent input messages for the chat + completions request to resolve as configured. + :vartype tool_calls: list[~azure.ai.inference.models.StreamingChatResponseToolCallUpdate] + """ + + role: Optional[Union[str, "_models.ChatRole"]] = rest_field() + """The chat role associated with the message. If present, should always be 'assistant'. Known + values are: \"system\", \"user\", \"assistant\", \"tool\", and \"developer\".""" + content: Optional[str] = rest_field() + """The content of the message.""" + tool_calls: Optional[List["_models.StreamingChatResponseToolCallUpdate"]] = rest_field() + """The tool calls that must be resolved and have their outputs appended to subsequent input + messages for the chat + completions request to resolve as configured.""" + + @overload + def __init__( + self, + *, + role: Optional[Union[str, "_models.ChatRole"]] = None, + content: Optional[str] = None, + tool_calls: Optional[List["_models.StreamingChatResponseToolCallUpdate"]] = None, + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + +class StreamingChatResponseToolCallUpdate(_model_base.Model): + """An update to the function tool call information requested by the AI model. + + + :ivar id: The ID of the tool call. Required. + :vartype id: str + :ivar function: Updates to the function call requested by the AI model. Required. + :vartype function: ~azure.ai.inference.models.FunctionCall + """ + + id: str = rest_field() + """The ID of the tool call. Required.""" + function: "_models.FunctionCall" = rest_field() + """Updates to the function call requested by the AI model. Required.""" + + @overload + def __init__( + self, + *, + id: str, # pylint: disable=redefined-builtin + function: "_models.FunctionCall", + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + +class TextContentItem(ContentItem, discriminator="text"): + """A structured chat content item containing plain text. + + :ivar type: The discriminated object type: always 'text' for this type. Required. Default value + is "text". + :vartype type: str + :ivar text: The content of the message. Required. + :vartype text: str + """ + + type: Literal["text"] = rest_discriminator(name="type") # type: ignore + """The discriminated object type: always 'text' for this type. Required. Default value is + \"text\".""" + text: str = rest_field() + """The content of the message. Required.""" + + @overload + def __init__( + self, + *, + text: str, + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, type="text", **kwargs) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/inference/models/_patch.py b/.venv/lib/python3.12/site-packages/azure/ai/inference/models/_patch.py new file mode 100644 index 00000000..1bc06799 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/inference/models/_patch.py @@ -0,0 +1,576 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +"""Customize generated code here. + +Follow our quickstart for examples: https://aka.ms/azsdk/python/dpcodegen/python/customize +""" +import base64 +import json +import logging +import queue +import re +import sys + +from typing import Mapping, Literal, Any, List, AsyncIterator, Iterator, Optional, Union, overload +from azure.core.rest import HttpResponse, AsyncHttpResponse +from ._enums import ChatRole +from .._model_base import rest_discriminator, rest_field +from ._models import ChatRequestMessage +from ._models import ImageUrl as ImageUrlGenerated +from ._models import ChatCompletions as ChatCompletionsGenerated +from ._models import EmbeddingsResult as EmbeddingsResultGenerated +from ._models import ImageEmbeddingInput as EmbeddingInputGenerated +from ._models import InputAudio as InputAudioGenerated +from .. import models as _models + +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing_extensions import Self + +logger = logging.getLogger(__name__) + + +class UserMessage(ChatRequestMessage, discriminator="user"): + """A request chat message representing user input to the assistant. + + :ivar role: The chat role associated with this message, which is always 'user' for user + messages. Required. The role that provides input for chat completions. + :vartype role: str or ~azure.ai.inference.models.USER + :ivar content: The contents of the user message, with available input types varying by selected + model. Required. Is either a str type or a [ContentItem] type. + :vartype content: str or list[~azure.ai.inference.models.ContentItem] + """ + + role: Literal[ChatRole.USER] = rest_discriminator(name="role") # type: ignore + """The chat role associated with this message, which is always 'user' for user messages. Required. + The role that provides input for chat completions.""" + content: Union["str", List["_models.ContentItem"]] = rest_field() + """The contents of the user message, with available input types varying by selected model. + Required. Is either a str type or a [ContentItem] type.""" + + @overload + def __init__( + self, + content: Union[str, List["_models.ContentItem"]], + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + if len(args) == 1 and isinstance(args[0], (List, str)): + if kwargs.get("content") is not None: + raise ValueError("content cannot be provided as positional and keyword arguments") + kwargs["content"] = args[0] + args = tuple() + super().__init__(*args, role=ChatRole.USER, **kwargs) + + +class SystemMessage(ChatRequestMessage, discriminator="system"): + """A request chat message containing system instructions that influence how the model will + generate a chat completions response. + + :ivar role: The chat role associated with this message, which is always 'system' for system + messages. Required. + :vartype role: str or ~azure.ai.inference.models.SYSTEM + :ivar content: The contents of the system message. Required. + :vartype content: str + """ + + role: Literal[ChatRole.SYSTEM] = rest_discriminator(name="role") # type: ignore + """The chat role associated with this message, which is always 'system' for system messages. + Required.""" + content: str = rest_field() + """The contents of the system message. Required.""" + + @overload + def __init__( + self, + content: str, + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + if len(args) == 1 and isinstance(args[0], str): + if kwargs.get("content") is not None: + raise ValueError("content cannot be provided as positional and keyword arguments") + kwargs["content"] = args[0] + args = tuple() + super().__init__(*args, role=ChatRole.SYSTEM, **kwargs) + + +class DeveloperMessage(ChatRequestMessage, discriminator="developer"): + """A request chat message containing developer instructions that influence how the model will + generate a chat completions response. Some AI models support developer messages instead + of system messages. + + :ivar role: The chat role associated with this message, which is always 'developer' for developer + messages. Required. + :vartype role: str or ~azure.ai.inference.models.DEVELOPER + :ivar content: The contents of the developer message. Required. + :vartype content: str + """ + + role: Literal[ChatRole.DEVELOPER] = rest_discriminator(name="role") # type: ignore + """The chat role associated with this message, which is always 'developer' for developer messages. + Required.""" + content: str = rest_field() + """The contents of the developer message. Required.""" + + @overload + def __init__( + self, + content: str, + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + if len(args) == 1 and isinstance(args[0], str): + if kwargs.get("content") is not None: + raise ValueError("content cannot be provided as positional and keyword arguments") + kwargs["content"] = args[0] + args = tuple() + super().__init__(*args, role=ChatRole.DEVELOPER, **kwargs) + + +class AssistantMessage(ChatRequestMessage, discriminator="assistant"): + """A request chat message representing response or action from the assistant. + + :ivar role: The chat role associated with this message, which is always 'assistant' for + assistant messages. Required. The role that provides responses to system-instructed, + user-prompted input. + :vartype role: str or ~azure.ai.inference.models.ASSISTANT + :ivar content: The content of the message. + :vartype content: str + :ivar tool_calls: The tool calls that must be resolved and have their outputs appended to + subsequent input messages for the chat + completions request to resolve as configured. + :vartype tool_calls: list[~azure.ai.inference.models.ChatCompletionsToolCall] + """ + + role: Literal[ChatRole.ASSISTANT] = rest_discriminator(name="role") # type: ignore + """The chat role associated with this message, which is always 'assistant' for assistant messages. + Required. The role that provides responses to system-instructed, user-prompted input.""" + content: Optional[str] = rest_field() + """The content of the message.""" + tool_calls: Optional[List["_models.ChatCompletionsToolCall"]] = rest_field() + """The tool calls that must be resolved and have their outputs appended to subsequent input + messages for the chat + completions request to resolve as configured.""" + + @overload + def __init__( + self, + content: Optional[str] = None, + *, + tool_calls: Optional[List["_models.ChatCompletionsToolCall"]] = None, + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + if len(args) == 1 and isinstance(args[0], str): + if kwargs.get("content") is not None: + raise ValueError("content cannot be provided as positional and keyword arguments") + kwargs["content"] = args[0] + args = tuple() + super().__init__(*args, role=ChatRole.ASSISTANT, **kwargs) + + +class ToolMessage(ChatRequestMessage, discriminator="tool"): + """A request chat message representing requested output from a configured tool. + + :ivar role: The chat role associated with this message, which is always 'tool' for tool + messages. Required. The role that represents extension tool activity within a chat completions + operation. + :vartype role: str or ~azure.ai.inference.models.TOOL + :ivar content: The content of the message. + :vartype content: str + :ivar tool_call_id: The ID of the tool call resolved by the provided content. Required. + :vartype tool_call_id: str + """ + + role: Literal[ChatRole.TOOL] = rest_discriminator(name="role") # type: ignore + """The chat role associated with this message, which is always 'tool' for tool messages. Required. + The role that represents extension tool activity within a chat completions operation.""" + content: Optional[str] = rest_field() + """The content of the message.""" + tool_call_id: str = rest_field() + """The ID of the tool call resolved by the provided content. Required.""" + + @overload + def __init__( + self, + content: Optional[str] = None, + *, + tool_call_id: str, + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + if len(args) == 1 and isinstance(args[0], str): + if kwargs.get("content") is not None: + raise ValueError("content cannot be provided as positional and keyword arguments") + kwargs["content"] = args[0] + args = tuple() + super().__init__(*args, role=ChatRole.TOOL, **kwargs) + + +class ChatCompletions(ChatCompletionsGenerated): + """Representation of the response data from a chat completions request. + Completions support a wide variety of tasks and generate text that continues from or + "completes" + provided prompt data. + + + :ivar id: A unique identifier associated with this chat completions response. Required. + :vartype id: str + :ivar created: The first timestamp associated with generation activity for this completions + response, + represented as seconds since the beginning of the Unix epoch of 00:00 on 1 Jan 1970. Required. + :vartype created: ~datetime.datetime + :ivar model: The model used for the chat completion. Required. + :vartype model: str + :ivar usage: Usage information for tokens processed and generated as part of this completions + operation. Required. + :vartype usage: ~azure.ai.inference.models.CompletionsUsage + :ivar choices: The collection of completions choices associated with this completions response. + Generally, ``n`` choices are generated per provided prompt with a default value of 1. + Token limits and other settings may limit the number of choices generated. Required. + :vartype choices: list[~azure.ai.inference.models.ChatChoice] + """ + + def __str__(self) -> str: + # pylint: disable=client-method-name-no-double-underscore + return json.dumps(self.as_dict(), indent=2) + + +class EmbeddingsResult(EmbeddingsResultGenerated): + """Representation of the response data from an embeddings request. + Embeddings measure the relatedness of text strings and are commonly used for search, + clustering, + recommendations, and other similar scenarios. + + + :ivar data: Embedding values for the prompts submitted in the request. Required. + :vartype data: list[~azure.ai.inference.models.EmbeddingItem] + :ivar usage: Usage counts for tokens input using the embeddings API. Required. + :vartype usage: ~azure.ai.inference.models.EmbeddingsUsage + :ivar model: The model ID used to generate this result. Required. + :vartype model: str + """ + + def __str__(self) -> str: + # pylint: disable=client-method-name-no-double-underscore + return json.dumps(self.as_dict(), indent=2) + + +class ImageUrl(ImageUrlGenerated): + + @classmethod + def load( + cls, *, image_file: str, image_format: str, detail: Optional[Union[str, "_models.ImageDetailLevel"]] = None + ) -> Self: + """ + Create an ImageUrl object from a local image file. The method reads the image + file and encodes it as a base64 string, which together with the image format + is then used to format the JSON `url` value passed in the request payload. + + :keyword image_file: The name of the local image file to load. Required. + :paramtype image_file: str + :keyword image_format: The MIME type format of the image. For example: "jpeg", "png". Required. + :paramtype image_format: str + :keyword detail: The evaluation quality setting to use, which controls relative prioritization of + speed, token consumption, and accuracy. Known values are: "auto", "low", and "high". + :paramtype detail: str or ~azure.ai.inference.models.ImageDetailLevel + :return: An ImageUrl object with the image data encoded as a base64 string. + :rtype: ~azure.ai.inference.models.ImageUrl + :raises FileNotFoundError: when the image file could not be opened. + """ + with open(image_file, "rb") as f: + image_data = base64.b64encode(f.read()).decode("utf-8") + url = f"data:image/{image_format};base64,{image_data}" + return cls(url=url, detail=detail) + + +class ImageEmbeddingInput(EmbeddingInputGenerated): + + @classmethod + def load(cls, *, image_file: str, image_format: str, text: Optional[str] = None) -> Self: + """ + Create an ImageEmbeddingInput object from a local image file. The method reads the image + file and encodes it as a base64 string, which together with the image format + is then used to format the JSON `url` value passed in the request payload. + + :keyword image_file: The name of the local image file to load. Required. + :paramtype image_file: str + :keyword image_format: The MIME type format of the image. For example: "jpeg", "png". Required. + :paramtype image_format: str + :keyword text: Optional. The text input to feed into the model (like DINO, CLIP). + Returns a 422 error if the model doesn't support the value or parameter. + :paramtype text: str + :return: An ImageEmbeddingInput object with the image data encoded as a base64 string. + :rtype: ~azure.ai.inference.models.EmbeddingsInput + :raises FileNotFoundError: when the image file could not be opened. + """ + with open(image_file, "rb") as f: + image_data = base64.b64encode(f.read()).decode("utf-8") + image_uri = f"data:image/{image_format};base64,{image_data}" + return cls(image=image_uri, text=text) + + +class BaseStreamingChatCompletions: + """A base class for the sync and async streaming chat completions responses, holding any common code + to deserializes the Server Sent Events (SSE) response stream into chat completions updates, each one + represented by a StreamingChatCompletionsUpdate object. + """ + + # Enable detailed logs of SSE parsing. For development only, should be `False` by default. + _ENABLE_CLASS_LOGS = False + + # The prefix of each line in the SSE stream that contains a JSON string + # to deserialize into a StreamingChatCompletionsUpdate object + _SSE_DATA_EVENT_PREFIX = b"data: " + + # The line indicating the end of the SSE stream + _SSE_DATA_EVENT_DONE = b"data: [DONE]" + + def __init__(self): + self._queue: "queue.Queue[_models.StreamingChatCompletionsUpdate]" = queue.Queue() + self._incomplete_line = b"" + self._done = False # Will be set to True when reading 'data: [DONE]' line + + # See https://html.spec.whatwg.org/multipage/server-sent-events.html#parsing-an-event-stream + def _deserialize_and_add_to_queue(self, element: bytes) -> bool: + + if self._ENABLE_CLASS_LOGS: + logger.debug("[Original element] %s", repr(element)) + + # Clear the queue of StreamingChatCompletionsUpdate before processing the next block + self._queue.queue.clear() + + # Split the single input bytes object at new line characters, and get a list of bytes objects, each + # representing a single "line". The bytes object at the end of the list may be a partial "line" that + # does not contain a new line character at the end. + # Note 1: DO NOT try to use something like this here: + # line_list: List[str] = re.split(r"(?<=\n)", element.decode("utf-8")) + # to do full UTF8 decoding of the whole input bytes object, as the last line in the list may be partial, and + # as such may contain a partial UTF8 Chinese character (for example). `decode("utf-8")` will raise an + # exception for such a case. See GitHub issue https://github.com/Azure/azure-sdk-for-python/issues/39565 + # Note 2: Consider future re-write and simplifications of this code by using: + # `codecs.getincrementaldecoder("utf-8")` + line_list: List[bytes] = re.split(re.compile(b"(?<=\n)"), element) + for index, line in enumerate(line_list): + + if self._ENABLE_CLASS_LOGS: + logger.debug("[Original line] %s", repr(line)) + + if index == 0: + line = self._incomplete_line + line + self._incomplete_line = b"" + + if index == len(line_list) - 1 and not line.endswith(b"\n"): + self._incomplete_line = line + return False + + if self._ENABLE_CLASS_LOGS: + logger.debug("[Modified line] %s", repr(line)) + + if line == b"\n": # Empty line, indicating flush output to client + continue + + if not line.startswith(self._SSE_DATA_EVENT_PREFIX): + raise ValueError(f"SSE event not supported (line `{repr(line)}`)") + + if line.startswith(self._SSE_DATA_EVENT_DONE): + if self._ENABLE_CLASS_LOGS: + logger.debug("[Done]") + return True + + # If you reached here, the line should contain `data: {...}\n` + # where the curly braces contain a valid JSON object. + # It is now safe to do UTF8 decoding of the line. + line_str = line.decode("utf-8") + + # Deserialize it into a StreamingChatCompletionsUpdate object + # and add it to the queue. + # pylint: disable=W0212 # Access to a protected member _deserialize of a client class + update = _models.StreamingChatCompletionsUpdate._deserialize( + json.loads(line_str[len(self._SSE_DATA_EVENT_PREFIX) : -1]), [] + ) + + # We skip any update that has a None or empty choices list, and does not have token usage info. + # (this is what OpenAI Python SDK does) + if update.choices or update.usage: + self._queue.put(update) + + if self._ENABLE_CLASS_LOGS: + logger.debug("[Added to queue]") + + return False + + +class StreamingChatCompletions(BaseStreamingChatCompletions): + """Represents an interator over StreamingChatCompletionsUpdate objects. It can be used for either synchronous or + asynchronous iterations. The class deserializes the Server Sent Events (SSE) response stream + into chat completions updates, each one represented by a StreamingChatCompletionsUpdate object. + """ + + def __init__(self, response: HttpResponse): + super().__init__() + self._response = response + self._bytes_iterator: Iterator[bytes] = response.iter_bytes() + + def __iter__(self) -> Any: + return self + + def __next__(self) -> "_models.StreamingChatCompletionsUpdate": + while self._queue.empty() and not self._done: + self._done = self._read_next_block() + if self._queue.empty(): + raise StopIteration + return self._queue.get() + + def _read_next_block(self) -> bool: + if self._ENABLE_CLASS_LOGS: + logger.debug("[Reading next block]") + try: + element = self._bytes_iterator.__next__() + except StopIteration: + self.close() + return True + return self._deserialize_and_add_to_queue(element) + + def __enter__(self): + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: # type: ignore + self.close() + + def close(self) -> None: + self._response.close() + + +class AsyncStreamingChatCompletions(BaseStreamingChatCompletions): + """Represents an async interator over StreamingChatCompletionsUpdate objects. + It can be used for either synchronous or asynchronous iterations. The class + deserializes the Server Sent Events (SSE) response stream into chat + completions updates, each one represented by a StreamingChatCompletionsUpdate object. + """ + + def __init__(self, response: AsyncHttpResponse): + super().__init__() + self._response = response + self._bytes_iterator: AsyncIterator[bytes] = response.iter_bytes() + + def __aiter__(self) -> Any: + return self + + async def __anext__(self) -> "_models.StreamingChatCompletionsUpdate": + while self._queue.empty() and not self._done: + self._done = await self._read_next_block_async() + if self._queue.empty(): + raise StopAsyncIteration + return self._queue.get() + + async def _read_next_block_async(self) -> bool: + if self._ENABLE_CLASS_LOGS: + logger.debug("[Reading next block]") + try: + element = await self._bytes_iterator.__anext__() + except StopAsyncIteration: + await self.aclose() + return True + return self._deserialize_and_add_to_queue(element) + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: # type: ignore + await self.aclose() + + async def aclose(self) -> None: + await self._response.close() + + +class InputAudio(InputAudioGenerated): + + @classmethod + def load( + cls, + *, + audio_file: str, + audio_format: str, + ) -> Self: + """ + Create an InputAudio object from a local audio file. The method reads the audio + file and encodes it as a base64 string, which together with the audio format + is then used to create the InputAudio object passed to the request payload. + + :keyword audio_file: The name of the local audio file to load. Required. + :vartype audio_file: str + :keyword audio_format: The MIME type format of the audio. For example: "wav", "mp3". Required. + :vartype audio_format: str + :return: An InputAudio object with the audio data encoded as a base64 string. + :rtype: ~azure.ai.inference.models.InputAudio + :raises FileNotFoundError: when the image file could not be opened. + """ + with open(audio_file, "rb") as f: + audio_data = base64.b64encode(f.read()).decode("utf-8") + return cls(data=audio_data, format=audio_format) + + +__all__: List[str] = [ + "AssistantMessage", + "AsyncStreamingChatCompletions", + "ChatCompletions", + "ChatRequestMessage", + "EmbeddingsResult", + "ImageEmbeddingInput", + "ImageUrl", + "InputAudio", + "StreamingChatCompletions", + "SystemMessage", + "ToolMessage", + "UserMessage", + "DeveloperMessage", +] # Add all objects you want publicly available to users at this package level + + +def patch_sdk(): + """Do not remove from this file. + + `patch_sdk` is a last resort escape hatch that allows you to do customizations + you can't accomplish using the techniques described in + https://aka.ms/azsdk/python/dpcodegen/python/customize + """ diff --git a/.venv/lib/python3.12/site-packages/azure/ai/inference/prompts/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/inference/prompts/__init__.py new file mode 100644 index 00000000..2e11b31c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/inference/prompts/__init__.py @@ -0,0 +1,8 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +# pylint: disable=unused-import +from ._patch import patch_sdk as _patch_sdk, PromptTemplate + +_patch_sdk() diff --git a/.venv/lib/python3.12/site-packages/azure/ai/inference/prompts/_core.py b/.venv/lib/python3.12/site-packages/azure/ai/inference/prompts/_core.py new file mode 100644 index 00000000..ec670299 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/inference/prompts/_core.py @@ -0,0 +1,312 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +# mypy: disable-error-code="assignment,attr-defined,index,arg-type" +# pylint: disable=line-too-long,R,consider-iterating-dictionary,raise-missing-from,dangerous-default-value +from __future__ import annotations +import os +from dataclasses import dataclass, field, asdict +from pathlib import Path +from typing import Any, AsyncIterator, Dict, Iterator, List, Literal, Union +from ._tracer import Tracer, to_dict +from ._utils import load_json + + +@dataclass +class ToolCall: + id: str + name: str + arguments: str + + +@dataclass +class PropertySettings: + """PropertySettings class to define the properties of the model + + Attributes + ---------- + type : str + The type of the property + default : Any + The default value of the property + description : str + The description of the property + """ + + type: Literal["string", "number", "array", "object", "boolean"] + default: Union[str, int, float, List, Dict, bool, None] = field(default=None) + description: str = field(default="") + + +@dataclass +class ModelSettings: + """ModelSettings class to define the model of the prompty + + Attributes + ---------- + api : str + The api of the model + configuration : Dict + The configuration of the model + parameters : Dict + The parameters of the model + response : Dict + The response of the model + """ + + api: str = field(default="") + configuration: Dict = field(default_factory=dict) + parameters: Dict = field(default_factory=dict) + response: Dict = field(default_factory=dict) + + +@dataclass +class TemplateSettings: + """TemplateSettings class to define the template of the prompty + + Attributes + ---------- + type : str + The type of the template + parser : str + The parser of the template + """ + + type: str = field(default="mustache") + parser: str = field(default="") + + +@dataclass +class Prompty: + """Prompty class to define the prompty + + Attributes + ---------- + name : str + The name of the prompty + description : str + The description of the prompty + authors : List[str] + The authors of the prompty + tags : List[str] + The tags of the prompty + version : str + The version of the prompty + base : str + The base of the prompty + basePrompty : Prompty + The base prompty + model : ModelSettings + The model of the prompty + sample : Dict + The sample of the prompty + inputs : Dict[str, PropertySettings] + The inputs of the prompty + outputs : Dict[str, PropertySettings] + The outputs of the prompty + template : TemplateSettings + The template of the prompty + file : FilePath + The file of the prompty + content : Union[str, List[str], Dict] + The content of the prompty + """ + + # metadata + name: str = field(default="") + description: str = field(default="") + authors: List[str] = field(default_factory=list) + tags: List[str] = field(default_factory=list) + version: str = field(default="") + base: str = field(default="") + basePrompty: Union[Prompty, None] = field(default=None) + # model + model: ModelSettings = field(default_factory=ModelSettings) + + # sample + sample: Dict = field(default_factory=dict) + + # input / output + inputs: Dict[str, PropertySettings] = field(default_factory=dict) + outputs: Dict[str, PropertySettings] = field(default_factory=dict) + + # template + template: TemplateSettings = field(default_factory=TemplateSettings) + + file: Union[Path, str] = field(default="") + content: Union[str, List[str], Dict] = field(default="") + + def to_safe_dict(self) -> Dict[str, Any]: + d = {} + if self.model: + d["model"] = asdict(self.model) + _mask_secrets(d, ["model", "configuration"]) + if self.template: + d["template"] = asdict(self.template) + if self.inputs: + d["inputs"] = {k: asdict(v) for k, v in self.inputs.items()} + if self.outputs: + d["outputs"] = {k: asdict(v) for k, v in self.outputs.items()} + if self.file: + d["file"] = str(self.file.as_posix()) if isinstance(self.file, Path) else self.file + return d + + @staticmethod + def hoist_base_prompty(top: Prompty, base: Prompty) -> Prompty: + top.name = base.name if top.name == "" else top.name + top.description = base.description if top.description == "" else top.description + top.authors = list(set(base.authors + top.authors)) + top.tags = list(set(base.tags + top.tags)) + top.version = base.version if top.version == "" else top.version + + top.model.api = base.model.api if top.model.api == "" else top.model.api + top.model.configuration = param_hoisting(top.model.configuration, base.model.configuration) + top.model.parameters = param_hoisting(top.model.parameters, base.model.parameters) + top.model.response = param_hoisting(top.model.response, base.model.response) + + top.sample = param_hoisting(top.sample, base.sample) + + top.basePrompty = base + + return top + + @staticmethod + def _process_file(file: str, parent: Path) -> Any: + file_path = Path(parent / Path(file)).resolve().absolute() + if file_path.exists(): + items = load_json(file_path) + if isinstance(items, list): + return [Prompty.normalize(value, parent) for value in items] + elif isinstance(items, Dict): + return {key: Prompty.normalize(value, parent) for key, value in items.items()} + else: + return items + else: + raise FileNotFoundError(f"File {file} not found") + + @staticmethod + def _process_env(variable: str, env_error=True, default: Union[str, None] = None) -> Any: + if variable in os.environ.keys(): + return os.environ[variable] + else: + if default: + return default + if env_error: + raise ValueError(f"Variable {variable} not found in environment") + + return "" + + @staticmethod + def normalize(attribute: Any, parent: Path, env_error=True) -> Any: + if isinstance(attribute, str): + attribute = attribute.strip() + if attribute.startswith("${") and attribute.endswith("}"): + # check if env or file + variable = attribute[2:-1].split(":") + if variable[0] == "env" and len(variable) > 1: + return Prompty._process_env( + variable[1], + env_error, + variable[2] if len(variable) > 2 else None, + ) + elif variable[0] == "file" and len(variable) > 1: + return Prompty._process_file(variable[1], parent) + else: + raise ValueError(f"Invalid attribute format ({attribute})") + else: + return attribute + elif isinstance(attribute, list): + return [Prompty.normalize(value, parent) for value in attribute] + elif isinstance(attribute, Dict): + return {key: Prompty.normalize(value, parent) for key, value in attribute.items()} + else: + return attribute + + +def param_hoisting(top: Dict[str, Any], bottom: Dict[str, Any], top_key: Union[str, None] = None) -> Dict[str, Any]: + if top_key: + new_dict = {**top[top_key]} if top_key in top else {} + else: + new_dict = {**top} + for key, value in bottom.items(): + if not key in new_dict: + new_dict[key] = value + return new_dict + + +class PromptyStream(Iterator): + """PromptyStream class to iterate over LLM stream. + Necessary for Prompty to handle streaming data when tracing.""" + + def __init__(self, name: str, iterator: Iterator): + self.name = name + self.iterator = iterator + self.items: List[Any] = [] + self.__name__ = "PromptyStream" + + def __iter__(self): + return self + + def __next__(self): + try: + # enumerate but add to list + o = self.iterator.__next__() + self.items.append(o) + return o + + except StopIteration: + # StopIteration is raised + # contents are exhausted + if len(self.items) > 0: + with Tracer.start("PromptyStream") as trace: + trace("signature", f"{self.name}.PromptyStream") + trace("inputs", "None") + trace("result", [to_dict(s) for s in self.items]) + + raise StopIteration + + +class AsyncPromptyStream(AsyncIterator): + """AsyncPromptyStream class to iterate over LLM stream. + Necessary for Prompty to handle streaming data when tracing.""" + + def __init__(self, name: str, iterator: AsyncIterator): + self.name = name + self.iterator = iterator + self.items: List[Any] = [] + self.__name__ = "AsyncPromptyStream" + + def __aiter__(self): + return self + + async def __anext__(self): + try: + # enumerate but add to list + o = await self.iterator.__anext__() + self.items.append(o) + return o + + except StopAsyncIteration: + # StopIteration is raised + # contents are exhausted + if len(self.items) > 0: + with Tracer.start("AsyncPromptyStream") as trace: + trace("signature", f"{self.name}.AsyncPromptyStream") + trace("inputs", "None") + trace("result", [to_dict(s) for s in self.items]) + + raise StopAsyncIteration + + +def _mask_secrets(d: Dict[str, Any], path: list[str], patterns: list[str] = ["key", "secret"]) -> bool: + sub_d = d + for key in path: + if key not in sub_d: + return False + sub_d = sub_d[key] + + for k, v in sub_d.items(): + if any([pattern in k.lower() for pattern in patterns]): + sub_d[k] = "*" * len(v) + return True diff --git a/.venv/lib/python3.12/site-packages/azure/ai/inference/prompts/_invoker.py b/.venv/lib/python3.12/site-packages/azure/ai/inference/prompts/_invoker.py new file mode 100644 index 00000000..d682662e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/inference/prompts/_invoker.py @@ -0,0 +1,295 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +# mypy: disable-error-code="return-value,operator" +# pylint: disable=line-too-long,R,docstring-missing-param,docstring-missing-return,docstring-missing-rtype,unnecessary-pass +import abc +from typing import Any, Callable, Dict, Literal +from ._tracer import trace +from ._core import Prompty + + +class Invoker(abc.ABC): + """Abstract class for Invoker + + Attributes + ---------- + prompty : Prompty + The prompty object + name : str + The name of the invoker + + """ + + def __init__(self, prompty: Prompty) -> None: + self.prompty = prompty + self.name = self.__class__.__name__ + + @abc.abstractmethod + def invoke(self, data: Any) -> Any: + """Abstract method to invoke the invoker + + Parameters + ---------- + data : Any + The data to be invoked + + Returns + ------- + Any + The invoked + """ + pass + + @abc.abstractmethod + async def invoke_async(self, data: Any) -> Any: + """Abstract method to invoke the invoker asynchronously + + Parameters + ---------- + data : Any + The data to be invoked + + Returns + ------- + Any + The invoked + """ + pass + + @trace + def run(self, data: Any) -> Any: + """Method to run the invoker + + Parameters + ---------- + data : Any + The data to be invoked + + Returns + ------- + Any + The invoked + """ + return self.invoke(data) + + @trace + async def run_async(self, data: Any) -> Any: + """Method to run the invoker asynchronously + + Parameters + ---------- + data : Any + The data to be invoked + + Returns + ------- + Any + The invoked + """ + return await self.invoke_async(data) + + +class InvokerFactory: + """Factory class for Invoker""" + + _renderers: Dict[str, Invoker] = {} + _parsers: Dict[str, Invoker] = {} + _executors: Dict[str, Invoker] = {} + _processors: Dict[str, Invoker] = {} + + @classmethod + def add_renderer(cls, name: str, invoker: Invoker) -> None: + cls._renderers[name] = invoker + + @classmethod + def add_parser(cls, name: str, invoker: Invoker) -> None: + cls._parsers[name] = invoker + + @classmethod + def add_executor(cls, name: str, invoker: Invoker) -> None: + cls._executors[name] = invoker + + @classmethod + def add_processor(cls, name: str, invoker: Invoker) -> None: + cls._processors[name] = invoker + + @classmethod + def register_renderer(cls, name: str) -> Callable: + def inner_wrapper(wrapped_class: Invoker) -> Callable: + cls._renderers[name] = wrapped_class + return wrapped_class # type: ignore + + return inner_wrapper + + @classmethod + def register_parser(cls, name: str) -> Callable: + def inner_wrapper(wrapped_class: Invoker) -> Callable: + cls._parsers[name] = wrapped_class + return wrapped_class # type: ignore + + return inner_wrapper + + @classmethod + def register_executor(cls, name: str) -> Callable: + def inner_wrapper(wrapped_class: Invoker) -> Callable: + cls._executors[name] = wrapped_class + return wrapped_class # type: ignore + + return inner_wrapper + + @classmethod + def register_processor(cls, name: str) -> Callable: + def inner_wrapper(wrapped_class: Invoker) -> Callable: + cls._processors[name] = wrapped_class + return wrapped_class # type: ignore + + return inner_wrapper + + @classmethod + def _get_name( + cls, + type: Literal["renderer", "parser", "executor", "processor"], + prompty: Prompty, + ) -> str: + if type == "renderer": + return prompty.template.type + elif type == "parser": + return f"{prompty.template.parser}.{prompty.model.api}" + elif type == "executor": + return prompty.model.configuration["type"] + elif type == "processor": + return prompty.model.configuration["type"] + else: + raise ValueError(f"Type {type} not found") + + @classmethod + def _get_invoker( + cls, + type: Literal["renderer", "parser", "executor", "processor"], + prompty: Prompty, + ) -> Invoker: + if type == "renderer": + name = prompty.template.type + if name not in cls._renderers: + raise ValueError(f"Renderer {name} not found") + + return cls._renderers[name](prompty) # type: ignore + + elif type == "parser": + name = f"{prompty.template.parser}.{prompty.model.api}" + if name not in cls._parsers: + raise ValueError(f"Parser {name} not found") + + return cls._parsers[name](prompty) # type: ignore + + elif type == "executor": + name = prompty.model.configuration["type"] + if name not in cls._executors: + raise ValueError(f"Executor {name} not found") + + return cls._executors[name](prompty) # type: ignore + + elif type == "processor": + name = prompty.model.configuration["type"] + if name not in cls._processors: + raise ValueError(f"Processor {name} not found") + + return cls._processors[name](prompty) # type: ignore + + else: + raise ValueError(f"Type {type} not found") + + @classmethod + def run( + cls, + type: Literal["renderer", "parser", "executor", "processor"], + prompty: Prompty, + data: Any, + default: Any = None, + ): + name = cls._get_name(type, prompty) + if name.startswith("NOOP") and default is not None: + return default + elif name.startswith("NOOP"): + return data + + invoker = cls._get_invoker(type, prompty) + value = invoker.run(data) + return value + + @classmethod + async def run_async( + cls, + type: Literal["renderer", "parser", "executor", "processor"], + prompty: Prompty, + data: Any, + default: Any = None, + ): + name = cls._get_name(type, prompty) + if name.startswith("NOOP") and default is not None: + return default + elif name.startswith("NOOP"): + return data + invoker = cls._get_invoker(type, prompty) + value = await invoker.run_async(data) + return value + + @classmethod + def run_renderer(cls, prompty: Prompty, data: Any, default: Any = None) -> Any: + return cls.run("renderer", prompty, data, default) + + @classmethod + async def run_renderer_async(cls, prompty: Prompty, data: Any, default: Any = None) -> Any: + return await cls.run_async("renderer", prompty, data, default) + + @classmethod + def run_parser(cls, prompty: Prompty, data: Any, default: Any = None) -> Any: + return cls.run("parser", prompty, data, default) + + @classmethod + async def run_parser_async(cls, prompty: Prompty, data: Any, default: Any = None) -> Any: + return await cls.run_async("parser", prompty, data, default) + + @classmethod + def run_executor(cls, prompty: Prompty, data: Any, default: Any = None) -> Any: + return cls.run("executor", prompty, data, default) + + @classmethod + async def run_executor_async(cls, prompty: Prompty, data: Any, default: Any = None) -> Any: + return await cls.run_async("executor", prompty, data, default) + + @classmethod + def run_processor(cls, prompty: Prompty, data: Any, default: Any = None) -> Any: + return cls.run("processor", prompty, data, default) + + @classmethod + async def run_processor_async(cls, prompty: Prompty, data: Any, default: Any = None) -> Any: + return await cls.run_async("processor", prompty, data, default) + + +class InvokerException(Exception): + """Exception class for Invoker""" + + def __init__(self, message: str, type: str) -> None: + super().__init__(message) + self.type = type + + def __str__(self) -> str: + return f"{super().__str__()}. Make sure to pip install any necessary package extras (i.e. could be something like `pip install prompty[{self.type}]`) for {self.type} as well as import the appropriate invokers (i.e. could be something like `import prompty.{self.type}`)." + + +@InvokerFactory.register_renderer("NOOP") +@InvokerFactory.register_parser("NOOP") +@InvokerFactory.register_executor("NOOP") +@InvokerFactory.register_processor("NOOP") +@InvokerFactory.register_parser("prompty.embedding") +@InvokerFactory.register_parser("prompty.image") +@InvokerFactory.register_parser("prompty.completion") +class NoOp(Invoker): + def invoke(self, data: Any) -> Any: + return data + + async def invoke_async(self, data: str) -> Any: + return self.invoke(data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/inference/prompts/_mustache.py b/.venv/lib/python3.12/site-packages/azure/ai/inference/prompts/_mustache.py new file mode 100644 index 00000000..f7a0c21d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/inference/prompts/_mustache.py @@ -0,0 +1,671 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +# pylint: disable=line-too-long,R,consider-using-dict-items,docstring-missing-return,docstring-missing-rtype,docstring-missing-param,global-statement,unused-argument,global-variable-not-assigned,protected-access,logging-fstring-interpolation,deprecated-method +from __future__ import annotations +import logging +from collections.abc import Iterator, Sequence +from types import MappingProxyType +from typing import ( + Any, + Dict, + List, + Literal, + Mapping, + Optional, + Union, + cast, +) +from typing_extensions import TypeAlias + +logger = logging.getLogger(__name__) + + +Scopes: TypeAlias = List[Union[Literal[False, 0], Mapping[str, Any]]] + + +# Globals +_CURRENT_LINE = 1 +_LAST_TAG_LINE = None + + +class ChevronError(SyntaxError): + """Custom exception for Chevron errors.""" + + +# +# Helper functions +# + + +def grab_literal(template: str, l_del: str) -> tuple[str, str]: + """Parse a literal from the template. + + Args: + template: The template to parse. + l_del: The left delimiter. + + Returns: + Tuple[str, str]: The literal and the template. + """ + + global _CURRENT_LINE + + try: + # Look for the next tag and move the template to it + literal, template = template.split(l_del, 1) + _CURRENT_LINE += literal.count("\n") + return (literal, template) + + # There are no more tags in the template? + except ValueError: + # Then the rest of the template is a literal + return (template, "") + + +def l_sa_check(template: str, literal: str, is_standalone: bool) -> bool: + """Do a preliminary check to see if a tag could be a standalone. + + Args: + template: The template. (Not used.) + literal: The literal. + is_standalone: Whether the tag is standalone. + + Returns: + bool: Whether the tag could be a standalone. + """ + + # If there is a newline, or the previous tag was a standalone + if literal.find("\n") != -1 or is_standalone: + padding = literal.split("\n")[-1] + + # If all the characters since the last newline are spaces + # Then the next tag could be a standalone + # Otherwise it can't be + return padding.isspace() or padding == "" + else: + return False + + +def r_sa_check(template: str, tag_type: str, is_standalone: bool) -> bool: + """Do a final check to see if a tag could be a standalone. + + Args: + template: The template. + tag_type: The type of the tag. + is_standalone: Whether the tag is standalone. + + Returns: + bool: Whether the tag could be a standalone. + """ + + # Check right side if we might be a standalone + if is_standalone and tag_type not in ["variable", "no escape"]: + on_newline = template.split("\n", 1) + + # If the stuff to the right of us are spaces we're a standalone + return on_newline[0].isspace() or not on_newline[0] + + # If we're a tag can't be a standalone + else: + return False + + +def parse_tag(template: str, l_del: str, r_del: str) -> tuple[tuple[str, str], str]: + """Parse a tag from a template. + + Args: + template: The template. + l_del: The left delimiter. + r_del: The right delimiter. + + Returns: + Tuple[Tuple[str, str], str]: The tag and the template. + + Raises: + ChevronError: If the tag is unclosed. + ChevronError: If the set delimiter tag is unclosed. + """ + global _CURRENT_LINE + global _LAST_TAG_LINE + + tag_types = { + "!": "comment", + "#": "section", + "^": "inverted section", + "/": "end", + ">": "partial", + "=": "set delimiter?", + "{": "no escape?", + "&": "no escape", + } + + # Get the tag + try: + tag, template = template.split(r_del, 1) + except ValueError as e: + msg = "unclosed tag " f"at line {_CURRENT_LINE}" + raise ChevronError(msg) from e + + # Find the type meaning of the first character + tag_type = tag_types.get(tag[0], "variable") + + # If the type is not a variable + if tag_type != "variable": + # Then that first character is not needed + tag = tag[1:] + + # If we might be a set delimiter tag + if tag_type == "set delimiter?": + # Double check to make sure we are + if tag.endswith("="): + tag_type = "set delimiter" + # Remove the equal sign + tag = tag[:-1] + + # Otherwise we should complain + else: + msg = "unclosed set delimiter tag\n" f"at line {_CURRENT_LINE}" + raise ChevronError(msg) + + elif ( + # If we might be a no html escape tag + tag_type == "no escape?" + # And we have a third curly brace + # (And are using curly braces as delimiters) + and l_del == "{{" + and r_del == "}}" + and template.startswith("}") + ): + # Then we are a no html escape tag + template = template[1:] + tag_type = "no escape" + + # Strip the whitespace off the key and return + return ((tag_type, tag.strip()), template) + + +# +# The main tokenizing function +# + + +def tokenize(template: str, def_ldel: str = "{{", def_rdel: str = "}}") -> Iterator[tuple[str, str]]: + """Tokenize a mustache template. + + Tokenizes a mustache template in a generator fashion, + using file-like objects. It also accepts a string containing + the template. + + + Arguments: + + template -- a file-like object, or a string of a mustache template + + def_ldel -- The default left delimiter + ("{{" by default, as in spec compliant mustache) + + def_rdel -- The default right delimiter + ("}}" by default, as in spec compliant mustache) + + + Returns: + + A generator of mustache tags in the form of a tuple + + -- (tag_type, tag_key) + + Where tag_type is one of: + * literal + * section + * inverted section + * end + * partial + * no escape + + And tag_key is either the key or in the case of a literal tag, + the literal itself. + """ + + global _CURRENT_LINE, _LAST_TAG_LINE + _CURRENT_LINE = 1 + _LAST_TAG_LINE = None + + is_standalone = True + open_sections = [] + l_del = def_ldel + r_del = def_rdel + + while template: + literal, template = grab_literal(template, l_del) + + # If the template is completed + if not template: + # Then yield the literal and leave + yield ("literal", literal) + break + + # Do the first check to see if we could be a standalone + is_standalone = l_sa_check(template, literal, is_standalone) + + # Parse the tag + tag, template = parse_tag(template, l_del, r_del) + tag_type, tag_key = tag + + # Special tag logic + + # If we are a set delimiter tag + if tag_type == "set delimiter": + # Then get and set the delimiters + dels = tag_key.strip().split(" ") + l_del, r_del = dels[0], dels[-1] + + # If we are a section tag + elif tag_type in ["section", "inverted section"]: + # Then open a new section + open_sections.append(tag_key) + _LAST_TAG_LINE = _CURRENT_LINE + + # If we are an end tag + elif tag_type == "end": + # Then check to see if the last opened section + # is the same as us + try: + last_section = open_sections.pop() + except IndexError as e: + msg = f'Trying to close tag "{tag_key}"\n' "Looks like it was not opened.\n" f"line {_CURRENT_LINE + 1}" + raise ChevronError(msg) from e + if tag_key != last_section: + # Otherwise we need to complain + msg = ( + f'Trying to close tag "{tag_key}"\n' + f'last open tag is "{last_section}"\n' + f"line {_CURRENT_LINE + 1}" + ) + raise ChevronError(msg) + + # Do the second check to see if we're a standalone + is_standalone = r_sa_check(template, tag_type, is_standalone) + + # Which if we are + if is_standalone: + # Remove the stuff before the newline + template = template.split("\n", 1)[-1] + + # Partials need to keep the spaces on their left + if tag_type != "partial": + # But other tags don't + literal = literal.rstrip(" ") + + # Start yielding + # Ignore literals that are empty + if literal != "": + yield ("literal", literal) + + # Ignore comments and set delimiters + if tag_type not in ["comment", "set delimiter?"]: + yield (tag_type, tag_key) + + # If there are any open sections when we're done + if open_sections: + # Then we need to complain + msg = ( + "Unexpected EOF\n" + f'the tag "{open_sections[-1]}" was never closed\n' + f"was opened at line {_LAST_TAG_LINE}" + ) + raise ChevronError(msg) + + +# +# Helper functions +# + + +def _html_escape(string: str) -> str: + """HTML escape all of these " & < >""" + + html_codes = { + '"': """, + "<": "<", + ">": ">", + } + + # & must be handled first + string = string.replace("&", "&") + for char in html_codes: + string = string.replace(char, html_codes[char]) + return string + + +def _get_key( + key: str, + scopes: Scopes, + warn: bool, + keep: bool, + def_ldel: str, + def_rdel: str, +) -> Any: + """Get a key from the current scope""" + + # If the key is a dot + if key == ".": + # Then just return the current scope + return scopes[0] + + # Loop through the scopes + for scope in scopes: + try: + # Return an empty string if falsy, with two exceptions + # 0 should return 0, and False should return False + if scope in (0, False): + return scope + + # For every dot separated key + for child in key.split("."): + # Return an empty string if falsy, with two exceptions + # 0 should return 0, and False should return False + if scope in (0, False): + return scope + # Move into the scope + try: + # Try subscripting (Normal dictionaries) + scope = cast(Dict[str, Any], scope)[child] + except (TypeError, AttributeError): + try: + scope = getattr(scope, child) + except (TypeError, AttributeError): + # Try as a list + scope = scope[int(child)] # type: ignore + + try: + # This allows for custom falsy data types + # https://github.com/noahmorrison/chevron/issues/35 + if scope._CHEVRON_return_scope_when_falsy: # type: ignore + return scope + except AttributeError: + if scope in (0, False): + return scope + return scope or "" + except (AttributeError, KeyError, IndexError, ValueError): + # We couldn't find the key in the current scope + # We'll try again on the next pass + pass + + # We couldn't find the key in any of the scopes + + if warn: + logger.warn(f"Could not find key '{key}'") + + if keep: + return f"{def_ldel} {key} {def_rdel}" + + return "" + + +def _get_partial(name: str, partials_dict: Mapping[str, str]) -> str: + """Load a partial""" + try: + # Maybe the partial is in the dictionary + return partials_dict[name] + except KeyError: + return "" + + +# +# The main rendering function +# +g_token_cache: Dict[str, List[tuple[str, str]]] = {} + +EMPTY_DICT: MappingProxyType[str, str] = MappingProxyType({}) + + +def render( + template: Union[str, List[tuple[str, str]]] = "", + data: Mapping[str, Any] = EMPTY_DICT, + partials_dict: Mapping[str, str] = EMPTY_DICT, + padding: str = "", + def_ldel: str = "{{", + def_rdel: str = "}}", + scopes: Optional[Scopes] = None, + warn: bool = False, + keep: bool = False, +) -> str: + """Render a mustache template. + + Renders a mustache template with a data scope and inline partial capability. + + Arguments: + + template -- A file-like object or a string containing the template. + + data -- A python dictionary with your data scope. + + partials_path -- The path to where your partials are stored. + If set to None, then partials won't be loaded from the file system + (defaults to '.'). + + partials_ext -- The extension that you want the parser to look for + (defaults to 'mustache'). + + partials_dict -- A python dictionary which will be search for partials + before the filesystem is. {'include': 'foo'} is the same + as a file called include.mustache + (defaults to {}). + + padding -- This is for padding partials, and shouldn't be used + (but can be if you really want to). + + def_ldel -- The default left delimiter + ("{{" by default, as in spec compliant mustache). + + def_rdel -- The default right delimiter + ("}}" by default, as in spec compliant mustache). + + scopes -- The list of scopes that get_key will look through. + + warn -- Log a warning when a template substitution isn't found in the data + + keep -- Keep unreplaced tags when a substitution isn't found in the data. + + + Returns: + + A string containing the rendered template. + """ + + # If the template is a sequence but not derived from a string + if isinstance(template, Sequence) and not isinstance(template, str): + # Then we don't need to tokenize it + # But it does need to be a generator + tokens: Iterator[tuple[str, str]] = (token for token in template) + else: + if template in g_token_cache: + tokens = (token for token in g_token_cache[template]) + else: + # Otherwise make a generator + tokens = tokenize(template, def_ldel, def_rdel) + + output = "" + + if scopes is None: + scopes = [data] + + # Run through the tokens + for tag, key in tokens: + # Set the current scope + current_scope = scopes[0] + + # If we're an end tag + if tag == "end": + # Pop out of the latest scope + del scopes[0] + + # If the current scope is falsy and not the only scope + elif not current_scope and len(scopes) != 1: + if tag in ["section", "inverted section"]: + # Set the most recent scope to a falsy value + scopes.insert(0, False) + + # If we're a literal tag + elif tag == "literal": + # Add padding to the key and add it to the output + output += key.replace("\n", "\n" + padding) + + # If we're a variable tag + elif tag == "variable": + # Add the html escaped key to the output + thing = _get_key(key, scopes, warn=warn, keep=keep, def_ldel=def_ldel, def_rdel=def_rdel) + if thing is True and key == ".": + # if we've coerced into a boolean by accident + # (inverted tags do this) + # then get the un-coerced object (next in the stack) + thing = scopes[1] + if not isinstance(thing, str): + thing = str(thing) + output += _html_escape(thing) + + # If we're a no html escape tag + elif tag == "no escape": + # Just lookup the key and add it + thing = _get_key(key, scopes, warn=warn, keep=keep, def_ldel=def_ldel, def_rdel=def_rdel) + if not isinstance(thing, str): + thing = str(thing) + output += thing + + # If we're a section tag + elif tag == "section": + # Get the sections scope + scope = _get_key(key, scopes, warn=warn, keep=keep, def_ldel=def_ldel, def_rdel=def_rdel) + + # If the scope is a callable (as described in + # https://mustache.github.io/mustache.5.html) + if callable(scope): + # Generate template text from tags + text = "" + tags: List[tuple[str, str]] = [] + for token in tokens: + if token == ("end", key): + break + + tags.append(token) + tag_type, tag_key = token + if tag_type == "literal": + text += tag_key + elif tag_type == "no escape": + text += f"{def_ldel}& {tag_key} {def_rdel}" + else: + text += "{}{} {}{}".format( + def_ldel, + { + "comment": "!", + "section": "#", + "inverted section": "^", + "end": "/", + "partial": ">", + "set delimiter": "=", + "no escape": "&", + "variable": "", + }[tag_type], + tag_key, + def_rdel, + ) + + g_token_cache[text] = tags + + rend = scope( + text, + lambda template, data=None: render( + template, + data={}, + partials_dict=partials_dict, + padding=padding, + def_ldel=def_ldel, + def_rdel=def_rdel, + scopes=data and [data] + scopes or scopes, + warn=warn, + keep=keep, + ), + ) + + output += rend # type: ignore[reportOperatorIssue] + + # If the scope is a sequence, an iterator or generator but not + # derived from a string + elif isinstance(scope, (Sequence, Iterator)) and not isinstance(scope, str): + # Then we need to do some looping + + # Gather up all the tags inside the section + # (And don't be tricked by nested end tags with the same key) + # TODO: This feels like it still has edge cases, no? + tags = [] + tags_with_same_key = 0 + for token in tokens: + if token == ("section", key): + tags_with_same_key += 1 + if token == ("end", key): + tags_with_same_key -= 1 + if tags_with_same_key < 0: + break + tags.append(token) + + # For every item in the scope + for thing in scope: + # Append it as the most recent scope and render + new_scope = [thing] + scopes + rend = render( + template=tags, + scopes=new_scope, + padding=padding, + partials_dict=partials_dict, + def_ldel=def_ldel, + def_rdel=def_rdel, + warn=warn, + keep=keep, + ) + + output += rend + + else: + # Otherwise we're just a scope section + scopes.insert(0, scope) # type: ignore[reportArgumentType] + + # If we're an inverted section + elif tag == "inverted section": + # Add the flipped scope to the scopes + scope = _get_key(key, scopes, warn=warn, keep=keep, def_ldel=def_ldel, def_rdel=def_rdel) + scopes.insert(0, cast(Literal[False], not scope)) + + # If we're a partial + elif tag == "partial": + # Load the partial + partial = _get_partial(key, partials_dict) + + # Find what to pad the partial with + left = output.rpartition("\n")[2] + part_padding = padding + if left.isspace(): + part_padding += left + + # Render the partial + part_out = render( + template=partial, + partials_dict=partials_dict, + def_ldel=def_ldel, + def_rdel=def_rdel, + padding=part_padding, + scopes=scopes, + warn=warn, + keep=keep, + ) + + # If the partial was indented + if left.isspace(): + # then remove the spaces from the end + part_out = part_out.rstrip(" \t") + + # Add the partials output to the output + output += part_out + + return output diff --git a/.venv/lib/python3.12/site-packages/azure/ai/inference/prompts/_parsers.py b/.venv/lib/python3.12/site-packages/azure/ai/inference/prompts/_parsers.py new file mode 100644 index 00000000..de3c570e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/inference/prompts/_parsers.py @@ -0,0 +1,156 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +# mypy: disable-error-code="union-attr,return-value" +# pylint: disable=line-too-long,R,consider-using-enumerate,docstring-missing-param,docstring-missing-return,docstring-missing-rtype +import re +import base64 +from pathlib import Path +from typing import Any, Union +from ._core import Prompty +from ._invoker import Invoker, InvokerFactory + + +ROLES = ["assistant", "function", "system", "user"] + + +@InvokerFactory.register_parser("prompty.chat") +class PromptyChatParser(Invoker): + """Prompty Chat Parser""" + + def __init__(self, prompty: Prompty) -> None: + super().__init__(prompty) + self.path = Path(self.prompty.file).parent + + def invoke(self, data: str) -> Any: + return invoke_parser(self.path, data) + + async def invoke_async(self, data: str) -> Any: + """Invoke the Prompty Chat Parser (Async) + + Parameters + ---------- + data : str + The data to parse + + Returns + ------- + str + The parsed data + """ + return self.invoke(data) + + +def _inline_image(path: Union[Path, None], image_item: str) -> str: + """Inline Image + + Parameters + ---------- + image_item : str + The image item to inline + + Returns + ------- + str + The inlined image + """ + # pass through if it's a url or base64 encoded or the path is None + if image_item.startswith("http") or image_item.startswith("data") or path is None: + return image_item + # otherwise, it's a local file - need to base64 encode it + else: + image_path = (path if path is not None else Path(".")) / image_item + with open(image_path, "rb") as f: + base64_image = base64.b64encode(f.read()).decode("utf-8") + + if image_path.suffix == ".png": + return f"data:image/png;base64,{base64_image}" + elif image_path.suffix == ".jpg": + return f"data:image/jpeg;base64,{base64_image}" + elif image_path.suffix == ".jpeg": + return f"data:image/jpeg;base64,{base64_image}" + else: + raise ValueError( + f"Invalid image format {image_path.suffix} - currently only .png and .jpg / .jpeg are supported." + ) + + +def _parse_content(path: Union[Path, None], content: str): + """for parsing inline images + + Parameters + ---------- + content : str + The content to parse + + Returns + ------- + any + The parsed content + """ + # regular expression to parse markdown images + image = r"(?P<alt>!\[[^\]]*\])\((?P<filename>.*?)(?=\"|\))\)" + matches = re.findall(image, content, flags=re.MULTILINE) + if len(matches) > 0: + content_items = [] + content_chunks = re.split(image, content, flags=re.MULTILINE) + current_chunk = 0 + for i in range(len(content_chunks)): + # image entry + if current_chunk < len(matches) and content_chunks[i] == matches[current_chunk][0]: + content_items.append( + { + "type": "image_url", + "image_url": {"url": _inline_image(path, matches[current_chunk][1].split(" ")[0].strip())}, + } + ) + # second part of image entry + elif current_chunk < len(matches) and content_chunks[i] == matches[current_chunk][1]: + current_chunk += 1 + # text entry + else: + if len(content_chunks[i].strip()) > 0: + content_items.append({"type": "text", "text": content_chunks[i].strip()}) + return content_items + else: + return content + + +def invoke_parser(path: Union[Path, None], data: str) -> Any: + """Invoke the Prompty Chat Parser + + Parameters + ---------- + data : str + The data to parse + + Returns + ------- + str + The parsed data + """ + messages = [] + separator = r"(?i)^\s*#?\s*(" + "|".join(ROLES) + r")\s*:\s*\n" + + # get valid chunks - remove empty items + chunks = [item for item in re.split(separator, data, flags=re.MULTILINE) if len(item.strip()) > 0] + + # if no starter role, then inject system role + if not chunks[0].strip().lower() in ROLES: + chunks.insert(0, "system") + + # if last chunk is role entry, then remove (no content?) + if chunks[-1].strip().lower() in ROLES: + chunks.pop() + + if len(chunks) % 2 != 0: + raise ValueError("Invalid prompt format") + + # create messages + for i in range(0, len(chunks), 2): + role = chunks[i].strip().lower() + content = chunks[i + 1].strip() + messages.append({"role": role, "content": _parse_content(path, content)}) + + return messages diff --git a/.venv/lib/python3.12/site-packages/azure/ai/inference/prompts/_patch.py b/.venv/lib/python3.12/site-packages/azure/ai/inference/prompts/_patch.py new file mode 100644 index 00000000..14ad4f62 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/inference/prompts/_patch.py @@ -0,0 +1,124 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +# pylint: disable=line-too-long,R +"""Customize generated code here. + +Follow our quickstart for examples: https://aka.ms/azsdk/python/dpcodegen/python/customize +""" + +import traceback +from pathlib import Path +from typing import Any, Dict, List, Optional +from typing_extensions import Self +from ._core import Prompty +from ._mustache import render +from ._parsers import invoke_parser +from ._prompty_utils import load, prepare +from ._utils import remove_leading_empty_space + + +class PromptTemplate: + """The helper class which takes variant of inputs, e.g. Prompty format or string, and returns the parsed prompt in an array.""" + + @classmethod + def from_prompty(cls, file_path: str) -> Self: + """Initialize a PromptTemplate object from a prompty file. + + :param file_path: The path to the prompty file. + :type file_path: str + :return: The PromptTemplate object. + :rtype: PromptTemplate + """ + if not file_path: + raise ValueError("Please provide file_path") + + # Get the absolute path of the file by `traceback.extract_stack()`, it's "-2" because: + # In the stack, the last function is the current function. + # The second last function is the caller function, which is the root of the file_path. + stack = traceback.extract_stack() + caller = Path(stack[-2].filename) + abs_file_path = Path(caller.parent / Path(file_path)).resolve().absolute() + + prompty = load(str(abs_file_path)) + return cls(prompty=prompty) + + @classmethod + def from_string(cls, prompt_template: str, api: str = "chat", model_name: Optional[str] = None) -> Self: + """Initialize a PromptTemplate object from a message template. + + :param prompt_template: The prompt template string. + :type prompt_template: str + :param api: The API type, e.g. "chat" or "completion". + :type api: str + :param model_name: The model name, e.g. "gpt-4o-mini". + :type model_name: str + :return: The PromptTemplate object. + :rtype: PromptTemplate + """ + return cls( + api=api, + prompt_template=prompt_template, + model_name=model_name, + prompty=None, + ) + + def __init__( + self, + *, + api: str = "chat", + prompty: Optional[Prompty] = None, + prompt_template: Optional[str] = None, + model_name: Optional[str] = None, + ) -> None: + self.prompty = prompty + if self.prompty is not None: + self.model_name = ( + self.prompty.model.configuration["azure_deployment"] + if "azure_deployment" in self.prompty.model.configuration + else None + ) + self.parameters = self.prompty.model.parameters + self._config = {} + elif prompt_template is not None: + self.model_name = model_name + self.parameters = {} + # _config is a dict to hold the internal configuration + self._config = { + "api": api if api is not None else "chat", + "prompt_template": prompt_template, + } + else: + raise ValueError("Please pass valid arguments for PromptTemplate") + + def create_messages(self, data: Optional[Dict[str, Any]] = None, **kwargs) -> List[Dict[str, Any]]: + """Render the prompt template with the given data. + + :param data: The data to render the prompt template with. + :type data: Optional[Dict[str, Any]] + :return: The rendered prompt template. + :rtype: List[Dict[str, Any]] + """ + if data is None: + data = kwargs + + if self.prompty is not None: + parsed = prepare(self.prompty, data) + return parsed + elif "prompt_template" in self._config: + prompt_template = remove_leading_empty_space(self._config["prompt_template"]) + system_prompt_str = render(prompt_template, data) + parsed = invoke_parser(None, system_prompt_str) + return parsed + else: + raise ValueError("Please provide valid prompt template") + + +def patch_sdk(): + """Do not remove from this file. + + `patch_sdk` is a last resort escape hatch that allows you to do customizations + you can't accomplish using the techniques described in + https://aka.ms/azsdk/python/dpcodegen/python/customize + """ diff --git a/.venv/lib/python3.12/site-packages/azure/ai/inference/prompts/_prompty_utils.py b/.venv/lib/python3.12/site-packages/azure/ai/inference/prompts/_prompty_utils.py new file mode 100644 index 00000000..5ea38bda --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/inference/prompts/_prompty_utils.py @@ -0,0 +1,415 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +# mypy: disable-error-code="assignment" +# pylint: disable=R,docstring-missing-param,docstring-missing-return,docstring-missing-rtype,dangerous-default-value,redefined-outer-name,unused-wildcard-import,wildcard-import,raise-missing-from +import traceback +from pathlib import Path +from typing import Any, Dict, List, Union +from ._tracer import trace +from ._invoker import InvokerFactory +from ._core import ( + ModelSettings, + Prompty, + PropertySettings, + TemplateSettings, + param_hoisting, +) +from ._utils import ( + load_global_config, + load_prompty, +) + +from ._renderers import * +from ._parsers import * + + +@trace(description="Create a headless prompty object for programmatic use.") +def headless( + api: str, + content: Union[str, List[str], dict], + configuration: Dict[str, Any] = {}, + parameters: Dict[str, Any] = {}, + connection: str = "default", +) -> Prompty: + """Create a headless prompty object for programmatic use. + + Parameters + ---------- + api : str + The API to use for the model + content : Union[str, List[str], dict] + The content to process + configuration : Dict[str, Any], optional + The configuration to use, by default {} + parameters : Dict[str, Any], optional + The parameters to use, by default {} + connection : str, optional + The connection to use, by default "default" + + Returns + ------- + Prompty + The headless prompty object + + Example + ------- + >>> import prompty + >>> p = prompty.headless( + api="embedding", + configuration={"type": "azure", "azure_deployment": "text-embedding-ada-002"}, + content="hello world", + ) + >>> emb = prompty.execute(p) + + """ + + # get caller's path (to get relative path for prompty.json) + caller = Path(traceback.extract_stack()[-2].filename) + templateSettings = TemplateSettings(type="NOOP", parser="NOOP") + modelSettings = ModelSettings( + api=api, + configuration=Prompty.normalize( + param_hoisting(configuration, load_global_config(caller.parent, connection)), + caller.parent, + ), + parameters=parameters, + ) + + return Prompty(model=modelSettings, template=templateSettings, content=content) + + +def _load_raw_prompty(attributes: dict, content: str, p: Path, global_config: dict): + if "model" not in attributes: + attributes["model"] = {} + + if "configuration" not in attributes["model"]: + attributes["model"]["configuration"] = global_config + else: + attributes["model"]["configuration"] = param_hoisting( + attributes["model"]["configuration"], + global_config, + ) + + # pull model settings out of attributes + try: + model = ModelSettings(**attributes.pop("model")) + except Exception as e: + raise ValueError(f"Error in model settings: {e}") + + # pull template settings + try: + if "template" in attributes: + t = attributes.pop("template") + if isinstance(t, dict): + template = TemplateSettings(**t) + # has to be a string denoting the type + else: + template = TemplateSettings(type=t, parser="prompty") + else: + template = TemplateSettings(type="mustache", parser="prompty") + except Exception as e: + raise ValueError(f"Error in template loader: {e}") + + # formalize inputs and outputs + if "inputs" in attributes: + try: + inputs = {k: PropertySettings(**v) for (k, v) in attributes.pop("inputs").items()} + except Exception as e: + raise ValueError(f"Error in inputs: {e}") + else: + inputs = {} + if "outputs" in attributes: + try: + outputs = {k: PropertySettings(**v) for (k, v) in attributes.pop("outputs").items()} + except Exception as e: + raise ValueError(f"Error in outputs: {e}") + else: + outputs = {} + + prompty = Prompty( + **attributes, + model=model, + inputs=inputs, + outputs=outputs, + template=template, + content=content, + file=p, + ) + + return prompty + + +@trace(description="Load a prompty file.") +def load(prompty_file: Union[str, Path], configuration: str = "default") -> Prompty: + """Load a prompty file. + + Parameters + ---------- + prompty_file : Union[str, Path] + The path to the prompty file + configuration : str, optional + The configuration to use, by default "default" + + Returns + ------- + Prompty + The loaded prompty object + + Example + ------- + >>> import prompty + >>> p = prompty.load("prompts/basic.prompty") + >>> print(p) + """ + + p = Path(prompty_file) + if not p.is_absolute(): + # get caller's path (take into account trace frame) + caller = Path(traceback.extract_stack()[-3].filename) + p = Path(caller.parent / p).resolve().absolute() + + # load dictionary from prompty file + matter = load_prompty(p) + + attributes = matter["attributes"] + content = matter["body"] + + # normalize attribute dictionary resolve keys and files + attributes = Prompty.normalize(attributes, p.parent) + + # load global configuration + global_config = Prompty.normalize(load_global_config(p.parent, configuration), p.parent) + + prompty = _load_raw_prompty(attributes, content, p, global_config) + + # recursive loading of base prompty + if "base" in attributes: + # load the base prompty from the same directory as the current prompty + base = load(p.parent / attributes["base"]) + prompty = Prompty.hoist_base_prompty(prompty, base) + + return prompty + + +@trace(description="Prepare the inputs for the prompt.") +def prepare( + prompt: Prompty, + inputs: Dict[str, Any] = {}, +): + """Prepare the inputs for the prompt. + + Parameters + ---------- + prompt : Prompty + The prompty object + inputs : Dict[str, Any], optional + The inputs to the prompt, by default {} + + Returns + ------- + dict + The prepared and hidrated template shaped to the LLM model + + Example + ------- + >>> import prompty + >>> p = prompty.load("prompts/basic.prompty") + >>> inputs = {"name": "John Doe"} + >>> content = prompty.prepare(p, inputs) + """ + inputs = param_hoisting(inputs, prompt.sample) + + render = InvokerFactory.run_renderer(prompt, inputs, prompt.content) + result = InvokerFactory.run_parser(prompt, render) + + return result + + +@trace(description="Prepare the inputs for the prompt.") +async def prepare_async( + prompt: Prompty, + inputs: Dict[str, Any] = {}, +): + """Prepare the inputs for the prompt. + + Parameters + ---------- + prompt : Prompty + The prompty object + inputs : Dict[str, Any], optional + The inputs to the prompt, by default {} + + Returns + ------- + dict + The prepared and hidrated template shaped to the LLM model + + Example + ------- + >>> import prompty + >>> p = prompty.load("prompts/basic.prompty") + >>> inputs = {"name": "John Doe"} + >>> content = await prompty.prepare_async(p, inputs) + """ + inputs = param_hoisting(inputs, prompt.sample) + + render = await InvokerFactory.run_renderer_async(prompt, inputs, prompt.content) + result = await InvokerFactory.run_parser_async(prompt, render) + + return result + + +@trace(description="Run the prepared Prompty content against the model.") +def run( + prompt: Prompty, + content: Union[dict, list, str], + configuration: Dict[str, Any] = {}, + parameters: Dict[str, Any] = {}, + raw: bool = False, +): + """Run the prepared Prompty content. + + Parameters + ---------- + prompt : Prompty + The prompty object + content : Union[dict, list, str] + The content to process + configuration : Dict[str, Any], optional + The configuration to use, by default {} + parameters : Dict[str, Any], optional + The parameters to use, by default {} + raw : bool, optional + Whether to skip processing, by default False + + Returns + ------- + Any + The result of the prompt + + Example + ------- + >>> import prompty + >>> p = prompty.load("prompts/basic.prompty") + >>> inputs = {"name": "John Doe"} + >>> content = prompty.prepare(p, inputs) + >>> result = prompty.run(p, content) + """ + + if configuration != {}: + prompt.model.configuration = param_hoisting(configuration, prompt.model.configuration) + + if parameters != {}: + prompt.model.parameters = param_hoisting(parameters, prompt.model.parameters) + + result = InvokerFactory.run_executor(prompt, content) + if not raw: + result = InvokerFactory.run_processor(prompt, result) + + return result + + +@trace(description="Run the prepared Prompty content against the model.") +async def run_async( + prompt: Prompty, + content: Union[dict, list, str], + configuration: Dict[str, Any] = {}, + parameters: Dict[str, Any] = {}, + raw: bool = False, +): + """Run the prepared Prompty content. + + Parameters + ---------- + prompt : Prompty + The prompty object + content : Union[dict, list, str] + The content to process + configuration : Dict[str, Any], optional + The configuration to use, by default {} + parameters : Dict[str, Any], optional + The parameters to use, by default {} + raw : bool, optional + Whether to skip processing, by default False + + Returns + ------- + Any + The result of the prompt + + Example + ------- + >>> import prompty + >>> p = prompty.load("prompts/basic.prompty") + >>> inputs = {"name": "John Doe"} + >>> content = await prompty.prepare_async(p, inputs) + >>> result = await prompty.run_async(p, content) + """ + + if configuration != {}: + prompt.model.configuration = param_hoisting(configuration, prompt.model.configuration) + + if parameters != {}: + prompt.model.parameters = param_hoisting(parameters, prompt.model.parameters) + + result = await InvokerFactory.run_executor_async(prompt, content) + if not raw: + result = await InvokerFactory.run_processor_async(prompt, result) + + return result + + +@trace(description="Execute a prompty") +def execute( + prompt: Union[str, Prompty], + configuration: Dict[str, Any] = {}, + parameters: Dict[str, Any] = {}, + inputs: Dict[str, Any] = {}, + raw: bool = False, + config_name: str = "default", +): + """Execute a prompty. + + Parameters + ---------- + prompt : Union[str, Prompty] + The prompty object or path to the prompty file + configuration : Dict[str, Any], optional + The configuration to use, by default {} + parameters : Dict[str, Any], optional + The parameters to use, by default {} + inputs : Dict[str, Any], optional + The inputs to the prompt, by default {} + raw : bool, optional + Whether to skip processing, by default False + connection : str, optional + The connection to use, by default "default" + + Returns + ------- + Any + The result of the prompt + + Example + ------- + >>> import prompty + >>> inputs = {"name": "John Doe"} + >>> result = prompty.execute("prompts/basic.prompty", inputs=inputs) + """ + if isinstance(prompt, str): + path = Path(prompt) + if not path.is_absolute(): + # get caller's path (take into account trace frame) + caller = Path(traceback.extract_stack()[-3].filename) + path = Path(caller.parent / path).resolve().absolute() + prompt = load(path, config_name) + + # prepare content + content = prepare(prompt, inputs) + + # run LLM model + result = run(prompt, content, configuration, parameters, raw) + + return result diff --git a/.venv/lib/python3.12/site-packages/azure/ai/inference/prompts/_renderers.py b/.venv/lib/python3.12/site-packages/azure/ai/inference/prompts/_renderers.py new file mode 100644 index 00000000..0d682a7f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/inference/prompts/_renderers.py @@ -0,0 +1,30 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +# mypy: disable-error-code="union-attr,assignment,arg-type" +from pathlib import Path +from ._core import Prompty +from ._invoker import Invoker, InvokerFactory +from ._mustache import render + + +@InvokerFactory.register_renderer("mustache") +class MustacheRenderer(Invoker): + """Render a mustache template.""" + + def __init__(self, prompty: Prompty) -> None: + super().__init__(prompty) + self.templates = {} + cur_prompt = self.prompty + while cur_prompt: + self.templates[Path(cur_prompt.file).name] = cur_prompt.content + cur_prompt = cur_prompt.basePrompty + self.name = Path(self.prompty.file).name + + def invoke(self, data: str) -> str: + generated = render(self.prompty.content, data) # type: ignore + return generated + + async def invoke_async(self, data: str) -> str: + return self.invoke(data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/inference/prompts/_tracer.py b/.venv/lib/python3.12/site-packages/azure/ai/inference/prompts/_tracer.py new file mode 100644 index 00000000..24f800b4 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/inference/prompts/_tracer.py @@ -0,0 +1,316 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +# mypy: disable-error-code="union-attr,arg-type,misc,return-value,assignment,func-returns-value" +# pylint: disable=R,redefined-outer-name,bare-except,unspecified-encoding +import os +import json +import inspect +import traceback +import importlib +import contextlib +from pathlib import Path +from numbers import Number +from datetime import datetime +from functools import wraps, partial +from typing import Any, Callable, Dict, Iterator, List, Union + + +# clean up key value pairs for sensitive values +def sanitize(key: str, value: Any) -> Any: + if isinstance(value, str) and any([s in key.lower() for s in ["key", "token", "secret", "password", "credential"]]): + return len(str(value)) * "*" + + if isinstance(value, dict): + return {k: sanitize(k, v) for k, v in value.items()} + + return value + + +class Tracer: + _tracers: Dict[str, Callable[[str], Iterator[Callable[[str, Any], None]]]] = {} + + @classmethod + def add(cls, name: str, tracer: Callable[[str], Iterator[Callable[[str, Any], None]]]) -> None: + cls._tracers[name] = tracer + + @classmethod + def clear(cls) -> None: + cls._tracers = {} + + @classmethod + @contextlib.contextmanager + def start(cls, name: str) -> Iterator[Callable[[str, Any], None]]: + with contextlib.ExitStack() as stack: + traces: List[Any] = [stack.enter_context(tracer(name)) for tracer in cls._tracers.values()] # type: ignore + yield lambda key, value: [ # type: ignore + # normalize and sanitize any trace values + trace(key, sanitize(key, to_dict(value))) + for trace in traces + ] + + +def to_dict(obj: Any) -> Union[Dict[str, Any], List[Dict[str, Any]], str, Number, bool]: + # simple json types + if isinstance(obj, str) or isinstance(obj, Number) or isinstance(obj, bool): + return obj + + # datetime + if isinstance(obj, datetime): + return obj.isoformat() + + # safe Prompty obj serialization + if type(obj).__name__ == "Prompty": + return obj.to_safe_dict() + + # safe PromptyStream obj serialization + if type(obj).__name__ == "PromptyStream": + return "PromptyStream" + + if type(obj).__name__ == "AsyncPromptyStream": + return "AsyncPromptyStream" + + # recursive list and dict + if isinstance(obj, List): + return [to_dict(item) for item in obj] # type: ignore + + if isinstance(obj, Dict): + return {k: v if isinstance(v, str) else to_dict(v) for k, v in obj.items()} + + if isinstance(obj, Path): + return str(obj) + + # cast to string otherwise... + return str(obj) + + +def _name(func: Callable, args): + if hasattr(func, "__qualname__"): + signature = f"{func.__module__}.{func.__qualname__}" + else: + signature = f"{func.__module__}.{func.__name__}" + + # core invoker gets special treatment prompty.invoker.Invoker + core_invoker = signature.startswith("prompty.invoker.Invoker.run") + if core_invoker: + name = type(args[0]).__name__ + if signature.endswith("async"): + signature = f"{args[0].__module__}.{args[0].__class__.__name__}.invoke_async" + else: + signature = f"{args[0].__module__}.{args[0].__class__.__name__}.invoke" + else: + name = func.__name__ + + return name, signature + + +def _inputs(func: Callable, args, kwargs) -> dict: + ba = inspect.signature(func).bind(*args, **kwargs) + ba.apply_defaults() + + inputs = {k: to_dict(v) for k, v in ba.arguments.items() if k != "self"} + + return inputs + + +def _results(result: Any) -> Union[Dict, List[Dict], str, Number, bool]: + return to_dict(result) if result is not None else "None" + + +def _trace_sync(func: Union[Callable, None] = None, **okwargs: Any) -> Callable: + + @wraps(func) # type: ignore + def wrapper(*args, **kwargs): + name, signature = _name(func, args) # type: ignore + with Tracer.start(name) as trace: + trace("signature", signature) + + # support arbitrary keyword + # arguments for trace decorator + for k, v in okwargs.items(): + trace(k, to_dict(v)) + + inputs = _inputs(func, args, kwargs) # type: ignore + trace("inputs", inputs) + + try: + result = func(*args, **kwargs) # type: ignore + trace("result", _results(result)) + except Exception as e: + trace( + "result", + { + "exception": { + "type": type(e), + "traceback": (traceback.format_tb(tb=e.__traceback__) if e.__traceback__ else None), + "message": str(e), + "args": to_dict(e.args), + } + }, + ) + raise e + + return result + + return wrapper + + +def _trace_async(func: Union[Callable, None] = None, **okwargs: Any) -> Callable: + + @wraps(func) # type: ignore + async def wrapper(*args, **kwargs): + name, signature = _name(func, args) # type: ignore + with Tracer.start(name) as trace: + trace("signature", signature) + + # support arbitrary keyword + # arguments for trace decorator + for k, v in okwargs.items(): + trace(k, to_dict(v)) + + inputs = _inputs(func, args, kwargs) # type: ignore + trace("inputs", inputs) + try: + result = await func(*args, **kwargs) # type: ignore + trace("result", _results(result)) + except Exception as e: + trace( + "result", + { + "exception": { + "type": type(e), + "traceback": (traceback.format_tb(tb=e.__traceback__) if e.__traceback__ else None), + "message": str(e), + "args": to_dict(e.args), + } + }, + ) + raise e + + return result + + return wrapper + + +def trace(func: Union[Callable, None] = None, **kwargs: Any) -> Callable: + if func is None: + return partial(trace, **kwargs) + wrapped_method = _trace_async if inspect.iscoroutinefunction(func) else _trace_sync + return wrapped_method(func, **kwargs) + + +class PromptyTracer: + def __init__(self, output_dir: Union[str, None] = None) -> None: + if output_dir: + self.output = Path(output_dir).resolve().absolute() + else: + self.output = Path(Path(os.getcwd()) / ".runs").resolve().absolute() + + if not self.output.exists(): + self.output.mkdir(parents=True, exist_ok=True) + + self.stack: List[Dict[str, Any]] = [] + + @contextlib.contextmanager + def tracer(self, name: str) -> Iterator[Callable[[str, Any], None]]: + try: + self.stack.append({"name": name}) + frame = self.stack[-1] + frame["__time"] = { + "start": datetime.now(), + } + + def add(key: str, value: Any) -> None: + if key not in frame: + frame[key] = value + # multiple values creates list + else: + if isinstance(frame[key], list): + frame[key].append(value) + else: + frame[key] = [frame[key], value] + + yield add + finally: + frame = self.stack.pop() + start: datetime = frame["__time"]["start"] + end: datetime = datetime.now() + + # add duration to frame + frame["__time"] = { + "start": start.strftime("%Y-%m-%dT%H:%M:%S.%f"), + "end": end.strftime("%Y-%m-%dT%H:%M:%S.%f"), + "duration": int((end - start).total_seconds() * 1000), + } + + # hoist usage to parent frame + if "result" in frame and isinstance(frame["result"], dict): + if "usage" in frame["result"]: + frame["__usage"] = self.hoist_item( + frame["result"]["usage"], + frame["__usage"] if "__usage" in frame else {}, + ) + + # streamed results may have usage as well + if "result" in frame and isinstance(frame["result"], list): + for result in frame["result"]: + if isinstance(result, dict) and "usage" in result and isinstance(result["usage"], dict): + frame["__usage"] = self.hoist_item( + result["usage"], + frame["__usage"] if "__usage" in frame else {}, + ) + + # add any usage frames from below + if "__frames" in frame: + for child in frame["__frames"]: + if "__usage" in child: + frame["__usage"] = self.hoist_item( + child["__usage"], + frame["__usage"] if "__usage" in frame else {}, + ) + + # if stack is empty, dump the frame + if len(self.stack) == 0: + self.write_trace(frame) + # otherwise, append the frame to the parent + else: + if "__frames" not in self.stack[-1]: + self.stack[-1]["__frames"] = [] + self.stack[-1]["__frames"].append(frame) + + def hoist_item(self, src: Dict[str, Any], cur: Dict[str, Any]) -> Dict[str, Any]: + for key, value in src.items(): + if value is None or isinstance(value, list) or isinstance(value, dict): + continue + try: + if key not in cur: + cur[key] = value + else: + cur[key] += value + except: + continue + + return cur + + def write_trace(self, frame: Dict[str, Any]) -> None: + trace_file = self.output / f"{frame['name']}.{datetime.now().strftime('%Y%m%d.%H%M%S')}.tracy" + + v = importlib.metadata.version("prompty") # type: ignore + enriched_frame = { + "runtime": "python", + "version": v, + "trace": frame, + } + + with open(trace_file, "w") as f: + json.dump(enriched_frame, f, indent=4) + + +@contextlib.contextmanager +def console_tracer(name: str) -> Iterator[Callable[[str, Any], None]]: + try: + print(f"Starting {name}") + yield lambda key, value: print(f"{key}:\n{json.dumps(to_dict(value), indent=4)}") + finally: + print(f"Ending {name}") diff --git a/.venv/lib/python3.12/site-packages/azure/ai/inference/prompts/_utils.py b/.venv/lib/python3.12/site-packages/azure/ai/inference/prompts/_utils.py new file mode 100644 index 00000000..22f28418 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/inference/prompts/_utils.py @@ -0,0 +1,100 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +# mypy: disable-error-code="import-untyped,return-value" +# pylint: disable=line-too-long,R,wrong-import-order,global-variable-not-assigned) +import json +import os +import re +import sys +from typing import Any, Dict +from pathlib import Path + + +_yaml_regex = re.compile( + r"^\s*" + r"(?:---|\+\+\+)" + r"(.*?)" + r"(?:---|\+\+\+)" + r"\s*(.+)$", + re.S | re.M, +) + + +def load_text(file_path, encoding="utf-8"): + with open(file_path, "r", encoding=encoding) as file: + return file.read() + + +def load_json(file_path, encoding="utf-8"): + return json.loads(load_text(file_path, encoding=encoding)) + + +def load_global_config(prompty_path: Path = Path.cwd(), configuration: str = "default") -> Dict[str, Any]: + prompty_config_path = prompty_path.joinpath("prompty.json") + if os.path.exists(prompty_config_path): + c = load_json(prompty_config_path) + if configuration in c: + return c[configuration] + else: + raise ValueError(f'Item "{configuration}" not found in "{prompty_config_path}"') + else: + return {} + + +def load_prompty(file_path, encoding="utf-8") -> Dict[str, Any]: + contents = load_text(file_path, encoding=encoding) + return parse(contents) + + +def parse(contents): + try: + import yaml # type: ignore + except ImportError as exc: + raise ImportError("Please install pyyaml to use this function. Run `pip install pyyaml`.") from exc + + global _yaml_regex + + fmatter = "" + body = "" + result = _yaml_regex.search(contents) + + if result: + fmatter = result.group(1) + body = result.group(2) + return { + "attributes": yaml.load(fmatter, Loader=yaml.SafeLoader), + "body": body, + "frontmatter": fmatter, + } + + +def remove_leading_empty_space(multiline_str: str) -> str: + """ + Processes a multiline string by: + 1. Removing empty lines + 2. Finding the minimum leading spaces + 3. Indenting all lines to the minimum level + + :param multiline_str: The input multiline string. + :type multiline_str: str + :return: The processed multiline string. + :rtype: str + """ + lines = multiline_str.splitlines() + start_index = 0 + while start_index < len(lines) and lines[start_index].strip() == "": + start_index += 1 + + # Find the minimum number of leading spaces + min_spaces = sys.maxsize + for line in lines[start_index:]: + if len(line.strip()) == 0: + continue + spaces = len(line) - len(line.lstrip()) + spaces += line.lstrip().count("\t") * 2 # Count tabs as 2 spaces + min_spaces = min(min_spaces, spaces) + + # Remove leading spaces and indent to the minimum level + processed_lines = [] + for line in lines[start_index:]: + processed_lines.append(line[min_spaces:]) + + return "\n".join(processed_lines) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/inference/py.typed b/.venv/lib/python3.12/site-packages/azure/ai/inference/py.typed new file mode 100644 index 00000000..e5aff4f8 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/inference/py.typed @@ -0,0 +1 @@ +# Marker file for PEP 561.
\ No newline at end of file diff --git a/.venv/lib/python3.12/site-packages/azure/ai/inference/tracing.py b/.venv/lib/python3.12/site-packages/azure/ai/inference/tracing.py new file mode 100644 index 00000000..f7937a99 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/inference/tracing.py @@ -0,0 +1,850 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +import copy +from enum import Enum +import functools +import json +import importlib +import logging +import os +from time import time_ns +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union +from urllib.parse import urlparse + +# pylint: disable = no-name-in-module +from azure.core import CaseInsensitiveEnumMeta # type: ignore +from azure.core.settings import settings +from . import models as _models + +try: + # pylint: disable = no-name-in-module + from azure.core.tracing import AbstractSpan, SpanKind # type: ignore + from opentelemetry.trace import StatusCode, Span + + _tracing_library_available = True +except ModuleNotFoundError: + + _tracing_library_available = False + + +__all__ = [ + "AIInferenceInstrumentor", +] + + +_inference_traces_enabled: bool = False +_trace_inference_content: bool = False +_INFERENCE_GEN_AI_SYSTEM_NAME = "az.ai.inference" + + +class TraceType(str, Enum, metaclass=CaseInsensitiveEnumMeta): # pylint: disable=C4747 + """An enumeration class to represent different types of traces.""" + + INFERENCE = "Inference" + + +class AIInferenceInstrumentor: + """ + A class for managing the trace instrumentation of AI Inference. + + This class allows enabling or disabling tracing for AI Inference. + and provides functionality to check whether instrumentation is active. + + """ + + def __init__(self): + if not _tracing_library_available: + raise ModuleNotFoundError( + "Azure Core Tracing Opentelemetry is not installed. " + "Please install it using 'pip install azure-core-tracing-opentelemetry'" + ) + # In the future we could support different versions from the same library + # and have a parameter that specifies the version to use. + self._impl = _AIInferenceInstrumentorPreview() + + def instrument(self, enable_content_recording: Optional[bool] = None) -> None: + """ + Enable trace instrumentation for AI Inference. + + :param enable_content_recording: Whether content recording is enabled as part + of the traces or not. Content in this context refers to chat message content + and function call tool related function names, function parameter names and + values. True will enable content recording, False will disable it. If no value + s provided, then the value read from environment variable + AZURE_TRACING_GEN_AI_CONTENT_RECORDING_ENABLED is used. If the environment variable + is not found, then the value will default to False. Please note that successive calls + to instrument will always apply the content recording value provided with the most + recent call to instrument (including applying the environment variable if no value is + provided and defaulting to false if the environment variable is not found), even if + instrument was already previously called without uninstrument being called in between + the instrument calls. + + :type enable_content_recording: bool, optional + """ + self._impl.instrument(enable_content_recording=enable_content_recording) + + def uninstrument(self) -> None: + """ + Disable trace instrumentation for AI Inference. + + Raises: + RuntimeError: If instrumentation is not currently enabled. + + This method removes any active instrumentation, stopping the tracing + of AI Inference. + """ + self._impl.uninstrument() + + def is_instrumented(self) -> bool: + """ + Check if trace instrumentation for AI Inference is currently enabled. + + :return: True if instrumentation is active, False otherwise. + :rtype: bool + """ + return self._impl.is_instrumented() + + def is_content_recording_enabled(self) -> bool: + """ + This function gets the content recording value. + + :return: A bool value indicating whether content recording is enabled. + :rtype: bool + """ + return self._impl.is_content_recording_enabled() + + +class _AIInferenceInstrumentorPreview: + """ + A class for managing the trace instrumentation of AI Inference. + + This class allows enabling or disabling tracing for AI Inference. + and provides functionality to check whether instrumentation is active. + """ + + def _str_to_bool(self, s): + if s is None: + return False + return str(s).lower() == "true" + + def instrument(self, enable_content_recording: Optional[bool] = None): + """ + Enable trace instrumentation for AI Inference. + + :param enable_content_recording: Whether content recording is enabled as part + of the traces or not. Content in this context refers to chat message content + and function call tool related function names, function parameter names and + values. True will enable content recording, False will disable it. If no value + is provided, then the value read from environment variable + AZURE_TRACING_GEN_AI_CONTENT_RECORDING_ENABLED is used. If the environment variable + is not found, then the value will default to False. + + :type enable_content_recording: bool, optional + """ + if enable_content_recording is None: + var_value = os.environ.get("AZURE_TRACING_GEN_AI_CONTENT_RECORDING_ENABLED") + enable_content_recording = self._str_to_bool(var_value) + if not self.is_instrumented(): + self._instrument_inference(enable_content_recording) + else: + self._set_content_recording_enabled(enable_content_recording=enable_content_recording) + + def uninstrument(self): + """ + Disable trace instrumentation for AI Inference. + + This method removes any active instrumentation, stopping the tracing + of AI Inference. + """ + if self.is_instrumented(): + self._uninstrument_inference() + + def is_instrumented(self): + """ + Check if trace instrumentation for AI Inference is currently enabled. + + :return: True if instrumentation is active, False otherwise. + :rtype: bool + """ + return self._is_instrumented() + + def set_content_recording_enabled(self, enable_content_recording: bool = False) -> None: + """This function sets the content recording value. + + :param enable_content_recording: Indicates whether tracing of message content should be enabled. + This also controls whether function call tool function names, + parameter names and parameter values are traced. + :type enable_content_recording: bool + """ + self._set_content_recording_enabled(enable_content_recording=enable_content_recording) + + def is_content_recording_enabled(self) -> bool: + """This function gets the content recording value. + + :return: A bool value indicating whether content tracing is enabled. + :rtype bool + """ + return self._is_content_recording_enabled() + + def _set_attributes(self, span: "AbstractSpan", *attrs: Tuple[str, Any]) -> None: + for attr in attrs: + key, value = attr + if value is not None: + span.add_attribute(key, value) + + def _add_request_chat_message_events(self, span: "AbstractSpan", **kwargs: Any) -> int: + timestamp = 0 + for message in kwargs.get("messages", []): + try: + message = message.as_dict() + except AttributeError: + pass + + if message.get("role"): + timestamp = self._record_event( + span, + f"gen_ai.{message.get('role')}.message", + { + "gen_ai.system": _INFERENCE_GEN_AI_SYSTEM_NAME, + "gen_ai.event.content": json.dumps(message), + }, + timestamp, + ) + + return timestamp + + def _parse_url(self, url): + parsed = urlparse(url) + server_address = parsed.hostname + port = parsed.port + return server_address, port + + def _add_request_chat_attributes(self, span: "AbstractSpan", *args: Any, **kwargs: Any) -> None: + client = args[0] + endpoint = client._config.endpoint # pylint: disable=protected-access + server_address, port = self._parse_url(endpoint) + model = "chat" + if kwargs.get("model") is not None: + model_value = kwargs.get("model") + if model_value is not None: + model = model_value + + self._set_attributes( + span, + ("gen_ai.operation.name", "chat"), + ("gen_ai.system", _INFERENCE_GEN_AI_SYSTEM_NAME), + ("gen_ai.request.model", model), + ("gen_ai.request.max_tokens", kwargs.get("max_tokens")), + ("gen_ai.request.temperature", kwargs.get("temperature")), + ("gen_ai.request.top_p", kwargs.get("top_p")), + ("server.address", server_address), + ) + if port is not None and port != 443: + span.add_attribute("server.port", port) + + def _remove_function_call_names_and_arguments(self, tool_calls: list) -> list: + tool_calls_copy = copy.deepcopy(tool_calls) + for tool_call in tool_calls_copy: + if "function" in tool_call: + if "name" in tool_call["function"]: + del tool_call["function"]["name"] + if "arguments" in tool_call["function"]: + del tool_call["function"]["arguments"] + if not tool_call["function"]: + del tool_call["function"] + return tool_calls_copy + + def _get_finish_reasons(self, result) -> Optional[List[str]]: + if hasattr(result, "choices") and result.choices: + finish_reasons: List[str] = [] + for choice in result.choices: + finish_reason = getattr(choice, "finish_reason", None) + + if finish_reason is None: + # If finish_reason is None, default to "none" + finish_reasons.append("none") + elif hasattr(finish_reason, "value"): + # If finish_reason has a 'value' attribute (i.e., it's an enum), use it + finish_reasons.append(finish_reason.value) + elif isinstance(finish_reason, str): + # If finish_reason is a string, use it directly + finish_reasons.append(finish_reason) + else: + # Default to "none" + finish_reasons.append("none") + + return finish_reasons + return None + + def _get_finish_reason_for_choice(self, choice): + finish_reason = getattr(choice, "finish_reason", None) + if finish_reason is not None: + return finish_reason.value + + return "none" + + def _add_response_chat_message_events( + self, span: "AbstractSpan", result: _models.ChatCompletions, last_event_timestamp_ns: int + ) -> None: + for choice in result.choices: + attributes = {} + if _trace_inference_content: + full_response: Dict[str, Any] = { + "message": {"content": choice.message.content}, + "finish_reason": self._get_finish_reason_for_choice(choice), + "index": choice.index, + } + if choice.message.tool_calls: + full_response["message"]["tool_calls"] = [tool.as_dict() for tool in choice.message.tool_calls] + attributes = { + "gen_ai.system": _INFERENCE_GEN_AI_SYSTEM_NAME, + "gen_ai.event.content": json.dumps(full_response), + } + else: + response: Dict[str, Any] = { + "finish_reason": self._get_finish_reason_for_choice(choice), + "index": choice.index, + } + if choice.message.tool_calls: + response["message"] = {} + tool_calls_function_names_and_arguments_removed = self._remove_function_call_names_and_arguments( + choice.message.tool_calls + ) + response["message"]["tool_calls"] = [ + tool.as_dict() for tool in tool_calls_function_names_and_arguments_removed + ] + + attributes = { + "gen_ai.system": _INFERENCE_GEN_AI_SYSTEM_NAME, + "gen_ai.event.content": json.dumps(response), + } + last_event_timestamp_ns = self._record_event(span, "gen_ai.choice", attributes, last_event_timestamp_ns) + + def _add_response_chat_attributes( + self, + span: "AbstractSpan", + result: Union[_models.ChatCompletions, _models.StreamingChatCompletionsUpdate], + ) -> None: + self._set_attributes( + span, + ("gen_ai.response.id", result.id), + ("gen_ai.response.model", result.model), + ( + "gen_ai.usage.input_tokens", + (result.usage.prompt_tokens if hasattr(result, "usage") and result.usage else None), + ), + ( + "gen_ai.usage.output_tokens", + (result.usage.completion_tokens if hasattr(result, "usage") and result.usage else None), + ), + ) + finish_reasons = self._get_finish_reasons(result) + if not finish_reasons is None: + span.add_attribute("gen_ai.response.finish_reasons", finish_reasons) # type: ignore + + def _add_request_details(self, span: "AbstractSpan", args: Any, kwargs: Any) -> int: + self._add_request_chat_attributes(span, *args, **kwargs) + if _trace_inference_content: + return self._add_request_chat_message_events(span, **kwargs) + return 0 + + def _add_response_details(self, span: "AbstractSpan", result: object, last_event_timestamp_ns: int) -> None: + if isinstance(result, _models.ChatCompletions): + self._add_response_chat_attributes(span, result) + self._add_response_chat_message_events(span, result, last_event_timestamp_ns) + # TODO add more models here + + def _accumulate_response(self, item, accumulate: Dict[str, Any]) -> None: + if item.finish_reason: + accumulate["finish_reason"] = item.finish_reason + if item.index: + accumulate["index"] = item.index + if item.delta.content: + accumulate.setdefault("message", {}) + accumulate["message"].setdefault("content", "") + accumulate["message"]["content"] += item.delta.content + if item.delta.tool_calls: + accumulate.setdefault("message", {}) + accumulate["message"].setdefault("tool_calls", []) + if item.delta.tool_calls is not None: + for tool_call in item.delta.tool_calls: + if tool_call.id: + accumulate["message"]["tool_calls"].append( + { + "id": tool_call.id, + "type": "", + "function": {"name": "", "arguments": ""}, + } + ) + if tool_call.function: + accumulate["message"]["tool_calls"][-1]["type"] = "function" + if tool_call.function and tool_call.function.name: + accumulate["message"]["tool_calls"][-1]["function"]["name"] = tool_call.function.name + if tool_call.function and tool_call.function.arguments: + accumulate["message"]["tool_calls"][-1]["function"]["arguments"] += tool_call.function.arguments + + def _accumulate_async_streaming_response(self, item, accumulate: Dict[str, Any]) -> None: + if not "choices" in item: + return + if "finish_reason" in item["choices"][0] and item["choices"][0]["finish_reason"]: + accumulate["finish_reason"] = item["choices"][0]["finish_reason"] + if "index" in item["choices"][0] and item["choices"][0]["index"]: + accumulate["index"] = item["choices"][0]["index"] + if not "delta" in item["choices"][0]: + return + if "content" in item["choices"][0]["delta"] and item["choices"][0]["delta"]["content"]: + accumulate.setdefault("message", {}) + accumulate["message"].setdefault("content", "") + accumulate["message"]["content"] += item["choices"][0]["delta"]["content"] + if "tool_calls" in item["choices"][0]["delta"] and item["choices"][0]["delta"]["tool_calls"]: + accumulate.setdefault("message", {}) + accumulate["message"].setdefault("tool_calls", []) + if item["choices"][0]["delta"]["tool_calls"] is not None: + for tool_call in item["choices"][0]["delta"]["tool_calls"]: + if tool_call.id: + accumulate["message"]["tool_calls"].append( + { + "id": tool_call.id, + "type": "", + "function": {"name": "", "arguments": ""}, + } + ) + if tool_call.function: + accumulate["message"]["tool_calls"][-1]["type"] = "function" + if tool_call.function and tool_call.function.name: + accumulate["message"]["tool_calls"][-1]["function"]["name"] = tool_call.function.name + if tool_call.function and tool_call.function.arguments: + accumulate["message"]["tool_calls"][-1]["function"]["arguments"] += tool_call.function.arguments + + def _wrapped_stream( + self, stream_obj: _models.StreamingChatCompletions, span: "AbstractSpan", previous_event_timestamp: int + ) -> _models.StreamingChatCompletions: + class StreamWrapper(_models.StreamingChatCompletions): + def __init__(self, stream_obj, instrumentor): + super().__init__(stream_obj._response) + self._instrumentor = instrumentor + + def __iter__( # pyright: ignore [reportIncompatibleMethodOverride] + self, + ) -> Iterator[_models.StreamingChatCompletionsUpdate]: + accumulate: Dict[str, Any] = {} + try: + chunk = None + for chunk in stream_obj: + for item in chunk.choices: + self._instrumentor._accumulate_response(item, accumulate) + yield chunk + + if chunk is not None: + self._instrumentor._add_response_chat_attributes(span, chunk) + + except Exception as exc: + # Set the span status to error + if isinstance(span.span_instance, Span): # pyright: ignore [reportPossiblyUnboundVariable] + span.span_instance.set_status( + StatusCode.ERROR, # pyright: ignore [reportPossiblyUnboundVariable] + description=str(exc), + ) + module = exc.__module__ if hasattr(exc, "__module__") and exc.__module__ != "builtins" else "" + error_type = f"{module}.{type(exc).__name__}" if module else type(exc).__name__ + self._instrumentor._set_attributes(span, ("error.type", error_type)) + raise + + finally: + if stream_obj._done is False: + if accumulate.get("finish_reason") is None: + accumulate["finish_reason"] = "error" + else: + # Only one choice expected with streaming + accumulate["index"] = 0 + # Delete message if content tracing is not enabled + if not _trace_inference_content: + if "message" in accumulate: + if "content" in accumulate["message"]: + del accumulate["message"]["content"] + if not accumulate["message"]: + del accumulate["message"] + if "message" in accumulate: + if "tool_calls" in accumulate["message"]: + tool_calls_function_names_and_arguments_removed = ( + self._instrumentor._remove_function_call_names_and_arguments( + accumulate["message"]["tool_calls"] + ) + ) + accumulate["message"]["tool_calls"] = list( + tool_calls_function_names_and_arguments_removed + ) + attributes = { + "gen_ai.system": _INFERENCE_GEN_AI_SYSTEM_NAME, + "gen_ai.event.content": json.dumps(accumulate), + } + self._instrumentor._record_event(span, "gen_ai.choice", attributes, previous_event_timestamp) + span.finish() + + return StreamWrapper(stream_obj, self) + + def _async_wrapped_stream( + self, stream_obj: _models.AsyncStreamingChatCompletions, span: "AbstractSpan", last_event_timestamp_ns: int + ) -> _models.AsyncStreamingChatCompletions: + class AsyncStreamWrapper(_models.AsyncStreamingChatCompletions): + def __init__(self, stream_obj, instrumentor, span, last_event_timestamp_ns): + super().__init__(stream_obj._response) + self._instrumentor = instrumentor + self._accumulate: Dict[str, Any] = {} + self._stream_obj = stream_obj + self.span = span + self._last_result = None + self._last_event_timestamp_ns = last_event_timestamp_ns + + async def __anext__(self) -> "_models.StreamingChatCompletionsUpdate": + try: + result = await super().__anext__() + self._instrumentor._accumulate_async_streaming_response( # pylint: disable=protected-access, line-too-long # pyright: ignore [reportFunctionMemberAccess] + result, self._accumulate + ) + self._last_result = result + except StopAsyncIteration as exc: + self._trace_stream_content() + raise exc + return result + + def _trace_stream_content(self) -> None: + if self._last_result: + self._instrumentor._add_response_chat_attributes( # pylint: disable=protected-access, line-too-long # pyright: ignore [reportFunctionMemberAccess] + span, self._last_result + ) + # Only one choice expected with streaming + self._accumulate["index"] = 0 + # Delete message if content tracing is not enabled + if not _trace_inference_content: + if "message" in self._accumulate: + if "content" in self._accumulate["message"]: + del self._accumulate["message"]["content"] + if not self._accumulate["message"]: + del self._accumulate["message"] + if "message" in self._accumulate: + if "tool_calls" in self._accumulate["message"]: + tools_no_recording = self._instrumentor._remove_function_call_names_and_arguments( # pylint: disable=protected-access, line-too-long # pyright: ignore [reportFunctionMemberAccess] + self._accumulate["message"]["tool_calls"] + ) + self._accumulate["message"]["tool_calls"] = list(tools_no_recording) + attributes = { + "gen_ai.system": _INFERENCE_GEN_AI_SYSTEM_NAME, + "gen_ai.event.content": json.dumps(self._accumulate), + } + self._last_event_timestamp_ns = self._instrumentor._record_event( # pylint: disable=protected-access, line-too-long # pyright: ignore [reportFunctionMemberAccess] + span, "gen_ai.choice", attributes, self._last_event_timestamp_ns + ) + span.finish() + + async_stream_wrapper = AsyncStreamWrapper(stream_obj, self, span, last_event_timestamp_ns) + return async_stream_wrapper + + def _record_event( + self, span: "AbstractSpan", name: str, attributes: Dict[str, Any], last_event_timestamp_ns: int + ) -> int: + timestamp = time_ns() + + # we're recording multiple events, some of them are emitted within (hundreds of) nanoseconds of each other. + # time.time_ns resolution is not high enough on windows to guarantee unique timestamps for each message. + # Also Azure Monitor truncates resolution to microseconds and some other backends truncate to milliseconds. + # + # But we need to give users a way to restore event order, so we're incrementing the timestamp + # by 1 microsecond for each message. + # + # This is a workaround, we'll find a generic and better solution - see + # https://github.com/open-telemetry/semantic-conventions/issues/1701 + if last_event_timestamp_ns > 0 and timestamp <= (last_event_timestamp_ns + 1000): + timestamp = last_event_timestamp_ns + 1000 + + span.span_instance.add_event(name=name, attributes=attributes, timestamp=timestamp) + + return timestamp + + def _trace_sync_function( + self, + function: Callable, + *, + _args_to_ignore: Optional[List[str]] = None, + _trace_type=TraceType.INFERENCE, + _name: Optional[str] = None, + ) -> Callable: + """ + Decorator that adds tracing to a synchronous function. + + :param function: The function to be traced. + :type function: Callable + :param args_to_ignore: A list of argument names to be ignored in the trace. + Defaults to None. + :type: args_to_ignore: [List[str]], optional + :param trace_type: The type of the trace. Defaults to TraceType.INFERENCE. + :type trace_type: TraceType, optional + :param name: The name of the trace, will set to func name if not provided. + :type name: str, optional + :return: The traced function. + :rtype: Callable + """ + + @functools.wraps(function) + def inner(*args, **kwargs): + + span_impl_type = settings.tracing_implementation() + if span_impl_type is None: + return function(*args, **kwargs) + + class_function_name = function.__qualname__ + + if class_function_name.startswith("ChatCompletionsClient.complete"): + if kwargs.get("model") is None: + span_name = "chat" + else: + model = kwargs.get("model") + span_name = f"chat {model}" + + span = span_impl_type( + name=span_name, + kind=SpanKind.CLIENT, # pyright: ignore [reportPossiblyUnboundVariable] + ) + + try: + # tracing events not supported in azure-core-tracing-opentelemetry + # so need to access the span instance directly + with span_impl_type.change_context(span.span_instance): + last_event_timestamp_ns = self._add_request_details(span, args, kwargs) + result = function(*args, **kwargs) + if kwargs.get("stream") is True: + return self._wrapped_stream(result, span, last_event_timestamp_ns) + self._add_response_details(span, result, last_event_timestamp_ns) + except Exception as exc: + # Set the span status to error + if isinstance(span.span_instance, Span): # pyright: ignore [reportPossiblyUnboundVariable] + span.span_instance.set_status( + StatusCode.ERROR, # pyright: ignore [reportPossiblyUnboundVariable] + description=str(exc), + ) + module = getattr(exc, "__module__", "") + module = module if module != "builtins" else "" + error_type = f"{module}.{type(exc).__name__}" if module else type(exc).__name__ + self._set_attributes(span, ("error.type", error_type)) + span.finish() + raise + + span.finish() + return result + + # Handle the default case (if the function name does not match) + return None # Ensure all paths return + + return inner + + def _trace_async_function( + self, + function: Callable, + *, + _args_to_ignore: Optional[List[str]] = None, + _trace_type=TraceType.INFERENCE, + _name: Optional[str] = None, + ) -> Callable: + """ + Decorator that adds tracing to an asynchronous function. + + :param function: The function to be traced. + :type function: Callable + :param args_to_ignore: A list of argument names to be ignored in the trace. + Defaults to None. + :type: args_to_ignore: [List[str]], optional + :param trace_type: The type of the trace. Defaults to TraceType.INFERENCE. + :type trace_type: TraceType, optional + :param name: The name of the trace, will set to func name if not provided. + :type name: str, optional + :return: The traced function. + :rtype: Callable + """ + + @functools.wraps(function) + async def inner(*args, **kwargs): + span_impl_type = settings.tracing_implementation() + if span_impl_type is None: + return await function(*args, **kwargs) + + class_function_name = function.__qualname__ + + if class_function_name.startswith("ChatCompletionsClient.complete"): + if kwargs.get("model") is None: + span_name = "chat" + else: + model = kwargs.get("model") + span_name = f"chat {model}" + + span = span_impl_type( + name=span_name, + kind=SpanKind.CLIENT, # pyright: ignore [reportPossiblyUnboundVariable] + ) + try: + # tracing events not supported in azure-core-tracing-opentelemetry + # so need to access the span instance directly + with span_impl_type.change_context(span.span_instance): + last_event_timestamp_ns = self._add_request_details(span, args, kwargs) + result = await function(*args, **kwargs) + if kwargs.get("stream") is True: + return self._async_wrapped_stream(result, span, last_event_timestamp_ns) + self._add_response_details(span, result, last_event_timestamp_ns) + + except Exception as exc: + # Set the span status to error + if isinstance(span.span_instance, Span): # pyright: ignore [reportPossiblyUnboundVariable] + span.span_instance.set_status( + StatusCode.ERROR, # pyright: ignore [reportPossiblyUnboundVariable] + description=str(exc), + ) + module = getattr(exc, "__module__", "") + module = module if module != "builtins" else "" + error_type = f"{module}.{type(exc).__name__}" if module else type(exc).__name__ + self._set_attributes(span, ("error.type", error_type)) + span.finish() + raise + + span.finish() + return result + + # Handle the default case (if the function name does not match) + return None # Ensure all paths return + + return inner + + def _inject_async(self, f, _trace_type, _name): + wrapper_fun = self._trace_async_function(f) + wrapper_fun._original = f # pylint: disable=protected-access # pyright: ignore [reportFunctionMemberAccess] + return wrapper_fun + + def _inject_sync(self, f, _trace_type, _name): + wrapper_fun = self._trace_sync_function(f) + wrapper_fun._original = f # pylint: disable=protected-access # pyright: ignore [reportFunctionMemberAccess] + return wrapper_fun + + def _inference_apis(self): + sync_apis = ( + ( + "azure.ai.inference", + "ChatCompletionsClient", + "complete", + TraceType.INFERENCE, + "inference_chat_completions_complete", + ), + ) + async_apis = ( + ( + "azure.ai.inference.aio", + "ChatCompletionsClient", + "complete", + TraceType.INFERENCE, + "inference_chat_completions_complete", + ), + ) + return sync_apis, async_apis + + def _inference_api_list(self): + sync_apis, async_apis = self._inference_apis() + yield sync_apis, self._inject_sync + yield async_apis, self._inject_async + + def _generate_api_and_injector(self, apis): + for api, injector in apis: + for module_name, class_name, method_name, trace_type, name in api: + try: + module = importlib.import_module(module_name) + api = getattr(module, class_name) + if hasattr(api, method_name): + yield api, method_name, trace_type, injector, name + except AttributeError as e: + # Log the attribute exception with the missing class information + logging.warning( + "AttributeError: The module '%s' does not have the class '%s'. %s", + module_name, + class_name, + str(e), + ) + except Exception as e: # pylint: disable=broad-except + # Log other exceptions as a warning, as we're not sure what they might be + logging.warning("An unexpected error occurred: '%s'", str(e)) + + def _available_inference_apis_and_injectors(self): + """ + Generates a sequence of tuples containing Inference API classes, method names, and + corresponding injector functions. + + :return: A generator yielding tuples. + :rtype: tuple + """ + yield from self._generate_api_and_injector(self._inference_api_list()) + + def _instrument_inference(self, enable_content_tracing: bool = False): + """This function modifies the methods of the Inference API classes to + inject logic before calling the original methods. + The original methods are stored as _original attributes of the methods. + + :param enable_content_tracing: Indicates whether tracing of message content should be enabled. + This also controls whether function call tool function names, + parameter names and parameter values are traced. + :type enable_content_tracing: bool + """ + # pylint: disable=W0603 + global _inference_traces_enabled + global _trace_inference_content + if _inference_traces_enabled: + raise RuntimeError("Traces already started for azure.ai.inference") + _inference_traces_enabled = True + _trace_inference_content = enable_content_tracing + for ( + api, + method, + trace_type, + injector, + name, + ) in self._available_inference_apis_and_injectors(): + # Check if the method of the api class has already been modified + if not hasattr(getattr(api, method), "_original"): + setattr(api, method, injector(getattr(api, method), trace_type, name)) + + def _uninstrument_inference(self): + """This function restores the original methods of the Inference API classes + by assigning them back from the _original attributes of the modified methods. + """ + # pylint: disable=W0603 + global _inference_traces_enabled + global _trace_inference_content + _trace_inference_content = False + for api, method, _, _, _ in self._available_inference_apis_and_injectors(): + if hasattr(getattr(api, method), "_original"): + setattr(api, method, getattr(getattr(api, method), "_original")) + _inference_traces_enabled = False + + def _is_instrumented(self): + """This function returns True if Inference libary has already been instrumented + for tracing and False if it has not been instrumented. + + :return: A value indicating whether the Inference library is currently instrumented or not. + :rtype: bool + """ + return _inference_traces_enabled + + def _set_content_recording_enabled(self, enable_content_recording: bool = False) -> None: + """This function sets the content recording value. + + :param enable_content_recording: Indicates whether tracing of message content should be enabled. + This also controls whether function call tool function names, + parameter names and parameter values are traced. + :type enable_content_recording: bool + """ + global _trace_inference_content # pylint: disable=W0603 + _trace_inference_content = enable_content_recording + + def _is_content_recording_enabled(self) -> bool: + """This function gets the content recording value. + + :return: A bool value indicating whether content tracing is enabled. + :rtype bool + """ + return _trace_inference_content |