aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/router_strategy
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/router_strategy')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/router_strategy/base_routing_strategy.py190
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/router_strategy/budget_limiter.py818
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/router_strategy/least_busy.py252
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/router_strategy/lowest_cost.py333
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/router_strategy/lowest_latency.py590
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/router_strategy/lowest_tpm_rpm.py243
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/router_strategy/lowest_tpm_rpm_v2.py671
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/router_strategy/simple_shuffle.py96
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/router_strategy/tag_based_routing.py146
9 files changed, 3339 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/router_strategy/base_routing_strategy.py b/.venv/lib/python3.12/site-packages/litellm/router_strategy/base_routing_strategy.py
new file mode 100644
index 00000000..a39d17e3
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/router_strategy/base_routing_strategy.py
@@ -0,0 +1,190 @@
+"""
+Base class across routing strategies to abstract commmon functions like batch incrementing redis
+"""
+
+import asyncio
+import threading
+from abc import ABC
+from typing import List, Optional, Set, Union
+
+from litellm._logging import verbose_router_logger
+from litellm.caching.caching import DualCache
+from litellm.caching.redis_cache import RedisPipelineIncrementOperation
+from litellm.constants import DEFAULT_REDIS_SYNC_INTERVAL
+
+
+class BaseRoutingStrategy(ABC):
+ def __init__(
+ self,
+ dual_cache: DualCache,
+ should_batch_redis_writes: bool,
+ default_sync_interval: Optional[Union[int, float]],
+ ):
+ self.dual_cache = dual_cache
+ self.redis_increment_operation_queue: List[RedisPipelineIncrementOperation] = []
+ if should_batch_redis_writes:
+ try:
+ # Try to get existing event loop
+ loop = asyncio.get_event_loop()
+ if loop.is_running():
+ # If loop exists and is running, create task in existing loop
+ loop.create_task(
+ self.periodic_sync_in_memory_spend_with_redis(
+ default_sync_interval=default_sync_interval
+ )
+ )
+ else:
+ self._create_sync_thread(default_sync_interval)
+ except RuntimeError: # No event loop in current thread
+ self._create_sync_thread(default_sync_interval)
+
+ self.in_memory_keys_to_update: set[str] = (
+ set()
+ ) # Set with max size of 1000 keys
+
+ async def _increment_value_in_current_window(
+ self, key: str, value: Union[int, float], ttl: int
+ ):
+ """
+ Increment spend within existing budget window
+
+ Runs once the budget start time exists in Redis Cache (on the 2nd and subsequent requests to the same provider)
+
+ - Increments the spend in memory cache (so spend instantly updated in memory)
+ - Queues the increment operation to Redis Pipeline (using batched pipeline to optimize performance. Using Redis for multi instance environment of LiteLLM)
+ """
+ result = await self.dual_cache.in_memory_cache.async_increment(
+ key=key,
+ value=value,
+ ttl=ttl,
+ )
+ increment_op = RedisPipelineIncrementOperation(
+ key=key,
+ increment_value=value,
+ ttl=ttl,
+ )
+ self.redis_increment_operation_queue.append(increment_op)
+ self.add_to_in_memory_keys_to_update(key=key)
+ return result
+
+ async def periodic_sync_in_memory_spend_with_redis(
+ self, default_sync_interval: Optional[Union[int, float]]
+ ):
+ """
+ Handler that triggers sync_in_memory_spend_with_redis every DEFAULT_REDIS_SYNC_INTERVAL seconds
+
+ Required for multi-instance environment usage of provider budgets
+ """
+ default_sync_interval = default_sync_interval or DEFAULT_REDIS_SYNC_INTERVAL
+ while True:
+ try:
+ await self._sync_in_memory_spend_with_redis()
+ await asyncio.sleep(
+ default_sync_interval
+ ) # Wait for DEFAULT_REDIS_SYNC_INTERVAL seconds before next sync
+ except Exception as e:
+ verbose_router_logger.error(f"Error in periodic sync task: {str(e)}")
+ await asyncio.sleep(
+ default_sync_interval
+ ) # Still wait DEFAULT_REDIS_SYNC_INTERVAL seconds on error before retrying
+
+ async def _push_in_memory_increments_to_redis(self):
+ """
+ How this works:
+ - async_log_success_event collects all provider spend increments in `redis_increment_operation_queue`
+ - This function pushes all increments to Redis in a batched pipeline to optimize performance
+
+ Only runs if Redis is initialized
+ """
+ try:
+ if not self.dual_cache.redis_cache:
+ return # Redis is not initialized
+
+ verbose_router_logger.debug(
+ "Pushing Redis Increment Pipeline for queue: %s",
+ self.redis_increment_operation_queue,
+ )
+ if len(self.redis_increment_operation_queue) > 0:
+ asyncio.create_task(
+ self.dual_cache.redis_cache.async_increment_pipeline(
+ increment_list=self.redis_increment_operation_queue,
+ )
+ )
+
+ self.redis_increment_operation_queue = []
+
+ except Exception as e:
+ verbose_router_logger.error(
+ f"Error syncing in-memory cache with Redis: {str(e)}"
+ )
+ self.redis_increment_operation_queue = []
+
+ def add_to_in_memory_keys_to_update(self, key: str):
+ self.in_memory_keys_to_update.add(key)
+
+ def get_in_memory_keys_to_update(self) -> Set[str]:
+ return self.in_memory_keys_to_update
+
+ def reset_in_memory_keys_to_update(self):
+ self.in_memory_keys_to_update = set()
+
+ async def _sync_in_memory_spend_with_redis(self):
+ """
+ Ensures in-memory cache is updated with latest Redis values for all provider spends.
+
+ Why Do we need this?
+ - Optimization to hit sub 100ms latency. Performance was impacted when redis was used for read/write per request
+ - Use provider budgets in multi-instance environment, we use Redis to sync spend across all instances
+
+ What this does:
+ 1. Push all provider spend increments to Redis
+ 2. Fetch all current provider spend from Redis to update in-memory cache
+ """
+
+ try:
+ # No need to sync if Redis cache is not initialized
+ if self.dual_cache.redis_cache is None:
+ return
+
+ # 1. Push all provider spend increments to Redis
+ await self._push_in_memory_increments_to_redis()
+
+ # 2. Fetch all current provider spend from Redis to update in-memory cache
+ cache_keys = self.get_in_memory_keys_to_update()
+
+ cache_keys_list = list(cache_keys)
+
+ # Batch fetch current spend values from Redis
+ redis_values = await self.dual_cache.redis_cache.async_batch_get_cache(
+ key_list=cache_keys_list
+ )
+
+ # Update in-memory cache with Redis values
+ if isinstance(redis_values, dict): # Check if redis_values is a dictionary
+ for key, value in redis_values.items():
+ if value is not None:
+ await self.dual_cache.in_memory_cache.async_set_cache(
+ key=key, value=float(value)
+ )
+ verbose_router_logger.debug(
+ f"Updated in-memory cache for {key}: {value}"
+ )
+
+ self.reset_in_memory_keys_to_update()
+ except Exception as e:
+ verbose_router_logger.exception(
+ f"Error syncing in-memory cache with Redis: {str(e)}"
+ )
+
+ def _create_sync_thread(self, default_sync_interval):
+ """Helper method to create a new thread for periodic sync"""
+ thread = threading.Thread(
+ target=asyncio.run,
+ args=(
+ self.periodic_sync_in_memory_spend_with_redis(
+ default_sync_interval=default_sync_interval
+ ),
+ ),
+ daemon=True,
+ )
+ thread.start()
diff --git a/.venv/lib/python3.12/site-packages/litellm/router_strategy/budget_limiter.py b/.venv/lib/python3.12/site-packages/litellm/router_strategy/budget_limiter.py
new file mode 100644
index 00000000..4f123df2
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/router_strategy/budget_limiter.py
@@ -0,0 +1,818 @@
+"""
+Provider budget limiting
+
+Use this if you want to set $ budget limits for each provider.
+
+Note: This is a filter, like tag-routing. Meaning it will accept healthy deployments and then filter out deployments that have exceeded their budget limit.
+
+This means you can use this with weighted-pick, lowest-latency, simple-shuffle, routing etc
+
+Example:
+```
+openai:
+ budget_limit: 0.000000000001
+ time_period: 1d
+anthropic:
+ budget_limit: 100
+ time_period: 7d
+```
+"""
+
+import asyncio
+from datetime import datetime, timedelta, timezone
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import litellm
+from litellm._logging import verbose_router_logger
+from litellm.caching.caching import DualCache
+from litellm.caching.redis_cache import RedisPipelineIncrementOperation
+from litellm.integrations.custom_logger import CustomLogger, Span
+from litellm.litellm_core_utils.duration_parser import duration_in_seconds
+from litellm.router_strategy.tag_based_routing import _get_tags_from_request_kwargs
+from litellm.router_utils.cooldown_callbacks import (
+ _get_prometheus_logger_from_callbacks,
+)
+from litellm.types.llms.openai import AllMessageValues
+from litellm.types.router import DeploymentTypedDict, LiteLLM_Params, RouterErrors
+from litellm.types.utils import BudgetConfig
+from litellm.types.utils import BudgetConfig as GenericBudgetInfo
+from litellm.types.utils import GenericBudgetConfigType, StandardLoggingPayload
+
+DEFAULT_REDIS_SYNC_INTERVAL = 1
+
+
+class RouterBudgetLimiting(CustomLogger):
+ def __init__(
+ self,
+ dual_cache: DualCache,
+ provider_budget_config: Optional[dict],
+ model_list: Optional[
+ Union[List[DeploymentTypedDict], List[Dict[str, Any]]]
+ ] = None,
+ ):
+ self.dual_cache = dual_cache
+ self.redis_increment_operation_queue: List[RedisPipelineIncrementOperation] = []
+ asyncio.create_task(self.periodic_sync_in_memory_spend_with_redis())
+ self.provider_budget_config: Optional[GenericBudgetConfigType] = (
+ provider_budget_config
+ )
+ self.deployment_budget_config: Optional[GenericBudgetConfigType] = None
+ self.tag_budget_config: Optional[GenericBudgetConfigType] = None
+ self._init_provider_budgets()
+ self._init_deployment_budgets(model_list=model_list)
+ self._init_tag_budgets()
+
+ # Add self to litellm callbacks if it's a list
+ if isinstance(litellm.callbacks, list):
+ litellm.logging_callback_manager.add_litellm_callback(self) # type: ignore
+
+ async def async_filter_deployments(
+ self,
+ model: str,
+ healthy_deployments: List,
+ messages: Optional[List[AllMessageValues]],
+ request_kwargs: Optional[dict] = None,
+ parent_otel_span: Optional[Span] = None, # type: ignore
+ ) -> List[dict]:
+ """
+ Filter out deployments that have exceeded their provider budget limit.
+
+
+ Example:
+ if deployment = openai/gpt-3.5-turbo
+ and openai spend > openai budget limit
+ then skip this deployment
+ """
+
+ # If a single deployment is passed, convert it to a list
+ if isinstance(healthy_deployments, dict):
+ healthy_deployments = [healthy_deployments]
+
+ # Don't do any filtering if there are no healthy deployments
+ if len(healthy_deployments) == 0:
+ return healthy_deployments
+
+ potential_deployments: List[Dict] = []
+
+ cache_keys, provider_configs, deployment_configs = (
+ await self._async_get_cache_keys_for_router_budget_limiting(
+ healthy_deployments=healthy_deployments,
+ request_kwargs=request_kwargs,
+ )
+ )
+
+ # Single cache read for all spend values
+ if len(cache_keys) > 0:
+ _current_spends = await self.dual_cache.async_batch_get_cache(
+ keys=cache_keys,
+ parent_otel_span=parent_otel_span,
+ )
+ current_spends: List = _current_spends or [0.0] * len(cache_keys)
+
+ # Map spends to their respective keys
+ spend_map: Dict[str, float] = {}
+ for idx, key in enumerate(cache_keys):
+ spend_map[key] = float(current_spends[idx] or 0.0)
+
+ potential_deployments, deployment_above_budget_info = (
+ self._filter_out_deployments_above_budget(
+ healthy_deployments=healthy_deployments,
+ provider_configs=provider_configs,
+ deployment_configs=deployment_configs,
+ spend_map=spend_map,
+ potential_deployments=potential_deployments,
+ request_tags=_get_tags_from_request_kwargs(
+ request_kwargs=request_kwargs
+ ),
+ )
+ )
+
+ if len(potential_deployments) == 0:
+ raise ValueError(
+ f"{RouterErrors.no_deployments_with_provider_budget_routing.value}: {deployment_above_budget_info}"
+ )
+
+ return potential_deployments
+ else:
+ return healthy_deployments
+
+ def _filter_out_deployments_above_budget(
+ self,
+ potential_deployments: List[Dict[str, Any]],
+ healthy_deployments: List[Dict[str, Any]],
+ provider_configs: Dict[str, GenericBudgetInfo],
+ deployment_configs: Dict[str, GenericBudgetInfo],
+ spend_map: Dict[str, float],
+ request_tags: List[str],
+ ) -> Tuple[List[Dict[str, Any]], str]:
+ """
+ Filter out deployments that have exceeded their budget limit.
+ Follow budget checks are run here:
+ - Provider budget
+ - Deployment budget
+ - Request tags budget
+ Returns:
+ Tuple[List[Dict[str, Any]], str]:
+ - A tuple containing the filtered deployments
+ - A string containing debug information about deployments that exceeded their budget limit.
+ """
+ # Filter deployments based on both provider and deployment budgets
+ deployment_above_budget_info: str = ""
+ for deployment in healthy_deployments:
+ is_within_budget = True
+
+ # Check provider budget
+ if self.provider_budget_config:
+ provider = self._get_llm_provider_for_deployment(deployment)
+ if provider in provider_configs:
+ config = provider_configs[provider]
+ if config.max_budget is None:
+ continue
+ current_spend = spend_map.get(
+ f"provider_spend:{provider}:{config.budget_duration}", 0.0
+ )
+ self._track_provider_remaining_budget_prometheus(
+ provider=provider,
+ spend=current_spend,
+ budget_limit=config.max_budget,
+ )
+
+ if config.max_budget and current_spend >= config.max_budget:
+ debug_msg = f"Exceeded budget for provider {provider}: {current_spend} >= {config.max_budget}"
+ deployment_above_budget_info += f"{debug_msg}\n"
+ is_within_budget = False
+ continue
+
+ # Check deployment budget
+ if self.deployment_budget_config and is_within_budget:
+ _model_name = deployment.get("model_name")
+ _litellm_params = deployment.get("litellm_params") or {}
+ _litellm_model_name = _litellm_params.get("model")
+ model_id = deployment.get("model_info", {}).get("id")
+ if model_id in deployment_configs:
+ config = deployment_configs[model_id]
+ current_spend = spend_map.get(
+ f"deployment_spend:{model_id}:{config.budget_duration}", 0.0
+ )
+ if config.max_budget and current_spend >= config.max_budget:
+ debug_msg = f"Exceeded budget for deployment model_name: {_model_name}, litellm_params.model: {_litellm_model_name}, model_id: {model_id}: {current_spend} >= {config.budget_duration}"
+ verbose_router_logger.debug(debug_msg)
+ deployment_above_budget_info += f"{debug_msg}\n"
+ is_within_budget = False
+ continue
+ # Check tag budget
+ if self.tag_budget_config and is_within_budget:
+ for _tag in request_tags:
+ _tag_budget_config = self._get_budget_config_for_tag(_tag)
+ if _tag_budget_config:
+ _tag_spend = spend_map.get(
+ f"tag_spend:{_tag}:{_tag_budget_config.budget_duration}",
+ 0.0,
+ )
+ if (
+ _tag_budget_config.max_budget
+ and _tag_spend >= _tag_budget_config.max_budget
+ ):
+ debug_msg = f"Exceeded budget for tag='{_tag}', tag_spend={_tag_spend}, tag_budget_limit={_tag_budget_config.max_budget}"
+ verbose_router_logger.debug(debug_msg)
+ deployment_above_budget_info += f"{debug_msg}\n"
+ is_within_budget = False
+ continue
+ if is_within_budget:
+ potential_deployments.append(deployment)
+
+ return potential_deployments, deployment_above_budget_info
+
+ async def _async_get_cache_keys_for_router_budget_limiting(
+ self,
+ healthy_deployments: List[Dict[str, Any]],
+ request_kwargs: Optional[Dict] = None,
+ ) -> Tuple[List[str], Dict[str, GenericBudgetInfo], Dict[str, GenericBudgetInfo]]:
+ """
+ Returns list of cache keys to fetch from router cache for budget limiting and provider and deployment configs
+
+ Returns:
+ Tuple[List[str], Dict[str, GenericBudgetInfo], Dict[str, GenericBudgetInfo]]:
+ - List of cache keys to fetch from router cache for budget limiting
+ - Dict of provider budget configs `provider_configs`
+ - Dict of deployment budget configs `deployment_configs`
+ """
+ cache_keys: List[str] = []
+ provider_configs: Dict[str, GenericBudgetInfo] = {}
+ deployment_configs: Dict[str, GenericBudgetInfo] = {}
+
+ for deployment in healthy_deployments:
+ # Check provider budgets
+ if self.provider_budget_config:
+ provider = self._get_llm_provider_for_deployment(deployment)
+ if provider is not None:
+ budget_config = self._get_budget_config_for_provider(provider)
+ if (
+ budget_config is not None
+ and budget_config.budget_duration is not None
+ ):
+ provider_configs[provider] = budget_config
+ cache_keys.append(
+ f"provider_spend:{provider}:{budget_config.budget_duration}"
+ )
+
+ # Check deployment budgets
+ if self.deployment_budget_config:
+ model_id = deployment.get("model_info", {}).get("id")
+ if model_id is not None:
+ budget_config = self._get_budget_config_for_deployment(model_id)
+ if budget_config is not None:
+ deployment_configs[model_id] = budget_config
+ cache_keys.append(
+ f"deployment_spend:{model_id}:{budget_config.budget_duration}"
+ )
+ # Check tag budgets
+ if self.tag_budget_config:
+ request_tags = _get_tags_from_request_kwargs(
+ request_kwargs=request_kwargs
+ )
+ for _tag in request_tags:
+ _tag_budget_config = self._get_budget_config_for_tag(_tag)
+ if _tag_budget_config:
+ cache_keys.append(
+ f"tag_spend:{_tag}:{_tag_budget_config.budget_duration}"
+ )
+ return cache_keys, provider_configs, deployment_configs
+
+ async def _get_or_set_budget_start_time(
+ self, start_time_key: str, current_time: float, ttl_seconds: int
+ ) -> float:
+ """
+ Checks if the key = `provider_budget_start_time:{provider}` exists in cache.
+
+ If it does, return the value.
+ If it does not, set the key to `current_time` and return the value.
+ """
+ budget_start = await self.dual_cache.async_get_cache(start_time_key)
+ if budget_start is None:
+ await self.dual_cache.async_set_cache(
+ key=start_time_key, value=current_time, ttl=ttl_seconds
+ )
+ return current_time
+ return float(budget_start)
+
+ async def _handle_new_budget_window(
+ self,
+ spend_key: str,
+ start_time_key: str,
+ current_time: float,
+ response_cost: float,
+ ttl_seconds: int,
+ ) -> float:
+ """
+ Handle start of new budget window by resetting spend and start time
+
+ Enters this when:
+ - The budget does not exist in cache, so we need to set it
+ - The budget window has expired, so we need to reset everything
+
+ Does 2 things:
+ - stores key: `provider_spend:{provider}:1d`, value: response_cost
+ - stores key: `provider_budget_start_time:{provider}`, value: current_time.
+ This stores the start time of the new budget window
+ """
+ await self.dual_cache.async_set_cache(
+ key=spend_key, value=response_cost, ttl=ttl_seconds
+ )
+ await self.dual_cache.async_set_cache(
+ key=start_time_key, value=current_time, ttl=ttl_seconds
+ )
+ return current_time
+
+ async def _increment_spend_in_current_window(
+ self, spend_key: str, response_cost: float, ttl: int
+ ):
+ """
+ Increment spend within existing budget window
+
+ Runs once the budget start time exists in Redis Cache (on the 2nd and subsequent requests to the same provider)
+
+ - Increments the spend in memory cache (so spend instantly updated in memory)
+ - Queues the increment operation to Redis Pipeline (using batched pipeline to optimize performance. Using Redis for multi instance environment of LiteLLM)
+ """
+ await self.dual_cache.in_memory_cache.async_increment(
+ key=spend_key,
+ value=response_cost,
+ ttl=ttl,
+ )
+ increment_op = RedisPipelineIncrementOperation(
+ key=spend_key,
+ increment_value=response_cost,
+ ttl=ttl,
+ )
+ self.redis_increment_operation_queue.append(increment_op)
+
+ async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
+ """Original method now uses helper functions"""
+ verbose_router_logger.debug("in RouterBudgetLimiting.async_log_success_event")
+ standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
+ "standard_logging_object", None
+ )
+ if standard_logging_payload is None:
+ raise ValueError("standard_logging_payload is required")
+
+ response_cost: float = standard_logging_payload.get("response_cost", 0)
+ model_id: str = str(standard_logging_payload.get("model_id", ""))
+ custom_llm_provider: str = kwargs.get("litellm_params", {}).get(
+ "custom_llm_provider", None
+ )
+ if custom_llm_provider is None:
+ raise ValueError("custom_llm_provider is required")
+
+ budget_config = self._get_budget_config_for_provider(custom_llm_provider)
+ if budget_config:
+ # increment spend for provider
+ spend_key = (
+ f"provider_spend:{custom_llm_provider}:{budget_config.budget_duration}"
+ )
+ start_time_key = f"provider_budget_start_time:{custom_llm_provider}"
+ await self._increment_spend_for_key(
+ budget_config=budget_config,
+ spend_key=spend_key,
+ start_time_key=start_time_key,
+ response_cost=response_cost,
+ )
+
+ deployment_budget_config = self._get_budget_config_for_deployment(model_id)
+ if deployment_budget_config:
+ # increment spend for specific deployment id
+ deployment_spend_key = f"deployment_spend:{model_id}:{deployment_budget_config.budget_duration}"
+ deployment_start_time_key = f"deployment_budget_start_time:{model_id}"
+ await self._increment_spend_for_key(
+ budget_config=deployment_budget_config,
+ spend_key=deployment_spend_key,
+ start_time_key=deployment_start_time_key,
+ response_cost=response_cost,
+ )
+
+ request_tags = _get_tags_from_request_kwargs(kwargs)
+ if len(request_tags) > 0:
+ for _tag in request_tags:
+ _tag_budget_config = self._get_budget_config_for_tag(_tag)
+ if _tag_budget_config:
+ _tag_spend_key = (
+ f"tag_spend:{_tag}:{_tag_budget_config.budget_duration}"
+ )
+ _tag_start_time_key = f"tag_budget_start_time:{_tag}"
+ await self._increment_spend_for_key(
+ budget_config=_tag_budget_config,
+ spend_key=_tag_spend_key,
+ start_time_key=_tag_start_time_key,
+ response_cost=response_cost,
+ )
+
+ async def _increment_spend_for_key(
+ self,
+ budget_config: GenericBudgetInfo,
+ spend_key: str,
+ start_time_key: str,
+ response_cost: float,
+ ):
+ if budget_config.budget_duration is None:
+ return
+
+ current_time = datetime.now(timezone.utc).timestamp()
+ ttl_seconds = duration_in_seconds(budget_config.budget_duration)
+
+ budget_start = await self._get_or_set_budget_start_time(
+ start_time_key=start_time_key,
+ current_time=current_time,
+ ttl_seconds=ttl_seconds,
+ )
+
+ if budget_start is None:
+ # First spend for this provider
+ budget_start = await self._handle_new_budget_window(
+ spend_key=spend_key,
+ start_time_key=start_time_key,
+ current_time=current_time,
+ response_cost=response_cost,
+ ttl_seconds=ttl_seconds,
+ )
+ elif (current_time - budget_start) > ttl_seconds:
+ # Budget window expired - reset everything
+ verbose_router_logger.debug("Budget window expired - resetting everything")
+ budget_start = await self._handle_new_budget_window(
+ spend_key=spend_key,
+ start_time_key=start_time_key,
+ current_time=current_time,
+ response_cost=response_cost,
+ ttl_seconds=ttl_seconds,
+ )
+ else:
+ # Within existing window - increment spend
+ remaining_time = ttl_seconds - (current_time - budget_start)
+ ttl_for_increment = int(remaining_time)
+
+ await self._increment_spend_in_current_window(
+ spend_key=spend_key, response_cost=response_cost, ttl=ttl_for_increment
+ )
+
+ verbose_router_logger.debug(
+ f"Incremented spend for {spend_key} by {response_cost}"
+ )
+
+ async def periodic_sync_in_memory_spend_with_redis(self):
+ """
+ Handler that triggers sync_in_memory_spend_with_redis every DEFAULT_REDIS_SYNC_INTERVAL seconds
+
+ Required for multi-instance environment usage of provider budgets
+ """
+ while True:
+ try:
+ await self._sync_in_memory_spend_with_redis()
+ await asyncio.sleep(
+ DEFAULT_REDIS_SYNC_INTERVAL
+ ) # Wait for DEFAULT_REDIS_SYNC_INTERVAL seconds before next sync
+ except Exception as e:
+ verbose_router_logger.error(f"Error in periodic sync task: {str(e)}")
+ await asyncio.sleep(
+ DEFAULT_REDIS_SYNC_INTERVAL
+ ) # Still wait DEFAULT_REDIS_SYNC_INTERVAL seconds on error before retrying
+
+ async def _push_in_memory_increments_to_redis(self):
+ """
+ How this works:
+ - async_log_success_event collects all provider spend increments in `redis_increment_operation_queue`
+ - This function pushes all increments to Redis in a batched pipeline to optimize performance
+
+ Only runs if Redis is initialized
+ """
+ try:
+ if not self.dual_cache.redis_cache:
+ return # Redis is not initialized
+
+ verbose_router_logger.debug(
+ "Pushing Redis Increment Pipeline for queue: %s",
+ self.redis_increment_operation_queue,
+ )
+ if len(self.redis_increment_operation_queue) > 0:
+ asyncio.create_task(
+ self.dual_cache.redis_cache.async_increment_pipeline(
+ increment_list=self.redis_increment_operation_queue,
+ )
+ )
+
+ self.redis_increment_operation_queue = []
+
+ except Exception as e:
+ verbose_router_logger.error(
+ f"Error syncing in-memory cache with Redis: {str(e)}"
+ )
+
+ async def _sync_in_memory_spend_with_redis(self):
+ """
+ Ensures in-memory cache is updated with latest Redis values for all provider spends.
+
+ Why Do we need this?
+ - Optimization to hit sub 100ms latency. Performance was impacted when redis was used for read/write per request
+ - Use provider budgets in multi-instance environment, we use Redis to sync spend across all instances
+
+ What this does:
+ 1. Push all provider spend increments to Redis
+ 2. Fetch all current provider spend from Redis to update in-memory cache
+ """
+
+ try:
+ # No need to sync if Redis cache is not initialized
+ if self.dual_cache.redis_cache is None:
+ return
+
+ # 1. Push all provider spend increments to Redis
+ await self._push_in_memory_increments_to_redis()
+
+ # 2. Fetch all current provider spend from Redis to update in-memory cache
+ cache_keys = []
+
+ if self.provider_budget_config is not None:
+ for provider, config in self.provider_budget_config.items():
+ if config is None:
+ continue
+ cache_keys.append(
+ f"provider_spend:{provider}:{config.budget_duration}"
+ )
+
+ if self.deployment_budget_config is not None:
+ for model_id, config in self.deployment_budget_config.items():
+ if config is None:
+ continue
+ cache_keys.append(
+ f"deployment_spend:{model_id}:{config.budget_duration}"
+ )
+
+ if self.tag_budget_config is not None:
+ for tag, config in self.tag_budget_config.items():
+ if config is None:
+ continue
+ cache_keys.append(f"tag_spend:{tag}:{config.budget_duration}")
+
+ # Batch fetch current spend values from Redis
+ redis_values = await self.dual_cache.redis_cache.async_batch_get_cache(
+ key_list=cache_keys
+ )
+
+ # Update in-memory cache with Redis values
+ if isinstance(redis_values, dict): # Check if redis_values is a dictionary
+ for key, value in redis_values.items():
+ if value is not None:
+ await self.dual_cache.in_memory_cache.async_set_cache(
+ key=key, value=float(value)
+ )
+ verbose_router_logger.debug(
+ f"Updated in-memory cache for {key}: {value}"
+ )
+
+ except Exception as e:
+ verbose_router_logger.error(
+ f"Error syncing in-memory cache with Redis: {str(e)}"
+ )
+
+ def _get_budget_config_for_deployment(
+ self,
+ model_id: str,
+ ) -> Optional[GenericBudgetInfo]:
+ if self.deployment_budget_config is None:
+ return None
+ return self.deployment_budget_config.get(model_id, None)
+
+ def _get_budget_config_for_provider(
+ self, provider: str
+ ) -> Optional[GenericBudgetInfo]:
+ if self.provider_budget_config is None:
+ return None
+ return self.provider_budget_config.get(provider, None)
+
+ def _get_budget_config_for_tag(self, tag: str) -> Optional[GenericBudgetInfo]:
+ if self.tag_budget_config is None:
+ return None
+ return self.tag_budget_config.get(tag, None)
+
+ def _get_llm_provider_for_deployment(self, deployment: Dict) -> Optional[str]:
+ try:
+ _litellm_params: LiteLLM_Params = LiteLLM_Params(
+ **deployment.get("litellm_params", {"model": ""})
+ )
+ _, custom_llm_provider, _, _ = litellm.get_llm_provider(
+ model=_litellm_params.model,
+ litellm_params=_litellm_params,
+ )
+ except Exception:
+ verbose_router_logger.error(
+ f"Error getting LLM provider for deployment: {deployment}"
+ )
+ return None
+ return custom_llm_provider
+
+ def _track_provider_remaining_budget_prometheus(
+ self, provider: str, spend: float, budget_limit: float
+ ):
+ """
+ Optional helper - emit provider remaining budget metric to Prometheus
+
+ This is helpful for debugging and monitoring provider budget limits.
+ """
+
+ prometheus_logger = _get_prometheus_logger_from_callbacks()
+ if prometheus_logger:
+ prometheus_logger.track_provider_remaining_budget(
+ provider=provider,
+ spend=spend,
+ budget_limit=budget_limit,
+ )
+
+ async def _get_current_provider_spend(self, provider: str) -> Optional[float]:
+ """
+ GET the current spend for a provider from cache
+
+ used for GET /provider/budgets endpoint in spend_management_endpoints.py
+
+ Args:
+ provider (str): The provider to get spend for (e.g., "openai", "anthropic")
+
+ Returns:
+ Optional[float]: The current spend for the provider, or None if not found
+ """
+ budget_config = self._get_budget_config_for_provider(provider)
+ if budget_config is None:
+ return None
+
+ spend_key = f"provider_spend:{provider}:{budget_config.budget_duration}"
+
+ if self.dual_cache.redis_cache:
+ # use Redis as source of truth since that has spend across all instances
+ current_spend = await self.dual_cache.redis_cache.async_get_cache(spend_key)
+ else:
+ # use in-memory cache if Redis is not initialized
+ current_spend = await self.dual_cache.async_get_cache(spend_key)
+ return float(current_spend) if current_spend is not None else 0.0
+
+ async def _get_current_provider_budget_reset_at(
+ self, provider: str
+ ) -> Optional[str]:
+ budget_config = self._get_budget_config_for_provider(provider)
+ if budget_config is None:
+ return None
+
+ spend_key = f"provider_spend:{provider}:{budget_config.budget_duration}"
+ if self.dual_cache.redis_cache:
+ ttl_seconds = await self.dual_cache.redis_cache.async_get_ttl(spend_key)
+ else:
+ ttl_seconds = await self.dual_cache.async_get_ttl(spend_key)
+
+ if ttl_seconds is None:
+ return None
+
+ return (datetime.now(timezone.utc) + timedelta(seconds=ttl_seconds)).isoformat()
+
+ async def _init_provider_budget_in_cache(
+ self, provider: str, budget_config: GenericBudgetInfo
+ ):
+ """
+ Initialize provider budget in cache by storing the following keys if they don't exist:
+ - provider_spend:{provider}:{budget_config.time_period} - stores the current spend
+ - provider_budget_start_time:{provider} - stores the start time of the budget window
+
+ """
+
+ spend_key = f"provider_spend:{provider}:{budget_config.budget_duration}"
+ start_time_key = f"provider_budget_start_time:{provider}"
+ ttl_seconds: Optional[int] = None
+ if budget_config.budget_duration is not None:
+ ttl_seconds = duration_in_seconds(budget_config.budget_duration)
+
+ budget_start = await self.dual_cache.async_get_cache(start_time_key)
+ if budget_start is None:
+ budget_start = datetime.now(timezone.utc).timestamp()
+ await self.dual_cache.async_set_cache(
+ key=start_time_key, value=budget_start, ttl=ttl_seconds
+ )
+
+ _spend_key = await self.dual_cache.async_get_cache(spend_key)
+ if _spend_key is None:
+ await self.dual_cache.async_set_cache(
+ key=spend_key, value=0.0, ttl=ttl_seconds
+ )
+
+ @staticmethod
+ def should_init_router_budget_limiter(
+ provider_budget_config: Optional[dict],
+ model_list: Optional[
+ Union[List[DeploymentTypedDict], List[Dict[str, Any]]]
+ ] = None,
+ ):
+ """
+ Returns `True` if the router budget routing settings are set and RouterBudgetLimiting should be initialized
+
+ Either:
+ - provider_budget_config is set
+ - budgets are set for deployments in the model_list
+ - tag_budget_config is set
+ """
+ if provider_budget_config is not None:
+ return True
+
+ if litellm.tag_budget_config is not None:
+ return True
+
+ if model_list is None:
+ return False
+
+ for _model in model_list:
+ _litellm_params = _model.get("litellm_params", {})
+ if (
+ _litellm_params.get("max_budget")
+ or _litellm_params.get("budget_duration") is not None
+ ):
+ return True
+ return False
+
+ def _init_provider_budgets(self):
+ if self.provider_budget_config is not None:
+ # cast elements of provider_budget_config to GenericBudgetInfo
+ for provider, config in self.provider_budget_config.items():
+ if config is None:
+ raise ValueError(
+ f"No budget config found for provider {provider}, provider_budget_config: {self.provider_budget_config}"
+ )
+
+ if not isinstance(config, GenericBudgetInfo):
+ self.provider_budget_config[provider] = GenericBudgetInfo(
+ budget_limit=config.get("budget_limit"),
+ time_period=config.get("time_period"),
+ )
+ asyncio.create_task(
+ self._init_provider_budget_in_cache(
+ provider=provider,
+ budget_config=self.provider_budget_config[provider],
+ )
+ )
+
+ verbose_router_logger.debug(
+ f"Initalized Provider budget config: {self.provider_budget_config}"
+ )
+
+ def _init_deployment_budgets(
+ self,
+ model_list: Optional[
+ Union[List[DeploymentTypedDict], List[Dict[str, Any]]]
+ ] = None,
+ ):
+ if model_list is None:
+ return
+ for _model in model_list:
+ _litellm_params = _model.get("litellm_params", {})
+ _model_info: Dict = _model.get("model_info") or {}
+ _model_id = _model_info.get("id")
+ _max_budget = _litellm_params.get("max_budget")
+ _budget_duration = _litellm_params.get("budget_duration")
+
+ verbose_router_logger.debug(
+ f"Init Deployment Budget: max_budget: {_max_budget}, budget_duration: {_budget_duration}, model_id: {_model_id}"
+ )
+ if (
+ _max_budget is not None
+ and _budget_duration is not None
+ and _model_id is not None
+ ):
+ _budget_config = GenericBudgetInfo(
+ time_period=_budget_duration,
+ budget_limit=_max_budget,
+ )
+ if self.deployment_budget_config is None:
+ self.deployment_budget_config = {}
+ self.deployment_budget_config[_model_id] = _budget_config
+
+ verbose_router_logger.debug(
+ f"Initialized Deployment Budget Config: {self.deployment_budget_config}"
+ )
+
+ def _init_tag_budgets(self):
+ if litellm.tag_budget_config is None:
+ return
+ from litellm.proxy.proxy_server import CommonProxyErrors, premium_user
+
+ if premium_user is not True:
+ raise ValueError(
+ f"Tag budgets are an Enterprise only feature, {CommonProxyErrors.not_premium_user}"
+ )
+
+ if self.tag_budget_config is None:
+ self.tag_budget_config = {}
+
+ for _tag, _tag_budget_config in litellm.tag_budget_config.items():
+ if isinstance(_tag_budget_config, dict):
+ _tag_budget_config = BudgetConfig(**_tag_budget_config)
+ _generic_budget_config = GenericBudgetInfo(
+ time_period=_tag_budget_config.budget_duration,
+ budget_limit=_tag_budget_config.max_budget,
+ )
+ self.tag_budget_config[_tag] = _generic_budget_config
+
+ verbose_router_logger.debug(
+ f"Initialized Tag Budget Config: {self.tag_budget_config}"
+ )
diff --git a/.venv/lib/python3.12/site-packages/litellm/router_strategy/least_busy.py b/.venv/lib/python3.12/site-packages/litellm/router_strategy/least_busy.py
new file mode 100644
index 00000000..12f3f01c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/router_strategy/least_busy.py
@@ -0,0 +1,252 @@
+#### What this does ####
+# identifies least busy deployment
+# How is this achieved?
+# - Before each call, have the router print the state of requests {"deployment": "requests_in_flight"}
+# - use litellm.input_callbacks to log when a request is just about to be made to a model - {"deployment-id": traffic}
+# - use litellm.success + failure callbacks to log when a request completed
+# - in get_available_deployment, for a given model group name -> pick based on traffic
+
+import random
+from typing import Optional
+
+from litellm.caching.caching import DualCache
+from litellm.integrations.custom_logger import CustomLogger
+
+
+class LeastBusyLoggingHandler(CustomLogger):
+ test_flag: bool = False
+ logged_success: int = 0
+ logged_failure: int = 0
+
+ def __init__(self, router_cache: DualCache, model_list: list):
+ self.router_cache = router_cache
+ self.mapping_deployment_to_id: dict = {}
+ self.model_list = model_list
+
+ def log_pre_api_call(self, model, messages, kwargs):
+ """
+ Log when a model is being used.
+
+ Caching based on model group.
+ """
+ try:
+ if kwargs["litellm_params"].get("metadata") is None:
+ pass
+ else:
+ model_group = kwargs["litellm_params"]["metadata"].get(
+ "model_group", None
+ )
+ id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
+ if model_group is None or id is None:
+ return
+ elif isinstance(id, int):
+ id = str(id)
+
+ request_count_api_key = f"{model_group}_request_count"
+ # update cache
+ request_count_dict = (
+ self.router_cache.get_cache(key=request_count_api_key) or {}
+ )
+ request_count_dict[id] = request_count_dict.get(id, 0) + 1
+
+ self.router_cache.set_cache(
+ key=request_count_api_key, value=request_count_dict
+ )
+ except Exception:
+ pass
+
+ def log_success_event(self, kwargs, response_obj, start_time, end_time):
+ try:
+ if kwargs["litellm_params"].get("metadata") is None:
+ pass
+ else:
+ model_group = kwargs["litellm_params"]["metadata"].get(
+ "model_group", None
+ )
+
+ id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
+ if model_group is None or id is None:
+ return
+ elif isinstance(id, int):
+ id = str(id)
+
+ request_count_api_key = f"{model_group}_request_count"
+ # decrement count in cache
+ request_count_dict = (
+ self.router_cache.get_cache(key=request_count_api_key) or {}
+ )
+ request_count_value: Optional[int] = request_count_dict.get(id, 0)
+ if request_count_value is None:
+ return
+ request_count_dict[id] = request_count_value - 1
+ self.router_cache.set_cache(
+ key=request_count_api_key, value=request_count_dict
+ )
+
+ ### TESTING ###
+ if self.test_flag:
+ self.logged_success += 1
+ except Exception:
+ pass
+
+ def log_failure_event(self, kwargs, response_obj, start_time, end_time):
+ try:
+ if kwargs["litellm_params"].get("metadata") is None:
+ pass
+ else:
+ model_group = kwargs["litellm_params"]["metadata"].get(
+ "model_group", None
+ )
+ id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
+ if model_group is None or id is None:
+ return
+ elif isinstance(id, int):
+ id = str(id)
+
+ request_count_api_key = f"{model_group}_request_count"
+ # decrement count in cache
+ request_count_dict = (
+ self.router_cache.get_cache(key=request_count_api_key) or {}
+ )
+ request_count_value: Optional[int] = request_count_dict.get(id, 0)
+ if request_count_value is None:
+ return
+ request_count_dict[id] = request_count_value - 1
+ self.router_cache.set_cache(
+ key=request_count_api_key, value=request_count_dict
+ )
+
+ ### TESTING ###
+ if self.test_flag:
+ self.logged_failure += 1
+ except Exception:
+ pass
+
+ async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
+ try:
+ if kwargs["litellm_params"].get("metadata") is None:
+ pass
+ else:
+ model_group = kwargs["litellm_params"]["metadata"].get(
+ "model_group", None
+ )
+
+ id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
+ if model_group is None or id is None:
+ return
+ elif isinstance(id, int):
+ id = str(id)
+
+ request_count_api_key = f"{model_group}_request_count"
+ # decrement count in cache
+ request_count_dict = (
+ await self.router_cache.async_get_cache(key=request_count_api_key)
+ or {}
+ )
+ request_count_value: Optional[int] = request_count_dict.get(id, 0)
+ if request_count_value is None:
+ return
+ request_count_dict[id] = request_count_value - 1
+ await self.router_cache.async_set_cache(
+ key=request_count_api_key, value=request_count_dict
+ )
+
+ ### TESTING ###
+ if self.test_flag:
+ self.logged_success += 1
+ except Exception:
+ pass
+
+ async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
+ try:
+ if kwargs["litellm_params"].get("metadata") is None:
+ pass
+ else:
+ model_group = kwargs["litellm_params"]["metadata"].get(
+ "model_group", None
+ )
+ id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
+ if model_group is None or id is None:
+ return
+ elif isinstance(id, int):
+ id = str(id)
+
+ request_count_api_key = f"{model_group}_request_count"
+ # decrement count in cache
+ request_count_dict = (
+ await self.router_cache.async_get_cache(key=request_count_api_key)
+ or {}
+ )
+ request_count_value: Optional[int] = request_count_dict.get(id, 0)
+ if request_count_value is None:
+ return
+ request_count_dict[id] = request_count_value - 1
+ await self.router_cache.async_set_cache(
+ key=request_count_api_key, value=request_count_dict
+ )
+
+ ### TESTING ###
+ if self.test_flag:
+ self.logged_failure += 1
+ except Exception:
+ pass
+
+ def _get_available_deployments(
+ self,
+ healthy_deployments: list,
+ all_deployments: dict,
+ ):
+ """
+ Helper to get deployments using least busy strategy
+ """
+ for d in healthy_deployments:
+ ## if healthy deployment not yet used
+ if d["model_info"]["id"] not in all_deployments:
+ all_deployments[d["model_info"]["id"]] = 0
+ # map deployment to id
+ # pick least busy deployment
+ min_traffic = float("inf")
+ min_deployment = None
+ for k, v in all_deployments.items():
+ if v < min_traffic:
+ min_traffic = v
+ min_deployment = k
+ if min_deployment is not None:
+ ## check if min deployment is a string, if so, cast it to int
+ for m in healthy_deployments:
+ if m["model_info"]["id"] == min_deployment:
+ return m
+ min_deployment = random.choice(healthy_deployments)
+ else:
+ min_deployment = random.choice(healthy_deployments)
+ return min_deployment
+
+ def get_available_deployments(
+ self,
+ model_group: str,
+ healthy_deployments: list,
+ ):
+ """
+ Sync helper to get deployments using least busy strategy
+ """
+ request_count_api_key = f"{model_group}_request_count"
+ all_deployments = self.router_cache.get_cache(key=request_count_api_key) or {}
+ return self._get_available_deployments(
+ healthy_deployments=healthy_deployments,
+ all_deployments=all_deployments,
+ )
+
+ async def async_get_available_deployments(
+ self, model_group: str, healthy_deployments: list
+ ):
+ """
+ Async helper to get deployments using least busy strategy
+ """
+ request_count_api_key = f"{model_group}_request_count"
+ all_deployments = (
+ await self.router_cache.async_get_cache(key=request_count_api_key) or {}
+ )
+ return self._get_available_deployments(
+ healthy_deployments=healthy_deployments,
+ all_deployments=all_deployments,
+ )
diff --git a/.venv/lib/python3.12/site-packages/litellm/router_strategy/lowest_cost.py b/.venv/lib/python3.12/site-packages/litellm/router_strategy/lowest_cost.py
new file mode 100644
index 00000000..bd28f6dc
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/router_strategy/lowest_cost.py
@@ -0,0 +1,333 @@
+#### What this does ####
+# picks based on response time (for streaming, this is time to first token)
+from datetime import datetime, timedelta
+from typing import Dict, List, Optional, Union
+
+import litellm
+from litellm import ModelResponse, token_counter, verbose_logger
+from litellm._logging import verbose_router_logger
+from litellm.caching.caching import DualCache
+from litellm.integrations.custom_logger import CustomLogger
+
+
+class LowestCostLoggingHandler(CustomLogger):
+ test_flag: bool = False
+ logged_success: int = 0
+ logged_failure: int = 0
+
+ def __init__(
+ self, router_cache: DualCache, model_list: list, routing_args: dict = {}
+ ):
+ self.router_cache = router_cache
+ self.model_list = model_list
+
+ def log_success_event(self, kwargs, response_obj, start_time, end_time):
+ try:
+ """
+ Update usage on success
+ """
+ if kwargs["litellm_params"].get("metadata") is None:
+ pass
+ else:
+ model_group = kwargs["litellm_params"]["metadata"].get(
+ "model_group", None
+ )
+
+ id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
+ if model_group is None or id is None:
+ return
+ elif isinstance(id, int):
+ id = str(id)
+
+ # ------------
+ # Setup values
+ # ------------
+ """
+ {
+ {model_group}_map: {
+ id: {
+ f"{date:hour:minute}" : {"tpm": 34, "rpm": 3}
+ }
+ }
+ }
+ """
+ current_date = datetime.now().strftime("%Y-%m-%d")
+ current_hour = datetime.now().strftime("%H")
+ current_minute = datetime.now().strftime("%M")
+ precise_minute = f"{current_date}-{current_hour}-{current_minute}"
+ cost_key = f"{model_group}_map"
+
+ response_ms: timedelta = end_time - start_time
+
+ total_tokens = 0
+
+ if isinstance(response_obj, ModelResponse):
+ _usage = getattr(response_obj, "usage", None)
+ if _usage is not None and isinstance(_usage, litellm.Usage):
+ completion_tokens = _usage.completion_tokens
+ total_tokens = _usage.total_tokens
+ float(response_ms.total_seconds() / completion_tokens)
+
+ # ------------
+ # Update usage
+ # ------------
+
+ request_count_dict = self.router_cache.get_cache(key=cost_key) or {}
+
+ # check local result first
+
+ if id not in request_count_dict:
+ request_count_dict[id] = {}
+
+ if precise_minute not in request_count_dict[id]:
+ request_count_dict[id][precise_minute] = {}
+
+ ## TPM
+ request_count_dict[id][precise_minute]["tpm"] = (
+ request_count_dict[id][precise_minute].get("tpm", 0) + total_tokens
+ )
+
+ ## RPM
+ request_count_dict[id][precise_minute]["rpm"] = (
+ request_count_dict[id][precise_minute].get("rpm", 0) + 1
+ )
+
+ self.router_cache.set_cache(key=cost_key, value=request_count_dict)
+
+ ### TESTING ###
+ if self.test_flag:
+ self.logged_success += 1
+ except Exception as e:
+ verbose_logger.exception(
+ "litellm.router_strategy.lowest_cost.py::log_success_event(): Exception occured - {}".format(
+ str(e)
+ )
+ )
+ pass
+
+ async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
+ try:
+ """
+ Update cost usage on success
+ """
+ if kwargs["litellm_params"].get("metadata") is None:
+ pass
+ else:
+ model_group = kwargs["litellm_params"]["metadata"].get(
+ "model_group", None
+ )
+
+ id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
+ if model_group is None or id is None:
+ return
+ elif isinstance(id, int):
+ id = str(id)
+
+ # ------------
+ # Setup values
+ # ------------
+ """
+ {
+ {model_group}_map: {
+ id: {
+ "cost": [..]
+ f"{date:hour:minute}" : {"tpm": 34, "rpm": 3}
+ }
+ }
+ }
+ """
+ cost_key = f"{model_group}_map"
+
+ current_date = datetime.now().strftime("%Y-%m-%d")
+ current_hour = datetime.now().strftime("%H")
+ current_minute = datetime.now().strftime("%M")
+ precise_minute = f"{current_date}-{current_hour}-{current_minute}"
+
+ response_ms: timedelta = end_time - start_time
+
+ total_tokens = 0
+
+ if isinstance(response_obj, ModelResponse):
+ _usage = getattr(response_obj, "usage", None)
+ if _usage is not None and isinstance(_usage, litellm.Usage):
+ completion_tokens = _usage.completion_tokens
+ total_tokens = _usage.total_tokens
+
+ float(response_ms.total_seconds() / completion_tokens)
+
+ # ------------
+ # Update usage
+ # ------------
+
+ request_count_dict = (
+ await self.router_cache.async_get_cache(key=cost_key) or {}
+ )
+
+ if id not in request_count_dict:
+ request_count_dict[id] = {}
+ if precise_minute not in request_count_dict[id]:
+ request_count_dict[id][precise_minute] = {}
+
+ ## TPM
+ request_count_dict[id][precise_minute]["tpm"] = (
+ request_count_dict[id][precise_minute].get("tpm", 0) + total_tokens
+ )
+
+ ## RPM
+ request_count_dict[id][precise_minute]["rpm"] = (
+ request_count_dict[id][precise_minute].get("rpm", 0) + 1
+ )
+
+ await self.router_cache.async_set_cache(
+ key=cost_key, value=request_count_dict
+ ) # reset map within window
+
+ ### TESTING ###
+ if self.test_flag:
+ self.logged_success += 1
+ except Exception as e:
+ verbose_logger.exception(
+ "litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format(
+ str(e)
+ )
+ )
+ pass
+
+ async def async_get_available_deployments( # noqa: PLR0915
+ self,
+ model_group: str,
+ healthy_deployments: list,
+ messages: Optional[List[Dict[str, str]]] = None,
+ input: Optional[Union[str, List]] = None,
+ request_kwargs: Optional[Dict] = None,
+ ):
+ """
+ Returns a deployment with the lowest cost
+ """
+ cost_key = f"{model_group}_map"
+
+ request_count_dict = await self.router_cache.async_get_cache(key=cost_key) or {}
+
+ # -----------------------
+ # Find lowest used model
+ # ----------------------
+ float("inf")
+
+ current_date = datetime.now().strftime("%Y-%m-%d")
+ current_hour = datetime.now().strftime("%H")
+ current_minute = datetime.now().strftime("%M")
+ precise_minute = f"{current_date}-{current_hour}-{current_minute}"
+
+ if request_count_dict is None: # base case
+ return
+
+ all_deployments = request_count_dict
+ for d in healthy_deployments:
+ ## if healthy deployment not yet used
+ if d["model_info"]["id"] not in all_deployments:
+ all_deployments[d["model_info"]["id"]] = {
+ precise_minute: {"tpm": 0, "rpm": 0},
+ }
+
+ try:
+ input_tokens = token_counter(messages=messages, text=input)
+ except Exception:
+ input_tokens = 0
+
+ # randomly sample from all_deployments, incase all deployments have latency=0.0
+ _items = all_deployments.items()
+
+ ### GET AVAILABLE DEPLOYMENTS ### filter out any deployments > tpm/rpm limits
+ potential_deployments = []
+ _cost_per_deployment = {}
+ for item, item_map in all_deployments.items():
+ ## get the item from model list
+ _deployment = None
+ for m in healthy_deployments:
+ if item == m["model_info"]["id"]:
+ _deployment = m
+
+ if _deployment is None:
+ continue # skip to next one
+
+ _deployment_tpm = (
+ _deployment.get("tpm", None)
+ or _deployment.get("litellm_params", {}).get("tpm", None)
+ or _deployment.get("model_info", {}).get("tpm", None)
+ or float("inf")
+ )
+
+ _deployment_rpm = (
+ _deployment.get("rpm", None)
+ or _deployment.get("litellm_params", {}).get("rpm", None)
+ or _deployment.get("model_info", {}).get("rpm", None)
+ or float("inf")
+ )
+ item_litellm_model_name = _deployment.get("litellm_params", {}).get("model")
+ item_litellm_model_cost_map = litellm.model_cost.get(
+ item_litellm_model_name, {}
+ )
+
+ # check if user provided input_cost_per_token and output_cost_per_token in litellm_params
+ item_input_cost = None
+ item_output_cost = None
+ if _deployment.get("litellm_params", {}).get("input_cost_per_token", None):
+ item_input_cost = _deployment.get("litellm_params", {}).get(
+ "input_cost_per_token"
+ )
+
+ if _deployment.get("litellm_params", {}).get("output_cost_per_token", None):
+ item_output_cost = _deployment.get("litellm_params", {}).get(
+ "output_cost_per_token"
+ )
+
+ if item_input_cost is None:
+ item_input_cost = item_litellm_model_cost_map.get(
+ "input_cost_per_token", 5.0
+ )
+
+ if item_output_cost is None:
+ item_output_cost = item_litellm_model_cost_map.get(
+ "output_cost_per_token", 5.0
+ )
+
+ # if litellm["model"] is not in model_cost map -> use item_cost = $10
+
+ item_cost = item_input_cost + item_output_cost
+
+ item_rpm = item_map.get(precise_minute, {}).get("rpm", 0)
+ item_tpm = item_map.get(precise_minute, {}).get("tpm", 0)
+
+ verbose_router_logger.debug(
+ f"item_cost: {item_cost}, item_tpm: {item_tpm}, item_rpm: {item_rpm}, model_id: {_deployment.get('model_info', {}).get('id')}"
+ )
+
+ # -------------- #
+ # Debugging Logic
+ # -------------- #
+ # We use _cost_per_deployment to log to langfuse, slack - this is not used to make a decision on routing
+ # this helps a user to debug why the router picked a specfic deployment #
+ _deployment_api_base = _deployment.get("litellm_params", {}).get(
+ "api_base", ""
+ )
+ if _deployment_api_base is not None:
+ _cost_per_deployment[_deployment_api_base] = item_cost
+ # -------------- #
+ # End of Debugging Logic
+ # -------------- #
+
+ if (
+ item_tpm + input_tokens > _deployment_tpm
+ or item_rpm + 1 > _deployment_rpm
+ ): # if user passed in tpm / rpm in the model_list
+ continue
+ else:
+ potential_deployments.append((_deployment, item_cost))
+
+ if len(potential_deployments) == 0:
+ return None
+
+ potential_deployments = sorted(potential_deployments, key=lambda x: x[1])
+
+ selected_deployment = potential_deployments[0][0]
+ return selected_deployment
diff --git a/.venv/lib/python3.12/site-packages/litellm/router_strategy/lowest_latency.py b/.venv/lib/python3.12/site-packages/litellm/router_strategy/lowest_latency.py
new file mode 100644
index 00000000..b049c942
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/router_strategy/lowest_latency.py
@@ -0,0 +1,590 @@
+#### What this does ####
+# picks based on response time (for streaming, this is time to first token)
+import random
+from datetime import datetime, timedelta
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
+
+import litellm
+from litellm import ModelResponse, token_counter, verbose_logger
+from litellm.caching.caching import DualCache
+from litellm.integrations.custom_logger import CustomLogger
+from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
+from litellm.types.utils import LiteLLMPydanticObjectBase
+
+if TYPE_CHECKING:
+ from opentelemetry.trace import Span as _Span
+
+ Span = _Span
+else:
+ Span = Any
+
+
+class RoutingArgs(LiteLLMPydanticObjectBase):
+ ttl: float = 1 * 60 * 60 # 1 hour
+ lowest_latency_buffer: float = 0
+ max_latency_list_size: int = 10
+
+
+class LowestLatencyLoggingHandler(CustomLogger):
+ test_flag: bool = False
+ logged_success: int = 0
+ logged_failure: int = 0
+
+ def __init__(
+ self, router_cache: DualCache, model_list: list, routing_args: dict = {}
+ ):
+ self.router_cache = router_cache
+ self.model_list = model_list
+ self.routing_args = RoutingArgs(**routing_args)
+
+ def log_success_event( # noqa: PLR0915
+ self, kwargs, response_obj, start_time, end_time
+ ):
+ try:
+ """
+ Update latency usage on success
+ """
+ if kwargs["litellm_params"].get("metadata") is None:
+ pass
+ else:
+ model_group = kwargs["litellm_params"]["metadata"].get(
+ "model_group", None
+ )
+
+ id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
+ if model_group is None or id is None:
+ return
+ elif isinstance(id, int):
+ id = str(id)
+
+ # ------------
+ # Setup values
+ # ------------
+ """
+ {
+ {model_group}_map: {
+ id: {
+ "latency": [..]
+ f"{date:hour:minute}" : {"tpm": 34, "rpm": 3}
+ }
+ }
+ }
+ """
+ latency_key = f"{model_group}_map"
+
+ current_date = datetime.now().strftime("%Y-%m-%d")
+ current_hour = datetime.now().strftime("%H")
+ current_minute = datetime.now().strftime("%M")
+ precise_minute = f"{current_date}-{current_hour}-{current_minute}"
+
+ response_ms: timedelta = end_time - start_time
+ time_to_first_token_response_time: Optional[timedelta] = None
+
+ if kwargs.get("stream", None) is not None and kwargs["stream"] is True:
+ # only log ttft for streaming request
+ time_to_first_token_response_time = (
+ kwargs.get("completion_start_time", end_time) - start_time
+ )
+
+ final_value: Union[float, timedelta] = response_ms
+ time_to_first_token: Optional[float] = None
+ total_tokens = 0
+
+ if isinstance(response_obj, ModelResponse):
+ _usage = getattr(response_obj, "usage", None)
+ if _usage is not None:
+ completion_tokens = _usage.completion_tokens
+ total_tokens = _usage.total_tokens
+ final_value = float(
+ response_ms.total_seconds() / completion_tokens
+ )
+
+ if time_to_first_token_response_time is not None:
+ time_to_first_token = float(
+ time_to_first_token_response_time.total_seconds()
+ / completion_tokens
+ )
+
+ # ------------
+ # Update usage
+ # ------------
+ parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
+ request_count_dict = (
+ self.router_cache.get_cache(
+ key=latency_key, parent_otel_span=parent_otel_span
+ )
+ or {}
+ )
+
+ if id not in request_count_dict:
+ request_count_dict[id] = {}
+
+ ## Latency
+ if (
+ len(request_count_dict[id].get("latency", []))
+ < self.routing_args.max_latency_list_size
+ ):
+ request_count_dict[id].setdefault("latency", []).append(final_value)
+ else:
+ request_count_dict[id]["latency"] = request_count_dict[id][
+ "latency"
+ ][: self.routing_args.max_latency_list_size - 1] + [final_value]
+
+ ## Time to first token
+ if time_to_first_token is not None:
+ if (
+ len(request_count_dict[id].get("time_to_first_token", []))
+ < self.routing_args.max_latency_list_size
+ ):
+ request_count_dict[id].setdefault(
+ "time_to_first_token", []
+ ).append(time_to_first_token)
+ else:
+ request_count_dict[id][
+ "time_to_first_token"
+ ] = request_count_dict[id]["time_to_first_token"][
+ : self.routing_args.max_latency_list_size - 1
+ ] + [
+ time_to_first_token
+ ]
+
+ if precise_minute not in request_count_dict[id]:
+ request_count_dict[id][precise_minute] = {}
+
+ ## TPM
+ request_count_dict[id][precise_minute]["tpm"] = (
+ request_count_dict[id][precise_minute].get("tpm", 0) + total_tokens
+ )
+
+ ## RPM
+ request_count_dict[id][precise_minute]["rpm"] = (
+ request_count_dict[id][precise_minute].get("rpm", 0) + 1
+ )
+
+ self.router_cache.set_cache(
+ key=latency_key, value=request_count_dict, ttl=self.routing_args.ttl
+ ) # reset map within window
+
+ ### TESTING ###
+ if self.test_flag:
+ self.logged_success += 1
+ except Exception as e:
+ verbose_logger.exception(
+ "litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format(
+ str(e)
+ )
+ )
+ pass
+
+ async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
+ """
+ Check if Timeout Error, if timeout set deployment latency -> 100
+ """
+ try:
+ _exception = kwargs.get("exception", None)
+ if isinstance(_exception, litellm.Timeout):
+ if kwargs["litellm_params"].get("metadata") is None:
+ pass
+ else:
+ model_group = kwargs["litellm_params"]["metadata"].get(
+ "model_group", None
+ )
+
+ id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
+ if model_group is None or id is None:
+ return
+ elif isinstance(id, int):
+ id = str(id)
+
+ # ------------
+ # Setup values
+ # ------------
+ """
+ {
+ {model_group}_map: {
+ id: {
+ "latency": [..]
+ f"{date:hour:minute}" : {"tpm": 34, "rpm": 3}
+ }
+ }
+ }
+ """
+ latency_key = f"{model_group}_map"
+ request_count_dict = (
+ await self.router_cache.async_get_cache(key=latency_key) or {}
+ )
+
+ if id not in request_count_dict:
+ request_count_dict[id] = {}
+
+ ## Latency - give 1000s penalty for failing
+ if (
+ len(request_count_dict[id].get("latency", []))
+ < self.routing_args.max_latency_list_size
+ ):
+ request_count_dict[id].setdefault("latency", []).append(1000.0)
+ else:
+ request_count_dict[id]["latency"] = request_count_dict[id][
+ "latency"
+ ][: self.routing_args.max_latency_list_size - 1] + [1000.0]
+
+ await self.router_cache.async_set_cache(
+ key=latency_key,
+ value=request_count_dict,
+ ttl=self.routing_args.ttl,
+ ) # reset map within window
+ else:
+ # do nothing if it's not a timeout error
+ return
+ except Exception as e:
+ verbose_logger.exception(
+ "litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format(
+ str(e)
+ )
+ )
+ pass
+
+ async def async_log_success_event( # noqa: PLR0915
+ self, kwargs, response_obj, start_time, end_time
+ ):
+ try:
+ """
+ Update latency usage on success
+ """
+ if kwargs["litellm_params"].get("metadata") is None:
+ pass
+ else:
+ model_group = kwargs["litellm_params"]["metadata"].get(
+ "model_group", None
+ )
+
+ id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
+ if model_group is None or id is None:
+ return
+ elif isinstance(id, int):
+ id = str(id)
+
+ # ------------
+ # Setup values
+ # ------------
+ """
+ {
+ {model_group}_map: {
+ id: {
+ "latency": [..]
+ "time_to_first_token": [..]
+ f"{date:hour:minute}" : {"tpm": 34, "rpm": 3}
+ }
+ }
+ }
+ """
+ latency_key = f"{model_group}_map"
+
+ current_date = datetime.now().strftime("%Y-%m-%d")
+ current_hour = datetime.now().strftime("%H")
+ current_minute = datetime.now().strftime("%M")
+ precise_minute = f"{current_date}-{current_hour}-{current_minute}"
+
+ response_ms: timedelta = end_time - start_time
+ time_to_first_token_response_time: Optional[timedelta] = None
+ if kwargs.get("stream", None) is not None and kwargs["stream"] is True:
+ # only log ttft for streaming request
+ time_to_first_token_response_time = (
+ kwargs.get("completion_start_time", end_time) - start_time
+ )
+
+ final_value: Union[float, timedelta] = response_ms
+ total_tokens = 0
+ time_to_first_token: Optional[float] = None
+
+ if isinstance(response_obj, ModelResponse):
+ _usage = getattr(response_obj, "usage", None)
+ if _usage is not None:
+ completion_tokens = _usage.completion_tokens
+ total_tokens = _usage.total_tokens
+ final_value = float(
+ response_ms.total_seconds() / completion_tokens
+ )
+
+ if time_to_first_token_response_time is not None:
+ time_to_first_token = float(
+ time_to_first_token_response_time.total_seconds()
+ / completion_tokens
+ )
+ # ------------
+ # Update usage
+ # ------------
+ parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
+ request_count_dict = (
+ await self.router_cache.async_get_cache(
+ key=latency_key,
+ parent_otel_span=parent_otel_span,
+ local_only=True,
+ )
+ or {}
+ )
+
+ if id not in request_count_dict:
+ request_count_dict[id] = {}
+
+ ## Latency
+ if (
+ len(request_count_dict[id].get("latency", []))
+ < self.routing_args.max_latency_list_size
+ ):
+ request_count_dict[id].setdefault("latency", []).append(final_value)
+ else:
+ request_count_dict[id]["latency"] = request_count_dict[id][
+ "latency"
+ ][: self.routing_args.max_latency_list_size - 1] + [final_value]
+
+ ## Time to first token
+ if time_to_first_token is not None:
+ if (
+ len(request_count_dict[id].get("time_to_first_token", []))
+ < self.routing_args.max_latency_list_size
+ ):
+ request_count_dict[id].setdefault(
+ "time_to_first_token", []
+ ).append(time_to_first_token)
+ else:
+ request_count_dict[id][
+ "time_to_first_token"
+ ] = request_count_dict[id]["time_to_first_token"][
+ : self.routing_args.max_latency_list_size - 1
+ ] + [
+ time_to_first_token
+ ]
+
+ if precise_minute not in request_count_dict[id]:
+ request_count_dict[id][precise_minute] = {}
+
+ ## TPM
+ request_count_dict[id][precise_minute]["tpm"] = (
+ request_count_dict[id][precise_minute].get("tpm", 0) + total_tokens
+ )
+
+ ## RPM
+ request_count_dict[id][precise_minute]["rpm"] = (
+ request_count_dict[id][precise_minute].get("rpm", 0) + 1
+ )
+
+ await self.router_cache.async_set_cache(
+ key=latency_key, value=request_count_dict, ttl=self.routing_args.ttl
+ ) # reset map within window
+
+ ### TESTING ###
+ if self.test_flag:
+ self.logged_success += 1
+ except Exception as e:
+ verbose_logger.exception(
+ "litellm.router_strategy.lowest_latency.py::async_log_success_event(): Exception occured - {}".format(
+ str(e)
+ )
+ )
+ pass
+
+ def _get_available_deployments( # noqa: PLR0915
+ self,
+ model_group: str,
+ healthy_deployments: list,
+ messages: Optional[List[Dict[str, str]]] = None,
+ input: Optional[Union[str, List]] = None,
+ request_kwargs: Optional[Dict] = None,
+ request_count_dict: Optional[Dict] = None,
+ ):
+ """Common logic for both sync and async get_available_deployments"""
+
+ # -----------------------
+ # Find lowest used model
+ # ----------------------
+ _latency_per_deployment = {}
+ lowest_latency = float("inf")
+
+ current_date = datetime.now().strftime("%Y-%m-%d")
+ current_hour = datetime.now().strftime("%H")
+ current_minute = datetime.now().strftime("%M")
+ precise_minute = f"{current_date}-{current_hour}-{current_minute}"
+
+ deployment = None
+
+ if request_count_dict is None: # base case
+ return
+
+ all_deployments = request_count_dict
+ for d in healthy_deployments:
+ ## if healthy deployment not yet used
+ if d["model_info"]["id"] not in all_deployments:
+ all_deployments[d["model_info"]["id"]] = {
+ "latency": [0],
+ precise_minute: {"tpm": 0, "rpm": 0},
+ }
+
+ try:
+ input_tokens = token_counter(messages=messages, text=input)
+ except Exception:
+ input_tokens = 0
+
+ # randomly sample from all_deployments, incase all deployments have latency=0.0
+ _items = all_deployments.items()
+
+ _all_deployments = random.sample(list(_items), len(_items))
+ all_deployments = dict(_all_deployments)
+ ### GET AVAILABLE DEPLOYMENTS ### filter out any deployments > tpm/rpm limits
+
+ potential_deployments = []
+ for item, item_map in all_deployments.items():
+ ## get the item from model list
+ _deployment = None
+ for m in healthy_deployments:
+ if item == m["model_info"]["id"]:
+ _deployment = m
+
+ if _deployment is None:
+ continue # skip to next one
+
+ _deployment_tpm = (
+ _deployment.get("tpm", None)
+ or _deployment.get("litellm_params", {}).get("tpm", None)
+ or _deployment.get("model_info", {}).get("tpm", None)
+ or float("inf")
+ )
+
+ _deployment_rpm = (
+ _deployment.get("rpm", None)
+ or _deployment.get("litellm_params", {}).get("rpm", None)
+ or _deployment.get("model_info", {}).get("rpm", None)
+ or float("inf")
+ )
+ item_latency = item_map.get("latency", [])
+ item_ttft_latency = item_map.get("time_to_first_token", [])
+ item_rpm = item_map.get(precise_minute, {}).get("rpm", 0)
+ item_tpm = item_map.get(precise_minute, {}).get("tpm", 0)
+
+ # get average latency or average ttft (depending on streaming/non-streaming)
+ total: float = 0.0
+ if (
+ request_kwargs is not None
+ and request_kwargs.get("stream", None) is not None
+ and request_kwargs["stream"] is True
+ and len(item_ttft_latency) > 0
+ ):
+ for _call_latency in item_ttft_latency:
+ if isinstance(_call_latency, float):
+ total += _call_latency
+ else:
+ for _call_latency in item_latency:
+ if isinstance(_call_latency, float):
+ total += _call_latency
+ item_latency = total / len(item_latency)
+
+ # -------------- #
+ # Debugging Logic
+ # -------------- #
+ # We use _latency_per_deployment to log to langfuse, slack - this is not used to make a decision on routing
+ # this helps a user to debug why the router picked a specfic deployment #
+ _deployment_api_base = _deployment.get("litellm_params", {}).get(
+ "api_base", ""
+ )
+ if _deployment_api_base is not None:
+ _latency_per_deployment[_deployment_api_base] = item_latency
+ # -------------- #
+ # End of Debugging Logic
+ # -------------- #
+
+ if (
+ item_tpm + input_tokens > _deployment_tpm
+ or item_rpm + 1 > _deployment_rpm
+ ): # if user passed in tpm / rpm in the model_list
+ continue
+ else:
+ potential_deployments.append((_deployment, item_latency))
+
+ if len(potential_deployments) == 0:
+ return None
+
+ # Sort potential deployments by latency
+ sorted_deployments = sorted(potential_deployments, key=lambda x: x[1])
+
+ # Find lowest latency deployment
+ lowest_latency = sorted_deployments[0][1]
+
+ # Find deployments within buffer of lowest latency
+ buffer = self.routing_args.lowest_latency_buffer * lowest_latency
+
+ valid_deployments = [
+ x for x in sorted_deployments if x[1] <= lowest_latency + buffer
+ ]
+
+ # Pick a random deployment from valid deployments
+ random_valid_deployment = random.choice(valid_deployments)
+ deployment = random_valid_deployment[0]
+
+ if request_kwargs is not None and "metadata" in request_kwargs:
+ request_kwargs["metadata"][
+ "_latency_per_deployment"
+ ] = _latency_per_deployment
+ return deployment
+
+ async def async_get_available_deployments(
+ self,
+ model_group: str,
+ healthy_deployments: list,
+ messages: Optional[List[Dict[str, str]]] = None,
+ input: Optional[Union[str, List]] = None,
+ request_kwargs: Optional[Dict] = None,
+ ):
+ # get list of potential deployments
+ latency_key = f"{model_group}_map"
+
+ parent_otel_span: Optional[Span] = _get_parent_otel_span_from_kwargs(
+ request_kwargs
+ )
+ request_count_dict = (
+ await self.router_cache.async_get_cache(
+ key=latency_key, parent_otel_span=parent_otel_span
+ )
+ or {}
+ )
+
+ return self._get_available_deployments(
+ model_group,
+ healthy_deployments,
+ messages,
+ input,
+ request_kwargs,
+ request_count_dict,
+ )
+
+ def get_available_deployments(
+ self,
+ model_group: str,
+ healthy_deployments: list,
+ messages: Optional[List[Dict[str, str]]] = None,
+ input: Optional[Union[str, List]] = None,
+ request_kwargs: Optional[Dict] = None,
+ ):
+ """
+ Returns a deployment with the lowest latency
+ """
+ # get list of potential deployments
+ latency_key = f"{model_group}_map"
+
+ parent_otel_span: Optional[Span] = _get_parent_otel_span_from_kwargs(
+ request_kwargs
+ )
+ request_count_dict = (
+ self.router_cache.get_cache(
+ key=latency_key, parent_otel_span=parent_otel_span
+ )
+ or {}
+ )
+
+ return self._get_available_deployments(
+ model_group,
+ healthy_deployments,
+ messages,
+ input,
+ request_kwargs,
+ request_count_dict,
+ )
diff --git a/.venv/lib/python3.12/site-packages/litellm/router_strategy/lowest_tpm_rpm.py b/.venv/lib/python3.12/site-packages/litellm/router_strategy/lowest_tpm_rpm.py
new file mode 100644
index 00000000..86587939
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/router_strategy/lowest_tpm_rpm.py
@@ -0,0 +1,243 @@
+#### What this does ####
+# identifies lowest tpm deployment
+import traceback
+from datetime import datetime
+from typing import Dict, List, Optional, Union
+
+from litellm import token_counter
+from litellm._logging import verbose_router_logger
+from litellm.caching.caching import DualCache
+from litellm.integrations.custom_logger import CustomLogger
+from litellm.types.utils import LiteLLMPydanticObjectBase
+from litellm.utils import print_verbose
+
+
+class RoutingArgs(LiteLLMPydanticObjectBase):
+ ttl: int = 1 * 60 # 1min (RPM/TPM expire key)
+
+
+class LowestTPMLoggingHandler(CustomLogger):
+ test_flag: bool = False
+ logged_success: int = 0
+ logged_failure: int = 0
+ default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour
+
+ def __init__(
+ self, router_cache: DualCache, model_list: list, routing_args: dict = {}
+ ):
+ self.router_cache = router_cache
+ self.model_list = model_list
+ self.routing_args = RoutingArgs(**routing_args)
+
+ def log_success_event(self, kwargs, response_obj, start_time, end_time):
+ try:
+ """
+ Update TPM/RPM usage on success
+ """
+ if kwargs["litellm_params"].get("metadata") is None:
+ pass
+ else:
+ model_group = kwargs["litellm_params"]["metadata"].get(
+ "model_group", None
+ )
+
+ id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
+ if model_group is None or id is None:
+ return
+ elif isinstance(id, int):
+ id = str(id)
+
+ total_tokens = response_obj["usage"]["total_tokens"]
+
+ # ------------
+ # Setup values
+ # ------------
+ current_minute = datetime.now().strftime("%H-%M")
+ tpm_key = f"{model_group}:tpm:{current_minute}"
+ rpm_key = f"{model_group}:rpm:{current_minute}"
+
+ # ------------
+ # Update usage
+ # ------------
+
+ ## TPM
+ request_count_dict = self.router_cache.get_cache(key=tpm_key) or {}
+ request_count_dict[id] = request_count_dict.get(id, 0) + total_tokens
+
+ self.router_cache.set_cache(
+ key=tpm_key, value=request_count_dict, ttl=self.routing_args.ttl
+ )
+
+ ## RPM
+ request_count_dict = self.router_cache.get_cache(key=rpm_key) or {}
+ request_count_dict[id] = request_count_dict.get(id, 0) + 1
+
+ self.router_cache.set_cache(
+ key=rpm_key, value=request_count_dict, ttl=self.routing_args.ttl
+ )
+
+ ### TESTING ###
+ if self.test_flag:
+ self.logged_success += 1
+ except Exception as e:
+ verbose_router_logger.error(
+ "litellm.router_strategy.lowest_tpm_rpm.py::async_log_success_event(): Exception occured - {}".format(
+ str(e)
+ )
+ )
+ verbose_router_logger.debug(traceback.format_exc())
+ pass
+
+ async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
+ try:
+ """
+ Update TPM/RPM usage on success
+ """
+ if kwargs["litellm_params"].get("metadata") is None:
+ pass
+ else:
+ model_group = kwargs["litellm_params"]["metadata"].get(
+ "model_group", None
+ )
+
+ id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
+ if model_group is None or id is None:
+ return
+ elif isinstance(id, int):
+ id = str(id)
+
+ total_tokens = response_obj["usage"]["total_tokens"]
+
+ # ------------
+ # Setup values
+ # ------------
+ current_minute = datetime.now().strftime("%H-%M")
+ tpm_key = f"{model_group}:tpm:{current_minute}"
+ rpm_key = f"{model_group}:rpm:{current_minute}"
+
+ # ------------
+ # Update usage
+ # ------------
+ # update cache
+
+ ## TPM
+ request_count_dict = (
+ await self.router_cache.async_get_cache(key=tpm_key) or {}
+ )
+ request_count_dict[id] = request_count_dict.get(id, 0) + total_tokens
+
+ await self.router_cache.async_set_cache(
+ key=tpm_key, value=request_count_dict, ttl=self.routing_args.ttl
+ )
+
+ ## RPM
+ request_count_dict = (
+ await self.router_cache.async_get_cache(key=rpm_key) or {}
+ )
+ request_count_dict[id] = request_count_dict.get(id, 0) + 1
+
+ await self.router_cache.async_set_cache(
+ key=rpm_key, value=request_count_dict, ttl=self.routing_args.ttl
+ )
+
+ ### TESTING ###
+ if self.test_flag:
+ self.logged_success += 1
+ except Exception as e:
+ verbose_router_logger.error(
+ "litellm.router_strategy.lowest_tpm_rpm.py::async_log_success_event(): Exception occured - {}".format(
+ str(e)
+ )
+ )
+ verbose_router_logger.debug(traceback.format_exc())
+ pass
+
+ def get_available_deployments( # noqa: PLR0915
+ self,
+ model_group: str,
+ healthy_deployments: list,
+ messages: Optional[List[Dict[str, str]]] = None,
+ input: Optional[Union[str, List]] = None,
+ ):
+ """
+ Returns a deployment with the lowest TPM/RPM usage.
+ """
+ # get list of potential deployments
+ verbose_router_logger.debug(
+ f"get_available_deployments - Usage Based. model_group: {model_group}, healthy_deployments: {healthy_deployments}"
+ )
+ current_minute = datetime.now().strftime("%H-%M")
+ tpm_key = f"{model_group}:tpm:{current_minute}"
+ rpm_key = f"{model_group}:rpm:{current_minute}"
+
+ tpm_dict = self.router_cache.get_cache(key=tpm_key)
+ rpm_dict = self.router_cache.get_cache(key=rpm_key)
+
+ verbose_router_logger.debug(
+ f"tpm_key={tpm_key}, tpm_dict: {tpm_dict}, rpm_dict: {rpm_dict}"
+ )
+ try:
+ input_tokens = token_counter(messages=messages, text=input)
+ except Exception:
+ input_tokens = 0
+ verbose_router_logger.debug(f"input_tokens={input_tokens}")
+ # -----------------------
+ # Find lowest used model
+ # ----------------------
+ lowest_tpm = float("inf")
+
+ if tpm_dict is None: # base case - none of the deployments have been used
+ # initialize a tpm dict with {model_id: 0}
+ tpm_dict = {}
+ for deployment in healthy_deployments:
+ tpm_dict[deployment["model_info"]["id"]] = 0
+ else:
+ for d in healthy_deployments:
+ ## if healthy deployment not yet used
+ if d["model_info"]["id"] not in tpm_dict:
+ tpm_dict[d["model_info"]["id"]] = 0
+
+ all_deployments = tpm_dict
+
+ deployment = None
+ for item, item_tpm in all_deployments.items():
+ ## get the item from model list
+ _deployment = None
+ for m in healthy_deployments:
+ if item == m["model_info"]["id"]:
+ _deployment = m
+
+ if _deployment is None:
+ continue # skip to next one
+
+ _deployment_tpm = None
+ if _deployment_tpm is None:
+ _deployment_tpm = _deployment.get("tpm")
+ if _deployment_tpm is None:
+ _deployment_tpm = _deployment.get("litellm_params", {}).get("tpm")
+ if _deployment_tpm is None:
+ _deployment_tpm = _deployment.get("model_info", {}).get("tpm")
+ if _deployment_tpm is None:
+ _deployment_tpm = float("inf")
+
+ _deployment_rpm = None
+ if _deployment_rpm is None:
+ _deployment_rpm = _deployment.get("rpm")
+ if _deployment_rpm is None:
+ _deployment_rpm = _deployment.get("litellm_params", {}).get("rpm")
+ if _deployment_rpm is None:
+ _deployment_rpm = _deployment.get("model_info", {}).get("rpm")
+ if _deployment_rpm is None:
+ _deployment_rpm = float("inf")
+
+ if item_tpm + input_tokens > _deployment_tpm:
+ continue
+ elif (rpm_dict is not None and item in rpm_dict) and (
+ rpm_dict[item] + 1 >= _deployment_rpm
+ ):
+ continue
+ elif item_tpm < lowest_tpm:
+ lowest_tpm = item_tpm
+ deployment = _deployment
+ print_verbose("returning picked lowest tpm/rpm deployment.")
+ return deployment
diff --git a/.venv/lib/python3.12/site-packages/litellm/router_strategy/lowest_tpm_rpm_v2.py b/.venv/lib/python3.12/site-packages/litellm/router_strategy/lowest_tpm_rpm_v2.py
new file mode 100644
index 00000000..d1a46b7e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/router_strategy/lowest_tpm_rpm_v2.py
@@ -0,0 +1,671 @@
+#### What this does ####
+# identifies lowest tpm deployment
+import random
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
+
+import httpx
+
+import litellm
+from litellm import token_counter
+from litellm._logging import verbose_logger, verbose_router_logger
+from litellm.caching.caching import DualCache
+from litellm.integrations.custom_logger import CustomLogger
+from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
+from litellm.types.router import RouterErrors
+from litellm.types.utils import LiteLLMPydanticObjectBase, StandardLoggingPayload
+from litellm.utils import get_utc_datetime, print_verbose
+
+from .base_routing_strategy import BaseRoutingStrategy
+
+if TYPE_CHECKING:
+ from opentelemetry.trace import Span as _Span
+
+ Span = _Span
+else:
+ Span = Any
+
+
+class RoutingArgs(LiteLLMPydanticObjectBase):
+ ttl: int = 1 * 60 # 1min (RPM/TPM expire key)
+
+
+class LowestTPMLoggingHandler_v2(BaseRoutingStrategy, CustomLogger):
+ """
+ Updated version of TPM/RPM Logging.
+
+ Meant to work across instances.
+
+ Caches individual models, not model_groups
+
+ Uses batch get (redis.mget)
+
+ Increments tpm/rpm limit using redis.incr
+ """
+
+ test_flag: bool = False
+ logged_success: int = 0
+ logged_failure: int = 0
+ default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour
+
+ def __init__(
+ self, router_cache: DualCache, model_list: list, routing_args: dict = {}
+ ):
+ self.router_cache = router_cache
+ self.model_list = model_list
+ self.routing_args = RoutingArgs(**routing_args)
+ BaseRoutingStrategy.__init__(
+ self,
+ dual_cache=router_cache,
+ should_batch_redis_writes=True,
+ default_sync_interval=0.1,
+ )
+
+ def pre_call_check(self, deployment: Dict) -> Optional[Dict]:
+ """
+ Pre-call check + update model rpm
+
+ Returns - deployment
+
+ Raises - RateLimitError if deployment over defined RPM limit
+ """
+ try:
+
+ # ------------
+ # Setup values
+ # ------------
+
+ dt = get_utc_datetime()
+ current_minute = dt.strftime("%H-%M")
+ model_id = deployment.get("model_info", {}).get("id")
+ deployment_name = deployment.get("litellm_params", {}).get("model")
+ rpm_key = f"{model_id}:{deployment_name}:rpm:{current_minute}"
+
+ local_result = self.router_cache.get_cache(
+ key=rpm_key, local_only=True
+ ) # check local result first
+
+ deployment_rpm = None
+ if deployment_rpm is None:
+ deployment_rpm = deployment.get("rpm")
+ if deployment_rpm is None:
+ deployment_rpm = deployment.get("litellm_params", {}).get("rpm")
+ if deployment_rpm is None:
+ deployment_rpm = deployment.get("model_info", {}).get("rpm")
+ if deployment_rpm is None:
+ deployment_rpm = float("inf")
+
+ if local_result is not None and local_result >= deployment_rpm:
+ raise litellm.RateLimitError(
+ message="Deployment over defined rpm limit={}. current usage={}".format(
+ deployment_rpm, local_result
+ ),
+ llm_provider="",
+ model=deployment.get("litellm_params", {}).get("model"),
+ response=httpx.Response(
+ status_code=429,
+ content="{} rpm limit={}. current usage={}. id={}, model_group={}. Get the model info by calling 'router.get_model_info(id)".format(
+ RouterErrors.user_defined_ratelimit_error.value,
+ deployment_rpm,
+ local_result,
+ model_id,
+ deployment.get("model_name", ""),
+ ),
+ request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore
+ ),
+ )
+ else:
+ # if local result below limit, check redis ## prevent unnecessary redis checks
+
+ result = self.router_cache.increment_cache(
+ key=rpm_key, value=1, ttl=self.routing_args.ttl
+ )
+ if result is not None and result > deployment_rpm:
+ raise litellm.RateLimitError(
+ message="Deployment over defined rpm limit={}. current usage={}".format(
+ deployment_rpm, result
+ ),
+ llm_provider="",
+ model=deployment.get("litellm_params", {}).get("model"),
+ response=httpx.Response(
+ status_code=429,
+ content="{} rpm limit={}. current usage={}".format(
+ RouterErrors.user_defined_ratelimit_error.value,
+ deployment_rpm,
+ result,
+ ),
+ request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore
+ ),
+ )
+ return deployment
+ except Exception as e:
+ if isinstance(e, litellm.RateLimitError):
+ raise e
+ return deployment # don't fail calls if eg. redis fails to connect
+
+ async def async_pre_call_check(
+ self, deployment: Dict, parent_otel_span: Optional[Span]
+ ) -> Optional[Dict]:
+ """
+ Pre-call check + update model rpm
+ - Used inside semaphore
+ - raise rate limit error if deployment over limit
+
+ Why? solves concurrency issue - https://github.com/BerriAI/litellm/issues/2994
+
+ Returns - deployment
+
+ Raises - RateLimitError if deployment over defined RPM limit
+ """
+ try:
+ # ------------
+ # Setup values
+ # ------------
+ dt = get_utc_datetime()
+ current_minute = dt.strftime("%H-%M")
+ model_id = deployment.get("model_info", {}).get("id")
+ deployment_name = deployment.get("litellm_params", {}).get("model")
+
+ rpm_key = f"{model_id}:{deployment_name}:rpm:{current_minute}"
+ local_result = await self.router_cache.async_get_cache(
+ key=rpm_key, local_only=True
+ ) # check local result first
+
+ deployment_rpm = None
+ if deployment_rpm is None:
+ deployment_rpm = deployment.get("rpm")
+ if deployment_rpm is None:
+ deployment_rpm = deployment.get("litellm_params", {}).get("rpm")
+ if deployment_rpm is None:
+ deployment_rpm = deployment.get("model_info", {}).get("rpm")
+ if deployment_rpm is None:
+ deployment_rpm = float("inf")
+ if local_result is not None and local_result >= deployment_rpm:
+ raise litellm.RateLimitError(
+ message="Deployment over defined rpm limit={}. current usage={}".format(
+ deployment_rpm, local_result
+ ),
+ llm_provider="",
+ model=deployment.get("litellm_params", {}).get("model"),
+ response=httpx.Response(
+ status_code=429,
+ content="{} rpm limit={}. current usage={}".format(
+ RouterErrors.user_defined_ratelimit_error.value,
+ deployment_rpm,
+ local_result,
+ ),
+ headers={"retry-after": str(60)}, # type: ignore
+ request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore
+ ),
+ num_retries=deployment.get("num_retries"),
+ )
+ else:
+ # if local result below limit, check redis ## prevent unnecessary redis checks
+ result = await self._increment_value_in_current_window(
+ key=rpm_key, value=1, ttl=self.routing_args.ttl
+ )
+ if result is not None and result > deployment_rpm:
+ raise litellm.RateLimitError(
+ message="Deployment over defined rpm limit={}. current usage={}".format(
+ deployment_rpm, result
+ ),
+ llm_provider="",
+ model=deployment.get("litellm_params", {}).get("model"),
+ response=httpx.Response(
+ status_code=429,
+ content="{} rpm limit={}. current usage={}".format(
+ RouterErrors.user_defined_ratelimit_error.value,
+ deployment_rpm,
+ result,
+ ),
+ headers={"retry-after": str(60)}, # type: ignore
+ request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore
+ ),
+ num_retries=deployment.get("num_retries"),
+ )
+ return deployment
+ except Exception as e:
+ if isinstance(e, litellm.RateLimitError):
+ raise e
+ return deployment # don't fail calls if eg. redis fails to connect
+
+ def log_success_event(self, kwargs, response_obj, start_time, end_time):
+ try:
+ """
+ Update TPM/RPM usage on success
+ """
+ standard_logging_object: Optional[StandardLoggingPayload] = kwargs.get(
+ "standard_logging_object"
+ )
+ if standard_logging_object is None:
+ raise ValueError("standard_logging_object not passed in.")
+ model_group = standard_logging_object.get("model_group")
+ model = standard_logging_object["hidden_params"].get("litellm_model_name")
+ id = standard_logging_object.get("model_id")
+ if model_group is None or id is None or model is None:
+ return
+ elif isinstance(id, int):
+ id = str(id)
+
+ total_tokens = standard_logging_object.get("total_tokens")
+
+ # ------------
+ # Setup values
+ # ------------
+ dt = get_utc_datetime()
+ current_minute = dt.strftime(
+ "%H-%M"
+ ) # use the same timezone regardless of system clock
+
+ tpm_key = f"{id}:{model}:tpm:{current_minute}"
+ # ------------
+ # Update usage
+ # ------------
+ # update cache
+
+ ## TPM
+ self.router_cache.increment_cache(
+ key=tpm_key, value=total_tokens, ttl=self.routing_args.ttl
+ )
+ ### TESTING ###
+ if self.test_flag:
+ self.logged_success += 1
+ except Exception as e:
+ verbose_logger.exception(
+ "litellm.proxy.hooks.lowest_tpm_rpm_v2.py::log_success_event(): Exception occured - {}".format(
+ str(e)
+ )
+ )
+ pass
+
+ async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
+ try:
+ """
+ Update TPM usage on success
+ """
+ standard_logging_object: Optional[StandardLoggingPayload] = kwargs.get(
+ "standard_logging_object"
+ )
+ if standard_logging_object is None:
+ raise ValueError("standard_logging_object not passed in.")
+ model_group = standard_logging_object.get("model_group")
+ model = standard_logging_object["hidden_params"]["litellm_model_name"]
+ id = standard_logging_object.get("model_id")
+ if model_group is None or id is None:
+ return
+ elif isinstance(id, int):
+ id = str(id)
+ total_tokens = standard_logging_object.get("total_tokens")
+ # ------------
+ # Setup values
+ # ------------
+ dt = get_utc_datetime()
+ current_minute = dt.strftime(
+ "%H-%M"
+ ) # use the same timezone regardless of system clock
+
+ tpm_key = f"{id}:{model}:tpm:{current_minute}"
+ # ------------
+ # Update usage
+ # ------------
+ # update cache
+ parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
+ ## TPM
+ await self.router_cache.async_increment_cache(
+ key=tpm_key,
+ value=total_tokens,
+ ttl=self.routing_args.ttl,
+ parent_otel_span=parent_otel_span,
+ )
+
+ ### TESTING ###
+ if self.test_flag:
+ self.logged_success += 1
+ except Exception as e:
+ verbose_logger.exception(
+ "litellm.proxy.hooks.lowest_tpm_rpm_v2.py::async_log_success_event(): Exception occured - {}".format(
+ str(e)
+ )
+ )
+ pass
+
+ def _return_potential_deployments(
+ self,
+ healthy_deployments: List[Dict],
+ all_deployments: Dict,
+ input_tokens: int,
+ rpm_dict: Dict,
+ ):
+ lowest_tpm = float("inf")
+ potential_deployments = [] # if multiple deployments have the same low value
+ for item, item_tpm in all_deployments.items():
+ ## get the item from model list
+ _deployment = None
+ item = item.split(":")[0]
+ for m in healthy_deployments:
+ if item == m["model_info"]["id"]:
+ _deployment = m
+ if _deployment is None:
+ continue # skip to next one
+ elif item_tpm is None:
+ continue # skip if unhealthy deployment
+
+ _deployment_tpm = None
+ if _deployment_tpm is None:
+ _deployment_tpm = _deployment.get("tpm")
+ if _deployment_tpm is None:
+ _deployment_tpm = _deployment.get("litellm_params", {}).get("tpm")
+ if _deployment_tpm is None:
+ _deployment_tpm = _deployment.get("model_info", {}).get("tpm")
+ if _deployment_tpm is None:
+ _deployment_tpm = float("inf")
+
+ _deployment_rpm = None
+ if _deployment_rpm is None:
+ _deployment_rpm = _deployment.get("rpm")
+ if _deployment_rpm is None:
+ _deployment_rpm = _deployment.get("litellm_params", {}).get("rpm")
+ if _deployment_rpm is None:
+ _deployment_rpm = _deployment.get("model_info", {}).get("rpm")
+ if _deployment_rpm is None:
+ _deployment_rpm = float("inf")
+ if item_tpm + input_tokens > _deployment_tpm:
+ continue
+ elif (
+ (rpm_dict is not None and item in rpm_dict)
+ and rpm_dict[item] is not None
+ and (rpm_dict[item] + 1 >= _deployment_rpm)
+ ):
+ continue
+ elif item_tpm == lowest_tpm:
+ potential_deployments.append(_deployment)
+ elif item_tpm < lowest_tpm:
+ lowest_tpm = item_tpm
+ potential_deployments = [_deployment]
+ return potential_deployments
+
+ def _common_checks_available_deployment( # noqa: PLR0915
+ self,
+ model_group: str,
+ healthy_deployments: list,
+ tpm_keys: list,
+ tpm_values: Optional[list],
+ rpm_keys: list,
+ rpm_values: Optional[list],
+ messages: Optional[List[Dict[str, str]]] = None,
+ input: Optional[Union[str, List]] = None,
+ ) -> Optional[dict]:
+ """
+ Common checks for get available deployment, across sync + async implementations
+ """
+
+ if tpm_values is None or rpm_values is None:
+ return None
+
+ tpm_dict = {} # {model_id: 1, ..}
+ for idx, key in enumerate(tpm_keys):
+ tpm_dict[tpm_keys[idx].split(":")[0]] = tpm_values[idx]
+
+ rpm_dict = {} # {model_id: 1, ..}
+ for idx, key in enumerate(rpm_keys):
+ rpm_dict[rpm_keys[idx].split(":")[0]] = rpm_values[idx]
+
+ try:
+ input_tokens = token_counter(messages=messages, text=input)
+ except Exception:
+ input_tokens = 0
+ verbose_router_logger.debug(f"input_tokens={input_tokens}")
+ # -----------------------
+ # Find lowest used model
+ # ----------------------
+
+ if tpm_dict is None: # base case - none of the deployments have been used
+ # initialize a tpm dict with {model_id: 0}
+ tpm_dict = {}
+ for deployment in healthy_deployments:
+ tpm_dict[deployment["model_info"]["id"]] = 0
+ else:
+ for d in healthy_deployments:
+ ## if healthy deployment not yet used
+ tpm_key = d["model_info"]["id"]
+ if tpm_key not in tpm_dict or tpm_dict[tpm_key] is None:
+ tpm_dict[tpm_key] = 0
+
+ all_deployments = tpm_dict
+ potential_deployments = self._return_potential_deployments(
+ healthy_deployments=healthy_deployments,
+ all_deployments=all_deployments,
+ input_tokens=input_tokens,
+ rpm_dict=rpm_dict,
+ )
+ print_verbose("returning picked lowest tpm/rpm deployment.")
+
+ if len(potential_deployments) > 0:
+ return random.choice(potential_deployments)
+ else:
+ return None
+
+ async def async_get_available_deployments(
+ self,
+ model_group: str,
+ healthy_deployments: list,
+ messages: Optional[List[Dict[str, str]]] = None,
+ input: Optional[Union[str, List]] = None,
+ ):
+ """
+ Async implementation of get deployments.
+
+ Reduces time to retrieve the tpm/rpm values from cache
+ """
+ # get list of potential deployments
+ verbose_router_logger.debug(
+ f"get_available_deployments - Usage Based. model_group: {model_group}, healthy_deployments: {healthy_deployments}"
+ )
+
+ dt = get_utc_datetime()
+ current_minute = dt.strftime("%H-%M")
+
+ tpm_keys = []
+ rpm_keys = []
+ for m in healthy_deployments:
+ if isinstance(m, dict):
+ id = m.get("model_info", {}).get(
+ "id"
+ ) # a deployment should always have an 'id'. this is set in router.py
+ deployment_name = m.get("litellm_params", {}).get("model")
+ tpm_key = "{}:{}:tpm:{}".format(id, deployment_name, current_minute)
+ rpm_key = "{}:{}:rpm:{}".format(id, deployment_name, current_minute)
+
+ tpm_keys.append(tpm_key)
+ rpm_keys.append(rpm_key)
+
+ combined_tpm_rpm_keys = tpm_keys + rpm_keys
+
+ combined_tpm_rpm_values = await self.router_cache.async_batch_get_cache(
+ keys=combined_tpm_rpm_keys
+ ) # [1, 2, None, ..]
+
+ if combined_tpm_rpm_values is not None:
+ tpm_values = combined_tpm_rpm_values[: len(tpm_keys)]
+ rpm_values = combined_tpm_rpm_values[len(tpm_keys) :]
+ else:
+ tpm_values = None
+ rpm_values = None
+
+ deployment = self._common_checks_available_deployment(
+ model_group=model_group,
+ healthy_deployments=healthy_deployments,
+ tpm_keys=tpm_keys,
+ tpm_values=tpm_values,
+ rpm_keys=rpm_keys,
+ rpm_values=rpm_values,
+ messages=messages,
+ input=input,
+ )
+
+ try:
+ assert deployment is not None
+ return deployment
+ except Exception:
+ ### GET THE DICT OF TPM / RPM + LIMITS PER DEPLOYMENT ###
+ deployment_dict = {}
+ for index, _deployment in enumerate(healthy_deployments):
+ if isinstance(_deployment, dict):
+ id = _deployment.get("model_info", {}).get("id")
+ ### GET DEPLOYMENT TPM LIMIT ###
+ _deployment_tpm = None
+ if _deployment_tpm is None:
+ _deployment_tpm = _deployment.get("tpm", None)
+ if _deployment_tpm is None:
+ _deployment_tpm = _deployment.get("litellm_params", {}).get(
+ "tpm", None
+ )
+ if _deployment_tpm is None:
+ _deployment_tpm = _deployment.get("model_info", {}).get(
+ "tpm", None
+ )
+ if _deployment_tpm is None:
+ _deployment_tpm = float("inf")
+
+ ### GET CURRENT TPM ###
+ current_tpm = tpm_values[index] if tpm_values else 0
+
+ ### GET DEPLOYMENT TPM LIMIT ###
+ _deployment_rpm = None
+ if _deployment_rpm is None:
+ _deployment_rpm = _deployment.get("rpm", None)
+ if _deployment_rpm is None:
+ _deployment_rpm = _deployment.get("litellm_params", {}).get(
+ "rpm", None
+ )
+ if _deployment_rpm is None:
+ _deployment_rpm = _deployment.get("model_info", {}).get(
+ "rpm", None
+ )
+ if _deployment_rpm is None:
+ _deployment_rpm = float("inf")
+
+ ### GET CURRENT RPM ###
+ current_rpm = rpm_values[index] if rpm_values else 0
+
+ deployment_dict[id] = {
+ "current_tpm": current_tpm,
+ "tpm_limit": _deployment_tpm,
+ "current_rpm": current_rpm,
+ "rpm_limit": _deployment_rpm,
+ }
+ raise litellm.RateLimitError(
+ message=f"{RouterErrors.no_deployments_available.value}. Passed model={model_group}. Deployments={deployment_dict}",
+ llm_provider="",
+ model=model_group,
+ response=httpx.Response(
+ status_code=429,
+ content="",
+ headers={"retry-after": str(60)}, # type: ignore
+ request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore
+ ),
+ )
+
+ def get_available_deployments(
+ self,
+ model_group: str,
+ healthy_deployments: list,
+ messages: Optional[List[Dict[str, str]]] = None,
+ input: Optional[Union[str, List]] = None,
+ parent_otel_span: Optional[Span] = None,
+ ):
+ """
+ Returns a deployment with the lowest TPM/RPM usage.
+ """
+ # get list of potential deployments
+ verbose_router_logger.debug(
+ f"get_available_deployments - Usage Based. model_group: {model_group}, healthy_deployments: {healthy_deployments}"
+ )
+
+ dt = get_utc_datetime()
+ current_minute = dt.strftime("%H-%M")
+ tpm_keys = []
+ rpm_keys = []
+ for m in healthy_deployments:
+ if isinstance(m, dict):
+ id = m.get("model_info", {}).get(
+ "id"
+ ) # a deployment should always have an 'id'. this is set in router.py
+ deployment_name = m.get("litellm_params", {}).get("model")
+ tpm_key = "{}:{}:tpm:{}".format(id, deployment_name, current_minute)
+ rpm_key = "{}:{}:rpm:{}".format(id, deployment_name, current_minute)
+
+ tpm_keys.append(tpm_key)
+ rpm_keys.append(rpm_key)
+
+ tpm_values = self.router_cache.batch_get_cache(
+ keys=tpm_keys, parent_otel_span=parent_otel_span
+ ) # [1, 2, None, ..]
+ rpm_values = self.router_cache.batch_get_cache(
+ keys=rpm_keys, parent_otel_span=parent_otel_span
+ ) # [1, 2, None, ..]
+
+ deployment = self._common_checks_available_deployment(
+ model_group=model_group,
+ healthy_deployments=healthy_deployments,
+ tpm_keys=tpm_keys,
+ tpm_values=tpm_values,
+ rpm_keys=rpm_keys,
+ rpm_values=rpm_values,
+ messages=messages,
+ input=input,
+ )
+
+ try:
+ assert deployment is not None
+ return deployment
+ except Exception:
+ ### GET THE DICT OF TPM / RPM + LIMITS PER DEPLOYMENT ###
+ deployment_dict = {}
+ for index, _deployment in enumerate(healthy_deployments):
+ if isinstance(_deployment, dict):
+ id = _deployment.get("model_info", {}).get("id")
+ ### GET DEPLOYMENT TPM LIMIT ###
+ _deployment_tpm = None
+ if _deployment_tpm is None:
+ _deployment_tpm = _deployment.get("tpm", None)
+ if _deployment_tpm is None:
+ _deployment_tpm = _deployment.get("litellm_params", {}).get(
+ "tpm", None
+ )
+ if _deployment_tpm is None:
+ _deployment_tpm = _deployment.get("model_info", {}).get(
+ "tpm", None
+ )
+ if _deployment_tpm is None:
+ _deployment_tpm = float("inf")
+
+ ### GET CURRENT TPM ###
+ current_tpm = tpm_values[index] if tpm_values else 0
+
+ ### GET DEPLOYMENT TPM LIMIT ###
+ _deployment_rpm = None
+ if _deployment_rpm is None:
+ _deployment_rpm = _deployment.get("rpm", None)
+ if _deployment_rpm is None:
+ _deployment_rpm = _deployment.get("litellm_params", {}).get(
+ "rpm", None
+ )
+ if _deployment_rpm is None:
+ _deployment_rpm = _deployment.get("model_info", {}).get(
+ "rpm", None
+ )
+ if _deployment_rpm is None:
+ _deployment_rpm = float("inf")
+
+ ### GET CURRENT RPM ###
+ current_rpm = rpm_values[index] if rpm_values else 0
+
+ deployment_dict[id] = {
+ "current_tpm": current_tpm,
+ "tpm_limit": _deployment_tpm,
+ "current_rpm": current_rpm,
+ "rpm_limit": _deployment_rpm,
+ }
+ raise ValueError(
+ f"{RouterErrors.no_deployments_available.value}. Passed model={model_group}. Deployments={deployment_dict}"
+ )
diff --git a/.venv/lib/python3.12/site-packages/litellm/router_strategy/simple_shuffle.py b/.venv/lib/python3.12/site-packages/litellm/router_strategy/simple_shuffle.py
new file mode 100644
index 00000000..da24c02f
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/router_strategy/simple_shuffle.py
@@ -0,0 +1,96 @@
+"""
+Returns a random deployment from the list of healthy deployments.
+
+If weights are provided, it will return a deployment based on the weights.
+
+"""
+
+import random
+from typing import TYPE_CHECKING, Any, Dict, List, Union
+
+from litellm._logging import verbose_router_logger
+
+if TYPE_CHECKING:
+ from litellm.router import Router as _Router
+
+ LitellmRouter = _Router
+else:
+ LitellmRouter = Any
+
+
+def simple_shuffle(
+ llm_router_instance: LitellmRouter,
+ healthy_deployments: Union[List[Any], Dict[Any, Any]],
+ model: str,
+) -> Dict:
+ """
+ Returns a random deployment from the list of healthy deployments.
+
+ If weights are provided, it will return a deployment based on the weights.
+
+ If users pass `rpm` or `tpm`, we do a random weighted pick - based on `rpm`/`tpm`.
+
+ Args:
+ llm_router_instance: LitellmRouter instance
+ healthy_deployments: List of healthy deployments
+ model: Model name
+
+ Returns:
+ Dict: A single healthy deployment
+ """
+
+ ############## Check if 'weight' param set for a weighted pick #################
+ weight = healthy_deployments[0].get("litellm_params").get("weight", None)
+ if weight is not None:
+ # use weight-random pick if rpms provided
+ weights = [m["litellm_params"].get("weight", 0) for m in healthy_deployments]
+ verbose_router_logger.debug(f"\nweight {weights}")
+ total_weight = sum(weights)
+ weights = [weight / total_weight for weight in weights]
+ verbose_router_logger.debug(f"\n weights {weights}")
+ # Perform weighted random pick
+ selected_index = random.choices(range(len(weights)), weights=weights)[0]
+ verbose_router_logger.debug(f"\n selected index, {selected_index}")
+ deployment = healthy_deployments[selected_index]
+ verbose_router_logger.info(
+ f"get_available_deployment for model: {model}, Selected deployment: {llm_router_instance.print_deployment(deployment) or deployment[0]} for model: {model}"
+ )
+ return deployment or deployment[0]
+ ############## Check if we can do a RPM/TPM based weighted pick #################
+ rpm = healthy_deployments[0].get("litellm_params").get("rpm", None)
+ if rpm is not None:
+ # use weight-random pick if rpms provided
+ rpms = [m["litellm_params"].get("rpm", 0) for m in healthy_deployments]
+ verbose_router_logger.debug(f"\nrpms {rpms}")
+ total_rpm = sum(rpms)
+ weights = [rpm / total_rpm for rpm in rpms]
+ verbose_router_logger.debug(f"\n weights {weights}")
+ # Perform weighted random pick
+ selected_index = random.choices(range(len(rpms)), weights=weights)[0]
+ verbose_router_logger.debug(f"\n selected index, {selected_index}")
+ deployment = healthy_deployments[selected_index]
+ verbose_router_logger.info(
+ f"get_available_deployment for model: {model}, Selected deployment: {llm_router_instance.print_deployment(deployment) or deployment[0]} for model: {model}"
+ )
+ return deployment or deployment[0]
+ ############## Check if we can do a RPM/TPM based weighted pick #################
+ tpm = healthy_deployments[0].get("litellm_params").get("tpm", None)
+ if tpm is not None:
+ # use weight-random pick if rpms provided
+ tpms = [m["litellm_params"].get("tpm", 0) for m in healthy_deployments]
+ verbose_router_logger.debug(f"\ntpms {tpms}")
+ total_tpm = sum(tpms)
+ weights = [tpm / total_tpm for tpm in tpms]
+ verbose_router_logger.debug(f"\n weights {weights}")
+ # Perform weighted random pick
+ selected_index = random.choices(range(len(tpms)), weights=weights)[0]
+ verbose_router_logger.debug(f"\n selected index, {selected_index}")
+ deployment = healthy_deployments[selected_index]
+ verbose_router_logger.info(
+ f"get_available_deployment for model: {model}, Selected deployment: {llm_router_instance.print_deployment(deployment) or deployment[0]} for model: {model}"
+ )
+ return deployment or deployment[0]
+
+ ############## No RPM/TPM passed, we do a random pick #################
+ item = random.choice(healthy_deployments)
+ return item or item[0]
diff --git a/.venv/lib/python3.12/site-packages/litellm/router_strategy/tag_based_routing.py b/.venv/lib/python3.12/site-packages/litellm/router_strategy/tag_based_routing.py
new file mode 100644
index 00000000..e6a93614
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/router_strategy/tag_based_routing.py
@@ -0,0 +1,146 @@
+"""
+Use this to route requests between Teams
+
+- If tags in request is a subset of tags in deployment, return deployment
+- if deployments are set with default tags, return all default deployment
+- If no default_deployments are set, return all deployments
+"""
+
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
+
+from litellm._logging import verbose_logger
+from litellm.types.router import RouterErrors
+
+if TYPE_CHECKING:
+ from litellm.router import Router as _Router
+
+ LitellmRouter = _Router
+else:
+ LitellmRouter = Any
+
+
+def is_valid_deployment_tag(
+ deployment_tags: List[str], request_tags: List[str]
+) -> bool:
+ """
+ Check if a tag is valid
+ """
+
+ if any(tag in deployment_tags for tag in request_tags):
+ verbose_logger.debug(
+ "adding deployment with tags: %s, request tags: %s",
+ deployment_tags,
+ request_tags,
+ )
+ return True
+ elif "default" in deployment_tags:
+ verbose_logger.debug(
+ "adding default deployment with tags: %s, request tags: %s",
+ deployment_tags,
+ request_tags,
+ )
+ return True
+ return False
+
+
+async def get_deployments_for_tag(
+ llm_router_instance: LitellmRouter,
+ model: str, # used to raise the correct error
+ healthy_deployments: Union[List[Any], Dict[Any, Any]],
+ request_kwargs: Optional[Dict[Any, Any]] = None,
+):
+ """
+ Returns a list of deployments that match the requested model and tags in the request.
+
+ Executes tag based filtering based on the tags in request metadata and the tags on the deployments
+ """
+ if llm_router_instance.enable_tag_filtering is not True:
+ return healthy_deployments
+
+ if request_kwargs is None:
+ verbose_logger.debug(
+ "get_deployments_for_tag: request_kwargs is None returning healthy_deployments: %s",
+ healthy_deployments,
+ )
+ return healthy_deployments
+
+ if healthy_deployments is None:
+ verbose_logger.debug(
+ "get_deployments_for_tag: healthy_deployments is None returning healthy_deployments"
+ )
+ return healthy_deployments
+
+ verbose_logger.debug("request metadata: %s", request_kwargs.get("metadata"))
+ if "metadata" in request_kwargs:
+ metadata = request_kwargs["metadata"]
+ request_tags = metadata.get("tags")
+
+ new_healthy_deployments = []
+ if request_tags:
+ verbose_logger.debug(
+ "get_deployments_for_tag routing: router_keys: %s", request_tags
+ )
+ # example this can be router_keys=["free", "custom"]
+ # get all deployments that have a superset of these router keys
+ for deployment in healthy_deployments:
+ deployment_litellm_params = deployment.get("litellm_params")
+ deployment_tags = deployment_litellm_params.get("tags")
+
+ verbose_logger.debug(
+ "deployment: %s, deployment_router_keys: %s",
+ deployment,
+ deployment_tags,
+ )
+
+ if deployment_tags is None:
+ continue
+
+ if is_valid_deployment_tag(deployment_tags, request_tags):
+ new_healthy_deployments.append(deployment)
+
+ if len(new_healthy_deployments) == 0:
+ raise ValueError(
+ f"{RouterErrors.no_deployments_with_tag_routing.value}. Passed model={model} and tags={request_tags}"
+ )
+
+ return new_healthy_deployments
+
+ # for Untagged requests use default deployments if set
+ _default_deployments_with_tags = []
+ for deployment in healthy_deployments:
+ if "default" in deployment.get("litellm_params", {}).get("tags", []):
+ _default_deployments_with_tags.append(deployment)
+
+ if len(_default_deployments_with_tags) > 0:
+ return _default_deployments_with_tags
+
+ # if no default deployment is found, return healthy_deployments
+ verbose_logger.debug(
+ "no tier found in metadata, returning healthy_deployments: %s",
+ healthy_deployments,
+ )
+ return healthy_deployments
+
+
+def _get_tags_from_request_kwargs(
+ request_kwargs: Optional[Dict[Any, Any]] = None
+) -> List[str]:
+ """
+ Helper to get tags from request kwargs
+
+ Args:
+ request_kwargs: The request kwargs to get tags from
+
+ Returns:
+ List[str]: The tags from the request kwargs
+ """
+ if request_kwargs is None:
+ return []
+ if "metadata" in request_kwargs:
+ metadata = request_kwargs["metadata"]
+ return metadata.get("tags", [])
+ elif "litellm_params" in request_kwargs:
+ litellm_params = request_kwargs["litellm_params"]
+ _metadata = litellm_params.get("metadata", {})
+ return _metadata.get("tags", [])
+ return []