about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_index_operations.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_index_operations.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_index_operations.py')
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_index_operations.py483
1 files changed, 483 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_index_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_index_operations.py
new file mode 100644
index 00000000..28e409c7
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_index_operations.py
@@ -0,0 +1,483 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+import json
+from typing import Any, Callable, Dict, Iterable, List, Optional, Union
+
+from azure.ai.ml._artifacts._artifact_utilities import _check_and_upload_path
+
+# cspell:disable-next-line
+from azure.ai.ml._restclient.azure_ai_assets_v2024_04_01.azureaiassetsv20240401 import (
+    MachineLearningServicesClient as AzureAiAssetsClient042024,
+)
+
+# cspell:disable-next-line
+from azure.ai.ml._restclient.azure_ai_assets_v2024_04_01.azureaiassetsv20240401.models import Index as RestIndex
+from azure.ai.ml._restclient.v2023_04_01_preview.models import ListViewType
+from azure.ai.ml._scope_dependent_operations import (
+    OperationConfig,
+    OperationsContainer,
+    OperationScope,
+    _ScopeDependentOperations,
+)
+from azure.ai.ml._telemetry import ActivityType, monitor_with_activity
+from azure.ai.ml._utils._asset_utils import (
+    _resolve_label_to_asset,
+    _validate_auto_delete_setting_in_data_output,
+    _validate_workspace_managed_datastore,
+)
+from azure.ai.ml._utils._http_utils import HttpPipeline
+from azure.ai.ml._utils._logger_utils import OpsLogger
+from azure.ai.ml._utils.utils import _get_base_urls_from_discovery_service
+from azure.ai.ml.constants._common import AssetTypes, AzureMLResourceType, WorkspaceDiscoveryUrlKey
+from azure.ai.ml.dsl import pipeline
+from azure.ai.ml.entities import PipelineJob, PipelineJobSettings
+from azure.ai.ml.entities._assets import Index
+from azure.ai.ml.entities._credentials import ManagedIdentityConfiguration, UserIdentityConfiguration
+from azure.ai.ml.entities._indexes import (
+    AzureAISearchConfig,
+    GitSource,
+    IndexDataSource,
+    LocalSource,
+    ModelConfiguration,
+)
+from azure.ai.ml.entities._indexes.data_index_func import index_data as index_data_func
+from azure.ai.ml.entities._indexes.entities.data_index import (
+    CitationRegex,
+    Data,
+    DataIndex,
+    Embedding,
+    IndexSource,
+    IndexStore,
+)
+from azure.ai.ml.entities._indexes.utils._open_ai_utils import build_connection_id, build_open_ai_protocol
+from azure.ai.ml.entities._inputs_outputs import Input
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException
+from azure.ai.ml.operations._datastore_operations import DatastoreOperations
+from azure.core.credentials import TokenCredential
+
+ops_logger = OpsLogger(__name__)
+module_logger = ops_logger.module_logger
+
+
+class IndexOperations(_ScopeDependentOperations):
+    """Represents a client for performing operations on index assets.
+
+    You should not instantiate this class directly. Instead, you should create MLClient and use this client via the
+    property MLClient.index
+    """
+
+    def __init__(
+        self,
+        *,
+        operation_scope: OperationScope,
+        operation_config: OperationConfig,
+        credential: TokenCredential,
+        datastore_operations: DatastoreOperations,
+        all_operations: OperationsContainer,
+        **kwargs: Any,
+    ):
+        super().__init__(operation_scope, operation_config)
+        ops_logger.update_filter()
+        self._credential = credential
+        # Dataplane service clients are lazily created as they are needed
+        self.__azure_ai_assets_client: Optional[AzureAiAssetsClient042024] = None
+        # Kwargs to propagate to dataplane service clients
+        self._service_client_kwargs: Dict[str, Any] = kwargs.pop("_service_client_kwargs", {})
+        self._all_operations = all_operations
+
+        self._datastore_operation: DatastoreOperations = datastore_operations
+        self._requests_pipeline: HttpPipeline = kwargs.pop("requests_pipeline")
+
+        # Maps a label to a function which given an asset name,
+        # returns the asset associated with the label
+        self._managed_label_resolver: Dict[str, Callable[[str], Index]] = {"latest": self._get_latest_version}
+
+    @property
+    def _azure_ai_assets(self) -> AzureAiAssetsClient042024:
+        """Lazily instantiated client for azure ai assets api.
+
+        .. note::
+
+            Property is lazily instantiated since the api's base url depends on the user's workspace, and
+            must be fetched remotely.
+        """
+        if self.__azure_ai_assets_client is None:
+            workspace_operations = self._all_operations.all_operations[AzureMLResourceType.WORKSPACE]
+
+            endpoint = _get_base_urls_from_discovery_service(
+                workspace_operations, self._operation_scope.workspace_name, self._requests_pipeline
+            )[WorkspaceDiscoveryUrlKey.API]
+
+            self.__azure_ai_assets_client = AzureAiAssetsClient042024(
+                endpoint=endpoint,
+                subscription_id=self._operation_scope.subscription_id,
+                resource_group_name=self._operation_scope.resource_group_name,
+                workspace_name=self._operation_scope.workspace_name,
+                credential=self._credential,
+                **self._service_client_kwargs,
+            )
+
+        return self.__azure_ai_assets_client
+
+    @monitor_with_activity(ops_logger, "Index.CreateOrUpdate", ActivityType.PUBLICAPI)
+    def create_or_update(self, index: Index, **kwargs) -> Index:
+        """Returns created or updated index asset.
+
+        If not already in storage, asset will be uploaded to the workspace's default datastore.
+
+        :param index: Index asset object.
+        :type index: Index
+        :return: Index asset object.
+        :rtype: ~azure.ai.ml.entities.Index
+        """
+
+        if not index.name:
+            msg = "Must specify a name."
+            raise ValidationException(
+                message=msg,
+                target=ErrorTarget.INDEX,
+                no_personal_data_message=msg,
+                error_category=ErrorCategory.USER_ERROR,
+                error_type=ValidationErrorType.MISSING_FIELD,
+            )
+
+        if not index.version:
+            if not index._auto_increment_version:
+                msg = "Must specify a version."
+                raise ValidationException(
+                    message=msg,
+                    target=ErrorTarget.INDEX,
+                    no_personal_data_message=msg,
+                    error_category=ErrorCategory.USER_ERROR,
+                    error_type=ValidationErrorType.MISSING_FIELD,
+                )
+
+            next_version = self._azure_ai_assets.indexes.get_next_version(index.name).next_version
+
+            if next_version is None:
+                msg = "Version not specified, could not automatically increment version. Set a version to resolve."
+                raise ValidationException(
+                    message=msg,
+                    target=ErrorTarget.INDEX,
+                    no_personal_data_message=msg,
+                    error_category=ErrorCategory.SYSTEM_ERROR,
+                    error_type=ValidationErrorType.MISSING_FIELD,
+                )
+
+            index.version = str(next_version)
+
+        _ = _check_and_upload_path(
+            artifact=index,
+            asset_operations=self,
+            datastore_name=index.datastore,
+            artifact_type=ErrorTarget.INDEX,
+            show_progress=self._show_progress,
+        )
+
+        return Index._from_rest_object(
+            self._azure_ai_assets.indexes.create_or_update(
+                name=index.name, version=index.version, body=index._to_rest_object(), **kwargs
+            )
+        )
+
+    @monitor_with_activity(ops_logger, "Index.Get", ActivityType.PUBLICAPI)
+    def get(self, name: str, *, version: Optional[str] = None, label: Optional[str] = None, **kwargs) -> Index:
+        """Returns information about the specified index asset.
+
+        :param str name: Name of the index asset.
+        :keyword Optional[str] version: Version of the index asset. Mutually exclusive with `label`.
+        :keyword Optional[str] label: The label of the index asset. Mutually exclusive with  `version`.
+        :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Index cannot be successfully validated.
+            Details will be provided in the error message.
+        :return: Index asset object.
+        :rtype: ~azure.ai.ml.entities.Index
+        """
+        if version and label:
+            msg = "Cannot specify both version and label."
+            raise ValidationException(
+                message=msg,
+                target=ErrorTarget.INDEX,
+                no_personal_data_message=msg,
+                error_category=ErrorCategory.USER_ERROR,
+                error_type=ValidationErrorType.INVALID_VALUE,
+            )
+
+        if label:
+            return _resolve_label_to_asset(self, name, label)
+
+        if not version:
+            msg = "Must provide either version or label."
+            raise ValidationException(
+                message=msg,
+                target=ErrorTarget.INDEX,
+                no_personal_data_message=msg,
+                error_category=ErrorCategory.USER_ERROR,
+                error_type=ValidationErrorType.MISSING_FIELD,
+            )
+
+        index_version_resource = self._azure_ai_assets.indexes.get(name=name, version=version, **kwargs)
+
+        return Index._from_rest_object(index_version_resource)
+
+    def _get_latest_version(self, name: str) -> Index:
+        return Index._from_rest_object(self._azure_ai_assets.indexes.get_latest(name))
+
+    @monitor_with_activity(ops_logger, "Index.List", ActivityType.PUBLICAPI)
+    def list(
+        self, name: Optional[str] = None, *, list_view_type: ListViewType = ListViewType.ACTIVE_ONLY, **kwargs
+    ) -> Iterable[Index]:
+        """List all Index assets in workspace.
+
+        If an Index is specified by name, list all version of that Index.
+
+        :param name: Name of the model.
+        :type name: Optional[str]
+        :keyword list_view_type: View type for including/excluding (for example) archived models.
+            Defaults to :attr:`ListViewType.ACTIVE_ONLY`.
+        :paramtype list_view_type: ListViewType
+        :return: An iterator like instance of Index objects
+        :rtype: ~azure.core.paging.ItemPaged[~azure.ai.ml.entities.Index]
+        """
+
+        def cls(rest_indexes: Iterable[RestIndex]) -> List[Index]:
+            return [Index._from_rest_object(i) for i in rest_indexes]
+
+        if name is None:
+            return self._azure_ai_assets.indexes.list_latest(cls=cls, **kwargs)
+
+        return self._azure_ai_assets.indexes.list(name, list_view_type=list_view_type, cls=cls, **kwargs)
+
+    def build_index(
+        self,
+        *,
+        ######## required args ##########
+        name: str,
+        embeddings_model_config: ModelConfiguration,
+        ######## chunking information ##########
+        data_source_citation_url: Optional[str] = None,
+        tokens_per_chunk: Optional[int] = None,
+        token_overlap_across_chunks: Optional[int] = None,
+        input_glob: Optional[str] = None,
+        ######## other generic args ########
+        document_path_replacement_regex: Optional[str] = None,
+        ######## Azure AI Search index info ########
+        index_config: Optional[AzureAISearchConfig] = None,  # todo better name?
+        ######## data source info ########
+        input_source: Union[IndexDataSource, str],
+        input_source_credential: Optional[Union[ManagedIdentityConfiguration, UserIdentityConfiguration]] = None,
+    ) -> Union["Index", "Job"]:  # type: ignore[name-defined]
+        """Builds an index on the cloud using the Azure AI Resources service.
+
+        :keyword name: The name of the index to be created.
+        :paramtype name: str
+        :keyword embeddings_model_config: Model config for the embedding model.
+        :paramtype embeddings_model_config: ~azure.ai.ml.entities._indexes.ModelConfiguration
+        :keyword data_source_citation_url: The URL of the data source.
+        :paramtype data_source_citation_url: Optional[str]
+        :keyword tokens_per_chunk: The size of chunks to be used for indexing.
+        :paramtype tokens_per_chunk: Optional[int]
+        :keyword token_overlap_across_chunks: The amount of overlap between chunks.
+        :paramtype token_overlap_across_chunks: Optional[int]
+        :keyword input_glob: The glob pattern to be used for indexing.
+        :paramtype input_glob: Optional[str]
+        :keyword document_path_replacement_regex: The regex pattern for replacing document paths.
+        :paramtype document_path_replacement_regex: Optional[str]
+        :keyword index_config: The configuration for the ACS output.
+        :paramtype index_config: Optional[~azure.ai.ml.entities._indexes.AzureAISearchConfig]
+        :keyword input_source: The input source for the index.
+        :paramtype input_source: Union[~azure.ai.ml.entities._indexes.IndexDataSource, str]
+        :keyword input_source_credential: The identity to be used for the index.
+        :paramtype input_source_credential: Optional[Union[~azure.ai.ml.entities.ManagedIdentityConfiguration,
+            ~azure.ai.ml.entities.UserIdentityConfiguration]]
+        :return: If the `source_input` is a GitSource, returns a created DataIndex Job object.
+        :rtype: Union[~azure.ai.ml.entities.Index, ~azure.ai.ml.entities.Job]
+        :raises ValueError: If the `source_input` is not type ~typing.Str or
+            ~azure.ai.ml.entities._indexes.LocalSource.
+        """
+        if document_path_replacement_regex:
+            document_path_replacement_regex = json.loads(document_path_replacement_regex)
+
+        data_index = DataIndex(
+            name=name,
+            source=IndexSource(
+                input_data=Data(
+                    type="uri_folder",
+                    path=".",
+                ),
+                input_glob=input_glob,
+                chunk_size=tokens_per_chunk,
+                chunk_overlap=token_overlap_across_chunks,
+                citation_url=data_source_citation_url,
+                citation_url_replacement_regex=(
+                    CitationRegex(
+                        match_pattern=document_path_replacement_regex["match_pattern"],  # type: ignore[index]
+                        replacement_pattern=document_path_replacement_regex[
+                            "replacement_pattern"  # type: ignore[index]
+                        ],
+                    )
+                    if document_path_replacement_regex
+                    else None
+                ),
+            ),
+            embedding=Embedding(
+                model=build_open_ai_protocol(
+                    model=embeddings_model_config.model_name, deployment=embeddings_model_config.deployment_name
+                ),
+                connection=build_connection_id(embeddings_model_config.connection_name, self._operation_scope),
+            ),
+            index=(
+                IndexStore(
+                    type="acs",
+                    connection=build_connection_id(index_config.connection_id, self._operation_scope),
+                    name=index_config.index_name,
+                )
+                if index_config is not None
+                else IndexStore(type="faiss")
+            ),
+            # name is replaced with a unique value each time the job is run
+            path=f"azureml://datastores/workspaceblobstore/paths/indexes/{name}/{{name}}",
+        )
+
+        if isinstance(input_source, LocalSource):
+            data_index.source.input_data = Data(
+                type="uri_folder",
+                path=input_source.input_data.path,
+            )
+
+            return self._create_data_indexing_job(data_index=data_index, identity=input_source_credential)
+
+        if isinstance(input_source, str):
+            data_index.source.input_data = Data(
+                type="uri_folder",
+                path=input_source,
+            )
+
+            return self._create_data_indexing_job(data_index=data_index, identity=input_source_credential)
+
+        if isinstance(input_source, GitSource):
+            from azure.ai.ml import MLClient
+
+            ml_registry = MLClient(credential=self._credential, registry_name="azureml")
+            git_clone_component = ml_registry.components.get("llm_rag_git_clone", label="latest")
+
+            # Clone Git Repo and use as input to index_job
+            @pipeline(default_compute="serverless")  # type: ignore[call-overload]
+            def git_to_index(
+                git_url,
+                branch_name="",
+                git_connection_id="",
+            ):
+                git_clone = git_clone_component(git_repository=git_url, branch_name=branch_name)
+                git_clone.environment_variables["AZUREML_WORKSPACE_CONNECTION_ID_GIT"] = git_connection_id
+
+                index_job = index_data_func(
+                    description=data_index.description,
+                    data_index=data_index,
+                    input_data_override=git_clone.outputs.output_data,
+                    ml_client=MLClient(
+                        subscription_id=self._subscription_id,
+                        resource_group_name=self._resource_group_name,
+                        workspace_name=self._workspace_name,
+                        credential=self._credential,
+                    ),
+                )
+                # pylint: disable=no-member
+                return index_job.outputs
+
+            git_index_job = git_to_index(
+                git_url=input_source.url,
+                branch_name=input_source.branch_name,
+                git_connection_id=input_source.connection_id,
+            )
+            # Ensure repo cloned each run to get latest, comment out to have first clone reused.
+            git_index_job.settings.force_rerun = True
+
+            # Submit the DataIndex Job
+            return self._all_operations.all_operations[AzureMLResourceType.JOB].create_or_update(git_index_job)
+        raise ValueError(f"Unsupported input source type {type(input_source)}")
+
+    def _create_data_indexing_job(
+        self,
+        data_index: DataIndex,
+        identity: Optional[Union[ManagedIdentityConfiguration, UserIdentityConfiguration]] = None,
+        compute: str = "serverless",
+        serverless_instance_type: Optional[str] = None,
+        input_data_override: Optional[Input] = None,
+        submit_job: bool = True,
+        **kwargs,
+    ) -> PipelineJob:
+        """
+        Returns the data import job that is creating the data asset.
+
+        :param data_index: DataIndex object.
+        :type data_index: azure.ai.ml.entities._dataindex
+        :param identity: Identity configuration for the job.
+        :type identity: Optional[Union[ManagedIdentityConfiguration, UserIdentityConfiguration]]
+        :param compute: The compute target to use for the job. Default: "serverless".
+        :type compute: str
+        :param serverless_instance_type: The instance type to use for serverless compute.
+        :type serverless_instance_type: Optional[str]
+        :param input_data_override: Input data override for the job.
+            Used to pipe output of step into DataIndex Job in a pipeline.
+        :type input_data_override: Optional[Input]
+        :param submit_job: Whether to submit the job to the service. Default: True.
+        :type submit_job: bool
+        :return: data import job object.
+        :rtype: ~azure.ai.ml.entities.PipelineJob.
+        """
+        # pylint: disable=no-member
+        from azure.ai.ml import MLClient
+
+        default_name = "data_index_" + data_index.name if data_index.name is not None else ""
+        experiment_name = kwargs.pop("experiment_name", None) or default_name
+        data_index.type = AssetTypes.URI_FOLDER
+
+        # avoid specifying auto_delete_setting in job output now
+        _validate_auto_delete_setting_in_data_output(data_index.auto_delete_setting)
+
+        # block customer specified path on managed datastore
+        data_index.path = _validate_workspace_managed_datastore(data_index.path)
+
+        if "${{name}}" not in str(data_index.path) and "{name}" not in str(data_index.path):
+            data_index.path = str(data_index.path).rstrip("/") + "/${{name}}"
+
+        index_job = index_data_func(
+            description=data_index.description or kwargs.pop("description", None) or default_name,
+            name=data_index.name or kwargs.pop("name", None),
+            display_name=kwargs.pop("display_name", None) or default_name,
+            experiment_name=experiment_name,
+            compute=compute,
+            serverless_instance_type=serverless_instance_type,
+            data_index=data_index,
+            ml_client=MLClient(
+                subscription_id=self._subscription_id,
+                resource_group_name=self._resource_group_name,
+                workspace_name=self._workspace_name,
+                credential=self._credential,
+            ),
+            identity=identity,
+            input_data_override=input_data_override,
+            **kwargs,
+        )
+        index_pipeline = PipelineJob(
+            description=index_job.description,
+            tags=index_job.tags,
+            name=index_job.name,
+            display_name=index_job.display_name,
+            experiment_name=experiment_name,
+            properties=index_job.properties or {},
+            settings=PipelineJobSettings(force_rerun=True, default_compute=compute),
+            jobs={default_name: index_job},
+        )
+        index_pipeline.properties["azureml.mlIndexAssetName"] = data_index.name
+        index_pipeline.properties["azureml.mlIndexAssetKind"] = data_index.index.type
+        index_pipeline.properties["azureml.mlIndexAssetSource"] = kwargs.pop("mlindex_asset_source", "Data Asset")
+
+        if submit_job:
+            return self._all_operations.all_operations[AzureMLResourceType.JOB].create_or_update(
+                job=index_pipeline, skip_validation=True, **kwargs
+            )
+        return index_pipeline