diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_index_operations.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_index_operations.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/azure/ai/ml/operations/_index_operations.py | 483 |
1 files changed, 483 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_index_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_index_operations.py new file mode 100644 index 00000000..28e409c7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_index_operations.py @@ -0,0 +1,483 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access +import json +from typing import Any, Callable, Dict, Iterable, List, Optional, Union + +from azure.ai.ml._artifacts._artifact_utilities import _check_and_upload_path + +# cspell:disable-next-line +from azure.ai.ml._restclient.azure_ai_assets_v2024_04_01.azureaiassetsv20240401 import ( + MachineLearningServicesClient as AzureAiAssetsClient042024, +) + +# cspell:disable-next-line +from azure.ai.ml._restclient.azure_ai_assets_v2024_04_01.azureaiassetsv20240401.models import Index as RestIndex +from azure.ai.ml._restclient.v2023_04_01_preview.models import ListViewType +from azure.ai.ml._scope_dependent_operations import ( + OperationConfig, + OperationsContainer, + OperationScope, + _ScopeDependentOperations, +) +from azure.ai.ml._telemetry import ActivityType, monitor_with_activity +from azure.ai.ml._utils._asset_utils import ( + _resolve_label_to_asset, + _validate_auto_delete_setting_in_data_output, + _validate_workspace_managed_datastore, +) +from azure.ai.ml._utils._http_utils import HttpPipeline +from azure.ai.ml._utils._logger_utils import OpsLogger +from azure.ai.ml._utils.utils import _get_base_urls_from_discovery_service +from azure.ai.ml.constants._common import AssetTypes, AzureMLResourceType, WorkspaceDiscoveryUrlKey +from azure.ai.ml.dsl import pipeline +from azure.ai.ml.entities import PipelineJob, PipelineJobSettings +from azure.ai.ml.entities._assets import Index +from azure.ai.ml.entities._credentials import ManagedIdentityConfiguration, UserIdentityConfiguration +from azure.ai.ml.entities._indexes import ( + AzureAISearchConfig, + GitSource, + IndexDataSource, + LocalSource, + ModelConfiguration, +) +from azure.ai.ml.entities._indexes.data_index_func import index_data as index_data_func +from azure.ai.ml.entities._indexes.entities.data_index import ( + CitationRegex, + Data, + DataIndex, + Embedding, + IndexSource, + IndexStore, +) +from azure.ai.ml.entities._indexes.utils._open_ai_utils import build_connection_id, build_open_ai_protocol +from azure.ai.ml.entities._inputs_outputs import Input +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException +from azure.ai.ml.operations._datastore_operations import DatastoreOperations +from azure.core.credentials import TokenCredential + +ops_logger = OpsLogger(__name__) +module_logger = ops_logger.module_logger + + +class IndexOperations(_ScopeDependentOperations): + """Represents a client for performing operations on index assets. + + You should not instantiate this class directly. Instead, you should create MLClient and use this client via the + property MLClient.index + """ + + def __init__( + self, + *, + operation_scope: OperationScope, + operation_config: OperationConfig, + credential: TokenCredential, + datastore_operations: DatastoreOperations, + all_operations: OperationsContainer, + **kwargs: Any, + ): + super().__init__(operation_scope, operation_config) + ops_logger.update_filter() + self._credential = credential + # Dataplane service clients are lazily created as they are needed + self.__azure_ai_assets_client: Optional[AzureAiAssetsClient042024] = None + # Kwargs to propagate to dataplane service clients + self._service_client_kwargs: Dict[str, Any] = kwargs.pop("_service_client_kwargs", {}) + self._all_operations = all_operations + + self._datastore_operation: DatastoreOperations = datastore_operations + self._requests_pipeline: HttpPipeline = kwargs.pop("requests_pipeline") + + # Maps a label to a function which given an asset name, + # returns the asset associated with the label + self._managed_label_resolver: Dict[str, Callable[[str], Index]] = {"latest": self._get_latest_version} + + @property + def _azure_ai_assets(self) -> AzureAiAssetsClient042024: + """Lazily instantiated client for azure ai assets api. + + .. note:: + + Property is lazily instantiated since the api's base url depends on the user's workspace, and + must be fetched remotely. + """ + if self.__azure_ai_assets_client is None: + workspace_operations = self._all_operations.all_operations[AzureMLResourceType.WORKSPACE] + + endpoint = _get_base_urls_from_discovery_service( + workspace_operations, self._operation_scope.workspace_name, self._requests_pipeline + )[WorkspaceDiscoveryUrlKey.API] + + self.__azure_ai_assets_client = AzureAiAssetsClient042024( + endpoint=endpoint, + subscription_id=self._operation_scope.subscription_id, + resource_group_name=self._operation_scope.resource_group_name, + workspace_name=self._operation_scope.workspace_name, + credential=self._credential, + **self._service_client_kwargs, + ) + + return self.__azure_ai_assets_client + + @monitor_with_activity(ops_logger, "Index.CreateOrUpdate", ActivityType.PUBLICAPI) + def create_or_update(self, index: Index, **kwargs) -> Index: + """Returns created or updated index asset. + + If not already in storage, asset will be uploaded to the workspace's default datastore. + + :param index: Index asset object. + :type index: Index + :return: Index asset object. + :rtype: ~azure.ai.ml.entities.Index + """ + + if not index.name: + msg = "Must specify a name." + raise ValidationException( + message=msg, + target=ErrorTarget.INDEX, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.MISSING_FIELD, + ) + + if not index.version: + if not index._auto_increment_version: + msg = "Must specify a version." + raise ValidationException( + message=msg, + target=ErrorTarget.INDEX, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.MISSING_FIELD, + ) + + next_version = self._azure_ai_assets.indexes.get_next_version(index.name).next_version + + if next_version is None: + msg = "Version not specified, could not automatically increment version. Set a version to resolve." + raise ValidationException( + message=msg, + target=ErrorTarget.INDEX, + no_personal_data_message=msg, + error_category=ErrorCategory.SYSTEM_ERROR, + error_type=ValidationErrorType.MISSING_FIELD, + ) + + index.version = str(next_version) + + _ = _check_and_upload_path( + artifact=index, + asset_operations=self, + datastore_name=index.datastore, + artifact_type=ErrorTarget.INDEX, + show_progress=self._show_progress, + ) + + return Index._from_rest_object( + self._azure_ai_assets.indexes.create_or_update( + name=index.name, version=index.version, body=index._to_rest_object(), **kwargs + ) + ) + + @monitor_with_activity(ops_logger, "Index.Get", ActivityType.PUBLICAPI) + def get(self, name: str, *, version: Optional[str] = None, label: Optional[str] = None, **kwargs) -> Index: + """Returns information about the specified index asset. + + :param str name: Name of the index asset. + :keyword Optional[str] version: Version of the index asset. Mutually exclusive with `label`. + :keyword Optional[str] label: The label of the index asset. Mutually exclusive with `version`. + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Index cannot be successfully validated. + Details will be provided in the error message. + :return: Index asset object. + :rtype: ~azure.ai.ml.entities.Index + """ + if version and label: + msg = "Cannot specify both version and label." + raise ValidationException( + message=msg, + target=ErrorTarget.INDEX, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + + if label: + return _resolve_label_to_asset(self, name, label) + + if not version: + msg = "Must provide either version or label." + raise ValidationException( + message=msg, + target=ErrorTarget.INDEX, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.MISSING_FIELD, + ) + + index_version_resource = self._azure_ai_assets.indexes.get(name=name, version=version, **kwargs) + + return Index._from_rest_object(index_version_resource) + + def _get_latest_version(self, name: str) -> Index: + return Index._from_rest_object(self._azure_ai_assets.indexes.get_latest(name)) + + @monitor_with_activity(ops_logger, "Index.List", ActivityType.PUBLICAPI) + def list( + self, name: Optional[str] = None, *, list_view_type: ListViewType = ListViewType.ACTIVE_ONLY, **kwargs + ) -> Iterable[Index]: + """List all Index assets in workspace. + + If an Index is specified by name, list all version of that Index. + + :param name: Name of the model. + :type name: Optional[str] + :keyword list_view_type: View type for including/excluding (for example) archived models. + Defaults to :attr:`ListViewType.ACTIVE_ONLY`. + :paramtype list_view_type: ListViewType + :return: An iterator like instance of Index objects + :rtype: ~azure.core.paging.ItemPaged[~azure.ai.ml.entities.Index] + """ + + def cls(rest_indexes: Iterable[RestIndex]) -> List[Index]: + return [Index._from_rest_object(i) for i in rest_indexes] + + if name is None: + return self._azure_ai_assets.indexes.list_latest(cls=cls, **kwargs) + + return self._azure_ai_assets.indexes.list(name, list_view_type=list_view_type, cls=cls, **kwargs) + + def build_index( + self, + *, + ######## required args ########## + name: str, + embeddings_model_config: ModelConfiguration, + ######## chunking information ########## + data_source_citation_url: Optional[str] = None, + tokens_per_chunk: Optional[int] = None, + token_overlap_across_chunks: Optional[int] = None, + input_glob: Optional[str] = None, + ######## other generic args ######## + document_path_replacement_regex: Optional[str] = None, + ######## Azure AI Search index info ######## + index_config: Optional[AzureAISearchConfig] = None, # todo better name? + ######## data source info ######## + input_source: Union[IndexDataSource, str], + input_source_credential: Optional[Union[ManagedIdentityConfiguration, UserIdentityConfiguration]] = None, + ) -> Union["Index", "Job"]: # type: ignore[name-defined] + """Builds an index on the cloud using the Azure AI Resources service. + + :keyword name: The name of the index to be created. + :paramtype name: str + :keyword embeddings_model_config: Model config for the embedding model. + :paramtype embeddings_model_config: ~azure.ai.ml.entities._indexes.ModelConfiguration + :keyword data_source_citation_url: The URL of the data source. + :paramtype data_source_citation_url: Optional[str] + :keyword tokens_per_chunk: The size of chunks to be used for indexing. + :paramtype tokens_per_chunk: Optional[int] + :keyword token_overlap_across_chunks: The amount of overlap between chunks. + :paramtype token_overlap_across_chunks: Optional[int] + :keyword input_glob: The glob pattern to be used for indexing. + :paramtype input_glob: Optional[str] + :keyword document_path_replacement_regex: The regex pattern for replacing document paths. + :paramtype document_path_replacement_regex: Optional[str] + :keyword index_config: The configuration for the ACS output. + :paramtype index_config: Optional[~azure.ai.ml.entities._indexes.AzureAISearchConfig] + :keyword input_source: The input source for the index. + :paramtype input_source: Union[~azure.ai.ml.entities._indexes.IndexDataSource, str] + :keyword input_source_credential: The identity to be used for the index. + :paramtype input_source_credential: Optional[Union[~azure.ai.ml.entities.ManagedIdentityConfiguration, + ~azure.ai.ml.entities.UserIdentityConfiguration]] + :return: If the `source_input` is a GitSource, returns a created DataIndex Job object. + :rtype: Union[~azure.ai.ml.entities.Index, ~azure.ai.ml.entities.Job] + :raises ValueError: If the `source_input` is not type ~typing.Str or + ~azure.ai.ml.entities._indexes.LocalSource. + """ + if document_path_replacement_regex: + document_path_replacement_regex = json.loads(document_path_replacement_regex) + + data_index = DataIndex( + name=name, + source=IndexSource( + input_data=Data( + type="uri_folder", + path=".", + ), + input_glob=input_glob, + chunk_size=tokens_per_chunk, + chunk_overlap=token_overlap_across_chunks, + citation_url=data_source_citation_url, + citation_url_replacement_regex=( + CitationRegex( + match_pattern=document_path_replacement_regex["match_pattern"], # type: ignore[index] + replacement_pattern=document_path_replacement_regex[ + "replacement_pattern" # type: ignore[index] + ], + ) + if document_path_replacement_regex + else None + ), + ), + embedding=Embedding( + model=build_open_ai_protocol( + model=embeddings_model_config.model_name, deployment=embeddings_model_config.deployment_name + ), + connection=build_connection_id(embeddings_model_config.connection_name, self._operation_scope), + ), + index=( + IndexStore( + type="acs", + connection=build_connection_id(index_config.connection_id, self._operation_scope), + name=index_config.index_name, + ) + if index_config is not None + else IndexStore(type="faiss") + ), + # name is replaced with a unique value each time the job is run + path=f"azureml://datastores/workspaceblobstore/paths/indexes/{name}/{{name}}", + ) + + if isinstance(input_source, LocalSource): + data_index.source.input_data = Data( + type="uri_folder", + path=input_source.input_data.path, + ) + + return self._create_data_indexing_job(data_index=data_index, identity=input_source_credential) + + if isinstance(input_source, str): + data_index.source.input_data = Data( + type="uri_folder", + path=input_source, + ) + + return self._create_data_indexing_job(data_index=data_index, identity=input_source_credential) + + if isinstance(input_source, GitSource): + from azure.ai.ml import MLClient + + ml_registry = MLClient(credential=self._credential, registry_name="azureml") + git_clone_component = ml_registry.components.get("llm_rag_git_clone", label="latest") + + # Clone Git Repo and use as input to index_job + @pipeline(default_compute="serverless") # type: ignore[call-overload] + def git_to_index( + git_url, + branch_name="", + git_connection_id="", + ): + git_clone = git_clone_component(git_repository=git_url, branch_name=branch_name) + git_clone.environment_variables["AZUREML_WORKSPACE_CONNECTION_ID_GIT"] = git_connection_id + + index_job = index_data_func( + description=data_index.description, + data_index=data_index, + input_data_override=git_clone.outputs.output_data, + ml_client=MLClient( + subscription_id=self._subscription_id, + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + credential=self._credential, + ), + ) + # pylint: disable=no-member + return index_job.outputs + + git_index_job = git_to_index( + git_url=input_source.url, + branch_name=input_source.branch_name, + git_connection_id=input_source.connection_id, + ) + # Ensure repo cloned each run to get latest, comment out to have first clone reused. + git_index_job.settings.force_rerun = True + + # Submit the DataIndex Job + return self._all_operations.all_operations[AzureMLResourceType.JOB].create_or_update(git_index_job) + raise ValueError(f"Unsupported input source type {type(input_source)}") + + def _create_data_indexing_job( + self, + data_index: DataIndex, + identity: Optional[Union[ManagedIdentityConfiguration, UserIdentityConfiguration]] = None, + compute: str = "serverless", + serverless_instance_type: Optional[str] = None, + input_data_override: Optional[Input] = None, + submit_job: bool = True, + **kwargs, + ) -> PipelineJob: + """ + Returns the data import job that is creating the data asset. + + :param data_index: DataIndex object. + :type data_index: azure.ai.ml.entities._dataindex + :param identity: Identity configuration for the job. + :type identity: Optional[Union[ManagedIdentityConfiguration, UserIdentityConfiguration]] + :param compute: The compute target to use for the job. Default: "serverless". + :type compute: str + :param serverless_instance_type: The instance type to use for serverless compute. + :type serverless_instance_type: Optional[str] + :param input_data_override: Input data override for the job. + Used to pipe output of step into DataIndex Job in a pipeline. + :type input_data_override: Optional[Input] + :param submit_job: Whether to submit the job to the service. Default: True. + :type submit_job: bool + :return: data import job object. + :rtype: ~azure.ai.ml.entities.PipelineJob. + """ + # pylint: disable=no-member + from azure.ai.ml import MLClient + + default_name = "data_index_" + data_index.name if data_index.name is not None else "" + experiment_name = kwargs.pop("experiment_name", None) or default_name + data_index.type = AssetTypes.URI_FOLDER + + # avoid specifying auto_delete_setting in job output now + _validate_auto_delete_setting_in_data_output(data_index.auto_delete_setting) + + # block customer specified path on managed datastore + data_index.path = _validate_workspace_managed_datastore(data_index.path) + + if "${{name}}" not in str(data_index.path) and "{name}" not in str(data_index.path): + data_index.path = str(data_index.path).rstrip("/") + "/${{name}}" + + index_job = index_data_func( + description=data_index.description or kwargs.pop("description", None) or default_name, + name=data_index.name or kwargs.pop("name", None), + display_name=kwargs.pop("display_name", None) or default_name, + experiment_name=experiment_name, + compute=compute, + serverless_instance_type=serverless_instance_type, + data_index=data_index, + ml_client=MLClient( + subscription_id=self._subscription_id, + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + credential=self._credential, + ), + identity=identity, + input_data_override=input_data_override, + **kwargs, + ) + index_pipeline = PipelineJob( + description=index_job.description, + tags=index_job.tags, + name=index_job.name, + display_name=index_job.display_name, + experiment_name=experiment_name, + properties=index_job.properties or {}, + settings=PipelineJobSettings(force_rerun=True, default_compute=compute), + jobs={default_name: index_job}, + ) + index_pipeline.properties["azureml.mlIndexAssetName"] = data_index.name + index_pipeline.properties["azureml.mlIndexAssetKind"] = data_index.index.type + index_pipeline.properties["azureml.mlIndexAssetSource"] = kwargs.pop("mlindex_asset_source", "Data Asset") + + if submit_job: + return self._all_operations.all_operations[AzureMLResourceType.JOB].create_or_update( + job=index_pipeline, skip_validation=True, **kwargs + ) + return index_pipeline |