about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/model_config.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/model_config.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/model_config.py')
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/model_config.py122
1 files changed, 122 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/model_config.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/model_config.py
new file mode 100644
index 00000000..c9e54da4
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/model_config.py
@@ -0,0 +1,122 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+from dataclasses import dataclass
+from typing import Any, Dict, Optional
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml._utils.utils import camel_to_snake
+from azure.ai.ml.entities._workspace.connections.workspace_connection import WorkspaceConnection
+from azure.ai.ml.entities._workspace.connections.connection_subtypes import (
+    AzureOpenAIConnection,
+    AadCredentialConfiguration,
+)
+
+
+@experimental
+@dataclass
+class ModelConfiguration:
+    """Configuration for a embedding model.
+
+    :param api_base: The base URL for the API.
+    :type api_base: Optional[str]
+    :param api_key: The API key.
+    :type api_key: Optional[str]
+    :param api_version: The API version.
+    :type api_version: Optional[str]
+    :param model_name: The name of the model.
+    :type model_name: Optional[str]
+    :param model_name: The deployment name of the model.
+    :type model_name: Optional[str]
+    :param connection_name: The name of the workspace connection of this model.
+    :type connection_name: Optional[str]
+    :param connection_type: The type of the workspace connection of this model.
+    :type connection_type: Optional[str]
+    :param model_kwargs: Additional keyword arguments for the model.
+    :type model_kwargs: Dict[str, Any]
+    """
+
+    api_base: Optional[str]
+    api_key: Optional[str]
+    api_version: Optional[str]
+    connection_name: Optional[str]
+    connection_type: Optional[str]
+    model_name: Optional[str]
+    deployment_name: Optional[str]
+    model_kwargs: Dict[str, Any]
+
+    def __init__(
+        self,
+        *,
+        api_base: Optional[str],
+        api_key: Optional[str],
+        api_version: Optional[str],
+        connection_name: Optional[str],
+        connection_type: Optional[str],
+        model_name: Optional[str],
+        deployment_name: Optional[str],
+        model_kwargs: Dict[str, Any]
+    ):
+        self.api_base = api_base
+        self.api_key = api_key
+        self.api_version = api_version
+        self.connection_name = connection_name
+        self.connection_type = connection_type
+        self.model_name = model_name
+        self.deployment_name = deployment_name
+        self.model_kwargs = model_kwargs
+
+    @staticmethod
+    def from_connection(
+        connection: WorkspaceConnection,
+        model_name: Optional[str] = None,
+        deployment_name: Optional[str] = None,
+        **kwargs
+    ) -> "ModelConfiguration":
+        """Create an model configuration from a Connection.
+
+        :param connection: The WorkspaceConnection object.
+        :type connection: ~azure.ai.ml.entities.WorkspaceConnection
+        :param model_name: The name of the model.
+        :type model_name: Optional[str]
+        :param deployment_name: The name of the deployment.
+        :type deployment_name: Optional[str]
+        :return: The model configuration.
+        :rtype: ~azure.ai.ml.entities._indexes.entities.ModelConfiguration
+        :raises TypeError: If the connection is not an AzureOpenAIConnection.
+        :raises ValueError: If the connection does not contain an OpenAI key.
+        """
+        if isinstance(connection, AzureOpenAIConnection) or camel_to_snake(connection.type) == "azure_open_ai":
+            connection_type = "azure_open_ai"
+            api_version = connection.api_version  # type: ignore[attr-defined]
+            if not model_name or not deployment_name:
+                raise ValueError("Please specify model_name and deployment_name.")
+        elif connection.type and connection.type.lower() == "serverless":
+            connection_type = "serverless"
+            api_version = None
+            if not connection.id:
+                raise TypeError("The connection id is missing from the serverless connection object.")
+        else:
+            raise TypeError("Connection object is not supported.")
+
+        if isinstance(connection.credentials, AadCredentialConfiguration):
+            key = None
+        else:
+            key = connection.credentials.get("key")  # type: ignore[union-attr]
+            if key is None and connection_type == "azure_open_ai":
+                import os
+
+                if "AZURE_OPENAI_API_KEY" in os.environ:
+                    key = os.getenv("AZURE_OPENAI_API_KEY")
+                else:
+                    raise ValueError("Unable to retrieve openai key from connection object or env variable.")
+
+        return ModelConfiguration(
+            api_base=connection.target,
+            api_key=key,
+            api_version=api_version,
+            connection_name=connection.name,
+            connection_type=connection_type,
+            model_name=model_name,
+            deployment_name=deployment_name,
+            model_kwargs=kwargs,
+        )