1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
|
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
from dataclasses import dataclass
from typing import Any, Dict, Optional
from azure.ai.ml._utils._experimental import experimental
from azure.ai.ml._utils.utils import camel_to_snake
from azure.ai.ml.entities._workspace.connections.workspace_connection import WorkspaceConnection
from azure.ai.ml.entities._workspace.connections.connection_subtypes import (
AzureOpenAIConnection,
AadCredentialConfiguration,
)
@experimental
@dataclass
class ModelConfiguration:
"""Configuration for a embedding model.
:param api_base: The base URL for the API.
:type api_base: Optional[str]
:param api_key: The API key.
:type api_key: Optional[str]
:param api_version: The API version.
:type api_version: Optional[str]
:param model_name: The name of the model.
:type model_name: Optional[str]
:param model_name: The deployment name of the model.
:type model_name: Optional[str]
:param connection_name: The name of the workspace connection of this model.
:type connection_name: Optional[str]
:param connection_type: The type of the workspace connection of this model.
:type connection_type: Optional[str]
:param model_kwargs: Additional keyword arguments for the model.
:type model_kwargs: Dict[str, Any]
"""
api_base: Optional[str]
api_key: Optional[str]
api_version: Optional[str]
connection_name: Optional[str]
connection_type: Optional[str]
model_name: Optional[str]
deployment_name: Optional[str]
model_kwargs: Dict[str, Any]
def __init__(
self,
*,
api_base: Optional[str],
api_key: Optional[str],
api_version: Optional[str],
connection_name: Optional[str],
connection_type: Optional[str],
model_name: Optional[str],
deployment_name: Optional[str],
model_kwargs: Dict[str, Any]
):
self.api_base = api_base
self.api_key = api_key
self.api_version = api_version
self.connection_name = connection_name
self.connection_type = connection_type
self.model_name = model_name
self.deployment_name = deployment_name
self.model_kwargs = model_kwargs
@staticmethod
def from_connection(
connection: WorkspaceConnection,
model_name: Optional[str] = None,
deployment_name: Optional[str] = None,
**kwargs
) -> "ModelConfiguration":
"""Create an model configuration from a Connection.
:param connection: The WorkspaceConnection object.
:type connection: ~azure.ai.ml.entities.WorkspaceConnection
:param model_name: The name of the model.
:type model_name: Optional[str]
:param deployment_name: The name of the deployment.
:type deployment_name: Optional[str]
:return: The model configuration.
:rtype: ~azure.ai.ml.entities._indexes.entities.ModelConfiguration
:raises TypeError: If the connection is not an AzureOpenAIConnection.
:raises ValueError: If the connection does not contain an OpenAI key.
"""
if isinstance(connection, AzureOpenAIConnection) or camel_to_snake(connection.type) == "azure_open_ai":
connection_type = "azure_open_ai"
api_version = connection.api_version # type: ignore[attr-defined]
if not model_name or not deployment_name:
raise ValueError("Please specify model_name and deployment_name.")
elif connection.type and connection.type.lower() == "serverless":
connection_type = "serverless"
api_version = None
if not connection.id:
raise TypeError("The connection id is missing from the serverless connection object.")
else:
raise TypeError("Connection object is not supported.")
if isinstance(connection.credentials, AadCredentialConfiguration):
key = None
else:
key = connection.credentials.get("key") # type: ignore[union-attr]
if key is None and connection_type == "azure_open_ai":
import os
if "AZURE_OPENAI_API_KEY" in os.environ:
key = os.getenv("AZURE_OPENAI_API_KEY")
else:
raise ValueError("Unable to retrieve openai key from connection object or env variable.")
return ModelConfiguration(
api_base=connection.target,
api_key=key,
api_version=api_version,
connection_name=connection.name,
connection_type=connection_type,
model_name=model_name,
deployment_name=deployment_name,
model_kwargs=kwargs,
)
|