diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/router.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/router.py | 6249 |
1 files changed, 6249 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/router.py b/.venv/lib/python3.12/site-packages/litellm/router.py new file mode 100644 index 00000000..a395c851 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/router.py @@ -0,0 +1,6249 @@ +# +-----------------------------------------------+ +# | | +# | Give Feedback / Get Help | +# | https://github.com/BerriAI/litellm/issues/new | +# | | +# +-----------------------------------------------+ +# +# Thank you ! We ❤️ you! - Krrish & Ishaan + +import asyncio +import copy +import enum +import hashlib +import inspect +import json +import logging +import threading +import time +import traceback +import uuid +from collections import defaultdict +from functools import lru_cache +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Literal, + Optional, + Tuple, + Union, + cast, +) + +import httpx +import openai +from openai import AsyncOpenAI +from pydantic import BaseModel +from typing_extensions import overload + +import litellm +import litellm.litellm_core_utils +import litellm.litellm_core_utils.exception_mapping_utils +from litellm import get_secret_str +from litellm._logging import verbose_router_logger +from litellm.caching.caching import ( + DualCache, + InMemoryCache, + RedisCache, + RedisClusterCache, +) +from litellm.integrations.custom_logger import CustomLogger +from litellm.litellm_core_utils.asyncify import run_async_function +from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs +from litellm.litellm_core_utils.dd_tracing import tracer +from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging +from litellm.router_strategy.budget_limiter import RouterBudgetLimiting +from litellm.router_strategy.least_busy import LeastBusyLoggingHandler +from litellm.router_strategy.lowest_cost import LowestCostLoggingHandler +from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler +from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler +from litellm.router_strategy.lowest_tpm_rpm_v2 import LowestTPMLoggingHandler_v2 +from litellm.router_strategy.simple_shuffle import simple_shuffle +from litellm.router_strategy.tag_based_routing import get_deployments_for_tag +from litellm.router_utils.add_retry_fallback_headers import ( + add_fallback_headers_to_response, + add_retry_headers_to_response, +) +from litellm.router_utils.batch_utils import ( + _get_router_metadata_variable_name, + replace_model_in_jsonl, +) +from litellm.router_utils.client_initalization_utils import InitalizeCachedClient +from litellm.router_utils.clientside_credential_handler import ( + get_dynamic_litellm_params, + is_clientside_credential, +) +from litellm.router_utils.cooldown_cache import CooldownCache +from litellm.router_utils.cooldown_handlers import ( + DEFAULT_COOLDOWN_TIME_SECONDS, + _async_get_cooldown_deployments, + _async_get_cooldown_deployments_with_debug_info, + _get_cooldown_deployments, + _set_cooldown_deployments, +) +from litellm.router_utils.fallback_event_handlers import ( + _check_non_standard_fallback_format, + get_fallback_model_group, + run_async_fallback, +) +from litellm.router_utils.get_retry_from_policy import ( + get_num_retries_from_retry_policy as _get_num_retries_from_retry_policy, +) +from litellm.router_utils.handle_error import ( + async_raise_no_deployment_exception, + send_llm_exception_alert, +) +from litellm.router_utils.pre_call_checks.prompt_caching_deployment_check import ( + PromptCachingDeploymentCheck, +) +from litellm.router_utils.router_callbacks.track_deployment_metrics import ( + increment_deployment_failures_for_current_minute, + increment_deployment_successes_for_current_minute, +) +from litellm.scheduler import FlowItem, Scheduler +from litellm.types.llms.openai import AllMessageValues, Batch, FileObject, FileTypes +from litellm.types.router import ( + CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS, + VALID_LITELLM_ENVIRONMENTS, + AlertingConfig, + AllowedFailsPolicy, + AssistantsTypedDict, + CredentialLiteLLMParams, + CustomRoutingStrategyBase, + Deployment, + DeploymentTypedDict, + LiteLLM_Params, + ModelGroupInfo, + OptionalPreCallChecks, + RetryPolicy, + RouterCacheEnum, + RouterGeneralSettings, + RouterModelGroupAliasItem, + RouterRateLimitError, + RouterRateLimitErrorBasic, + RoutingStrategy, +) +from litellm.types.services import ServiceTypes +from litellm.types.utils import GenericBudgetConfigType +from litellm.types.utils import ModelInfo as ModelMapInfo +from litellm.types.utils import StandardLoggingPayload +from litellm.utils import ( + CustomStreamWrapper, + EmbeddingResponse, + ModelResponse, + Rules, + function_setup, + get_llm_provider, + get_non_default_completion_params, + get_secret, + get_utc_datetime, + is_region_allowed, +) + +from .router_utils.pattern_match_deployments import PatternMatchRouter + +if TYPE_CHECKING: + from opentelemetry.trace import Span as _Span + + Span = _Span +else: + Span = Any + + +class RoutingArgs(enum.Enum): + ttl = 60 # 1min (RPM/TPM expire key) + + +class Router: + model_names: List = [] + cache_responses: Optional[bool] = False + default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour + tenacity = None + leastbusy_logger: Optional[LeastBusyLoggingHandler] = None + lowesttpm_logger: Optional[LowestTPMLoggingHandler] = None + + def __init__( # noqa: PLR0915 + self, + model_list: Optional[ + Union[List[DeploymentTypedDict], List[Dict[str, Any]]] + ] = None, + ## ASSISTANTS API ## + assistants_config: Optional[AssistantsTypedDict] = None, + ## CACHING ## + redis_url: Optional[str] = None, + redis_host: Optional[str] = None, + redis_port: Optional[int] = None, + redis_password: Optional[str] = None, + cache_responses: Optional[bool] = False, + cache_kwargs: dict = {}, # additional kwargs to pass to RedisCache (see caching.py) + caching_groups: Optional[ + List[tuple] + ] = None, # if you want to cache across model groups + client_ttl: int = 3600, # ttl for cached clients - will re-initialize after this time in seconds + ## SCHEDULER ## + polling_interval: Optional[float] = None, + default_priority: Optional[int] = None, + ## RELIABILITY ## + num_retries: Optional[int] = None, + max_fallbacks: Optional[ + int + ] = None, # max fallbacks to try before exiting the call. Defaults to 5. + timeout: Optional[float] = None, + stream_timeout: Optional[float] = None, + default_litellm_params: Optional[ + dict + ] = None, # default params for Router.chat.completion.create + default_max_parallel_requests: Optional[int] = None, + set_verbose: bool = False, + debug_level: Literal["DEBUG", "INFO"] = "INFO", + default_fallbacks: Optional[ + List[str] + ] = None, # generic fallbacks, works across all deployments + fallbacks: List = [], + context_window_fallbacks: List = [], + content_policy_fallbacks: List = [], + model_group_alias: Optional[ + Dict[str, Union[str, RouterModelGroupAliasItem]] + ] = {}, + enable_pre_call_checks: bool = False, + enable_tag_filtering: bool = False, + retry_after: int = 0, # min time to wait before retrying a failed request + retry_policy: Optional[ + Union[RetryPolicy, dict] + ] = None, # set custom retries for different exceptions + model_group_retry_policy: Dict[ + str, RetryPolicy + ] = {}, # set custom retry policies based on model group + allowed_fails: Optional[ + int + ] = None, # Number of times a deployment can failbefore being added to cooldown + allowed_fails_policy: Optional[ + AllowedFailsPolicy + ] = None, # set custom allowed fails policy + cooldown_time: Optional[ + float + ] = None, # (seconds) time to cooldown a deployment after failure + disable_cooldowns: Optional[bool] = None, + routing_strategy: Literal[ + "simple-shuffle", + "least-busy", + "usage-based-routing", + "latency-based-routing", + "cost-based-routing", + "usage-based-routing-v2", + ] = "simple-shuffle", + optional_pre_call_checks: Optional[OptionalPreCallChecks] = None, + routing_strategy_args: dict = {}, # just for latency-based + provider_budget_config: Optional[GenericBudgetConfigType] = None, + alerting_config: Optional[AlertingConfig] = None, + router_general_settings: Optional[ + RouterGeneralSettings + ] = RouterGeneralSettings(), + ) -> None: + """ + Initialize the Router class with the given parameters for caching, reliability, and routing strategy. + + Args: + model_list (Optional[list]): List of models to be used. Defaults to None. + redis_url (Optional[str]): URL of the Redis server. Defaults to None. + redis_host (Optional[str]): Hostname of the Redis server. Defaults to None. + redis_port (Optional[int]): Port of the Redis server. Defaults to None. + redis_password (Optional[str]): Password of the Redis server. Defaults to None. + cache_responses (Optional[bool]): Flag to enable caching of responses. Defaults to False. + cache_kwargs (dict): Additional kwargs to pass to RedisCache. Defaults to {}. + caching_groups (Optional[List[tuple]]): List of model groups for caching across model groups. Defaults to None. + client_ttl (int): Time-to-live for cached clients in seconds. Defaults to 3600. + polling_interval: (Optional[float]): frequency of polling queue. Only for '.scheduler_acompletion()'. Default is 3ms. + default_priority: (Optional[int]): the default priority for a request. Only for '.scheduler_acompletion()'. Default is None. + num_retries (Optional[int]): Number of retries for failed requests. Defaults to 2. + timeout (Optional[float]): Timeout for requests. Defaults to None. + default_litellm_params (dict): Default parameters for Router.chat.completion.create. Defaults to {}. + set_verbose (bool): Flag to set verbose mode. Defaults to False. + debug_level (Literal["DEBUG", "INFO"]): Debug level for logging. Defaults to "INFO". + fallbacks (List): List of fallback options. Defaults to []. + context_window_fallbacks (List): List of context window fallback options. Defaults to []. + enable_pre_call_checks (boolean): Filter out deployments which are outside context window limits for a given prompt + model_group_alias (Optional[dict]): Alias for model groups. Defaults to {}. + retry_after (int): Minimum time to wait before retrying a failed request. Defaults to 0. + allowed_fails (Optional[int]): Number of allowed fails before adding to cooldown. Defaults to None. + cooldown_time (float): Time to cooldown a deployment after failure in seconds. Defaults to 1. + routing_strategy (Literal["simple-shuffle", "least-busy", "usage-based-routing", "latency-based-routing", "cost-based-routing"]): Routing strategy. Defaults to "simple-shuffle". + routing_strategy_args (dict): Additional args for latency-based routing. Defaults to {}. + alerting_config (AlertingConfig): Slack alerting configuration. Defaults to None. + provider_budget_config (ProviderBudgetConfig): Provider budget configuration. Use this to set llm_provider budget limits. example $100/day to OpenAI, $100/day to Azure, etc. Defaults to None. + Returns: + Router: An instance of the litellm.Router class. + + Example Usage: + ```python + from litellm import Router + model_list = [ + { + "model_name": "azure-gpt-3.5-turbo", # model alias + "litellm_params": { # params for litellm completion/embedding call + "model": "azure/<your-deployment-name-1>", + "api_key": <your-api-key>, + "api_version": <your-api-version>, + "api_base": <your-api-base> + }, + }, + { + "model_name": "azure-gpt-3.5-turbo", # model alias + "litellm_params": { # params for litellm completion/embedding call + "model": "azure/<your-deployment-name-2>", + "api_key": <your-api-key>, + "api_version": <your-api-version>, + "api_base": <your-api-base> + }, + }, + { + "model_name": "openai-gpt-3.5-turbo", # model alias + "litellm_params": { # params for litellm completion/embedding call + "model": "gpt-3.5-turbo", + "api_key": <your-api-key>, + }, + ] + + router = Router(model_list=model_list, fallbacks=[{"azure-gpt-3.5-turbo": "openai-gpt-3.5-turbo"}]) + ``` + """ + + from litellm._service_logger import ServiceLogging + + self.set_verbose = set_verbose + self.debug_level = debug_level + self.enable_pre_call_checks = enable_pre_call_checks + self.enable_tag_filtering = enable_tag_filtering + litellm.suppress_debug_info = True # prevents 'Give Feedback/Get help' message from being emitted on Router - Relevant Issue: https://github.com/BerriAI/litellm/issues/5942 + if self.set_verbose is True: + if debug_level == "INFO": + verbose_router_logger.setLevel(logging.INFO) + elif debug_level == "DEBUG": + verbose_router_logger.setLevel(logging.DEBUG) + self.router_general_settings: RouterGeneralSettings = ( + router_general_settings or RouterGeneralSettings() + ) + + self.assistants_config = assistants_config + self.deployment_names: List = ( + [] + ) # names of models under litellm_params. ex. azure/chatgpt-v-2 + self.deployment_latency_map = {} + ### CACHING ### + cache_type: Literal["local", "redis", "redis-semantic", "s3", "disk"] = ( + "local" # default to an in-memory cache + ) + redis_cache = None + cache_config: Dict[str, Any] = {} + + self.client_ttl = client_ttl + if redis_url is not None or (redis_host is not None and redis_port is not None): + cache_type = "redis" + + if redis_url is not None: + cache_config["url"] = redis_url + + if redis_host is not None: + cache_config["host"] = redis_host + + if redis_port is not None: + cache_config["port"] = str(redis_port) # type: ignore + + if redis_password is not None: + cache_config["password"] = redis_password + + # Add additional key-value pairs from cache_kwargs + cache_config.update(cache_kwargs) + redis_cache = self._create_redis_cache(cache_config) + + if cache_responses: + if litellm.cache is None: + # the cache can be initialized on the proxy server. We should not overwrite it + litellm.cache = litellm.Cache(type=cache_type, **cache_config) # type: ignore + self.cache_responses = cache_responses + self.cache = DualCache( + redis_cache=redis_cache, in_memory_cache=InMemoryCache() + ) # use a dual cache (Redis+In-Memory) for tracking cooldowns, usage, etc. + + ### SCHEDULER ### + self.scheduler = Scheduler( + polling_interval=polling_interval, redis_cache=redis_cache + ) + self.default_priority = default_priority + self.default_deployment = None # use this to track the users default deployment, when they want to use model = * + self.default_max_parallel_requests = default_max_parallel_requests + self.provider_default_deployment_ids: List[str] = [] + self.pattern_router = PatternMatchRouter() + + if model_list is not None: + model_list = copy.deepcopy(model_list) + self.set_model_list(model_list) + self.healthy_deployments: List = self.model_list # type: ignore + for m in model_list: + if "model" in m["litellm_params"]: + self.deployment_latency_map[m["litellm_params"]["model"]] = 0 + else: + self.model_list: List = ( + [] + ) # initialize an empty list - to allow _add_deployment and delete_deployment to work + + if allowed_fails is not None: + self.allowed_fails = allowed_fails + else: + self.allowed_fails = litellm.allowed_fails + self.cooldown_time = cooldown_time or DEFAULT_COOLDOWN_TIME_SECONDS + self.cooldown_cache = CooldownCache( + cache=self.cache, default_cooldown_time=self.cooldown_time + ) + self.disable_cooldowns = disable_cooldowns + self.failed_calls = ( + InMemoryCache() + ) # cache to track failed call per deployment, if num failed calls within 1 minute > allowed fails, then add it to cooldown + + if num_retries is not None: + self.num_retries = num_retries + elif litellm.num_retries is not None: + self.num_retries = litellm.num_retries + else: + self.num_retries = openai.DEFAULT_MAX_RETRIES + + if max_fallbacks is not None: + self.max_fallbacks = max_fallbacks + elif litellm.max_fallbacks is not None: + self.max_fallbacks = litellm.max_fallbacks + else: + self.max_fallbacks = litellm.ROUTER_MAX_FALLBACKS + + self.timeout = timeout or litellm.request_timeout + self.stream_timeout = stream_timeout + + self.retry_after = retry_after + self.routing_strategy = routing_strategy + + ## SETTING FALLBACKS ## + ### validate if it's set + in correct format + _fallbacks = fallbacks or litellm.fallbacks + + self.validate_fallbacks(fallback_param=_fallbacks) + ### set fallbacks + self.fallbacks = _fallbacks + + if default_fallbacks is not None or litellm.default_fallbacks is not None: + _fallbacks = default_fallbacks or litellm.default_fallbacks + if self.fallbacks is not None: + self.fallbacks.append({"*": _fallbacks}) + else: + self.fallbacks = [{"*": _fallbacks}] + + self.context_window_fallbacks = ( + context_window_fallbacks or litellm.context_window_fallbacks + ) + + _content_policy_fallbacks = ( + content_policy_fallbacks or litellm.content_policy_fallbacks + ) + self.validate_fallbacks(fallback_param=_content_policy_fallbacks) + self.content_policy_fallbacks = _content_policy_fallbacks + self.total_calls: defaultdict = defaultdict( + int + ) # dict to store total calls made to each model + self.fail_calls: defaultdict = defaultdict( + int + ) # dict to store fail_calls made to each model + self.success_calls: defaultdict = defaultdict( + int + ) # dict to store success_calls made to each model + self.previous_models: List = ( + [] + ) # list to store failed calls (passed in as metadata to next call) + self.model_group_alias: Dict[str, Union[str, RouterModelGroupAliasItem]] = ( + model_group_alias or {} + ) # dict to store aliases for router, ex. {"gpt-4": "gpt-3.5-turbo"}, all requests with gpt-4 -> get routed to gpt-3.5-turbo group + + # make Router.chat.completions.create compatible for openai.chat.completions.create + default_litellm_params = default_litellm_params or {} + self.chat = litellm.Chat(params=default_litellm_params, router_obj=self) + + # default litellm args + self.default_litellm_params = default_litellm_params + self.default_litellm_params.setdefault("timeout", timeout) + self.default_litellm_params.setdefault("max_retries", 0) + self.default_litellm_params.setdefault("metadata", {}).update( + {"caching_groups": caching_groups} + ) + + self.deployment_stats: dict = {} # used for debugging load balancing + """ + deployment_stats = { + "122999-2828282-277: + { + "model": "gpt-3", + "api_base": "http://localhost:4000", + "num_requests": 20, + "avg_latency": 0.001, + "num_failures": 0, + "num_successes": 20 + } + } + """ + ### ROUTING SETUP ### + self.routing_strategy_init( + routing_strategy=routing_strategy, + routing_strategy_args=routing_strategy_args, + ) + self.access_groups = None + ## USAGE TRACKING ## + if isinstance(litellm._async_success_callback, list): + litellm.logging_callback_manager.add_litellm_async_success_callback( + self.deployment_callback_on_success + ) + else: + litellm.logging_callback_manager.add_litellm_async_success_callback( + self.deployment_callback_on_success + ) + if isinstance(litellm.success_callback, list): + litellm.logging_callback_manager.add_litellm_success_callback( + self.sync_deployment_callback_on_success + ) + else: + litellm.success_callback = [self.sync_deployment_callback_on_success] + if isinstance(litellm._async_failure_callback, list): + litellm.logging_callback_manager.add_litellm_async_failure_callback( + self.async_deployment_callback_on_failure + ) + else: + litellm._async_failure_callback = [ + self.async_deployment_callback_on_failure + ] + ## COOLDOWNS ## + if isinstance(litellm.failure_callback, list): + litellm.logging_callback_manager.add_litellm_failure_callback( + self.deployment_callback_on_failure + ) + else: + litellm.failure_callback = [self.deployment_callback_on_failure] + verbose_router_logger.debug( + f"Intialized router with Routing strategy: {self.routing_strategy}\n\n" + f"Routing enable_pre_call_checks: {self.enable_pre_call_checks}\n\n" + f"Routing fallbacks: {self.fallbacks}\n\n" + f"Routing content fallbacks: {self.content_policy_fallbacks}\n\n" + f"Routing context window fallbacks: {self.context_window_fallbacks}\n\n" + f"Router Redis Caching={self.cache.redis_cache}\n" + ) + self.service_logger_obj = ServiceLogging() + self.routing_strategy_args = routing_strategy_args + self.provider_budget_config = provider_budget_config + self.router_budget_logger: Optional[RouterBudgetLimiting] = None + if RouterBudgetLimiting.should_init_router_budget_limiter( + model_list=model_list, provider_budget_config=self.provider_budget_config + ): + if optional_pre_call_checks is not None: + optional_pre_call_checks.append("router_budget_limiting") + else: + optional_pre_call_checks = ["router_budget_limiting"] + self.retry_policy: Optional[RetryPolicy] = None + if retry_policy is not None: + if isinstance(retry_policy, dict): + self.retry_policy = RetryPolicy(**retry_policy) + elif isinstance(retry_policy, RetryPolicy): + self.retry_policy = retry_policy + verbose_router_logger.info( + "\033[32mRouter Custom Retry Policy Set:\n{}\033[0m".format( + self.retry_policy.model_dump(exclude_none=True) + ) + ) + + self.model_group_retry_policy: Optional[Dict[str, RetryPolicy]] = ( + model_group_retry_policy + ) + + self.allowed_fails_policy: Optional[AllowedFailsPolicy] = None + if allowed_fails_policy is not None: + if isinstance(allowed_fails_policy, dict): + self.allowed_fails_policy = AllowedFailsPolicy(**allowed_fails_policy) + elif isinstance(allowed_fails_policy, AllowedFailsPolicy): + self.allowed_fails_policy = allowed_fails_policy + + verbose_router_logger.info( + "\033[32mRouter Custom Allowed Fails Policy Set:\n{}\033[0m".format( + self.allowed_fails_policy.model_dump(exclude_none=True) + ) + ) + + self.alerting_config: Optional[AlertingConfig] = alerting_config + + if optional_pre_call_checks is not None: + self.add_optional_pre_call_checks(optional_pre_call_checks) + + if self.alerting_config is not None: + self._initialize_alerting() + + self.initialize_assistants_endpoint() + self.initialize_router_endpoints() + + def discard(self): + """ + Pseudo-destructor to be invoked to clean up global data structures when router is no longer used. + For now, unhook router's callbacks from all lists + """ + litellm.logging_callback_manager.remove_callback_from_list_by_object( + litellm._async_success_callback, self + ) + litellm.logging_callback_manager.remove_callback_from_list_by_object( + litellm.success_callback, self + ) + litellm.logging_callback_manager.remove_callback_from_list_by_object( + litellm._async_failure_callback, self + ) + litellm.logging_callback_manager.remove_callback_from_list_by_object( + litellm.failure_callback, self + ) + litellm.logging_callback_manager.remove_callback_from_list_by_object( + litellm.input_callback, self + ) + litellm.logging_callback_manager.remove_callback_from_list_by_object( + litellm.service_callback, self + ) + litellm.logging_callback_manager.remove_callback_from_list_by_object( + litellm.callbacks, self + ) + + @staticmethod + def _create_redis_cache( + cache_config: Dict[str, Any] + ) -> Union[RedisCache, RedisClusterCache]: + """ + Initializes either a RedisCache or RedisClusterCache based on the cache_config. + """ + if cache_config.get("startup_nodes"): + return RedisClusterCache(**cache_config) + else: + return RedisCache(**cache_config) + + def _update_redis_cache(self, cache: RedisCache): + """ + Update the redis cache for the router, if none set. + + Allows proxy user to just do + ```yaml + litellm_settings: + cache: true + ``` + and caching to just work. + """ + if self.cache.redis_cache is None: + self.cache.redis_cache = cache + + def routing_strategy_init( + self, routing_strategy: Union[RoutingStrategy, str], routing_strategy_args: dict + ): + verbose_router_logger.info(f"Routing strategy: {routing_strategy}") + if ( + routing_strategy == RoutingStrategy.LEAST_BUSY.value + or routing_strategy == RoutingStrategy.LEAST_BUSY + ): + self.leastbusy_logger = LeastBusyLoggingHandler( + router_cache=self.cache, model_list=self.model_list + ) + ## add callback + if isinstance(litellm.input_callback, list): + litellm.input_callback.append(self.leastbusy_logger) # type: ignore + else: + litellm.input_callback = [self.leastbusy_logger] # type: ignore + if isinstance(litellm.callbacks, list): + litellm.logging_callback_manager.add_litellm_callback(self.leastbusy_logger) # type: ignore + elif ( + routing_strategy == RoutingStrategy.USAGE_BASED_ROUTING.value + or routing_strategy == RoutingStrategy.USAGE_BASED_ROUTING + ): + self.lowesttpm_logger = LowestTPMLoggingHandler( + router_cache=self.cache, + model_list=self.model_list, + routing_args=routing_strategy_args, + ) + if isinstance(litellm.callbacks, list): + litellm.logging_callback_manager.add_litellm_callback(self.lowesttpm_logger) # type: ignore + elif ( + routing_strategy == RoutingStrategy.USAGE_BASED_ROUTING_V2.value + or routing_strategy == RoutingStrategy.USAGE_BASED_ROUTING_V2 + ): + self.lowesttpm_logger_v2 = LowestTPMLoggingHandler_v2( + router_cache=self.cache, + model_list=self.model_list, + routing_args=routing_strategy_args, + ) + if isinstance(litellm.callbacks, list): + litellm.logging_callback_manager.add_litellm_callback(self.lowesttpm_logger_v2) # type: ignore + elif ( + routing_strategy == RoutingStrategy.LATENCY_BASED.value + or routing_strategy == RoutingStrategy.LATENCY_BASED + ): + self.lowestlatency_logger = LowestLatencyLoggingHandler( + router_cache=self.cache, + model_list=self.model_list, + routing_args=routing_strategy_args, + ) + if isinstance(litellm.callbacks, list): + litellm.logging_callback_manager.add_litellm_callback(self.lowestlatency_logger) # type: ignore + elif ( + routing_strategy == RoutingStrategy.COST_BASED.value + or routing_strategy == RoutingStrategy.COST_BASED + ): + self.lowestcost_logger = LowestCostLoggingHandler( + router_cache=self.cache, + model_list=self.model_list, + routing_args={}, + ) + if isinstance(litellm.callbacks, list): + litellm.logging_callback_manager.add_litellm_callback(self.lowestcost_logger) # type: ignore + else: + pass + + def initialize_assistants_endpoint(self): + ## INITIALIZE PASS THROUGH ASSISTANTS ENDPOINT ## + self.acreate_assistants = self.factory_function(litellm.acreate_assistants) + self.adelete_assistant = self.factory_function(litellm.adelete_assistant) + self.aget_assistants = self.factory_function(litellm.aget_assistants) + self.acreate_thread = self.factory_function(litellm.acreate_thread) + self.aget_thread = self.factory_function(litellm.aget_thread) + self.a_add_message = self.factory_function(litellm.a_add_message) + self.aget_messages = self.factory_function(litellm.aget_messages) + self.arun_thread = self.factory_function(litellm.arun_thread) + + def initialize_router_endpoints(self): + self.amoderation = self.factory_function( + litellm.amoderation, call_type="moderation" + ) + self.aanthropic_messages = self.factory_function( + litellm.anthropic_messages, call_type="anthropic_messages" + ) + self.aresponses = self.factory_function( + litellm.aresponses, call_type="aresponses" + ) + self.responses = self.factory_function(litellm.responses, call_type="responses") + + def validate_fallbacks(self, fallback_param: Optional[List]): + """ + Validate the fallbacks parameter. + """ + if fallback_param is None: + return + for fallback_dict in fallback_param: + if not isinstance(fallback_dict, dict): + raise ValueError(f"Item '{fallback_dict}' is not a dictionary.") + if len(fallback_dict) != 1: + raise ValueError( + f"Dictionary '{fallback_dict}' must have exactly one key, but has {len(fallback_dict)} keys." + ) + + def add_optional_pre_call_checks( + self, optional_pre_call_checks: Optional[OptionalPreCallChecks] + ): + if optional_pre_call_checks is not None: + for pre_call_check in optional_pre_call_checks: + _callback: Optional[CustomLogger] = None + if pre_call_check == "prompt_caching": + _callback = PromptCachingDeploymentCheck(cache=self.cache) + elif pre_call_check == "router_budget_limiting": + _callback = RouterBudgetLimiting( + dual_cache=self.cache, + provider_budget_config=self.provider_budget_config, + model_list=self.model_list, + ) + if _callback is not None: + litellm.logging_callback_manager.add_litellm_callback(_callback) + + def print_deployment(self, deployment: dict): + """ + returns a copy of the deployment with the api key masked + + Only returns 2 characters of the api key and masks the rest with * (10 *). + """ + try: + _deployment_copy = copy.deepcopy(deployment) + litellm_params: dict = _deployment_copy["litellm_params"] + if "api_key" in litellm_params: + litellm_params["api_key"] = litellm_params["api_key"][:2] + "*" * 10 + return _deployment_copy + except Exception as e: + verbose_router_logger.debug( + f"Error occurred while printing deployment - {str(e)}" + ) + raise e + + ### COMPLETION, EMBEDDING, IMG GENERATION FUNCTIONS + + def completion( + self, model: str, messages: List[Dict[str, str]], **kwargs + ) -> Union[ModelResponse, CustomStreamWrapper]: + """ + Example usage: + response = router.completion(model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hey, how's it going?"}] + """ + try: + verbose_router_logger.debug(f"router.completion(model={model},..)") + kwargs["model"] = model + kwargs["messages"] = messages + kwargs["original_function"] = self._completion + self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs) + + response = self.function_with_fallbacks(**kwargs) + return response + except Exception as e: + raise e + + def _completion( + self, model: str, messages: List[Dict[str, str]], **kwargs + ) -> Union[ModelResponse, CustomStreamWrapper]: + model_name = None + try: + # pick the one that is available (lowest TPM/RPM) + deployment = self.get_available_deployment( + model=model, + messages=messages, + specific_deployment=kwargs.pop("specific_deployment", None), + ) + self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs) + + data = deployment["litellm_params"].copy() + model_name = data["model"] + potential_model_client = self._get_client( + deployment=deployment, kwargs=kwargs + ) + # check if provided keys == client keys # + dynamic_api_key = kwargs.get("api_key", None) + if ( + dynamic_api_key is not None + and potential_model_client is not None + and dynamic_api_key != potential_model_client.api_key + ): + model_client = None + else: + model_client = potential_model_client + + ### DEPLOYMENT-SPECIFIC PRE-CALL CHECKS ### (e.g. update rpm pre-call. Raise error, if deployment over limit) + ## only run if model group given, not model id + if model not in self.get_model_ids(): + self.routing_strategy_pre_call_checks(deployment=deployment) + + response = litellm.completion( + **{ + **data, + "messages": messages, + "caching": self.cache_responses, + "client": model_client, + **kwargs, + } + ) + verbose_router_logger.info( + f"litellm.completion(model={model_name})\033[32m 200 OK\033[0m" + ) + + ## CHECK CONTENT FILTER ERROR ## + if isinstance(response, ModelResponse): + _should_raise = self._should_raise_content_policy_error( + model=model, response=response, kwargs=kwargs + ) + if _should_raise: + raise litellm.ContentPolicyViolationError( + message="Response output was blocked.", + model=model, + llm_provider="", + ) + + return response + except Exception as e: + verbose_router_logger.info( + f"litellm.completion(model={model_name})\033[31m Exception {str(e)}\033[0m" + ) + raise e + + # fmt: off + + @overload + async def acompletion( + self, model: str, messages: List[AllMessageValues], stream: Literal[True], **kwargs + ) -> CustomStreamWrapper: + ... + + @overload + async def acompletion( + self, model: str, messages: List[AllMessageValues], stream: Literal[False] = False, **kwargs + ) -> ModelResponse: + ... + + @overload + async def acompletion( + self, model: str, messages: List[AllMessageValues], stream: Union[Literal[True], Literal[False]] = False, **kwargs + ) -> Union[CustomStreamWrapper, ModelResponse]: + ... + + # fmt: on + + # The actual implementation of the function + async def acompletion( + self, + model: str, + messages: List[AllMessageValues], + stream: bool = False, + **kwargs, + ): + try: + kwargs["model"] = model + kwargs["messages"] = messages + kwargs["stream"] = stream + kwargs["original_function"] = self._acompletion + + self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs) + request_priority = kwargs.get("priority") or self.default_priority + start_time = time.time() + _is_prompt_management_model = self._is_prompt_management_model(model) + + if _is_prompt_management_model: + return await self._prompt_management_factory( + model=model, + messages=messages, + kwargs=kwargs, + ) + if request_priority is not None and isinstance(request_priority, int): + response = await self.schedule_acompletion(**kwargs) + else: + response = await self.async_function_with_fallbacks(**kwargs) + end_time = time.time() + _duration = end_time - start_time + asyncio.create_task( + self.service_logger_obj.async_service_success_hook( + service=ServiceTypes.ROUTER, + duration=_duration, + call_type="acompletion", + start_time=start_time, + end_time=end_time, + parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs), + ) + ) + + return response + except Exception as e: + asyncio.create_task( + send_llm_exception_alert( + litellm_router_instance=self, + request_kwargs=kwargs, + error_traceback_str=traceback.format_exc(), + original_exception=e, + ) + ) + raise e + + async def _acompletion( + self, model: str, messages: List[Dict[str, str]], **kwargs + ) -> Union[ModelResponse, CustomStreamWrapper]: + """ + - Get an available deployment + - call it with a semaphore over the call + - semaphore specific to it's rpm + - in the semaphore, make a check against it's local rpm before running + """ + model_name = None + _timeout_debug_deployment_dict = ( + {} + ) # this is a temporary dict to debug timeout issues + try: + verbose_router_logger.debug( + f"Inside _acompletion()- model: {model}; kwargs: {kwargs}" + ) + parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs) + start_time = time.time() + deployment = await self.async_get_available_deployment( + model=model, + messages=messages, + specific_deployment=kwargs.pop("specific_deployment", None), + request_kwargs=kwargs, + ) + + _timeout_debug_deployment_dict = deployment + end_time = time.time() + _duration = end_time - start_time + asyncio.create_task( + self.service_logger_obj.async_service_success_hook( + service=ServiceTypes.ROUTER, + duration=_duration, + call_type="async_get_available_deployment", + start_time=start_time, + end_time=end_time, + parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs), + ) + ) + + # debug how often this deployment picked + + self._track_deployment_metrics( + deployment=deployment, parent_otel_span=parent_otel_span + ) + self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs) + data = deployment["litellm_params"].copy() + + model_name = data["model"] + + model_client = self._get_async_openai_model_client( + deployment=deployment, + kwargs=kwargs, + ) + self.total_calls[model_name] += 1 + + _response = litellm.acompletion( + **{ + **data, + "messages": messages, + "caching": self.cache_responses, + "client": model_client, + **kwargs, + } + ) + + logging_obj: Optional[LiteLLMLogging] = kwargs.get( + "litellm_logging_obj", None + ) + + rpm_semaphore = self._get_client( + deployment=deployment, + kwargs=kwargs, + client_type="max_parallel_requests", + ) + if rpm_semaphore is not None and isinstance( + rpm_semaphore, asyncio.Semaphore + ): + async with rpm_semaphore: + """ + - Check rpm limits before making the call + - If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe) + """ + await self.async_routing_strategy_pre_call_checks( + deployment=deployment, + logging_obj=logging_obj, + parent_otel_span=parent_otel_span, + ) + response = await _response + else: + await self.async_routing_strategy_pre_call_checks( + deployment=deployment, + logging_obj=logging_obj, + parent_otel_span=parent_otel_span, + ) + + response = await _response + + ## CHECK CONTENT FILTER ERROR ## + if isinstance(response, ModelResponse): + _should_raise = self._should_raise_content_policy_error( + model=model, response=response, kwargs=kwargs + ) + if _should_raise: + raise litellm.ContentPolicyViolationError( + message="Response output was blocked.", + model=model, + llm_provider="", + ) + + self.success_calls[model_name] += 1 + verbose_router_logger.info( + f"litellm.acompletion(model={model_name})\033[32m 200 OK\033[0m" + ) + # debug how often this deployment picked + self._track_deployment_metrics( + deployment=deployment, + response=response, + parent_otel_span=parent_otel_span, + ) + + return response + except litellm.Timeout as e: + deployment_request_timeout_param = _timeout_debug_deployment_dict.get( + "litellm_params", {} + ).get("request_timeout", None) + deployment_timeout_param = _timeout_debug_deployment_dict.get( + "litellm_params", {} + ).get("timeout", None) + e.message += f"\n\nDeployment Info: request_timeout: {deployment_request_timeout_param}\ntimeout: {deployment_timeout_param}" + raise e + except Exception as e: + verbose_router_logger.info( + f"litellm.acompletion(model={model_name})\033[31m Exception {str(e)}\033[0m" + ) + if model_name is not None: + self.fail_calls[model_name] += 1 + raise e + + def _update_kwargs_before_fallbacks(self, model: str, kwargs: dict) -> None: + """ + Adds/updates to kwargs: + - num_retries + - litellm_trace_id + - metadata + """ + kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) + kwargs.setdefault("litellm_trace_id", str(uuid.uuid4())) + kwargs.setdefault("metadata", {}).update({"model_group": model}) + + def _update_kwargs_with_default_litellm_params( + self, kwargs: dict, metadata_variable_name: Optional[str] = "metadata" + ) -> None: + """ + Adds default litellm params to kwargs, if set. + """ + self.default_litellm_params[metadata_variable_name] = ( + self.default_litellm_params.pop("metadata", {}) + ) + for k, v in self.default_litellm_params.items(): + if ( + k not in kwargs and v is not None + ): # prioritize model-specific params > default router params + kwargs[k] = v + elif k == metadata_variable_name: + kwargs[metadata_variable_name].update(v) + + def _handle_clientside_credential( + self, deployment: dict, kwargs: dict + ) -> Deployment: + """ + Handle clientside credential + """ + model_info = deployment.get("model_info", {}).copy() + litellm_params = deployment["litellm_params"].copy() + dynamic_litellm_params = get_dynamic_litellm_params( + litellm_params=litellm_params, request_kwargs=kwargs + ) + metadata = kwargs.get("metadata", {}) + model_group = cast(str, metadata.get("model_group")) + _model_id = self._generate_model_id( + model_group=model_group, litellm_params=dynamic_litellm_params + ) + original_model_id = model_info.get("id") + model_info["id"] = _model_id + model_info["original_model_id"] = original_model_id + deployment_pydantic_obj = Deployment( + model_name=model_group, + litellm_params=LiteLLM_Params(**dynamic_litellm_params), + model_info=model_info, + ) + self.upsert_deployment( + deployment=deployment_pydantic_obj + ) # add new deployment to router + return deployment_pydantic_obj + + def _update_kwargs_with_deployment( + self, + deployment: dict, + kwargs: dict, + function_name: Optional[str] = None, + ) -> None: + """ + 2 jobs: + - Adds selected deployment, model_info and api_base to kwargs["metadata"] (used for logging) + - Adds default litellm params to kwargs, if set. + """ + model_info = deployment.get("model_info", {}).copy() + deployment_model_name = deployment["litellm_params"]["model"] + deployment_api_base = deployment["litellm_params"].get("api_base") + if is_clientside_credential(request_kwargs=kwargs): + deployment_pydantic_obj = self._handle_clientside_credential( + deployment=deployment, kwargs=kwargs + ) + model_info = deployment_pydantic_obj.model_info.model_dump() + deployment_model_name = deployment_pydantic_obj.litellm_params.model + deployment_api_base = deployment_pydantic_obj.litellm_params.api_base + + metadata_variable_name = _get_router_metadata_variable_name( + function_name=function_name, + ) + kwargs.setdefault(metadata_variable_name, {}).update( + { + "deployment": deployment_model_name, + "model_info": model_info, + "api_base": deployment_api_base, + } + ) + kwargs["model_info"] = model_info + + kwargs["timeout"] = self._get_timeout( + kwargs=kwargs, data=deployment["litellm_params"] + ) + + self._update_kwargs_with_default_litellm_params( + kwargs=kwargs, metadata_variable_name=metadata_variable_name + ) + + def _get_async_openai_model_client(self, deployment: dict, kwargs: dict): + """ + Helper to get AsyncOpenAI or AsyncAzureOpenAI client that was created for the deployment + + The same OpenAI client is re-used to optimize latency / performance in production + + If dynamic api key is provided: + Do not re-use the client. Pass model_client=None. The OpenAI/ AzureOpenAI client will be recreated in the handler for the llm provider + """ + potential_model_client = self._get_client( + deployment=deployment, kwargs=kwargs, client_type="async" + ) + + # check if provided keys == client keys # + dynamic_api_key = kwargs.get("api_key", None) + if ( + dynamic_api_key is not None + and potential_model_client is not None + and dynamic_api_key != potential_model_client.api_key + ): + model_client = None + else: + model_client = potential_model_client + + return model_client + + def _get_stream_timeout( + self, kwargs: dict, data: dict + ) -> Optional[Union[float, int]]: + """Helper to get stream timeout from kwargs or deployment params""" + return ( + kwargs.get("stream_timeout", None) # the params dynamically set by user + or data.get( + "stream_timeout", None + ) # timeout set on litellm_params for this deployment + or self.stream_timeout # timeout set on router + or self.default_litellm_params.get("stream_timeout", None) + ) + + def _get_non_stream_timeout( + self, kwargs: dict, data: dict + ) -> Optional[Union[float, int]]: + """Helper to get non-stream timeout from kwargs or deployment params""" + timeout = ( + kwargs.get("timeout", None) # the params dynamically set by user + or kwargs.get("request_timeout", None) # the params dynamically set by user + or data.get( + "timeout", None + ) # timeout set on litellm_params for this deployment + or data.get( + "request_timeout", None + ) # timeout set on litellm_params for this deployment + or self.timeout # timeout set on router + or self.default_litellm_params.get("timeout", None) + ) + return timeout + + def _get_timeout(self, kwargs: dict, data: dict) -> Optional[Union[float, int]]: + """Helper to get timeout from kwargs or deployment params""" + timeout: Optional[Union[float, int]] = None + if kwargs.get("stream", False): + timeout = self._get_stream_timeout(kwargs=kwargs, data=data) + if timeout is None: + timeout = self._get_non_stream_timeout( + kwargs=kwargs, data=data + ) # default to this if no stream specific timeout set + return timeout + + async def abatch_completion( + self, + models: List[str], + messages: Union[List[Dict[str, str]], List[List[Dict[str, str]]]], + **kwargs, + ): + """ + Async Batch Completion. Used for 2 scenarios: + 1. Batch Process 1 request to N models on litellm.Router. Pass messages as List[Dict[str, str]] to use this + 2. Batch Process N requests to M models on litellm.Router. Pass messages as List[List[Dict[str, str]]] to use this + + Example Request for 1 request to N models: + ``` + response = await router.abatch_completion( + models=["gpt-3.5-turbo", "groq-llama"], + messages=[ + {"role": "user", "content": "is litellm becoming a better product ?"} + ], + max_tokens=15, + ) + ``` + + + Example Request for N requests to M models: + ``` + response = await router.abatch_completion( + models=["gpt-3.5-turbo", "groq-llama"], + messages=[ + [{"role": "user", "content": "is litellm becoming a better product ?"}], + [{"role": "user", "content": "who is this"}], + ], + ) + ``` + """ + ############## Helpers for async completion ################## + + async def _async_completion_no_exceptions( + model: str, messages: List[AllMessageValues], **kwargs + ): + """ + Wrapper around self.async_completion that catches exceptions and returns them as a result + """ + try: + return await self.acompletion(model=model, messages=messages, **kwargs) + except Exception as e: + return e + + async def _async_completion_no_exceptions_return_idx( + model: str, + messages: List[AllMessageValues], + idx: int, # index of message this response corresponds to + **kwargs, + ): + """ + Wrapper around self.async_completion that catches exceptions and returns them as a result + """ + try: + return ( + await self.acompletion(model=model, messages=messages, **kwargs), + idx, + ) + except Exception as e: + return e, idx + + ############## Helpers for async completion ################## + + if isinstance(messages, list) and all(isinstance(m, dict) for m in messages): + _tasks = [] + for model in models: + # add each task but if the task fails + _tasks.append(_async_completion_no_exceptions(model=model, messages=messages, **kwargs)) # type: ignore + response = await asyncio.gather(*_tasks) + return response + elif isinstance(messages, list) and all(isinstance(m, list) for m in messages): + _tasks = [] + for idx, message in enumerate(messages): + for model in models: + # Request Number X, Model Number Y + _tasks.append( + _async_completion_no_exceptions_return_idx( + model=model, idx=idx, messages=message, **kwargs # type: ignore + ) + ) + responses = await asyncio.gather(*_tasks) + final_responses: List[List[Any]] = [[] for _ in range(len(messages))] + for response in responses: + if isinstance(response, tuple): + final_responses[response[1]].append(response[0]) + else: + final_responses[0].append(response) + return final_responses + + async def abatch_completion_one_model_multiple_requests( + self, model: str, messages: List[List[AllMessageValues]], **kwargs + ): + """ + Async Batch Completion - Batch Process multiple Messages to one model_group on litellm.Router + + Use this for sending multiple requests to 1 model + + Args: + model (List[str]): model group + messages (List[List[Dict[str, str]]]): list of messages. Each element in the list is one request + **kwargs: additional kwargs + Usage: + response = await self.abatch_completion_one_model_multiple_requests( + model="gpt-3.5-turbo", + messages=[ + [{"role": "user", "content": "hello"}, {"role": "user", "content": "tell me something funny"}], + [{"role": "user", "content": "hello good mornign"}], + ] + ) + """ + + async def _async_completion_no_exceptions( + model: str, messages: List[AllMessageValues], **kwargs + ): + """ + Wrapper around self.async_completion that catches exceptions and returns them as a result + """ + try: + return await self.acompletion(model=model, messages=messages, **kwargs) + except Exception as e: + return e + + _tasks = [] + for message_request in messages: + # add each task but if the task fails + _tasks.append( + _async_completion_no_exceptions( + model=model, messages=message_request, **kwargs + ) + ) + + response = await asyncio.gather(*_tasks) + return response + + # fmt: off + + @overload + async def abatch_completion_fastest_response( + self, model: str, messages: List[Dict[str, str]], stream: Literal[True], **kwargs + ) -> CustomStreamWrapper: + ... + + + + @overload + async def abatch_completion_fastest_response( + self, model: str, messages: List[Dict[str, str]], stream: Literal[False] = False, **kwargs + ) -> ModelResponse: + ... + + # fmt: on + + async def abatch_completion_fastest_response( + self, + model: str, + messages: List[Dict[str, str]], + stream: bool = False, + **kwargs, + ): + """ + model - List of comma-separated model names. E.g. model="gpt-4, gpt-3.5-turbo" + + Returns fastest response from list of model names. OpenAI-compatible endpoint. + """ + models = [m.strip() for m in model.split(",")] + + async def _async_completion_no_exceptions( + model: str, messages: List[Dict[str, str]], stream: bool, **kwargs: Any + ) -> Union[ModelResponse, CustomStreamWrapper, Exception]: + """ + Wrapper around self.acompletion that catches exceptions and returns them as a result + """ + try: + return await self.acompletion(model=model, messages=messages, stream=stream, **kwargs) # type: ignore + except asyncio.CancelledError: + verbose_router_logger.debug( + "Received 'task.cancel'. Cancelling call w/ model={}.".format(model) + ) + raise + except Exception as e: + return e + + pending_tasks = [] # type: ignore + + async def check_response(task: asyncio.Task): + nonlocal pending_tasks + try: + result = await task + if isinstance(result, (ModelResponse, CustomStreamWrapper)): + verbose_router_logger.debug( + "Received successful response. Cancelling other LLM API calls." + ) + # If a desired response is received, cancel all other pending tasks + for t in pending_tasks: + t.cancel() + return result + except Exception: + # Ignore exceptions, let the loop handle them + pass + finally: + # Remove the task from pending tasks if it finishes + try: + pending_tasks.remove(task) + except KeyError: + pass + + for model in models: + task = asyncio.create_task( + _async_completion_no_exceptions( + model=model, messages=messages, stream=stream, **kwargs + ) + ) + pending_tasks.append(task) + + # Await the first task to complete successfully + while pending_tasks: + done, pending_tasks = await asyncio.wait( # type: ignore + pending_tasks, return_when=asyncio.FIRST_COMPLETED + ) + for completed_task in done: + result = await check_response(completed_task) + if result is not None: + # Return the first successful result + result._hidden_params["fastest_response_batch_completion"] = True + return result + + # If we exit the loop without returning, all tasks failed + raise Exception("All tasks failed") + + ### SCHEDULER ### + + # fmt: off + + @overload + async def schedule_acompletion( + self, model: str, messages: List[AllMessageValues], priority: int, stream: Literal[False] = False, **kwargs + ) -> ModelResponse: + ... + + @overload + async def schedule_acompletion( + self, model: str, messages: List[AllMessageValues], priority: int, stream: Literal[True], **kwargs + ) -> CustomStreamWrapper: + ... + + # fmt: on + + async def schedule_acompletion( + self, + model: str, + messages: List[AllMessageValues], + priority: int, + stream=False, + **kwargs, + ): + parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs) + ### FLOW ITEM ### + _request_id = str(uuid.uuid4()) + item = FlowItem( + priority=priority, # 👈 SET PRIORITY FOR REQUEST + request_id=_request_id, # 👈 SET REQUEST ID + model_name="gpt-3.5-turbo", # 👈 SAME as 'Router' + ) + ### [fin] ### + + ## ADDS REQUEST TO QUEUE ## + await self.scheduler.add_request(request=item) + + ## POLL QUEUE + end_time = time.time() + self.timeout + curr_time = time.time() + poll_interval = self.scheduler.polling_interval # poll every 3ms + make_request = False + + while curr_time < end_time: + _healthy_deployments, _ = await self._async_get_healthy_deployments( + model=model, parent_otel_span=parent_otel_span + ) + make_request = await self.scheduler.poll( ## POLL QUEUE ## - returns 'True' if there's healthy deployments OR if request is at top of queue + id=item.request_id, + model_name=item.model_name, + health_deployments=_healthy_deployments, + ) + if make_request: ## IF TRUE -> MAKE REQUEST + break + else: ## ELSE -> loop till default_timeout + await asyncio.sleep(poll_interval) + curr_time = time.time() + + if make_request: + try: + _response = await self.acompletion( + model=model, messages=messages, stream=stream, **kwargs + ) + _response._hidden_params.setdefault("additional_headers", {}) + _response._hidden_params["additional_headers"].update( + {"x-litellm-request-prioritization-used": True} + ) + return _response + except Exception as e: + setattr(e, "priority", priority) + raise e + else: + raise litellm.Timeout( + message="Request timed out while polling queue", + model=model, + llm_provider="openai", + ) + + async def _schedule_factory( + self, + model: str, + priority: int, + original_function: Callable, + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + ): + parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs) + ### FLOW ITEM ### + _request_id = str(uuid.uuid4()) + item = FlowItem( + priority=priority, # 👈 SET PRIORITY FOR REQUEST + request_id=_request_id, # 👈 SET REQUEST ID + model_name=model, # 👈 SAME as 'Router' + ) + ### [fin] ### + + ## ADDS REQUEST TO QUEUE ## + await self.scheduler.add_request(request=item) + + ## POLL QUEUE + end_time = time.time() + self.timeout + curr_time = time.time() + poll_interval = self.scheduler.polling_interval # poll every 3ms + make_request = False + + while curr_time < end_time: + _healthy_deployments, _ = await self._async_get_healthy_deployments( + model=model, parent_otel_span=parent_otel_span + ) + make_request = await self.scheduler.poll( ## POLL QUEUE ## - returns 'True' if there's healthy deployments OR if request is at top of queue + id=item.request_id, + model_name=item.model_name, + health_deployments=_healthy_deployments, + ) + if make_request: ## IF TRUE -> MAKE REQUEST + break + else: ## ELSE -> loop till default_timeout + await asyncio.sleep(poll_interval) + curr_time = time.time() + + if make_request: + try: + _response = await original_function(*args, **kwargs) + if isinstance(_response._hidden_params, dict): + _response._hidden_params.setdefault("additional_headers", {}) + _response._hidden_params["additional_headers"].update( + {"x-litellm-request-prioritization-used": True} + ) + return _response + except Exception as e: + setattr(e, "priority", priority) + raise e + else: + raise litellm.Timeout( + message="Request timed out while polling queue", + model=model, + llm_provider="openai", + ) + + def _is_prompt_management_model(self, model: str) -> bool: + model_list = self.get_model_list(model_name=model) + if model_list is None: + return False + if len(model_list) != 1: + return False + + litellm_model = model_list[0]["litellm_params"].get("model", None) + + if litellm_model is None: + return False + + if "/" in litellm_model: + split_litellm_model = litellm_model.split("/")[0] + if split_litellm_model in litellm._known_custom_logger_compatible_callbacks: + return True + return False + + async def _prompt_management_factory( + self, + model: str, + messages: List[AllMessageValues], + kwargs: Dict[str, Any], + ): + litellm_logging_object = kwargs.get("litellm_logging_obj", None) + if litellm_logging_object is None: + litellm_logging_object, kwargs = function_setup( + **{ + "original_function": "acompletion", + "rules_obj": Rules(), + "start_time": get_utc_datetime(), + **kwargs, + } + ) + litellm_logging_object = cast(LiteLLMLogging, litellm_logging_object) + prompt_management_deployment = self.get_available_deployment( + model=model, + messages=[{"role": "user", "content": "prompt"}], + specific_deployment=kwargs.pop("specific_deployment", None), + ) + + litellm_model = prompt_management_deployment["litellm_params"].get( + "model", None + ) + prompt_id = kwargs.get("prompt_id") or prompt_management_deployment[ + "litellm_params" + ].get("prompt_id", None) + prompt_variables = kwargs.get( + "prompt_variables" + ) or prompt_management_deployment["litellm_params"].get( + "prompt_variables", None + ) + + if prompt_id is None or not isinstance(prompt_id, str): + raise ValueError( + f"Prompt ID is not set or not a string. Got={prompt_id}, type={type(prompt_id)}" + ) + if prompt_variables is not None and not isinstance(prompt_variables, dict): + raise ValueError( + f"Prompt variables is set but not a dictionary. Got={prompt_variables}, type={type(prompt_variables)}" + ) + + model, messages, optional_params = ( + litellm_logging_object.get_chat_completion_prompt( + model=litellm_model, + messages=messages, + non_default_params=get_non_default_completion_params(kwargs=kwargs), + prompt_id=prompt_id, + prompt_variables=prompt_variables, + ) + ) + + kwargs = {**kwargs, **optional_params} + kwargs["model"] = model + kwargs["messages"] = messages + kwargs["litellm_logging_obj"] = litellm_logging_object + kwargs["prompt_id"] = prompt_id + kwargs["prompt_variables"] = prompt_variables + + _model_list = self.get_model_list(model_name=model) + if _model_list is None or len(_model_list) == 0: # if direct call to model + kwargs.pop("original_function") + return await litellm.acompletion(**kwargs) + + return await self.async_function_with_fallbacks(**kwargs) + + def image_generation(self, prompt: str, model: str, **kwargs): + try: + kwargs["model"] = model + kwargs["prompt"] = prompt + kwargs["original_function"] = self._image_generation + kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) + kwargs.setdefault("metadata", {}).update({"model_group": model}) + response = self.function_with_fallbacks(**kwargs) + + return response + except Exception as e: + raise e + + def _image_generation(self, prompt: str, model: str, **kwargs): + model_name = "" + try: + verbose_router_logger.debug( + f"Inside _image_generation()- model: {model}; kwargs: {kwargs}" + ) + deployment = self.get_available_deployment( + model=model, + messages=[{"role": "user", "content": "prompt"}], + specific_deployment=kwargs.pop("specific_deployment", None), + ) + self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs) + data = deployment["litellm_params"].copy() + + model_client = self._get_async_openai_model_client( + deployment=deployment, + kwargs=kwargs, + ) + + self.total_calls[model_name] += 1 + + ### DEPLOYMENT-SPECIFIC PRE-CALL CHECKS ### (e.g. update rpm pre-call. Raise error, if deployment over limit) + self.routing_strategy_pre_call_checks(deployment=deployment) + + response = litellm.image_generation( + **{ + **data, + "prompt": prompt, + "caching": self.cache_responses, + "client": model_client, + **kwargs, + } + ) + self.success_calls[model_name] += 1 + verbose_router_logger.info( + f"litellm.image_generation(model={model_name})\033[32m 200 OK\033[0m" + ) + return response + except Exception as e: + verbose_router_logger.info( + f"litellm.image_generation(model={model_name})\033[31m Exception {str(e)}\033[0m" + ) + if model_name is not None: + self.fail_calls[model_name] += 1 + raise e + + async def aimage_generation(self, prompt: str, model: str, **kwargs): + try: + kwargs["model"] = model + kwargs["prompt"] = prompt + kwargs["original_function"] = self._aimage_generation + kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) + self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs) + response = await self.async_function_with_fallbacks(**kwargs) + + return response + except Exception as e: + asyncio.create_task( + send_llm_exception_alert( + litellm_router_instance=self, + request_kwargs=kwargs, + error_traceback_str=traceback.format_exc(), + original_exception=e, + ) + ) + raise e + + async def _aimage_generation(self, prompt: str, model: str, **kwargs): + model_name = model + try: + verbose_router_logger.debug( + f"Inside _image_generation()- model: {model}; kwargs: {kwargs}" + ) + parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs) + deployment = await self.async_get_available_deployment( + model=model, + messages=[{"role": "user", "content": "prompt"}], + specific_deployment=kwargs.pop("specific_deployment", None), + request_kwargs=kwargs, + ) + self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs) + + data = deployment["litellm_params"].copy() + model_name = data["model"] + + model_client = self._get_async_openai_model_client( + deployment=deployment, + kwargs=kwargs, + ) + + self.total_calls[model_name] += 1 + response = litellm.aimage_generation( + **{ + **data, + "prompt": prompt, + "caching": self.cache_responses, + "client": model_client, + **kwargs, + } + ) + + ### CONCURRENCY-SAFE RPM CHECKS ### + rpm_semaphore = self._get_client( + deployment=deployment, + kwargs=kwargs, + client_type="max_parallel_requests", + ) + + if rpm_semaphore is not None and isinstance( + rpm_semaphore, asyncio.Semaphore + ): + async with rpm_semaphore: + """ + - Check rpm limits before making the call + - If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe) + """ + await self.async_routing_strategy_pre_call_checks( + deployment=deployment, parent_otel_span=parent_otel_span + ) + response = await response + else: + await self.async_routing_strategy_pre_call_checks( + deployment=deployment, parent_otel_span=parent_otel_span + ) + response = await response + + self.success_calls[model_name] += 1 + verbose_router_logger.info( + f"litellm.aimage_generation(model={model_name})\033[32m 200 OK\033[0m" + ) + return response + except Exception as e: + verbose_router_logger.info( + f"litellm.aimage_generation(model={model_name})\033[31m Exception {str(e)}\033[0m" + ) + if model_name is not None: + self.fail_calls[model_name] += 1 + raise e + + async def atranscription(self, file: FileTypes, model: str, **kwargs): + """ + Example Usage: + + ``` + from litellm import Router + client = Router(model_list = [ + { + "model_name": "whisper", + "litellm_params": { + "model": "whisper-1", + }, + }, + ]) + + audio_file = open("speech.mp3", "rb") + transcript = await client.atranscription( + model="whisper", + file=audio_file + ) + + ``` + """ + try: + kwargs["model"] = model + kwargs["file"] = file + kwargs["original_function"] = self._atranscription + self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs) + response = await self.async_function_with_fallbacks(**kwargs) + + return response + except Exception as e: + asyncio.create_task( + send_llm_exception_alert( + litellm_router_instance=self, + request_kwargs=kwargs, + error_traceback_str=traceback.format_exc(), + original_exception=e, + ) + ) + raise e + + async def _atranscription(self, file: FileTypes, model: str, **kwargs): + model_name = model + try: + verbose_router_logger.debug( + f"Inside _atranscription()- model: {model}; kwargs: {kwargs}" + ) + parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs) + deployment = await self.async_get_available_deployment( + model=model, + messages=[{"role": "user", "content": "prompt"}], + specific_deployment=kwargs.pop("specific_deployment", None), + request_kwargs=kwargs, + ) + + self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs) + data = deployment["litellm_params"].copy() + model_client = self._get_async_openai_model_client( + deployment=deployment, + kwargs=kwargs, + ) + + self.total_calls[model_name] += 1 + response = litellm.atranscription( + **{ + **data, + "file": file, + "caching": self.cache_responses, + "client": model_client, + **kwargs, + } + ) + + ### CONCURRENCY-SAFE RPM CHECKS ### + rpm_semaphore = self._get_client( + deployment=deployment, + kwargs=kwargs, + client_type="max_parallel_requests", + ) + + if rpm_semaphore is not None and isinstance( + rpm_semaphore, asyncio.Semaphore + ): + async with rpm_semaphore: + """ + - Check rpm limits before making the call + - If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe) + """ + await self.async_routing_strategy_pre_call_checks( + deployment=deployment, parent_otel_span=parent_otel_span + ) + response = await response + else: + await self.async_routing_strategy_pre_call_checks( + deployment=deployment, parent_otel_span=parent_otel_span + ) + response = await response + + self.success_calls[model_name] += 1 + verbose_router_logger.info( + f"litellm.atranscription(model={model_name})\033[32m 200 OK\033[0m" + ) + return response + except Exception as e: + verbose_router_logger.info( + f"litellm.atranscription(model={model_name})\033[31m Exception {str(e)}\033[0m" + ) + if model_name is not None: + self.fail_calls[model_name] += 1 + raise e + + async def aspeech(self, model: str, input: str, voice: str, **kwargs): + """ + Example Usage: + + ``` + from litellm import Router + client = Router(model_list = [ + { + "model_name": "tts", + "litellm_params": { + "model": "tts-1", + }, + }, + ]) + + async with client.aspeech( + model="tts", + voice="alloy", + input="the quick brown fox jumped over the lazy dogs", + api_base=None, + api_key=None, + organization=None, + project=None, + max_retries=1, + timeout=600, + client=None, + optional_params={}, + ) as response: + response.stream_to_file(speech_file_path) + + ``` + """ + try: + kwargs["input"] = input + kwargs["voice"] = voice + + deployment = await self.async_get_available_deployment( + model=model, + messages=[{"role": "user", "content": "prompt"}], + specific_deployment=kwargs.pop("specific_deployment", None), + request_kwargs=kwargs, + ) + self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs) + data = deployment["litellm_params"].copy() + data["model"] + for k, v in self.default_litellm_params.items(): + if ( + k not in kwargs + ): # prioritize model-specific params > default router params + kwargs[k] = v + elif k == "metadata": + kwargs[k].update(v) + + potential_model_client = self._get_client( + deployment=deployment, kwargs=kwargs, client_type="async" + ) + # check if provided keys == client keys # + dynamic_api_key = kwargs.get("api_key", None) + if ( + dynamic_api_key is not None + and potential_model_client is not None + and dynamic_api_key != potential_model_client.api_key + ): + model_client = None + else: + model_client = potential_model_client + + response = await litellm.aspeech( + **{ + **data, + "client": model_client, + **kwargs, + } + ) + return response + except Exception as e: + asyncio.create_task( + send_llm_exception_alert( + litellm_router_instance=self, + request_kwargs=kwargs, + error_traceback_str=traceback.format_exc(), + original_exception=e, + ) + ) + raise e + + async def arerank(self, model: str, **kwargs): + try: + kwargs["model"] = model + kwargs["input"] = input + kwargs["original_function"] = self._arerank + self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs) + + response = await self.async_function_with_fallbacks(**kwargs) + + return response + except Exception as e: + asyncio.create_task( + send_llm_exception_alert( + litellm_router_instance=self, + request_kwargs=kwargs, + error_traceback_str=traceback.format_exc(), + original_exception=e, + ) + ) + raise e + + async def _arerank(self, model: str, **kwargs): + model_name = None + try: + verbose_router_logger.debug( + f"Inside _rerank()- model: {model}; kwargs: {kwargs}" + ) + deployment = await self.async_get_available_deployment( + model=model, + specific_deployment=kwargs.pop("specific_deployment", None), + request_kwargs=kwargs, + ) + self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs) + data = deployment["litellm_params"].copy() + model_name = data["model"] + + model_client = self._get_async_openai_model_client( + deployment=deployment, + kwargs=kwargs, + ) + self.total_calls[model_name] += 1 + + response = await litellm.arerank( + **{ + **data, + "caching": self.cache_responses, + "client": model_client, + **kwargs, + } + ) + + self.success_calls[model_name] += 1 + verbose_router_logger.info( + f"litellm.arerank(model={model_name})\033[32m 200 OK\033[0m" + ) + return response + except Exception as e: + verbose_router_logger.info( + f"litellm.arerank(model={model_name})\033[31m Exception {str(e)}\033[0m" + ) + if model_name is not None: + self.fail_calls[model_name] += 1 + raise e + + async def _arealtime(self, model: str, **kwargs): + messages = [{"role": "user", "content": "dummy-text"}] + try: + kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) + self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs) + + # pick the one that is available (lowest TPM/RPM) + deployment = await self.async_get_available_deployment( + model=model, + messages=messages, + specific_deployment=kwargs.pop("specific_deployment", None), + request_kwargs=kwargs, + ) + + data = deployment["litellm_params"].copy() + for k, v in self.default_litellm_params.items(): + if ( + k not in kwargs + ): # prioritize model-specific params > default router params + kwargs[k] = v + elif k == "metadata": + kwargs[k].update(v) + + return await litellm._arealtime(**{**data, "caching": self.cache_responses, **kwargs}) # type: ignore + except Exception as e: + if self.num_retries > 0: + kwargs["model"] = model + kwargs["messages"] = messages + kwargs["original_function"] = self._arealtime + return await self.async_function_with_retries(**kwargs) + else: + raise e + + def text_completion( + self, + model: str, + prompt: str, + is_retry: Optional[bool] = False, + is_fallback: Optional[bool] = False, + is_async: Optional[bool] = False, + **kwargs, + ): + messages = [{"role": "user", "content": prompt}] + try: + kwargs["model"] = model + kwargs["prompt"] = prompt + kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) + kwargs.setdefault("metadata", {}).update({"model_group": model}) + + # pick the one that is available (lowest TPM/RPM) + deployment = self.get_available_deployment( + model=model, + messages=messages, + specific_deployment=kwargs.pop("specific_deployment", None), + ) + + data = deployment["litellm_params"].copy() + for k, v in self.default_litellm_params.items(): + if ( + k not in kwargs + ): # prioritize model-specific params > default router params + kwargs[k] = v + elif k == "metadata": + kwargs[k].update(v) + + # call via litellm.completion() + return litellm.text_completion(**{**data, "prompt": prompt, "caching": self.cache_responses, **kwargs}) # type: ignore + except Exception as e: + raise e + + async def atext_completion( + self, + model: str, + prompt: str, + is_retry: Optional[bool] = False, + is_fallback: Optional[bool] = False, + is_async: Optional[bool] = False, + **kwargs, + ): + if kwargs.get("priority", None) is not None: + return await self._schedule_factory( + model=model, + priority=kwargs.pop("priority"), + original_function=self.atext_completion, + args=(model, prompt), + kwargs=kwargs, + ) + try: + kwargs["model"] = model + kwargs["prompt"] = prompt + kwargs["original_function"] = self._atext_completion + + self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs) + response = await self.async_function_with_fallbacks(**kwargs) + + return response + except Exception as e: + asyncio.create_task( + send_llm_exception_alert( + litellm_router_instance=self, + request_kwargs=kwargs, + error_traceback_str=traceback.format_exc(), + original_exception=e, + ) + ) + raise e + + async def _atext_completion(self, model: str, prompt: str, **kwargs): + try: + verbose_router_logger.debug( + f"Inside _atext_completion()- model: {model}; kwargs: {kwargs}" + ) + parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs) + deployment = await self.async_get_available_deployment( + model=model, + messages=[{"role": "user", "content": prompt}], + specific_deployment=kwargs.pop("specific_deployment", None), + request_kwargs=kwargs, + ) + self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs) + + data = deployment["litellm_params"].copy() + model_name = data["model"] + + model_client = self._get_async_openai_model_client( + deployment=deployment, + kwargs=kwargs, + ) + self.total_calls[model_name] += 1 + + response = litellm.atext_completion( + **{ + **data, + "prompt": prompt, + "caching": self.cache_responses, + "client": model_client, + **kwargs, + } + ) + + rpm_semaphore = self._get_client( + deployment=deployment, + kwargs=kwargs, + client_type="max_parallel_requests", + ) + + if rpm_semaphore is not None and isinstance( + rpm_semaphore, asyncio.Semaphore + ): + async with rpm_semaphore: + """ + - Check rpm limits before making the call + - If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe) + """ + await self.async_routing_strategy_pre_call_checks( + deployment=deployment, parent_otel_span=parent_otel_span + ) + response = await response + else: + await self.async_routing_strategy_pre_call_checks( + deployment=deployment, parent_otel_span=parent_otel_span + ) + response = await response + + self.success_calls[model_name] += 1 + verbose_router_logger.info( + f"litellm.atext_completion(model={model_name})\033[32m 200 OK\033[0m" + ) + return response + except Exception as e: + verbose_router_logger.info( + f"litellm.atext_completion(model={model})\033[31m Exception {str(e)}\033[0m" + ) + if model is not None: + self.fail_calls[model] += 1 + raise e + + async def aadapter_completion( + self, + adapter_id: str, + model: str, + is_retry: Optional[bool] = False, + is_fallback: Optional[bool] = False, + is_async: Optional[bool] = False, + **kwargs, + ): + try: + kwargs["model"] = model + kwargs["adapter_id"] = adapter_id + kwargs["original_function"] = self._aadapter_completion + kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) + kwargs.setdefault("metadata", {}).update({"model_group": model}) + response = await self.async_function_with_fallbacks(**kwargs) + + return response + except Exception as e: + asyncio.create_task( + send_llm_exception_alert( + litellm_router_instance=self, + request_kwargs=kwargs, + error_traceback_str=traceback.format_exc(), + original_exception=e, + ) + ) + raise e + + async def _aadapter_completion(self, adapter_id: str, model: str, **kwargs): + try: + verbose_router_logger.debug( + f"Inside _aadapter_completion()- model: {model}; kwargs: {kwargs}" + ) + parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs) + deployment = await self.async_get_available_deployment( + model=model, + messages=[{"role": "user", "content": "default text"}], + specific_deployment=kwargs.pop("specific_deployment", None), + request_kwargs=kwargs, + ) + self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs) + + data = deployment["litellm_params"].copy() + model_name = data["model"] + + model_client = self._get_async_openai_model_client( + deployment=deployment, + kwargs=kwargs, + ) + self.total_calls[model_name] += 1 + + response = litellm.aadapter_completion( + **{ + **data, + "adapter_id": adapter_id, + "caching": self.cache_responses, + "client": model_client, + **kwargs, + } + ) + + rpm_semaphore = self._get_client( + deployment=deployment, + kwargs=kwargs, + client_type="max_parallel_requests", + ) + + if rpm_semaphore is not None and isinstance( + rpm_semaphore, asyncio.Semaphore + ): + async with rpm_semaphore: + """ + - Check rpm limits before making the call + - If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe) + """ + await self.async_routing_strategy_pre_call_checks( + deployment=deployment, parent_otel_span=parent_otel_span + ) + response = await response # type: ignore + else: + await self.async_routing_strategy_pre_call_checks( + deployment=deployment, parent_otel_span=parent_otel_span + ) + response = await response # type: ignore + + self.success_calls[model_name] += 1 + verbose_router_logger.info( + f"litellm.aadapter_completion(model={model_name})\033[32m 200 OK\033[0m" + ) + return response + except Exception as e: + verbose_router_logger.info( + f"litellm.aadapter_completion(model={model})\033[31m Exception {str(e)}\033[0m" + ) + if model is not None: + self.fail_calls[model] += 1 + raise e + + async def _ageneric_api_call_with_fallbacks( + self, model: str, original_function: Callable, **kwargs + ): + """ + Make a generic LLM API call through the router, this allows you to use retries/fallbacks with litellm router + + Args: + model: The model to use + handler_function: The handler function to call (e.g., litellm.anthropic_messages) + **kwargs: Additional arguments to pass to the handler function + + Returns: + The response from the handler function + """ + handler_name = original_function.__name__ + try: + verbose_router_logger.debug( + f"Inside _ageneric_api_call() - handler: {handler_name}, model: {model}; kwargs: {kwargs}" + ) + parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs) + deployment = await self.async_get_available_deployment( + model=model, + request_kwargs=kwargs, + messages=kwargs.get("messages", None), + specific_deployment=kwargs.pop("specific_deployment", None), + ) + self._update_kwargs_with_deployment( + deployment=deployment, kwargs=kwargs, function_name="generic_api_call" + ) + + data = deployment["litellm_params"].copy() + model_name = data["model"] + self.total_calls[model_name] += 1 + + response = original_function( + **{ + **data, + "caching": self.cache_responses, + **kwargs, + } + ) + + rpm_semaphore = self._get_client( + deployment=deployment, + kwargs=kwargs, + client_type="max_parallel_requests", + ) + + if rpm_semaphore is not None and isinstance( + rpm_semaphore, asyncio.Semaphore + ): + async with rpm_semaphore: + """ + - Check rpm limits before making the call + - If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe) + """ + await self.async_routing_strategy_pre_call_checks( + deployment=deployment, parent_otel_span=parent_otel_span + ) + response = await response # type: ignore + else: + await self.async_routing_strategy_pre_call_checks( + deployment=deployment, parent_otel_span=parent_otel_span + ) + response = await response # type: ignore + + self.success_calls[model_name] += 1 + verbose_router_logger.info( + f"{handler_name}(model={model_name})\033[32m 200 OK\033[0m" + ) + return response + except Exception as e: + verbose_router_logger.info( + f"{handler_name}(model={model})\033[31m Exception {str(e)}\033[0m" + ) + if model is not None: + self.fail_calls[model] += 1 + raise e + + def _generic_api_call_with_fallbacks( + self, model: str, original_function: Callable, **kwargs + ): + """ + Make a generic LLM API call through the router, this allows you to use retries/fallbacks with litellm router + Args: + model: The model to use + original_function: The handler function to call (e.g., litellm.completion) + **kwargs: Additional arguments to pass to the handler function + Returns: + The response from the handler function + """ + handler_name = original_function.__name__ + try: + verbose_router_logger.debug( + f"Inside _generic_api_call() - handler: {handler_name}, model: {model}; kwargs: {kwargs}" + ) + deployment = self.get_available_deployment( + model=model, + messages=kwargs.get("messages", None), + specific_deployment=kwargs.pop("specific_deployment", None), + ) + self._update_kwargs_with_deployment( + deployment=deployment, kwargs=kwargs, function_name="generic_api_call" + ) + + data = deployment["litellm_params"].copy() + model_name = data["model"] + + self.total_calls[model_name] += 1 + + # Perform pre-call checks for routing strategy + self.routing_strategy_pre_call_checks(deployment=deployment) + + response = original_function( + **{ + **data, + "caching": self.cache_responses, + **kwargs, + } + ) + + self.success_calls[model_name] += 1 + verbose_router_logger.info( + f"{handler_name}(model={model_name})\033[32m 200 OK\033[0m" + ) + return response + except Exception as e: + verbose_router_logger.info( + f"{handler_name}(model={model})\033[31m Exception {str(e)}\033[0m" + ) + if model is not None: + self.fail_calls[model] += 1 + raise e + + def embedding( + self, + model: str, + input: Union[str, List], + is_async: Optional[bool] = False, + **kwargs, + ) -> EmbeddingResponse: + try: + kwargs["model"] = model + kwargs["input"] = input + kwargs["original_function"] = self._embedding + kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) + kwargs.setdefault("metadata", {}).update({"model_group": model}) + response = self.function_with_fallbacks(**kwargs) + return response + except Exception as e: + raise e + + def _embedding(self, input: Union[str, List], model: str, **kwargs): + model_name = None + try: + verbose_router_logger.debug( + f"Inside embedding()- model: {model}; kwargs: {kwargs}" + ) + deployment = self.get_available_deployment( + model=model, + input=input, + specific_deployment=kwargs.pop("specific_deployment", None), + ) + self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs) + data = deployment["litellm_params"].copy() + model_name = data["model"] + + potential_model_client = self._get_client( + deployment=deployment, kwargs=kwargs, client_type="sync" + ) + # check if provided keys == client keys # + dynamic_api_key = kwargs.get("api_key", None) + if ( + dynamic_api_key is not None + and potential_model_client is not None + and dynamic_api_key != potential_model_client.api_key + ): + model_client = None + else: + model_client = potential_model_client + + self.total_calls[model_name] += 1 + + ### DEPLOYMENT-SPECIFIC PRE-CALL CHECKS ### (e.g. update rpm pre-call. Raise error, if deployment over limit) + self.routing_strategy_pre_call_checks(deployment=deployment) + + response = litellm.embedding( + **{ + **data, + "input": input, + "caching": self.cache_responses, + "client": model_client, + **kwargs, + } + ) + self.success_calls[model_name] += 1 + verbose_router_logger.info( + f"litellm.embedding(model={model_name})\033[32m 200 OK\033[0m" + ) + return response + except Exception as e: + verbose_router_logger.info( + f"litellm.embedding(model={model_name})\033[31m Exception {str(e)}\033[0m" + ) + if model_name is not None: + self.fail_calls[model_name] += 1 + raise e + + async def aembedding( + self, + model: str, + input: Union[str, List], + is_async: Optional[bool] = True, + **kwargs, + ) -> EmbeddingResponse: + try: + kwargs["model"] = model + kwargs["input"] = input + kwargs["original_function"] = self._aembedding + self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs) + response = await self.async_function_with_fallbacks(**kwargs) + return response + except Exception as e: + asyncio.create_task( + send_llm_exception_alert( + litellm_router_instance=self, + request_kwargs=kwargs, + error_traceback_str=traceback.format_exc(), + original_exception=e, + ) + ) + raise e + + async def _aembedding(self, input: Union[str, List], model: str, **kwargs): + model_name = None + try: + verbose_router_logger.debug( + f"Inside _aembedding()- model: {model}; kwargs: {kwargs}" + ) + parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs) + deployment = await self.async_get_available_deployment( + model=model, + input=input, + specific_deployment=kwargs.pop("specific_deployment", None), + request_kwargs=kwargs, + ) + self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs) + data = deployment["litellm_params"].copy() + model_name = data["model"] + model_client = self._get_async_openai_model_client( + deployment=deployment, + kwargs=kwargs, + ) + + self.total_calls[model_name] += 1 + response = litellm.aembedding( + **{ + **data, + "input": input, + "caching": self.cache_responses, + "client": model_client, + **kwargs, + } + ) + + ### CONCURRENCY-SAFE RPM CHECKS ### + rpm_semaphore = self._get_client( + deployment=deployment, + kwargs=kwargs, + client_type="max_parallel_requests", + ) + + if rpm_semaphore is not None and isinstance( + rpm_semaphore, asyncio.Semaphore + ): + async with rpm_semaphore: + """ + - Check rpm limits before making the call + - If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe) + """ + await self.async_routing_strategy_pre_call_checks( + deployment=deployment, parent_otel_span=parent_otel_span + ) + response = await response + else: + await self.async_routing_strategy_pre_call_checks( + deployment=deployment, parent_otel_span=parent_otel_span + ) + response = await response + + self.success_calls[model_name] += 1 + verbose_router_logger.info( + f"litellm.aembedding(model={model_name})\033[32m 200 OK\033[0m" + ) + return response + except Exception as e: + verbose_router_logger.info( + f"litellm.aembedding(model={model_name})\033[31m Exception {str(e)}\033[0m" + ) + if model_name is not None: + self.fail_calls[model_name] += 1 + raise e + + #### FILES API #### + async def acreate_file( + self, + model: str, + **kwargs, + ) -> FileObject: + try: + kwargs["model"] = model + kwargs["original_function"] = self._acreate_file + kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) + self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs) + response = await self.async_function_with_fallbacks(**kwargs) + + return response + except Exception as e: + asyncio.create_task( + send_llm_exception_alert( + litellm_router_instance=self, + request_kwargs=kwargs, + error_traceback_str=traceback.format_exc(), + original_exception=e, + ) + ) + raise e + + async def _acreate_file( + self, + model: str, + **kwargs, + ) -> FileObject: + try: + verbose_router_logger.debug( + f"Inside _atext_completion()- model: {model}; kwargs: {kwargs}" + ) + parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs) + deployment = await self.async_get_available_deployment( + model=model, + messages=[{"role": "user", "content": "files-api-fake-text"}], + specific_deployment=kwargs.pop("specific_deployment", None), + request_kwargs=kwargs, + ) + self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs) + + data = deployment["litellm_params"].copy() + model_name = data["model"] + + model_client = self._get_async_openai_model_client( + deployment=deployment, + kwargs=kwargs, + ) + self.total_calls[model_name] += 1 + + ## REPLACE MODEL IN FILE WITH SELECTED DEPLOYMENT ## + stripped_model, custom_llm_provider, _, _ = get_llm_provider( + model=data["model"] + ) + kwargs["file"] = replace_model_in_jsonl( + file_content=kwargs["file"], new_model_name=stripped_model + ) + + response = litellm.acreate_file( + **{ + **data, + "custom_llm_provider": custom_llm_provider, + "caching": self.cache_responses, + "client": model_client, + **kwargs, + } + ) + + rpm_semaphore = self._get_client( + deployment=deployment, + kwargs=kwargs, + client_type="max_parallel_requests", + ) + + if rpm_semaphore is not None and isinstance( + rpm_semaphore, asyncio.Semaphore + ): + async with rpm_semaphore: + """ + - Check rpm limits before making the call + - If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe) + """ + await self.async_routing_strategy_pre_call_checks( + deployment=deployment, parent_otel_span=parent_otel_span + ) + response = await response # type: ignore + else: + await self.async_routing_strategy_pre_call_checks( + deployment=deployment, parent_otel_span=parent_otel_span + ) + response = await response # type: ignore + + self.success_calls[model_name] += 1 + verbose_router_logger.info( + f"litellm.acreate_file(model={model_name})\033[32m 200 OK\033[0m" + ) + return response # type: ignore + except Exception as e: + verbose_router_logger.exception( + f"litellm.acreate_file(model={model}, {kwargs})\033[31m Exception {str(e)}\033[0m" + ) + if model is not None: + self.fail_calls[model] += 1 + raise e + + async def acreate_batch( + self, + model: str, + **kwargs, + ) -> Batch: + try: + kwargs["model"] = model + kwargs["original_function"] = self._acreate_batch + kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) + self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs) + response = await self.async_function_with_fallbacks(**kwargs) + + return response + except Exception as e: + asyncio.create_task( + send_llm_exception_alert( + litellm_router_instance=self, + request_kwargs=kwargs, + error_traceback_str=traceback.format_exc(), + original_exception=e, + ) + ) + raise e + + async def _acreate_batch( + self, + model: str, + **kwargs, + ) -> Batch: + try: + verbose_router_logger.debug( + f"Inside _acreate_batch()- model: {model}; kwargs: {kwargs}" + ) + parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs) + deployment = await self.async_get_available_deployment( + model=model, + messages=[{"role": "user", "content": "files-api-fake-text"}], + specific_deployment=kwargs.pop("specific_deployment", None), + request_kwargs=kwargs, + ) + metadata_variable_name = _get_router_metadata_variable_name( + function_name="_acreate_batch" + ) + + kwargs.setdefault(metadata_variable_name, {}).update( + { + "deployment": deployment["litellm_params"]["model"], + "model_info": deployment.get("model_info", {}), + "api_base": deployment.get("litellm_params", {}).get("api_base"), + } + ) + kwargs["model_info"] = deployment.get("model_info", {}) + data = deployment["litellm_params"].copy() + model_name = data["model"] + self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs) + + model_client = self._get_async_openai_model_client( + deployment=deployment, + kwargs=kwargs, + ) + self.total_calls[model_name] += 1 + + ## SET CUSTOM PROVIDER TO SELECTED DEPLOYMENT ## + _, custom_llm_provider, _, _ = get_llm_provider(model=data["model"]) + + response = litellm.acreate_batch( + **{ + **data, + "custom_llm_provider": custom_llm_provider, + "caching": self.cache_responses, + "client": model_client, + **kwargs, + } + ) + + rpm_semaphore = self._get_client( + deployment=deployment, + kwargs=kwargs, + client_type="max_parallel_requests", + ) + + if rpm_semaphore is not None and isinstance( + rpm_semaphore, asyncio.Semaphore + ): + async with rpm_semaphore: + """ + - Check rpm limits before making the call + - If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe) + """ + await self.async_routing_strategy_pre_call_checks( + deployment=deployment, parent_otel_span=parent_otel_span + ) + response = await response # type: ignore + else: + await self.async_routing_strategy_pre_call_checks( + deployment=deployment, parent_otel_span=parent_otel_span + ) + response = await response # type: ignore + + self.success_calls[model_name] += 1 + verbose_router_logger.info( + f"litellm.acreate_file(model={model_name})\033[32m 200 OK\033[0m" + ) + return response # type: ignore + except Exception as e: + verbose_router_logger.exception( + f"litellm._acreate_batch(model={model}, {kwargs})\033[31m Exception {str(e)}\033[0m" + ) + if model is not None: + self.fail_calls[model] += 1 + raise e + + async def aretrieve_batch( + self, + **kwargs, + ) -> Batch: + """ + Iterate through all models in a model group to check for batch + + Future Improvement - cache the result. + """ + try: + + filtered_model_list = self.get_model_list() + if filtered_model_list is None: + raise Exception("Router not yet initialized.") + + receieved_exceptions = [] + + async def try_retrieve_batch(model_name): + try: + # Update kwargs with the current model name or any other model-specific adjustments + ## SET CUSTOM PROVIDER TO SELECTED DEPLOYMENT ## + _, custom_llm_provider, _, _ = get_llm_provider( # type: ignore + model=model_name["litellm_params"]["model"] + ) + new_kwargs = copy.deepcopy(kwargs) + new_kwargs.pop("custom_llm_provider", None) + return await litellm.aretrieve_batch( + custom_llm_provider=custom_llm_provider, **new_kwargs # type: ignore + ) + except Exception as e: + receieved_exceptions.append(e) + return None + + # Check all models in parallel + results = await asyncio.gather( + *[try_retrieve_batch(model) for model in filtered_model_list], + return_exceptions=True, + ) + + # Check for successful responses and handle exceptions + for result in results: + if isinstance(result, Batch): + return result + + # If no valid Batch response was found, raise the first encountered exception + if receieved_exceptions: + raise receieved_exceptions[0] # Raising the first exception encountered + + # If no exceptions were encountered, raise a generic exception + raise Exception( + "Unable to find batch in any model. Received errors - {}".format( + receieved_exceptions + ) + ) + except Exception as e: + asyncio.create_task( + send_llm_exception_alert( + litellm_router_instance=self, + request_kwargs=kwargs, + error_traceback_str=traceback.format_exc(), + original_exception=e, + ) + ) + raise e + + async def alist_batches( + self, + model: str, + **kwargs, + ): + """ + Return all the batches across all deployments of a model group. + """ + + filtered_model_list = self.get_model_list(model_name=model) + if filtered_model_list is None: + raise Exception("Router not yet initialized.") + + async def try_retrieve_batch(model: DeploymentTypedDict): + try: + # Update kwargs with the current model name or any other model-specific adjustments + return await litellm.alist_batches( + **{**model["litellm_params"], **kwargs} + ) + except Exception: + return None + + # Check all models in parallel + results = await asyncio.gather( + *[try_retrieve_batch(model) for model in filtered_model_list] + ) + + final_results = { + "object": "list", + "data": [], + "first_id": None, + "last_id": None, + "has_more": False, + } + + for result in results: + if result is not None: + ## check batch id + if final_results["first_id"] is None and hasattr(result, "first_id"): + final_results["first_id"] = getattr(result, "first_id") + final_results["last_id"] = getattr(result, "last_id") + final_results["data"].extend(result.data) # type: ignore + + ## check 'has_more' + if getattr(result, "has_more", False) is True: + final_results["has_more"] = True + + return final_results + + #### PASSTHROUGH API #### + + async def _pass_through_moderation_endpoint_factory( + self, + original_function: Callable, + **kwargs, + ): + if kwargs.get("model") and self.get_model_list(model_name=kwargs["model"]): + deployment = await self.async_get_available_deployment( + model=kwargs["model"], + request_kwargs=kwargs, + ) + kwargs["model"] = deployment["litellm_params"]["model"] + return await original_function(**kwargs) + + def factory_function( + self, + original_function: Callable, + call_type: Literal[ + "assistants", + "moderation", + "anthropic_messages", + "aresponses", + "responses", + ] = "assistants", + ): + """ + Creates appropriate wrapper functions for different API call types. + + Returns: + - A synchronous function for synchronous call types + - An asynchronous function for asynchronous call types + """ + # Handle synchronous call types + if call_type == "responses": + + def sync_wrapper( + custom_llm_provider: Optional[ + Literal["openai", "azure", "anthropic"] + ] = None, + client: Optional[Any] = None, + **kwargs, + ): + return self._generic_api_call_with_fallbacks( + original_function=original_function, **kwargs + ) + + return sync_wrapper + + # Handle asynchronous call types + async def async_wrapper( + custom_llm_provider: Optional[ + Literal["openai", "azure", "anthropic"] + ] = None, + client: Optional[Any] = None, + **kwargs, + ): + if call_type == "assistants": + return await self._pass_through_assistants_endpoint_factory( + original_function=original_function, + custom_llm_provider=custom_llm_provider, + client=client, + **kwargs, + ) + elif call_type == "moderation": + return await self._pass_through_moderation_endpoint_factory( + original_function=original_function, **kwargs + ) + elif call_type in ("anthropic_messages", "aresponses"): + return await self._ageneric_api_call_with_fallbacks( + original_function=original_function, + **kwargs, + ) + + return async_wrapper + + async def _pass_through_assistants_endpoint_factory( + self, + original_function: Callable, + custom_llm_provider: Optional[Literal["openai", "azure", "anthropic"]] = None, + client: Optional[AsyncOpenAI] = None, + **kwargs, + ): + """Internal helper function to pass through the assistants endpoint""" + if custom_llm_provider is None: + if self.assistants_config is not None: + custom_llm_provider = self.assistants_config["custom_llm_provider"] + kwargs.update(self.assistants_config["litellm_params"]) + else: + raise Exception( + "'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`" + ) + return await original_function( # type: ignore + custom_llm_provider=custom_llm_provider, client=client, **kwargs + ) + + #### [END] ASSISTANTS API #### + + @tracer.wrap() + async def async_function_with_fallbacks(self, *args, **kwargs): # noqa: PLR0915 + """ + Try calling the function_with_retries + If it fails after num_retries, fall back to another model group + """ + model_group: Optional[str] = kwargs.get("model") + disable_fallbacks: Optional[bool] = kwargs.pop("disable_fallbacks", False) + fallbacks: Optional[List] = kwargs.get("fallbacks", self.fallbacks) + context_window_fallbacks: Optional[List] = kwargs.get( + "context_window_fallbacks", self.context_window_fallbacks + ) + content_policy_fallbacks: Optional[List] = kwargs.get( + "content_policy_fallbacks", self.content_policy_fallbacks + ) + + mock_timeout = kwargs.pop("mock_timeout", None) + + try: + self._handle_mock_testing_fallbacks( + kwargs=kwargs, + model_group=model_group, + fallbacks=fallbacks, + context_window_fallbacks=context_window_fallbacks, + content_policy_fallbacks=content_policy_fallbacks, + ) + + if mock_timeout is not None: + response = await self.async_function_with_retries( + *args, **kwargs, mock_timeout=mock_timeout + ) + else: + response = await self.async_function_with_retries(*args, **kwargs) + verbose_router_logger.debug(f"Async Response: {response}") + response = add_fallback_headers_to_response( + response=response, + attempted_fallbacks=0, + ) + return response + except Exception as e: + verbose_router_logger.debug(f"Traceback{traceback.format_exc()}") + original_exception = e + fallback_model_group = None + original_model_group: Optional[str] = kwargs.get("model") # type: ignore + fallback_failure_exception_str = "" + + if disable_fallbacks is True or original_model_group is None: + raise e + + input_kwargs = { + "litellm_router": self, + "original_exception": original_exception, + **kwargs, + } + + if "max_fallbacks" not in input_kwargs: + input_kwargs["max_fallbacks"] = self.max_fallbacks + if "fallback_depth" not in input_kwargs: + input_kwargs["fallback_depth"] = 0 + + try: + verbose_router_logger.info("Trying to fallback b/w models") + + # check if client-side fallbacks are used (e.g. fallbacks = ["gpt-3.5-turbo", "claude-3-haiku"] or fallbacks=[{"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hey, how's it going?"}]}] + is_non_standard_fallback_format = _check_non_standard_fallback_format( + fallbacks=fallbacks + ) + + if is_non_standard_fallback_format: + input_kwargs.update( + { + "fallback_model_group": fallbacks, + "original_model_group": original_model_group, + } + ) + + response = await run_async_fallback( + *args, + **input_kwargs, + ) + + return response + + if isinstance(e, litellm.ContextWindowExceededError): + if context_window_fallbacks is not None: + fallback_model_group: Optional[List[str]] = ( + self._get_fallback_model_group_from_fallbacks( + fallbacks=context_window_fallbacks, + model_group=model_group, + ) + ) + if fallback_model_group is None: + raise original_exception + + input_kwargs.update( + { + "fallback_model_group": fallback_model_group, + "original_model_group": original_model_group, + } + ) + + response = await run_async_fallback( + *args, + **input_kwargs, + ) + return response + + else: + error_message = "model={}. context_window_fallbacks={}. fallbacks={}.\n\nSet 'context_window_fallback' - https://docs.litellm.ai/docs/routing#fallbacks".format( + model_group, context_window_fallbacks, fallbacks + ) + verbose_router_logger.info( + msg="Got 'ContextWindowExceededError'. No context_window_fallback set. Defaulting \ + to fallbacks, if available.{}".format( + error_message + ) + ) + + e.message += "\n{}".format(error_message) + elif isinstance(e, litellm.ContentPolicyViolationError): + if content_policy_fallbacks is not None: + fallback_model_group: Optional[List[str]] = ( + self._get_fallback_model_group_from_fallbacks( + fallbacks=content_policy_fallbacks, + model_group=model_group, + ) + ) + if fallback_model_group is None: + raise original_exception + + input_kwargs.update( + { + "fallback_model_group": fallback_model_group, + "original_model_group": original_model_group, + } + ) + + response = await run_async_fallback( + *args, + **input_kwargs, + ) + return response + else: + error_message = "model={}. content_policy_fallback={}. fallbacks={}.\n\nSet 'content_policy_fallback' - https://docs.litellm.ai/docs/routing#fallbacks".format( + model_group, content_policy_fallbacks, fallbacks + ) + verbose_router_logger.info( + msg="Got 'ContentPolicyViolationError'. No content_policy_fallback set. Defaulting \ + to fallbacks, if available.{}".format( + error_message + ) + ) + + e.message += "\n{}".format(error_message) + if fallbacks is not None and model_group is not None: + verbose_router_logger.debug(f"inside model fallbacks: {fallbacks}") + fallback_model_group, generic_fallback_idx = ( + get_fallback_model_group( + fallbacks=fallbacks, # if fallbacks = [{"gpt-3.5-turbo": ["claude-3-haiku"]}] + model_group=cast(str, model_group), + ) + ) + ## if none, check for generic fallback + if ( + fallback_model_group is None + and generic_fallback_idx is not None + ): + fallback_model_group = fallbacks[generic_fallback_idx]["*"] + + if fallback_model_group is None: + verbose_router_logger.info( + f"No fallback model group found for original model_group={model_group}. Fallbacks={fallbacks}" + ) + if hasattr(original_exception, "message"): + original_exception.message += f"No fallback model group found for original model_group={model_group}. Fallbacks={fallbacks}" # type: ignore + raise original_exception + + input_kwargs.update( + { + "fallback_model_group": fallback_model_group, + "original_model_group": original_model_group, + } + ) + + response = await run_async_fallback( + *args, + **input_kwargs, + ) + + return response + except Exception as new_exception: + traceback.print_exc() + parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs) + verbose_router_logger.error( + "litellm.router.py::async_function_with_fallbacks() - Error occurred while trying to do fallbacks - {}\n{}\n\nDebug Information:\nCooldown Deployments={}".format( + str(new_exception), + traceback.format_exc(), + await _async_get_cooldown_deployments_with_debug_info( + litellm_router_instance=self, + parent_otel_span=parent_otel_span, + ), + ) + ) + fallback_failure_exception_str = str(new_exception) + + if hasattr(original_exception, "message"): + # add the available fallbacks to the exception + original_exception.message += ". Received Model Group={}\nAvailable Model Group Fallbacks={}".format( # type: ignore + model_group, + fallback_model_group, + ) + if len(fallback_failure_exception_str) > 0: + original_exception.message += ( # type: ignore + "\nError doing the fallback: {}".format( + fallback_failure_exception_str + ) + ) + + raise original_exception + + def _handle_mock_testing_fallbacks( + self, + kwargs: dict, + model_group: Optional[str] = None, + fallbacks: Optional[List] = None, + context_window_fallbacks: Optional[List] = None, + content_policy_fallbacks: Optional[List] = None, + ): + """ + Helper function to raise a litellm Error for mock testing purposes. + + Raises: + litellm.InternalServerError: when `mock_testing_fallbacks=True` passed in request params + litellm.ContextWindowExceededError: when `mock_testing_context_fallbacks=True` passed in request params + litellm.ContentPolicyViolationError: when `mock_testing_content_policy_fallbacks=True` passed in request params + """ + mock_testing_fallbacks = kwargs.pop("mock_testing_fallbacks", None) + mock_testing_context_fallbacks = kwargs.pop( + "mock_testing_context_fallbacks", None + ) + mock_testing_content_policy_fallbacks = kwargs.pop( + "mock_testing_content_policy_fallbacks", None + ) + + if mock_testing_fallbacks is not None and mock_testing_fallbacks is True: + raise litellm.InternalServerError( + model=model_group, + llm_provider="", + message=f"This is a mock exception for model={model_group}, to trigger a fallback. Fallbacks={fallbacks}", + ) + elif ( + mock_testing_context_fallbacks is not None + and mock_testing_context_fallbacks is True + ): + raise litellm.ContextWindowExceededError( + model=model_group, + llm_provider="", + message=f"This is a mock exception for model={model_group}, to trigger a fallback. \ + Context_Window_Fallbacks={context_window_fallbacks}", + ) + elif ( + mock_testing_content_policy_fallbacks is not None + and mock_testing_content_policy_fallbacks is True + ): + raise litellm.ContentPolicyViolationError( + model=model_group, + llm_provider="", + message=f"This is a mock exception for model={model_group}, to trigger a fallback. \ + Context_Policy_Fallbacks={content_policy_fallbacks}", + ) + + @tracer.wrap() + async def async_function_with_retries(self, *args, **kwargs): # noqa: PLR0915 + verbose_router_logger.debug("Inside async function with retries.") + original_function = kwargs.pop("original_function") + fallbacks = kwargs.pop("fallbacks", self.fallbacks) + parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs) + context_window_fallbacks = kwargs.pop( + "context_window_fallbacks", self.context_window_fallbacks + ) + content_policy_fallbacks = kwargs.pop( + "content_policy_fallbacks", self.content_policy_fallbacks + ) + model_group: Optional[str] = kwargs.get("model") + num_retries = kwargs.pop("num_retries") + + ## ADD MODEL GROUP SIZE TO METADATA - used for model_group_rate_limit_error tracking + _metadata: dict = kwargs.get("metadata") or {} + if "model_group" in _metadata and isinstance(_metadata["model_group"], str): + model_list = self.get_model_list(model_name=_metadata["model_group"]) + if model_list is not None: + _metadata.update({"model_group_size": len(model_list)}) + + verbose_router_logger.debug( + f"async function w/ retries: original_function - {original_function}, num_retries - {num_retries}" + ) + try: + self._handle_mock_testing_rate_limit_error( + model_group=model_group, kwargs=kwargs + ) + # if the function call is successful, no exception will be raised and we'll break out of the loop + response = await self.make_call(original_function, *args, **kwargs) + response = add_retry_headers_to_response( + response=response, attempted_retries=0, max_retries=None + ) + return response + except Exception as e: + current_attempt = None + original_exception = e + deployment_num_retries = getattr(e, "num_retries", None) + + if deployment_num_retries is not None and isinstance( + deployment_num_retries, int + ): + num_retries = deployment_num_retries + """ + Retry Logic + """ + _healthy_deployments, _all_deployments = ( + await self._async_get_healthy_deployments( + model=kwargs.get("model") or "", + parent_otel_span=parent_otel_span, + ) + ) + + # raises an exception if this error should not be retries + self.should_retry_this_error( + error=e, + healthy_deployments=_healthy_deployments, + all_deployments=_all_deployments, + context_window_fallbacks=context_window_fallbacks, + regular_fallbacks=fallbacks, + content_policy_fallbacks=content_policy_fallbacks, + ) + + if ( + self.retry_policy is not None + or self.model_group_retry_policy is not None + ): + # get num_retries from retry policy + _retry_policy_retries = self.get_num_retries_from_retry_policy( + exception=original_exception, model_group=kwargs.get("model") + ) + if _retry_policy_retries is not None: + num_retries = _retry_policy_retries + ## LOGGING + if num_retries > 0: + kwargs = self.log_retry(kwargs=kwargs, e=original_exception) + else: + raise + + verbose_router_logger.info( + f"Retrying request with num_retries: {num_retries}" + ) + # decides how long to sleep before retry + retry_after = self._time_to_sleep_before_retry( + e=original_exception, + remaining_retries=num_retries, + num_retries=num_retries, + healthy_deployments=_healthy_deployments, + all_deployments=_all_deployments, + ) + + await asyncio.sleep(retry_after) + + for current_attempt in range(num_retries): + try: + # if the function call is successful, no exception will be raised and we'll break out of the loop + response = await self.make_call(original_function, *args, **kwargs) + if inspect.iscoroutinefunction( + response + ): # async errors are often returned as coroutines + response = await response + + response = add_retry_headers_to_response( + response=response, + attempted_retries=current_attempt + 1, + max_retries=num_retries, + ) + return response + + except Exception as e: + ## LOGGING + kwargs = self.log_retry(kwargs=kwargs, e=e) + remaining_retries = num_retries - current_attempt + _model: Optional[str] = kwargs.get("model") # type: ignore + if _model is not None: + _healthy_deployments, _ = ( + await self._async_get_healthy_deployments( + model=_model, + parent_otel_span=parent_otel_span, + ) + ) + else: + _healthy_deployments = [] + _timeout = self._time_to_sleep_before_retry( + e=original_exception, + remaining_retries=remaining_retries, + num_retries=num_retries, + healthy_deployments=_healthy_deployments, + all_deployments=_all_deployments, + ) + await asyncio.sleep(_timeout) + + if type(original_exception) in litellm.LITELLM_EXCEPTION_TYPES: + setattr(original_exception, "max_retries", num_retries) + setattr(original_exception, "num_retries", current_attempt) + + raise original_exception + + async def make_call(self, original_function: Any, *args, **kwargs): + """ + Handler for making a call to the .completion()/.embeddings()/etc. functions. + """ + model_group = kwargs.get("model") + response = original_function(*args, **kwargs) + if inspect.iscoroutinefunction(response) or inspect.isawaitable(response): + response = await response + ## PROCESS RESPONSE HEADERS + response = await self.set_response_headers( + response=response, model_group=model_group + ) + + return response + + def _handle_mock_testing_rate_limit_error( + self, kwargs: dict, model_group: Optional[str] = None + ): + """ + Helper function to raise a mock litellm.RateLimitError error for testing purposes. + + Raises: + litellm.RateLimitError error when `mock_testing_rate_limit_error=True` passed in request params + """ + mock_testing_rate_limit_error: Optional[bool] = kwargs.pop( + "mock_testing_rate_limit_error", None + ) + + available_models = self.get_model_list(model_name=model_group) + num_retries: Optional[int] = None + + if available_models is not None and len(available_models) == 1: + num_retries = cast( + Optional[int], available_models[0]["litellm_params"].get("num_retries") + ) + + if ( + mock_testing_rate_limit_error is not None + and mock_testing_rate_limit_error is True + ): + verbose_router_logger.info( + f"litellm.router.py::_mock_rate_limit_error() - Raising mock RateLimitError for model={model_group}" + ) + raise litellm.RateLimitError( + model=model_group, + llm_provider="", + message=f"This is a mock exception for model={model_group}, to trigger a rate limit error.", + num_retries=num_retries, + ) + + def should_retry_this_error( + self, + error: Exception, + healthy_deployments: Optional[List] = None, + all_deployments: Optional[List] = None, + context_window_fallbacks: Optional[List] = None, + content_policy_fallbacks: Optional[List] = None, + regular_fallbacks: Optional[List] = None, + ): + """ + 1. raise an exception for ContextWindowExceededError if context_window_fallbacks is not None + 2. raise an exception for ContentPolicyViolationError if content_policy_fallbacks is not None + + 2. raise an exception for RateLimitError if + - there are no fallbacks + - there are no healthy deployments in the same model group + """ + _num_healthy_deployments = 0 + if healthy_deployments is not None and isinstance(healthy_deployments, list): + _num_healthy_deployments = len(healthy_deployments) + + _num_all_deployments = 0 + if all_deployments is not None and isinstance(all_deployments, list): + _num_all_deployments = len(all_deployments) + + ### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR / CONTENT POLICY VIOLATION ERROR w/ fallbacks available / Bad Request Error + if ( + isinstance(error, litellm.ContextWindowExceededError) + and context_window_fallbacks is not None + ): + raise error + + if ( + isinstance(error, litellm.ContentPolicyViolationError) + and content_policy_fallbacks is not None + ): + raise error + + if isinstance(error, litellm.NotFoundError): + raise error + # Error we should only retry if there are other deployments + if isinstance(error, openai.RateLimitError): + if ( + _num_healthy_deployments <= 0 # if no healthy deployments + and regular_fallbacks is not None # and fallbacks available + and len(regular_fallbacks) > 0 + ): + raise error # then raise the error + + if isinstance(error, openai.AuthenticationError): + """ + - if other deployments available -> retry + - else -> raise error + """ + if ( + _num_all_deployments <= 1 + ): # if there is only 1 deployment for this model group then don't retry + raise error # then raise error + + # Do not retry if there are no healthy deployments + # just raise the error + if _num_healthy_deployments <= 0: # if no healthy deployments + raise error + + return True + + def function_with_fallbacks(self, *args, **kwargs): + """ + Sync wrapper for async_function_with_fallbacks + + Wrapped to reduce code duplication and prevent bugs. + """ + return run_async_function(self.async_function_with_fallbacks, *args, **kwargs) + + def _get_fallback_model_group_from_fallbacks( + self, + fallbacks: List[Dict[str, List[str]]], + model_group: Optional[str] = None, + ) -> Optional[List[str]]: + """ + Returns the list of fallback models to use for a given model group + + If no fallback model group is found, returns None + + Example: + fallbacks = [{"gpt-3.5-turbo": ["gpt-4"]}, {"gpt-4o": ["gpt-3.5-turbo"]}] + model_group = "gpt-3.5-turbo" + returns: ["gpt-4"] + """ + if model_group is None: + return None + + fallback_model_group: Optional[List[str]] = None + for item in fallbacks: # [{"gpt-3.5-turbo": ["gpt-4"]}] + if list(item.keys())[0] == model_group: + fallback_model_group = item[model_group] + break + return fallback_model_group + + def _time_to_sleep_before_retry( + self, + e: Exception, + remaining_retries: int, + num_retries: int, + healthy_deployments: Optional[List] = None, + all_deployments: Optional[List] = None, + ) -> Union[int, float]: + """ + Calculate back-off, then retry + + It should instantly retry only when: + 1. there are healthy deployments in the same model group + 2. there are fallbacks for the completion call + """ + + ## base case - single deployment + if all_deployments is not None and len(all_deployments) == 1: + pass + elif ( + healthy_deployments is not None + and isinstance(healthy_deployments, list) + and len(healthy_deployments) > 0 + ): + return 0 + + response_headers: Optional[httpx.Headers] = None + if hasattr(e, "response") and hasattr(e.response, "headers"): # type: ignore + response_headers = e.response.headers # type: ignore + if hasattr(e, "litellm_response_headers"): + response_headers = e.litellm_response_headers # type: ignore + + if response_headers is not None: + timeout = litellm._calculate_retry_after( + remaining_retries=remaining_retries, + max_retries=num_retries, + response_headers=response_headers, + min_timeout=self.retry_after, + ) + + else: + timeout = litellm._calculate_retry_after( + remaining_retries=remaining_retries, + max_retries=num_retries, + min_timeout=self.retry_after, + ) + + return timeout + + ### HELPER FUNCTIONS + + async def deployment_callback_on_success( + self, + kwargs, # kwargs to completion + completion_response, # response from completion + start_time, + end_time, # start/end time + ): + """ + Track remaining tpm/rpm quota for model in model_list + """ + try: + standard_logging_object: Optional[StandardLoggingPayload] = kwargs.get( + "standard_logging_object", None + ) + if standard_logging_object is None: + raise ValueError("standard_logging_object is None") + if kwargs["litellm_params"].get("metadata") is None: + pass + else: + deployment_name = kwargs["litellm_params"]["metadata"].get( + "deployment", None + ) # stable name - works for wildcard routes as well + model_group = standard_logging_object.get("model_group", None) + id = standard_logging_object.get("model_id", None) + if model_group is None or id is None: + return + elif isinstance(id, int): + id = str(id) + + parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs) + total_tokens: float = standard_logging_object.get("total_tokens", 0) + + # ------------ + # Setup values + # ------------ + dt = get_utc_datetime() + current_minute = dt.strftime( + "%H-%M" + ) # use the same timezone regardless of system clock + + tpm_key = RouterCacheEnum.TPM.value.format( + id=id, current_minute=current_minute, model=deployment_name + ) + # ------------ + # Update usage + # ------------ + # update cache + + ## TPM + await self.cache.async_increment_cache( + key=tpm_key, + value=total_tokens, + parent_otel_span=parent_otel_span, + ttl=RoutingArgs.ttl.value, + ) + + ## RPM + rpm_key = RouterCacheEnum.RPM.value.format( + id=id, current_minute=current_minute, model=deployment_name + ) + await self.cache.async_increment_cache( + key=rpm_key, + value=1, + parent_otel_span=parent_otel_span, + ttl=RoutingArgs.ttl.value, + ) + + increment_deployment_successes_for_current_minute( + litellm_router_instance=self, + deployment_id=id, + ) + + return tpm_key + + except Exception as e: + verbose_router_logger.exception( + "litellm.router.Router::deployment_callback_on_success(): Exception occured - {}".format( + str(e) + ) + ) + pass + + def sync_deployment_callback_on_success( + self, + kwargs, # kwargs to completion + completion_response, # response from completion + start_time, + end_time, # start/end time + ) -> Optional[str]: + """ + Tracks the number of successes for a deployment in the current minute (using in-memory cache) + + Returns: + - key: str - The key used to increment the cache + - None: if no key is found + """ + id = None + if kwargs["litellm_params"].get("metadata") is None: + pass + else: + model_group = kwargs["litellm_params"]["metadata"].get("model_group", None) + model_info = kwargs["litellm_params"].get("model_info", {}) or {} + id = model_info.get("id", None) + if model_group is None or id is None: + return None + elif isinstance(id, int): + id = str(id) + + if id is not None: + key = increment_deployment_successes_for_current_minute( + litellm_router_instance=self, + deployment_id=id, + ) + return key + + return None + + def deployment_callback_on_failure( + self, + kwargs, # kwargs to completion + completion_response, # response from completion + start_time, + end_time, # start/end time + ) -> bool: + """ + 2 jobs: + - Tracks the number of failures for a deployment in the current minute (using in-memory cache) + - Puts the deployment in cooldown if it exceeds the allowed fails / minute + + Returns: + - True if the deployment should be put in cooldown + - False if the deployment should not be put in cooldown + """ + verbose_router_logger.debug("Router: Entering 'deployment_callback_on_failure'") + try: + exception = kwargs.get("exception", None) + exception_status = getattr(exception, "status_code", "") + _model_info = kwargs.get("litellm_params", {}).get("model_info", {}) + + exception_headers = litellm.litellm_core_utils.exception_mapping_utils._get_response_headers( + original_exception=exception + ) + + _time_to_cooldown = kwargs.get("litellm_params", {}).get( + "cooldown_time", self.cooldown_time + ) + + if exception_headers is not None: + + _time_to_cooldown = ( + litellm.utils._get_retry_after_from_exception_header( + response_headers=exception_headers + ) + ) + + if _time_to_cooldown is None or _time_to_cooldown < 0: + # if the response headers did not read it -> set to default cooldown time + _time_to_cooldown = self.cooldown_time + + if isinstance(_model_info, dict): + deployment_id = _model_info.get("id", None) + increment_deployment_failures_for_current_minute( + litellm_router_instance=self, + deployment_id=deployment_id, + ) + result = _set_cooldown_deployments( + litellm_router_instance=self, + exception_status=exception_status, + original_exception=exception, + deployment=deployment_id, + time_to_cooldown=_time_to_cooldown, + ) # setting deployment_id in cooldown deployments + + return result + else: + verbose_router_logger.debug( + "Router: Exiting 'deployment_callback_on_failure' without cooldown. No model_info found." + ) + return False + + except Exception as e: + raise e + + async def async_deployment_callback_on_failure( + self, kwargs, completion_response: Optional[Any], start_time, end_time + ): + """ + Update RPM usage for a deployment + """ + deployment_name = kwargs["litellm_params"]["metadata"].get( + "deployment", None + ) # handles wildcard routes - by giving the original name sent to `litellm.completion` + model_group = kwargs["litellm_params"]["metadata"].get("model_group", None) + model_info = kwargs["litellm_params"].get("model_info", {}) or {} + id = model_info.get("id", None) + if model_group is None or id is None: + return + elif isinstance(id, int): + id = str(id) + parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs) + + dt = get_utc_datetime() + current_minute = dt.strftime( + "%H-%M" + ) # use the same timezone regardless of system clock + + ## RPM + rpm_key = RouterCacheEnum.RPM.value.format( + id=id, current_minute=current_minute, model=deployment_name + ) + await self.cache.async_increment_cache( + key=rpm_key, + value=1, + parent_otel_span=parent_otel_span, + ttl=RoutingArgs.ttl.value, + ) + + def log_retry(self, kwargs: dict, e: Exception) -> dict: + """ + When a retry or fallback happens, log the details of the just failed model call - similar to Sentry breadcrumbing + """ + try: + # Log failed model as the previous model + previous_model = { + "exception_type": type(e).__name__, + "exception_string": str(e), + } + for ( + k, + v, + ) in ( + kwargs.items() + ): # log everything in kwargs except the old previous_models value - prevent nesting + if k not in ["metadata", "messages", "original_function"]: + previous_model[k] = v + elif k == "metadata" and isinstance(v, dict): + previous_model["metadata"] = {} # type: ignore + for metadata_k, metadata_v in kwargs["metadata"].items(): + if metadata_k != "previous_models": + previous_model[k][metadata_k] = metadata_v # type: ignore + + # check current size of self.previous_models, if it's larger than 3, remove the first element + if len(self.previous_models) > 3: + self.previous_models.pop(0) + + self.previous_models.append(previous_model) + kwargs["metadata"]["previous_models"] = self.previous_models + return kwargs + except Exception as e: + raise e + + def _update_usage( + self, deployment_id: str, parent_otel_span: Optional[Span] + ) -> int: + """ + Update deployment rpm for that minute + + Returns: + - int: request count + """ + rpm_key = deployment_id + + request_count = self.cache.get_cache( + key=rpm_key, parent_otel_span=parent_otel_span, local_only=True + ) + if request_count is None: + request_count = 1 + self.cache.set_cache( + key=rpm_key, value=request_count, local_only=True, ttl=60 + ) # only store for 60s + else: + request_count += 1 + self.cache.set_cache( + key=rpm_key, value=request_count, local_only=True + ) # don't change existing ttl + + return request_count + + def _has_default_fallbacks(self) -> bool: + if self.fallbacks is None: + return False + for fallback in self.fallbacks: + if isinstance(fallback, dict): + if "*" in fallback: + return True + return False + + def _should_raise_content_policy_error( + self, model: str, response: ModelResponse, kwargs: dict + ) -> bool: + """ + Determines if a content policy error should be raised. + + Only raised if a fallback is available. + + Else, original response is returned. + """ + if response.choices and len(response.choices) > 0: + if response.choices[0].finish_reason != "content_filter": + return False + + content_policy_fallbacks = kwargs.get( + "content_policy_fallbacks", self.content_policy_fallbacks + ) + + ### ONLY RAISE ERROR IF CP FALLBACK AVAILABLE ### + if content_policy_fallbacks is not None: + fallback_model_group = None + for item in content_policy_fallbacks: # [{"gpt-3.5-turbo": ["gpt-4"]}] + if list(item.keys())[0] == model: + fallback_model_group = item[model] + break + + if fallback_model_group is not None: + return True + elif self._has_default_fallbacks(): # default fallbacks set + return True + + verbose_router_logger.info( + "Content Policy Error occurred. No available fallbacks. Returning original response. model={}, content_policy_fallbacks={}".format( + model, content_policy_fallbacks + ) + ) + return False + + def _get_healthy_deployments(self, model: str, parent_otel_span: Optional[Span]): + _all_deployments: list = [] + try: + _, _all_deployments = self._common_checks_available_deployment( # type: ignore + model=model, + ) + if isinstance(_all_deployments, dict): + return [] + except Exception: + pass + + unhealthy_deployments = _get_cooldown_deployments( + litellm_router_instance=self, parent_otel_span=parent_otel_span + ) + healthy_deployments: list = [] + for deployment in _all_deployments: + if deployment["model_info"]["id"] in unhealthy_deployments: + continue + else: + healthy_deployments.append(deployment) + + return healthy_deployments, _all_deployments + + async def _async_get_healthy_deployments( + self, model: str, parent_otel_span: Optional[Span] + ) -> Tuple[List[Dict], List[Dict]]: + """ + Returns Tuple of: + - Tuple[List[Dict], List[Dict]]: + 1. healthy_deployments: list of healthy deployments + 2. all_deployments: list of all deployments + """ + _all_deployments: list = [] + try: + _, _all_deployments = self._common_checks_available_deployment( # type: ignore + model=model, + ) + if isinstance(_all_deployments, dict): + return [], _all_deployments + except Exception: + pass + + unhealthy_deployments = await _async_get_cooldown_deployments( + litellm_router_instance=self, parent_otel_span=parent_otel_span + ) + healthy_deployments: list = [] + for deployment in _all_deployments: + if deployment["model_info"]["id"] in unhealthy_deployments: + continue + else: + healthy_deployments.append(deployment) + return healthy_deployments, _all_deployments + + def routing_strategy_pre_call_checks(self, deployment: dict): + """ + Mimics 'async_routing_strategy_pre_call_checks' + + Ensures consistent update rpm implementation for 'usage-based-routing-v2' + + Returns: + - None + + Raises: + - Rate Limit Exception - If the deployment is over it's tpm/rpm limits + """ + for _callback in litellm.callbacks: + if isinstance(_callback, CustomLogger): + _callback.pre_call_check(deployment) + + async def async_routing_strategy_pre_call_checks( + self, + deployment: dict, + parent_otel_span: Optional[Span], + logging_obj: Optional[LiteLLMLogging] = None, + ): + """ + For usage-based-routing-v2, enables running rpm checks before the call is made, inside the semaphore. + + -> makes the calls concurrency-safe, when rpm limits are set for a deployment + + Returns: + - None + + Raises: + - Rate Limit Exception - If the deployment is over it's tpm/rpm limits + """ + for _callback in litellm.callbacks: + if isinstance(_callback, CustomLogger): + try: + await _callback.async_pre_call_check(deployment, parent_otel_span) + except litellm.RateLimitError as e: + ## LOG FAILURE EVENT + if logging_obj is not None: + asyncio.create_task( + logging_obj.async_failure_handler( + exception=e, + traceback_exception=traceback.format_exc(), + end_time=time.time(), + ) + ) + ## LOGGING + threading.Thread( + target=logging_obj.failure_handler, + args=(e, traceback.format_exc()), + ).start() # log response + _set_cooldown_deployments( + litellm_router_instance=self, + exception_status=e.status_code, + original_exception=e, + deployment=deployment["model_info"]["id"], + time_to_cooldown=self.cooldown_time, + ) + raise e + except Exception as e: + ## LOG FAILURE EVENT + if logging_obj is not None: + asyncio.create_task( + logging_obj.async_failure_handler( + exception=e, + traceback_exception=traceback.format_exc(), + end_time=time.time(), + ) + ) + ## LOGGING + threading.Thread( + target=logging_obj.failure_handler, + args=(e, traceback.format_exc()), + ).start() # log response + raise e + + async def async_callback_filter_deployments( + self, + model: str, + healthy_deployments: List[dict], + messages: Optional[List[AllMessageValues]], + parent_otel_span: Optional[Span], + request_kwargs: Optional[dict] = None, + logging_obj: Optional[LiteLLMLogging] = None, + ): + """ + For usage-based-routing-v2, enables running rpm checks before the call is made, inside the semaphore. + + -> makes the calls concurrency-safe, when rpm limits are set for a deployment + + Returns: + - None + + Raises: + - Rate Limit Exception - If the deployment is over it's tpm/rpm limits + """ + returned_healthy_deployments = healthy_deployments + for _callback in litellm.callbacks: + if isinstance(_callback, CustomLogger): + try: + returned_healthy_deployments = ( + await _callback.async_filter_deployments( + model=model, + healthy_deployments=returned_healthy_deployments, + messages=messages, + request_kwargs=request_kwargs, + parent_otel_span=parent_otel_span, + ) + ) + except Exception as e: + ## LOG FAILURE EVENT + if logging_obj is not None: + asyncio.create_task( + logging_obj.async_failure_handler( + exception=e, + traceback_exception=traceback.format_exc(), + end_time=time.time(), + ) + ) + ## LOGGING + threading.Thread( + target=logging_obj.failure_handler, + args=(e, traceback.format_exc()), + ).start() # log response + raise e + return returned_healthy_deployments + + def _generate_model_id(self, model_group: str, litellm_params: dict): + """ + Helper function to consistently generate the same id for a deployment + + - create a string from all the litellm params + - hash + - use hash as id + """ + concat_str = model_group + for k, v in litellm_params.items(): + if isinstance(k, str): + concat_str += k + elif isinstance(k, dict): + concat_str += json.dumps(k) + else: + concat_str += str(k) + + if isinstance(v, str): + concat_str += v + elif isinstance(v, dict): + concat_str += json.dumps(v) + else: + concat_str += str(v) + + hash_object = hashlib.sha256(concat_str.encode()) + + return hash_object.hexdigest() + + def _create_deployment( + self, + deployment_info: dict, + _model_name: str, + _litellm_params: dict, + _model_info: dict, + ) -> Optional[Deployment]: + """ + Create a deployment object and add it to the model list + + If the deployment is not active for the current environment, it is ignored + + Returns: + - Deployment: The deployment object + - None: If the deployment is not active for the current environment (if 'supported_environments' is set in litellm_params) + """ + deployment = Deployment( + **deployment_info, + model_name=_model_name, + litellm_params=LiteLLM_Params(**_litellm_params), + model_info=_model_info, + ) + + ## REGISTER MODEL INFO IN LITELLM MODEL COST MAP + _model_name = deployment.litellm_params.model + if deployment.litellm_params.custom_llm_provider is not None: + _model_name = ( + deployment.litellm_params.custom_llm_provider + "/" + _model_name + ) + + litellm.register_model( + model_cost={ + _model_name: _model_info, + } + ) + + ## Check if LLM Deployment is allowed for this deployment + if self.deployment_is_active_for_environment(deployment=deployment) is not True: + verbose_router_logger.warning( + f"Ignoring deployment {deployment.model_name} as it is not active for environment {deployment.model_info['supported_environments']}" + ) + return None + + deployment = self._add_deployment(deployment=deployment) + + model = deployment.to_json(exclude_none=True) + + self.model_list.append(model) + return deployment + + def deployment_is_active_for_environment(self, deployment: Deployment) -> bool: + """ + Function to check if a llm deployment is active for a given environment. Allows using the same config.yaml across multople environments + + Requires `LITELLM_ENVIRONMENT` to be set in .env. Valid values for environment: + - development + - staging + - production + + Raises: + - ValueError: If LITELLM_ENVIRONMENT is not set in .env or not one of the valid values + - ValueError: If supported_environments is not set in model_info or not one of the valid values + """ + if ( + deployment.model_info is None + or "supported_environments" not in deployment.model_info + or deployment.model_info["supported_environments"] is None + ): + return True + litellm_environment = get_secret_str(secret_name="LITELLM_ENVIRONMENT") + if litellm_environment is None: + raise ValueError( + "Set 'supported_environments' for model but not 'LITELLM_ENVIRONMENT' set in .env" + ) + + if litellm_environment not in VALID_LITELLM_ENVIRONMENTS: + raise ValueError( + f"LITELLM_ENVIRONMENT must be one of {VALID_LITELLM_ENVIRONMENTS}. but set as: {litellm_environment}" + ) + + for _env in deployment.model_info["supported_environments"]: + if _env not in VALID_LITELLM_ENVIRONMENTS: + raise ValueError( + f"supported_environments must be one of {VALID_LITELLM_ENVIRONMENTS}. but set as: {_env} for deployment: {deployment}" + ) + + if litellm_environment in deployment.model_info["supported_environments"]: + return True + return False + + def set_model_list(self, model_list: list): + original_model_list = copy.deepcopy(model_list) + self.model_list = [] + # we add api_base/api_key each model so load balancing between azure/gpt on api_base1 and api_base2 works + + for model in original_model_list: + _model_name = model.pop("model_name") + _litellm_params = model.pop("litellm_params") + ## check if litellm params in os.environ + if isinstance(_litellm_params, dict): + for k, v in _litellm_params.items(): + if isinstance(v, str) and v.startswith("os.environ/"): + _litellm_params[k] = get_secret(v) + + _model_info: dict = model.pop("model_info", {}) + + # check if model info has id + if "id" not in _model_info: + _id = self._generate_model_id(_model_name, _litellm_params) + _model_info["id"] = _id + + if _litellm_params.get("organization", None) is not None and isinstance( + _litellm_params["organization"], list + ): # Addresses https://github.com/BerriAI/litellm/issues/3949 + for org in _litellm_params["organization"]: + _litellm_params["organization"] = org + self._create_deployment( + deployment_info=model, + _model_name=_model_name, + _litellm_params=_litellm_params, + _model_info=_model_info, + ) + else: + self._create_deployment( + deployment_info=model, + _model_name=_model_name, + _litellm_params=_litellm_params, + _model_info=_model_info, + ) + + verbose_router_logger.debug( + f"\nInitialized Model List {self.get_model_names()}" + ) + self.model_names = [m["model_name"] for m in model_list] + + def _add_deployment(self, deployment: Deployment) -> Deployment: + import os + + #### DEPLOYMENT NAMES INIT ######## + self.deployment_names.append(deployment.litellm_params.model) + ############ Users can either pass tpm/rpm as a litellm_param or a router param ########### + # for get_available_deployment, we use the litellm_param["rpm"] + # in this snippet we also set rpm to be a litellm_param + if ( + deployment.litellm_params.rpm is None + and getattr(deployment, "rpm", None) is not None + ): + deployment.litellm_params.rpm = getattr(deployment, "rpm") + + if ( + deployment.litellm_params.tpm is None + and getattr(deployment, "tpm", None) is not None + ): + deployment.litellm_params.tpm = getattr(deployment, "tpm") + + #### VALIDATE MODEL ######## + # check if model provider in supported providers + ( + _model, + custom_llm_provider, + dynamic_api_key, + api_base, + ) = litellm.get_llm_provider( + model=deployment.litellm_params.model, + custom_llm_provider=deployment.litellm_params.get( + "custom_llm_provider", None + ), + ) + + # Check if user is trying to use model_name == "*" + # this is a catch all model for their specific api key + # if deployment.model_name == "*": + # if deployment.litellm_params.model == "*": + # # user wants to pass through all requests to litellm.acompletion for unknown deployments + # self.router_general_settings.pass_through_all_models = True + # else: + # self.default_deployment = deployment.to_json(exclude_none=True) + # Check if user is using provider specific wildcard routing + # example model_name = "databricks/*" or model_name = "anthropic/*" + if "*" in deployment.model_name: + # store this as a regex pattern - all deployments matching this pattern will be sent to this deployment + # Store deployment.model_name as a regex pattern + self.pattern_router.add_pattern( + deployment.model_name, deployment.to_json(exclude_none=True) + ) + if deployment.model_info.id: + self.provider_default_deployment_ids.append(deployment.model_info.id) + + # Azure GPT-Vision Enhancements, users can pass os.environ/ + data_sources = deployment.litellm_params.get("dataSources", []) or [] + + for data_source in data_sources: + params = data_source.get("parameters", {}) + for param_key in ["endpoint", "key"]: + # if endpoint or key set for Azure GPT Vision Enhancements, check if it's an env var + if param_key in params and params[param_key].startswith("os.environ/"): + env_name = params[param_key].replace("os.environ/", "") + params[param_key] = os.environ.get(env_name, "") + + # done reading model["litellm_params"] + if custom_llm_provider not in litellm.provider_list: + raise Exception(f"Unsupported provider - {custom_llm_provider}") + + # # init OpenAI, Azure clients + # InitalizeOpenAISDKClient.set_client( + # litellm_router_instance=self, model=deployment.to_json(exclude_none=True) + # ) + + self._initialize_deployment_for_pass_through( + deployment=deployment, + custom_llm_provider=custom_llm_provider, + model=deployment.litellm_params.model, + ) + + return deployment + + def _initialize_deployment_for_pass_through( + self, deployment: Deployment, custom_llm_provider: str, model: str + ): + """ + Optional: Initialize deployment for pass-through endpoints if `deployment.litellm_params.use_in_pass_through` is True + + Each provider uses diff .env vars for pass-through endpoints, this helper uses the deployment credentials to set the .env vars for pass-through endpoints + """ + if deployment.litellm_params.use_in_pass_through is True: + if custom_llm_provider == "vertex_ai": + from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import ( + vertex_pass_through_router, + ) + + if ( + deployment.litellm_params.vertex_project is None + or deployment.litellm_params.vertex_location is None + or deployment.litellm_params.vertex_credentials is None + ): + raise ValueError( + "vertex_project, vertex_location, and vertex_credentials must be set in litellm_params for pass-through endpoints" + ) + vertex_pass_through_router.add_vertex_credentials( + project_id=deployment.litellm_params.vertex_project, + location=deployment.litellm_params.vertex_location, + vertex_credentials=deployment.litellm_params.vertex_credentials, + ) + else: + from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import ( + passthrough_endpoint_router, + ) + + passthrough_endpoint_router.set_pass_through_credentials( + custom_llm_provider=custom_llm_provider, + api_base=deployment.litellm_params.api_base, + api_key=deployment.litellm_params.api_key, + ) + pass + pass + + def add_deployment(self, deployment: Deployment) -> Optional[Deployment]: + """ + Parameters: + - deployment: Deployment - the deployment to be added to the Router + + Returns: + - The added deployment + - OR None (if deployment already exists) + """ + # check if deployment already exists + + if deployment.model_info.id in self.get_model_ids(): + return None + + # add to model list + _deployment = deployment.to_json(exclude_none=True) + self.model_list.append(_deployment) + + # initialize client + self._add_deployment(deployment=deployment) + + # add to model names + self.model_names.append(deployment.model_name) + return deployment + + def upsert_deployment(self, deployment: Deployment) -> Optional[Deployment]: + """ + Add or update deployment + Parameters: + - deployment: Deployment - the deployment to be added to the Router + + Returns: + - The added/updated deployment + """ + # check if deployment already exists + _deployment_model_id = deployment.model_info.id or "" + + _deployment_on_router: Optional[Deployment] = self.get_deployment( + model_id=_deployment_model_id + ) + if _deployment_on_router is not None: + # deployment with this model_id exists on the router + if deployment.litellm_params == _deployment_on_router.litellm_params: + # No need to update + return None + + # if there is a new litellm param -> then update the deployment + # remove the previous deployment + removal_idx: Optional[int] = None + for idx, model in enumerate(self.model_list): + if model["model_info"]["id"] == deployment.model_info.id: + removal_idx = idx + + if removal_idx is not None: + self.model_list.pop(removal_idx) + + # if the model_id is not in router + self.add_deployment(deployment=deployment) + return deployment + + def delete_deployment(self, id: str) -> Optional[Deployment]: + """ + Parameters: + - id: str - the id of the deployment to be deleted + + Returns: + - The deleted deployment + - OR None (if deleted deployment not found) + """ + deployment_idx = None + for idx, m in enumerate(self.model_list): + if m["model_info"]["id"] == id: + deployment_idx = idx + + try: + if deployment_idx is not None: + item = self.model_list.pop(deployment_idx) + return item + else: + return None + except Exception: + return None + + def get_deployment(self, model_id: str) -> Optional[Deployment]: + """ + Returns -> Deployment or None + + Raise Exception -> if model found in invalid format + """ + for model in self.model_list: + if "model_info" in model and "id" in model["model_info"]: + if model_id == model["model_info"]["id"]: + if isinstance(model, dict): + return Deployment(**model) + elif isinstance(model, Deployment): + return model + else: + raise Exception("Model invalid format - {}".format(type(model))) + return None + + def get_deployment_credentials(self, model_id: str) -> Optional[dict]: + """ + Returns -> dict of credentials for a given model id + """ + deployment = self.get_deployment(model_id=model_id) + if deployment is None: + return None + return CredentialLiteLLMParams( + **deployment.litellm_params.model_dump(exclude_none=True) + ).model_dump(exclude_none=True) + + def get_deployment_by_model_group_name( + self, model_group_name: str + ) -> Optional[Deployment]: + """ + Returns -> Deployment or None + + Raise Exception -> if model found in invalid format + """ + for model in self.model_list: + if model["model_name"] == model_group_name: + if isinstance(model, dict): + return Deployment(**model) + elif isinstance(model, Deployment): + return model + else: + raise Exception("Model Name invalid - {}".format(type(model))) + return None + + @overload + def get_router_model_info( + self, deployment: dict, received_model_name: str, id: None = None + ) -> ModelMapInfo: + pass + + @overload + def get_router_model_info( + self, deployment: None, received_model_name: str, id: str + ) -> ModelMapInfo: + pass + + def get_router_model_info( + self, + deployment: Optional[dict], + received_model_name: str, + id: Optional[str] = None, + ) -> ModelMapInfo: + """ + For a given model id, return the model info (max tokens, input cost, output cost, etc.). + + Augment litellm info with additional params set in `model_info`. + + For azure models, ignore the `model:`. Only set max tokens, cost values if base_model is set. + + Returns + - ModelInfo - If found -> typed dict with max tokens, input cost, etc. + + Raises: + - ValueError -> If model is not mapped yet + """ + if id is not None: + _deployment = self.get_deployment(model_id=id) + if _deployment is not None: + deployment = _deployment.model_dump(exclude_none=True) + + if deployment is None: + raise ValueError("Deployment not found") + + ## GET BASE MODEL + base_model = deployment.get("model_info", {}).get("base_model", None) + if base_model is None: + base_model = deployment.get("litellm_params", {}).get("base_model", None) + + model = base_model + + ## GET PROVIDER + _model, custom_llm_provider, _, _ = litellm.get_llm_provider( + model=deployment.get("litellm_params", {}).get("model", ""), + litellm_params=LiteLLM_Params(**deployment.get("litellm_params", {})), + ) + + ## SET MODEL TO 'model=' - if base_model is None + not azure + if custom_llm_provider == "azure" and base_model is None: + verbose_router_logger.error( + "Could not identify azure model. Set azure 'base_model' for accurate max tokens, cost tracking, etc.- https://docs.litellm.ai/docs/proxy/cost_tracking#spend-tracking-for-azure-openai-models" + ) + elif custom_llm_provider != "azure": + model = _model + + potential_models = self.pattern_router.route(received_model_name) + if "*" in model and potential_models is not None: # if wildcard route + for potential_model in potential_models: + try: + if potential_model.get("model_info", {}).get( + "id" + ) == deployment.get("model_info", {}).get("id"): + model = potential_model.get("litellm_params", {}).get( + "model" + ) + break + except Exception: + pass + + ## GET LITELLM MODEL INFO - raises exception, if model is not mapped + if not model.startswith("{}/".format(custom_llm_provider)): + model_info_name = "{}/{}".format(custom_llm_provider, model) + else: + model_info_name = model + + model_info = litellm.get_model_info(model=model_info_name) + + ## CHECK USER SET MODEL INFO + user_model_info = deployment.get("model_info", {}) + + model_info.update(user_model_info) + + return model_info + + def get_model_info(self, id: str) -> Optional[dict]: + """ + For a given model id, return the model info + + Returns + - dict: the model in list with 'model_name', 'litellm_params', Optional['model_info'] + - None: could not find deployment in list + """ + for model in self.model_list: + if "model_info" in model and "id" in model["model_info"]: + if id == model["model_info"]["id"]: + return model + return None + + def get_model_group(self, id: str) -> Optional[List]: + """ + Return list of all models in the same model group as that model id + """ + + model_info = self.get_model_info(id=id) + if model_info is None: + return None + + model_name = model_info["model_name"] + return self.get_model_list(model_name=model_name) + + def _set_model_group_info( # noqa: PLR0915 + self, model_group: str, user_facing_model_group_name: str + ) -> Optional[ModelGroupInfo]: + """ + For a given model group name, return the combined model info + + Returns: + - ModelGroupInfo if able to construct a model group + - None if error constructing model group info + """ + model_group_info: Optional[ModelGroupInfo] = None + + total_tpm: Optional[int] = None + total_rpm: Optional[int] = None + configurable_clientside_auth_params: CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS = None + model_list = self.get_model_list(model_name=model_group) + if model_list is None: + return None + for model in model_list: + is_match = False + if ( + "model_name" in model and model["model_name"] == model_group + ): # exact match + is_match = True + elif ( + "model_name" in model + and self.pattern_router.route(model_group) is not None + ): # wildcard model + is_match = True + + if not is_match: + continue + # model in model group found # + litellm_params = LiteLLM_Params(**model["litellm_params"]) # type: ignore + # get configurable clientside auth params + configurable_clientside_auth_params = ( + litellm_params.configurable_clientside_auth_params + ) + # get model tpm + _deployment_tpm: Optional[int] = None + if _deployment_tpm is None: + _deployment_tpm = model.get("tpm", None) # type: ignore + if _deployment_tpm is None: + _deployment_tpm = model.get("litellm_params", {}).get("tpm", None) # type: ignore + if _deployment_tpm is None: + _deployment_tpm = model.get("model_info", {}).get("tpm", None) # type: ignore + + # get model rpm + _deployment_rpm: Optional[int] = None + if _deployment_rpm is None: + _deployment_rpm = model.get("rpm", None) # type: ignore + if _deployment_rpm is None: + _deployment_rpm = model.get("litellm_params", {}).get("rpm", None) # type: ignore + if _deployment_rpm is None: + _deployment_rpm = model.get("model_info", {}).get("rpm", None) # type: ignore + + # get model info + try: + model_info = litellm.get_model_info(model=litellm_params.model) + except Exception: + model_info = None + # get llm provider + litellm_model, llm_provider = "", "" + try: + litellm_model, llm_provider, _, _ = litellm.get_llm_provider( + model=litellm_params.model, + custom_llm_provider=litellm_params.custom_llm_provider, + ) + except litellm.exceptions.BadRequestError as e: + verbose_router_logger.error( + "litellm.router.py::get_model_group_info() - {}".format(str(e)) + ) + + if model_info is None: + supported_openai_params = litellm.get_supported_openai_params( + model=litellm_model, custom_llm_provider=llm_provider + ) + if supported_openai_params is None: + supported_openai_params = [] + model_info = ModelMapInfo( + key=model_group, + max_tokens=None, + max_input_tokens=None, + max_output_tokens=None, + input_cost_per_token=0, + output_cost_per_token=0, + litellm_provider=llm_provider, + mode="chat", + supported_openai_params=supported_openai_params, + supports_system_messages=None, + ) + + if model_group_info is None: + model_group_info = ModelGroupInfo( + model_group=user_facing_model_group_name, providers=[llm_provider], **model_info # type: ignore + ) + else: + # if max_input_tokens > curr + # if max_output_tokens > curr + # if input_cost_per_token > curr + # if output_cost_per_token > curr + # supports_parallel_function_calling == True + # supports_vision == True + # supports_function_calling == True + if llm_provider not in model_group_info.providers: + model_group_info.providers.append(llm_provider) + if ( + model_info.get("max_input_tokens", None) is not None + and model_info["max_input_tokens"] is not None + and ( + model_group_info.max_input_tokens is None + or model_info["max_input_tokens"] + > model_group_info.max_input_tokens + ) + ): + model_group_info.max_input_tokens = model_info["max_input_tokens"] + if ( + model_info.get("max_output_tokens", None) is not None + and model_info["max_output_tokens"] is not None + and ( + model_group_info.max_output_tokens is None + or model_info["max_output_tokens"] + > model_group_info.max_output_tokens + ) + ): + model_group_info.max_output_tokens = model_info["max_output_tokens"] + if model_info.get("input_cost_per_token", None) is not None and ( + model_group_info.input_cost_per_token is None + or model_info["input_cost_per_token"] + > model_group_info.input_cost_per_token + ): + model_group_info.input_cost_per_token = model_info[ + "input_cost_per_token" + ] + if model_info.get("output_cost_per_token", None) is not None and ( + model_group_info.output_cost_per_token is None + or model_info["output_cost_per_token"] + > model_group_info.output_cost_per_token + ): + model_group_info.output_cost_per_token = model_info[ + "output_cost_per_token" + ] + if ( + model_info.get("supports_parallel_function_calling", None) + is not None + and model_info["supports_parallel_function_calling"] is True # type: ignore + ): + model_group_info.supports_parallel_function_calling = True + if ( + model_info.get("supports_vision", None) is not None + and model_info["supports_vision"] is True # type: ignore + ): + model_group_info.supports_vision = True + if ( + model_info.get("supports_function_calling", None) is not None + and model_info["supports_function_calling"] is True # type: ignore + ): + model_group_info.supports_function_calling = True + if ( + model_info.get("supported_openai_params", None) is not None + and model_info["supported_openai_params"] is not None + ): + model_group_info.supported_openai_params = model_info[ + "supported_openai_params" + ] + if model_info.get("tpm", None) is not None and _deployment_tpm is None: + _deployment_tpm = model_info.get("tpm") + if model_info.get("rpm", None) is not None and _deployment_rpm is None: + _deployment_rpm = model_info.get("rpm") + + if _deployment_tpm is not None: + if total_tpm is None: + total_tpm = 0 + total_tpm += _deployment_tpm # type: ignore + + if _deployment_rpm is not None: + if total_rpm is None: + total_rpm = 0 + total_rpm += _deployment_rpm # type: ignore + if model_group_info is not None: + ## UPDATE WITH TOTAL TPM/RPM FOR MODEL GROUP + if total_tpm is not None: + model_group_info.tpm = total_tpm + + if total_rpm is not None: + model_group_info.rpm = total_rpm + + ## UPDATE WITH CONFIGURABLE CLIENTSIDE AUTH PARAMS FOR MODEL GROUP + if configurable_clientside_auth_params is not None: + model_group_info.configurable_clientside_auth_params = ( + configurable_clientside_auth_params + ) + + return model_group_info + + def get_model_group_info(self, model_group: str) -> Optional[ModelGroupInfo]: + """ + For a given model group name, return the combined model info + + Returns: + - ModelGroupInfo if able to construct a model group + - None if error constructing model group info or hidden model group + """ + ## Check if model group alias + if model_group in self.model_group_alias: + item = self.model_group_alias[model_group] + if isinstance(item, str): + _router_model_group = item + elif isinstance(item, dict): + if item["hidden"] is True: + return None + else: + _router_model_group = item["model"] + else: + return None + + return self._set_model_group_info( + model_group=_router_model_group, + user_facing_model_group_name=model_group, + ) + + ## Check if actual model + return self._set_model_group_info( + model_group=model_group, user_facing_model_group_name=model_group + ) + + async def get_model_group_usage( + self, model_group: str + ) -> Tuple[Optional[int], Optional[int]]: + """ + Returns current tpm/rpm usage for model group + + Parameters: + - model_group: str - the received model name from the user (can be a wildcard route). + + Returns: + - usage: Tuple[tpm, rpm] + """ + dt = get_utc_datetime() + current_minute = dt.strftime( + "%H-%M" + ) # use the same timezone regardless of system clock + tpm_keys: List[str] = [] + rpm_keys: List[str] = [] + + model_list = self.get_model_list(model_name=model_group) + if model_list is None: # no matching deployments + return None, None + + for model in model_list: + id: Optional[str] = model.get("model_info", {}).get("id") # type: ignore + litellm_model: Optional[str] = model["litellm_params"].get( + "model" + ) # USE THE MODEL SENT TO litellm.completion() - consistent with how global_router cache is written. + if id is None or litellm_model is None: + continue + tpm_keys.append( + RouterCacheEnum.TPM.value.format( + id=id, + model=litellm_model, + current_minute=current_minute, + ) + ) + rpm_keys.append( + RouterCacheEnum.RPM.value.format( + id=id, + model=litellm_model, + current_minute=current_minute, + ) + ) + combined_tpm_rpm_keys = tpm_keys + rpm_keys + + combined_tpm_rpm_values = await self.cache.async_batch_get_cache( + keys=combined_tpm_rpm_keys + ) + if combined_tpm_rpm_values is None: + return None, None + + tpm_usage_list: Optional[List] = combined_tpm_rpm_values[: len(tpm_keys)] + rpm_usage_list: Optional[List] = combined_tpm_rpm_values[len(tpm_keys) :] + + ## TPM + tpm_usage: Optional[int] = None + if tpm_usage_list is not None: + for t in tpm_usage_list: + if isinstance(t, int): + if tpm_usage is None: + tpm_usage = 0 + tpm_usage += t + ## RPM + rpm_usage: Optional[int] = None + if rpm_usage_list is not None: + for t in rpm_usage_list: + if isinstance(t, int): + if rpm_usage is None: + rpm_usage = 0 + rpm_usage += t + return tpm_usage, rpm_usage + + @lru_cache(maxsize=64) + def _cached_get_model_group_info( + self, model_group: str + ) -> Optional[ModelGroupInfo]: + """ + Cached version of get_model_group_info, uses @lru_cache wrapper + + This is a speed optimization, since set_response_headers makes a call to get_model_group_info on every request + """ + return self.get_model_group_info(model_group) + + async def get_remaining_model_group_usage(self, model_group: str) -> Dict[str, int]: + model_group_info = self._cached_get_model_group_info(model_group) + + if model_group_info is not None and model_group_info.tpm is not None: + tpm_limit = model_group_info.tpm + else: + tpm_limit = None + + if model_group_info is not None and model_group_info.rpm is not None: + rpm_limit = model_group_info.rpm + else: + rpm_limit = None + + if tpm_limit is None and rpm_limit is None: + return {} + + current_tpm, current_rpm = await self.get_model_group_usage(model_group) + + returned_dict = {} + if tpm_limit is not None: + returned_dict["x-ratelimit-remaining-tokens"] = tpm_limit - ( + current_tpm or 0 + ) + returned_dict["x-ratelimit-limit-tokens"] = tpm_limit + if rpm_limit is not None: + returned_dict["x-ratelimit-remaining-requests"] = rpm_limit - ( + current_rpm or 0 + ) + returned_dict["x-ratelimit-limit-requests"] = rpm_limit + + return returned_dict + + async def set_response_headers( + self, response: Any, model_group: Optional[str] = None + ) -> Any: + """ + Add the most accurate rate limit headers for a given model response. + + ## TODO: add model group rate limit headers + # - if healthy_deployments > 1, return model group rate limit headers + # - else return the model's rate limit headers + """ + if ( + isinstance(response, BaseModel) + and hasattr(response, "_hidden_params") + and isinstance(response._hidden_params, dict) # type: ignore + ): + response._hidden_params.setdefault("additional_headers", {}) # type: ignore + response._hidden_params["additional_headers"][ # type: ignore + "x-litellm-model-group" + ] = model_group + + additional_headers = response._hidden_params["additional_headers"] # type: ignore + + if ( + "x-ratelimit-remaining-tokens" not in additional_headers + and "x-ratelimit-remaining-requests" not in additional_headers + and model_group is not None + ): + remaining_usage = await self.get_remaining_model_group_usage( + model_group + ) + + for header, value in remaining_usage.items(): + if value is not None: + additional_headers[header] = value + return response + + def get_model_ids(self, model_name: Optional[str] = None) -> List[str]: + """ + if 'model_name' is none, returns all. + + Returns list of model id's. + """ + ids = [] + for model in self.model_list: + if "model_info" in model and "id" in model["model_info"]: + id = model["model_info"]["id"] + if model_name is not None and model["model_name"] == model_name: + ids.append(id) + elif model_name is None: + ids.append(id) + return ids + + def _get_all_deployments( + self, model_name: str, model_alias: Optional[str] = None + ) -> List[DeploymentTypedDict]: + """ + Return all deployments of a model name + + Used for accurate 'get_model_list'. + """ + returned_models: List[DeploymentTypedDict] = [] + for model in self.model_list: + if model_name is not None and model["model_name"] == model_name: + if model_alias is not None: + alias_model = copy.deepcopy(model) + alias_model["model_name"] = model_alias + returned_models.append(alias_model) + else: + returned_models.append(model) + + return returned_models + + def get_model_names(self, team_id: Optional[str] = None) -> List[str]: + """ + Returns all possible model names for the router, including models defined via model_group_alias. + + If a team_id is provided, only deployments configured with that team_id (i.e. team‐specific models) + will yield their team public name. + """ + deployments = self.get_model_list() or [] + model_names = [] + + for deployment in deployments: + model_info = deployment.get("model_info") + if self._is_team_specific_model(model_info): + team_model_name = self._get_team_specific_model( + deployment=deployment, team_id=team_id + ) + if team_model_name: + model_names.append(team_model_name) + else: + model_names.append(deployment.get("model_name", "")) + + return model_names + + def _get_team_specific_model( + self, deployment: DeploymentTypedDict, team_id: Optional[str] = None + ) -> Optional[str]: + """ + Get the team-specific model name if team_id matches the deployment. + + Args: + deployment: DeploymentTypedDict - The model deployment + team_id: Optional[str] - If passed, will return router models set with a `team_id` matching the passed `team_id`. + + Returns: + str: The `team_public_model_name` if team_id matches + None: If team_id doesn't match or no team info exists + """ + model_info: Optional[Dict] = deployment.get("model_info") or {} + if model_info is None: + return None + if team_id == model_info.get("team_id"): + return model_info.get("team_public_model_name") + return None + + def _is_team_specific_model(self, model_info: Optional[Dict]) -> bool: + """ + Check if model info contains team-specific configuration. + + Args: + model_info: Model information dictionary + + Returns: + bool: True if model has team-specific configuration + """ + return bool(model_info and model_info.get("team_id")) + + def get_model_list_from_model_alias( + self, model_name: Optional[str] = None + ) -> List[DeploymentTypedDict]: + """ + Helper function to get model list from model alias. + + Used by `.get_model_list` to get model list from model alias. + """ + returned_models: List[DeploymentTypedDict] = [] + for model_alias, model_value in self.model_group_alias.items(): + if model_name is not None and model_alias != model_name: + continue + if isinstance(model_value, str): + _router_model_name: str = model_value + elif isinstance(model_value, dict): + _model_value = RouterModelGroupAliasItem(**model_value) # type: ignore + if _model_value["hidden"] is True: + continue + else: + _router_model_name = _model_value["model"] + else: + continue + + returned_models.extend( + self._get_all_deployments( + model_name=_router_model_name, model_alias=model_alias + ) + ) + + return returned_models + + def get_model_list( + self, model_name: Optional[str] = None + ) -> Optional[List[DeploymentTypedDict]]: + """ + Includes router model_group_alias'es as well + """ + if hasattr(self, "model_list"): + returned_models: List[DeploymentTypedDict] = [] + + if model_name is not None: + returned_models.extend(self._get_all_deployments(model_name=model_name)) + + if hasattr(self, "model_group_alias"): + returned_models.extend( + self.get_model_list_from_model_alias(model_name=model_name) + ) + + if len(returned_models) == 0: # check if wildcard route + potential_wildcard_models = self.pattern_router.route(model_name) + if potential_wildcard_models is not None: + returned_models.extend( + [DeploymentTypedDict(**m) for m in potential_wildcard_models] # type: ignore + ) + + if model_name is None: + returned_models += self.model_list + + return returned_models + + return returned_models + return None + + def get_model_access_groups( + self, model_name: Optional[str] = None, model_access_group: Optional[str] = None + ) -> Dict[str, List[str]]: + """ + If model_name is provided, only return access groups for that model. + + Parameters: + - model_name: Optional[str] - the received model name from the user (can be a wildcard route). If set, will only return access groups for that model. + - model_access_group: Optional[str] - the received model access group from the user. If set, will only return models for that access group. + """ + from collections import defaultdict + + access_groups = defaultdict(list) + + model_list = self.get_model_list(model_name=model_name) + if model_list: + for m in model_list: + _model_info = m.get("model_info") + if _model_info: + for group in _model_info.get("access_groups", []) or []: + if model_access_group is not None: + if group == model_access_group: + model_name = m["model_name"] + access_groups[group].append(model_name) + else: + model_name = m["model_name"] + access_groups[group].append(model_name) + + return access_groups + + def _is_model_access_group_for_wildcard_route( + self, model_access_group: str + ) -> bool: + """ + Return True if model access group is a wildcard route + """ + # GET ACCESS GROUPS + access_groups = self.get_model_access_groups( + model_access_group=model_access_group + ) + + if len(access_groups) == 0: + return False + + models = access_groups.get(model_access_group, []) + + for model in models: + # CHECK IF MODEL ACCESS GROUP IS A WILDCARD ROUTE + if self.pattern_router.route(request=model) is not None: + return True + + return False + + def get_settings(self): + """ + Get router settings method, returns a dictionary of the settings and their values. + For example get the set values for routing_strategy_args, routing_strategy, allowed_fails, cooldown_time, num_retries, timeout, max_retries, retry_after + """ + _all_vars = vars(self) + _settings_to_return = {} + vars_to_include = [ + "routing_strategy_args", + "routing_strategy", + "allowed_fails", + "cooldown_time", + "num_retries", + "timeout", + "max_retries", + "retry_after", + "fallbacks", + "context_window_fallbacks", + "model_group_retry_policy", + ] + + for var in vars_to_include: + if var in _all_vars: + _settings_to_return[var] = _all_vars[var] + if ( + var == "routing_strategy_args" + and self.routing_strategy == "latency-based-routing" + ): + _settings_to_return[var] = self.lowestlatency_logger.routing_args.json() + return _settings_to_return + + def update_settings(self, **kwargs): + """ + Update the router settings. + """ + # only the following settings are allowed to be configured + _allowed_settings = [ + "routing_strategy_args", + "routing_strategy", + "allowed_fails", + "cooldown_time", + "num_retries", + "timeout", + "max_retries", + "retry_after", + "fallbacks", + "context_window_fallbacks", + "model_group_retry_policy", + ] + + _int_settings = [ + "timeout", + "num_retries", + "retry_after", + "allowed_fails", + "cooldown_time", + ] + + _existing_router_settings = self.get_settings() + for var in kwargs: + if var in _allowed_settings: + if var in _int_settings: + _casted_value = int(kwargs[var]) + setattr(self, var, _casted_value) + else: + # only run routing strategy init if it has changed + if ( + var == "routing_strategy" + and _existing_router_settings["routing_strategy"] != kwargs[var] + ): + self.routing_strategy_init( + routing_strategy=kwargs[var], + routing_strategy_args=kwargs.get( + "routing_strategy_args", {} + ), + ) + setattr(self, var, kwargs[var]) + else: + verbose_router_logger.debug("Setting {} is not allowed".format(var)) + verbose_router_logger.debug(f"Updated Router settings: {self.get_settings()}") + + def _get_client(self, deployment, kwargs, client_type=None): + """ + Returns the appropriate client based on the given deployment, kwargs, and client_type. + + Parameters: + deployment (dict): The deployment dictionary containing the clients. + kwargs (dict): The keyword arguments passed to the function. + client_type (str): The type of client to return. + + Returns: + The appropriate client based on the given client_type and kwargs. + """ + model_id = deployment["model_info"]["id"] + parent_otel_span: Optional[Span] = _get_parent_otel_span_from_kwargs(kwargs) + if client_type == "max_parallel_requests": + cache_key = "{}_max_parallel_requests_client".format(model_id) + client = self.cache.get_cache( + key=cache_key, local_only=True, parent_otel_span=parent_otel_span + ) + if client is None: + InitalizeCachedClient.set_max_parallel_requests_client( + litellm_router_instance=self, model=deployment + ) + client = self.cache.get_cache( + key=cache_key, local_only=True, parent_otel_span=parent_otel_span + ) + return client + elif client_type == "async": + if kwargs.get("stream") is True: + cache_key = f"{model_id}_stream_async_client" + client = self.cache.get_cache( + key=cache_key, local_only=True, parent_otel_span=parent_otel_span + ) + return client + else: + cache_key = f"{model_id}_async_client" + client = self.cache.get_cache( + key=cache_key, local_only=True, parent_otel_span=parent_otel_span + ) + return client + else: + if kwargs.get("stream") is True: + cache_key = f"{model_id}_stream_client" + client = self.cache.get_cache( + key=cache_key, parent_otel_span=parent_otel_span + ) + return client + else: + cache_key = f"{model_id}_client" + client = self.cache.get_cache( + key=cache_key, parent_otel_span=parent_otel_span + ) + return client + + def _pre_call_checks( # noqa: PLR0915 + self, + model: str, + healthy_deployments: List, + messages: List[Dict[str, str]], + request_kwargs: Optional[dict] = None, + ): + """ + Filter out model in model group, if: + + - model context window < message length. For azure openai models, requires 'base_model' is set. - https://docs.litellm.ai/docs/proxy/cost_tracking#spend-tracking-for-azure-openai-models + - filter models above rpm limits + - if region given, filter out models not in that region / unknown region + - [TODO] function call and model doesn't support function calling + """ + + verbose_router_logger.debug( + f"Starting Pre-call checks for deployments in model={model}" + ) + + _returned_deployments = copy.deepcopy(healthy_deployments) + + invalid_model_indices = [] + + try: + input_tokens = litellm.token_counter(messages=messages) + except Exception as e: + verbose_router_logger.error( + "litellm.router.py::_pre_call_checks: failed to count tokens. Returning initial list of deployments. Got - {}".format( + str(e) + ) + ) + return _returned_deployments + + _context_window_error = False + _potential_error_str = "" + _rate_limit_error = False + parent_otel_span = _get_parent_otel_span_from_kwargs(request_kwargs) + + ## get model group RPM ## + dt = get_utc_datetime() + current_minute = dt.strftime("%H-%M") + rpm_key = f"{model}:rpm:{current_minute}" + model_group_cache = ( + self.cache.get_cache( + key=rpm_key, local_only=True, parent_otel_span=parent_otel_span + ) + or {} + ) # check the in-memory cache used by lowest_latency and usage-based routing. Only check the local cache. + for idx, deployment in enumerate(_returned_deployments): + # see if we have the info for this model + try: + base_model = deployment.get("model_info", {}).get("base_model", None) + if base_model is None: + base_model = deployment.get("litellm_params", {}).get( + "base_model", None + ) + model_info = self.get_router_model_info( + deployment=deployment, received_model_name=model + ) + model = base_model or deployment.get("litellm_params", {}).get( + "model", None + ) + + if ( + isinstance(model_info, dict) + and model_info.get("max_input_tokens", None) is not None + ): + if ( + isinstance(model_info["max_input_tokens"], int) + and input_tokens > model_info["max_input_tokens"] + ): + invalid_model_indices.append(idx) + _context_window_error = True + _potential_error_str += ( + "Model={}, Max Input Tokens={}, Got={}".format( + model, model_info["max_input_tokens"], input_tokens + ) + ) + continue + except Exception as e: + verbose_router_logger.exception("An error occurs - {}".format(str(e))) + + _litellm_params = deployment.get("litellm_params", {}) + model_id = deployment.get("model_info", {}).get("id", "") + ## RPM CHECK ## + ### get local router cache ### + current_request_cache_local = ( + self.cache.get_cache( + key=model_id, local_only=True, parent_otel_span=parent_otel_span + ) + or 0 + ) + ### get usage based cache ### + if ( + isinstance(model_group_cache, dict) + and self.routing_strategy != "usage-based-routing-v2" + ): + model_group_cache[model_id] = model_group_cache.get(model_id, 0) + + current_request = max( + current_request_cache_local, model_group_cache[model_id] + ) + + if ( + isinstance(_litellm_params, dict) + and _litellm_params.get("rpm", None) is not None + ): + if ( + isinstance(_litellm_params["rpm"], int) + and _litellm_params["rpm"] <= current_request + ): + invalid_model_indices.append(idx) + _rate_limit_error = True + continue + + ## REGION CHECK ## + if ( + request_kwargs is not None + and request_kwargs.get("allowed_model_region") is not None + ): + allowed_model_region = request_kwargs.get("allowed_model_region") + + if allowed_model_region is not None: + if not is_region_allowed( + litellm_params=LiteLLM_Params(**_litellm_params), + allowed_model_region=allowed_model_region, + ): + invalid_model_indices.append(idx) + continue + + ## INVALID PARAMS ## -> catch 'gpt-3.5-turbo-16k' not supporting 'response_format' param + if request_kwargs is not None and litellm.drop_params is False: + # get supported params + model, custom_llm_provider, _, _ = litellm.get_llm_provider( + model=model, litellm_params=LiteLLM_Params(**_litellm_params) + ) + + supported_openai_params = litellm.get_supported_openai_params( + model=model, custom_llm_provider=custom_llm_provider + ) + + if supported_openai_params is None: + continue + else: + # check the non-default openai params in request kwargs + non_default_params = litellm.utils.get_non_default_params( + passed_params=request_kwargs + ) + special_params = ["response_format"] + # check if all params are supported + for k, v in non_default_params.items(): + if k not in supported_openai_params and k in special_params: + # if not -> invalid model + verbose_router_logger.debug( + f"INVALID MODEL INDEX @ REQUEST KWARG FILTERING, k={k}" + ) + invalid_model_indices.append(idx) + + if len(invalid_model_indices) == len(_returned_deployments): + """ + - no healthy deployments available b/c context window checks or rate limit error + + - First check for rate limit errors (if this is true, it means the model passed the context window check but failed the rate limit check) + """ + + if _rate_limit_error is True: # allow generic fallback logic to take place + raise RouterRateLimitErrorBasic( + model=model, + ) + + elif _context_window_error is True: + raise litellm.ContextWindowExceededError( + message="litellm._pre_call_checks: Context Window exceeded for given call. No models have context window large enough for this call.\n{}".format( + _potential_error_str + ), + model=model, + llm_provider="", + ) + if len(invalid_model_indices) > 0: + for idx in reversed(invalid_model_indices): + _returned_deployments.pop(idx) + + ## ORDER FILTERING ## -> if user set 'order' in deployments, return deployments with lowest order (e.g. order=1 > order=2) + if len(_returned_deployments) > 0: + _returned_deployments = litellm.utils._get_order_filtered_deployments( + _returned_deployments + ) + + return _returned_deployments + + def _get_model_from_alias(self, model: str) -> Optional[str]: + """ + Get the model from the alias. + + Returns: + - str, the litellm model name + - None, if model is not in model group alias + """ + if model not in self.model_group_alias: + return None + + _item = self.model_group_alias[model] + if isinstance(_item, str): + model = _item + else: + model = _item["model"] + + return model + + def _get_deployment_by_litellm_model(self, model: str) -> List: + """ + Get the deployment by litellm model. + """ + return [m for m in self.model_list if m["litellm_params"]["model"] == model] + + def _common_checks_available_deployment( + self, + model: str, + messages: Optional[List[Dict[str, str]]] = None, + input: Optional[Union[str, List]] = None, + specific_deployment: Optional[bool] = False, + ) -> Tuple[str, Union[List, Dict]]: + """ + Common checks for 'get_available_deployment' across sync + async call. + + If 'healthy_deployments' returned is None, this means the user chose a specific deployment + + Returns + - str, the litellm model name + - List, if multiple models chosen + - Dict, if specific model chosen + """ + # check if aliases set on litellm model alias map + if specific_deployment is True: + return model, self._get_deployment_by_litellm_model(model=model) + elif model in self.get_model_ids(): + deployment = self.get_deployment(model_id=model) + if deployment is not None: + deployment_model = deployment.litellm_params.model + return deployment_model, deployment.model_dump(exclude_none=True) + raise ValueError( + f"LiteLLM Router: Trying to call specific deployment, but Model ID :{model} does not exist in \ + Model ID List: {self.get_model_ids}" + ) + + _model_from_alias = self._get_model_from_alias(model=model) + if _model_from_alias is not None: + model = _model_from_alias + + if model not in self.model_names: + # check if provider/ specific wildcard routing use pattern matching + pattern_deployments = self.pattern_router.get_deployments_by_pattern( + model=model, + ) + if pattern_deployments: + return model, pattern_deployments + + # check if default deployment is set + if self.default_deployment is not None: + updated_deployment = copy.deepcopy( + self.default_deployment + ) # self.default_deployment + updated_deployment["litellm_params"]["model"] = model + return model, updated_deployment + + ## get healthy deployments + ### get all deployments + healthy_deployments = self._get_all_deployments(model_name=model) + + if len(healthy_deployments) == 0: + # check if the user sent in a deployment name instead + healthy_deployments = self._get_deployment_by_litellm_model(model=model) + + verbose_router_logger.debug( + f"initial list of deployments: {healthy_deployments}" + ) + + if len(healthy_deployments) == 0: + raise litellm.BadRequestError( + message="You passed in model={}. There is no 'model_name' with this string ".format( + model + ), + model=model, + llm_provider="", + ) + + if litellm.model_alias_map and model in litellm.model_alias_map: + model = litellm.model_alias_map[ + model + ] # update the model to the actual value if an alias has been passed in + + return model, healthy_deployments + + async def async_get_available_deployment( + self, + model: str, + request_kwargs: Dict, + messages: Optional[List[Dict[str, str]]] = None, + input: Optional[Union[str, List]] = None, + specific_deployment: Optional[bool] = False, + ): + """ + Async implementation of 'get_available_deployments'. + + Allows all cache calls to be made async => 10x perf impact (8rps -> 100 rps). + """ + if ( + self.routing_strategy != "usage-based-routing-v2" + and self.routing_strategy != "simple-shuffle" + and self.routing_strategy != "cost-based-routing" + and self.routing_strategy != "latency-based-routing" + and self.routing_strategy != "least-busy" + ): # prevent regressions for other routing strategies, that don't have async get available deployments implemented. + return self.get_available_deployment( + model=model, + messages=messages, + input=input, + specific_deployment=specific_deployment, + request_kwargs=request_kwargs, + ) + try: + parent_otel_span = _get_parent_otel_span_from_kwargs(request_kwargs) + model, healthy_deployments = self._common_checks_available_deployment( + model=model, + messages=messages, + input=input, + specific_deployment=specific_deployment, + ) # type: ignore + if isinstance(healthy_deployments, dict): + return healthy_deployments + + cooldown_deployments = await _async_get_cooldown_deployments( + litellm_router_instance=self, parent_otel_span=parent_otel_span + ) + verbose_router_logger.debug( + f"async cooldown deployments: {cooldown_deployments}" + ) + verbose_router_logger.debug(f"cooldown_deployments: {cooldown_deployments}") + healthy_deployments = self._filter_cooldown_deployments( + healthy_deployments=healthy_deployments, + cooldown_deployments=cooldown_deployments, + ) + + healthy_deployments = await self.async_callback_filter_deployments( + model=model, + healthy_deployments=healthy_deployments, + messages=( + cast(List[AllMessageValues], messages) + if messages is not None + else None + ), + request_kwargs=request_kwargs, + parent_otel_span=parent_otel_span, + ) + + if self.enable_pre_call_checks and messages is not None: + healthy_deployments = self._pre_call_checks( + model=model, + healthy_deployments=cast(List[Dict], healthy_deployments), + messages=messages, + request_kwargs=request_kwargs, + ) + # check if user wants to do tag based routing + healthy_deployments = await get_deployments_for_tag( # type: ignore + llm_router_instance=self, + model=model, + request_kwargs=request_kwargs, + healthy_deployments=healthy_deployments, + ) + + if len(healthy_deployments) == 0: + exception = await async_raise_no_deployment_exception( + litellm_router_instance=self, + model=model, + parent_otel_span=parent_otel_span, + ) + raise exception + start_time = time.time() + if ( + self.routing_strategy == "usage-based-routing-v2" + and self.lowesttpm_logger_v2 is not None + ): + deployment = ( + await self.lowesttpm_logger_v2.async_get_available_deployments( + model_group=model, + healthy_deployments=healthy_deployments, # type: ignore + messages=messages, + input=input, + ) + ) + elif ( + self.routing_strategy == "cost-based-routing" + and self.lowestcost_logger is not None + ): + deployment = ( + await self.lowestcost_logger.async_get_available_deployments( + model_group=model, + healthy_deployments=healthy_deployments, # type: ignore + messages=messages, + input=input, + ) + ) + elif ( + self.routing_strategy == "latency-based-routing" + and self.lowestlatency_logger is not None + ): + deployment = ( + await self.lowestlatency_logger.async_get_available_deployments( + model_group=model, + healthy_deployments=healthy_deployments, # type: ignore + messages=messages, + input=input, + request_kwargs=request_kwargs, + ) + ) + elif self.routing_strategy == "simple-shuffle": + return simple_shuffle( + llm_router_instance=self, + healthy_deployments=healthy_deployments, + model=model, + ) + elif ( + self.routing_strategy == "least-busy" + and self.leastbusy_logger is not None + ): + deployment = ( + await self.leastbusy_logger.async_get_available_deployments( + model_group=model, + healthy_deployments=healthy_deployments, # type: ignore + ) + ) + else: + deployment = None + if deployment is None: + exception = await async_raise_no_deployment_exception( + litellm_router_instance=self, + model=model, + parent_otel_span=parent_otel_span, + ) + raise exception + verbose_router_logger.info( + f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment)} for model: {model}" + ) + + end_time = time.time() + _duration = end_time - start_time + asyncio.create_task( + self.service_logger_obj.async_service_success_hook( + service=ServiceTypes.ROUTER, + duration=_duration, + call_type="<routing_strategy>.async_get_available_deployments", + parent_otel_span=parent_otel_span, + start_time=start_time, + end_time=end_time, + ) + ) + + return deployment + except Exception as e: + traceback_exception = traceback.format_exc() + # if router rejects call -> log to langfuse/otel/etc. + if request_kwargs is not None: + logging_obj = request_kwargs.get("litellm_logging_obj", None) + + if logging_obj is not None: + ## LOGGING + threading.Thread( + target=logging_obj.failure_handler, + args=(e, traceback_exception), + ).start() # log response + # Handle any exceptions that might occur during streaming + asyncio.create_task( + logging_obj.async_failure_handler(e, traceback_exception) # type: ignore + ) + raise e + + def get_available_deployment( + self, + model: str, + messages: Optional[List[Dict[str, str]]] = None, + input: Optional[Union[str, List]] = None, + specific_deployment: Optional[bool] = False, + request_kwargs: Optional[Dict] = None, + ): + """ + Returns the deployment based on routing strategy + """ + # users need to explicitly call a specific deployment, by setting `specific_deployment = True` as completion()/embedding() kwarg + # When this was no explicit we had several issues with fallbacks timing out + + model, healthy_deployments = self._common_checks_available_deployment( + model=model, + messages=messages, + input=input, + specific_deployment=specific_deployment, + ) + + if isinstance(healthy_deployments, dict): + return healthy_deployments + + parent_otel_span: Optional[Span] = _get_parent_otel_span_from_kwargs( + request_kwargs + ) + cooldown_deployments = _get_cooldown_deployments( + litellm_router_instance=self, parent_otel_span=parent_otel_span + ) + healthy_deployments = self._filter_cooldown_deployments( + healthy_deployments=healthy_deployments, + cooldown_deployments=cooldown_deployments, + ) + + # filter pre-call checks + if self.enable_pre_call_checks and messages is not None: + healthy_deployments = self._pre_call_checks( + model=model, + healthy_deployments=healthy_deployments, + messages=messages, + request_kwargs=request_kwargs, + ) + + if len(healthy_deployments) == 0: + model_ids = self.get_model_ids(model_name=model) + _cooldown_time = self.cooldown_cache.get_min_cooldown( + model_ids=model_ids, parent_otel_span=parent_otel_span + ) + _cooldown_list = _get_cooldown_deployments( + litellm_router_instance=self, parent_otel_span=parent_otel_span + ) + raise RouterRateLimitError( + model=model, + cooldown_time=_cooldown_time, + enable_pre_call_checks=self.enable_pre_call_checks, + cooldown_list=_cooldown_list, + ) + + if self.routing_strategy == "least-busy" and self.leastbusy_logger is not None: + deployment = self.leastbusy_logger.get_available_deployments( + model_group=model, healthy_deployments=healthy_deployments # type: ignore + ) + elif self.routing_strategy == "simple-shuffle": + # if users pass rpm or tpm, we do a random weighted pick - based on rpm/tpm + ############## Check 'weight' param set for weighted pick ################# + return simple_shuffle( + llm_router_instance=self, + healthy_deployments=healthy_deployments, + model=model, + ) + elif ( + self.routing_strategy == "latency-based-routing" + and self.lowestlatency_logger is not None + ): + deployment = self.lowestlatency_logger.get_available_deployments( + model_group=model, + healthy_deployments=healthy_deployments, # type: ignore + request_kwargs=request_kwargs, + ) + elif ( + self.routing_strategy == "usage-based-routing" + and self.lowesttpm_logger is not None + ): + deployment = self.lowesttpm_logger.get_available_deployments( + model_group=model, + healthy_deployments=healthy_deployments, # type: ignore + messages=messages, + input=input, + ) + elif ( + self.routing_strategy == "usage-based-routing-v2" + and self.lowesttpm_logger_v2 is not None + ): + deployment = self.lowesttpm_logger_v2.get_available_deployments( + model_group=model, + healthy_deployments=healthy_deployments, # type: ignore + messages=messages, + input=input, + ) + else: + deployment = None + + if deployment is None: + verbose_router_logger.info( + f"get_available_deployment for model: {model}, No deployment available" + ) + model_ids = self.get_model_ids(model_name=model) + _cooldown_time = self.cooldown_cache.get_min_cooldown( + model_ids=model_ids, parent_otel_span=parent_otel_span + ) + _cooldown_list = _get_cooldown_deployments( + litellm_router_instance=self, parent_otel_span=parent_otel_span + ) + raise RouterRateLimitError( + model=model, + cooldown_time=_cooldown_time, + enable_pre_call_checks=self.enable_pre_call_checks, + cooldown_list=_cooldown_list, + ) + verbose_router_logger.info( + f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment)} for model: {model}" + ) + return deployment + + def _filter_cooldown_deployments( + self, healthy_deployments: List[Dict], cooldown_deployments: List[str] + ) -> List[Dict]: + """ + Filters out the deployments currently cooling down from the list of healthy deployments + + Args: + healthy_deployments: List of healthy deployments + cooldown_deployments: List of model_ids cooling down. cooldown_deployments is a list of model_id's cooling down, cooldown_deployments = ["16700539-b3cd-42f4-b426-6a12a1bb706a", "16700539-b3cd-42f4-b426-7899"] + + Returns: + List of healthy deployments + """ + # filter out the deployments currently cooling down + deployments_to_remove = [] + verbose_router_logger.debug(f"cooldown deployments: {cooldown_deployments}") + # Find deployments in model_list whose model_id is cooling down + for deployment in healthy_deployments: + deployment_id = deployment["model_info"]["id"] + if deployment_id in cooldown_deployments: + deployments_to_remove.append(deployment) + + # remove unhealthy deployments from healthy deployments + for deployment in deployments_to_remove: + healthy_deployments.remove(deployment) + return healthy_deployments + + def _track_deployment_metrics( + self, deployment, parent_otel_span: Optional[Span], response=None + ): + """ + Tracks successful requests rpm usage. + """ + try: + model_id = deployment.get("model_info", {}).get("id", None) + if response is None: + + # update self.deployment_stats + if model_id is not None: + self._update_usage( + model_id, parent_otel_span + ) # update in-memory cache for tracking + except Exception as e: + verbose_router_logger.error(f"Error in _track_deployment_metrics: {str(e)}") + + def get_num_retries_from_retry_policy( + self, exception: Exception, model_group: Optional[str] = None + ): + return _get_num_retries_from_retry_policy( + exception=exception, + model_group=model_group, + model_group_retry_policy=self.model_group_retry_policy, + retry_policy=self.retry_policy, + ) + + def get_allowed_fails_from_policy(self, exception: Exception): + """ + BadRequestErrorRetries: Optional[int] = None + AuthenticationErrorRetries: Optional[int] = None + TimeoutErrorRetries: Optional[int] = None + RateLimitErrorRetries: Optional[int] = None + ContentPolicyViolationErrorRetries: Optional[int] = None + """ + # if we can find the exception then in the retry policy -> return the number of retries + allowed_fails_policy: Optional[AllowedFailsPolicy] = self.allowed_fails_policy + + if allowed_fails_policy is None: + return None + + if ( + isinstance(exception, litellm.BadRequestError) + and allowed_fails_policy.BadRequestErrorAllowedFails is not None + ): + return allowed_fails_policy.BadRequestErrorAllowedFails + if ( + isinstance(exception, litellm.AuthenticationError) + and allowed_fails_policy.AuthenticationErrorAllowedFails is not None + ): + return allowed_fails_policy.AuthenticationErrorAllowedFails + if ( + isinstance(exception, litellm.Timeout) + and allowed_fails_policy.TimeoutErrorAllowedFails is not None + ): + return allowed_fails_policy.TimeoutErrorAllowedFails + if ( + isinstance(exception, litellm.RateLimitError) + and allowed_fails_policy.RateLimitErrorAllowedFails is not None + ): + return allowed_fails_policy.RateLimitErrorAllowedFails + if ( + isinstance(exception, litellm.ContentPolicyViolationError) + and allowed_fails_policy.ContentPolicyViolationErrorAllowedFails is not None + ): + return allowed_fails_policy.ContentPolicyViolationErrorAllowedFails + + def _initialize_alerting(self): + from litellm.integrations.SlackAlerting.slack_alerting import SlackAlerting + + if self.alerting_config is None: + return + + router_alerting_config: AlertingConfig = self.alerting_config + + _slack_alerting_logger = SlackAlerting( + alerting_threshold=router_alerting_config.alerting_threshold, + alerting=["slack"], + default_webhook_url=router_alerting_config.webhook_url, + ) + + self.slack_alerting_logger = _slack_alerting_logger + + litellm.logging_callback_manager.add_litellm_callback(_slack_alerting_logger) # type: ignore + litellm.logging_callback_manager.add_litellm_success_callback( + _slack_alerting_logger.response_taking_too_long_callback + ) + verbose_router_logger.info( + "\033[94m\nInitialized Alerting for litellm.Router\033[0m\n" + ) + + def set_custom_routing_strategy( + self, CustomRoutingStrategy: CustomRoutingStrategyBase + ): + """ + Sets get_available_deployment and async_get_available_deployment on an instanced of litellm.Router + + Use this to set your custom routing strategy + + Args: + CustomRoutingStrategy: litellm.router.CustomRoutingStrategyBase + """ + + setattr( + self, + "get_available_deployment", + CustomRoutingStrategy.get_available_deployment, + ) + setattr( + self, + "async_get_available_deployment", + CustomRoutingStrategy.async_get_available_deployment, + ) + + def flush_cache(self): + litellm.cache = None + self.cache.flush_cache() + + def reset(self): + ## clean up on close + litellm.success_callback = [] + litellm._async_success_callback = [] + litellm.failure_callback = [] + litellm._async_failure_callback = [] + self.retry_policy = None + self.flush_cache() |