diff options
| author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
|---|---|---|
| committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
| commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
| tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes | |
| parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
| download | gn-ai-master.tar.gz | |
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes')
13 files changed, 1598 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/__init__.py new file mode 100644 index 00000000..43f615c3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/__init__.py @@ -0,0 +1,16 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""AzureML Retrieval Augmented Generation (RAG) utilities.""" + +from .input._ai_search_config import AzureAISearchConfig +from .input._index_data_source import IndexDataSource, GitSource, LocalSource +from .model_config import ModelConfiguration + +__all__ = [ + "ModelConfiguration", + "AzureAISearchConfig", + "IndexDataSource", + "GitSource", + "LocalSource", +] diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/data_index_func.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/data_index_func.py new file mode 100644 index 00000000..884faf82 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/data_index_func.py @@ -0,0 +1,748 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +# pylint: disable=protected-access +# pylint: disable=no-member + +import json +import re +from typing import Any, Dict, Optional, Union + +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml.constants._common import AssetTypes, LegacyAssetTypes +from azure.ai.ml.entities import PipelineJob +from azure.ai.ml.entities._builders.base_node import pipeline_node_decorator +from azure.ai.ml.entities._credentials import ManagedIdentityConfiguration, UserIdentityConfiguration +from azure.ai.ml.entities._inputs_outputs import Input, Output +from azure.ai.ml.entities._job.pipeline._io import PipelineInput +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException +from azure.ai.ml.constants._common import DataIndexTypes +from azure.ai.ml.constants._component import LLMRAGComponentUri +from azure.ai.ml.entities._indexes.entities.data_index import DataIndex + +SUPPORTED_INPUTS = [ + LegacyAssetTypes.PATH, + AssetTypes.URI_FILE, + AssetTypes.URI_FOLDER, + AssetTypes.MLTABLE, +] + + +def _build_data_index(io_dict: Union[Dict, DataIndex]): + if io_dict is None: + return io_dict + if isinstance(io_dict, DataIndex): + component_io = io_dict + else: + if isinstance(io_dict, dict): + component_io = DataIndex(**io_dict) + else: + msg = "data_index only support dict and DataIndex" + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.DATA, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + + return component_io + + +@experimental +@pipeline_node_decorator +def index_data( + *, + data_index: DataIndex, + description: Optional[str] = None, + tags: Optional[Dict] = None, + name: Optional[str] = None, + display_name: Optional[str] = None, + experiment_name: Optional[str] = None, + compute: Optional[str] = None, + serverless_instance_type: Optional[str] = None, + ml_client: Optional[Any] = None, + identity: Optional[Union[ManagedIdentityConfiguration, UserIdentityConfiguration]] = None, + input_data_override: Optional[Input] = None, + **kwargs, +) -> PipelineJob: + """ + Create a PipelineJob object which can be used inside dsl.pipeline. + + :keyword data_index: The data index configuration. + :type data_index: DataIndex + :keyword description: Description of the job. + :type description: str + :keyword tags: Tag dictionary. Tags can be added, removed, and updated. + :type tags: dict[str, str] + :keyword name: Name of the job. + :type name: str + :keyword display_name: Display name of the job. + :type display_name: str + :keyword experiment_name: Name of the experiment the job will be created under. + :type experiment_name: str + :keyword compute: The compute resource the job runs on. + :type compute: str + :keyword serverless_instance_type: The instance type to use for serverless compute. + :type serverless_instance_type: Optional[str] + :keyword ml_client: The ml client to use for the job. + :type ml_client: Any + :keyword identity: Identity configuration for the job. + :type identity: Optional[Union[ManagedIdentityConfiguration, UserIdentityConfiguration]] + :keyword input_data_override: Input data override for the job. + Used to pipe output of step into DataIndex Job in a pipeline. + :type input_data_override: Optional[Input] + :return: A PipelineJob object. + :rtype: ~azure.ai.ml.entities.PipelineJob. + """ + data_index = _build_data_index(data_index) + + if data_index.index.type == DataIndexTypes.FAISS: + configured_component = data_index_faiss( + ml_client, + data_index, + description, + tags, + name, + display_name, + experiment_name, + compute, + serverless_instance_type, + identity, + input_data_override, + ) + elif data_index.index.type in (DataIndexTypes.ACS, DataIndexTypes.PINECONE): + if kwargs.get("incremental_update", False): + configured_component = data_index_incremental_update_hosted( + ml_client, + data_index, + description, + tags, + name, + display_name, + experiment_name, + compute, + serverless_instance_type, + identity, + input_data_override, + ) + else: + configured_component = data_index_hosted( + ml_client, + data_index, + description, + tags, + name, + display_name, + experiment_name, + compute, + serverless_instance_type, + identity, + input_data_override, + ) + else: + raise ValueError(f"Unsupported index type: {data_index.index.type}") + + configured_component.properties["azureml.mlIndexAssetName"] = data_index.name + configured_component.properties["azureml.mlIndexAssetKind"] = data_index.index.type + configured_component.properties["azureml.mlIndexAssetSource"] = "Data Asset" + + return configured_component + + +# pylint: disable=too-many-statements +def data_index_incremental_update_hosted( + ml_client: Any, + data_index: DataIndex, + description: Optional[str] = None, + tags: Optional[Dict] = None, + name: Optional[str] = None, + display_name: Optional[str] = None, + experiment_name: Optional[str] = None, + compute: Optional[str] = None, + serverless_instance_type: Optional[str] = None, + identity: Optional[Union[ManagedIdentityConfiguration, UserIdentityConfiguration]] = None, + input_data_override: Optional[Input] = None, +): + from azure.ai.ml.entities._indexes.utils import build_model_protocol, pipeline + + crack_and_chunk_and_embed_component = get_component_obj( + ml_client, LLMRAGComponentUri.LLM_RAG_CRACK_AND_CHUNK_AND_EMBED + ) + + if data_index.index.type == DataIndexTypes.ACS: + update_index_component = get_component_obj(ml_client, LLMRAGComponentUri.LLM_RAG_UPDATE_ACS_INDEX) + elif data_index.index.type == DataIndexTypes.PINECONE: + update_index_component = get_component_obj(ml_client, LLMRAGComponentUri.LLM_RAG_UPDATE_PINECONE_INDEX) + else: + raise ValueError(f"Unsupported hosted index type: {data_index.index.type}") + + register_mlindex_asset_component = get_component_obj(ml_client, LLMRAGComponentUri.LLM_RAG_REGISTER_MLINDEX_ASSET) + + @pipeline( # type: ignore [call-overload] + name=name if name else f"data_index_incremental_update_{data_index.index.type}", + description=description, + tags=tags, + display_name=( + display_name if display_name else f"LLM - Data to {data_index.index.type.upper()} (Incremental Update)" + ), + experiment_name=experiment_name, + compute=compute, + get_component=True, + ) + def data_index_pipeline( + input_data: Input, + embeddings_model: str, + index_config: str, + index_connection_id: str, + chunk_size: int = 768, + chunk_overlap: int = 0, + input_glob: str = "**/*", + citation_url: Optional[str] = None, + citation_replacement_regex: Optional[str] = None, + aoai_connection_id: Optional[str] = None, + embeddings_container: Optional[Input] = None, + ): + """ + Generate embeddings for a `input_data` source and + push them into a hosted index (such as Azure Cognitive Search and Pinecone). + + :param input_data: The input data to be indexed. + :type input_data: Input + :param embeddings_model: The embedding model to use when processing source data chunks. + :type embeddings_model: str + :param index_config: The configuration for the hosted index. + :type index_config: str + :param index_connection_id: The connection ID for the hosted index. + :type index_connection_id: str + :param chunk_size: The size of the chunks to break the input data into. + :type chunk_size: int + :param chunk_overlap: The number of tokens to overlap between chunks. + :type chunk_overlap: int + :param input_glob: The glob pattern to use when searching for input data. + :type input_glob: str + :param citation_url: The URL to use when generating citations for the input data. + :type citation_url: str + :param citation_replacement_regex: The regex to use when generating citations for the input data. + :type citation_replacement_regex: str + :param aoai_connection_id: The connection ID for the Azure Open AI service. + :type aoai_connection_id: str + :param embeddings_container: The container to use when caching embeddings. + :type embeddings_container: Input + :return: The URI of the generated Azure Cognitive Search index. + :rtype: str. + """ + crack_and_chunk_and_embed = crack_and_chunk_and_embed_component( + input_data=input_data, + input_glob=input_glob, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + citation_url=citation_url, + citation_replacement_regex=citation_replacement_regex, + embeddings_container=embeddings_container, + embeddings_model=embeddings_model, + embeddings_connection_id=aoai_connection_id, + ) + if compute is None or compute == "serverless": + use_automatic_compute(crack_and_chunk_and_embed, instance_type=serverless_instance_type) + if optional_pipeline_input_provided(embeddings_container): # type: ignore [arg-type] + crack_and_chunk_and_embed.outputs.embeddings = Output( + type="uri_folder", path=f"{embeddings_container.path}/{{name}}" # type: ignore [union-attr] + ) + if identity: + crack_and_chunk_and_embed.identity = identity + + if data_index.index.type == DataIndexTypes.ACS: + update_index = update_index_component( + embeddings=crack_and_chunk_and_embed.outputs.embeddings, acs_config=index_config + ) + update_index.environment_variables["AZUREML_WORKSPACE_CONNECTION_ID_ACS"] = index_connection_id + elif data_index.index.type == DataIndexTypes.PINECONE: + update_index = update_index_component( + embeddings=crack_and_chunk_and_embed.outputs.embeddings, pinecone_config=index_config + ) + update_index.environment_variables["AZUREML_WORKSPACE_CONNECTION_ID_PINECONE"] = index_connection_id + else: + raise ValueError(f"Unsupported hosted index type: {data_index.index.type}") + if compute is None or compute == "serverless": + use_automatic_compute(update_index, instance_type=serverless_instance_type) + if identity: + update_index.identity = identity + + register_mlindex_asset = register_mlindex_asset_component( + storage_uri=update_index.outputs.index, + asset_name=data_index.name, + ) + if compute is None or compute == "serverless": + use_automatic_compute(register_mlindex_asset, instance_type=serverless_instance_type) + if identity: + register_mlindex_asset.identity = identity + return { + "mlindex_asset_uri": update_index.outputs.index, + "mlindex_asset_id": register_mlindex_asset.outputs.asset_id, + } + + if input_data_override is not None: + input_data = input_data_override + else: + input_data = Input( + type=data_index.source.input_data.type, path=data_index.source.input_data.path # type: ignore [arg-type] + ) + + index_config = { + "index_name": data_index.index.name if data_index.index.name is not None else data_index.name, + "full_sync": True, + } + if data_index.index.config is not None: + index_config.update(data_index.index.config) + + component = data_index_pipeline( + input_data=input_data, + input_glob=data_index.source.input_glob, # type: ignore [arg-type] + chunk_size=data_index.source.chunk_size, # type: ignore [arg-type] + chunk_overlap=data_index.source.chunk_overlap, # type: ignore [arg-type] + citation_url=data_index.source.citation_url, + citation_replacement_regex=( + json.dumps(data_index.source.citation_url_replacement_regex._to_dict()) + if data_index.source.citation_url_replacement_regex + else None + ), + embeddings_model=build_model_protocol(data_index.embedding.model), + aoai_connection_id=_resolve_connection_id(ml_client, data_index.embedding.connection), + embeddings_container=( + Input(type=AssetTypes.URI_FOLDER, path=data_index.embedding.cache_path) + if data_index.embedding.cache_path + else None + ), + index_config=json.dumps(index_config), + index_connection_id=_resolve_connection_id(ml_client, data_index.index.connection), # type: ignore [arg-type] + ) + # Hack until full Component classes are implemented that can annotate the optional parameters properly + component.inputs["input_glob"]._meta.optional = True + component.inputs["chunk_size"]._meta.optional = True + component.inputs["chunk_overlap"]._meta.optional = True + component.inputs["citation_url"]._meta.optional = True + component.inputs["citation_replacement_regex"]._meta.optional = True + component.inputs["aoai_connection_id"]._meta.optional = True + component.inputs["embeddings_container"]._meta.optional = True + + if data_index.path: + component.outputs.mlindex_asset_uri = Output( # type: ignore [attr-defined] + type=AssetTypes.URI_FOLDER, path=data_index.path # type: ignore [arg-type] + ) + + return component + + +def data_index_faiss( + ml_client: Any, + data_index: DataIndex, + description: Optional[str] = None, + tags: Optional[Dict] = None, + name: Optional[str] = None, + display_name: Optional[str] = None, + experiment_name: Optional[str] = None, + compute: Optional[str] = None, + serverless_instance_type: Optional[str] = None, + identity: Optional[Union[ManagedIdentityConfiguration, UserIdentityConfiguration]] = None, + input_data_override: Optional[Input] = None, +): + from azure.ai.ml.entities._indexes.utils import build_model_protocol, pipeline + + crack_and_chunk_component = get_component_obj(ml_client, LLMRAGComponentUri.LLM_RAG_CRACK_AND_CHUNK) + generate_embeddings_component = get_component_obj(ml_client, LLMRAGComponentUri.LLM_RAG_GENERATE_EMBEDDINGS) + create_faiss_index_component = get_component_obj(ml_client, LLMRAGComponentUri.LLM_RAG_CREATE_FAISS_INDEX) + register_mlindex_asset_component = get_component_obj(ml_client, LLMRAGComponentUri.LLM_RAG_REGISTER_MLINDEX_ASSET) + + @pipeline( # type: ignore [call-overload] + name=name if name else "data_index_faiss", + description=description, + tags=tags, + display_name=display_name if display_name else "LLM - Data to Faiss", + experiment_name=experiment_name, + compute=compute, + get_component=True, + ) + def data_index_faiss_pipeline( + input_data: Input, + embeddings_model: str, + chunk_size: int = 1024, + data_source_glob: str = None, # type: ignore [assignment] + data_source_url: str = None, # type: ignore [assignment] + document_path_replacement_regex: str = None, # type: ignore [assignment] + aoai_connection_id: str = None, # type: ignore [assignment] + embeddings_container: Input = None, # type: ignore [assignment] + ): + """ + Generate embeddings for a `input_data` source and create a Faiss index from them. + + :param input_data: The input data to be indexed. + :type input_data: Input + :param embeddings_model: The embedding model to use when processing source data chunks. + :type embeddings_model: str + :param chunk_size: The size of the chunks to break the input data into. + :type chunk_size: int + :param data_source_glob: The glob pattern to use when searching for input data. + :type data_source_glob: str + :param data_source_url: The URL to use when generating citations for the input data. + :type data_source_url: str + :param document_path_replacement_regex: The regex to use when generating citations for the input data. + :type document_path_replacement_regex: str + :param aoai_connection_id: The connection ID for the Azure Open AI service. + :type aoai_connection_id: str + :param embeddings_container: The container to use when caching embeddings. + :type embeddings_container: Input + :return: The URI of the generated Faiss index. + :rtype: str. + """ + crack_and_chunk = crack_and_chunk_component( + input_data=input_data, + input_glob=data_source_glob, + chunk_size=chunk_size, + data_source_url=data_source_url, + document_path_replacement_regex=document_path_replacement_regex, + ) + if compute is None or compute == "serverless": + use_automatic_compute(crack_and_chunk, instance_type=serverless_instance_type) + if identity: + crack_and_chunk.identity = identity + + generate_embeddings = generate_embeddings_component( + chunks_source=crack_and_chunk.outputs.output_chunks, + embeddings_container=embeddings_container, + embeddings_model=embeddings_model, + ) + if compute is None or compute == "serverless": + use_automatic_compute(generate_embeddings, instance_type=serverless_instance_type) + if optional_pipeline_input_provided(aoai_connection_id): # type: ignore [arg-type] + generate_embeddings.environment_variables["AZUREML_WORKSPACE_CONNECTION_ID_AOAI"] = aoai_connection_id + if optional_pipeline_input_provided(embeddings_container): # type: ignore [arg-type] + generate_embeddings.outputs.embeddings = Output( + type="uri_folder", path=f"{embeddings_container.path}/{{name}}" + ) + if identity: + generate_embeddings.identity = identity + + create_faiss_index = create_faiss_index_component(embeddings=generate_embeddings.outputs.embeddings) + if compute is None or compute == "serverless": + use_automatic_compute(create_faiss_index, instance_type=serverless_instance_type) + if identity: + create_faiss_index.identity = identity + + register_mlindex_asset = register_mlindex_asset_component( + storage_uri=create_faiss_index.outputs.index, + asset_name=data_index.name, + ) + if compute is None or compute == "serverless": + use_automatic_compute(register_mlindex_asset, instance_type=serverless_instance_type) + if identity: + register_mlindex_asset.identity = identity + return { + "mlindex_asset_uri": create_faiss_index.outputs.index, + "mlindex_asset_id": register_mlindex_asset.outputs.asset_id, + } + + if input_data_override is not None: + input_data = input_data_override + else: + input_data = Input( + type=data_index.source.input_data.type, path=data_index.source.input_data.path # type: ignore [arg-type] + ) + + component = data_index_faiss_pipeline( + input_data=input_data, + embeddings_model=build_model_protocol(data_index.embedding.model), + chunk_size=data_index.source.chunk_size, # type: ignore [arg-type] + data_source_glob=data_index.source.input_glob, # type: ignore [arg-type] + data_source_url=data_index.source.citation_url, # type: ignore [arg-type] + document_path_replacement_regex=( + json.dumps(data_index.source.citation_url_replacement_regex._to_dict()) # type: ignore [arg-type] + if data_index.source.citation_url_replacement_regex + else None + ), + aoai_connection_id=_resolve_connection_id( + ml_client, data_index.embedding.connection + ), # type: ignore [arg-type] + embeddings_container=( + Input(type=AssetTypes.URI_FOLDER, path=data_index.embedding.cache_path) # type: ignore [arg-type] + if data_index.embedding.cache_path + else None + ), + ) + # Hack until full Component classes are implemented that can annotate the optional parameters properly + component.inputs["data_source_glob"]._meta.optional = True + component.inputs["data_source_url"]._meta.optional = True + component.inputs["document_path_replacement_regex"]._meta.optional = True + component.inputs["aoai_connection_id"]._meta.optional = True + component.inputs["embeddings_container"]._meta.optional = True + if data_index.path: + component.outputs.mlindex_asset_uri = Output( + type=AssetTypes.URI_FOLDER, path=data_index.path # type: ignore [arg-type] + ) + + return component + + +# pylint: disable=too-many-statements +def data_index_hosted( + ml_client: Any, + data_index: DataIndex, + description: Optional[str] = None, + tags: Optional[Dict] = None, + name: Optional[str] = None, + display_name: Optional[str] = None, + experiment_name: Optional[str] = None, + compute: Optional[str] = None, + serverless_instance_type: Optional[str] = None, + identity: Optional[Union[ManagedIdentityConfiguration, UserIdentityConfiguration]] = None, + input_data_override: Optional[Input] = None, +): + from azure.ai.ml.entities._indexes.utils import build_model_protocol, pipeline + + crack_and_chunk_component = get_component_obj(ml_client, LLMRAGComponentUri.LLM_RAG_CRACK_AND_CHUNK) + generate_embeddings_component = get_component_obj(ml_client, LLMRAGComponentUri.LLM_RAG_GENERATE_EMBEDDINGS) + + if data_index.index.type == DataIndexTypes.ACS: + update_index_component = get_component_obj(ml_client, LLMRAGComponentUri.LLM_RAG_UPDATE_ACS_INDEX) + elif data_index.index.type == DataIndexTypes.PINECONE: + update_index_component = get_component_obj(ml_client, LLMRAGComponentUri.LLM_RAG_UPDATE_PINECONE_INDEX) + else: + raise ValueError(f"Unsupported hosted index type: {data_index.index.type}") + + register_mlindex_asset_component = get_component_obj(ml_client, LLMRAGComponentUri.LLM_RAG_REGISTER_MLINDEX_ASSET) + + @pipeline( # type: ignore [call-overload] + name=name if name else f"data_index_{data_index.index.type}", + description=description, + tags=tags, + display_name=display_name if display_name else f"LLM - Data to {data_index.index.type.upper()}", + experiment_name=experiment_name, + compute=compute, + get_component=True, + ) + def data_index_pipeline( + input_data: Input, + embeddings_model: str, + index_config: str, + index_connection_id: str, + chunk_size: int = 1024, + data_source_glob: str = None, # type: ignore [assignment] + data_source_url: str = None, # type: ignore [assignment] + document_path_replacement_regex: str = None, # type: ignore [assignment] + aoai_connection_id: str = None, # type: ignore [assignment] + embeddings_container: Input = None, # type: ignore [assignment] + ): + """ + Generate embeddings for a `input_data` source + and push them into a hosted index (such as Azure Cognitive Search and Pinecone). + + :param input_data: The input data to be indexed. + :type input_data: Input + :param embeddings_model: The embedding model to use when processing source data chunks. + :type embeddings_model: str + :param index_config: The configuration for the hosted index. + :type index_config: str + :param index_connection_id: The connection ID for the hosted index. + :type index_connection_id: str + :param chunk_size: The size of the chunks to break the input data into. + :type chunk_size: int + :param data_source_glob: The glob pattern to use when searching for input data. + :type data_source_glob: str + :param data_source_url: The URL to use when generating citations for the input data. + :type data_source_url: str + :param document_path_replacement_regex: The regex to use when generating citations for the input data. + :type document_path_replacement_regex: str + :param aoai_connection_id: The connection ID for the Azure Open AI service. + :type aoai_connection_id: str + :param embeddings_container: The container to use when caching embeddings. + :type embeddings_container: Input + :return: The URI of the generated Azure Cognitive Search index. + :rtype: str. + """ + crack_and_chunk = crack_and_chunk_component( + input_data=input_data, + input_glob=data_source_glob, + chunk_size=chunk_size, + data_source_url=data_source_url, + document_path_replacement_regex=document_path_replacement_regex, + ) + if compute is None or compute == "serverless": + use_automatic_compute(crack_and_chunk, instance_type=serverless_instance_type) + if identity: + crack_and_chunk.identity = identity + + generate_embeddings = generate_embeddings_component( + chunks_source=crack_and_chunk.outputs.output_chunks, + embeddings_container=embeddings_container, + embeddings_model=embeddings_model, + ) + if compute is None or compute == "serverless": + use_automatic_compute(generate_embeddings, instance_type=serverless_instance_type) + if optional_pipeline_input_provided(aoai_connection_id): # type: ignore [arg-type] + generate_embeddings.environment_variables["AZUREML_WORKSPACE_CONNECTION_ID_AOAI"] = aoai_connection_id + if optional_pipeline_input_provided(embeddings_container): # type: ignore [arg-type] + generate_embeddings.outputs.embeddings = Output( + type="uri_folder", path=f"{embeddings_container.path}/{{name}}" + ) + if identity: + generate_embeddings.identity = identity + + if data_index.index.type == DataIndexTypes.ACS: + update_index = update_index_component( + embeddings=generate_embeddings.outputs.embeddings, acs_config=index_config + ) + update_index.environment_variables["AZUREML_WORKSPACE_CONNECTION_ID_ACS"] = index_connection_id + elif data_index.index.type == DataIndexTypes.PINECONE: + update_index = update_index_component( + embeddings=generate_embeddings.outputs.embeddings, pinecone_config=index_config + ) + update_index.environment_variables["AZUREML_WORKSPACE_CONNECTION_ID_PINECONE"] = index_connection_id + else: + raise ValueError(f"Unsupported hosted index type: {data_index.index.type}") + if compute is None or compute == "serverless": + use_automatic_compute(update_index, instance_type=serverless_instance_type) + if identity: + update_index.identity = identity + + register_mlindex_asset = register_mlindex_asset_component( + storage_uri=update_index.outputs.index, + asset_name=data_index.name, + ) + if compute is None or compute == "serverless": + use_automatic_compute(register_mlindex_asset, instance_type=serverless_instance_type) + if identity: + register_mlindex_asset.identity = identity + return { + "mlindex_asset_uri": update_index.outputs.index, + "mlindex_asset_id": register_mlindex_asset.outputs.asset_id, + } + + if input_data_override is not None: + input_data = input_data_override + else: + input_data = Input( + type=data_index.source.input_data.type, path=data_index.source.input_data.path # type: ignore [arg-type] + ) + + index_config = { + "index_name": data_index.index.name if data_index.index.name is not None else data_index.name, + } + if data_index.index.config is not None: + index_config.update(data_index.index.config) + + component = data_index_pipeline( + input_data=input_data, + embeddings_model=build_model_protocol(data_index.embedding.model), + index_config=json.dumps(index_config), + index_connection_id=_resolve_connection_id(ml_client, data_index.index.connection), # type: ignore [arg-type] + chunk_size=data_index.source.chunk_size, # type: ignore [arg-type] + data_source_glob=data_index.source.input_glob, # type: ignore [arg-type] + data_source_url=data_index.source.citation_url, # type: ignore [arg-type] + document_path_replacement_regex=( + json.dumps(data_index.source.citation_url_replacement_regex._to_dict()) # type: ignore [arg-type] + if data_index.source.citation_url_replacement_regex + else None + ), + aoai_connection_id=_resolve_connection_id( + ml_client, data_index.embedding.connection # type: ignore [arg-type] + ), + embeddings_container=( + Input(type=AssetTypes.URI_FOLDER, path=data_index.embedding.cache_path) # type: ignore [arg-type] + if data_index.embedding.cache_path + else None + ), + ) + # Hack until full Component classes are implemented that can annotate the optional parameters properly + component.inputs["data_source_glob"]._meta.optional = True + component.inputs["data_source_url"]._meta.optional = True + component.inputs["document_path_replacement_regex"]._meta.optional = True + component.inputs["aoai_connection_id"]._meta.optional = True + component.inputs["embeddings_container"]._meta.optional = True + + if data_index.path: + component.outputs.mlindex_asset_uri = Output( + type=AssetTypes.URI_FOLDER, path=data_index.path # type: ignore [arg-type] + ) + + return component + + +def optional_pipeline_input_provided(input: Optional[PipelineInput]): + """ + Checks if optional pipeline inputs are provided. + + :param input: The pipeline input to check. + :type input: Optional[PipelineInput] + :return: True if the input is not None and has a value, False otherwise. + :rtype: bool. + """ + return input is not None and input._data is not None + + +def use_automatic_compute(component, instance_count=1, instance_type=None): + """ + Configure input `component` to use automatic compute with `instance_count` and `instance_type`. + + This avoids the need to provision a compute cluster to run the component. + :param component: The component to configure. + :type component: Any + :param instance_count: The number of instances to use. + :type instance_count: int + :param instance_type: The type of instance to use. + :type instance_type: str + :return: The configured component. + :rtype: Any. + """ + component.set_resources( + instance_count=instance_count, + instance_type=instance_type, + properties={"compute_specification": {"automatic": True}}, + ) + return component + + +def get_component_obj(ml_client, component_uri): + from azure.ai.ml import MLClient + + if not isinstance(component_uri, str): + # Assume Component object + return component_uri + + matches = re.match( + r"azureml://registries/(?P<registry_name>.*)/components/(?P<component_name>.*)" + r"/(?P<identifier_type>.*)/(?P<identifier_name>.*)", + component_uri, + ) + if matches is None: + from azure.ai.ml import load_component + + # Assume local path to component + return load_component(source=component_uri) + + registry_name = matches.group("registry_name") + registry_client = MLClient( + subscription_id=ml_client.subscription_id, + resource_group_name=ml_client.resource_group_name, + credential=ml_client._credential, + registry_name=registry_name, + ) + component_obj = registry_client.components.get( + matches.group("component_name"), + **{matches.group("identifier_type").rstrip("s"): matches.group("identifier_name")}, + ) + return component_obj + + +def _resolve_connection_id(ml_client, connection: Optional[str] = None) -> Optional[str]: + if connection is None: + return None + + if isinstance(connection, str): + from azure.ai.ml._utils._arm_id_utils import AMLNamedArmId + + connection_name = AMLNamedArmId(connection).asset_name + + connection = ml_client.connections.get(connection_name) + if connection is None: + return None + return connection.id # type: ignore [attr-defined] diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/entities/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/entities/__init__.py new file mode 100644 index 00000000..fdf8caba --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/entities/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/entities/data_index.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/entities/data_index.py new file mode 100644 index 00000000..094d19aa --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/entities/data_index.py @@ -0,0 +1,243 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""DataIndex entities.""" + +from typing import Dict, Optional + +from azure.ai.ml.constants._common import DataIndexTypes +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml.entities._assets import Data +from azure.ai.ml.entities._inputs_outputs.utils import _remove_empty_values +from azure.ai.ml.entities._mixins import DictMixin + + +@experimental +class CitationRegex(DictMixin): + """ + :keyword match_pattern: Regex to match citation in the citation_url + input file path. + e.g. '(.*)/articles/(.*)(\\.[^.]+)$'. + :type match_pattern: str + :keyword replacement_pattern: Replacement string for citation. e.g. '\\1/\\2'. + :type replacement_pattern: str + """ + + def __init__( + self, + *, + match_pattern: str, + replacement_pattern: str, + ): + """Initialize a CitationRegex object.""" + self.match_pattern = match_pattern + self.replacement_pattern = replacement_pattern + + def _to_dict(self) -> Dict: + """Convert the Source object to a dict. + :return: The dictionary representation of the class + :rtype: Dict + """ + keys = [ + "match_pattern", + "replacement_pattern", + ] + result = {key: getattr(self, key) for key in keys} + return _remove_empty_values(result) + + +@experimental +class IndexSource(DictMixin): + """Congifuration for the destination index to write processed data to. + :keyword input_data: Input Data to index files from. MLTable type inputs will use `mode: eval_mount`. + :type input_data: Data + :keyword input_glob: Connection reference to use for embedding model information, + only needed for hosted embeddings models (such as Azure OpenAI). + :type input_glob: str, optional + :keyword chunk_size: Maximum number of tokens to put in each chunk. + :type chunk_size: int, optional + :keyword chunk_overlap: Number of tokens to overlap between chunks. + :type chunk_overlap: int, optional + :keyword citation_url: Base URL to join with file paths to create full source file URL for chunk metadata. + :type citation_url: str, optional + :keyword citation_url_replacement_regex: Regex match and replacement patterns for citation url. Useful if the paths + in `input_data` don't match the desired citation format. + :type citation_url_replacement_regex: CitationRegex, optional + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if the IndexSource object cannot be validated. + Details will be provided in the error message. + """ + + def __init__( + self, + *, + input_data: Data, + input_glob: Optional[str] = None, + chunk_size: Optional[int] = None, + chunk_overlap: Optional[int] = None, + citation_url: Optional[str] = None, + citation_url_replacement_regex: Optional[CitationRegex] = None, + ): + """Initialize a IndexSource object.""" + self.input_data = input_data + self.input_glob = input_glob + self.chunk_size = chunk_size + self.chunk_overlap = chunk_overlap + self.citation_url = citation_url + self.citation_url_replacement_regex = citation_url_replacement_regex + + def _to_dict(self) -> Dict: + """Convert the Source object to a dict. + :return: The dictionary representation of the class + :rtype: Dict + """ + keys = [ + "input_data", + "input_glob", + "chunk_size", + "chunk_overlap", + "citation_url", + "citation_url_replacement_regex", + ] + result = {key: getattr(self, key) for key in keys} + return _remove_empty_values(result) + + +@experimental +class Embedding(DictMixin): + """Congifuration for the destination index to write processed data to. + :keyword model: The model to use to embed data. E.g. 'hugging_face://model/sentence-transformers/all-mpnet-base-v2' + or 'azure_open_ai://deployment/{deployment_name}/model/{model_name}' + :type model: str + :keyword connection: Connection reference to use for embedding model information, + only needed for hosted embeddings models (such as Azure OpenAI). + :type connection: str, optional + :keyword cache_path: Folder containing previously generated embeddings. + Should be parent folder of the 'embeddings' output path used for for this component. + Will compare input data to existing embeddings and only embed changed/new data, reusing existing chunks. + :type cache_path: str, optional + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if the Embedding object cannot be validated. + Details will be provided in the error message. + """ + + def __init__( + self, + *, + model: str, + connection: Optional[str] = None, + cache_path: Optional[str] = None, + ): + """Initialize a Embedding object.""" + self.model = model + self.connection = connection + self.cache_path = cache_path + + def _to_dict(self) -> Dict: + """Convert the Source object to a dict. + :return: The dictionary representation of the class + :rtype: Dict + """ + keys = [ + "model", + "connection", + "cache_path", + ] + result = {key: getattr(self, key) for key in keys} + return _remove_empty_values(result) + + +@experimental +class IndexStore(DictMixin): + """Congifuration for the destination index to write processed data to. + :keyword type: The type of index to write to. Currently supported types are 'acs', 'pinecone', and 'faiss'. + :type type: str + :keyword name: Name of index to update/create, only needed for hosted indexes + (such as Azure Cognitive Search and Pinecone). + :type name: str, optional + :keyword connection: Connection reference to use for index information, + only needed for hosted indexes (such as Azure Cognitive Search and Pinecone). + :type connection: str, optional + :keyword config: Configuration for the index. Configuration for the index. + Primary use is to configure AI Search and Pinecone specific settings. + Such as custom `field_mapping` for known field types. + :type config: dict, optional + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if the IndexStore object cannot be validated. + Details will be provided in the error message. + """ + + def __init__( + self, + *, + type: str = DataIndexTypes.FAISS, + name: Optional[str] = None, + connection: Optional[str] = None, + config: Optional[Dict] = None, + ): + """Initialize a IndexStore object.""" + self.type = type + self.name = name + self.connection = connection + self.config = config + + def _to_dict(self) -> Dict: + """Convert the Source object to a dict. + :return: The dictionary representation of the class + :rtype: Dict + """ + keys = ["type", "name", "connection", "config"] + result = {key: getattr(self, key) for key in keys} + return _remove_empty_values(result) + + +@experimental +class DataIndex(Data): + """Data asset with a creating data index job. + :param name: Name of the asset. + :type name: str + :param path: The path to the asset being created by data index job. + :type path: str + :param source: The source data to be indexed. + :type source: IndexSource + :param embedding: The embedding model to use when processing source data chunks. + :type embedding: Embedding + :param index: The destination index to write processed data to. + :type index: IndexStore + :param version: Version of the asset created by running this DataIndex Job. + :type version: str + :param description: Description of the resource. + :type description: str + :param tags: Tag dictionary. Tags can be added, removed, and updated. + :type tags: dict[str, str] + :param properties: The asset property dictionary. + :type properties: dict[str, str] + :param kwargs: A dictionary of additional configuration parameters. + :type kwargs: dict + """ + + def __init__( + self, + *, + name: str, + source: IndexSource, + embedding: Embedding, + index: IndexStore, + incremental_update: bool = False, + path: Optional[str] = None, + version: Optional[str] = None, + description: Optional[str] = None, + tags: Optional[Dict] = None, + properties: Optional[Dict] = None, + **kwargs, + ): + """Initialize a DataIndex object.""" + super().__init__( + name=name, + version=version, + description=description, + tags=tags, + properties=properties, + path=path, + **kwargs, + ) + self.source = source + self.embedding = embedding + self.index = index + self.incremental_update = incremental_update diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/input/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/input/__init__.py new file mode 100644 index 00000000..fdf8caba --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/input/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/input/_ai_search_config.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/input/_ai_search_config.py new file mode 100644 index 00000000..b2163c40 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/input/_ai_search_config.py @@ -0,0 +1,31 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# General todo: need to determine which args are required or optional when parsed out into groups like this. +# General todo: move these to more permanent locations? + +# Defines stuff related to the resulting created index, like the index type. + +from typing import Optional +from azure.ai.ml._utils._experimental import experimental + + +@experimental +class AzureAISearchConfig: + """Config class for creating an Azure AI Search index. + + :param index_name: The name of the Azure AI Search index. + :type index_name: Optional[str] + :param connection_id: The Azure AI Search connection ID. + :type connection_id: Optional[str] + """ + + def __init__( + self, + *, + index_name: Optional[str] = None, + connection_id: Optional[str] = None, + ) -> None: + self.index_name = index_name + self.connection_id = connection_id diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/input/_index_config.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/input/_index_config.py new file mode 100644 index 00000000..0eec691a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/input/_index_config.py @@ -0,0 +1,47 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from typing import Optional + + +class IndexConfig: # pylint: disable=too-many-instance-attributes + """Convenience class that contains all config values that for index creation that are + NOT specific to the index source data or the created index type. Meant for internal use only + to simplify function headers. The user-entry point is a function that + should still contain all the fields in this class as individual function parameters. + + Params omitted for brevity and to avoid maintaining duplicate docs. See index creation function + for actual parameter descriptions. + """ + + def __init__( + self, + *, + output_index_name: str, + vector_store: str, + data_source_url: Optional[str] = None, + chunk_size: Optional[int] = None, + chunk_overlap: Optional[int] = None, + input_glob: Optional[str] = None, + max_sample_files: Optional[int] = None, + chunk_prepend_summary: Optional[bool] = None, + document_path_replacement_regex: Optional[str] = None, + embeddings_container: Optional[str] = None, + embeddings_model: str, + aoai_connection_id: str, + _dry_run: bool = False + ): + self.output_index_name = output_index_name + self.vector_store = vector_store + self.data_source_url = data_source_url + self.chunk_size = chunk_size + self.chunk_overlap = chunk_overlap + self.input_glob = input_glob + self.max_sample_files = max_sample_files + self.chunk_prepend_summary = chunk_prepend_summary + self.document_path_replacement_regex = document_path_replacement_regex + self.embeddings_container = embeddings_container + self.embeddings_model = embeddings_model + self.aoai_connection_id = aoai_connection_id + self._dry_run = _dry_run diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/input/_index_data_source.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/input/_index_data_source.py new file mode 100644 index 00000000..92b62b6b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/input/_index_data_source.py @@ -0,0 +1,62 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from typing import Union + +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml.entities._inputs_outputs import Input +from azure.ai.ml.constants._common import IndexInputType + + +# General todo: need to determine which args are required or optional when parsed out into groups like this. +# General todo: move these to more permanent locations? + + +# Defines stuff related to supplying inputs for an index AKA the base data. +@experimental +class IndexDataSource: + """Base class for configs that define data that will be processed into an ML index. + This class should not be instantiated directly. Use one of its child classes instead. + + :param input_type: A type enum describing the source of the index. Used to avoid + direct type checking. + :type input_type: Union[str, ~azure.ai.ml.constants._common.IndexInputType] + """ + + def __init__(self, *, input_type: Union[str, IndexInputType]): + self.input_type = input_type + + +# Field bundle for creating an index from files located in a Git repo. +# TODO Does git_url need to specifically be an SSH or HTTPS style link? +# TODO What is git connection id? +@experimental +class GitSource(IndexDataSource): + """Config class for creating an ML index from files located in a git repository. + + :param url: A link to the repository to use. + :type url: str + :param branch_name: The name of the branch to use from the target repository. + :type branch_name: str + :param connection_id: The connection ID for GitHub + :type connection_id: str + """ + + def __init__(self, *, url: str, branch_name: str, connection_id: str): + self.url = url + self.branch_name = branch_name + self.connection_id = connection_id + super().__init__(input_type=IndexInputType.GIT) + + +@experimental +class LocalSource(IndexDataSource): + """Config class for creating an ML index from a collection of local files. + + :param input_data: An input object describing the local location of index source files. + :type input_data: ~azure.ai.ml.Input + """ + + def __init__(self, *, input_data: str): # todo Make sure type of input_data is correct + self.input_data = Input(type="uri_folder", path=input_data) + super().__init__(input_type=IndexInputType.LOCAL) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/model_config.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/model_config.py new file mode 100644 index 00000000..c9e54da4 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/model_config.py @@ -0,0 +1,122 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from dataclasses import dataclass +from typing import Any, Dict, Optional +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml._utils.utils import camel_to_snake +from azure.ai.ml.entities._workspace.connections.workspace_connection import WorkspaceConnection +from azure.ai.ml.entities._workspace.connections.connection_subtypes import ( + AzureOpenAIConnection, + AadCredentialConfiguration, +) + + +@experimental +@dataclass +class ModelConfiguration: + """Configuration for a embedding model. + + :param api_base: The base URL for the API. + :type api_base: Optional[str] + :param api_key: The API key. + :type api_key: Optional[str] + :param api_version: The API version. + :type api_version: Optional[str] + :param model_name: The name of the model. + :type model_name: Optional[str] + :param model_name: The deployment name of the model. + :type model_name: Optional[str] + :param connection_name: The name of the workspace connection of this model. + :type connection_name: Optional[str] + :param connection_type: The type of the workspace connection of this model. + :type connection_type: Optional[str] + :param model_kwargs: Additional keyword arguments for the model. + :type model_kwargs: Dict[str, Any] + """ + + api_base: Optional[str] + api_key: Optional[str] + api_version: Optional[str] + connection_name: Optional[str] + connection_type: Optional[str] + model_name: Optional[str] + deployment_name: Optional[str] + model_kwargs: Dict[str, Any] + + def __init__( + self, + *, + api_base: Optional[str], + api_key: Optional[str], + api_version: Optional[str], + connection_name: Optional[str], + connection_type: Optional[str], + model_name: Optional[str], + deployment_name: Optional[str], + model_kwargs: Dict[str, Any] + ): + self.api_base = api_base + self.api_key = api_key + self.api_version = api_version + self.connection_name = connection_name + self.connection_type = connection_type + self.model_name = model_name + self.deployment_name = deployment_name + self.model_kwargs = model_kwargs + + @staticmethod + def from_connection( + connection: WorkspaceConnection, + model_name: Optional[str] = None, + deployment_name: Optional[str] = None, + **kwargs + ) -> "ModelConfiguration": + """Create an model configuration from a Connection. + + :param connection: The WorkspaceConnection object. + :type connection: ~azure.ai.ml.entities.WorkspaceConnection + :param model_name: The name of the model. + :type model_name: Optional[str] + :param deployment_name: The name of the deployment. + :type deployment_name: Optional[str] + :return: The model configuration. + :rtype: ~azure.ai.ml.entities._indexes.entities.ModelConfiguration + :raises TypeError: If the connection is not an AzureOpenAIConnection. + :raises ValueError: If the connection does not contain an OpenAI key. + """ + if isinstance(connection, AzureOpenAIConnection) or camel_to_snake(connection.type) == "azure_open_ai": + connection_type = "azure_open_ai" + api_version = connection.api_version # type: ignore[attr-defined] + if not model_name or not deployment_name: + raise ValueError("Please specify model_name and deployment_name.") + elif connection.type and connection.type.lower() == "serverless": + connection_type = "serverless" + api_version = None + if not connection.id: + raise TypeError("The connection id is missing from the serverless connection object.") + else: + raise TypeError("Connection object is not supported.") + + if isinstance(connection.credentials, AadCredentialConfiguration): + key = None + else: + key = connection.credentials.get("key") # type: ignore[union-attr] + if key is None and connection_type == "azure_open_ai": + import os + + if "AZURE_OPENAI_API_KEY" in os.environ: + key = os.getenv("AZURE_OPENAI_API_KEY") + else: + raise ValueError("Unable to retrieve openai key from connection object or env variable.") + + return ModelConfiguration( + api_base=connection.target, + api_key=key, + api_version=api_version, + connection_name=connection.name, + connection_type=connection_type, + model_name=model_name, + deployment_name=deployment_name, + model_kwargs=kwargs, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/utils/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/utils/__init__.py new file mode 100644 index 00000000..f65f5505 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/utils/__init__.py @@ -0,0 +1,10 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""AzureML Retrieval Augmented Generation (RAG) utilities.""" + +from ._models import build_model_protocol +from ._open_ai_utils import build_open_ai_protocol, build_connection_id +from ._pipeline_decorator import pipeline + +__all__ = ["build_model_protocol", "build_open_ai_protocol", "build_connection_id", "pipeline"] diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/utils/_models.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/utils/_models.py new file mode 100644 index 00000000..d3e8c952 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/utils/_models.py @@ -0,0 +1,25 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""DataIndex embedding model helpers.""" +import re +from typing import Optional + +OPEN_AI_PROTOCOL_TEMPLATE = "azure_open_ai://deployment/{}/model/{}" +OPEN_AI_PROTOCOL_REGEX_PATTERN = OPEN_AI_PROTOCOL_TEMPLATE.format(".*", ".*") +OPEN_AI_SHORT_FORM_PROTOCOL_TEMPLATE = "azure_open_ai://deployments?/{}" +OPEN_AI_PROTOCOL_REGEX_PATTERN = OPEN_AI_SHORT_FORM_PROTOCOL_TEMPLATE.format(".*") + +HUGGINGFACE_PROTOCOL_TEMPLATE = "hugging_face://model/{}" +HUGGINGFACE_PROTOCOL_REGEX_PATTERN = HUGGINGFACE_PROTOCOL_TEMPLATE.format(".*") + + +def build_model_protocol(model: Optional[str] = None): + if not model or re.match(OPEN_AI_PROTOCOL_REGEX_PATTERN, model, re.IGNORECASE): + return model + if re.match(OPEN_AI_SHORT_FORM_PROTOCOL_TEMPLATE, model, re.IGNORECASE): + return model + if re.match(HUGGINGFACE_PROTOCOL_REGEX_PATTERN, model, re.IGNORECASE): + return model + + return OPEN_AI_PROTOCOL_TEMPLATE.format(model, model) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/utils/_open_ai_utils.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/utils/_open_ai_utils.py new file mode 100644 index 00000000..d38a447f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/utils/_open_ai_utils.py @@ -0,0 +1,36 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from typing import Optional + +from azure.ai.ml._utils._arm_id_utils import is_ARM_id_for_resource +from azure.ai.ml._scope_dependent_operations import OperationScope + +OPEN_AI_PROTOCOL_TEMPLATE = "azure_open_ai://deployment/{}/model/{}" + + +def build_open_ai_protocol( + model: Optional[str] = None, + deployment: Optional[str] = None, +): + if not deployment or not model: + return None + return OPEN_AI_PROTOCOL_TEMPLATE.format(deployment, model) + + +def build_connection_id(id: Optional[str], scope: OperationScope): + if not id or not scope.subscription_id or not scope.resource_group_name or not scope.workspace_name: + return id + + if is_ARM_id_for_resource(id, "connections", True): + return id + + # pylint: disable=line-too-long + template = "/subscriptions/{subscription_id}/resourceGroups/{resource_group_name}/providers/Microsoft.MachineLearningServices/workspaces/{workspace_name}/connections/{id}" + return template.format( + subscription_id=scope.subscription_id, + resource_group_name=scope.resource_group_name, + workspace_name=scope.workspace_name, + id=id, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/utils/_pipeline_decorator.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/utils/_pipeline_decorator.py new file mode 100644 index 00000000..e70f97f2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/utils/_pipeline_decorator.py @@ -0,0 +1,248 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +import inspect +import logging +from functools import wraps +from pathlib import Path +from typing import Any, Callable, Dict, Optional, TypeVar, Union, overload + +from typing_extensions import ParamSpec + +from azure.ai.ml.entities import Data, Model, PipelineJob, PipelineJobSettings +from azure.ai.ml.entities._builders.pipeline import Pipeline +from azure.ai.ml.entities._inputs_outputs import Input +from azure.ai.ml.entities._job.pipeline._io import NodeOutput, PipelineInput, _GroupAttrDict +from azure.ai.ml.entities._job.pipeline._pipeline_expression import PipelineExpression +from azure.ai.ml.exceptions import UserErrorException + +from azure.ai.ml.dsl._pipeline_component_builder import PipelineComponentBuilder, _is_inside_dsl_pipeline_func +from azure.ai.ml.dsl._pipeline_decorator import _validate_args +from azure.ai.ml.dsl._settings import _dsl_settings_stack +from azure.ai.ml.dsl._utils import _resolve_source_file + +SUPPORTED_INPUT_TYPES = ( + PipelineInput, + NodeOutput, + Input, + Model, + Data, # For the case use a Data object as an input, we will convert it to Input object + Pipeline, # For the case use a pipeline node as the input, we use its only one output as the real input. + str, + bool, + int, + float, + PipelineExpression, + _GroupAttrDict, +) +module_logger = logging.getLogger(__name__) + +T = TypeVar("T") +P = ParamSpec("P") + + +# Overload the returns a decorator when func is None +@overload +def pipeline( + func: None, + *, + name: Optional[str] = None, + version: Optional[str] = None, + display_name: Optional[str] = None, + description: Optional[str] = None, + experiment_name: Optional[str] = None, + tags: Optional[Dict[str, str]] = None, + **kwargs: Any, +) -> Callable[[Callable[P, T]], Callable[P, PipelineJob]]: ... + + +# Overload the returns a decorated function when func isn't None +@overload +def pipeline( + func: Callable[P, T], + *, + name: Optional[str] = None, + version: Optional[str] = None, + display_name: Optional[str] = None, + description: Optional[str] = None, + experiment_name: Optional[str] = None, + tags: Optional[Dict[str, str]] = None, + **kwargs: Any, +) -> Callable[P, PipelineJob]: ... + + +def pipeline( + func: Optional[Callable[P, T]] = None, + *, + name: Optional[str] = None, + version: Optional[str] = None, + display_name: Optional[str] = None, + description: Optional[str] = None, + experiment_name: Optional[str] = None, + tags: Optional[Union[Dict[str, str], str]] = None, + **kwargs: Any, +) -> Union[Callable[[Callable[P, T]], Callable[P, PipelineJob]], Callable[P, PipelineJob]]: + """Build a pipeline which contains all component nodes defined in this function. + + :param func: The user pipeline function to be decorated. + :type func: types.FunctionType + :keyword name: The name of pipeline component, defaults to function name. + :paramtype name: str + :keyword version: The version of pipeline component, defaults to "1". + :paramtype version: str + :keyword display_name: The display name of pipeline component, defaults to function name. + :paramtype display_name: str + :keyword description: The description of the built pipeline. + :paramtype description: str + :keyword experiment_name: Name of the experiment the job will be created under, \ + if None is provided, experiment will be set to current directory. + :paramtype experiment_name: str + :keyword tags: The tags of pipeline component. + :paramtype tags: dict[str, str] + :return: Either + * A decorator, if `func` is None + * The decorated `func` + + :rtype: Union[ + Callable[[Callable], Callable[..., ~azure.ai.ml.entities.PipelineJob]], + Callable[P, ~azure.ai.ml.entities.PipelineJob] + + ] + + .. admonition:: Example: + + .. literalinclude:: ../../../../samples/ml_samples_pipeline_job_configurations.py + :start-after: [START configure_pipeline] + :end-before: [END configure_pipeline] + :language: python + :dedent: 8 + :caption: Shows how to create a pipeline using this decorator. + """ + + # get_component force pipeline to return Pipeline instead of PipelineJob so we can set optional argument + # need to remove get_component and rely on azure.ai.ml.dsl.pipeline + get_component = kwargs.get("get_component", False) + + def pipeline_decorator(func: Callable[P, T]) -> Callable: + if not isinstance(func, Callable): # type: ignore + raise UserErrorException(f"Dsl pipeline decorator accept only function type, got {type(func)}.") + + non_pipeline_inputs = kwargs.get("non_pipeline_inputs", []) or kwargs.get("non_pipeline_parameters", []) + # compute variable names changed from default_compute_targe -> compute -> default_compute -> none + # to support legacy usage, we support them with priority. + compute = kwargs.get("compute", None) + default_compute_target = kwargs.get("default_compute_target", None) + default_compute_target = kwargs.get("default_compute", None) or default_compute_target + continue_on_step_failure = kwargs.get("continue_on_step_failure", None) + on_init = kwargs.get("on_init", None) + on_finalize = kwargs.get("on_finalize", None) + + default_datastore = kwargs.get("default_datastore", None) + force_rerun = kwargs.get("force_rerun", None) + job_settings = { + "default_datastore": default_datastore, + "continue_on_step_failure": continue_on_step_failure, + "force_rerun": force_rerun, + "default_compute": default_compute_target, + "on_init": on_init, + "on_finalize": on_finalize, + } + func_entry_path = _resolve_source_file() + if not func_entry_path: + func_path = Path(inspect.getfile(func)) + # in notebook, func_path may be a fake path and will raise error when trying to resolve this fake path + if func_path.exists(): + func_entry_path = func_path.resolve().absolute() + + job_settings = {k: v for k, v in job_settings.items() if v is not None} + pipeline_builder = PipelineComponentBuilder( + func=func, + name=name, + version=version, + display_name=display_name, + description=description, + default_datastore=default_datastore, + tags=tags, + source_path=str(func_entry_path), + non_pipeline_inputs=non_pipeline_inputs, + ) + + @wraps(func) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> Union[Pipeline, PipelineJob]: + # Default args will be added here. + # Node: push/pop stack here instead of put it inside build() + # Because we only want to enable dsl settings on top level pipeline + _dsl_settings_stack.push() # use this stack to track on_init/on_finalize settings + try: + # Convert args to kwargs + provided_positional_kwargs = _validate_args(func, args, kwargs, non_pipeline_inputs) + + # When pipeline supports variable params, update pipeline component to support the inputs in **kwargs. + pipeline_parameters = { + k: v for k, v in provided_positional_kwargs.items() if k not in non_pipeline_inputs + } + pipeline_builder._update_inputs(pipeline_parameters) + + non_pipeline_params_dict = { + k: v for k, v in provided_positional_kwargs.items() if k in non_pipeline_inputs + } + + # TODO: cache built pipeline component + pipeline_component = pipeline_builder.build( + user_provided_kwargs=provided_positional_kwargs, + non_pipeline_inputs_dict=non_pipeline_params_dict, + non_pipeline_inputs=non_pipeline_inputs, + ) + finally: + # use `finally` to ensure pop operation from the stack + dsl_settings = _dsl_settings_stack.pop() + + # update on_init/on_finalize settings if init/finalize job is set + if dsl_settings.init_job_set: + job_settings["on_init"] = dsl_settings.init_job_name(pipeline_component.jobs) + if dsl_settings.finalize_job_set: + job_settings["on_finalize"] = dsl_settings.finalize_job_name(pipeline_component.jobs) + + # TODO: pass compute & default_compute separately? + common_init_args: Any = { + "experiment_name": experiment_name, + "component": pipeline_component, + "inputs": pipeline_parameters, + "tags": tags, + } + built_pipeline: Any = None + if _is_inside_dsl_pipeline_func() or get_component: + # on_init/on_finalize is not supported for pipeline component + if job_settings.get("on_init") is not None or job_settings.get("on_finalize") is not None: + raise UserErrorException("On_init/on_finalize is not supported for pipeline component.") + # Build pipeline node instead of pipeline job if inside dsl. + built_pipeline = Pipeline(_from_component_func=True, **common_init_args) + if job_settings: + module_logger.warning( + ("Job settings %s on pipeline function %r are ignored when using inside PipelineJob."), + job_settings, + func.__name__, + ) + else: + built_pipeline = PipelineJob( + jobs=pipeline_component.jobs, + compute=compute, + settings=PipelineJobSettings(**job_settings), + **common_init_args, + ) + + return built_pipeline + + # Bug Item number: 2883169 + wrapper._is_dsl_func = True # type: ignore + wrapper._job_settings = job_settings # type: ignore + wrapper._pipeline_builder = pipeline_builder # type: ignore + return wrapper + + # enable use decorator without "()" if all arguments are default values + if func is not None: + return pipeline_decorator(func) + return pipeline_decorator |
