import logging from typing import Any from core.base.abstractions import GenerationConfig from core.base.providers.llm import CompletionConfig, CompletionProvider from .anthropic import AnthropicCompletionProvider from .azure_foundry import AzureFoundryCompletionProvider from .litellm import LiteLLMCompletionProvider from .openai import OpenAICompletionProvider logger = logging.getLogger(__name__) class R2RCompletionProvider(CompletionProvider): """A provider that routes to the right LLM provider (R2R): - If `generation_config.model` starts with "anthropic/", call AnthropicCompletionProvider. - If it starts with "azure-foundry/", call AzureFoundryCompletionProvider. - If it starts with one of the other OpenAI-like prefixes ("openai/", "azure/", "deepseek/", "ollama/", "lmstudio/") or has no prefix (e.g. "gpt-4", "gpt-3.5"), call OpenAICompletionProvider. - Otherwise, fallback to LiteLLMCompletionProvider. """ def __init__(self, config: CompletionConfig, *args, **kwargs) -> None: """Initialize sub-providers for OpenAI, Anthropic, LiteLLM, and Azure Foundry.""" super().__init__(config) self.config = config logger.info("Initializing R2RCompletionProvider...") self._openai_provider = OpenAICompletionProvider( self.config, *args, **kwargs ) self._anthropic_provider = AnthropicCompletionProvider( self.config, *args, **kwargs ) self._litellm_provider = LiteLLMCompletionProvider( self.config, *args, **kwargs ) self._azure_foundry_provider = AzureFoundryCompletionProvider( self.config, *args, **kwargs ) # New provider logger.debug( "R2RCompletionProvider initialized with OpenAI, Anthropic, LiteLLM, and Azure Foundry sub-providers." ) def _choose_subprovider_by_model( self, model_name: str, is_streaming: bool = False ) -> CompletionProvider: """Decide which underlying sub-provider to call based on the model name (prefix).""" # Route to Anthropic if appropriate. if model_name.startswith("anthropic/"): return self._anthropic_provider # Route to Azure Foundry explicitly. if model_name.startswith("azure-foundry/"): return self._azure_foundry_provider # OpenAI-like prefixes. openai_like_prefixes = [ "openai/", "azure/", "deepseek/", "ollama/", "lmstudio/", ] if ( any( model_name.startswith(prefix) for prefix in openai_like_prefixes ) or "/" not in model_name ): return self._openai_provider # Fallback to LiteLLM. return self._litellm_provider async def _execute_task(self, task: dict[str, Any]): """Pick the sub-provider based on model name and forward the async call.""" generation_config: GenerationConfig = task["generation_config"] model_name = generation_config.model sub_provider = self._choose_subprovider_by_model(model_name or "") return await sub_provider._execute_task(task) def _execute_task_sync(self, task: dict[str, Any]): """Pick the sub-provider based on model name and forward the sync call.""" generation_config: GenerationConfig = task["generation_config"] model_name = generation_config.model sub_provider = self._choose_subprovider_by_model(model_name or "") return sub_provider._execute_task_sync(task)