aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/router.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/router.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/router.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/router.py6249
1 files changed, 6249 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/router.py b/.venv/lib/python3.12/site-packages/litellm/router.py
new file mode 100644
index 00000000..a395c851
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/router.py
@@ -0,0 +1,6249 @@
+# +-----------------------------------------------+
+# | |
+# | Give Feedback / Get Help |
+# | https://github.com/BerriAI/litellm/issues/new |
+# | |
+# +-----------------------------------------------+
+#
+# Thank you ! We ❤️ you! - Krrish & Ishaan
+
+import asyncio
+import copy
+import enum
+import hashlib
+import inspect
+import json
+import logging
+import threading
+import time
+import traceback
+import uuid
+from collections import defaultdict
+from functools import lru_cache
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Dict,
+ List,
+ Literal,
+ Optional,
+ Tuple,
+ Union,
+ cast,
+)
+
+import httpx
+import openai
+from openai import AsyncOpenAI
+from pydantic import BaseModel
+from typing_extensions import overload
+
+import litellm
+import litellm.litellm_core_utils
+import litellm.litellm_core_utils.exception_mapping_utils
+from litellm import get_secret_str
+from litellm._logging import verbose_router_logger
+from litellm.caching.caching import (
+ DualCache,
+ InMemoryCache,
+ RedisCache,
+ RedisClusterCache,
+)
+from litellm.integrations.custom_logger import CustomLogger
+from litellm.litellm_core_utils.asyncify import run_async_function
+from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
+from litellm.litellm_core_utils.dd_tracing import tracer
+from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
+from litellm.router_strategy.budget_limiter import RouterBudgetLimiting
+from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
+from litellm.router_strategy.lowest_cost import LowestCostLoggingHandler
+from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler
+from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler
+from litellm.router_strategy.lowest_tpm_rpm_v2 import LowestTPMLoggingHandler_v2
+from litellm.router_strategy.simple_shuffle import simple_shuffle
+from litellm.router_strategy.tag_based_routing import get_deployments_for_tag
+from litellm.router_utils.add_retry_fallback_headers import (
+ add_fallback_headers_to_response,
+ add_retry_headers_to_response,
+)
+from litellm.router_utils.batch_utils import (
+ _get_router_metadata_variable_name,
+ replace_model_in_jsonl,
+)
+from litellm.router_utils.client_initalization_utils import InitalizeCachedClient
+from litellm.router_utils.clientside_credential_handler import (
+ get_dynamic_litellm_params,
+ is_clientside_credential,
+)
+from litellm.router_utils.cooldown_cache import CooldownCache
+from litellm.router_utils.cooldown_handlers import (
+ DEFAULT_COOLDOWN_TIME_SECONDS,
+ _async_get_cooldown_deployments,
+ _async_get_cooldown_deployments_with_debug_info,
+ _get_cooldown_deployments,
+ _set_cooldown_deployments,
+)
+from litellm.router_utils.fallback_event_handlers import (
+ _check_non_standard_fallback_format,
+ get_fallback_model_group,
+ run_async_fallback,
+)
+from litellm.router_utils.get_retry_from_policy import (
+ get_num_retries_from_retry_policy as _get_num_retries_from_retry_policy,
+)
+from litellm.router_utils.handle_error import (
+ async_raise_no_deployment_exception,
+ send_llm_exception_alert,
+)
+from litellm.router_utils.pre_call_checks.prompt_caching_deployment_check import (
+ PromptCachingDeploymentCheck,
+)
+from litellm.router_utils.router_callbacks.track_deployment_metrics import (
+ increment_deployment_failures_for_current_minute,
+ increment_deployment_successes_for_current_minute,
+)
+from litellm.scheduler import FlowItem, Scheduler
+from litellm.types.llms.openai import AllMessageValues, Batch, FileObject, FileTypes
+from litellm.types.router import (
+ CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS,
+ VALID_LITELLM_ENVIRONMENTS,
+ AlertingConfig,
+ AllowedFailsPolicy,
+ AssistantsTypedDict,
+ CredentialLiteLLMParams,
+ CustomRoutingStrategyBase,
+ Deployment,
+ DeploymentTypedDict,
+ LiteLLM_Params,
+ ModelGroupInfo,
+ OptionalPreCallChecks,
+ RetryPolicy,
+ RouterCacheEnum,
+ RouterGeneralSettings,
+ RouterModelGroupAliasItem,
+ RouterRateLimitError,
+ RouterRateLimitErrorBasic,
+ RoutingStrategy,
+)
+from litellm.types.services import ServiceTypes
+from litellm.types.utils import GenericBudgetConfigType
+from litellm.types.utils import ModelInfo as ModelMapInfo
+from litellm.types.utils import StandardLoggingPayload
+from litellm.utils import (
+ CustomStreamWrapper,
+ EmbeddingResponse,
+ ModelResponse,
+ Rules,
+ function_setup,
+ get_llm_provider,
+ get_non_default_completion_params,
+ get_secret,
+ get_utc_datetime,
+ is_region_allowed,
+)
+
+from .router_utils.pattern_match_deployments import PatternMatchRouter
+
+if TYPE_CHECKING:
+ from opentelemetry.trace import Span as _Span
+
+ Span = _Span
+else:
+ Span = Any
+
+
+class RoutingArgs(enum.Enum):
+ ttl = 60 # 1min (RPM/TPM expire key)
+
+
+class Router:
+ model_names: List = []
+ cache_responses: Optional[bool] = False
+ default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour
+ tenacity = None
+ leastbusy_logger: Optional[LeastBusyLoggingHandler] = None
+ lowesttpm_logger: Optional[LowestTPMLoggingHandler] = None
+
+ def __init__( # noqa: PLR0915
+ self,
+ model_list: Optional[
+ Union[List[DeploymentTypedDict], List[Dict[str, Any]]]
+ ] = None,
+ ## ASSISTANTS API ##
+ assistants_config: Optional[AssistantsTypedDict] = None,
+ ## CACHING ##
+ redis_url: Optional[str] = None,
+ redis_host: Optional[str] = None,
+ redis_port: Optional[int] = None,
+ redis_password: Optional[str] = None,
+ cache_responses: Optional[bool] = False,
+ cache_kwargs: dict = {}, # additional kwargs to pass to RedisCache (see caching.py)
+ caching_groups: Optional[
+ List[tuple]
+ ] = None, # if you want to cache across model groups
+ client_ttl: int = 3600, # ttl for cached clients - will re-initialize after this time in seconds
+ ## SCHEDULER ##
+ polling_interval: Optional[float] = None,
+ default_priority: Optional[int] = None,
+ ## RELIABILITY ##
+ num_retries: Optional[int] = None,
+ max_fallbacks: Optional[
+ int
+ ] = None, # max fallbacks to try before exiting the call. Defaults to 5.
+ timeout: Optional[float] = None,
+ stream_timeout: Optional[float] = None,
+ default_litellm_params: Optional[
+ dict
+ ] = None, # default params for Router.chat.completion.create
+ default_max_parallel_requests: Optional[int] = None,
+ set_verbose: bool = False,
+ debug_level: Literal["DEBUG", "INFO"] = "INFO",
+ default_fallbacks: Optional[
+ List[str]
+ ] = None, # generic fallbacks, works across all deployments
+ fallbacks: List = [],
+ context_window_fallbacks: List = [],
+ content_policy_fallbacks: List = [],
+ model_group_alias: Optional[
+ Dict[str, Union[str, RouterModelGroupAliasItem]]
+ ] = {},
+ enable_pre_call_checks: bool = False,
+ enable_tag_filtering: bool = False,
+ retry_after: int = 0, # min time to wait before retrying a failed request
+ retry_policy: Optional[
+ Union[RetryPolicy, dict]
+ ] = None, # set custom retries for different exceptions
+ model_group_retry_policy: Dict[
+ str, RetryPolicy
+ ] = {}, # set custom retry policies based on model group
+ allowed_fails: Optional[
+ int
+ ] = None, # Number of times a deployment can failbefore being added to cooldown
+ allowed_fails_policy: Optional[
+ AllowedFailsPolicy
+ ] = None, # set custom allowed fails policy
+ cooldown_time: Optional[
+ float
+ ] = None, # (seconds) time to cooldown a deployment after failure
+ disable_cooldowns: Optional[bool] = None,
+ routing_strategy: Literal[
+ "simple-shuffle",
+ "least-busy",
+ "usage-based-routing",
+ "latency-based-routing",
+ "cost-based-routing",
+ "usage-based-routing-v2",
+ ] = "simple-shuffle",
+ optional_pre_call_checks: Optional[OptionalPreCallChecks] = None,
+ routing_strategy_args: dict = {}, # just for latency-based
+ provider_budget_config: Optional[GenericBudgetConfigType] = None,
+ alerting_config: Optional[AlertingConfig] = None,
+ router_general_settings: Optional[
+ RouterGeneralSettings
+ ] = RouterGeneralSettings(),
+ ) -> None:
+ """
+ Initialize the Router class with the given parameters for caching, reliability, and routing strategy.
+
+ Args:
+ model_list (Optional[list]): List of models to be used. Defaults to None.
+ redis_url (Optional[str]): URL of the Redis server. Defaults to None.
+ redis_host (Optional[str]): Hostname of the Redis server. Defaults to None.
+ redis_port (Optional[int]): Port of the Redis server. Defaults to None.
+ redis_password (Optional[str]): Password of the Redis server. Defaults to None.
+ cache_responses (Optional[bool]): Flag to enable caching of responses. Defaults to False.
+ cache_kwargs (dict): Additional kwargs to pass to RedisCache. Defaults to {}.
+ caching_groups (Optional[List[tuple]]): List of model groups for caching across model groups. Defaults to None.
+ client_ttl (int): Time-to-live for cached clients in seconds. Defaults to 3600.
+ polling_interval: (Optional[float]): frequency of polling queue. Only for '.scheduler_acompletion()'. Default is 3ms.
+ default_priority: (Optional[int]): the default priority for a request. Only for '.scheduler_acompletion()'. Default is None.
+ num_retries (Optional[int]): Number of retries for failed requests. Defaults to 2.
+ timeout (Optional[float]): Timeout for requests. Defaults to None.
+ default_litellm_params (dict): Default parameters for Router.chat.completion.create. Defaults to {}.
+ set_verbose (bool): Flag to set verbose mode. Defaults to False.
+ debug_level (Literal["DEBUG", "INFO"]): Debug level for logging. Defaults to "INFO".
+ fallbacks (List): List of fallback options. Defaults to [].
+ context_window_fallbacks (List): List of context window fallback options. Defaults to [].
+ enable_pre_call_checks (boolean): Filter out deployments which are outside context window limits for a given prompt
+ model_group_alias (Optional[dict]): Alias for model groups. Defaults to {}.
+ retry_after (int): Minimum time to wait before retrying a failed request. Defaults to 0.
+ allowed_fails (Optional[int]): Number of allowed fails before adding to cooldown. Defaults to None.
+ cooldown_time (float): Time to cooldown a deployment after failure in seconds. Defaults to 1.
+ routing_strategy (Literal["simple-shuffle", "least-busy", "usage-based-routing", "latency-based-routing", "cost-based-routing"]): Routing strategy. Defaults to "simple-shuffle".
+ routing_strategy_args (dict): Additional args for latency-based routing. Defaults to {}.
+ alerting_config (AlertingConfig): Slack alerting configuration. Defaults to None.
+ provider_budget_config (ProviderBudgetConfig): Provider budget configuration. Use this to set llm_provider budget limits. example $100/day to OpenAI, $100/day to Azure, etc. Defaults to None.
+ Returns:
+ Router: An instance of the litellm.Router class.
+
+ Example Usage:
+ ```python
+ from litellm import Router
+ model_list = [
+ {
+ "model_name": "azure-gpt-3.5-turbo", # model alias
+ "litellm_params": { # params for litellm completion/embedding call
+ "model": "azure/<your-deployment-name-1>",
+ "api_key": <your-api-key>,
+ "api_version": <your-api-version>,
+ "api_base": <your-api-base>
+ },
+ },
+ {
+ "model_name": "azure-gpt-3.5-turbo", # model alias
+ "litellm_params": { # params for litellm completion/embedding call
+ "model": "azure/<your-deployment-name-2>",
+ "api_key": <your-api-key>,
+ "api_version": <your-api-version>,
+ "api_base": <your-api-base>
+ },
+ },
+ {
+ "model_name": "openai-gpt-3.5-turbo", # model alias
+ "litellm_params": { # params for litellm completion/embedding call
+ "model": "gpt-3.5-turbo",
+ "api_key": <your-api-key>,
+ },
+ ]
+
+ router = Router(model_list=model_list, fallbacks=[{"azure-gpt-3.5-turbo": "openai-gpt-3.5-turbo"}])
+ ```
+ """
+
+ from litellm._service_logger import ServiceLogging
+
+ self.set_verbose = set_verbose
+ self.debug_level = debug_level
+ self.enable_pre_call_checks = enable_pre_call_checks
+ self.enable_tag_filtering = enable_tag_filtering
+ litellm.suppress_debug_info = True # prevents 'Give Feedback/Get help' message from being emitted on Router - Relevant Issue: https://github.com/BerriAI/litellm/issues/5942
+ if self.set_verbose is True:
+ if debug_level == "INFO":
+ verbose_router_logger.setLevel(logging.INFO)
+ elif debug_level == "DEBUG":
+ verbose_router_logger.setLevel(logging.DEBUG)
+ self.router_general_settings: RouterGeneralSettings = (
+ router_general_settings or RouterGeneralSettings()
+ )
+
+ self.assistants_config = assistants_config
+ self.deployment_names: List = (
+ []
+ ) # names of models under litellm_params. ex. azure/chatgpt-v-2
+ self.deployment_latency_map = {}
+ ### CACHING ###
+ cache_type: Literal["local", "redis", "redis-semantic", "s3", "disk"] = (
+ "local" # default to an in-memory cache
+ )
+ redis_cache = None
+ cache_config: Dict[str, Any] = {}
+
+ self.client_ttl = client_ttl
+ if redis_url is not None or (redis_host is not None and redis_port is not None):
+ cache_type = "redis"
+
+ if redis_url is not None:
+ cache_config["url"] = redis_url
+
+ if redis_host is not None:
+ cache_config["host"] = redis_host
+
+ if redis_port is not None:
+ cache_config["port"] = str(redis_port) # type: ignore
+
+ if redis_password is not None:
+ cache_config["password"] = redis_password
+
+ # Add additional key-value pairs from cache_kwargs
+ cache_config.update(cache_kwargs)
+ redis_cache = self._create_redis_cache(cache_config)
+
+ if cache_responses:
+ if litellm.cache is None:
+ # the cache can be initialized on the proxy server. We should not overwrite it
+ litellm.cache = litellm.Cache(type=cache_type, **cache_config) # type: ignore
+ self.cache_responses = cache_responses
+ self.cache = DualCache(
+ redis_cache=redis_cache, in_memory_cache=InMemoryCache()
+ ) # use a dual cache (Redis+In-Memory) for tracking cooldowns, usage, etc.
+
+ ### SCHEDULER ###
+ self.scheduler = Scheduler(
+ polling_interval=polling_interval, redis_cache=redis_cache
+ )
+ self.default_priority = default_priority
+ self.default_deployment = None # use this to track the users default deployment, when they want to use model = *
+ self.default_max_parallel_requests = default_max_parallel_requests
+ self.provider_default_deployment_ids: List[str] = []
+ self.pattern_router = PatternMatchRouter()
+
+ if model_list is not None:
+ model_list = copy.deepcopy(model_list)
+ self.set_model_list(model_list)
+ self.healthy_deployments: List = self.model_list # type: ignore
+ for m in model_list:
+ if "model" in m["litellm_params"]:
+ self.deployment_latency_map[m["litellm_params"]["model"]] = 0
+ else:
+ self.model_list: List = (
+ []
+ ) # initialize an empty list - to allow _add_deployment and delete_deployment to work
+
+ if allowed_fails is not None:
+ self.allowed_fails = allowed_fails
+ else:
+ self.allowed_fails = litellm.allowed_fails
+ self.cooldown_time = cooldown_time or DEFAULT_COOLDOWN_TIME_SECONDS
+ self.cooldown_cache = CooldownCache(
+ cache=self.cache, default_cooldown_time=self.cooldown_time
+ )
+ self.disable_cooldowns = disable_cooldowns
+ self.failed_calls = (
+ InMemoryCache()
+ ) # cache to track failed call per deployment, if num failed calls within 1 minute > allowed fails, then add it to cooldown
+
+ if num_retries is not None:
+ self.num_retries = num_retries
+ elif litellm.num_retries is not None:
+ self.num_retries = litellm.num_retries
+ else:
+ self.num_retries = openai.DEFAULT_MAX_RETRIES
+
+ if max_fallbacks is not None:
+ self.max_fallbacks = max_fallbacks
+ elif litellm.max_fallbacks is not None:
+ self.max_fallbacks = litellm.max_fallbacks
+ else:
+ self.max_fallbacks = litellm.ROUTER_MAX_FALLBACKS
+
+ self.timeout = timeout or litellm.request_timeout
+ self.stream_timeout = stream_timeout
+
+ self.retry_after = retry_after
+ self.routing_strategy = routing_strategy
+
+ ## SETTING FALLBACKS ##
+ ### validate if it's set + in correct format
+ _fallbacks = fallbacks or litellm.fallbacks
+
+ self.validate_fallbacks(fallback_param=_fallbacks)
+ ### set fallbacks
+ self.fallbacks = _fallbacks
+
+ if default_fallbacks is not None or litellm.default_fallbacks is not None:
+ _fallbacks = default_fallbacks or litellm.default_fallbacks
+ if self.fallbacks is not None:
+ self.fallbacks.append({"*": _fallbacks})
+ else:
+ self.fallbacks = [{"*": _fallbacks}]
+
+ self.context_window_fallbacks = (
+ context_window_fallbacks or litellm.context_window_fallbacks
+ )
+
+ _content_policy_fallbacks = (
+ content_policy_fallbacks or litellm.content_policy_fallbacks
+ )
+ self.validate_fallbacks(fallback_param=_content_policy_fallbacks)
+ self.content_policy_fallbacks = _content_policy_fallbacks
+ self.total_calls: defaultdict = defaultdict(
+ int
+ ) # dict to store total calls made to each model
+ self.fail_calls: defaultdict = defaultdict(
+ int
+ ) # dict to store fail_calls made to each model
+ self.success_calls: defaultdict = defaultdict(
+ int
+ ) # dict to store success_calls made to each model
+ self.previous_models: List = (
+ []
+ ) # list to store failed calls (passed in as metadata to next call)
+ self.model_group_alias: Dict[str, Union[str, RouterModelGroupAliasItem]] = (
+ model_group_alias or {}
+ ) # dict to store aliases for router, ex. {"gpt-4": "gpt-3.5-turbo"}, all requests with gpt-4 -> get routed to gpt-3.5-turbo group
+
+ # make Router.chat.completions.create compatible for openai.chat.completions.create
+ default_litellm_params = default_litellm_params or {}
+ self.chat = litellm.Chat(params=default_litellm_params, router_obj=self)
+
+ # default litellm args
+ self.default_litellm_params = default_litellm_params
+ self.default_litellm_params.setdefault("timeout", timeout)
+ self.default_litellm_params.setdefault("max_retries", 0)
+ self.default_litellm_params.setdefault("metadata", {}).update(
+ {"caching_groups": caching_groups}
+ )
+
+ self.deployment_stats: dict = {} # used for debugging load balancing
+ """
+ deployment_stats = {
+ "122999-2828282-277:
+ {
+ "model": "gpt-3",
+ "api_base": "http://localhost:4000",
+ "num_requests": 20,
+ "avg_latency": 0.001,
+ "num_failures": 0,
+ "num_successes": 20
+ }
+ }
+ """
+ ### ROUTING SETUP ###
+ self.routing_strategy_init(
+ routing_strategy=routing_strategy,
+ routing_strategy_args=routing_strategy_args,
+ )
+ self.access_groups = None
+ ## USAGE TRACKING ##
+ if isinstance(litellm._async_success_callback, list):
+ litellm.logging_callback_manager.add_litellm_async_success_callback(
+ self.deployment_callback_on_success
+ )
+ else:
+ litellm.logging_callback_manager.add_litellm_async_success_callback(
+ self.deployment_callback_on_success
+ )
+ if isinstance(litellm.success_callback, list):
+ litellm.logging_callback_manager.add_litellm_success_callback(
+ self.sync_deployment_callback_on_success
+ )
+ else:
+ litellm.success_callback = [self.sync_deployment_callback_on_success]
+ if isinstance(litellm._async_failure_callback, list):
+ litellm.logging_callback_manager.add_litellm_async_failure_callback(
+ self.async_deployment_callback_on_failure
+ )
+ else:
+ litellm._async_failure_callback = [
+ self.async_deployment_callback_on_failure
+ ]
+ ## COOLDOWNS ##
+ if isinstance(litellm.failure_callback, list):
+ litellm.logging_callback_manager.add_litellm_failure_callback(
+ self.deployment_callback_on_failure
+ )
+ else:
+ litellm.failure_callback = [self.deployment_callback_on_failure]
+ verbose_router_logger.debug(
+ f"Intialized router with Routing strategy: {self.routing_strategy}\n\n"
+ f"Routing enable_pre_call_checks: {self.enable_pre_call_checks}\n\n"
+ f"Routing fallbacks: {self.fallbacks}\n\n"
+ f"Routing content fallbacks: {self.content_policy_fallbacks}\n\n"
+ f"Routing context window fallbacks: {self.context_window_fallbacks}\n\n"
+ f"Router Redis Caching={self.cache.redis_cache}\n"
+ )
+ self.service_logger_obj = ServiceLogging()
+ self.routing_strategy_args = routing_strategy_args
+ self.provider_budget_config = provider_budget_config
+ self.router_budget_logger: Optional[RouterBudgetLimiting] = None
+ if RouterBudgetLimiting.should_init_router_budget_limiter(
+ model_list=model_list, provider_budget_config=self.provider_budget_config
+ ):
+ if optional_pre_call_checks is not None:
+ optional_pre_call_checks.append("router_budget_limiting")
+ else:
+ optional_pre_call_checks = ["router_budget_limiting"]
+ self.retry_policy: Optional[RetryPolicy] = None
+ if retry_policy is not None:
+ if isinstance(retry_policy, dict):
+ self.retry_policy = RetryPolicy(**retry_policy)
+ elif isinstance(retry_policy, RetryPolicy):
+ self.retry_policy = retry_policy
+ verbose_router_logger.info(
+ "\033[32mRouter Custom Retry Policy Set:\n{}\033[0m".format(
+ self.retry_policy.model_dump(exclude_none=True)
+ )
+ )
+
+ self.model_group_retry_policy: Optional[Dict[str, RetryPolicy]] = (
+ model_group_retry_policy
+ )
+
+ self.allowed_fails_policy: Optional[AllowedFailsPolicy] = None
+ if allowed_fails_policy is not None:
+ if isinstance(allowed_fails_policy, dict):
+ self.allowed_fails_policy = AllowedFailsPolicy(**allowed_fails_policy)
+ elif isinstance(allowed_fails_policy, AllowedFailsPolicy):
+ self.allowed_fails_policy = allowed_fails_policy
+
+ verbose_router_logger.info(
+ "\033[32mRouter Custom Allowed Fails Policy Set:\n{}\033[0m".format(
+ self.allowed_fails_policy.model_dump(exclude_none=True)
+ )
+ )
+
+ self.alerting_config: Optional[AlertingConfig] = alerting_config
+
+ if optional_pre_call_checks is not None:
+ self.add_optional_pre_call_checks(optional_pre_call_checks)
+
+ if self.alerting_config is not None:
+ self._initialize_alerting()
+
+ self.initialize_assistants_endpoint()
+ self.initialize_router_endpoints()
+
+ def discard(self):
+ """
+ Pseudo-destructor to be invoked to clean up global data structures when router is no longer used.
+ For now, unhook router's callbacks from all lists
+ """
+ litellm.logging_callback_manager.remove_callback_from_list_by_object(
+ litellm._async_success_callback, self
+ )
+ litellm.logging_callback_manager.remove_callback_from_list_by_object(
+ litellm.success_callback, self
+ )
+ litellm.logging_callback_manager.remove_callback_from_list_by_object(
+ litellm._async_failure_callback, self
+ )
+ litellm.logging_callback_manager.remove_callback_from_list_by_object(
+ litellm.failure_callback, self
+ )
+ litellm.logging_callback_manager.remove_callback_from_list_by_object(
+ litellm.input_callback, self
+ )
+ litellm.logging_callback_manager.remove_callback_from_list_by_object(
+ litellm.service_callback, self
+ )
+ litellm.logging_callback_manager.remove_callback_from_list_by_object(
+ litellm.callbacks, self
+ )
+
+ @staticmethod
+ def _create_redis_cache(
+ cache_config: Dict[str, Any]
+ ) -> Union[RedisCache, RedisClusterCache]:
+ """
+ Initializes either a RedisCache or RedisClusterCache based on the cache_config.
+ """
+ if cache_config.get("startup_nodes"):
+ return RedisClusterCache(**cache_config)
+ else:
+ return RedisCache(**cache_config)
+
+ def _update_redis_cache(self, cache: RedisCache):
+ """
+ Update the redis cache for the router, if none set.
+
+ Allows proxy user to just do
+ ```yaml
+ litellm_settings:
+ cache: true
+ ```
+ and caching to just work.
+ """
+ if self.cache.redis_cache is None:
+ self.cache.redis_cache = cache
+
+ def routing_strategy_init(
+ self, routing_strategy: Union[RoutingStrategy, str], routing_strategy_args: dict
+ ):
+ verbose_router_logger.info(f"Routing strategy: {routing_strategy}")
+ if (
+ routing_strategy == RoutingStrategy.LEAST_BUSY.value
+ or routing_strategy == RoutingStrategy.LEAST_BUSY
+ ):
+ self.leastbusy_logger = LeastBusyLoggingHandler(
+ router_cache=self.cache, model_list=self.model_list
+ )
+ ## add callback
+ if isinstance(litellm.input_callback, list):
+ litellm.input_callback.append(self.leastbusy_logger) # type: ignore
+ else:
+ litellm.input_callback = [self.leastbusy_logger] # type: ignore
+ if isinstance(litellm.callbacks, list):
+ litellm.logging_callback_manager.add_litellm_callback(self.leastbusy_logger) # type: ignore
+ elif (
+ routing_strategy == RoutingStrategy.USAGE_BASED_ROUTING.value
+ or routing_strategy == RoutingStrategy.USAGE_BASED_ROUTING
+ ):
+ self.lowesttpm_logger = LowestTPMLoggingHandler(
+ router_cache=self.cache,
+ model_list=self.model_list,
+ routing_args=routing_strategy_args,
+ )
+ if isinstance(litellm.callbacks, list):
+ litellm.logging_callback_manager.add_litellm_callback(self.lowesttpm_logger) # type: ignore
+ elif (
+ routing_strategy == RoutingStrategy.USAGE_BASED_ROUTING_V2.value
+ or routing_strategy == RoutingStrategy.USAGE_BASED_ROUTING_V2
+ ):
+ self.lowesttpm_logger_v2 = LowestTPMLoggingHandler_v2(
+ router_cache=self.cache,
+ model_list=self.model_list,
+ routing_args=routing_strategy_args,
+ )
+ if isinstance(litellm.callbacks, list):
+ litellm.logging_callback_manager.add_litellm_callback(self.lowesttpm_logger_v2) # type: ignore
+ elif (
+ routing_strategy == RoutingStrategy.LATENCY_BASED.value
+ or routing_strategy == RoutingStrategy.LATENCY_BASED
+ ):
+ self.lowestlatency_logger = LowestLatencyLoggingHandler(
+ router_cache=self.cache,
+ model_list=self.model_list,
+ routing_args=routing_strategy_args,
+ )
+ if isinstance(litellm.callbacks, list):
+ litellm.logging_callback_manager.add_litellm_callback(self.lowestlatency_logger) # type: ignore
+ elif (
+ routing_strategy == RoutingStrategy.COST_BASED.value
+ or routing_strategy == RoutingStrategy.COST_BASED
+ ):
+ self.lowestcost_logger = LowestCostLoggingHandler(
+ router_cache=self.cache,
+ model_list=self.model_list,
+ routing_args={},
+ )
+ if isinstance(litellm.callbacks, list):
+ litellm.logging_callback_manager.add_litellm_callback(self.lowestcost_logger) # type: ignore
+ else:
+ pass
+
+ def initialize_assistants_endpoint(self):
+ ## INITIALIZE PASS THROUGH ASSISTANTS ENDPOINT ##
+ self.acreate_assistants = self.factory_function(litellm.acreate_assistants)
+ self.adelete_assistant = self.factory_function(litellm.adelete_assistant)
+ self.aget_assistants = self.factory_function(litellm.aget_assistants)
+ self.acreate_thread = self.factory_function(litellm.acreate_thread)
+ self.aget_thread = self.factory_function(litellm.aget_thread)
+ self.a_add_message = self.factory_function(litellm.a_add_message)
+ self.aget_messages = self.factory_function(litellm.aget_messages)
+ self.arun_thread = self.factory_function(litellm.arun_thread)
+
+ def initialize_router_endpoints(self):
+ self.amoderation = self.factory_function(
+ litellm.amoderation, call_type="moderation"
+ )
+ self.aanthropic_messages = self.factory_function(
+ litellm.anthropic_messages, call_type="anthropic_messages"
+ )
+ self.aresponses = self.factory_function(
+ litellm.aresponses, call_type="aresponses"
+ )
+ self.responses = self.factory_function(litellm.responses, call_type="responses")
+
+ def validate_fallbacks(self, fallback_param: Optional[List]):
+ """
+ Validate the fallbacks parameter.
+ """
+ if fallback_param is None:
+ return
+ for fallback_dict in fallback_param:
+ if not isinstance(fallback_dict, dict):
+ raise ValueError(f"Item '{fallback_dict}' is not a dictionary.")
+ if len(fallback_dict) != 1:
+ raise ValueError(
+ f"Dictionary '{fallback_dict}' must have exactly one key, but has {len(fallback_dict)} keys."
+ )
+
+ def add_optional_pre_call_checks(
+ self, optional_pre_call_checks: Optional[OptionalPreCallChecks]
+ ):
+ if optional_pre_call_checks is not None:
+ for pre_call_check in optional_pre_call_checks:
+ _callback: Optional[CustomLogger] = None
+ if pre_call_check == "prompt_caching":
+ _callback = PromptCachingDeploymentCheck(cache=self.cache)
+ elif pre_call_check == "router_budget_limiting":
+ _callback = RouterBudgetLimiting(
+ dual_cache=self.cache,
+ provider_budget_config=self.provider_budget_config,
+ model_list=self.model_list,
+ )
+ if _callback is not None:
+ litellm.logging_callback_manager.add_litellm_callback(_callback)
+
+ def print_deployment(self, deployment: dict):
+ """
+ returns a copy of the deployment with the api key masked
+
+ Only returns 2 characters of the api key and masks the rest with * (10 *).
+ """
+ try:
+ _deployment_copy = copy.deepcopy(deployment)
+ litellm_params: dict = _deployment_copy["litellm_params"]
+ if "api_key" in litellm_params:
+ litellm_params["api_key"] = litellm_params["api_key"][:2] + "*" * 10
+ return _deployment_copy
+ except Exception as e:
+ verbose_router_logger.debug(
+ f"Error occurred while printing deployment - {str(e)}"
+ )
+ raise e
+
+ ### COMPLETION, EMBEDDING, IMG GENERATION FUNCTIONS
+
+ def completion(
+ self, model: str, messages: List[Dict[str, str]], **kwargs
+ ) -> Union[ModelResponse, CustomStreamWrapper]:
+ """
+ Example usage:
+ response = router.completion(model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hey, how's it going?"}]
+ """
+ try:
+ verbose_router_logger.debug(f"router.completion(model={model},..)")
+ kwargs["model"] = model
+ kwargs["messages"] = messages
+ kwargs["original_function"] = self._completion
+ self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
+
+ response = self.function_with_fallbacks(**kwargs)
+ return response
+ except Exception as e:
+ raise e
+
+ def _completion(
+ self, model: str, messages: List[Dict[str, str]], **kwargs
+ ) -> Union[ModelResponse, CustomStreamWrapper]:
+ model_name = None
+ try:
+ # pick the one that is available (lowest TPM/RPM)
+ deployment = self.get_available_deployment(
+ model=model,
+ messages=messages,
+ specific_deployment=kwargs.pop("specific_deployment", None),
+ )
+ self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
+
+ data = deployment["litellm_params"].copy()
+ model_name = data["model"]
+ potential_model_client = self._get_client(
+ deployment=deployment, kwargs=kwargs
+ )
+ # check if provided keys == client keys #
+ dynamic_api_key = kwargs.get("api_key", None)
+ if (
+ dynamic_api_key is not None
+ and potential_model_client is not None
+ and dynamic_api_key != potential_model_client.api_key
+ ):
+ model_client = None
+ else:
+ model_client = potential_model_client
+
+ ### DEPLOYMENT-SPECIFIC PRE-CALL CHECKS ### (e.g. update rpm pre-call. Raise error, if deployment over limit)
+ ## only run if model group given, not model id
+ if model not in self.get_model_ids():
+ self.routing_strategy_pre_call_checks(deployment=deployment)
+
+ response = litellm.completion(
+ **{
+ **data,
+ "messages": messages,
+ "caching": self.cache_responses,
+ "client": model_client,
+ **kwargs,
+ }
+ )
+ verbose_router_logger.info(
+ f"litellm.completion(model={model_name})\033[32m 200 OK\033[0m"
+ )
+
+ ## CHECK CONTENT FILTER ERROR ##
+ if isinstance(response, ModelResponse):
+ _should_raise = self._should_raise_content_policy_error(
+ model=model, response=response, kwargs=kwargs
+ )
+ if _should_raise:
+ raise litellm.ContentPolicyViolationError(
+ message="Response output was blocked.",
+ model=model,
+ llm_provider="",
+ )
+
+ return response
+ except Exception as e:
+ verbose_router_logger.info(
+ f"litellm.completion(model={model_name})\033[31m Exception {str(e)}\033[0m"
+ )
+ raise e
+
+ # fmt: off
+
+ @overload
+ async def acompletion(
+ self, model: str, messages: List[AllMessageValues], stream: Literal[True], **kwargs
+ ) -> CustomStreamWrapper:
+ ...
+
+ @overload
+ async def acompletion(
+ self, model: str, messages: List[AllMessageValues], stream: Literal[False] = False, **kwargs
+ ) -> ModelResponse:
+ ...
+
+ @overload
+ async def acompletion(
+ self, model: str, messages: List[AllMessageValues], stream: Union[Literal[True], Literal[False]] = False, **kwargs
+ ) -> Union[CustomStreamWrapper, ModelResponse]:
+ ...
+
+ # fmt: on
+
+ # The actual implementation of the function
+ async def acompletion(
+ self,
+ model: str,
+ messages: List[AllMessageValues],
+ stream: bool = False,
+ **kwargs,
+ ):
+ try:
+ kwargs["model"] = model
+ kwargs["messages"] = messages
+ kwargs["stream"] = stream
+ kwargs["original_function"] = self._acompletion
+
+ self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
+ request_priority = kwargs.get("priority") or self.default_priority
+ start_time = time.time()
+ _is_prompt_management_model = self._is_prompt_management_model(model)
+
+ if _is_prompt_management_model:
+ return await self._prompt_management_factory(
+ model=model,
+ messages=messages,
+ kwargs=kwargs,
+ )
+ if request_priority is not None and isinstance(request_priority, int):
+ response = await self.schedule_acompletion(**kwargs)
+ else:
+ response = await self.async_function_with_fallbacks(**kwargs)
+ end_time = time.time()
+ _duration = end_time - start_time
+ asyncio.create_task(
+ self.service_logger_obj.async_service_success_hook(
+ service=ServiceTypes.ROUTER,
+ duration=_duration,
+ call_type="acompletion",
+ start_time=start_time,
+ end_time=end_time,
+ parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
+ )
+ )
+
+ return response
+ except Exception as e:
+ asyncio.create_task(
+ send_llm_exception_alert(
+ litellm_router_instance=self,
+ request_kwargs=kwargs,
+ error_traceback_str=traceback.format_exc(),
+ original_exception=e,
+ )
+ )
+ raise e
+
+ async def _acompletion(
+ self, model: str, messages: List[Dict[str, str]], **kwargs
+ ) -> Union[ModelResponse, CustomStreamWrapper]:
+ """
+ - Get an available deployment
+ - call it with a semaphore over the call
+ - semaphore specific to it's rpm
+ - in the semaphore, make a check against it's local rpm before running
+ """
+ model_name = None
+ _timeout_debug_deployment_dict = (
+ {}
+ ) # this is a temporary dict to debug timeout issues
+ try:
+ verbose_router_logger.debug(
+ f"Inside _acompletion()- model: {model}; kwargs: {kwargs}"
+ )
+ parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
+ start_time = time.time()
+ deployment = await self.async_get_available_deployment(
+ model=model,
+ messages=messages,
+ specific_deployment=kwargs.pop("specific_deployment", None),
+ request_kwargs=kwargs,
+ )
+
+ _timeout_debug_deployment_dict = deployment
+ end_time = time.time()
+ _duration = end_time - start_time
+ asyncio.create_task(
+ self.service_logger_obj.async_service_success_hook(
+ service=ServiceTypes.ROUTER,
+ duration=_duration,
+ call_type="async_get_available_deployment",
+ start_time=start_time,
+ end_time=end_time,
+ parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
+ )
+ )
+
+ # debug how often this deployment picked
+
+ self._track_deployment_metrics(
+ deployment=deployment, parent_otel_span=parent_otel_span
+ )
+ self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
+ data = deployment["litellm_params"].copy()
+
+ model_name = data["model"]
+
+ model_client = self._get_async_openai_model_client(
+ deployment=deployment,
+ kwargs=kwargs,
+ )
+ self.total_calls[model_name] += 1
+
+ _response = litellm.acompletion(
+ **{
+ **data,
+ "messages": messages,
+ "caching": self.cache_responses,
+ "client": model_client,
+ **kwargs,
+ }
+ )
+
+ logging_obj: Optional[LiteLLMLogging] = kwargs.get(
+ "litellm_logging_obj", None
+ )
+
+ rpm_semaphore = self._get_client(
+ deployment=deployment,
+ kwargs=kwargs,
+ client_type="max_parallel_requests",
+ )
+ if rpm_semaphore is not None and isinstance(
+ rpm_semaphore, asyncio.Semaphore
+ ):
+ async with rpm_semaphore:
+ """
+ - Check rpm limits before making the call
+ - If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
+ """
+ await self.async_routing_strategy_pre_call_checks(
+ deployment=deployment,
+ logging_obj=logging_obj,
+ parent_otel_span=parent_otel_span,
+ )
+ response = await _response
+ else:
+ await self.async_routing_strategy_pre_call_checks(
+ deployment=deployment,
+ logging_obj=logging_obj,
+ parent_otel_span=parent_otel_span,
+ )
+
+ response = await _response
+
+ ## CHECK CONTENT FILTER ERROR ##
+ if isinstance(response, ModelResponse):
+ _should_raise = self._should_raise_content_policy_error(
+ model=model, response=response, kwargs=kwargs
+ )
+ if _should_raise:
+ raise litellm.ContentPolicyViolationError(
+ message="Response output was blocked.",
+ model=model,
+ llm_provider="",
+ )
+
+ self.success_calls[model_name] += 1
+ verbose_router_logger.info(
+ f"litellm.acompletion(model={model_name})\033[32m 200 OK\033[0m"
+ )
+ # debug how often this deployment picked
+ self._track_deployment_metrics(
+ deployment=deployment,
+ response=response,
+ parent_otel_span=parent_otel_span,
+ )
+
+ return response
+ except litellm.Timeout as e:
+ deployment_request_timeout_param = _timeout_debug_deployment_dict.get(
+ "litellm_params", {}
+ ).get("request_timeout", None)
+ deployment_timeout_param = _timeout_debug_deployment_dict.get(
+ "litellm_params", {}
+ ).get("timeout", None)
+ e.message += f"\n\nDeployment Info: request_timeout: {deployment_request_timeout_param}\ntimeout: {deployment_timeout_param}"
+ raise e
+ except Exception as e:
+ verbose_router_logger.info(
+ f"litellm.acompletion(model={model_name})\033[31m Exception {str(e)}\033[0m"
+ )
+ if model_name is not None:
+ self.fail_calls[model_name] += 1
+ raise e
+
+ def _update_kwargs_before_fallbacks(self, model: str, kwargs: dict) -> None:
+ """
+ Adds/updates to kwargs:
+ - num_retries
+ - litellm_trace_id
+ - metadata
+ """
+ kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
+ kwargs.setdefault("litellm_trace_id", str(uuid.uuid4()))
+ kwargs.setdefault("metadata", {}).update({"model_group": model})
+
+ def _update_kwargs_with_default_litellm_params(
+ self, kwargs: dict, metadata_variable_name: Optional[str] = "metadata"
+ ) -> None:
+ """
+ Adds default litellm params to kwargs, if set.
+ """
+ self.default_litellm_params[metadata_variable_name] = (
+ self.default_litellm_params.pop("metadata", {})
+ )
+ for k, v in self.default_litellm_params.items():
+ if (
+ k not in kwargs and v is not None
+ ): # prioritize model-specific params > default router params
+ kwargs[k] = v
+ elif k == metadata_variable_name:
+ kwargs[metadata_variable_name].update(v)
+
+ def _handle_clientside_credential(
+ self, deployment: dict, kwargs: dict
+ ) -> Deployment:
+ """
+ Handle clientside credential
+ """
+ model_info = deployment.get("model_info", {}).copy()
+ litellm_params = deployment["litellm_params"].copy()
+ dynamic_litellm_params = get_dynamic_litellm_params(
+ litellm_params=litellm_params, request_kwargs=kwargs
+ )
+ metadata = kwargs.get("metadata", {})
+ model_group = cast(str, metadata.get("model_group"))
+ _model_id = self._generate_model_id(
+ model_group=model_group, litellm_params=dynamic_litellm_params
+ )
+ original_model_id = model_info.get("id")
+ model_info["id"] = _model_id
+ model_info["original_model_id"] = original_model_id
+ deployment_pydantic_obj = Deployment(
+ model_name=model_group,
+ litellm_params=LiteLLM_Params(**dynamic_litellm_params),
+ model_info=model_info,
+ )
+ self.upsert_deployment(
+ deployment=deployment_pydantic_obj
+ ) # add new deployment to router
+ return deployment_pydantic_obj
+
+ def _update_kwargs_with_deployment(
+ self,
+ deployment: dict,
+ kwargs: dict,
+ function_name: Optional[str] = None,
+ ) -> None:
+ """
+ 2 jobs:
+ - Adds selected deployment, model_info and api_base to kwargs["metadata"] (used for logging)
+ - Adds default litellm params to kwargs, if set.
+ """
+ model_info = deployment.get("model_info", {}).copy()
+ deployment_model_name = deployment["litellm_params"]["model"]
+ deployment_api_base = deployment["litellm_params"].get("api_base")
+ if is_clientside_credential(request_kwargs=kwargs):
+ deployment_pydantic_obj = self._handle_clientside_credential(
+ deployment=deployment, kwargs=kwargs
+ )
+ model_info = deployment_pydantic_obj.model_info.model_dump()
+ deployment_model_name = deployment_pydantic_obj.litellm_params.model
+ deployment_api_base = deployment_pydantic_obj.litellm_params.api_base
+
+ metadata_variable_name = _get_router_metadata_variable_name(
+ function_name=function_name,
+ )
+ kwargs.setdefault(metadata_variable_name, {}).update(
+ {
+ "deployment": deployment_model_name,
+ "model_info": model_info,
+ "api_base": deployment_api_base,
+ }
+ )
+ kwargs["model_info"] = model_info
+
+ kwargs["timeout"] = self._get_timeout(
+ kwargs=kwargs, data=deployment["litellm_params"]
+ )
+
+ self._update_kwargs_with_default_litellm_params(
+ kwargs=kwargs, metadata_variable_name=metadata_variable_name
+ )
+
+ def _get_async_openai_model_client(self, deployment: dict, kwargs: dict):
+ """
+ Helper to get AsyncOpenAI or AsyncAzureOpenAI client that was created for the deployment
+
+ The same OpenAI client is re-used to optimize latency / performance in production
+
+ If dynamic api key is provided:
+ Do not re-use the client. Pass model_client=None. The OpenAI/ AzureOpenAI client will be recreated in the handler for the llm provider
+ """
+ potential_model_client = self._get_client(
+ deployment=deployment, kwargs=kwargs, client_type="async"
+ )
+
+ # check if provided keys == client keys #
+ dynamic_api_key = kwargs.get("api_key", None)
+ if (
+ dynamic_api_key is not None
+ and potential_model_client is not None
+ and dynamic_api_key != potential_model_client.api_key
+ ):
+ model_client = None
+ else:
+ model_client = potential_model_client
+
+ return model_client
+
+ def _get_stream_timeout(
+ self, kwargs: dict, data: dict
+ ) -> Optional[Union[float, int]]:
+ """Helper to get stream timeout from kwargs or deployment params"""
+ return (
+ kwargs.get("stream_timeout", None) # the params dynamically set by user
+ or data.get(
+ "stream_timeout", None
+ ) # timeout set on litellm_params for this deployment
+ or self.stream_timeout # timeout set on router
+ or self.default_litellm_params.get("stream_timeout", None)
+ )
+
+ def _get_non_stream_timeout(
+ self, kwargs: dict, data: dict
+ ) -> Optional[Union[float, int]]:
+ """Helper to get non-stream timeout from kwargs or deployment params"""
+ timeout = (
+ kwargs.get("timeout", None) # the params dynamically set by user
+ or kwargs.get("request_timeout", None) # the params dynamically set by user
+ or data.get(
+ "timeout", None
+ ) # timeout set on litellm_params for this deployment
+ or data.get(
+ "request_timeout", None
+ ) # timeout set on litellm_params for this deployment
+ or self.timeout # timeout set on router
+ or self.default_litellm_params.get("timeout", None)
+ )
+ return timeout
+
+ def _get_timeout(self, kwargs: dict, data: dict) -> Optional[Union[float, int]]:
+ """Helper to get timeout from kwargs or deployment params"""
+ timeout: Optional[Union[float, int]] = None
+ if kwargs.get("stream", False):
+ timeout = self._get_stream_timeout(kwargs=kwargs, data=data)
+ if timeout is None:
+ timeout = self._get_non_stream_timeout(
+ kwargs=kwargs, data=data
+ ) # default to this if no stream specific timeout set
+ return timeout
+
+ async def abatch_completion(
+ self,
+ models: List[str],
+ messages: Union[List[Dict[str, str]], List[List[Dict[str, str]]]],
+ **kwargs,
+ ):
+ """
+ Async Batch Completion. Used for 2 scenarios:
+ 1. Batch Process 1 request to N models on litellm.Router. Pass messages as List[Dict[str, str]] to use this
+ 2. Batch Process N requests to M models on litellm.Router. Pass messages as List[List[Dict[str, str]]] to use this
+
+ Example Request for 1 request to N models:
+ ```
+ response = await router.abatch_completion(
+ models=["gpt-3.5-turbo", "groq-llama"],
+ messages=[
+ {"role": "user", "content": "is litellm becoming a better product ?"}
+ ],
+ max_tokens=15,
+ )
+ ```
+
+
+ Example Request for N requests to M models:
+ ```
+ response = await router.abatch_completion(
+ models=["gpt-3.5-turbo", "groq-llama"],
+ messages=[
+ [{"role": "user", "content": "is litellm becoming a better product ?"}],
+ [{"role": "user", "content": "who is this"}],
+ ],
+ )
+ ```
+ """
+ ############## Helpers for async completion ##################
+
+ async def _async_completion_no_exceptions(
+ model: str, messages: List[AllMessageValues], **kwargs
+ ):
+ """
+ Wrapper around self.async_completion that catches exceptions and returns them as a result
+ """
+ try:
+ return await self.acompletion(model=model, messages=messages, **kwargs)
+ except Exception as e:
+ return e
+
+ async def _async_completion_no_exceptions_return_idx(
+ model: str,
+ messages: List[AllMessageValues],
+ idx: int, # index of message this response corresponds to
+ **kwargs,
+ ):
+ """
+ Wrapper around self.async_completion that catches exceptions and returns them as a result
+ """
+ try:
+ return (
+ await self.acompletion(model=model, messages=messages, **kwargs),
+ idx,
+ )
+ except Exception as e:
+ return e, idx
+
+ ############## Helpers for async completion ##################
+
+ if isinstance(messages, list) and all(isinstance(m, dict) for m in messages):
+ _tasks = []
+ for model in models:
+ # add each task but if the task fails
+ _tasks.append(_async_completion_no_exceptions(model=model, messages=messages, **kwargs)) # type: ignore
+ response = await asyncio.gather(*_tasks)
+ return response
+ elif isinstance(messages, list) and all(isinstance(m, list) for m in messages):
+ _tasks = []
+ for idx, message in enumerate(messages):
+ for model in models:
+ # Request Number X, Model Number Y
+ _tasks.append(
+ _async_completion_no_exceptions_return_idx(
+ model=model, idx=idx, messages=message, **kwargs # type: ignore
+ )
+ )
+ responses = await asyncio.gather(*_tasks)
+ final_responses: List[List[Any]] = [[] for _ in range(len(messages))]
+ for response in responses:
+ if isinstance(response, tuple):
+ final_responses[response[1]].append(response[0])
+ else:
+ final_responses[0].append(response)
+ return final_responses
+
+ async def abatch_completion_one_model_multiple_requests(
+ self, model: str, messages: List[List[AllMessageValues]], **kwargs
+ ):
+ """
+ Async Batch Completion - Batch Process multiple Messages to one model_group on litellm.Router
+
+ Use this for sending multiple requests to 1 model
+
+ Args:
+ model (List[str]): model group
+ messages (List[List[Dict[str, str]]]): list of messages. Each element in the list is one request
+ **kwargs: additional kwargs
+ Usage:
+ response = await self.abatch_completion_one_model_multiple_requests(
+ model="gpt-3.5-turbo",
+ messages=[
+ [{"role": "user", "content": "hello"}, {"role": "user", "content": "tell me something funny"}],
+ [{"role": "user", "content": "hello good mornign"}],
+ ]
+ )
+ """
+
+ async def _async_completion_no_exceptions(
+ model: str, messages: List[AllMessageValues], **kwargs
+ ):
+ """
+ Wrapper around self.async_completion that catches exceptions and returns them as a result
+ """
+ try:
+ return await self.acompletion(model=model, messages=messages, **kwargs)
+ except Exception as e:
+ return e
+
+ _tasks = []
+ for message_request in messages:
+ # add each task but if the task fails
+ _tasks.append(
+ _async_completion_no_exceptions(
+ model=model, messages=message_request, **kwargs
+ )
+ )
+
+ response = await asyncio.gather(*_tasks)
+ return response
+
+ # fmt: off
+
+ @overload
+ async def abatch_completion_fastest_response(
+ self, model: str, messages: List[Dict[str, str]], stream: Literal[True], **kwargs
+ ) -> CustomStreamWrapper:
+ ...
+
+
+
+ @overload
+ async def abatch_completion_fastest_response(
+ self, model: str, messages: List[Dict[str, str]], stream: Literal[False] = False, **kwargs
+ ) -> ModelResponse:
+ ...
+
+ # fmt: on
+
+ async def abatch_completion_fastest_response(
+ self,
+ model: str,
+ messages: List[Dict[str, str]],
+ stream: bool = False,
+ **kwargs,
+ ):
+ """
+ model - List of comma-separated model names. E.g. model="gpt-4, gpt-3.5-turbo"
+
+ Returns fastest response from list of model names. OpenAI-compatible endpoint.
+ """
+ models = [m.strip() for m in model.split(",")]
+
+ async def _async_completion_no_exceptions(
+ model: str, messages: List[Dict[str, str]], stream: bool, **kwargs: Any
+ ) -> Union[ModelResponse, CustomStreamWrapper, Exception]:
+ """
+ Wrapper around self.acompletion that catches exceptions and returns them as a result
+ """
+ try:
+ return await self.acompletion(model=model, messages=messages, stream=stream, **kwargs) # type: ignore
+ except asyncio.CancelledError:
+ verbose_router_logger.debug(
+ "Received 'task.cancel'. Cancelling call w/ model={}.".format(model)
+ )
+ raise
+ except Exception as e:
+ return e
+
+ pending_tasks = [] # type: ignore
+
+ async def check_response(task: asyncio.Task):
+ nonlocal pending_tasks
+ try:
+ result = await task
+ if isinstance(result, (ModelResponse, CustomStreamWrapper)):
+ verbose_router_logger.debug(
+ "Received successful response. Cancelling other LLM API calls."
+ )
+ # If a desired response is received, cancel all other pending tasks
+ for t in pending_tasks:
+ t.cancel()
+ return result
+ except Exception:
+ # Ignore exceptions, let the loop handle them
+ pass
+ finally:
+ # Remove the task from pending tasks if it finishes
+ try:
+ pending_tasks.remove(task)
+ except KeyError:
+ pass
+
+ for model in models:
+ task = asyncio.create_task(
+ _async_completion_no_exceptions(
+ model=model, messages=messages, stream=stream, **kwargs
+ )
+ )
+ pending_tasks.append(task)
+
+ # Await the first task to complete successfully
+ while pending_tasks:
+ done, pending_tasks = await asyncio.wait( # type: ignore
+ pending_tasks, return_when=asyncio.FIRST_COMPLETED
+ )
+ for completed_task in done:
+ result = await check_response(completed_task)
+ if result is not None:
+ # Return the first successful result
+ result._hidden_params["fastest_response_batch_completion"] = True
+ return result
+
+ # If we exit the loop without returning, all tasks failed
+ raise Exception("All tasks failed")
+
+ ### SCHEDULER ###
+
+ # fmt: off
+
+ @overload
+ async def schedule_acompletion(
+ self, model: str, messages: List[AllMessageValues], priority: int, stream: Literal[False] = False, **kwargs
+ ) -> ModelResponse:
+ ...
+
+ @overload
+ async def schedule_acompletion(
+ self, model: str, messages: List[AllMessageValues], priority: int, stream: Literal[True], **kwargs
+ ) -> CustomStreamWrapper:
+ ...
+
+ # fmt: on
+
+ async def schedule_acompletion(
+ self,
+ model: str,
+ messages: List[AllMessageValues],
+ priority: int,
+ stream=False,
+ **kwargs,
+ ):
+ parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
+ ### FLOW ITEM ###
+ _request_id = str(uuid.uuid4())
+ item = FlowItem(
+ priority=priority, # 👈 SET PRIORITY FOR REQUEST
+ request_id=_request_id, # 👈 SET REQUEST ID
+ model_name="gpt-3.5-turbo", # 👈 SAME as 'Router'
+ )
+ ### [fin] ###
+
+ ## ADDS REQUEST TO QUEUE ##
+ await self.scheduler.add_request(request=item)
+
+ ## POLL QUEUE
+ end_time = time.time() + self.timeout
+ curr_time = time.time()
+ poll_interval = self.scheduler.polling_interval # poll every 3ms
+ make_request = False
+
+ while curr_time < end_time:
+ _healthy_deployments, _ = await self._async_get_healthy_deployments(
+ model=model, parent_otel_span=parent_otel_span
+ )
+ make_request = await self.scheduler.poll( ## POLL QUEUE ## - returns 'True' if there's healthy deployments OR if request is at top of queue
+ id=item.request_id,
+ model_name=item.model_name,
+ health_deployments=_healthy_deployments,
+ )
+ if make_request: ## IF TRUE -> MAKE REQUEST
+ break
+ else: ## ELSE -> loop till default_timeout
+ await asyncio.sleep(poll_interval)
+ curr_time = time.time()
+
+ if make_request:
+ try:
+ _response = await self.acompletion(
+ model=model, messages=messages, stream=stream, **kwargs
+ )
+ _response._hidden_params.setdefault("additional_headers", {})
+ _response._hidden_params["additional_headers"].update(
+ {"x-litellm-request-prioritization-used": True}
+ )
+ return _response
+ except Exception as e:
+ setattr(e, "priority", priority)
+ raise e
+ else:
+ raise litellm.Timeout(
+ message="Request timed out while polling queue",
+ model=model,
+ llm_provider="openai",
+ )
+
+ async def _schedule_factory(
+ self,
+ model: str,
+ priority: int,
+ original_function: Callable,
+ args: Tuple[Any, ...],
+ kwargs: Dict[str, Any],
+ ):
+ parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
+ ### FLOW ITEM ###
+ _request_id = str(uuid.uuid4())
+ item = FlowItem(
+ priority=priority, # 👈 SET PRIORITY FOR REQUEST
+ request_id=_request_id, # 👈 SET REQUEST ID
+ model_name=model, # 👈 SAME as 'Router'
+ )
+ ### [fin] ###
+
+ ## ADDS REQUEST TO QUEUE ##
+ await self.scheduler.add_request(request=item)
+
+ ## POLL QUEUE
+ end_time = time.time() + self.timeout
+ curr_time = time.time()
+ poll_interval = self.scheduler.polling_interval # poll every 3ms
+ make_request = False
+
+ while curr_time < end_time:
+ _healthy_deployments, _ = await self._async_get_healthy_deployments(
+ model=model, parent_otel_span=parent_otel_span
+ )
+ make_request = await self.scheduler.poll( ## POLL QUEUE ## - returns 'True' if there's healthy deployments OR if request is at top of queue
+ id=item.request_id,
+ model_name=item.model_name,
+ health_deployments=_healthy_deployments,
+ )
+ if make_request: ## IF TRUE -> MAKE REQUEST
+ break
+ else: ## ELSE -> loop till default_timeout
+ await asyncio.sleep(poll_interval)
+ curr_time = time.time()
+
+ if make_request:
+ try:
+ _response = await original_function(*args, **kwargs)
+ if isinstance(_response._hidden_params, dict):
+ _response._hidden_params.setdefault("additional_headers", {})
+ _response._hidden_params["additional_headers"].update(
+ {"x-litellm-request-prioritization-used": True}
+ )
+ return _response
+ except Exception as e:
+ setattr(e, "priority", priority)
+ raise e
+ else:
+ raise litellm.Timeout(
+ message="Request timed out while polling queue",
+ model=model,
+ llm_provider="openai",
+ )
+
+ def _is_prompt_management_model(self, model: str) -> bool:
+ model_list = self.get_model_list(model_name=model)
+ if model_list is None:
+ return False
+ if len(model_list) != 1:
+ return False
+
+ litellm_model = model_list[0]["litellm_params"].get("model", None)
+
+ if litellm_model is None:
+ return False
+
+ if "/" in litellm_model:
+ split_litellm_model = litellm_model.split("/")[0]
+ if split_litellm_model in litellm._known_custom_logger_compatible_callbacks:
+ return True
+ return False
+
+ async def _prompt_management_factory(
+ self,
+ model: str,
+ messages: List[AllMessageValues],
+ kwargs: Dict[str, Any],
+ ):
+ litellm_logging_object = kwargs.get("litellm_logging_obj", None)
+ if litellm_logging_object is None:
+ litellm_logging_object, kwargs = function_setup(
+ **{
+ "original_function": "acompletion",
+ "rules_obj": Rules(),
+ "start_time": get_utc_datetime(),
+ **kwargs,
+ }
+ )
+ litellm_logging_object = cast(LiteLLMLogging, litellm_logging_object)
+ prompt_management_deployment = self.get_available_deployment(
+ model=model,
+ messages=[{"role": "user", "content": "prompt"}],
+ specific_deployment=kwargs.pop("specific_deployment", None),
+ )
+
+ litellm_model = prompt_management_deployment["litellm_params"].get(
+ "model", None
+ )
+ prompt_id = kwargs.get("prompt_id") or prompt_management_deployment[
+ "litellm_params"
+ ].get("prompt_id", None)
+ prompt_variables = kwargs.get(
+ "prompt_variables"
+ ) or prompt_management_deployment["litellm_params"].get(
+ "prompt_variables", None
+ )
+
+ if prompt_id is None or not isinstance(prompt_id, str):
+ raise ValueError(
+ f"Prompt ID is not set or not a string. Got={prompt_id}, type={type(prompt_id)}"
+ )
+ if prompt_variables is not None and not isinstance(prompt_variables, dict):
+ raise ValueError(
+ f"Prompt variables is set but not a dictionary. Got={prompt_variables}, type={type(prompt_variables)}"
+ )
+
+ model, messages, optional_params = (
+ litellm_logging_object.get_chat_completion_prompt(
+ model=litellm_model,
+ messages=messages,
+ non_default_params=get_non_default_completion_params(kwargs=kwargs),
+ prompt_id=prompt_id,
+ prompt_variables=prompt_variables,
+ )
+ )
+
+ kwargs = {**kwargs, **optional_params}
+ kwargs["model"] = model
+ kwargs["messages"] = messages
+ kwargs["litellm_logging_obj"] = litellm_logging_object
+ kwargs["prompt_id"] = prompt_id
+ kwargs["prompt_variables"] = prompt_variables
+
+ _model_list = self.get_model_list(model_name=model)
+ if _model_list is None or len(_model_list) == 0: # if direct call to model
+ kwargs.pop("original_function")
+ return await litellm.acompletion(**kwargs)
+
+ return await self.async_function_with_fallbacks(**kwargs)
+
+ def image_generation(self, prompt: str, model: str, **kwargs):
+ try:
+ kwargs["model"] = model
+ kwargs["prompt"] = prompt
+ kwargs["original_function"] = self._image_generation
+ kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
+ kwargs.setdefault("metadata", {}).update({"model_group": model})
+ response = self.function_with_fallbacks(**kwargs)
+
+ return response
+ except Exception as e:
+ raise e
+
+ def _image_generation(self, prompt: str, model: str, **kwargs):
+ model_name = ""
+ try:
+ verbose_router_logger.debug(
+ f"Inside _image_generation()- model: {model}; kwargs: {kwargs}"
+ )
+ deployment = self.get_available_deployment(
+ model=model,
+ messages=[{"role": "user", "content": "prompt"}],
+ specific_deployment=kwargs.pop("specific_deployment", None),
+ )
+ self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
+ data = deployment["litellm_params"].copy()
+
+ model_client = self._get_async_openai_model_client(
+ deployment=deployment,
+ kwargs=kwargs,
+ )
+
+ self.total_calls[model_name] += 1
+
+ ### DEPLOYMENT-SPECIFIC PRE-CALL CHECKS ### (e.g. update rpm pre-call. Raise error, if deployment over limit)
+ self.routing_strategy_pre_call_checks(deployment=deployment)
+
+ response = litellm.image_generation(
+ **{
+ **data,
+ "prompt": prompt,
+ "caching": self.cache_responses,
+ "client": model_client,
+ **kwargs,
+ }
+ )
+ self.success_calls[model_name] += 1
+ verbose_router_logger.info(
+ f"litellm.image_generation(model={model_name})\033[32m 200 OK\033[0m"
+ )
+ return response
+ except Exception as e:
+ verbose_router_logger.info(
+ f"litellm.image_generation(model={model_name})\033[31m Exception {str(e)}\033[0m"
+ )
+ if model_name is not None:
+ self.fail_calls[model_name] += 1
+ raise e
+
+ async def aimage_generation(self, prompt: str, model: str, **kwargs):
+ try:
+ kwargs["model"] = model
+ kwargs["prompt"] = prompt
+ kwargs["original_function"] = self._aimage_generation
+ kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
+ self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
+ response = await self.async_function_with_fallbacks(**kwargs)
+
+ return response
+ except Exception as e:
+ asyncio.create_task(
+ send_llm_exception_alert(
+ litellm_router_instance=self,
+ request_kwargs=kwargs,
+ error_traceback_str=traceback.format_exc(),
+ original_exception=e,
+ )
+ )
+ raise e
+
+ async def _aimage_generation(self, prompt: str, model: str, **kwargs):
+ model_name = model
+ try:
+ verbose_router_logger.debug(
+ f"Inside _image_generation()- model: {model}; kwargs: {kwargs}"
+ )
+ parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
+ deployment = await self.async_get_available_deployment(
+ model=model,
+ messages=[{"role": "user", "content": "prompt"}],
+ specific_deployment=kwargs.pop("specific_deployment", None),
+ request_kwargs=kwargs,
+ )
+ self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
+
+ data = deployment["litellm_params"].copy()
+ model_name = data["model"]
+
+ model_client = self._get_async_openai_model_client(
+ deployment=deployment,
+ kwargs=kwargs,
+ )
+
+ self.total_calls[model_name] += 1
+ response = litellm.aimage_generation(
+ **{
+ **data,
+ "prompt": prompt,
+ "caching": self.cache_responses,
+ "client": model_client,
+ **kwargs,
+ }
+ )
+
+ ### CONCURRENCY-SAFE RPM CHECKS ###
+ rpm_semaphore = self._get_client(
+ deployment=deployment,
+ kwargs=kwargs,
+ client_type="max_parallel_requests",
+ )
+
+ if rpm_semaphore is not None and isinstance(
+ rpm_semaphore, asyncio.Semaphore
+ ):
+ async with rpm_semaphore:
+ """
+ - Check rpm limits before making the call
+ - If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
+ """
+ await self.async_routing_strategy_pre_call_checks(
+ deployment=deployment, parent_otel_span=parent_otel_span
+ )
+ response = await response
+ else:
+ await self.async_routing_strategy_pre_call_checks(
+ deployment=deployment, parent_otel_span=parent_otel_span
+ )
+ response = await response
+
+ self.success_calls[model_name] += 1
+ verbose_router_logger.info(
+ f"litellm.aimage_generation(model={model_name})\033[32m 200 OK\033[0m"
+ )
+ return response
+ except Exception as e:
+ verbose_router_logger.info(
+ f"litellm.aimage_generation(model={model_name})\033[31m Exception {str(e)}\033[0m"
+ )
+ if model_name is not None:
+ self.fail_calls[model_name] += 1
+ raise e
+
+ async def atranscription(self, file: FileTypes, model: str, **kwargs):
+ """
+ Example Usage:
+
+ ```
+ from litellm import Router
+ client = Router(model_list = [
+ {
+ "model_name": "whisper",
+ "litellm_params": {
+ "model": "whisper-1",
+ },
+ },
+ ])
+
+ audio_file = open("speech.mp3", "rb")
+ transcript = await client.atranscription(
+ model="whisper",
+ file=audio_file
+ )
+
+ ```
+ """
+ try:
+ kwargs["model"] = model
+ kwargs["file"] = file
+ kwargs["original_function"] = self._atranscription
+ self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
+ response = await self.async_function_with_fallbacks(**kwargs)
+
+ return response
+ except Exception as e:
+ asyncio.create_task(
+ send_llm_exception_alert(
+ litellm_router_instance=self,
+ request_kwargs=kwargs,
+ error_traceback_str=traceback.format_exc(),
+ original_exception=e,
+ )
+ )
+ raise e
+
+ async def _atranscription(self, file: FileTypes, model: str, **kwargs):
+ model_name = model
+ try:
+ verbose_router_logger.debug(
+ f"Inside _atranscription()- model: {model}; kwargs: {kwargs}"
+ )
+ parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
+ deployment = await self.async_get_available_deployment(
+ model=model,
+ messages=[{"role": "user", "content": "prompt"}],
+ specific_deployment=kwargs.pop("specific_deployment", None),
+ request_kwargs=kwargs,
+ )
+
+ self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
+ data = deployment["litellm_params"].copy()
+ model_client = self._get_async_openai_model_client(
+ deployment=deployment,
+ kwargs=kwargs,
+ )
+
+ self.total_calls[model_name] += 1
+ response = litellm.atranscription(
+ **{
+ **data,
+ "file": file,
+ "caching": self.cache_responses,
+ "client": model_client,
+ **kwargs,
+ }
+ )
+
+ ### CONCURRENCY-SAFE RPM CHECKS ###
+ rpm_semaphore = self._get_client(
+ deployment=deployment,
+ kwargs=kwargs,
+ client_type="max_parallel_requests",
+ )
+
+ if rpm_semaphore is not None and isinstance(
+ rpm_semaphore, asyncio.Semaphore
+ ):
+ async with rpm_semaphore:
+ """
+ - Check rpm limits before making the call
+ - If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
+ """
+ await self.async_routing_strategy_pre_call_checks(
+ deployment=deployment, parent_otel_span=parent_otel_span
+ )
+ response = await response
+ else:
+ await self.async_routing_strategy_pre_call_checks(
+ deployment=deployment, parent_otel_span=parent_otel_span
+ )
+ response = await response
+
+ self.success_calls[model_name] += 1
+ verbose_router_logger.info(
+ f"litellm.atranscription(model={model_name})\033[32m 200 OK\033[0m"
+ )
+ return response
+ except Exception as e:
+ verbose_router_logger.info(
+ f"litellm.atranscription(model={model_name})\033[31m Exception {str(e)}\033[0m"
+ )
+ if model_name is not None:
+ self.fail_calls[model_name] += 1
+ raise e
+
+ async def aspeech(self, model: str, input: str, voice: str, **kwargs):
+ """
+ Example Usage:
+
+ ```
+ from litellm import Router
+ client = Router(model_list = [
+ {
+ "model_name": "tts",
+ "litellm_params": {
+ "model": "tts-1",
+ },
+ },
+ ])
+
+ async with client.aspeech(
+ model="tts",
+ voice="alloy",
+ input="the quick brown fox jumped over the lazy dogs",
+ api_base=None,
+ api_key=None,
+ organization=None,
+ project=None,
+ max_retries=1,
+ timeout=600,
+ client=None,
+ optional_params={},
+ ) as response:
+ response.stream_to_file(speech_file_path)
+
+ ```
+ """
+ try:
+ kwargs["input"] = input
+ kwargs["voice"] = voice
+
+ deployment = await self.async_get_available_deployment(
+ model=model,
+ messages=[{"role": "user", "content": "prompt"}],
+ specific_deployment=kwargs.pop("specific_deployment", None),
+ request_kwargs=kwargs,
+ )
+ self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
+ data = deployment["litellm_params"].copy()
+ data["model"]
+ for k, v in self.default_litellm_params.items():
+ if (
+ k not in kwargs
+ ): # prioritize model-specific params > default router params
+ kwargs[k] = v
+ elif k == "metadata":
+ kwargs[k].update(v)
+
+ potential_model_client = self._get_client(
+ deployment=deployment, kwargs=kwargs, client_type="async"
+ )
+ # check if provided keys == client keys #
+ dynamic_api_key = kwargs.get("api_key", None)
+ if (
+ dynamic_api_key is not None
+ and potential_model_client is not None
+ and dynamic_api_key != potential_model_client.api_key
+ ):
+ model_client = None
+ else:
+ model_client = potential_model_client
+
+ response = await litellm.aspeech(
+ **{
+ **data,
+ "client": model_client,
+ **kwargs,
+ }
+ )
+ return response
+ except Exception as e:
+ asyncio.create_task(
+ send_llm_exception_alert(
+ litellm_router_instance=self,
+ request_kwargs=kwargs,
+ error_traceback_str=traceback.format_exc(),
+ original_exception=e,
+ )
+ )
+ raise e
+
+ async def arerank(self, model: str, **kwargs):
+ try:
+ kwargs["model"] = model
+ kwargs["input"] = input
+ kwargs["original_function"] = self._arerank
+ self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
+
+ response = await self.async_function_with_fallbacks(**kwargs)
+
+ return response
+ except Exception as e:
+ asyncio.create_task(
+ send_llm_exception_alert(
+ litellm_router_instance=self,
+ request_kwargs=kwargs,
+ error_traceback_str=traceback.format_exc(),
+ original_exception=e,
+ )
+ )
+ raise e
+
+ async def _arerank(self, model: str, **kwargs):
+ model_name = None
+ try:
+ verbose_router_logger.debug(
+ f"Inside _rerank()- model: {model}; kwargs: {kwargs}"
+ )
+ deployment = await self.async_get_available_deployment(
+ model=model,
+ specific_deployment=kwargs.pop("specific_deployment", None),
+ request_kwargs=kwargs,
+ )
+ self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
+ data = deployment["litellm_params"].copy()
+ model_name = data["model"]
+
+ model_client = self._get_async_openai_model_client(
+ deployment=deployment,
+ kwargs=kwargs,
+ )
+ self.total_calls[model_name] += 1
+
+ response = await litellm.arerank(
+ **{
+ **data,
+ "caching": self.cache_responses,
+ "client": model_client,
+ **kwargs,
+ }
+ )
+
+ self.success_calls[model_name] += 1
+ verbose_router_logger.info(
+ f"litellm.arerank(model={model_name})\033[32m 200 OK\033[0m"
+ )
+ return response
+ except Exception as e:
+ verbose_router_logger.info(
+ f"litellm.arerank(model={model_name})\033[31m Exception {str(e)}\033[0m"
+ )
+ if model_name is not None:
+ self.fail_calls[model_name] += 1
+ raise e
+
+ async def _arealtime(self, model: str, **kwargs):
+ messages = [{"role": "user", "content": "dummy-text"}]
+ try:
+ kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
+ self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
+
+ # pick the one that is available (lowest TPM/RPM)
+ deployment = await self.async_get_available_deployment(
+ model=model,
+ messages=messages,
+ specific_deployment=kwargs.pop("specific_deployment", None),
+ request_kwargs=kwargs,
+ )
+
+ data = deployment["litellm_params"].copy()
+ for k, v in self.default_litellm_params.items():
+ if (
+ k not in kwargs
+ ): # prioritize model-specific params > default router params
+ kwargs[k] = v
+ elif k == "metadata":
+ kwargs[k].update(v)
+
+ return await litellm._arealtime(**{**data, "caching": self.cache_responses, **kwargs}) # type: ignore
+ except Exception as e:
+ if self.num_retries > 0:
+ kwargs["model"] = model
+ kwargs["messages"] = messages
+ kwargs["original_function"] = self._arealtime
+ return await self.async_function_with_retries(**kwargs)
+ else:
+ raise e
+
+ def text_completion(
+ self,
+ model: str,
+ prompt: str,
+ is_retry: Optional[bool] = False,
+ is_fallback: Optional[bool] = False,
+ is_async: Optional[bool] = False,
+ **kwargs,
+ ):
+ messages = [{"role": "user", "content": prompt}]
+ try:
+ kwargs["model"] = model
+ kwargs["prompt"] = prompt
+ kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
+ kwargs.setdefault("metadata", {}).update({"model_group": model})
+
+ # pick the one that is available (lowest TPM/RPM)
+ deployment = self.get_available_deployment(
+ model=model,
+ messages=messages,
+ specific_deployment=kwargs.pop("specific_deployment", None),
+ )
+
+ data = deployment["litellm_params"].copy()
+ for k, v in self.default_litellm_params.items():
+ if (
+ k not in kwargs
+ ): # prioritize model-specific params > default router params
+ kwargs[k] = v
+ elif k == "metadata":
+ kwargs[k].update(v)
+
+ # call via litellm.completion()
+ return litellm.text_completion(**{**data, "prompt": prompt, "caching": self.cache_responses, **kwargs}) # type: ignore
+ except Exception as e:
+ raise e
+
+ async def atext_completion(
+ self,
+ model: str,
+ prompt: str,
+ is_retry: Optional[bool] = False,
+ is_fallback: Optional[bool] = False,
+ is_async: Optional[bool] = False,
+ **kwargs,
+ ):
+ if kwargs.get("priority", None) is not None:
+ return await self._schedule_factory(
+ model=model,
+ priority=kwargs.pop("priority"),
+ original_function=self.atext_completion,
+ args=(model, prompt),
+ kwargs=kwargs,
+ )
+ try:
+ kwargs["model"] = model
+ kwargs["prompt"] = prompt
+ kwargs["original_function"] = self._atext_completion
+
+ self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
+ response = await self.async_function_with_fallbacks(**kwargs)
+
+ return response
+ except Exception as e:
+ asyncio.create_task(
+ send_llm_exception_alert(
+ litellm_router_instance=self,
+ request_kwargs=kwargs,
+ error_traceback_str=traceback.format_exc(),
+ original_exception=e,
+ )
+ )
+ raise e
+
+ async def _atext_completion(self, model: str, prompt: str, **kwargs):
+ try:
+ verbose_router_logger.debug(
+ f"Inside _atext_completion()- model: {model}; kwargs: {kwargs}"
+ )
+ parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
+ deployment = await self.async_get_available_deployment(
+ model=model,
+ messages=[{"role": "user", "content": prompt}],
+ specific_deployment=kwargs.pop("specific_deployment", None),
+ request_kwargs=kwargs,
+ )
+ self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
+
+ data = deployment["litellm_params"].copy()
+ model_name = data["model"]
+
+ model_client = self._get_async_openai_model_client(
+ deployment=deployment,
+ kwargs=kwargs,
+ )
+ self.total_calls[model_name] += 1
+
+ response = litellm.atext_completion(
+ **{
+ **data,
+ "prompt": prompt,
+ "caching": self.cache_responses,
+ "client": model_client,
+ **kwargs,
+ }
+ )
+
+ rpm_semaphore = self._get_client(
+ deployment=deployment,
+ kwargs=kwargs,
+ client_type="max_parallel_requests",
+ )
+
+ if rpm_semaphore is not None and isinstance(
+ rpm_semaphore, asyncio.Semaphore
+ ):
+ async with rpm_semaphore:
+ """
+ - Check rpm limits before making the call
+ - If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
+ """
+ await self.async_routing_strategy_pre_call_checks(
+ deployment=deployment, parent_otel_span=parent_otel_span
+ )
+ response = await response
+ else:
+ await self.async_routing_strategy_pre_call_checks(
+ deployment=deployment, parent_otel_span=parent_otel_span
+ )
+ response = await response
+
+ self.success_calls[model_name] += 1
+ verbose_router_logger.info(
+ f"litellm.atext_completion(model={model_name})\033[32m 200 OK\033[0m"
+ )
+ return response
+ except Exception as e:
+ verbose_router_logger.info(
+ f"litellm.atext_completion(model={model})\033[31m Exception {str(e)}\033[0m"
+ )
+ if model is not None:
+ self.fail_calls[model] += 1
+ raise e
+
+ async def aadapter_completion(
+ self,
+ adapter_id: str,
+ model: str,
+ is_retry: Optional[bool] = False,
+ is_fallback: Optional[bool] = False,
+ is_async: Optional[bool] = False,
+ **kwargs,
+ ):
+ try:
+ kwargs["model"] = model
+ kwargs["adapter_id"] = adapter_id
+ kwargs["original_function"] = self._aadapter_completion
+ kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
+ kwargs.setdefault("metadata", {}).update({"model_group": model})
+ response = await self.async_function_with_fallbacks(**kwargs)
+
+ return response
+ except Exception as e:
+ asyncio.create_task(
+ send_llm_exception_alert(
+ litellm_router_instance=self,
+ request_kwargs=kwargs,
+ error_traceback_str=traceback.format_exc(),
+ original_exception=e,
+ )
+ )
+ raise e
+
+ async def _aadapter_completion(self, adapter_id: str, model: str, **kwargs):
+ try:
+ verbose_router_logger.debug(
+ f"Inside _aadapter_completion()- model: {model}; kwargs: {kwargs}"
+ )
+ parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
+ deployment = await self.async_get_available_deployment(
+ model=model,
+ messages=[{"role": "user", "content": "default text"}],
+ specific_deployment=kwargs.pop("specific_deployment", None),
+ request_kwargs=kwargs,
+ )
+ self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
+
+ data = deployment["litellm_params"].copy()
+ model_name = data["model"]
+
+ model_client = self._get_async_openai_model_client(
+ deployment=deployment,
+ kwargs=kwargs,
+ )
+ self.total_calls[model_name] += 1
+
+ response = litellm.aadapter_completion(
+ **{
+ **data,
+ "adapter_id": adapter_id,
+ "caching": self.cache_responses,
+ "client": model_client,
+ **kwargs,
+ }
+ )
+
+ rpm_semaphore = self._get_client(
+ deployment=deployment,
+ kwargs=kwargs,
+ client_type="max_parallel_requests",
+ )
+
+ if rpm_semaphore is not None and isinstance(
+ rpm_semaphore, asyncio.Semaphore
+ ):
+ async with rpm_semaphore:
+ """
+ - Check rpm limits before making the call
+ - If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
+ """
+ await self.async_routing_strategy_pre_call_checks(
+ deployment=deployment, parent_otel_span=parent_otel_span
+ )
+ response = await response # type: ignore
+ else:
+ await self.async_routing_strategy_pre_call_checks(
+ deployment=deployment, parent_otel_span=parent_otel_span
+ )
+ response = await response # type: ignore
+
+ self.success_calls[model_name] += 1
+ verbose_router_logger.info(
+ f"litellm.aadapter_completion(model={model_name})\033[32m 200 OK\033[0m"
+ )
+ return response
+ except Exception as e:
+ verbose_router_logger.info(
+ f"litellm.aadapter_completion(model={model})\033[31m Exception {str(e)}\033[0m"
+ )
+ if model is not None:
+ self.fail_calls[model] += 1
+ raise e
+
+ async def _ageneric_api_call_with_fallbacks(
+ self, model: str, original_function: Callable, **kwargs
+ ):
+ """
+ Make a generic LLM API call through the router, this allows you to use retries/fallbacks with litellm router
+
+ Args:
+ model: The model to use
+ handler_function: The handler function to call (e.g., litellm.anthropic_messages)
+ **kwargs: Additional arguments to pass to the handler function
+
+ Returns:
+ The response from the handler function
+ """
+ handler_name = original_function.__name__
+ try:
+ verbose_router_logger.debug(
+ f"Inside _ageneric_api_call() - handler: {handler_name}, model: {model}; kwargs: {kwargs}"
+ )
+ parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
+ deployment = await self.async_get_available_deployment(
+ model=model,
+ request_kwargs=kwargs,
+ messages=kwargs.get("messages", None),
+ specific_deployment=kwargs.pop("specific_deployment", None),
+ )
+ self._update_kwargs_with_deployment(
+ deployment=deployment, kwargs=kwargs, function_name="generic_api_call"
+ )
+
+ data = deployment["litellm_params"].copy()
+ model_name = data["model"]
+ self.total_calls[model_name] += 1
+
+ response = original_function(
+ **{
+ **data,
+ "caching": self.cache_responses,
+ **kwargs,
+ }
+ )
+
+ rpm_semaphore = self._get_client(
+ deployment=deployment,
+ kwargs=kwargs,
+ client_type="max_parallel_requests",
+ )
+
+ if rpm_semaphore is not None and isinstance(
+ rpm_semaphore, asyncio.Semaphore
+ ):
+ async with rpm_semaphore:
+ """
+ - Check rpm limits before making the call
+ - If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
+ """
+ await self.async_routing_strategy_pre_call_checks(
+ deployment=deployment, parent_otel_span=parent_otel_span
+ )
+ response = await response # type: ignore
+ else:
+ await self.async_routing_strategy_pre_call_checks(
+ deployment=deployment, parent_otel_span=parent_otel_span
+ )
+ response = await response # type: ignore
+
+ self.success_calls[model_name] += 1
+ verbose_router_logger.info(
+ f"{handler_name}(model={model_name})\033[32m 200 OK\033[0m"
+ )
+ return response
+ except Exception as e:
+ verbose_router_logger.info(
+ f"{handler_name}(model={model})\033[31m Exception {str(e)}\033[0m"
+ )
+ if model is not None:
+ self.fail_calls[model] += 1
+ raise e
+
+ def _generic_api_call_with_fallbacks(
+ self, model: str, original_function: Callable, **kwargs
+ ):
+ """
+ Make a generic LLM API call through the router, this allows you to use retries/fallbacks with litellm router
+ Args:
+ model: The model to use
+ original_function: The handler function to call (e.g., litellm.completion)
+ **kwargs: Additional arguments to pass to the handler function
+ Returns:
+ The response from the handler function
+ """
+ handler_name = original_function.__name__
+ try:
+ verbose_router_logger.debug(
+ f"Inside _generic_api_call() - handler: {handler_name}, model: {model}; kwargs: {kwargs}"
+ )
+ deployment = self.get_available_deployment(
+ model=model,
+ messages=kwargs.get("messages", None),
+ specific_deployment=kwargs.pop("specific_deployment", None),
+ )
+ self._update_kwargs_with_deployment(
+ deployment=deployment, kwargs=kwargs, function_name="generic_api_call"
+ )
+
+ data = deployment["litellm_params"].copy()
+ model_name = data["model"]
+
+ self.total_calls[model_name] += 1
+
+ # Perform pre-call checks for routing strategy
+ self.routing_strategy_pre_call_checks(deployment=deployment)
+
+ response = original_function(
+ **{
+ **data,
+ "caching": self.cache_responses,
+ **kwargs,
+ }
+ )
+
+ self.success_calls[model_name] += 1
+ verbose_router_logger.info(
+ f"{handler_name}(model={model_name})\033[32m 200 OK\033[0m"
+ )
+ return response
+ except Exception as e:
+ verbose_router_logger.info(
+ f"{handler_name}(model={model})\033[31m Exception {str(e)}\033[0m"
+ )
+ if model is not None:
+ self.fail_calls[model] += 1
+ raise e
+
+ def embedding(
+ self,
+ model: str,
+ input: Union[str, List],
+ is_async: Optional[bool] = False,
+ **kwargs,
+ ) -> EmbeddingResponse:
+ try:
+ kwargs["model"] = model
+ kwargs["input"] = input
+ kwargs["original_function"] = self._embedding
+ kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
+ kwargs.setdefault("metadata", {}).update({"model_group": model})
+ response = self.function_with_fallbacks(**kwargs)
+ return response
+ except Exception as e:
+ raise e
+
+ def _embedding(self, input: Union[str, List], model: str, **kwargs):
+ model_name = None
+ try:
+ verbose_router_logger.debug(
+ f"Inside embedding()- model: {model}; kwargs: {kwargs}"
+ )
+ deployment = self.get_available_deployment(
+ model=model,
+ input=input,
+ specific_deployment=kwargs.pop("specific_deployment", None),
+ )
+ self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
+ data = deployment["litellm_params"].copy()
+ model_name = data["model"]
+
+ potential_model_client = self._get_client(
+ deployment=deployment, kwargs=kwargs, client_type="sync"
+ )
+ # check if provided keys == client keys #
+ dynamic_api_key = kwargs.get("api_key", None)
+ if (
+ dynamic_api_key is not None
+ and potential_model_client is not None
+ and dynamic_api_key != potential_model_client.api_key
+ ):
+ model_client = None
+ else:
+ model_client = potential_model_client
+
+ self.total_calls[model_name] += 1
+
+ ### DEPLOYMENT-SPECIFIC PRE-CALL CHECKS ### (e.g. update rpm pre-call. Raise error, if deployment over limit)
+ self.routing_strategy_pre_call_checks(deployment=deployment)
+
+ response = litellm.embedding(
+ **{
+ **data,
+ "input": input,
+ "caching": self.cache_responses,
+ "client": model_client,
+ **kwargs,
+ }
+ )
+ self.success_calls[model_name] += 1
+ verbose_router_logger.info(
+ f"litellm.embedding(model={model_name})\033[32m 200 OK\033[0m"
+ )
+ return response
+ except Exception as e:
+ verbose_router_logger.info(
+ f"litellm.embedding(model={model_name})\033[31m Exception {str(e)}\033[0m"
+ )
+ if model_name is not None:
+ self.fail_calls[model_name] += 1
+ raise e
+
+ async def aembedding(
+ self,
+ model: str,
+ input: Union[str, List],
+ is_async: Optional[bool] = True,
+ **kwargs,
+ ) -> EmbeddingResponse:
+ try:
+ kwargs["model"] = model
+ kwargs["input"] = input
+ kwargs["original_function"] = self._aembedding
+ self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
+ response = await self.async_function_with_fallbacks(**kwargs)
+ return response
+ except Exception as e:
+ asyncio.create_task(
+ send_llm_exception_alert(
+ litellm_router_instance=self,
+ request_kwargs=kwargs,
+ error_traceback_str=traceback.format_exc(),
+ original_exception=e,
+ )
+ )
+ raise e
+
+ async def _aembedding(self, input: Union[str, List], model: str, **kwargs):
+ model_name = None
+ try:
+ verbose_router_logger.debug(
+ f"Inside _aembedding()- model: {model}; kwargs: {kwargs}"
+ )
+ parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
+ deployment = await self.async_get_available_deployment(
+ model=model,
+ input=input,
+ specific_deployment=kwargs.pop("specific_deployment", None),
+ request_kwargs=kwargs,
+ )
+ self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
+ data = deployment["litellm_params"].copy()
+ model_name = data["model"]
+ model_client = self._get_async_openai_model_client(
+ deployment=deployment,
+ kwargs=kwargs,
+ )
+
+ self.total_calls[model_name] += 1
+ response = litellm.aembedding(
+ **{
+ **data,
+ "input": input,
+ "caching": self.cache_responses,
+ "client": model_client,
+ **kwargs,
+ }
+ )
+
+ ### CONCURRENCY-SAFE RPM CHECKS ###
+ rpm_semaphore = self._get_client(
+ deployment=deployment,
+ kwargs=kwargs,
+ client_type="max_parallel_requests",
+ )
+
+ if rpm_semaphore is not None and isinstance(
+ rpm_semaphore, asyncio.Semaphore
+ ):
+ async with rpm_semaphore:
+ """
+ - Check rpm limits before making the call
+ - If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
+ """
+ await self.async_routing_strategy_pre_call_checks(
+ deployment=deployment, parent_otel_span=parent_otel_span
+ )
+ response = await response
+ else:
+ await self.async_routing_strategy_pre_call_checks(
+ deployment=deployment, parent_otel_span=parent_otel_span
+ )
+ response = await response
+
+ self.success_calls[model_name] += 1
+ verbose_router_logger.info(
+ f"litellm.aembedding(model={model_name})\033[32m 200 OK\033[0m"
+ )
+ return response
+ except Exception as e:
+ verbose_router_logger.info(
+ f"litellm.aembedding(model={model_name})\033[31m Exception {str(e)}\033[0m"
+ )
+ if model_name is not None:
+ self.fail_calls[model_name] += 1
+ raise e
+
+ #### FILES API ####
+ async def acreate_file(
+ self,
+ model: str,
+ **kwargs,
+ ) -> FileObject:
+ try:
+ kwargs["model"] = model
+ kwargs["original_function"] = self._acreate_file
+ kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
+ self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
+ response = await self.async_function_with_fallbacks(**kwargs)
+
+ return response
+ except Exception as e:
+ asyncio.create_task(
+ send_llm_exception_alert(
+ litellm_router_instance=self,
+ request_kwargs=kwargs,
+ error_traceback_str=traceback.format_exc(),
+ original_exception=e,
+ )
+ )
+ raise e
+
+ async def _acreate_file(
+ self,
+ model: str,
+ **kwargs,
+ ) -> FileObject:
+ try:
+ verbose_router_logger.debug(
+ f"Inside _atext_completion()- model: {model}; kwargs: {kwargs}"
+ )
+ parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
+ deployment = await self.async_get_available_deployment(
+ model=model,
+ messages=[{"role": "user", "content": "files-api-fake-text"}],
+ specific_deployment=kwargs.pop("specific_deployment", None),
+ request_kwargs=kwargs,
+ )
+ self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
+
+ data = deployment["litellm_params"].copy()
+ model_name = data["model"]
+
+ model_client = self._get_async_openai_model_client(
+ deployment=deployment,
+ kwargs=kwargs,
+ )
+ self.total_calls[model_name] += 1
+
+ ## REPLACE MODEL IN FILE WITH SELECTED DEPLOYMENT ##
+ stripped_model, custom_llm_provider, _, _ = get_llm_provider(
+ model=data["model"]
+ )
+ kwargs["file"] = replace_model_in_jsonl(
+ file_content=kwargs["file"], new_model_name=stripped_model
+ )
+
+ response = litellm.acreate_file(
+ **{
+ **data,
+ "custom_llm_provider": custom_llm_provider,
+ "caching": self.cache_responses,
+ "client": model_client,
+ **kwargs,
+ }
+ )
+
+ rpm_semaphore = self._get_client(
+ deployment=deployment,
+ kwargs=kwargs,
+ client_type="max_parallel_requests",
+ )
+
+ if rpm_semaphore is not None and isinstance(
+ rpm_semaphore, asyncio.Semaphore
+ ):
+ async with rpm_semaphore:
+ """
+ - Check rpm limits before making the call
+ - If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
+ """
+ await self.async_routing_strategy_pre_call_checks(
+ deployment=deployment, parent_otel_span=parent_otel_span
+ )
+ response = await response # type: ignore
+ else:
+ await self.async_routing_strategy_pre_call_checks(
+ deployment=deployment, parent_otel_span=parent_otel_span
+ )
+ response = await response # type: ignore
+
+ self.success_calls[model_name] += 1
+ verbose_router_logger.info(
+ f"litellm.acreate_file(model={model_name})\033[32m 200 OK\033[0m"
+ )
+ return response # type: ignore
+ except Exception as e:
+ verbose_router_logger.exception(
+ f"litellm.acreate_file(model={model}, {kwargs})\033[31m Exception {str(e)}\033[0m"
+ )
+ if model is not None:
+ self.fail_calls[model] += 1
+ raise e
+
+ async def acreate_batch(
+ self,
+ model: str,
+ **kwargs,
+ ) -> Batch:
+ try:
+ kwargs["model"] = model
+ kwargs["original_function"] = self._acreate_batch
+ kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
+ self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
+ response = await self.async_function_with_fallbacks(**kwargs)
+
+ return response
+ except Exception as e:
+ asyncio.create_task(
+ send_llm_exception_alert(
+ litellm_router_instance=self,
+ request_kwargs=kwargs,
+ error_traceback_str=traceback.format_exc(),
+ original_exception=e,
+ )
+ )
+ raise e
+
+ async def _acreate_batch(
+ self,
+ model: str,
+ **kwargs,
+ ) -> Batch:
+ try:
+ verbose_router_logger.debug(
+ f"Inside _acreate_batch()- model: {model}; kwargs: {kwargs}"
+ )
+ parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
+ deployment = await self.async_get_available_deployment(
+ model=model,
+ messages=[{"role": "user", "content": "files-api-fake-text"}],
+ specific_deployment=kwargs.pop("specific_deployment", None),
+ request_kwargs=kwargs,
+ )
+ metadata_variable_name = _get_router_metadata_variable_name(
+ function_name="_acreate_batch"
+ )
+
+ kwargs.setdefault(metadata_variable_name, {}).update(
+ {
+ "deployment": deployment["litellm_params"]["model"],
+ "model_info": deployment.get("model_info", {}),
+ "api_base": deployment.get("litellm_params", {}).get("api_base"),
+ }
+ )
+ kwargs["model_info"] = deployment.get("model_info", {})
+ data = deployment["litellm_params"].copy()
+ model_name = data["model"]
+ self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
+
+ model_client = self._get_async_openai_model_client(
+ deployment=deployment,
+ kwargs=kwargs,
+ )
+ self.total_calls[model_name] += 1
+
+ ## SET CUSTOM PROVIDER TO SELECTED DEPLOYMENT ##
+ _, custom_llm_provider, _, _ = get_llm_provider(model=data["model"])
+
+ response = litellm.acreate_batch(
+ **{
+ **data,
+ "custom_llm_provider": custom_llm_provider,
+ "caching": self.cache_responses,
+ "client": model_client,
+ **kwargs,
+ }
+ )
+
+ rpm_semaphore = self._get_client(
+ deployment=deployment,
+ kwargs=kwargs,
+ client_type="max_parallel_requests",
+ )
+
+ if rpm_semaphore is not None and isinstance(
+ rpm_semaphore, asyncio.Semaphore
+ ):
+ async with rpm_semaphore:
+ """
+ - Check rpm limits before making the call
+ - If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
+ """
+ await self.async_routing_strategy_pre_call_checks(
+ deployment=deployment, parent_otel_span=parent_otel_span
+ )
+ response = await response # type: ignore
+ else:
+ await self.async_routing_strategy_pre_call_checks(
+ deployment=deployment, parent_otel_span=parent_otel_span
+ )
+ response = await response # type: ignore
+
+ self.success_calls[model_name] += 1
+ verbose_router_logger.info(
+ f"litellm.acreate_file(model={model_name})\033[32m 200 OK\033[0m"
+ )
+ return response # type: ignore
+ except Exception as e:
+ verbose_router_logger.exception(
+ f"litellm._acreate_batch(model={model}, {kwargs})\033[31m Exception {str(e)}\033[0m"
+ )
+ if model is not None:
+ self.fail_calls[model] += 1
+ raise e
+
+ async def aretrieve_batch(
+ self,
+ **kwargs,
+ ) -> Batch:
+ """
+ Iterate through all models in a model group to check for batch
+
+ Future Improvement - cache the result.
+ """
+ try:
+
+ filtered_model_list = self.get_model_list()
+ if filtered_model_list is None:
+ raise Exception("Router not yet initialized.")
+
+ receieved_exceptions = []
+
+ async def try_retrieve_batch(model_name):
+ try:
+ # Update kwargs with the current model name or any other model-specific adjustments
+ ## SET CUSTOM PROVIDER TO SELECTED DEPLOYMENT ##
+ _, custom_llm_provider, _, _ = get_llm_provider( # type: ignore
+ model=model_name["litellm_params"]["model"]
+ )
+ new_kwargs = copy.deepcopy(kwargs)
+ new_kwargs.pop("custom_llm_provider", None)
+ return await litellm.aretrieve_batch(
+ custom_llm_provider=custom_llm_provider, **new_kwargs # type: ignore
+ )
+ except Exception as e:
+ receieved_exceptions.append(e)
+ return None
+
+ # Check all models in parallel
+ results = await asyncio.gather(
+ *[try_retrieve_batch(model) for model in filtered_model_list],
+ return_exceptions=True,
+ )
+
+ # Check for successful responses and handle exceptions
+ for result in results:
+ if isinstance(result, Batch):
+ return result
+
+ # If no valid Batch response was found, raise the first encountered exception
+ if receieved_exceptions:
+ raise receieved_exceptions[0] # Raising the first exception encountered
+
+ # If no exceptions were encountered, raise a generic exception
+ raise Exception(
+ "Unable to find batch in any model. Received errors - {}".format(
+ receieved_exceptions
+ )
+ )
+ except Exception as e:
+ asyncio.create_task(
+ send_llm_exception_alert(
+ litellm_router_instance=self,
+ request_kwargs=kwargs,
+ error_traceback_str=traceback.format_exc(),
+ original_exception=e,
+ )
+ )
+ raise e
+
+ async def alist_batches(
+ self,
+ model: str,
+ **kwargs,
+ ):
+ """
+ Return all the batches across all deployments of a model group.
+ """
+
+ filtered_model_list = self.get_model_list(model_name=model)
+ if filtered_model_list is None:
+ raise Exception("Router not yet initialized.")
+
+ async def try_retrieve_batch(model: DeploymentTypedDict):
+ try:
+ # Update kwargs with the current model name or any other model-specific adjustments
+ return await litellm.alist_batches(
+ **{**model["litellm_params"], **kwargs}
+ )
+ except Exception:
+ return None
+
+ # Check all models in parallel
+ results = await asyncio.gather(
+ *[try_retrieve_batch(model) for model in filtered_model_list]
+ )
+
+ final_results = {
+ "object": "list",
+ "data": [],
+ "first_id": None,
+ "last_id": None,
+ "has_more": False,
+ }
+
+ for result in results:
+ if result is not None:
+ ## check batch id
+ if final_results["first_id"] is None and hasattr(result, "first_id"):
+ final_results["first_id"] = getattr(result, "first_id")
+ final_results["last_id"] = getattr(result, "last_id")
+ final_results["data"].extend(result.data) # type: ignore
+
+ ## check 'has_more'
+ if getattr(result, "has_more", False) is True:
+ final_results["has_more"] = True
+
+ return final_results
+
+ #### PASSTHROUGH API ####
+
+ async def _pass_through_moderation_endpoint_factory(
+ self,
+ original_function: Callable,
+ **kwargs,
+ ):
+ if kwargs.get("model") and self.get_model_list(model_name=kwargs["model"]):
+ deployment = await self.async_get_available_deployment(
+ model=kwargs["model"],
+ request_kwargs=kwargs,
+ )
+ kwargs["model"] = deployment["litellm_params"]["model"]
+ return await original_function(**kwargs)
+
+ def factory_function(
+ self,
+ original_function: Callable,
+ call_type: Literal[
+ "assistants",
+ "moderation",
+ "anthropic_messages",
+ "aresponses",
+ "responses",
+ ] = "assistants",
+ ):
+ """
+ Creates appropriate wrapper functions for different API call types.
+
+ Returns:
+ - A synchronous function for synchronous call types
+ - An asynchronous function for asynchronous call types
+ """
+ # Handle synchronous call types
+ if call_type == "responses":
+
+ def sync_wrapper(
+ custom_llm_provider: Optional[
+ Literal["openai", "azure", "anthropic"]
+ ] = None,
+ client: Optional[Any] = None,
+ **kwargs,
+ ):
+ return self._generic_api_call_with_fallbacks(
+ original_function=original_function, **kwargs
+ )
+
+ return sync_wrapper
+
+ # Handle asynchronous call types
+ async def async_wrapper(
+ custom_llm_provider: Optional[
+ Literal["openai", "azure", "anthropic"]
+ ] = None,
+ client: Optional[Any] = None,
+ **kwargs,
+ ):
+ if call_type == "assistants":
+ return await self._pass_through_assistants_endpoint_factory(
+ original_function=original_function,
+ custom_llm_provider=custom_llm_provider,
+ client=client,
+ **kwargs,
+ )
+ elif call_type == "moderation":
+ return await self._pass_through_moderation_endpoint_factory(
+ original_function=original_function, **kwargs
+ )
+ elif call_type in ("anthropic_messages", "aresponses"):
+ return await self._ageneric_api_call_with_fallbacks(
+ original_function=original_function,
+ **kwargs,
+ )
+
+ return async_wrapper
+
+ async def _pass_through_assistants_endpoint_factory(
+ self,
+ original_function: Callable,
+ custom_llm_provider: Optional[Literal["openai", "azure", "anthropic"]] = None,
+ client: Optional[AsyncOpenAI] = None,
+ **kwargs,
+ ):
+ """Internal helper function to pass through the assistants endpoint"""
+ if custom_llm_provider is None:
+ if self.assistants_config is not None:
+ custom_llm_provider = self.assistants_config["custom_llm_provider"]
+ kwargs.update(self.assistants_config["litellm_params"])
+ else:
+ raise Exception(
+ "'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
+ )
+ return await original_function( # type: ignore
+ custom_llm_provider=custom_llm_provider, client=client, **kwargs
+ )
+
+ #### [END] ASSISTANTS API ####
+
+ @tracer.wrap()
+ async def async_function_with_fallbacks(self, *args, **kwargs): # noqa: PLR0915
+ """
+ Try calling the function_with_retries
+ If it fails after num_retries, fall back to another model group
+ """
+ model_group: Optional[str] = kwargs.get("model")
+ disable_fallbacks: Optional[bool] = kwargs.pop("disable_fallbacks", False)
+ fallbacks: Optional[List] = kwargs.get("fallbacks", self.fallbacks)
+ context_window_fallbacks: Optional[List] = kwargs.get(
+ "context_window_fallbacks", self.context_window_fallbacks
+ )
+ content_policy_fallbacks: Optional[List] = kwargs.get(
+ "content_policy_fallbacks", self.content_policy_fallbacks
+ )
+
+ mock_timeout = kwargs.pop("mock_timeout", None)
+
+ try:
+ self._handle_mock_testing_fallbacks(
+ kwargs=kwargs,
+ model_group=model_group,
+ fallbacks=fallbacks,
+ context_window_fallbacks=context_window_fallbacks,
+ content_policy_fallbacks=content_policy_fallbacks,
+ )
+
+ if mock_timeout is not None:
+ response = await self.async_function_with_retries(
+ *args, **kwargs, mock_timeout=mock_timeout
+ )
+ else:
+ response = await self.async_function_with_retries(*args, **kwargs)
+ verbose_router_logger.debug(f"Async Response: {response}")
+ response = add_fallback_headers_to_response(
+ response=response,
+ attempted_fallbacks=0,
+ )
+ return response
+ except Exception as e:
+ verbose_router_logger.debug(f"Traceback{traceback.format_exc()}")
+ original_exception = e
+ fallback_model_group = None
+ original_model_group: Optional[str] = kwargs.get("model") # type: ignore
+ fallback_failure_exception_str = ""
+
+ if disable_fallbacks is True or original_model_group is None:
+ raise e
+
+ input_kwargs = {
+ "litellm_router": self,
+ "original_exception": original_exception,
+ **kwargs,
+ }
+
+ if "max_fallbacks" not in input_kwargs:
+ input_kwargs["max_fallbacks"] = self.max_fallbacks
+ if "fallback_depth" not in input_kwargs:
+ input_kwargs["fallback_depth"] = 0
+
+ try:
+ verbose_router_logger.info("Trying to fallback b/w models")
+
+ # check if client-side fallbacks are used (e.g. fallbacks = ["gpt-3.5-turbo", "claude-3-haiku"] or fallbacks=[{"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hey, how's it going?"}]}]
+ is_non_standard_fallback_format = _check_non_standard_fallback_format(
+ fallbacks=fallbacks
+ )
+
+ if is_non_standard_fallback_format:
+ input_kwargs.update(
+ {
+ "fallback_model_group": fallbacks,
+ "original_model_group": original_model_group,
+ }
+ )
+
+ response = await run_async_fallback(
+ *args,
+ **input_kwargs,
+ )
+
+ return response
+
+ if isinstance(e, litellm.ContextWindowExceededError):
+ if context_window_fallbacks is not None:
+ fallback_model_group: Optional[List[str]] = (
+ self._get_fallback_model_group_from_fallbacks(
+ fallbacks=context_window_fallbacks,
+ model_group=model_group,
+ )
+ )
+ if fallback_model_group is None:
+ raise original_exception
+
+ input_kwargs.update(
+ {
+ "fallback_model_group": fallback_model_group,
+ "original_model_group": original_model_group,
+ }
+ )
+
+ response = await run_async_fallback(
+ *args,
+ **input_kwargs,
+ )
+ return response
+
+ else:
+ error_message = "model={}. context_window_fallbacks={}. fallbacks={}.\n\nSet 'context_window_fallback' - https://docs.litellm.ai/docs/routing#fallbacks".format(
+ model_group, context_window_fallbacks, fallbacks
+ )
+ verbose_router_logger.info(
+ msg="Got 'ContextWindowExceededError'. No context_window_fallback set. Defaulting \
+ to fallbacks, if available.{}".format(
+ error_message
+ )
+ )
+
+ e.message += "\n{}".format(error_message)
+ elif isinstance(e, litellm.ContentPolicyViolationError):
+ if content_policy_fallbacks is not None:
+ fallback_model_group: Optional[List[str]] = (
+ self._get_fallback_model_group_from_fallbacks(
+ fallbacks=content_policy_fallbacks,
+ model_group=model_group,
+ )
+ )
+ if fallback_model_group is None:
+ raise original_exception
+
+ input_kwargs.update(
+ {
+ "fallback_model_group": fallback_model_group,
+ "original_model_group": original_model_group,
+ }
+ )
+
+ response = await run_async_fallback(
+ *args,
+ **input_kwargs,
+ )
+ return response
+ else:
+ error_message = "model={}. content_policy_fallback={}. fallbacks={}.\n\nSet 'content_policy_fallback' - https://docs.litellm.ai/docs/routing#fallbacks".format(
+ model_group, content_policy_fallbacks, fallbacks
+ )
+ verbose_router_logger.info(
+ msg="Got 'ContentPolicyViolationError'. No content_policy_fallback set. Defaulting \
+ to fallbacks, if available.{}".format(
+ error_message
+ )
+ )
+
+ e.message += "\n{}".format(error_message)
+ if fallbacks is not None and model_group is not None:
+ verbose_router_logger.debug(f"inside model fallbacks: {fallbacks}")
+ fallback_model_group, generic_fallback_idx = (
+ get_fallback_model_group(
+ fallbacks=fallbacks, # if fallbacks = [{"gpt-3.5-turbo": ["claude-3-haiku"]}]
+ model_group=cast(str, model_group),
+ )
+ )
+ ## if none, check for generic fallback
+ if (
+ fallback_model_group is None
+ and generic_fallback_idx is not None
+ ):
+ fallback_model_group = fallbacks[generic_fallback_idx]["*"]
+
+ if fallback_model_group is None:
+ verbose_router_logger.info(
+ f"No fallback model group found for original model_group={model_group}. Fallbacks={fallbacks}"
+ )
+ if hasattr(original_exception, "message"):
+ original_exception.message += f"No fallback model group found for original model_group={model_group}. Fallbacks={fallbacks}" # type: ignore
+ raise original_exception
+
+ input_kwargs.update(
+ {
+ "fallback_model_group": fallback_model_group,
+ "original_model_group": original_model_group,
+ }
+ )
+
+ response = await run_async_fallback(
+ *args,
+ **input_kwargs,
+ )
+
+ return response
+ except Exception as new_exception:
+ traceback.print_exc()
+ parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
+ verbose_router_logger.error(
+ "litellm.router.py::async_function_with_fallbacks() - Error occurred while trying to do fallbacks - {}\n{}\n\nDebug Information:\nCooldown Deployments={}".format(
+ str(new_exception),
+ traceback.format_exc(),
+ await _async_get_cooldown_deployments_with_debug_info(
+ litellm_router_instance=self,
+ parent_otel_span=parent_otel_span,
+ ),
+ )
+ )
+ fallback_failure_exception_str = str(new_exception)
+
+ if hasattr(original_exception, "message"):
+ # add the available fallbacks to the exception
+ original_exception.message += ". Received Model Group={}\nAvailable Model Group Fallbacks={}".format( # type: ignore
+ model_group,
+ fallback_model_group,
+ )
+ if len(fallback_failure_exception_str) > 0:
+ original_exception.message += ( # type: ignore
+ "\nError doing the fallback: {}".format(
+ fallback_failure_exception_str
+ )
+ )
+
+ raise original_exception
+
+ def _handle_mock_testing_fallbacks(
+ self,
+ kwargs: dict,
+ model_group: Optional[str] = None,
+ fallbacks: Optional[List] = None,
+ context_window_fallbacks: Optional[List] = None,
+ content_policy_fallbacks: Optional[List] = None,
+ ):
+ """
+ Helper function to raise a litellm Error for mock testing purposes.
+
+ Raises:
+ litellm.InternalServerError: when `mock_testing_fallbacks=True` passed in request params
+ litellm.ContextWindowExceededError: when `mock_testing_context_fallbacks=True` passed in request params
+ litellm.ContentPolicyViolationError: when `mock_testing_content_policy_fallbacks=True` passed in request params
+ """
+ mock_testing_fallbacks = kwargs.pop("mock_testing_fallbacks", None)
+ mock_testing_context_fallbacks = kwargs.pop(
+ "mock_testing_context_fallbacks", None
+ )
+ mock_testing_content_policy_fallbacks = kwargs.pop(
+ "mock_testing_content_policy_fallbacks", None
+ )
+
+ if mock_testing_fallbacks is not None and mock_testing_fallbacks is True:
+ raise litellm.InternalServerError(
+ model=model_group,
+ llm_provider="",
+ message=f"This is a mock exception for model={model_group}, to trigger a fallback. Fallbacks={fallbacks}",
+ )
+ elif (
+ mock_testing_context_fallbacks is not None
+ and mock_testing_context_fallbacks is True
+ ):
+ raise litellm.ContextWindowExceededError(
+ model=model_group,
+ llm_provider="",
+ message=f"This is a mock exception for model={model_group}, to trigger a fallback. \
+ Context_Window_Fallbacks={context_window_fallbacks}",
+ )
+ elif (
+ mock_testing_content_policy_fallbacks is not None
+ and mock_testing_content_policy_fallbacks is True
+ ):
+ raise litellm.ContentPolicyViolationError(
+ model=model_group,
+ llm_provider="",
+ message=f"This is a mock exception for model={model_group}, to trigger a fallback. \
+ Context_Policy_Fallbacks={content_policy_fallbacks}",
+ )
+
+ @tracer.wrap()
+ async def async_function_with_retries(self, *args, **kwargs): # noqa: PLR0915
+ verbose_router_logger.debug("Inside async function with retries.")
+ original_function = kwargs.pop("original_function")
+ fallbacks = kwargs.pop("fallbacks", self.fallbacks)
+ parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
+ context_window_fallbacks = kwargs.pop(
+ "context_window_fallbacks", self.context_window_fallbacks
+ )
+ content_policy_fallbacks = kwargs.pop(
+ "content_policy_fallbacks", self.content_policy_fallbacks
+ )
+ model_group: Optional[str] = kwargs.get("model")
+ num_retries = kwargs.pop("num_retries")
+
+ ## ADD MODEL GROUP SIZE TO METADATA - used for model_group_rate_limit_error tracking
+ _metadata: dict = kwargs.get("metadata") or {}
+ if "model_group" in _metadata and isinstance(_metadata["model_group"], str):
+ model_list = self.get_model_list(model_name=_metadata["model_group"])
+ if model_list is not None:
+ _metadata.update({"model_group_size": len(model_list)})
+
+ verbose_router_logger.debug(
+ f"async function w/ retries: original_function - {original_function}, num_retries - {num_retries}"
+ )
+ try:
+ self._handle_mock_testing_rate_limit_error(
+ model_group=model_group, kwargs=kwargs
+ )
+ # if the function call is successful, no exception will be raised and we'll break out of the loop
+ response = await self.make_call(original_function, *args, **kwargs)
+ response = add_retry_headers_to_response(
+ response=response, attempted_retries=0, max_retries=None
+ )
+ return response
+ except Exception as e:
+ current_attempt = None
+ original_exception = e
+ deployment_num_retries = getattr(e, "num_retries", None)
+
+ if deployment_num_retries is not None and isinstance(
+ deployment_num_retries, int
+ ):
+ num_retries = deployment_num_retries
+ """
+ Retry Logic
+ """
+ _healthy_deployments, _all_deployments = (
+ await self._async_get_healthy_deployments(
+ model=kwargs.get("model") or "",
+ parent_otel_span=parent_otel_span,
+ )
+ )
+
+ # raises an exception if this error should not be retries
+ self.should_retry_this_error(
+ error=e,
+ healthy_deployments=_healthy_deployments,
+ all_deployments=_all_deployments,
+ context_window_fallbacks=context_window_fallbacks,
+ regular_fallbacks=fallbacks,
+ content_policy_fallbacks=content_policy_fallbacks,
+ )
+
+ if (
+ self.retry_policy is not None
+ or self.model_group_retry_policy is not None
+ ):
+ # get num_retries from retry policy
+ _retry_policy_retries = self.get_num_retries_from_retry_policy(
+ exception=original_exception, model_group=kwargs.get("model")
+ )
+ if _retry_policy_retries is not None:
+ num_retries = _retry_policy_retries
+ ## LOGGING
+ if num_retries > 0:
+ kwargs = self.log_retry(kwargs=kwargs, e=original_exception)
+ else:
+ raise
+
+ verbose_router_logger.info(
+ f"Retrying request with num_retries: {num_retries}"
+ )
+ # decides how long to sleep before retry
+ retry_after = self._time_to_sleep_before_retry(
+ e=original_exception,
+ remaining_retries=num_retries,
+ num_retries=num_retries,
+ healthy_deployments=_healthy_deployments,
+ all_deployments=_all_deployments,
+ )
+
+ await asyncio.sleep(retry_after)
+
+ for current_attempt in range(num_retries):
+ try:
+ # if the function call is successful, no exception will be raised and we'll break out of the loop
+ response = await self.make_call(original_function, *args, **kwargs)
+ if inspect.iscoroutinefunction(
+ response
+ ): # async errors are often returned as coroutines
+ response = await response
+
+ response = add_retry_headers_to_response(
+ response=response,
+ attempted_retries=current_attempt + 1,
+ max_retries=num_retries,
+ )
+ return response
+
+ except Exception as e:
+ ## LOGGING
+ kwargs = self.log_retry(kwargs=kwargs, e=e)
+ remaining_retries = num_retries - current_attempt
+ _model: Optional[str] = kwargs.get("model") # type: ignore
+ if _model is not None:
+ _healthy_deployments, _ = (
+ await self._async_get_healthy_deployments(
+ model=_model,
+ parent_otel_span=parent_otel_span,
+ )
+ )
+ else:
+ _healthy_deployments = []
+ _timeout = self._time_to_sleep_before_retry(
+ e=original_exception,
+ remaining_retries=remaining_retries,
+ num_retries=num_retries,
+ healthy_deployments=_healthy_deployments,
+ all_deployments=_all_deployments,
+ )
+ await asyncio.sleep(_timeout)
+
+ if type(original_exception) in litellm.LITELLM_EXCEPTION_TYPES:
+ setattr(original_exception, "max_retries", num_retries)
+ setattr(original_exception, "num_retries", current_attempt)
+
+ raise original_exception
+
+ async def make_call(self, original_function: Any, *args, **kwargs):
+ """
+ Handler for making a call to the .completion()/.embeddings()/etc. functions.
+ """
+ model_group = kwargs.get("model")
+ response = original_function(*args, **kwargs)
+ if inspect.iscoroutinefunction(response) or inspect.isawaitable(response):
+ response = await response
+ ## PROCESS RESPONSE HEADERS
+ response = await self.set_response_headers(
+ response=response, model_group=model_group
+ )
+
+ return response
+
+ def _handle_mock_testing_rate_limit_error(
+ self, kwargs: dict, model_group: Optional[str] = None
+ ):
+ """
+ Helper function to raise a mock litellm.RateLimitError error for testing purposes.
+
+ Raises:
+ litellm.RateLimitError error when `mock_testing_rate_limit_error=True` passed in request params
+ """
+ mock_testing_rate_limit_error: Optional[bool] = kwargs.pop(
+ "mock_testing_rate_limit_error", None
+ )
+
+ available_models = self.get_model_list(model_name=model_group)
+ num_retries: Optional[int] = None
+
+ if available_models is not None and len(available_models) == 1:
+ num_retries = cast(
+ Optional[int], available_models[0]["litellm_params"].get("num_retries")
+ )
+
+ if (
+ mock_testing_rate_limit_error is not None
+ and mock_testing_rate_limit_error is True
+ ):
+ verbose_router_logger.info(
+ f"litellm.router.py::_mock_rate_limit_error() - Raising mock RateLimitError for model={model_group}"
+ )
+ raise litellm.RateLimitError(
+ model=model_group,
+ llm_provider="",
+ message=f"This is a mock exception for model={model_group}, to trigger a rate limit error.",
+ num_retries=num_retries,
+ )
+
+ def should_retry_this_error(
+ self,
+ error: Exception,
+ healthy_deployments: Optional[List] = None,
+ all_deployments: Optional[List] = None,
+ context_window_fallbacks: Optional[List] = None,
+ content_policy_fallbacks: Optional[List] = None,
+ regular_fallbacks: Optional[List] = None,
+ ):
+ """
+ 1. raise an exception for ContextWindowExceededError if context_window_fallbacks is not None
+ 2. raise an exception for ContentPolicyViolationError if content_policy_fallbacks is not None
+
+ 2. raise an exception for RateLimitError if
+ - there are no fallbacks
+ - there are no healthy deployments in the same model group
+ """
+ _num_healthy_deployments = 0
+ if healthy_deployments is not None and isinstance(healthy_deployments, list):
+ _num_healthy_deployments = len(healthy_deployments)
+
+ _num_all_deployments = 0
+ if all_deployments is not None and isinstance(all_deployments, list):
+ _num_all_deployments = len(all_deployments)
+
+ ### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR / CONTENT POLICY VIOLATION ERROR w/ fallbacks available / Bad Request Error
+ if (
+ isinstance(error, litellm.ContextWindowExceededError)
+ and context_window_fallbacks is not None
+ ):
+ raise error
+
+ if (
+ isinstance(error, litellm.ContentPolicyViolationError)
+ and content_policy_fallbacks is not None
+ ):
+ raise error
+
+ if isinstance(error, litellm.NotFoundError):
+ raise error
+ # Error we should only retry if there are other deployments
+ if isinstance(error, openai.RateLimitError):
+ if (
+ _num_healthy_deployments <= 0 # if no healthy deployments
+ and regular_fallbacks is not None # and fallbacks available
+ and len(regular_fallbacks) > 0
+ ):
+ raise error # then raise the error
+
+ if isinstance(error, openai.AuthenticationError):
+ """
+ - if other deployments available -> retry
+ - else -> raise error
+ """
+ if (
+ _num_all_deployments <= 1
+ ): # if there is only 1 deployment for this model group then don't retry
+ raise error # then raise error
+
+ # Do not retry if there are no healthy deployments
+ # just raise the error
+ if _num_healthy_deployments <= 0: # if no healthy deployments
+ raise error
+
+ return True
+
+ def function_with_fallbacks(self, *args, **kwargs):
+ """
+ Sync wrapper for async_function_with_fallbacks
+
+ Wrapped to reduce code duplication and prevent bugs.
+ """
+ return run_async_function(self.async_function_with_fallbacks, *args, **kwargs)
+
+ def _get_fallback_model_group_from_fallbacks(
+ self,
+ fallbacks: List[Dict[str, List[str]]],
+ model_group: Optional[str] = None,
+ ) -> Optional[List[str]]:
+ """
+ Returns the list of fallback models to use for a given model group
+
+ If no fallback model group is found, returns None
+
+ Example:
+ fallbacks = [{"gpt-3.5-turbo": ["gpt-4"]}, {"gpt-4o": ["gpt-3.5-turbo"]}]
+ model_group = "gpt-3.5-turbo"
+ returns: ["gpt-4"]
+ """
+ if model_group is None:
+ return None
+
+ fallback_model_group: Optional[List[str]] = None
+ for item in fallbacks: # [{"gpt-3.5-turbo": ["gpt-4"]}]
+ if list(item.keys())[0] == model_group:
+ fallback_model_group = item[model_group]
+ break
+ return fallback_model_group
+
+ def _time_to_sleep_before_retry(
+ self,
+ e: Exception,
+ remaining_retries: int,
+ num_retries: int,
+ healthy_deployments: Optional[List] = None,
+ all_deployments: Optional[List] = None,
+ ) -> Union[int, float]:
+ """
+ Calculate back-off, then retry
+
+ It should instantly retry only when:
+ 1. there are healthy deployments in the same model group
+ 2. there are fallbacks for the completion call
+ """
+
+ ## base case - single deployment
+ if all_deployments is not None and len(all_deployments) == 1:
+ pass
+ elif (
+ healthy_deployments is not None
+ and isinstance(healthy_deployments, list)
+ and len(healthy_deployments) > 0
+ ):
+ return 0
+
+ response_headers: Optional[httpx.Headers] = None
+ if hasattr(e, "response") and hasattr(e.response, "headers"): # type: ignore
+ response_headers = e.response.headers # type: ignore
+ if hasattr(e, "litellm_response_headers"):
+ response_headers = e.litellm_response_headers # type: ignore
+
+ if response_headers is not None:
+ timeout = litellm._calculate_retry_after(
+ remaining_retries=remaining_retries,
+ max_retries=num_retries,
+ response_headers=response_headers,
+ min_timeout=self.retry_after,
+ )
+
+ else:
+ timeout = litellm._calculate_retry_after(
+ remaining_retries=remaining_retries,
+ max_retries=num_retries,
+ min_timeout=self.retry_after,
+ )
+
+ return timeout
+
+ ### HELPER FUNCTIONS
+
+ async def deployment_callback_on_success(
+ self,
+ kwargs, # kwargs to completion
+ completion_response, # response from completion
+ start_time,
+ end_time, # start/end time
+ ):
+ """
+ Track remaining tpm/rpm quota for model in model_list
+ """
+ try:
+ standard_logging_object: Optional[StandardLoggingPayload] = kwargs.get(
+ "standard_logging_object", None
+ )
+ if standard_logging_object is None:
+ raise ValueError("standard_logging_object is None")
+ if kwargs["litellm_params"].get("metadata") is None:
+ pass
+ else:
+ deployment_name = kwargs["litellm_params"]["metadata"].get(
+ "deployment", None
+ ) # stable name - works for wildcard routes as well
+ model_group = standard_logging_object.get("model_group", None)
+ id = standard_logging_object.get("model_id", None)
+ if model_group is None or id is None:
+ return
+ elif isinstance(id, int):
+ id = str(id)
+
+ parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
+ total_tokens: float = standard_logging_object.get("total_tokens", 0)
+
+ # ------------
+ # Setup values
+ # ------------
+ dt = get_utc_datetime()
+ current_minute = dt.strftime(
+ "%H-%M"
+ ) # use the same timezone regardless of system clock
+
+ tpm_key = RouterCacheEnum.TPM.value.format(
+ id=id, current_minute=current_minute, model=deployment_name
+ )
+ # ------------
+ # Update usage
+ # ------------
+ # update cache
+
+ ## TPM
+ await self.cache.async_increment_cache(
+ key=tpm_key,
+ value=total_tokens,
+ parent_otel_span=parent_otel_span,
+ ttl=RoutingArgs.ttl.value,
+ )
+
+ ## RPM
+ rpm_key = RouterCacheEnum.RPM.value.format(
+ id=id, current_minute=current_minute, model=deployment_name
+ )
+ await self.cache.async_increment_cache(
+ key=rpm_key,
+ value=1,
+ parent_otel_span=parent_otel_span,
+ ttl=RoutingArgs.ttl.value,
+ )
+
+ increment_deployment_successes_for_current_minute(
+ litellm_router_instance=self,
+ deployment_id=id,
+ )
+
+ return tpm_key
+
+ except Exception as e:
+ verbose_router_logger.exception(
+ "litellm.router.Router::deployment_callback_on_success(): Exception occured - {}".format(
+ str(e)
+ )
+ )
+ pass
+
+ def sync_deployment_callback_on_success(
+ self,
+ kwargs, # kwargs to completion
+ completion_response, # response from completion
+ start_time,
+ end_time, # start/end time
+ ) -> Optional[str]:
+ """
+ Tracks the number of successes for a deployment in the current minute (using in-memory cache)
+
+ Returns:
+ - key: str - The key used to increment the cache
+ - None: if no key is found
+ """
+ id = None
+ if kwargs["litellm_params"].get("metadata") is None:
+ pass
+ else:
+ model_group = kwargs["litellm_params"]["metadata"].get("model_group", None)
+ model_info = kwargs["litellm_params"].get("model_info", {}) or {}
+ id = model_info.get("id", None)
+ if model_group is None or id is None:
+ return None
+ elif isinstance(id, int):
+ id = str(id)
+
+ if id is not None:
+ key = increment_deployment_successes_for_current_minute(
+ litellm_router_instance=self,
+ deployment_id=id,
+ )
+ return key
+
+ return None
+
+ def deployment_callback_on_failure(
+ self,
+ kwargs, # kwargs to completion
+ completion_response, # response from completion
+ start_time,
+ end_time, # start/end time
+ ) -> bool:
+ """
+ 2 jobs:
+ - Tracks the number of failures for a deployment in the current minute (using in-memory cache)
+ - Puts the deployment in cooldown if it exceeds the allowed fails / minute
+
+ Returns:
+ - True if the deployment should be put in cooldown
+ - False if the deployment should not be put in cooldown
+ """
+ verbose_router_logger.debug("Router: Entering 'deployment_callback_on_failure'")
+ try:
+ exception = kwargs.get("exception", None)
+ exception_status = getattr(exception, "status_code", "")
+ _model_info = kwargs.get("litellm_params", {}).get("model_info", {})
+
+ exception_headers = litellm.litellm_core_utils.exception_mapping_utils._get_response_headers(
+ original_exception=exception
+ )
+
+ _time_to_cooldown = kwargs.get("litellm_params", {}).get(
+ "cooldown_time", self.cooldown_time
+ )
+
+ if exception_headers is not None:
+
+ _time_to_cooldown = (
+ litellm.utils._get_retry_after_from_exception_header(
+ response_headers=exception_headers
+ )
+ )
+
+ if _time_to_cooldown is None or _time_to_cooldown < 0:
+ # if the response headers did not read it -> set to default cooldown time
+ _time_to_cooldown = self.cooldown_time
+
+ if isinstance(_model_info, dict):
+ deployment_id = _model_info.get("id", None)
+ increment_deployment_failures_for_current_minute(
+ litellm_router_instance=self,
+ deployment_id=deployment_id,
+ )
+ result = _set_cooldown_deployments(
+ litellm_router_instance=self,
+ exception_status=exception_status,
+ original_exception=exception,
+ deployment=deployment_id,
+ time_to_cooldown=_time_to_cooldown,
+ ) # setting deployment_id in cooldown deployments
+
+ return result
+ else:
+ verbose_router_logger.debug(
+ "Router: Exiting 'deployment_callback_on_failure' without cooldown. No model_info found."
+ )
+ return False
+
+ except Exception as e:
+ raise e
+
+ async def async_deployment_callback_on_failure(
+ self, kwargs, completion_response: Optional[Any], start_time, end_time
+ ):
+ """
+ Update RPM usage for a deployment
+ """
+ deployment_name = kwargs["litellm_params"]["metadata"].get(
+ "deployment", None
+ ) # handles wildcard routes - by giving the original name sent to `litellm.completion`
+ model_group = kwargs["litellm_params"]["metadata"].get("model_group", None)
+ model_info = kwargs["litellm_params"].get("model_info", {}) or {}
+ id = model_info.get("id", None)
+ if model_group is None or id is None:
+ return
+ elif isinstance(id, int):
+ id = str(id)
+ parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
+
+ dt = get_utc_datetime()
+ current_minute = dt.strftime(
+ "%H-%M"
+ ) # use the same timezone regardless of system clock
+
+ ## RPM
+ rpm_key = RouterCacheEnum.RPM.value.format(
+ id=id, current_minute=current_minute, model=deployment_name
+ )
+ await self.cache.async_increment_cache(
+ key=rpm_key,
+ value=1,
+ parent_otel_span=parent_otel_span,
+ ttl=RoutingArgs.ttl.value,
+ )
+
+ def log_retry(self, kwargs: dict, e: Exception) -> dict:
+ """
+ When a retry or fallback happens, log the details of the just failed model call - similar to Sentry breadcrumbing
+ """
+ try:
+ # Log failed model as the previous model
+ previous_model = {
+ "exception_type": type(e).__name__,
+ "exception_string": str(e),
+ }
+ for (
+ k,
+ v,
+ ) in (
+ kwargs.items()
+ ): # log everything in kwargs except the old previous_models value - prevent nesting
+ if k not in ["metadata", "messages", "original_function"]:
+ previous_model[k] = v
+ elif k == "metadata" and isinstance(v, dict):
+ previous_model["metadata"] = {} # type: ignore
+ for metadata_k, metadata_v in kwargs["metadata"].items():
+ if metadata_k != "previous_models":
+ previous_model[k][metadata_k] = metadata_v # type: ignore
+
+ # check current size of self.previous_models, if it's larger than 3, remove the first element
+ if len(self.previous_models) > 3:
+ self.previous_models.pop(0)
+
+ self.previous_models.append(previous_model)
+ kwargs["metadata"]["previous_models"] = self.previous_models
+ return kwargs
+ except Exception as e:
+ raise e
+
+ def _update_usage(
+ self, deployment_id: str, parent_otel_span: Optional[Span]
+ ) -> int:
+ """
+ Update deployment rpm for that minute
+
+ Returns:
+ - int: request count
+ """
+ rpm_key = deployment_id
+
+ request_count = self.cache.get_cache(
+ key=rpm_key, parent_otel_span=parent_otel_span, local_only=True
+ )
+ if request_count is None:
+ request_count = 1
+ self.cache.set_cache(
+ key=rpm_key, value=request_count, local_only=True, ttl=60
+ ) # only store for 60s
+ else:
+ request_count += 1
+ self.cache.set_cache(
+ key=rpm_key, value=request_count, local_only=True
+ ) # don't change existing ttl
+
+ return request_count
+
+ def _has_default_fallbacks(self) -> bool:
+ if self.fallbacks is None:
+ return False
+ for fallback in self.fallbacks:
+ if isinstance(fallback, dict):
+ if "*" in fallback:
+ return True
+ return False
+
+ def _should_raise_content_policy_error(
+ self, model: str, response: ModelResponse, kwargs: dict
+ ) -> bool:
+ """
+ Determines if a content policy error should be raised.
+
+ Only raised if a fallback is available.
+
+ Else, original response is returned.
+ """
+ if response.choices and len(response.choices) > 0:
+ if response.choices[0].finish_reason != "content_filter":
+ return False
+
+ content_policy_fallbacks = kwargs.get(
+ "content_policy_fallbacks", self.content_policy_fallbacks
+ )
+
+ ### ONLY RAISE ERROR IF CP FALLBACK AVAILABLE ###
+ if content_policy_fallbacks is not None:
+ fallback_model_group = None
+ for item in content_policy_fallbacks: # [{"gpt-3.5-turbo": ["gpt-4"]}]
+ if list(item.keys())[0] == model:
+ fallback_model_group = item[model]
+ break
+
+ if fallback_model_group is not None:
+ return True
+ elif self._has_default_fallbacks(): # default fallbacks set
+ return True
+
+ verbose_router_logger.info(
+ "Content Policy Error occurred. No available fallbacks. Returning original response. model={}, content_policy_fallbacks={}".format(
+ model, content_policy_fallbacks
+ )
+ )
+ return False
+
+ def _get_healthy_deployments(self, model: str, parent_otel_span: Optional[Span]):
+ _all_deployments: list = []
+ try:
+ _, _all_deployments = self._common_checks_available_deployment( # type: ignore
+ model=model,
+ )
+ if isinstance(_all_deployments, dict):
+ return []
+ except Exception:
+ pass
+
+ unhealthy_deployments = _get_cooldown_deployments(
+ litellm_router_instance=self, parent_otel_span=parent_otel_span
+ )
+ healthy_deployments: list = []
+ for deployment in _all_deployments:
+ if deployment["model_info"]["id"] in unhealthy_deployments:
+ continue
+ else:
+ healthy_deployments.append(deployment)
+
+ return healthy_deployments, _all_deployments
+
+ async def _async_get_healthy_deployments(
+ self, model: str, parent_otel_span: Optional[Span]
+ ) -> Tuple[List[Dict], List[Dict]]:
+ """
+ Returns Tuple of:
+ - Tuple[List[Dict], List[Dict]]:
+ 1. healthy_deployments: list of healthy deployments
+ 2. all_deployments: list of all deployments
+ """
+ _all_deployments: list = []
+ try:
+ _, _all_deployments = self._common_checks_available_deployment( # type: ignore
+ model=model,
+ )
+ if isinstance(_all_deployments, dict):
+ return [], _all_deployments
+ except Exception:
+ pass
+
+ unhealthy_deployments = await _async_get_cooldown_deployments(
+ litellm_router_instance=self, parent_otel_span=parent_otel_span
+ )
+ healthy_deployments: list = []
+ for deployment in _all_deployments:
+ if deployment["model_info"]["id"] in unhealthy_deployments:
+ continue
+ else:
+ healthy_deployments.append(deployment)
+ return healthy_deployments, _all_deployments
+
+ def routing_strategy_pre_call_checks(self, deployment: dict):
+ """
+ Mimics 'async_routing_strategy_pre_call_checks'
+
+ Ensures consistent update rpm implementation for 'usage-based-routing-v2'
+
+ Returns:
+ - None
+
+ Raises:
+ - Rate Limit Exception - If the deployment is over it's tpm/rpm limits
+ """
+ for _callback in litellm.callbacks:
+ if isinstance(_callback, CustomLogger):
+ _callback.pre_call_check(deployment)
+
+ async def async_routing_strategy_pre_call_checks(
+ self,
+ deployment: dict,
+ parent_otel_span: Optional[Span],
+ logging_obj: Optional[LiteLLMLogging] = None,
+ ):
+ """
+ For usage-based-routing-v2, enables running rpm checks before the call is made, inside the semaphore.
+
+ -> makes the calls concurrency-safe, when rpm limits are set for a deployment
+
+ Returns:
+ - None
+
+ Raises:
+ - Rate Limit Exception - If the deployment is over it's tpm/rpm limits
+ """
+ for _callback in litellm.callbacks:
+ if isinstance(_callback, CustomLogger):
+ try:
+ await _callback.async_pre_call_check(deployment, parent_otel_span)
+ except litellm.RateLimitError as e:
+ ## LOG FAILURE EVENT
+ if logging_obj is not None:
+ asyncio.create_task(
+ logging_obj.async_failure_handler(
+ exception=e,
+ traceback_exception=traceback.format_exc(),
+ end_time=time.time(),
+ )
+ )
+ ## LOGGING
+ threading.Thread(
+ target=logging_obj.failure_handler,
+ args=(e, traceback.format_exc()),
+ ).start() # log response
+ _set_cooldown_deployments(
+ litellm_router_instance=self,
+ exception_status=e.status_code,
+ original_exception=e,
+ deployment=deployment["model_info"]["id"],
+ time_to_cooldown=self.cooldown_time,
+ )
+ raise e
+ except Exception as e:
+ ## LOG FAILURE EVENT
+ if logging_obj is not None:
+ asyncio.create_task(
+ logging_obj.async_failure_handler(
+ exception=e,
+ traceback_exception=traceback.format_exc(),
+ end_time=time.time(),
+ )
+ )
+ ## LOGGING
+ threading.Thread(
+ target=logging_obj.failure_handler,
+ args=(e, traceback.format_exc()),
+ ).start() # log response
+ raise e
+
+ async def async_callback_filter_deployments(
+ self,
+ model: str,
+ healthy_deployments: List[dict],
+ messages: Optional[List[AllMessageValues]],
+ parent_otel_span: Optional[Span],
+ request_kwargs: Optional[dict] = None,
+ logging_obj: Optional[LiteLLMLogging] = None,
+ ):
+ """
+ For usage-based-routing-v2, enables running rpm checks before the call is made, inside the semaphore.
+
+ -> makes the calls concurrency-safe, when rpm limits are set for a deployment
+
+ Returns:
+ - None
+
+ Raises:
+ - Rate Limit Exception - If the deployment is over it's tpm/rpm limits
+ """
+ returned_healthy_deployments = healthy_deployments
+ for _callback in litellm.callbacks:
+ if isinstance(_callback, CustomLogger):
+ try:
+ returned_healthy_deployments = (
+ await _callback.async_filter_deployments(
+ model=model,
+ healthy_deployments=returned_healthy_deployments,
+ messages=messages,
+ request_kwargs=request_kwargs,
+ parent_otel_span=parent_otel_span,
+ )
+ )
+ except Exception as e:
+ ## LOG FAILURE EVENT
+ if logging_obj is not None:
+ asyncio.create_task(
+ logging_obj.async_failure_handler(
+ exception=e,
+ traceback_exception=traceback.format_exc(),
+ end_time=time.time(),
+ )
+ )
+ ## LOGGING
+ threading.Thread(
+ target=logging_obj.failure_handler,
+ args=(e, traceback.format_exc()),
+ ).start() # log response
+ raise e
+ return returned_healthy_deployments
+
+ def _generate_model_id(self, model_group: str, litellm_params: dict):
+ """
+ Helper function to consistently generate the same id for a deployment
+
+ - create a string from all the litellm params
+ - hash
+ - use hash as id
+ """
+ concat_str = model_group
+ for k, v in litellm_params.items():
+ if isinstance(k, str):
+ concat_str += k
+ elif isinstance(k, dict):
+ concat_str += json.dumps(k)
+ else:
+ concat_str += str(k)
+
+ if isinstance(v, str):
+ concat_str += v
+ elif isinstance(v, dict):
+ concat_str += json.dumps(v)
+ else:
+ concat_str += str(v)
+
+ hash_object = hashlib.sha256(concat_str.encode())
+
+ return hash_object.hexdigest()
+
+ def _create_deployment(
+ self,
+ deployment_info: dict,
+ _model_name: str,
+ _litellm_params: dict,
+ _model_info: dict,
+ ) -> Optional[Deployment]:
+ """
+ Create a deployment object and add it to the model list
+
+ If the deployment is not active for the current environment, it is ignored
+
+ Returns:
+ - Deployment: The deployment object
+ - None: If the deployment is not active for the current environment (if 'supported_environments' is set in litellm_params)
+ """
+ deployment = Deployment(
+ **deployment_info,
+ model_name=_model_name,
+ litellm_params=LiteLLM_Params(**_litellm_params),
+ model_info=_model_info,
+ )
+
+ ## REGISTER MODEL INFO IN LITELLM MODEL COST MAP
+ _model_name = deployment.litellm_params.model
+ if deployment.litellm_params.custom_llm_provider is not None:
+ _model_name = (
+ deployment.litellm_params.custom_llm_provider + "/" + _model_name
+ )
+
+ litellm.register_model(
+ model_cost={
+ _model_name: _model_info,
+ }
+ )
+
+ ## Check if LLM Deployment is allowed for this deployment
+ if self.deployment_is_active_for_environment(deployment=deployment) is not True:
+ verbose_router_logger.warning(
+ f"Ignoring deployment {deployment.model_name} as it is not active for environment {deployment.model_info['supported_environments']}"
+ )
+ return None
+
+ deployment = self._add_deployment(deployment=deployment)
+
+ model = deployment.to_json(exclude_none=True)
+
+ self.model_list.append(model)
+ return deployment
+
+ def deployment_is_active_for_environment(self, deployment: Deployment) -> bool:
+ """
+ Function to check if a llm deployment is active for a given environment. Allows using the same config.yaml across multople environments
+
+ Requires `LITELLM_ENVIRONMENT` to be set in .env. Valid values for environment:
+ - development
+ - staging
+ - production
+
+ Raises:
+ - ValueError: If LITELLM_ENVIRONMENT is not set in .env or not one of the valid values
+ - ValueError: If supported_environments is not set in model_info or not one of the valid values
+ """
+ if (
+ deployment.model_info is None
+ or "supported_environments" not in deployment.model_info
+ or deployment.model_info["supported_environments"] is None
+ ):
+ return True
+ litellm_environment = get_secret_str(secret_name="LITELLM_ENVIRONMENT")
+ if litellm_environment is None:
+ raise ValueError(
+ "Set 'supported_environments' for model but not 'LITELLM_ENVIRONMENT' set in .env"
+ )
+
+ if litellm_environment not in VALID_LITELLM_ENVIRONMENTS:
+ raise ValueError(
+ f"LITELLM_ENVIRONMENT must be one of {VALID_LITELLM_ENVIRONMENTS}. but set as: {litellm_environment}"
+ )
+
+ for _env in deployment.model_info["supported_environments"]:
+ if _env not in VALID_LITELLM_ENVIRONMENTS:
+ raise ValueError(
+ f"supported_environments must be one of {VALID_LITELLM_ENVIRONMENTS}. but set as: {_env} for deployment: {deployment}"
+ )
+
+ if litellm_environment in deployment.model_info["supported_environments"]:
+ return True
+ return False
+
+ def set_model_list(self, model_list: list):
+ original_model_list = copy.deepcopy(model_list)
+ self.model_list = []
+ # we add api_base/api_key each model so load balancing between azure/gpt on api_base1 and api_base2 works
+
+ for model in original_model_list:
+ _model_name = model.pop("model_name")
+ _litellm_params = model.pop("litellm_params")
+ ## check if litellm params in os.environ
+ if isinstance(_litellm_params, dict):
+ for k, v in _litellm_params.items():
+ if isinstance(v, str) and v.startswith("os.environ/"):
+ _litellm_params[k] = get_secret(v)
+
+ _model_info: dict = model.pop("model_info", {})
+
+ # check if model info has id
+ if "id" not in _model_info:
+ _id = self._generate_model_id(_model_name, _litellm_params)
+ _model_info["id"] = _id
+
+ if _litellm_params.get("organization", None) is not None and isinstance(
+ _litellm_params["organization"], list
+ ): # Addresses https://github.com/BerriAI/litellm/issues/3949
+ for org in _litellm_params["organization"]:
+ _litellm_params["organization"] = org
+ self._create_deployment(
+ deployment_info=model,
+ _model_name=_model_name,
+ _litellm_params=_litellm_params,
+ _model_info=_model_info,
+ )
+ else:
+ self._create_deployment(
+ deployment_info=model,
+ _model_name=_model_name,
+ _litellm_params=_litellm_params,
+ _model_info=_model_info,
+ )
+
+ verbose_router_logger.debug(
+ f"\nInitialized Model List {self.get_model_names()}"
+ )
+ self.model_names = [m["model_name"] for m in model_list]
+
+ def _add_deployment(self, deployment: Deployment) -> Deployment:
+ import os
+
+ #### DEPLOYMENT NAMES INIT ########
+ self.deployment_names.append(deployment.litellm_params.model)
+ ############ Users can either pass tpm/rpm as a litellm_param or a router param ###########
+ # for get_available_deployment, we use the litellm_param["rpm"]
+ # in this snippet we also set rpm to be a litellm_param
+ if (
+ deployment.litellm_params.rpm is None
+ and getattr(deployment, "rpm", None) is not None
+ ):
+ deployment.litellm_params.rpm = getattr(deployment, "rpm")
+
+ if (
+ deployment.litellm_params.tpm is None
+ and getattr(deployment, "tpm", None) is not None
+ ):
+ deployment.litellm_params.tpm = getattr(deployment, "tpm")
+
+ #### VALIDATE MODEL ########
+ # check if model provider in supported providers
+ (
+ _model,
+ custom_llm_provider,
+ dynamic_api_key,
+ api_base,
+ ) = litellm.get_llm_provider(
+ model=deployment.litellm_params.model,
+ custom_llm_provider=deployment.litellm_params.get(
+ "custom_llm_provider", None
+ ),
+ )
+
+ # Check if user is trying to use model_name == "*"
+ # this is a catch all model for their specific api key
+ # if deployment.model_name == "*":
+ # if deployment.litellm_params.model == "*":
+ # # user wants to pass through all requests to litellm.acompletion for unknown deployments
+ # self.router_general_settings.pass_through_all_models = True
+ # else:
+ # self.default_deployment = deployment.to_json(exclude_none=True)
+ # Check if user is using provider specific wildcard routing
+ # example model_name = "databricks/*" or model_name = "anthropic/*"
+ if "*" in deployment.model_name:
+ # store this as a regex pattern - all deployments matching this pattern will be sent to this deployment
+ # Store deployment.model_name as a regex pattern
+ self.pattern_router.add_pattern(
+ deployment.model_name, deployment.to_json(exclude_none=True)
+ )
+ if deployment.model_info.id:
+ self.provider_default_deployment_ids.append(deployment.model_info.id)
+
+ # Azure GPT-Vision Enhancements, users can pass os.environ/
+ data_sources = deployment.litellm_params.get("dataSources", []) or []
+
+ for data_source in data_sources:
+ params = data_source.get("parameters", {})
+ for param_key in ["endpoint", "key"]:
+ # if endpoint or key set for Azure GPT Vision Enhancements, check if it's an env var
+ if param_key in params and params[param_key].startswith("os.environ/"):
+ env_name = params[param_key].replace("os.environ/", "")
+ params[param_key] = os.environ.get(env_name, "")
+
+ # done reading model["litellm_params"]
+ if custom_llm_provider not in litellm.provider_list:
+ raise Exception(f"Unsupported provider - {custom_llm_provider}")
+
+ # # init OpenAI, Azure clients
+ # InitalizeOpenAISDKClient.set_client(
+ # litellm_router_instance=self, model=deployment.to_json(exclude_none=True)
+ # )
+
+ self._initialize_deployment_for_pass_through(
+ deployment=deployment,
+ custom_llm_provider=custom_llm_provider,
+ model=deployment.litellm_params.model,
+ )
+
+ return deployment
+
+ def _initialize_deployment_for_pass_through(
+ self, deployment: Deployment, custom_llm_provider: str, model: str
+ ):
+ """
+ Optional: Initialize deployment for pass-through endpoints if `deployment.litellm_params.use_in_pass_through` is True
+
+ Each provider uses diff .env vars for pass-through endpoints, this helper uses the deployment credentials to set the .env vars for pass-through endpoints
+ """
+ if deployment.litellm_params.use_in_pass_through is True:
+ if custom_llm_provider == "vertex_ai":
+ from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import (
+ vertex_pass_through_router,
+ )
+
+ if (
+ deployment.litellm_params.vertex_project is None
+ or deployment.litellm_params.vertex_location is None
+ or deployment.litellm_params.vertex_credentials is None
+ ):
+ raise ValueError(
+ "vertex_project, vertex_location, and vertex_credentials must be set in litellm_params for pass-through endpoints"
+ )
+ vertex_pass_through_router.add_vertex_credentials(
+ project_id=deployment.litellm_params.vertex_project,
+ location=deployment.litellm_params.vertex_location,
+ vertex_credentials=deployment.litellm_params.vertex_credentials,
+ )
+ else:
+ from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import (
+ passthrough_endpoint_router,
+ )
+
+ passthrough_endpoint_router.set_pass_through_credentials(
+ custom_llm_provider=custom_llm_provider,
+ api_base=deployment.litellm_params.api_base,
+ api_key=deployment.litellm_params.api_key,
+ )
+ pass
+ pass
+
+ def add_deployment(self, deployment: Deployment) -> Optional[Deployment]:
+ """
+ Parameters:
+ - deployment: Deployment - the deployment to be added to the Router
+
+ Returns:
+ - The added deployment
+ - OR None (if deployment already exists)
+ """
+ # check if deployment already exists
+
+ if deployment.model_info.id in self.get_model_ids():
+ return None
+
+ # add to model list
+ _deployment = deployment.to_json(exclude_none=True)
+ self.model_list.append(_deployment)
+
+ # initialize client
+ self._add_deployment(deployment=deployment)
+
+ # add to model names
+ self.model_names.append(deployment.model_name)
+ return deployment
+
+ def upsert_deployment(self, deployment: Deployment) -> Optional[Deployment]:
+ """
+ Add or update deployment
+ Parameters:
+ - deployment: Deployment - the deployment to be added to the Router
+
+ Returns:
+ - The added/updated deployment
+ """
+ # check if deployment already exists
+ _deployment_model_id = deployment.model_info.id or ""
+
+ _deployment_on_router: Optional[Deployment] = self.get_deployment(
+ model_id=_deployment_model_id
+ )
+ if _deployment_on_router is not None:
+ # deployment with this model_id exists on the router
+ if deployment.litellm_params == _deployment_on_router.litellm_params:
+ # No need to update
+ return None
+
+ # if there is a new litellm param -> then update the deployment
+ # remove the previous deployment
+ removal_idx: Optional[int] = None
+ for idx, model in enumerate(self.model_list):
+ if model["model_info"]["id"] == deployment.model_info.id:
+ removal_idx = idx
+
+ if removal_idx is not None:
+ self.model_list.pop(removal_idx)
+
+ # if the model_id is not in router
+ self.add_deployment(deployment=deployment)
+ return deployment
+
+ def delete_deployment(self, id: str) -> Optional[Deployment]:
+ """
+ Parameters:
+ - id: str - the id of the deployment to be deleted
+
+ Returns:
+ - The deleted deployment
+ - OR None (if deleted deployment not found)
+ """
+ deployment_idx = None
+ for idx, m in enumerate(self.model_list):
+ if m["model_info"]["id"] == id:
+ deployment_idx = idx
+
+ try:
+ if deployment_idx is not None:
+ item = self.model_list.pop(deployment_idx)
+ return item
+ else:
+ return None
+ except Exception:
+ return None
+
+ def get_deployment(self, model_id: str) -> Optional[Deployment]:
+ """
+ Returns -> Deployment or None
+
+ Raise Exception -> if model found in invalid format
+ """
+ for model in self.model_list:
+ if "model_info" in model and "id" in model["model_info"]:
+ if model_id == model["model_info"]["id"]:
+ if isinstance(model, dict):
+ return Deployment(**model)
+ elif isinstance(model, Deployment):
+ return model
+ else:
+ raise Exception("Model invalid format - {}".format(type(model)))
+ return None
+
+ def get_deployment_credentials(self, model_id: str) -> Optional[dict]:
+ """
+ Returns -> dict of credentials for a given model id
+ """
+ deployment = self.get_deployment(model_id=model_id)
+ if deployment is None:
+ return None
+ return CredentialLiteLLMParams(
+ **deployment.litellm_params.model_dump(exclude_none=True)
+ ).model_dump(exclude_none=True)
+
+ def get_deployment_by_model_group_name(
+ self, model_group_name: str
+ ) -> Optional[Deployment]:
+ """
+ Returns -> Deployment or None
+
+ Raise Exception -> if model found in invalid format
+ """
+ for model in self.model_list:
+ if model["model_name"] == model_group_name:
+ if isinstance(model, dict):
+ return Deployment(**model)
+ elif isinstance(model, Deployment):
+ return model
+ else:
+ raise Exception("Model Name invalid - {}".format(type(model)))
+ return None
+
+ @overload
+ def get_router_model_info(
+ self, deployment: dict, received_model_name: str, id: None = None
+ ) -> ModelMapInfo:
+ pass
+
+ @overload
+ def get_router_model_info(
+ self, deployment: None, received_model_name: str, id: str
+ ) -> ModelMapInfo:
+ pass
+
+ def get_router_model_info(
+ self,
+ deployment: Optional[dict],
+ received_model_name: str,
+ id: Optional[str] = None,
+ ) -> ModelMapInfo:
+ """
+ For a given model id, return the model info (max tokens, input cost, output cost, etc.).
+
+ Augment litellm info with additional params set in `model_info`.
+
+ For azure models, ignore the `model:`. Only set max tokens, cost values if base_model is set.
+
+ Returns
+ - ModelInfo - If found -> typed dict with max tokens, input cost, etc.
+
+ Raises:
+ - ValueError -> If model is not mapped yet
+ """
+ if id is not None:
+ _deployment = self.get_deployment(model_id=id)
+ if _deployment is not None:
+ deployment = _deployment.model_dump(exclude_none=True)
+
+ if deployment is None:
+ raise ValueError("Deployment not found")
+
+ ## GET BASE MODEL
+ base_model = deployment.get("model_info", {}).get("base_model", None)
+ if base_model is None:
+ base_model = deployment.get("litellm_params", {}).get("base_model", None)
+
+ model = base_model
+
+ ## GET PROVIDER
+ _model, custom_llm_provider, _, _ = litellm.get_llm_provider(
+ model=deployment.get("litellm_params", {}).get("model", ""),
+ litellm_params=LiteLLM_Params(**deployment.get("litellm_params", {})),
+ )
+
+ ## SET MODEL TO 'model=' - if base_model is None + not azure
+ if custom_llm_provider == "azure" and base_model is None:
+ verbose_router_logger.error(
+ "Could not identify azure model. Set azure 'base_model' for accurate max tokens, cost tracking, etc.- https://docs.litellm.ai/docs/proxy/cost_tracking#spend-tracking-for-azure-openai-models"
+ )
+ elif custom_llm_provider != "azure":
+ model = _model
+
+ potential_models = self.pattern_router.route(received_model_name)
+ if "*" in model and potential_models is not None: # if wildcard route
+ for potential_model in potential_models:
+ try:
+ if potential_model.get("model_info", {}).get(
+ "id"
+ ) == deployment.get("model_info", {}).get("id"):
+ model = potential_model.get("litellm_params", {}).get(
+ "model"
+ )
+ break
+ except Exception:
+ pass
+
+ ## GET LITELLM MODEL INFO - raises exception, if model is not mapped
+ if not model.startswith("{}/".format(custom_llm_provider)):
+ model_info_name = "{}/{}".format(custom_llm_provider, model)
+ else:
+ model_info_name = model
+
+ model_info = litellm.get_model_info(model=model_info_name)
+
+ ## CHECK USER SET MODEL INFO
+ user_model_info = deployment.get("model_info", {})
+
+ model_info.update(user_model_info)
+
+ return model_info
+
+ def get_model_info(self, id: str) -> Optional[dict]:
+ """
+ For a given model id, return the model info
+
+ Returns
+ - dict: the model in list with 'model_name', 'litellm_params', Optional['model_info']
+ - None: could not find deployment in list
+ """
+ for model in self.model_list:
+ if "model_info" in model and "id" in model["model_info"]:
+ if id == model["model_info"]["id"]:
+ return model
+ return None
+
+ def get_model_group(self, id: str) -> Optional[List]:
+ """
+ Return list of all models in the same model group as that model id
+ """
+
+ model_info = self.get_model_info(id=id)
+ if model_info is None:
+ return None
+
+ model_name = model_info["model_name"]
+ return self.get_model_list(model_name=model_name)
+
+ def _set_model_group_info( # noqa: PLR0915
+ self, model_group: str, user_facing_model_group_name: str
+ ) -> Optional[ModelGroupInfo]:
+ """
+ For a given model group name, return the combined model info
+
+ Returns:
+ - ModelGroupInfo if able to construct a model group
+ - None if error constructing model group info
+ """
+ model_group_info: Optional[ModelGroupInfo] = None
+
+ total_tpm: Optional[int] = None
+ total_rpm: Optional[int] = None
+ configurable_clientside_auth_params: CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS = None
+ model_list = self.get_model_list(model_name=model_group)
+ if model_list is None:
+ return None
+ for model in model_list:
+ is_match = False
+ if (
+ "model_name" in model and model["model_name"] == model_group
+ ): # exact match
+ is_match = True
+ elif (
+ "model_name" in model
+ and self.pattern_router.route(model_group) is not None
+ ): # wildcard model
+ is_match = True
+
+ if not is_match:
+ continue
+ # model in model group found #
+ litellm_params = LiteLLM_Params(**model["litellm_params"]) # type: ignore
+ # get configurable clientside auth params
+ configurable_clientside_auth_params = (
+ litellm_params.configurable_clientside_auth_params
+ )
+ # get model tpm
+ _deployment_tpm: Optional[int] = None
+ if _deployment_tpm is None:
+ _deployment_tpm = model.get("tpm", None) # type: ignore
+ if _deployment_tpm is None:
+ _deployment_tpm = model.get("litellm_params", {}).get("tpm", None) # type: ignore
+ if _deployment_tpm is None:
+ _deployment_tpm = model.get("model_info", {}).get("tpm", None) # type: ignore
+
+ # get model rpm
+ _deployment_rpm: Optional[int] = None
+ if _deployment_rpm is None:
+ _deployment_rpm = model.get("rpm", None) # type: ignore
+ if _deployment_rpm is None:
+ _deployment_rpm = model.get("litellm_params", {}).get("rpm", None) # type: ignore
+ if _deployment_rpm is None:
+ _deployment_rpm = model.get("model_info", {}).get("rpm", None) # type: ignore
+
+ # get model info
+ try:
+ model_info = litellm.get_model_info(model=litellm_params.model)
+ except Exception:
+ model_info = None
+ # get llm provider
+ litellm_model, llm_provider = "", ""
+ try:
+ litellm_model, llm_provider, _, _ = litellm.get_llm_provider(
+ model=litellm_params.model,
+ custom_llm_provider=litellm_params.custom_llm_provider,
+ )
+ except litellm.exceptions.BadRequestError as e:
+ verbose_router_logger.error(
+ "litellm.router.py::get_model_group_info() - {}".format(str(e))
+ )
+
+ if model_info is None:
+ supported_openai_params = litellm.get_supported_openai_params(
+ model=litellm_model, custom_llm_provider=llm_provider
+ )
+ if supported_openai_params is None:
+ supported_openai_params = []
+ model_info = ModelMapInfo(
+ key=model_group,
+ max_tokens=None,
+ max_input_tokens=None,
+ max_output_tokens=None,
+ input_cost_per_token=0,
+ output_cost_per_token=0,
+ litellm_provider=llm_provider,
+ mode="chat",
+ supported_openai_params=supported_openai_params,
+ supports_system_messages=None,
+ )
+
+ if model_group_info is None:
+ model_group_info = ModelGroupInfo(
+ model_group=user_facing_model_group_name, providers=[llm_provider], **model_info # type: ignore
+ )
+ else:
+ # if max_input_tokens > curr
+ # if max_output_tokens > curr
+ # if input_cost_per_token > curr
+ # if output_cost_per_token > curr
+ # supports_parallel_function_calling == True
+ # supports_vision == True
+ # supports_function_calling == True
+ if llm_provider not in model_group_info.providers:
+ model_group_info.providers.append(llm_provider)
+ if (
+ model_info.get("max_input_tokens", None) is not None
+ and model_info["max_input_tokens"] is not None
+ and (
+ model_group_info.max_input_tokens is None
+ or model_info["max_input_tokens"]
+ > model_group_info.max_input_tokens
+ )
+ ):
+ model_group_info.max_input_tokens = model_info["max_input_tokens"]
+ if (
+ model_info.get("max_output_tokens", None) is not None
+ and model_info["max_output_tokens"] is not None
+ and (
+ model_group_info.max_output_tokens is None
+ or model_info["max_output_tokens"]
+ > model_group_info.max_output_tokens
+ )
+ ):
+ model_group_info.max_output_tokens = model_info["max_output_tokens"]
+ if model_info.get("input_cost_per_token", None) is not None and (
+ model_group_info.input_cost_per_token is None
+ or model_info["input_cost_per_token"]
+ > model_group_info.input_cost_per_token
+ ):
+ model_group_info.input_cost_per_token = model_info[
+ "input_cost_per_token"
+ ]
+ if model_info.get("output_cost_per_token", None) is not None and (
+ model_group_info.output_cost_per_token is None
+ or model_info["output_cost_per_token"]
+ > model_group_info.output_cost_per_token
+ ):
+ model_group_info.output_cost_per_token = model_info[
+ "output_cost_per_token"
+ ]
+ if (
+ model_info.get("supports_parallel_function_calling", None)
+ is not None
+ and model_info["supports_parallel_function_calling"] is True # type: ignore
+ ):
+ model_group_info.supports_parallel_function_calling = True
+ if (
+ model_info.get("supports_vision", None) is not None
+ and model_info["supports_vision"] is True # type: ignore
+ ):
+ model_group_info.supports_vision = True
+ if (
+ model_info.get("supports_function_calling", None) is not None
+ and model_info["supports_function_calling"] is True # type: ignore
+ ):
+ model_group_info.supports_function_calling = True
+ if (
+ model_info.get("supported_openai_params", None) is not None
+ and model_info["supported_openai_params"] is not None
+ ):
+ model_group_info.supported_openai_params = model_info[
+ "supported_openai_params"
+ ]
+ if model_info.get("tpm", None) is not None and _deployment_tpm is None:
+ _deployment_tpm = model_info.get("tpm")
+ if model_info.get("rpm", None) is not None and _deployment_rpm is None:
+ _deployment_rpm = model_info.get("rpm")
+
+ if _deployment_tpm is not None:
+ if total_tpm is None:
+ total_tpm = 0
+ total_tpm += _deployment_tpm # type: ignore
+
+ if _deployment_rpm is not None:
+ if total_rpm is None:
+ total_rpm = 0
+ total_rpm += _deployment_rpm # type: ignore
+ if model_group_info is not None:
+ ## UPDATE WITH TOTAL TPM/RPM FOR MODEL GROUP
+ if total_tpm is not None:
+ model_group_info.tpm = total_tpm
+
+ if total_rpm is not None:
+ model_group_info.rpm = total_rpm
+
+ ## UPDATE WITH CONFIGURABLE CLIENTSIDE AUTH PARAMS FOR MODEL GROUP
+ if configurable_clientside_auth_params is not None:
+ model_group_info.configurable_clientside_auth_params = (
+ configurable_clientside_auth_params
+ )
+
+ return model_group_info
+
+ def get_model_group_info(self, model_group: str) -> Optional[ModelGroupInfo]:
+ """
+ For a given model group name, return the combined model info
+
+ Returns:
+ - ModelGroupInfo if able to construct a model group
+ - None if error constructing model group info or hidden model group
+ """
+ ## Check if model group alias
+ if model_group in self.model_group_alias:
+ item = self.model_group_alias[model_group]
+ if isinstance(item, str):
+ _router_model_group = item
+ elif isinstance(item, dict):
+ if item["hidden"] is True:
+ return None
+ else:
+ _router_model_group = item["model"]
+ else:
+ return None
+
+ return self._set_model_group_info(
+ model_group=_router_model_group,
+ user_facing_model_group_name=model_group,
+ )
+
+ ## Check if actual model
+ return self._set_model_group_info(
+ model_group=model_group, user_facing_model_group_name=model_group
+ )
+
+ async def get_model_group_usage(
+ self, model_group: str
+ ) -> Tuple[Optional[int], Optional[int]]:
+ """
+ Returns current tpm/rpm usage for model group
+
+ Parameters:
+ - model_group: str - the received model name from the user (can be a wildcard route).
+
+ Returns:
+ - usage: Tuple[tpm, rpm]
+ """
+ dt = get_utc_datetime()
+ current_minute = dt.strftime(
+ "%H-%M"
+ ) # use the same timezone regardless of system clock
+ tpm_keys: List[str] = []
+ rpm_keys: List[str] = []
+
+ model_list = self.get_model_list(model_name=model_group)
+ if model_list is None: # no matching deployments
+ return None, None
+
+ for model in model_list:
+ id: Optional[str] = model.get("model_info", {}).get("id") # type: ignore
+ litellm_model: Optional[str] = model["litellm_params"].get(
+ "model"
+ ) # USE THE MODEL SENT TO litellm.completion() - consistent with how global_router cache is written.
+ if id is None or litellm_model is None:
+ continue
+ tpm_keys.append(
+ RouterCacheEnum.TPM.value.format(
+ id=id,
+ model=litellm_model,
+ current_minute=current_minute,
+ )
+ )
+ rpm_keys.append(
+ RouterCacheEnum.RPM.value.format(
+ id=id,
+ model=litellm_model,
+ current_minute=current_minute,
+ )
+ )
+ combined_tpm_rpm_keys = tpm_keys + rpm_keys
+
+ combined_tpm_rpm_values = await self.cache.async_batch_get_cache(
+ keys=combined_tpm_rpm_keys
+ )
+ if combined_tpm_rpm_values is None:
+ return None, None
+
+ tpm_usage_list: Optional[List] = combined_tpm_rpm_values[: len(tpm_keys)]
+ rpm_usage_list: Optional[List] = combined_tpm_rpm_values[len(tpm_keys) :]
+
+ ## TPM
+ tpm_usage: Optional[int] = None
+ if tpm_usage_list is not None:
+ for t in tpm_usage_list:
+ if isinstance(t, int):
+ if tpm_usage is None:
+ tpm_usage = 0
+ tpm_usage += t
+ ## RPM
+ rpm_usage: Optional[int] = None
+ if rpm_usage_list is not None:
+ for t in rpm_usage_list:
+ if isinstance(t, int):
+ if rpm_usage is None:
+ rpm_usage = 0
+ rpm_usage += t
+ return tpm_usage, rpm_usage
+
+ @lru_cache(maxsize=64)
+ def _cached_get_model_group_info(
+ self, model_group: str
+ ) -> Optional[ModelGroupInfo]:
+ """
+ Cached version of get_model_group_info, uses @lru_cache wrapper
+
+ This is a speed optimization, since set_response_headers makes a call to get_model_group_info on every request
+ """
+ return self.get_model_group_info(model_group)
+
+ async def get_remaining_model_group_usage(self, model_group: str) -> Dict[str, int]:
+ model_group_info = self._cached_get_model_group_info(model_group)
+
+ if model_group_info is not None and model_group_info.tpm is not None:
+ tpm_limit = model_group_info.tpm
+ else:
+ tpm_limit = None
+
+ if model_group_info is not None and model_group_info.rpm is not None:
+ rpm_limit = model_group_info.rpm
+ else:
+ rpm_limit = None
+
+ if tpm_limit is None and rpm_limit is None:
+ return {}
+
+ current_tpm, current_rpm = await self.get_model_group_usage(model_group)
+
+ returned_dict = {}
+ if tpm_limit is not None:
+ returned_dict["x-ratelimit-remaining-tokens"] = tpm_limit - (
+ current_tpm or 0
+ )
+ returned_dict["x-ratelimit-limit-tokens"] = tpm_limit
+ if rpm_limit is not None:
+ returned_dict["x-ratelimit-remaining-requests"] = rpm_limit - (
+ current_rpm or 0
+ )
+ returned_dict["x-ratelimit-limit-requests"] = rpm_limit
+
+ return returned_dict
+
+ async def set_response_headers(
+ self, response: Any, model_group: Optional[str] = None
+ ) -> Any:
+ """
+ Add the most accurate rate limit headers for a given model response.
+
+ ## TODO: add model group rate limit headers
+ # - if healthy_deployments > 1, return model group rate limit headers
+ # - else return the model's rate limit headers
+ """
+ if (
+ isinstance(response, BaseModel)
+ and hasattr(response, "_hidden_params")
+ and isinstance(response._hidden_params, dict) # type: ignore
+ ):
+ response._hidden_params.setdefault("additional_headers", {}) # type: ignore
+ response._hidden_params["additional_headers"][ # type: ignore
+ "x-litellm-model-group"
+ ] = model_group
+
+ additional_headers = response._hidden_params["additional_headers"] # type: ignore
+
+ if (
+ "x-ratelimit-remaining-tokens" not in additional_headers
+ and "x-ratelimit-remaining-requests" not in additional_headers
+ and model_group is not None
+ ):
+ remaining_usage = await self.get_remaining_model_group_usage(
+ model_group
+ )
+
+ for header, value in remaining_usage.items():
+ if value is not None:
+ additional_headers[header] = value
+ return response
+
+ def get_model_ids(self, model_name: Optional[str] = None) -> List[str]:
+ """
+ if 'model_name' is none, returns all.
+
+ Returns list of model id's.
+ """
+ ids = []
+ for model in self.model_list:
+ if "model_info" in model and "id" in model["model_info"]:
+ id = model["model_info"]["id"]
+ if model_name is not None and model["model_name"] == model_name:
+ ids.append(id)
+ elif model_name is None:
+ ids.append(id)
+ return ids
+
+ def _get_all_deployments(
+ self, model_name: str, model_alias: Optional[str] = None
+ ) -> List[DeploymentTypedDict]:
+ """
+ Return all deployments of a model name
+
+ Used for accurate 'get_model_list'.
+ """
+ returned_models: List[DeploymentTypedDict] = []
+ for model in self.model_list:
+ if model_name is not None and model["model_name"] == model_name:
+ if model_alias is not None:
+ alias_model = copy.deepcopy(model)
+ alias_model["model_name"] = model_alias
+ returned_models.append(alias_model)
+ else:
+ returned_models.append(model)
+
+ return returned_models
+
+ def get_model_names(self, team_id: Optional[str] = None) -> List[str]:
+ """
+ Returns all possible model names for the router, including models defined via model_group_alias.
+
+ If a team_id is provided, only deployments configured with that team_id (i.e. team‐specific models)
+ will yield their team public name.
+ """
+ deployments = self.get_model_list() or []
+ model_names = []
+
+ for deployment in deployments:
+ model_info = deployment.get("model_info")
+ if self._is_team_specific_model(model_info):
+ team_model_name = self._get_team_specific_model(
+ deployment=deployment, team_id=team_id
+ )
+ if team_model_name:
+ model_names.append(team_model_name)
+ else:
+ model_names.append(deployment.get("model_name", ""))
+
+ return model_names
+
+ def _get_team_specific_model(
+ self, deployment: DeploymentTypedDict, team_id: Optional[str] = None
+ ) -> Optional[str]:
+ """
+ Get the team-specific model name if team_id matches the deployment.
+
+ Args:
+ deployment: DeploymentTypedDict - The model deployment
+ team_id: Optional[str] - If passed, will return router models set with a `team_id` matching the passed `team_id`.
+
+ Returns:
+ str: The `team_public_model_name` if team_id matches
+ None: If team_id doesn't match or no team info exists
+ """
+ model_info: Optional[Dict] = deployment.get("model_info") or {}
+ if model_info is None:
+ return None
+ if team_id == model_info.get("team_id"):
+ return model_info.get("team_public_model_name")
+ return None
+
+ def _is_team_specific_model(self, model_info: Optional[Dict]) -> bool:
+ """
+ Check if model info contains team-specific configuration.
+
+ Args:
+ model_info: Model information dictionary
+
+ Returns:
+ bool: True if model has team-specific configuration
+ """
+ return bool(model_info and model_info.get("team_id"))
+
+ def get_model_list_from_model_alias(
+ self, model_name: Optional[str] = None
+ ) -> List[DeploymentTypedDict]:
+ """
+ Helper function to get model list from model alias.
+
+ Used by `.get_model_list` to get model list from model alias.
+ """
+ returned_models: List[DeploymentTypedDict] = []
+ for model_alias, model_value in self.model_group_alias.items():
+ if model_name is not None and model_alias != model_name:
+ continue
+ if isinstance(model_value, str):
+ _router_model_name: str = model_value
+ elif isinstance(model_value, dict):
+ _model_value = RouterModelGroupAliasItem(**model_value) # type: ignore
+ if _model_value["hidden"] is True:
+ continue
+ else:
+ _router_model_name = _model_value["model"]
+ else:
+ continue
+
+ returned_models.extend(
+ self._get_all_deployments(
+ model_name=_router_model_name, model_alias=model_alias
+ )
+ )
+
+ return returned_models
+
+ def get_model_list(
+ self, model_name: Optional[str] = None
+ ) -> Optional[List[DeploymentTypedDict]]:
+ """
+ Includes router model_group_alias'es as well
+ """
+ if hasattr(self, "model_list"):
+ returned_models: List[DeploymentTypedDict] = []
+
+ if model_name is not None:
+ returned_models.extend(self._get_all_deployments(model_name=model_name))
+
+ if hasattr(self, "model_group_alias"):
+ returned_models.extend(
+ self.get_model_list_from_model_alias(model_name=model_name)
+ )
+
+ if len(returned_models) == 0: # check if wildcard route
+ potential_wildcard_models = self.pattern_router.route(model_name)
+ if potential_wildcard_models is not None:
+ returned_models.extend(
+ [DeploymentTypedDict(**m) for m in potential_wildcard_models] # type: ignore
+ )
+
+ if model_name is None:
+ returned_models += self.model_list
+
+ return returned_models
+
+ return returned_models
+ return None
+
+ def get_model_access_groups(
+ self, model_name: Optional[str] = None, model_access_group: Optional[str] = None
+ ) -> Dict[str, List[str]]:
+ """
+ If model_name is provided, only return access groups for that model.
+
+ Parameters:
+ - model_name: Optional[str] - the received model name from the user (can be a wildcard route). If set, will only return access groups for that model.
+ - model_access_group: Optional[str] - the received model access group from the user. If set, will only return models for that access group.
+ """
+ from collections import defaultdict
+
+ access_groups = defaultdict(list)
+
+ model_list = self.get_model_list(model_name=model_name)
+ if model_list:
+ for m in model_list:
+ _model_info = m.get("model_info")
+ if _model_info:
+ for group in _model_info.get("access_groups", []) or []:
+ if model_access_group is not None:
+ if group == model_access_group:
+ model_name = m["model_name"]
+ access_groups[group].append(model_name)
+ else:
+ model_name = m["model_name"]
+ access_groups[group].append(model_name)
+
+ return access_groups
+
+ def _is_model_access_group_for_wildcard_route(
+ self, model_access_group: str
+ ) -> bool:
+ """
+ Return True if model access group is a wildcard route
+ """
+ # GET ACCESS GROUPS
+ access_groups = self.get_model_access_groups(
+ model_access_group=model_access_group
+ )
+
+ if len(access_groups) == 0:
+ return False
+
+ models = access_groups.get(model_access_group, [])
+
+ for model in models:
+ # CHECK IF MODEL ACCESS GROUP IS A WILDCARD ROUTE
+ if self.pattern_router.route(request=model) is not None:
+ return True
+
+ return False
+
+ def get_settings(self):
+ """
+ Get router settings method, returns a dictionary of the settings and their values.
+ For example get the set values for routing_strategy_args, routing_strategy, allowed_fails, cooldown_time, num_retries, timeout, max_retries, retry_after
+ """
+ _all_vars = vars(self)
+ _settings_to_return = {}
+ vars_to_include = [
+ "routing_strategy_args",
+ "routing_strategy",
+ "allowed_fails",
+ "cooldown_time",
+ "num_retries",
+ "timeout",
+ "max_retries",
+ "retry_after",
+ "fallbacks",
+ "context_window_fallbacks",
+ "model_group_retry_policy",
+ ]
+
+ for var in vars_to_include:
+ if var in _all_vars:
+ _settings_to_return[var] = _all_vars[var]
+ if (
+ var == "routing_strategy_args"
+ and self.routing_strategy == "latency-based-routing"
+ ):
+ _settings_to_return[var] = self.lowestlatency_logger.routing_args.json()
+ return _settings_to_return
+
+ def update_settings(self, **kwargs):
+ """
+ Update the router settings.
+ """
+ # only the following settings are allowed to be configured
+ _allowed_settings = [
+ "routing_strategy_args",
+ "routing_strategy",
+ "allowed_fails",
+ "cooldown_time",
+ "num_retries",
+ "timeout",
+ "max_retries",
+ "retry_after",
+ "fallbacks",
+ "context_window_fallbacks",
+ "model_group_retry_policy",
+ ]
+
+ _int_settings = [
+ "timeout",
+ "num_retries",
+ "retry_after",
+ "allowed_fails",
+ "cooldown_time",
+ ]
+
+ _existing_router_settings = self.get_settings()
+ for var in kwargs:
+ if var in _allowed_settings:
+ if var in _int_settings:
+ _casted_value = int(kwargs[var])
+ setattr(self, var, _casted_value)
+ else:
+ # only run routing strategy init if it has changed
+ if (
+ var == "routing_strategy"
+ and _existing_router_settings["routing_strategy"] != kwargs[var]
+ ):
+ self.routing_strategy_init(
+ routing_strategy=kwargs[var],
+ routing_strategy_args=kwargs.get(
+ "routing_strategy_args", {}
+ ),
+ )
+ setattr(self, var, kwargs[var])
+ else:
+ verbose_router_logger.debug("Setting {} is not allowed".format(var))
+ verbose_router_logger.debug(f"Updated Router settings: {self.get_settings()}")
+
+ def _get_client(self, deployment, kwargs, client_type=None):
+ """
+ Returns the appropriate client based on the given deployment, kwargs, and client_type.
+
+ Parameters:
+ deployment (dict): The deployment dictionary containing the clients.
+ kwargs (dict): The keyword arguments passed to the function.
+ client_type (str): The type of client to return.
+
+ Returns:
+ The appropriate client based on the given client_type and kwargs.
+ """
+ model_id = deployment["model_info"]["id"]
+ parent_otel_span: Optional[Span] = _get_parent_otel_span_from_kwargs(kwargs)
+ if client_type == "max_parallel_requests":
+ cache_key = "{}_max_parallel_requests_client".format(model_id)
+ client = self.cache.get_cache(
+ key=cache_key, local_only=True, parent_otel_span=parent_otel_span
+ )
+ if client is None:
+ InitalizeCachedClient.set_max_parallel_requests_client(
+ litellm_router_instance=self, model=deployment
+ )
+ client = self.cache.get_cache(
+ key=cache_key, local_only=True, parent_otel_span=parent_otel_span
+ )
+ return client
+ elif client_type == "async":
+ if kwargs.get("stream") is True:
+ cache_key = f"{model_id}_stream_async_client"
+ client = self.cache.get_cache(
+ key=cache_key, local_only=True, parent_otel_span=parent_otel_span
+ )
+ return client
+ else:
+ cache_key = f"{model_id}_async_client"
+ client = self.cache.get_cache(
+ key=cache_key, local_only=True, parent_otel_span=parent_otel_span
+ )
+ return client
+ else:
+ if kwargs.get("stream") is True:
+ cache_key = f"{model_id}_stream_client"
+ client = self.cache.get_cache(
+ key=cache_key, parent_otel_span=parent_otel_span
+ )
+ return client
+ else:
+ cache_key = f"{model_id}_client"
+ client = self.cache.get_cache(
+ key=cache_key, parent_otel_span=parent_otel_span
+ )
+ return client
+
+ def _pre_call_checks( # noqa: PLR0915
+ self,
+ model: str,
+ healthy_deployments: List,
+ messages: List[Dict[str, str]],
+ request_kwargs: Optional[dict] = None,
+ ):
+ """
+ Filter out model in model group, if:
+
+ - model context window < message length. For azure openai models, requires 'base_model' is set. - https://docs.litellm.ai/docs/proxy/cost_tracking#spend-tracking-for-azure-openai-models
+ - filter models above rpm limits
+ - if region given, filter out models not in that region / unknown region
+ - [TODO] function call and model doesn't support function calling
+ """
+
+ verbose_router_logger.debug(
+ f"Starting Pre-call checks for deployments in model={model}"
+ )
+
+ _returned_deployments = copy.deepcopy(healthy_deployments)
+
+ invalid_model_indices = []
+
+ try:
+ input_tokens = litellm.token_counter(messages=messages)
+ except Exception as e:
+ verbose_router_logger.error(
+ "litellm.router.py::_pre_call_checks: failed to count tokens. Returning initial list of deployments. Got - {}".format(
+ str(e)
+ )
+ )
+ return _returned_deployments
+
+ _context_window_error = False
+ _potential_error_str = ""
+ _rate_limit_error = False
+ parent_otel_span = _get_parent_otel_span_from_kwargs(request_kwargs)
+
+ ## get model group RPM ##
+ dt = get_utc_datetime()
+ current_minute = dt.strftime("%H-%M")
+ rpm_key = f"{model}:rpm:{current_minute}"
+ model_group_cache = (
+ self.cache.get_cache(
+ key=rpm_key, local_only=True, parent_otel_span=parent_otel_span
+ )
+ or {}
+ ) # check the in-memory cache used by lowest_latency and usage-based routing. Only check the local cache.
+ for idx, deployment in enumerate(_returned_deployments):
+ # see if we have the info for this model
+ try:
+ base_model = deployment.get("model_info", {}).get("base_model", None)
+ if base_model is None:
+ base_model = deployment.get("litellm_params", {}).get(
+ "base_model", None
+ )
+ model_info = self.get_router_model_info(
+ deployment=deployment, received_model_name=model
+ )
+ model = base_model or deployment.get("litellm_params", {}).get(
+ "model", None
+ )
+
+ if (
+ isinstance(model_info, dict)
+ and model_info.get("max_input_tokens", None) is not None
+ ):
+ if (
+ isinstance(model_info["max_input_tokens"], int)
+ and input_tokens > model_info["max_input_tokens"]
+ ):
+ invalid_model_indices.append(idx)
+ _context_window_error = True
+ _potential_error_str += (
+ "Model={}, Max Input Tokens={}, Got={}".format(
+ model, model_info["max_input_tokens"], input_tokens
+ )
+ )
+ continue
+ except Exception as e:
+ verbose_router_logger.exception("An error occurs - {}".format(str(e)))
+
+ _litellm_params = deployment.get("litellm_params", {})
+ model_id = deployment.get("model_info", {}).get("id", "")
+ ## RPM CHECK ##
+ ### get local router cache ###
+ current_request_cache_local = (
+ self.cache.get_cache(
+ key=model_id, local_only=True, parent_otel_span=parent_otel_span
+ )
+ or 0
+ )
+ ### get usage based cache ###
+ if (
+ isinstance(model_group_cache, dict)
+ and self.routing_strategy != "usage-based-routing-v2"
+ ):
+ model_group_cache[model_id] = model_group_cache.get(model_id, 0)
+
+ current_request = max(
+ current_request_cache_local, model_group_cache[model_id]
+ )
+
+ if (
+ isinstance(_litellm_params, dict)
+ and _litellm_params.get("rpm", None) is not None
+ ):
+ if (
+ isinstance(_litellm_params["rpm"], int)
+ and _litellm_params["rpm"] <= current_request
+ ):
+ invalid_model_indices.append(idx)
+ _rate_limit_error = True
+ continue
+
+ ## REGION CHECK ##
+ if (
+ request_kwargs is not None
+ and request_kwargs.get("allowed_model_region") is not None
+ ):
+ allowed_model_region = request_kwargs.get("allowed_model_region")
+
+ if allowed_model_region is not None:
+ if not is_region_allowed(
+ litellm_params=LiteLLM_Params(**_litellm_params),
+ allowed_model_region=allowed_model_region,
+ ):
+ invalid_model_indices.append(idx)
+ continue
+
+ ## INVALID PARAMS ## -> catch 'gpt-3.5-turbo-16k' not supporting 'response_format' param
+ if request_kwargs is not None and litellm.drop_params is False:
+ # get supported params
+ model, custom_llm_provider, _, _ = litellm.get_llm_provider(
+ model=model, litellm_params=LiteLLM_Params(**_litellm_params)
+ )
+
+ supported_openai_params = litellm.get_supported_openai_params(
+ model=model, custom_llm_provider=custom_llm_provider
+ )
+
+ if supported_openai_params is None:
+ continue
+ else:
+ # check the non-default openai params in request kwargs
+ non_default_params = litellm.utils.get_non_default_params(
+ passed_params=request_kwargs
+ )
+ special_params = ["response_format"]
+ # check if all params are supported
+ for k, v in non_default_params.items():
+ if k not in supported_openai_params and k in special_params:
+ # if not -> invalid model
+ verbose_router_logger.debug(
+ f"INVALID MODEL INDEX @ REQUEST KWARG FILTERING, k={k}"
+ )
+ invalid_model_indices.append(idx)
+
+ if len(invalid_model_indices) == len(_returned_deployments):
+ """
+ - no healthy deployments available b/c context window checks or rate limit error
+
+ - First check for rate limit errors (if this is true, it means the model passed the context window check but failed the rate limit check)
+ """
+
+ if _rate_limit_error is True: # allow generic fallback logic to take place
+ raise RouterRateLimitErrorBasic(
+ model=model,
+ )
+
+ elif _context_window_error is True:
+ raise litellm.ContextWindowExceededError(
+ message="litellm._pre_call_checks: Context Window exceeded for given call. No models have context window large enough for this call.\n{}".format(
+ _potential_error_str
+ ),
+ model=model,
+ llm_provider="",
+ )
+ if len(invalid_model_indices) > 0:
+ for idx in reversed(invalid_model_indices):
+ _returned_deployments.pop(idx)
+
+ ## ORDER FILTERING ## -> if user set 'order' in deployments, return deployments with lowest order (e.g. order=1 > order=2)
+ if len(_returned_deployments) > 0:
+ _returned_deployments = litellm.utils._get_order_filtered_deployments(
+ _returned_deployments
+ )
+
+ return _returned_deployments
+
+ def _get_model_from_alias(self, model: str) -> Optional[str]:
+ """
+ Get the model from the alias.
+
+ Returns:
+ - str, the litellm model name
+ - None, if model is not in model group alias
+ """
+ if model not in self.model_group_alias:
+ return None
+
+ _item = self.model_group_alias[model]
+ if isinstance(_item, str):
+ model = _item
+ else:
+ model = _item["model"]
+
+ return model
+
+ def _get_deployment_by_litellm_model(self, model: str) -> List:
+ """
+ Get the deployment by litellm model.
+ """
+ return [m for m in self.model_list if m["litellm_params"]["model"] == model]
+
+ def _common_checks_available_deployment(
+ self,
+ model: str,
+ messages: Optional[List[Dict[str, str]]] = None,
+ input: Optional[Union[str, List]] = None,
+ specific_deployment: Optional[bool] = False,
+ ) -> Tuple[str, Union[List, Dict]]:
+ """
+ Common checks for 'get_available_deployment' across sync + async call.
+
+ If 'healthy_deployments' returned is None, this means the user chose a specific deployment
+
+ Returns
+ - str, the litellm model name
+ - List, if multiple models chosen
+ - Dict, if specific model chosen
+ """
+ # check if aliases set on litellm model alias map
+ if specific_deployment is True:
+ return model, self._get_deployment_by_litellm_model(model=model)
+ elif model in self.get_model_ids():
+ deployment = self.get_deployment(model_id=model)
+ if deployment is not None:
+ deployment_model = deployment.litellm_params.model
+ return deployment_model, deployment.model_dump(exclude_none=True)
+ raise ValueError(
+ f"LiteLLM Router: Trying to call specific deployment, but Model ID :{model} does not exist in \
+ Model ID List: {self.get_model_ids}"
+ )
+
+ _model_from_alias = self._get_model_from_alias(model=model)
+ if _model_from_alias is not None:
+ model = _model_from_alias
+
+ if model not in self.model_names:
+ # check if provider/ specific wildcard routing use pattern matching
+ pattern_deployments = self.pattern_router.get_deployments_by_pattern(
+ model=model,
+ )
+ if pattern_deployments:
+ return model, pattern_deployments
+
+ # check if default deployment is set
+ if self.default_deployment is not None:
+ updated_deployment = copy.deepcopy(
+ self.default_deployment
+ ) # self.default_deployment
+ updated_deployment["litellm_params"]["model"] = model
+ return model, updated_deployment
+
+ ## get healthy deployments
+ ### get all deployments
+ healthy_deployments = self._get_all_deployments(model_name=model)
+
+ if len(healthy_deployments) == 0:
+ # check if the user sent in a deployment name instead
+ healthy_deployments = self._get_deployment_by_litellm_model(model=model)
+
+ verbose_router_logger.debug(
+ f"initial list of deployments: {healthy_deployments}"
+ )
+
+ if len(healthy_deployments) == 0:
+ raise litellm.BadRequestError(
+ message="You passed in model={}. There is no 'model_name' with this string ".format(
+ model
+ ),
+ model=model,
+ llm_provider="",
+ )
+
+ if litellm.model_alias_map and model in litellm.model_alias_map:
+ model = litellm.model_alias_map[
+ model
+ ] # update the model to the actual value if an alias has been passed in
+
+ return model, healthy_deployments
+
+ async def async_get_available_deployment(
+ self,
+ model: str,
+ request_kwargs: Dict,
+ messages: Optional[List[Dict[str, str]]] = None,
+ input: Optional[Union[str, List]] = None,
+ specific_deployment: Optional[bool] = False,
+ ):
+ """
+ Async implementation of 'get_available_deployments'.
+
+ Allows all cache calls to be made async => 10x perf impact (8rps -> 100 rps).
+ """
+ if (
+ self.routing_strategy != "usage-based-routing-v2"
+ and self.routing_strategy != "simple-shuffle"
+ and self.routing_strategy != "cost-based-routing"
+ and self.routing_strategy != "latency-based-routing"
+ and self.routing_strategy != "least-busy"
+ ): # prevent regressions for other routing strategies, that don't have async get available deployments implemented.
+ return self.get_available_deployment(
+ model=model,
+ messages=messages,
+ input=input,
+ specific_deployment=specific_deployment,
+ request_kwargs=request_kwargs,
+ )
+ try:
+ parent_otel_span = _get_parent_otel_span_from_kwargs(request_kwargs)
+ model, healthy_deployments = self._common_checks_available_deployment(
+ model=model,
+ messages=messages,
+ input=input,
+ specific_deployment=specific_deployment,
+ ) # type: ignore
+ if isinstance(healthy_deployments, dict):
+ return healthy_deployments
+
+ cooldown_deployments = await _async_get_cooldown_deployments(
+ litellm_router_instance=self, parent_otel_span=parent_otel_span
+ )
+ verbose_router_logger.debug(
+ f"async cooldown deployments: {cooldown_deployments}"
+ )
+ verbose_router_logger.debug(f"cooldown_deployments: {cooldown_deployments}")
+ healthy_deployments = self._filter_cooldown_deployments(
+ healthy_deployments=healthy_deployments,
+ cooldown_deployments=cooldown_deployments,
+ )
+
+ healthy_deployments = await self.async_callback_filter_deployments(
+ model=model,
+ healthy_deployments=healthy_deployments,
+ messages=(
+ cast(List[AllMessageValues], messages)
+ if messages is not None
+ else None
+ ),
+ request_kwargs=request_kwargs,
+ parent_otel_span=parent_otel_span,
+ )
+
+ if self.enable_pre_call_checks and messages is not None:
+ healthy_deployments = self._pre_call_checks(
+ model=model,
+ healthy_deployments=cast(List[Dict], healthy_deployments),
+ messages=messages,
+ request_kwargs=request_kwargs,
+ )
+ # check if user wants to do tag based routing
+ healthy_deployments = await get_deployments_for_tag( # type: ignore
+ llm_router_instance=self,
+ model=model,
+ request_kwargs=request_kwargs,
+ healthy_deployments=healthy_deployments,
+ )
+
+ if len(healthy_deployments) == 0:
+ exception = await async_raise_no_deployment_exception(
+ litellm_router_instance=self,
+ model=model,
+ parent_otel_span=parent_otel_span,
+ )
+ raise exception
+ start_time = time.time()
+ if (
+ self.routing_strategy == "usage-based-routing-v2"
+ and self.lowesttpm_logger_v2 is not None
+ ):
+ deployment = (
+ await self.lowesttpm_logger_v2.async_get_available_deployments(
+ model_group=model,
+ healthy_deployments=healthy_deployments, # type: ignore
+ messages=messages,
+ input=input,
+ )
+ )
+ elif (
+ self.routing_strategy == "cost-based-routing"
+ and self.lowestcost_logger is not None
+ ):
+ deployment = (
+ await self.lowestcost_logger.async_get_available_deployments(
+ model_group=model,
+ healthy_deployments=healthy_deployments, # type: ignore
+ messages=messages,
+ input=input,
+ )
+ )
+ elif (
+ self.routing_strategy == "latency-based-routing"
+ and self.lowestlatency_logger is not None
+ ):
+ deployment = (
+ await self.lowestlatency_logger.async_get_available_deployments(
+ model_group=model,
+ healthy_deployments=healthy_deployments, # type: ignore
+ messages=messages,
+ input=input,
+ request_kwargs=request_kwargs,
+ )
+ )
+ elif self.routing_strategy == "simple-shuffle":
+ return simple_shuffle(
+ llm_router_instance=self,
+ healthy_deployments=healthy_deployments,
+ model=model,
+ )
+ elif (
+ self.routing_strategy == "least-busy"
+ and self.leastbusy_logger is not None
+ ):
+ deployment = (
+ await self.leastbusy_logger.async_get_available_deployments(
+ model_group=model,
+ healthy_deployments=healthy_deployments, # type: ignore
+ )
+ )
+ else:
+ deployment = None
+ if deployment is None:
+ exception = await async_raise_no_deployment_exception(
+ litellm_router_instance=self,
+ model=model,
+ parent_otel_span=parent_otel_span,
+ )
+ raise exception
+ verbose_router_logger.info(
+ f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment)} for model: {model}"
+ )
+
+ end_time = time.time()
+ _duration = end_time - start_time
+ asyncio.create_task(
+ self.service_logger_obj.async_service_success_hook(
+ service=ServiceTypes.ROUTER,
+ duration=_duration,
+ call_type="<routing_strategy>.async_get_available_deployments",
+ parent_otel_span=parent_otel_span,
+ start_time=start_time,
+ end_time=end_time,
+ )
+ )
+
+ return deployment
+ except Exception as e:
+ traceback_exception = traceback.format_exc()
+ # if router rejects call -> log to langfuse/otel/etc.
+ if request_kwargs is not None:
+ logging_obj = request_kwargs.get("litellm_logging_obj", None)
+
+ if logging_obj is not None:
+ ## LOGGING
+ threading.Thread(
+ target=logging_obj.failure_handler,
+ args=(e, traceback_exception),
+ ).start() # log response
+ # Handle any exceptions that might occur during streaming
+ asyncio.create_task(
+ logging_obj.async_failure_handler(e, traceback_exception) # type: ignore
+ )
+ raise e
+
+ def get_available_deployment(
+ self,
+ model: str,
+ messages: Optional[List[Dict[str, str]]] = None,
+ input: Optional[Union[str, List]] = None,
+ specific_deployment: Optional[bool] = False,
+ request_kwargs: Optional[Dict] = None,
+ ):
+ """
+ Returns the deployment based on routing strategy
+ """
+ # users need to explicitly call a specific deployment, by setting `specific_deployment = True` as completion()/embedding() kwarg
+ # When this was no explicit we had several issues with fallbacks timing out
+
+ model, healthy_deployments = self._common_checks_available_deployment(
+ model=model,
+ messages=messages,
+ input=input,
+ specific_deployment=specific_deployment,
+ )
+
+ if isinstance(healthy_deployments, dict):
+ return healthy_deployments
+
+ parent_otel_span: Optional[Span] = _get_parent_otel_span_from_kwargs(
+ request_kwargs
+ )
+ cooldown_deployments = _get_cooldown_deployments(
+ litellm_router_instance=self, parent_otel_span=parent_otel_span
+ )
+ healthy_deployments = self._filter_cooldown_deployments(
+ healthy_deployments=healthy_deployments,
+ cooldown_deployments=cooldown_deployments,
+ )
+
+ # filter pre-call checks
+ if self.enable_pre_call_checks and messages is not None:
+ healthy_deployments = self._pre_call_checks(
+ model=model,
+ healthy_deployments=healthy_deployments,
+ messages=messages,
+ request_kwargs=request_kwargs,
+ )
+
+ if len(healthy_deployments) == 0:
+ model_ids = self.get_model_ids(model_name=model)
+ _cooldown_time = self.cooldown_cache.get_min_cooldown(
+ model_ids=model_ids, parent_otel_span=parent_otel_span
+ )
+ _cooldown_list = _get_cooldown_deployments(
+ litellm_router_instance=self, parent_otel_span=parent_otel_span
+ )
+ raise RouterRateLimitError(
+ model=model,
+ cooldown_time=_cooldown_time,
+ enable_pre_call_checks=self.enable_pre_call_checks,
+ cooldown_list=_cooldown_list,
+ )
+
+ if self.routing_strategy == "least-busy" and self.leastbusy_logger is not None:
+ deployment = self.leastbusy_logger.get_available_deployments(
+ model_group=model, healthy_deployments=healthy_deployments # type: ignore
+ )
+ elif self.routing_strategy == "simple-shuffle":
+ # if users pass rpm or tpm, we do a random weighted pick - based on rpm/tpm
+ ############## Check 'weight' param set for weighted pick #################
+ return simple_shuffle(
+ llm_router_instance=self,
+ healthy_deployments=healthy_deployments,
+ model=model,
+ )
+ elif (
+ self.routing_strategy == "latency-based-routing"
+ and self.lowestlatency_logger is not None
+ ):
+ deployment = self.lowestlatency_logger.get_available_deployments(
+ model_group=model,
+ healthy_deployments=healthy_deployments, # type: ignore
+ request_kwargs=request_kwargs,
+ )
+ elif (
+ self.routing_strategy == "usage-based-routing"
+ and self.lowesttpm_logger is not None
+ ):
+ deployment = self.lowesttpm_logger.get_available_deployments(
+ model_group=model,
+ healthy_deployments=healthy_deployments, # type: ignore
+ messages=messages,
+ input=input,
+ )
+ elif (
+ self.routing_strategy == "usage-based-routing-v2"
+ and self.lowesttpm_logger_v2 is not None
+ ):
+ deployment = self.lowesttpm_logger_v2.get_available_deployments(
+ model_group=model,
+ healthy_deployments=healthy_deployments, # type: ignore
+ messages=messages,
+ input=input,
+ )
+ else:
+ deployment = None
+
+ if deployment is None:
+ verbose_router_logger.info(
+ f"get_available_deployment for model: {model}, No deployment available"
+ )
+ model_ids = self.get_model_ids(model_name=model)
+ _cooldown_time = self.cooldown_cache.get_min_cooldown(
+ model_ids=model_ids, parent_otel_span=parent_otel_span
+ )
+ _cooldown_list = _get_cooldown_deployments(
+ litellm_router_instance=self, parent_otel_span=parent_otel_span
+ )
+ raise RouterRateLimitError(
+ model=model,
+ cooldown_time=_cooldown_time,
+ enable_pre_call_checks=self.enable_pre_call_checks,
+ cooldown_list=_cooldown_list,
+ )
+ verbose_router_logger.info(
+ f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment)} for model: {model}"
+ )
+ return deployment
+
+ def _filter_cooldown_deployments(
+ self, healthy_deployments: List[Dict], cooldown_deployments: List[str]
+ ) -> List[Dict]:
+ """
+ Filters out the deployments currently cooling down from the list of healthy deployments
+
+ Args:
+ healthy_deployments: List of healthy deployments
+ cooldown_deployments: List of model_ids cooling down. cooldown_deployments is a list of model_id's cooling down, cooldown_deployments = ["16700539-b3cd-42f4-b426-6a12a1bb706a", "16700539-b3cd-42f4-b426-7899"]
+
+ Returns:
+ List of healthy deployments
+ """
+ # filter out the deployments currently cooling down
+ deployments_to_remove = []
+ verbose_router_logger.debug(f"cooldown deployments: {cooldown_deployments}")
+ # Find deployments in model_list whose model_id is cooling down
+ for deployment in healthy_deployments:
+ deployment_id = deployment["model_info"]["id"]
+ if deployment_id in cooldown_deployments:
+ deployments_to_remove.append(deployment)
+
+ # remove unhealthy deployments from healthy deployments
+ for deployment in deployments_to_remove:
+ healthy_deployments.remove(deployment)
+ return healthy_deployments
+
+ def _track_deployment_metrics(
+ self, deployment, parent_otel_span: Optional[Span], response=None
+ ):
+ """
+ Tracks successful requests rpm usage.
+ """
+ try:
+ model_id = deployment.get("model_info", {}).get("id", None)
+ if response is None:
+
+ # update self.deployment_stats
+ if model_id is not None:
+ self._update_usage(
+ model_id, parent_otel_span
+ ) # update in-memory cache for tracking
+ except Exception as e:
+ verbose_router_logger.error(f"Error in _track_deployment_metrics: {str(e)}")
+
+ def get_num_retries_from_retry_policy(
+ self, exception: Exception, model_group: Optional[str] = None
+ ):
+ return _get_num_retries_from_retry_policy(
+ exception=exception,
+ model_group=model_group,
+ model_group_retry_policy=self.model_group_retry_policy,
+ retry_policy=self.retry_policy,
+ )
+
+ def get_allowed_fails_from_policy(self, exception: Exception):
+ """
+ BadRequestErrorRetries: Optional[int] = None
+ AuthenticationErrorRetries: Optional[int] = None
+ TimeoutErrorRetries: Optional[int] = None
+ RateLimitErrorRetries: Optional[int] = None
+ ContentPolicyViolationErrorRetries: Optional[int] = None
+ """
+ # if we can find the exception then in the retry policy -> return the number of retries
+ allowed_fails_policy: Optional[AllowedFailsPolicy] = self.allowed_fails_policy
+
+ if allowed_fails_policy is None:
+ return None
+
+ if (
+ isinstance(exception, litellm.BadRequestError)
+ and allowed_fails_policy.BadRequestErrorAllowedFails is not None
+ ):
+ return allowed_fails_policy.BadRequestErrorAllowedFails
+ if (
+ isinstance(exception, litellm.AuthenticationError)
+ and allowed_fails_policy.AuthenticationErrorAllowedFails is not None
+ ):
+ return allowed_fails_policy.AuthenticationErrorAllowedFails
+ if (
+ isinstance(exception, litellm.Timeout)
+ and allowed_fails_policy.TimeoutErrorAllowedFails is not None
+ ):
+ return allowed_fails_policy.TimeoutErrorAllowedFails
+ if (
+ isinstance(exception, litellm.RateLimitError)
+ and allowed_fails_policy.RateLimitErrorAllowedFails is not None
+ ):
+ return allowed_fails_policy.RateLimitErrorAllowedFails
+ if (
+ isinstance(exception, litellm.ContentPolicyViolationError)
+ and allowed_fails_policy.ContentPolicyViolationErrorAllowedFails is not None
+ ):
+ return allowed_fails_policy.ContentPolicyViolationErrorAllowedFails
+
+ def _initialize_alerting(self):
+ from litellm.integrations.SlackAlerting.slack_alerting import SlackAlerting
+
+ if self.alerting_config is None:
+ return
+
+ router_alerting_config: AlertingConfig = self.alerting_config
+
+ _slack_alerting_logger = SlackAlerting(
+ alerting_threshold=router_alerting_config.alerting_threshold,
+ alerting=["slack"],
+ default_webhook_url=router_alerting_config.webhook_url,
+ )
+
+ self.slack_alerting_logger = _slack_alerting_logger
+
+ litellm.logging_callback_manager.add_litellm_callback(_slack_alerting_logger) # type: ignore
+ litellm.logging_callback_manager.add_litellm_success_callback(
+ _slack_alerting_logger.response_taking_too_long_callback
+ )
+ verbose_router_logger.info(
+ "\033[94m\nInitialized Alerting for litellm.Router\033[0m\n"
+ )
+
+ def set_custom_routing_strategy(
+ self, CustomRoutingStrategy: CustomRoutingStrategyBase
+ ):
+ """
+ Sets get_available_deployment and async_get_available_deployment on an instanced of litellm.Router
+
+ Use this to set your custom routing strategy
+
+ Args:
+ CustomRoutingStrategy: litellm.router.CustomRoutingStrategyBase
+ """
+
+ setattr(
+ self,
+ "get_available_deployment",
+ CustomRoutingStrategy.get_available_deployment,
+ )
+ setattr(
+ self,
+ "async_get_available_deployment",
+ CustomRoutingStrategy.async_get_available_deployment,
+ )
+
+ def flush_cache(self):
+ litellm.cache = None
+ self.cache.flush_cache()
+
+ def reset(self):
+ ## clean up on close
+ litellm.success_callback = []
+ litellm._async_success_callback = []
+ litellm.failure_callback = []
+ litellm._async_failure_callback = []
+ self.retry_policy = None
+ self.flush_cache()