import logging import os from typing import Any, Optional from azure.ai.inference import ( ChatCompletionsClient as AzureChatCompletionsClient, ) from azure.ai.inference.aio import ( ChatCompletionsClient as AsyncAzureChatCompletionsClient, ) from azure.core.credentials import AzureKeyCredential from core.base.abstractions import GenerationConfig from core.base.providers.llm import CompletionConfig, CompletionProvider logger = logging.getLogger(__name__) class AzureFoundryCompletionProvider(CompletionProvider): def __init__(self, config: CompletionConfig, *args, **kwargs) -> None: super().__init__(config) self.azure_foundry_client: Optional[AzureChatCompletionsClient] = None self.async_azure_foundry_client: Optional[ AsyncAzureChatCompletionsClient ] = None # Initialize Azure Foundry clients if credentials exist. azure_foundry_api_key = os.getenv("AZURE_FOUNDRY_API_KEY") azure_foundry_api_endpoint = os.getenv("AZURE_FOUNDRY_API_ENDPOINT") if azure_foundry_api_key and azure_foundry_api_endpoint: self.azure_foundry_client = AzureChatCompletionsClient( endpoint=azure_foundry_api_endpoint, credential=AzureKeyCredential(azure_foundry_api_key), api_version=os.getenv( "AZURE_FOUNDRY_API_VERSION", "2024-05-01-preview" ), ) self.async_azure_foundry_client = AsyncAzureChatCompletionsClient( endpoint=azure_foundry_api_endpoint, credential=AzureKeyCredential(azure_foundry_api_key), api_version=os.getenv( "AZURE_FOUNDRY_API_VERSION", "2024-05-01-preview" ), ) logger.debug("Azure Foundry clients initialized successfully") def _get_base_args( self, generation_config: GenerationConfig ) -> dict[str, Any]: # Construct arguments similar to the other providers. args: dict[str, Any] = { "top_p": generation_config.top_p, "stream": generation_config.stream, "max_tokens": generation_config.max_tokens_to_sample, "temperature": generation_config.temperature, } if generation_config.functions is not None: args["functions"] = generation_config.functions if generation_config.tools is not None: args["tools"] = generation_config.tools if generation_config.response_format is not None: args["response_format"] = generation_config.response_format return args async def _execute_task(self, task: dict[str, Any]): messages = task["messages"] generation_config = task["generation_config"] kwargs = task["kwargs"] args = self._get_base_args(generation_config) # Azure Foundry does not require a "model" argument; the endpoint is fixed. args["messages"] = messages args = {**args, **kwargs} logger.debug(f"Executing async Azure Foundry task with args: {args}") try: if self.async_azure_foundry_client is None: raise ValueError("Azure Foundry client is not initialized") response = await self.async_azure_foundry_client.complete(**args) logger.debug("Async Azure Foundry task executed successfully") return response except Exception as e: logger.error( f"Async Azure Foundry task execution failed: {str(e)}" ) raise def _execute_task_sync(self, task: dict[str, Any]): messages = task["messages"] generation_config = task["generation_config"] kwargs = task["kwargs"] args = self._get_base_args(generation_config) args["messages"] = messages args = {**args, **kwargs} logger.debug(f"Executing sync Azure Foundry task with args: {args}") try: if self.azure_foundry_client is None: raise ValueError("Azure Foundry client is not initialized") response = self.azure_foundry_client.complete(**args) logger.debug("Sync Azure Foundry task executed successfully") return response except Exception as e: logger.error(f"Sync Azure Foundry task execution failed: {str(e)}") raise