diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/router_strategy')
9 files changed, 3339 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/router_strategy/base_routing_strategy.py b/.venv/lib/python3.12/site-packages/litellm/router_strategy/base_routing_strategy.py new file mode 100644 index 00000000..a39d17e3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/router_strategy/base_routing_strategy.py @@ -0,0 +1,190 @@ +""" +Base class across routing strategies to abstract commmon functions like batch incrementing redis +""" + +import asyncio +import threading +from abc import ABC +from typing import List, Optional, Set, Union + +from litellm._logging import verbose_router_logger +from litellm.caching.caching import DualCache +from litellm.caching.redis_cache import RedisPipelineIncrementOperation +from litellm.constants import DEFAULT_REDIS_SYNC_INTERVAL + + +class BaseRoutingStrategy(ABC): + def __init__( + self, + dual_cache: DualCache, + should_batch_redis_writes: bool, + default_sync_interval: Optional[Union[int, float]], + ): + self.dual_cache = dual_cache + self.redis_increment_operation_queue: List[RedisPipelineIncrementOperation] = [] + if should_batch_redis_writes: + try: + # Try to get existing event loop + loop = asyncio.get_event_loop() + if loop.is_running(): + # If loop exists and is running, create task in existing loop + loop.create_task( + self.periodic_sync_in_memory_spend_with_redis( + default_sync_interval=default_sync_interval + ) + ) + else: + self._create_sync_thread(default_sync_interval) + except RuntimeError: # No event loop in current thread + self._create_sync_thread(default_sync_interval) + + self.in_memory_keys_to_update: set[str] = ( + set() + ) # Set with max size of 1000 keys + + async def _increment_value_in_current_window( + self, key: str, value: Union[int, float], ttl: int + ): + """ + Increment spend within existing budget window + + Runs once the budget start time exists in Redis Cache (on the 2nd and subsequent requests to the same provider) + + - Increments the spend in memory cache (so spend instantly updated in memory) + - Queues the increment operation to Redis Pipeline (using batched pipeline to optimize performance. Using Redis for multi instance environment of LiteLLM) + """ + result = await self.dual_cache.in_memory_cache.async_increment( + key=key, + value=value, + ttl=ttl, + ) + increment_op = RedisPipelineIncrementOperation( + key=key, + increment_value=value, + ttl=ttl, + ) + self.redis_increment_operation_queue.append(increment_op) + self.add_to_in_memory_keys_to_update(key=key) + return result + + async def periodic_sync_in_memory_spend_with_redis( + self, default_sync_interval: Optional[Union[int, float]] + ): + """ + Handler that triggers sync_in_memory_spend_with_redis every DEFAULT_REDIS_SYNC_INTERVAL seconds + + Required for multi-instance environment usage of provider budgets + """ + default_sync_interval = default_sync_interval or DEFAULT_REDIS_SYNC_INTERVAL + while True: + try: + await self._sync_in_memory_spend_with_redis() + await asyncio.sleep( + default_sync_interval + ) # Wait for DEFAULT_REDIS_SYNC_INTERVAL seconds before next sync + except Exception as e: + verbose_router_logger.error(f"Error in periodic sync task: {str(e)}") + await asyncio.sleep( + default_sync_interval + ) # Still wait DEFAULT_REDIS_SYNC_INTERVAL seconds on error before retrying + + async def _push_in_memory_increments_to_redis(self): + """ + How this works: + - async_log_success_event collects all provider spend increments in `redis_increment_operation_queue` + - This function pushes all increments to Redis in a batched pipeline to optimize performance + + Only runs if Redis is initialized + """ + try: + if not self.dual_cache.redis_cache: + return # Redis is not initialized + + verbose_router_logger.debug( + "Pushing Redis Increment Pipeline for queue: %s", + self.redis_increment_operation_queue, + ) + if len(self.redis_increment_operation_queue) > 0: + asyncio.create_task( + self.dual_cache.redis_cache.async_increment_pipeline( + increment_list=self.redis_increment_operation_queue, + ) + ) + + self.redis_increment_operation_queue = [] + + except Exception as e: + verbose_router_logger.error( + f"Error syncing in-memory cache with Redis: {str(e)}" + ) + self.redis_increment_operation_queue = [] + + def add_to_in_memory_keys_to_update(self, key: str): + self.in_memory_keys_to_update.add(key) + + def get_in_memory_keys_to_update(self) -> Set[str]: + return self.in_memory_keys_to_update + + def reset_in_memory_keys_to_update(self): + self.in_memory_keys_to_update = set() + + async def _sync_in_memory_spend_with_redis(self): + """ + Ensures in-memory cache is updated with latest Redis values for all provider spends. + + Why Do we need this? + - Optimization to hit sub 100ms latency. Performance was impacted when redis was used for read/write per request + - Use provider budgets in multi-instance environment, we use Redis to sync spend across all instances + + What this does: + 1. Push all provider spend increments to Redis + 2. Fetch all current provider spend from Redis to update in-memory cache + """ + + try: + # No need to sync if Redis cache is not initialized + if self.dual_cache.redis_cache is None: + return + + # 1. Push all provider spend increments to Redis + await self._push_in_memory_increments_to_redis() + + # 2. Fetch all current provider spend from Redis to update in-memory cache + cache_keys = self.get_in_memory_keys_to_update() + + cache_keys_list = list(cache_keys) + + # Batch fetch current spend values from Redis + redis_values = await self.dual_cache.redis_cache.async_batch_get_cache( + key_list=cache_keys_list + ) + + # Update in-memory cache with Redis values + if isinstance(redis_values, dict): # Check if redis_values is a dictionary + for key, value in redis_values.items(): + if value is not None: + await self.dual_cache.in_memory_cache.async_set_cache( + key=key, value=float(value) + ) + verbose_router_logger.debug( + f"Updated in-memory cache for {key}: {value}" + ) + + self.reset_in_memory_keys_to_update() + except Exception as e: + verbose_router_logger.exception( + f"Error syncing in-memory cache with Redis: {str(e)}" + ) + + def _create_sync_thread(self, default_sync_interval): + """Helper method to create a new thread for periodic sync""" + thread = threading.Thread( + target=asyncio.run, + args=( + self.periodic_sync_in_memory_spend_with_redis( + default_sync_interval=default_sync_interval + ), + ), + daemon=True, + ) + thread.start() diff --git a/.venv/lib/python3.12/site-packages/litellm/router_strategy/budget_limiter.py b/.venv/lib/python3.12/site-packages/litellm/router_strategy/budget_limiter.py new file mode 100644 index 00000000..4f123df2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/router_strategy/budget_limiter.py @@ -0,0 +1,818 @@ +""" +Provider budget limiting + +Use this if you want to set $ budget limits for each provider. + +Note: This is a filter, like tag-routing. Meaning it will accept healthy deployments and then filter out deployments that have exceeded their budget limit. + +This means you can use this with weighted-pick, lowest-latency, simple-shuffle, routing etc + +Example: +``` +openai: + budget_limit: 0.000000000001 + time_period: 1d +anthropic: + budget_limit: 100 + time_period: 7d +``` +""" + +import asyncio +from datetime import datetime, timedelta, timezone +from typing import Any, Dict, List, Optional, Tuple, Union + +import litellm +from litellm._logging import verbose_router_logger +from litellm.caching.caching import DualCache +from litellm.caching.redis_cache import RedisPipelineIncrementOperation +from litellm.integrations.custom_logger import CustomLogger, Span +from litellm.litellm_core_utils.duration_parser import duration_in_seconds +from litellm.router_strategy.tag_based_routing import _get_tags_from_request_kwargs +from litellm.router_utils.cooldown_callbacks import ( + _get_prometheus_logger_from_callbacks, +) +from litellm.types.llms.openai import AllMessageValues +from litellm.types.router import DeploymentTypedDict, LiteLLM_Params, RouterErrors +from litellm.types.utils import BudgetConfig +from litellm.types.utils import BudgetConfig as GenericBudgetInfo +from litellm.types.utils import GenericBudgetConfigType, StandardLoggingPayload + +DEFAULT_REDIS_SYNC_INTERVAL = 1 + + +class RouterBudgetLimiting(CustomLogger): + def __init__( + self, + dual_cache: DualCache, + provider_budget_config: Optional[dict], + model_list: Optional[ + Union[List[DeploymentTypedDict], List[Dict[str, Any]]] + ] = None, + ): + self.dual_cache = dual_cache + self.redis_increment_operation_queue: List[RedisPipelineIncrementOperation] = [] + asyncio.create_task(self.periodic_sync_in_memory_spend_with_redis()) + self.provider_budget_config: Optional[GenericBudgetConfigType] = ( + provider_budget_config + ) + self.deployment_budget_config: Optional[GenericBudgetConfigType] = None + self.tag_budget_config: Optional[GenericBudgetConfigType] = None + self._init_provider_budgets() + self._init_deployment_budgets(model_list=model_list) + self._init_tag_budgets() + + # Add self to litellm callbacks if it's a list + if isinstance(litellm.callbacks, list): + litellm.logging_callback_manager.add_litellm_callback(self) # type: ignore + + async def async_filter_deployments( + self, + model: str, + healthy_deployments: List, + messages: Optional[List[AllMessageValues]], + request_kwargs: Optional[dict] = None, + parent_otel_span: Optional[Span] = None, # type: ignore + ) -> List[dict]: + """ + Filter out deployments that have exceeded their provider budget limit. + + + Example: + if deployment = openai/gpt-3.5-turbo + and openai spend > openai budget limit + then skip this deployment + """ + + # If a single deployment is passed, convert it to a list + if isinstance(healthy_deployments, dict): + healthy_deployments = [healthy_deployments] + + # Don't do any filtering if there are no healthy deployments + if len(healthy_deployments) == 0: + return healthy_deployments + + potential_deployments: List[Dict] = [] + + cache_keys, provider_configs, deployment_configs = ( + await self._async_get_cache_keys_for_router_budget_limiting( + healthy_deployments=healthy_deployments, + request_kwargs=request_kwargs, + ) + ) + + # Single cache read for all spend values + if len(cache_keys) > 0: + _current_spends = await self.dual_cache.async_batch_get_cache( + keys=cache_keys, + parent_otel_span=parent_otel_span, + ) + current_spends: List = _current_spends or [0.0] * len(cache_keys) + + # Map spends to their respective keys + spend_map: Dict[str, float] = {} + for idx, key in enumerate(cache_keys): + spend_map[key] = float(current_spends[idx] or 0.0) + + potential_deployments, deployment_above_budget_info = ( + self._filter_out_deployments_above_budget( + healthy_deployments=healthy_deployments, + provider_configs=provider_configs, + deployment_configs=deployment_configs, + spend_map=spend_map, + potential_deployments=potential_deployments, + request_tags=_get_tags_from_request_kwargs( + request_kwargs=request_kwargs + ), + ) + ) + + if len(potential_deployments) == 0: + raise ValueError( + f"{RouterErrors.no_deployments_with_provider_budget_routing.value}: {deployment_above_budget_info}" + ) + + return potential_deployments + else: + return healthy_deployments + + def _filter_out_deployments_above_budget( + self, + potential_deployments: List[Dict[str, Any]], + healthy_deployments: List[Dict[str, Any]], + provider_configs: Dict[str, GenericBudgetInfo], + deployment_configs: Dict[str, GenericBudgetInfo], + spend_map: Dict[str, float], + request_tags: List[str], + ) -> Tuple[List[Dict[str, Any]], str]: + """ + Filter out deployments that have exceeded their budget limit. + Follow budget checks are run here: + - Provider budget + - Deployment budget + - Request tags budget + Returns: + Tuple[List[Dict[str, Any]], str]: + - A tuple containing the filtered deployments + - A string containing debug information about deployments that exceeded their budget limit. + """ + # Filter deployments based on both provider and deployment budgets + deployment_above_budget_info: str = "" + for deployment in healthy_deployments: + is_within_budget = True + + # Check provider budget + if self.provider_budget_config: + provider = self._get_llm_provider_for_deployment(deployment) + if provider in provider_configs: + config = provider_configs[provider] + if config.max_budget is None: + continue + current_spend = spend_map.get( + f"provider_spend:{provider}:{config.budget_duration}", 0.0 + ) + self._track_provider_remaining_budget_prometheus( + provider=provider, + spend=current_spend, + budget_limit=config.max_budget, + ) + + if config.max_budget and current_spend >= config.max_budget: + debug_msg = f"Exceeded budget for provider {provider}: {current_spend} >= {config.max_budget}" + deployment_above_budget_info += f"{debug_msg}\n" + is_within_budget = False + continue + + # Check deployment budget + if self.deployment_budget_config and is_within_budget: + _model_name = deployment.get("model_name") + _litellm_params = deployment.get("litellm_params") or {} + _litellm_model_name = _litellm_params.get("model") + model_id = deployment.get("model_info", {}).get("id") + if model_id in deployment_configs: + config = deployment_configs[model_id] + current_spend = spend_map.get( + f"deployment_spend:{model_id}:{config.budget_duration}", 0.0 + ) + if config.max_budget and current_spend >= config.max_budget: + debug_msg = f"Exceeded budget for deployment model_name: {_model_name}, litellm_params.model: {_litellm_model_name}, model_id: {model_id}: {current_spend} >= {config.budget_duration}" + verbose_router_logger.debug(debug_msg) + deployment_above_budget_info += f"{debug_msg}\n" + is_within_budget = False + continue + # Check tag budget + if self.tag_budget_config and is_within_budget: + for _tag in request_tags: + _tag_budget_config = self._get_budget_config_for_tag(_tag) + if _tag_budget_config: + _tag_spend = spend_map.get( + f"tag_spend:{_tag}:{_tag_budget_config.budget_duration}", + 0.0, + ) + if ( + _tag_budget_config.max_budget + and _tag_spend >= _tag_budget_config.max_budget + ): + debug_msg = f"Exceeded budget for tag='{_tag}', tag_spend={_tag_spend}, tag_budget_limit={_tag_budget_config.max_budget}" + verbose_router_logger.debug(debug_msg) + deployment_above_budget_info += f"{debug_msg}\n" + is_within_budget = False + continue + if is_within_budget: + potential_deployments.append(deployment) + + return potential_deployments, deployment_above_budget_info + + async def _async_get_cache_keys_for_router_budget_limiting( + self, + healthy_deployments: List[Dict[str, Any]], + request_kwargs: Optional[Dict] = None, + ) -> Tuple[List[str], Dict[str, GenericBudgetInfo], Dict[str, GenericBudgetInfo]]: + """ + Returns list of cache keys to fetch from router cache for budget limiting and provider and deployment configs + + Returns: + Tuple[List[str], Dict[str, GenericBudgetInfo], Dict[str, GenericBudgetInfo]]: + - List of cache keys to fetch from router cache for budget limiting + - Dict of provider budget configs `provider_configs` + - Dict of deployment budget configs `deployment_configs` + """ + cache_keys: List[str] = [] + provider_configs: Dict[str, GenericBudgetInfo] = {} + deployment_configs: Dict[str, GenericBudgetInfo] = {} + + for deployment in healthy_deployments: + # Check provider budgets + if self.provider_budget_config: + provider = self._get_llm_provider_for_deployment(deployment) + if provider is not None: + budget_config = self._get_budget_config_for_provider(provider) + if ( + budget_config is not None + and budget_config.budget_duration is not None + ): + provider_configs[provider] = budget_config + cache_keys.append( + f"provider_spend:{provider}:{budget_config.budget_duration}" + ) + + # Check deployment budgets + if self.deployment_budget_config: + model_id = deployment.get("model_info", {}).get("id") + if model_id is not None: + budget_config = self._get_budget_config_for_deployment(model_id) + if budget_config is not None: + deployment_configs[model_id] = budget_config + cache_keys.append( + f"deployment_spend:{model_id}:{budget_config.budget_duration}" + ) + # Check tag budgets + if self.tag_budget_config: + request_tags = _get_tags_from_request_kwargs( + request_kwargs=request_kwargs + ) + for _tag in request_tags: + _tag_budget_config = self._get_budget_config_for_tag(_tag) + if _tag_budget_config: + cache_keys.append( + f"tag_spend:{_tag}:{_tag_budget_config.budget_duration}" + ) + return cache_keys, provider_configs, deployment_configs + + async def _get_or_set_budget_start_time( + self, start_time_key: str, current_time: float, ttl_seconds: int + ) -> float: + """ + Checks if the key = `provider_budget_start_time:{provider}` exists in cache. + + If it does, return the value. + If it does not, set the key to `current_time` and return the value. + """ + budget_start = await self.dual_cache.async_get_cache(start_time_key) + if budget_start is None: + await self.dual_cache.async_set_cache( + key=start_time_key, value=current_time, ttl=ttl_seconds + ) + return current_time + return float(budget_start) + + async def _handle_new_budget_window( + self, + spend_key: str, + start_time_key: str, + current_time: float, + response_cost: float, + ttl_seconds: int, + ) -> float: + """ + Handle start of new budget window by resetting spend and start time + + Enters this when: + - The budget does not exist in cache, so we need to set it + - The budget window has expired, so we need to reset everything + + Does 2 things: + - stores key: `provider_spend:{provider}:1d`, value: response_cost + - stores key: `provider_budget_start_time:{provider}`, value: current_time. + This stores the start time of the new budget window + """ + await self.dual_cache.async_set_cache( + key=spend_key, value=response_cost, ttl=ttl_seconds + ) + await self.dual_cache.async_set_cache( + key=start_time_key, value=current_time, ttl=ttl_seconds + ) + return current_time + + async def _increment_spend_in_current_window( + self, spend_key: str, response_cost: float, ttl: int + ): + """ + Increment spend within existing budget window + + Runs once the budget start time exists in Redis Cache (on the 2nd and subsequent requests to the same provider) + + - Increments the spend in memory cache (so spend instantly updated in memory) + - Queues the increment operation to Redis Pipeline (using batched pipeline to optimize performance. Using Redis for multi instance environment of LiteLLM) + """ + await self.dual_cache.in_memory_cache.async_increment( + key=spend_key, + value=response_cost, + ttl=ttl, + ) + increment_op = RedisPipelineIncrementOperation( + key=spend_key, + increment_value=response_cost, + ttl=ttl, + ) + self.redis_increment_operation_queue.append(increment_op) + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + """Original method now uses helper functions""" + verbose_router_logger.debug("in RouterBudgetLimiting.async_log_success_event") + standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get( + "standard_logging_object", None + ) + if standard_logging_payload is None: + raise ValueError("standard_logging_payload is required") + + response_cost: float = standard_logging_payload.get("response_cost", 0) + model_id: str = str(standard_logging_payload.get("model_id", "")) + custom_llm_provider: str = kwargs.get("litellm_params", {}).get( + "custom_llm_provider", None + ) + if custom_llm_provider is None: + raise ValueError("custom_llm_provider is required") + + budget_config = self._get_budget_config_for_provider(custom_llm_provider) + if budget_config: + # increment spend for provider + spend_key = ( + f"provider_spend:{custom_llm_provider}:{budget_config.budget_duration}" + ) + start_time_key = f"provider_budget_start_time:{custom_llm_provider}" + await self._increment_spend_for_key( + budget_config=budget_config, + spend_key=spend_key, + start_time_key=start_time_key, + response_cost=response_cost, + ) + + deployment_budget_config = self._get_budget_config_for_deployment(model_id) + if deployment_budget_config: + # increment spend for specific deployment id + deployment_spend_key = f"deployment_spend:{model_id}:{deployment_budget_config.budget_duration}" + deployment_start_time_key = f"deployment_budget_start_time:{model_id}" + await self._increment_spend_for_key( + budget_config=deployment_budget_config, + spend_key=deployment_spend_key, + start_time_key=deployment_start_time_key, + response_cost=response_cost, + ) + + request_tags = _get_tags_from_request_kwargs(kwargs) + if len(request_tags) > 0: + for _tag in request_tags: + _tag_budget_config = self._get_budget_config_for_tag(_tag) + if _tag_budget_config: + _tag_spend_key = ( + f"tag_spend:{_tag}:{_tag_budget_config.budget_duration}" + ) + _tag_start_time_key = f"tag_budget_start_time:{_tag}" + await self._increment_spend_for_key( + budget_config=_tag_budget_config, + spend_key=_tag_spend_key, + start_time_key=_tag_start_time_key, + response_cost=response_cost, + ) + + async def _increment_spend_for_key( + self, + budget_config: GenericBudgetInfo, + spend_key: str, + start_time_key: str, + response_cost: float, + ): + if budget_config.budget_duration is None: + return + + current_time = datetime.now(timezone.utc).timestamp() + ttl_seconds = duration_in_seconds(budget_config.budget_duration) + + budget_start = await self._get_or_set_budget_start_time( + start_time_key=start_time_key, + current_time=current_time, + ttl_seconds=ttl_seconds, + ) + + if budget_start is None: + # First spend for this provider + budget_start = await self._handle_new_budget_window( + spend_key=spend_key, + start_time_key=start_time_key, + current_time=current_time, + response_cost=response_cost, + ttl_seconds=ttl_seconds, + ) + elif (current_time - budget_start) > ttl_seconds: + # Budget window expired - reset everything + verbose_router_logger.debug("Budget window expired - resetting everything") + budget_start = await self._handle_new_budget_window( + spend_key=spend_key, + start_time_key=start_time_key, + current_time=current_time, + response_cost=response_cost, + ttl_seconds=ttl_seconds, + ) + else: + # Within existing window - increment spend + remaining_time = ttl_seconds - (current_time - budget_start) + ttl_for_increment = int(remaining_time) + + await self._increment_spend_in_current_window( + spend_key=spend_key, response_cost=response_cost, ttl=ttl_for_increment + ) + + verbose_router_logger.debug( + f"Incremented spend for {spend_key} by {response_cost}" + ) + + async def periodic_sync_in_memory_spend_with_redis(self): + """ + Handler that triggers sync_in_memory_spend_with_redis every DEFAULT_REDIS_SYNC_INTERVAL seconds + + Required for multi-instance environment usage of provider budgets + """ + while True: + try: + await self._sync_in_memory_spend_with_redis() + await asyncio.sleep( + DEFAULT_REDIS_SYNC_INTERVAL + ) # Wait for DEFAULT_REDIS_SYNC_INTERVAL seconds before next sync + except Exception as e: + verbose_router_logger.error(f"Error in periodic sync task: {str(e)}") + await asyncio.sleep( + DEFAULT_REDIS_SYNC_INTERVAL + ) # Still wait DEFAULT_REDIS_SYNC_INTERVAL seconds on error before retrying + + async def _push_in_memory_increments_to_redis(self): + """ + How this works: + - async_log_success_event collects all provider spend increments in `redis_increment_operation_queue` + - This function pushes all increments to Redis in a batched pipeline to optimize performance + + Only runs if Redis is initialized + """ + try: + if not self.dual_cache.redis_cache: + return # Redis is not initialized + + verbose_router_logger.debug( + "Pushing Redis Increment Pipeline for queue: %s", + self.redis_increment_operation_queue, + ) + if len(self.redis_increment_operation_queue) > 0: + asyncio.create_task( + self.dual_cache.redis_cache.async_increment_pipeline( + increment_list=self.redis_increment_operation_queue, + ) + ) + + self.redis_increment_operation_queue = [] + + except Exception as e: + verbose_router_logger.error( + f"Error syncing in-memory cache with Redis: {str(e)}" + ) + + async def _sync_in_memory_spend_with_redis(self): + """ + Ensures in-memory cache is updated with latest Redis values for all provider spends. + + Why Do we need this? + - Optimization to hit sub 100ms latency. Performance was impacted when redis was used for read/write per request + - Use provider budgets in multi-instance environment, we use Redis to sync spend across all instances + + What this does: + 1. Push all provider spend increments to Redis + 2. Fetch all current provider spend from Redis to update in-memory cache + """ + + try: + # No need to sync if Redis cache is not initialized + if self.dual_cache.redis_cache is None: + return + + # 1. Push all provider spend increments to Redis + await self._push_in_memory_increments_to_redis() + + # 2. Fetch all current provider spend from Redis to update in-memory cache + cache_keys = [] + + if self.provider_budget_config is not None: + for provider, config in self.provider_budget_config.items(): + if config is None: + continue + cache_keys.append( + f"provider_spend:{provider}:{config.budget_duration}" + ) + + if self.deployment_budget_config is not None: + for model_id, config in self.deployment_budget_config.items(): + if config is None: + continue + cache_keys.append( + f"deployment_spend:{model_id}:{config.budget_duration}" + ) + + if self.tag_budget_config is not None: + for tag, config in self.tag_budget_config.items(): + if config is None: + continue + cache_keys.append(f"tag_spend:{tag}:{config.budget_duration}") + + # Batch fetch current spend values from Redis + redis_values = await self.dual_cache.redis_cache.async_batch_get_cache( + key_list=cache_keys + ) + + # Update in-memory cache with Redis values + if isinstance(redis_values, dict): # Check if redis_values is a dictionary + for key, value in redis_values.items(): + if value is not None: + await self.dual_cache.in_memory_cache.async_set_cache( + key=key, value=float(value) + ) + verbose_router_logger.debug( + f"Updated in-memory cache for {key}: {value}" + ) + + except Exception as e: + verbose_router_logger.error( + f"Error syncing in-memory cache with Redis: {str(e)}" + ) + + def _get_budget_config_for_deployment( + self, + model_id: str, + ) -> Optional[GenericBudgetInfo]: + if self.deployment_budget_config is None: + return None + return self.deployment_budget_config.get(model_id, None) + + def _get_budget_config_for_provider( + self, provider: str + ) -> Optional[GenericBudgetInfo]: + if self.provider_budget_config is None: + return None + return self.provider_budget_config.get(provider, None) + + def _get_budget_config_for_tag(self, tag: str) -> Optional[GenericBudgetInfo]: + if self.tag_budget_config is None: + return None + return self.tag_budget_config.get(tag, None) + + def _get_llm_provider_for_deployment(self, deployment: Dict) -> Optional[str]: + try: + _litellm_params: LiteLLM_Params = LiteLLM_Params( + **deployment.get("litellm_params", {"model": ""}) + ) + _, custom_llm_provider, _, _ = litellm.get_llm_provider( + model=_litellm_params.model, + litellm_params=_litellm_params, + ) + except Exception: + verbose_router_logger.error( + f"Error getting LLM provider for deployment: {deployment}" + ) + return None + return custom_llm_provider + + def _track_provider_remaining_budget_prometheus( + self, provider: str, spend: float, budget_limit: float + ): + """ + Optional helper - emit provider remaining budget metric to Prometheus + + This is helpful for debugging and monitoring provider budget limits. + """ + + prometheus_logger = _get_prometheus_logger_from_callbacks() + if prometheus_logger: + prometheus_logger.track_provider_remaining_budget( + provider=provider, + spend=spend, + budget_limit=budget_limit, + ) + + async def _get_current_provider_spend(self, provider: str) -> Optional[float]: + """ + GET the current spend for a provider from cache + + used for GET /provider/budgets endpoint in spend_management_endpoints.py + + Args: + provider (str): The provider to get spend for (e.g., "openai", "anthropic") + + Returns: + Optional[float]: The current spend for the provider, or None if not found + """ + budget_config = self._get_budget_config_for_provider(provider) + if budget_config is None: + return None + + spend_key = f"provider_spend:{provider}:{budget_config.budget_duration}" + + if self.dual_cache.redis_cache: + # use Redis as source of truth since that has spend across all instances + current_spend = await self.dual_cache.redis_cache.async_get_cache(spend_key) + else: + # use in-memory cache if Redis is not initialized + current_spend = await self.dual_cache.async_get_cache(spend_key) + return float(current_spend) if current_spend is not None else 0.0 + + async def _get_current_provider_budget_reset_at( + self, provider: str + ) -> Optional[str]: + budget_config = self._get_budget_config_for_provider(provider) + if budget_config is None: + return None + + spend_key = f"provider_spend:{provider}:{budget_config.budget_duration}" + if self.dual_cache.redis_cache: + ttl_seconds = await self.dual_cache.redis_cache.async_get_ttl(spend_key) + else: + ttl_seconds = await self.dual_cache.async_get_ttl(spend_key) + + if ttl_seconds is None: + return None + + return (datetime.now(timezone.utc) + timedelta(seconds=ttl_seconds)).isoformat() + + async def _init_provider_budget_in_cache( + self, provider: str, budget_config: GenericBudgetInfo + ): + """ + Initialize provider budget in cache by storing the following keys if they don't exist: + - provider_spend:{provider}:{budget_config.time_period} - stores the current spend + - provider_budget_start_time:{provider} - stores the start time of the budget window + + """ + + spend_key = f"provider_spend:{provider}:{budget_config.budget_duration}" + start_time_key = f"provider_budget_start_time:{provider}" + ttl_seconds: Optional[int] = None + if budget_config.budget_duration is not None: + ttl_seconds = duration_in_seconds(budget_config.budget_duration) + + budget_start = await self.dual_cache.async_get_cache(start_time_key) + if budget_start is None: + budget_start = datetime.now(timezone.utc).timestamp() + await self.dual_cache.async_set_cache( + key=start_time_key, value=budget_start, ttl=ttl_seconds + ) + + _spend_key = await self.dual_cache.async_get_cache(spend_key) + if _spend_key is None: + await self.dual_cache.async_set_cache( + key=spend_key, value=0.0, ttl=ttl_seconds + ) + + @staticmethod + def should_init_router_budget_limiter( + provider_budget_config: Optional[dict], + model_list: Optional[ + Union[List[DeploymentTypedDict], List[Dict[str, Any]]] + ] = None, + ): + """ + Returns `True` if the router budget routing settings are set and RouterBudgetLimiting should be initialized + + Either: + - provider_budget_config is set + - budgets are set for deployments in the model_list + - tag_budget_config is set + """ + if provider_budget_config is not None: + return True + + if litellm.tag_budget_config is not None: + return True + + if model_list is None: + return False + + for _model in model_list: + _litellm_params = _model.get("litellm_params", {}) + if ( + _litellm_params.get("max_budget") + or _litellm_params.get("budget_duration") is not None + ): + return True + return False + + def _init_provider_budgets(self): + if self.provider_budget_config is not None: + # cast elements of provider_budget_config to GenericBudgetInfo + for provider, config in self.provider_budget_config.items(): + if config is None: + raise ValueError( + f"No budget config found for provider {provider}, provider_budget_config: {self.provider_budget_config}" + ) + + if not isinstance(config, GenericBudgetInfo): + self.provider_budget_config[provider] = GenericBudgetInfo( + budget_limit=config.get("budget_limit"), + time_period=config.get("time_period"), + ) + asyncio.create_task( + self._init_provider_budget_in_cache( + provider=provider, + budget_config=self.provider_budget_config[provider], + ) + ) + + verbose_router_logger.debug( + f"Initalized Provider budget config: {self.provider_budget_config}" + ) + + def _init_deployment_budgets( + self, + model_list: Optional[ + Union[List[DeploymentTypedDict], List[Dict[str, Any]]] + ] = None, + ): + if model_list is None: + return + for _model in model_list: + _litellm_params = _model.get("litellm_params", {}) + _model_info: Dict = _model.get("model_info") or {} + _model_id = _model_info.get("id") + _max_budget = _litellm_params.get("max_budget") + _budget_duration = _litellm_params.get("budget_duration") + + verbose_router_logger.debug( + f"Init Deployment Budget: max_budget: {_max_budget}, budget_duration: {_budget_duration}, model_id: {_model_id}" + ) + if ( + _max_budget is not None + and _budget_duration is not None + and _model_id is not None + ): + _budget_config = GenericBudgetInfo( + time_period=_budget_duration, + budget_limit=_max_budget, + ) + if self.deployment_budget_config is None: + self.deployment_budget_config = {} + self.deployment_budget_config[_model_id] = _budget_config + + verbose_router_logger.debug( + f"Initialized Deployment Budget Config: {self.deployment_budget_config}" + ) + + def _init_tag_budgets(self): + if litellm.tag_budget_config is None: + return + from litellm.proxy.proxy_server import CommonProxyErrors, premium_user + + if premium_user is not True: + raise ValueError( + f"Tag budgets are an Enterprise only feature, {CommonProxyErrors.not_premium_user}" + ) + + if self.tag_budget_config is None: + self.tag_budget_config = {} + + for _tag, _tag_budget_config in litellm.tag_budget_config.items(): + if isinstance(_tag_budget_config, dict): + _tag_budget_config = BudgetConfig(**_tag_budget_config) + _generic_budget_config = GenericBudgetInfo( + time_period=_tag_budget_config.budget_duration, + budget_limit=_tag_budget_config.max_budget, + ) + self.tag_budget_config[_tag] = _generic_budget_config + + verbose_router_logger.debug( + f"Initialized Tag Budget Config: {self.tag_budget_config}" + ) diff --git a/.venv/lib/python3.12/site-packages/litellm/router_strategy/least_busy.py b/.venv/lib/python3.12/site-packages/litellm/router_strategy/least_busy.py new file mode 100644 index 00000000..12f3f01c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/router_strategy/least_busy.py @@ -0,0 +1,252 @@ +#### What this does #### +# identifies least busy deployment +# How is this achieved? +# - Before each call, have the router print the state of requests {"deployment": "requests_in_flight"} +# - use litellm.input_callbacks to log when a request is just about to be made to a model - {"deployment-id": traffic} +# - use litellm.success + failure callbacks to log when a request completed +# - in get_available_deployment, for a given model group name -> pick based on traffic + +import random +from typing import Optional + +from litellm.caching.caching import DualCache +from litellm.integrations.custom_logger import CustomLogger + + +class LeastBusyLoggingHandler(CustomLogger): + test_flag: bool = False + logged_success: int = 0 + logged_failure: int = 0 + + def __init__(self, router_cache: DualCache, model_list: list): + self.router_cache = router_cache + self.mapping_deployment_to_id: dict = {} + self.model_list = model_list + + def log_pre_api_call(self, model, messages, kwargs): + """ + Log when a model is being used. + + Caching based on model group. + """ + try: + if kwargs["litellm_params"].get("metadata") is None: + pass + else: + model_group = kwargs["litellm_params"]["metadata"].get( + "model_group", None + ) + id = kwargs["litellm_params"].get("model_info", {}).get("id", None) + if model_group is None or id is None: + return + elif isinstance(id, int): + id = str(id) + + request_count_api_key = f"{model_group}_request_count" + # update cache + request_count_dict = ( + self.router_cache.get_cache(key=request_count_api_key) or {} + ) + request_count_dict[id] = request_count_dict.get(id, 0) + 1 + + self.router_cache.set_cache( + key=request_count_api_key, value=request_count_dict + ) + except Exception: + pass + + def log_success_event(self, kwargs, response_obj, start_time, end_time): + try: + if kwargs["litellm_params"].get("metadata") is None: + pass + else: + model_group = kwargs["litellm_params"]["metadata"].get( + "model_group", None + ) + + id = kwargs["litellm_params"].get("model_info", {}).get("id", None) + if model_group is None or id is None: + return + elif isinstance(id, int): + id = str(id) + + request_count_api_key = f"{model_group}_request_count" + # decrement count in cache + request_count_dict = ( + self.router_cache.get_cache(key=request_count_api_key) or {} + ) + request_count_value: Optional[int] = request_count_dict.get(id, 0) + if request_count_value is None: + return + request_count_dict[id] = request_count_value - 1 + self.router_cache.set_cache( + key=request_count_api_key, value=request_count_dict + ) + + ### TESTING ### + if self.test_flag: + self.logged_success += 1 + except Exception: + pass + + def log_failure_event(self, kwargs, response_obj, start_time, end_time): + try: + if kwargs["litellm_params"].get("metadata") is None: + pass + else: + model_group = kwargs["litellm_params"]["metadata"].get( + "model_group", None + ) + id = kwargs["litellm_params"].get("model_info", {}).get("id", None) + if model_group is None or id is None: + return + elif isinstance(id, int): + id = str(id) + + request_count_api_key = f"{model_group}_request_count" + # decrement count in cache + request_count_dict = ( + self.router_cache.get_cache(key=request_count_api_key) or {} + ) + request_count_value: Optional[int] = request_count_dict.get(id, 0) + if request_count_value is None: + return + request_count_dict[id] = request_count_value - 1 + self.router_cache.set_cache( + key=request_count_api_key, value=request_count_dict + ) + + ### TESTING ### + if self.test_flag: + self.logged_failure += 1 + except Exception: + pass + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + try: + if kwargs["litellm_params"].get("metadata") is None: + pass + else: + model_group = kwargs["litellm_params"]["metadata"].get( + "model_group", None + ) + + id = kwargs["litellm_params"].get("model_info", {}).get("id", None) + if model_group is None or id is None: + return + elif isinstance(id, int): + id = str(id) + + request_count_api_key = f"{model_group}_request_count" + # decrement count in cache + request_count_dict = ( + await self.router_cache.async_get_cache(key=request_count_api_key) + or {} + ) + request_count_value: Optional[int] = request_count_dict.get(id, 0) + if request_count_value is None: + return + request_count_dict[id] = request_count_value - 1 + await self.router_cache.async_set_cache( + key=request_count_api_key, value=request_count_dict + ) + + ### TESTING ### + if self.test_flag: + self.logged_success += 1 + except Exception: + pass + + async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): + try: + if kwargs["litellm_params"].get("metadata") is None: + pass + else: + model_group = kwargs["litellm_params"]["metadata"].get( + "model_group", None + ) + id = kwargs["litellm_params"].get("model_info", {}).get("id", None) + if model_group is None or id is None: + return + elif isinstance(id, int): + id = str(id) + + request_count_api_key = f"{model_group}_request_count" + # decrement count in cache + request_count_dict = ( + await self.router_cache.async_get_cache(key=request_count_api_key) + or {} + ) + request_count_value: Optional[int] = request_count_dict.get(id, 0) + if request_count_value is None: + return + request_count_dict[id] = request_count_value - 1 + await self.router_cache.async_set_cache( + key=request_count_api_key, value=request_count_dict + ) + + ### TESTING ### + if self.test_flag: + self.logged_failure += 1 + except Exception: + pass + + def _get_available_deployments( + self, + healthy_deployments: list, + all_deployments: dict, + ): + """ + Helper to get deployments using least busy strategy + """ + for d in healthy_deployments: + ## if healthy deployment not yet used + if d["model_info"]["id"] not in all_deployments: + all_deployments[d["model_info"]["id"]] = 0 + # map deployment to id + # pick least busy deployment + min_traffic = float("inf") + min_deployment = None + for k, v in all_deployments.items(): + if v < min_traffic: + min_traffic = v + min_deployment = k + if min_deployment is not None: + ## check if min deployment is a string, if so, cast it to int + for m in healthy_deployments: + if m["model_info"]["id"] == min_deployment: + return m + min_deployment = random.choice(healthy_deployments) + else: + min_deployment = random.choice(healthy_deployments) + return min_deployment + + def get_available_deployments( + self, + model_group: str, + healthy_deployments: list, + ): + """ + Sync helper to get deployments using least busy strategy + """ + request_count_api_key = f"{model_group}_request_count" + all_deployments = self.router_cache.get_cache(key=request_count_api_key) or {} + return self._get_available_deployments( + healthy_deployments=healthy_deployments, + all_deployments=all_deployments, + ) + + async def async_get_available_deployments( + self, model_group: str, healthy_deployments: list + ): + """ + Async helper to get deployments using least busy strategy + """ + request_count_api_key = f"{model_group}_request_count" + all_deployments = ( + await self.router_cache.async_get_cache(key=request_count_api_key) or {} + ) + return self._get_available_deployments( + healthy_deployments=healthy_deployments, + all_deployments=all_deployments, + ) diff --git a/.venv/lib/python3.12/site-packages/litellm/router_strategy/lowest_cost.py b/.venv/lib/python3.12/site-packages/litellm/router_strategy/lowest_cost.py new file mode 100644 index 00000000..bd28f6dc --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/router_strategy/lowest_cost.py @@ -0,0 +1,333 @@ +#### What this does #### +# picks based on response time (for streaming, this is time to first token) +from datetime import datetime, timedelta +from typing import Dict, List, Optional, Union + +import litellm +from litellm import ModelResponse, token_counter, verbose_logger +from litellm._logging import verbose_router_logger +from litellm.caching.caching import DualCache +from litellm.integrations.custom_logger import CustomLogger + + +class LowestCostLoggingHandler(CustomLogger): + test_flag: bool = False + logged_success: int = 0 + logged_failure: int = 0 + + def __init__( + self, router_cache: DualCache, model_list: list, routing_args: dict = {} + ): + self.router_cache = router_cache + self.model_list = model_list + + def log_success_event(self, kwargs, response_obj, start_time, end_time): + try: + """ + Update usage on success + """ + if kwargs["litellm_params"].get("metadata") is None: + pass + else: + model_group = kwargs["litellm_params"]["metadata"].get( + "model_group", None + ) + + id = kwargs["litellm_params"].get("model_info", {}).get("id", None) + if model_group is None or id is None: + return + elif isinstance(id, int): + id = str(id) + + # ------------ + # Setup values + # ------------ + """ + { + {model_group}_map: { + id: { + f"{date:hour:minute}" : {"tpm": 34, "rpm": 3} + } + } + } + """ + current_date = datetime.now().strftime("%Y-%m-%d") + current_hour = datetime.now().strftime("%H") + current_minute = datetime.now().strftime("%M") + precise_minute = f"{current_date}-{current_hour}-{current_minute}" + cost_key = f"{model_group}_map" + + response_ms: timedelta = end_time - start_time + + total_tokens = 0 + + if isinstance(response_obj, ModelResponse): + _usage = getattr(response_obj, "usage", None) + if _usage is not None and isinstance(_usage, litellm.Usage): + completion_tokens = _usage.completion_tokens + total_tokens = _usage.total_tokens + float(response_ms.total_seconds() / completion_tokens) + + # ------------ + # Update usage + # ------------ + + request_count_dict = self.router_cache.get_cache(key=cost_key) or {} + + # check local result first + + if id not in request_count_dict: + request_count_dict[id] = {} + + if precise_minute not in request_count_dict[id]: + request_count_dict[id][precise_minute] = {} + + ## TPM + request_count_dict[id][precise_minute]["tpm"] = ( + request_count_dict[id][precise_minute].get("tpm", 0) + total_tokens + ) + + ## RPM + request_count_dict[id][precise_minute]["rpm"] = ( + request_count_dict[id][precise_minute].get("rpm", 0) + 1 + ) + + self.router_cache.set_cache(key=cost_key, value=request_count_dict) + + ### TESTING ### + if self.test_flag: + self.logged_success += 1 + except Exception as e: + verbose_logger.exception( + "litellm.router_strategy.lowest_cost.py::log_success_event(): Exception occured - {}".format( + str(e) + ) + ) + pass + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + try: + """ + Update cost usage on success + """ + if kwargs["litellm_params"].get("metadata") is None: + pass + else: + model_group = kwargs["litellm_params"]["metadata"].get( + "model_group", None + ) + + id = kwargs["litellm_params"].get("model_info", {}).get("id", None) + if model_group is None or id is None: + return + elif isinstance(id, int): + id = str(id) + + # ------------ + # Setup values + # ------------ + """ + { + {model_group}_map: { + id: { + "cost": [..] + f"{date:hour:minute}" : {"tpm": 34, "rpm": 3} + } + } + } + """ + cost_key = f"{model_group}_map" + + current_date = datetime.now().strftime("%Y-%m-%d") + current_hour = datetime.now().strftime("%H") + current_minute = datetime.now().strftime("%M") + precise_minute = f"{current_date}-{current_hour}-{current_minute}" + + response_ms: timedelta = end_time - start_time + + total_tokens = 0 + + if isinstance(response_obj, ModelResponse): + _usage = getattr(response_obj, "usage", None) + if _usage is not None and isinstance(_usage, litellm.Usage): + completion_tokens = _usage.completion_tokens + total_tokens = _usage.total_tokens + + float(response_ms.total_seconds() / completion_tokens) + + # ------------ + # Update usage + # ------------ + + request_count_dict = ( + await self.router_cache.async_get_cache(key=cost_key) or {} + ) + + if id not in request_count_dict: + request_count_dict[id] = {} + if precise_minute not in request_count_dict[id]: + request_count_dict[id][precise_minute] = {} + + ## TPM + request_count_dict[id][precise_minute]["tpm"] = ( + request_count_dict[id][precise_minute].get("tpm", 0) + total_tokens + ) + + ## RPM + request_count_dict[id][precise_minute]["rpm"] = ( + request_count_dict[id][precise_minute].get("rpm", 0) + 1 + ) + + await self.router_cache.async_set_cache( + key=cost_key, value=request_count_dict + ) # reset map within window + + ### TESTING ### + if self.test_flag: + self.logged_success += 1 + except Exception as e: + verbose_logger.exception( + "litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format( + str(e) + ) + ) + pass + + async def async_get_available_deployments( # noqa: PLR0915 + self, + model_group: str, + healthy_deployments: list, + messages: Optional[List[Dict[str, str]]] = None, + input: Optional[Union[str, List]] = None, + request_kwargs: Optional[Dict] = None, + ): + """ + Returns a deployment with the lowest cost + """ + cost_key = f"{model_group}_map" + + request_count_dict = await self.router_cache.async_get_cache(key=cost_key) or {} + + # ----------------------- + # Find lowest used model + # ---------------------- + float("inf") + + current_date = datetime.now().strftime("%Y-%m-%d") + current_hour = datetime.now().strftime("%H") + current_minute = datetime.now().strftime("%M") + precise_minute = f"{current_date}-{current_hour}-{current_minute}" + + if request_count_dict is None: # base case + return + + all_deployments = request_count_dict + for d in healthy_deployments: + ## if healthy deployment not yet used + if d["model_info"]["id"] not in all_deployments: + all_deployments[d["model_info"]["id"]] = { + precise_minute: {"tpm": 0, "rpm": 0}, + } + + try: + input_tokens = token_counter(messages=messages, text=input) + except Exception: + input_tokens = 0 + + # randomly sample from all_deployments, incase all deployments have latency=0.0 + _items = all_deployments.items() + + ### GET AVAILABLE DEPLOYMENTS ### filter out any deployments > tpm/rpm limits + potential_deployments = [] + _cost_per_deployment = {} + for item, item_map in all_deployments.items(): + ## get the item from model list + _deployment = None + for m in healthy_deployments: + if item == m["model_info"]["id"]: + _deployment = m + + if _deployment is None: + continue # skip to next one + + _deployment_tpm = ( + _deployment.get("tpm", None) + or _deployment.get("litellm_params", {}).get("tpm", None) + or _deployment.get("model_info", {}).get("tpm", None) + or float("inf") + ) + + _deployment_rpm = ( + _deployment.get("rpm", None) + or _deployment.get("litellm_params", {}).get("rpm", None) + or _deployment.get("model_info", {}).get("rpm", None) + or float("inf") + ) + item_litellm_model_name = _deployment.get("litellm_params", {}).get("model") + item_litellm_model_cost_map = litellm.model_cost.get( + item_litellm_model_name, {} + ) + + # check if user provided input_cost_per_token and output_cost_per_token in litellm_params + item_input_cost = None + item_output_cost = None + if _deployment.get("litellm_params", {}).get("input_cost_per_token", None): + item_input_cost = _deployment.get("litellm_params", {}).get( + "input_cost_per_token" + ) + + if _deployment.get("litellm_params", {}).get("output_cost_per_token", None): + item_output_cost = _deployment.get("litellm_params", {}).get( + "output_cost_per_token" + ) + + if item_input_cost is None: + item_input_cost = item_litellm_model_cost_map.get( + "input_cost_per_token", 5.0 + ) + + if item_output_cost is None: + item_output_cost = item_litellm_model_cost_map.get( + "output_cost_per_token", 5.0 + ) + + # if litellm["model"] is not in model_cost map -> use item_cost = $10 + + item_cost = item_input_cost + item_output_cost + + item_rpm = item_map.get(precise_minute, {}).get("rpm", 0) + item_tpm = item_map.get(precise_minute, {}).get("tpm", 0) + + verbose_router_logger.debug( + f"item_cost: {item_cost}, item_tpm: {item_tpm}, item_rpm: {item_rpm}, model_id: {_deployment.get('model_info', {}).get('id')}" + ) + + # -------------- # + # Debugging Logic + # -------------- # + # We use _cost_per_deployment to log to langfuse, slack - this is not used to make a decision on routing + # this helps a user to debug why the router picked a specfic deployment # + _deployment_api_base = _deployment.get("litellm_params", {}).get( + "api_base", "" + ) + if _deployment_api_base is not None: + _cost_per_deployment[_deployment_api_base] = item_cost + # -------------- # + # End of Debugging Logic + # -------------- # + + if ( + item_tpm + input_tokens > _deployment_tpm + or item_rpm + 1 > _deployment_rpm + ): # if user passed in tpm / rpm in the model_list + continue + else: + potential_deployments.append((_deployment, item_cost)) + + if len(potential_deployments) == 0: + return None + + potential_deployments = sorted(potential_deployments, key=lambda x: x[1]) + + selected_deployment = potential_deployments[0][0] + return selected_deployment diff --git a/.venv/lib/python3.12/site-packages/litellm/router_strategy/lowest_latency.py b/.venv/lib/python3.12/site-packages/litellm/router_strategy/lowest_latency.py new file mode 100644 index 00000000..b049c942 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/router_strategy/lowest_latency.py @@ -0,0 +1,590 @@ +#### What this does #### +# picks based on response time (for streaming, this is time to first token) +import random +from datetime import datetime, timedelta +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +import litellm +from litellm import ModelResponse, token_counter, verbose_logger +from litellm.caching.caching import DualCache +from litellm.integrations.custom_logger import CustomLogger +from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs +from litellm.types.utils import LiteLLMPydanticObjectBase + +if TYPE_CHECKING: + from opentelemetry.trace import Span as _Span + + Span = _Span +else: + Span = Any + + +class RoutingArgs(LiteLLMPydanticObjectBase): + ttl: float = 1 * 60 * 60 # 1 hour + lowest_latency_buffer: float = 0 + max_latency_list_size: int = 10 + + +class LowestLatencyLoggingHandler(CustomLogger): + test_flag: bool = False + logged_success: int = 0 + logged_failure: int = 0 + + def __init__( + self, router_cache: DualCache, model_list: list, routing_args: dict = {} + ): + self.router_cache = router_cache + self.model_list = model_list + self.routing_args = RoutingArgs(**routing_args) + + def log_success_event( # noqa: PLR0915 + self, kwargs, response_obj, start_time, end_time + ): + try: + """ + Update latency usage on success + """ + if kwargs["litellm_params"].get("metadata") is None: + pass + else: + model_group = kwargs["litellm_params"]["metadata"].get( + "model_group", None + ) + + id = kwargs["litellm_params"].get("model_info", {}).get("id", None) + if model_group is None or id is None: + return + elif isinstance(id, int): + id = str(id) + + # ------------ + # Setup values + # ------------ + """ + { + {model_group}_map: { + id: { + "latency": [..] + f"{date:hour:minute}" : {"tpm": 34, "rpm": 3} + } + } + } + """ + latency_key = f"{model_group}_map" + + current_date = datetime.now().strftime("%Y-%m-%d") + current_hour = datetime.now().strftime("%H") + current_minute = datetime.now().strftime("%M") + precise_minute = f"{current_date}-{current_hour}-{current_minute}" + + response_ms: timedelta = end_time - start_time + time_to_first_token_response_time: Optional[timedelta] = None + + if kwargs.get("stream", None) is not None and kwargs["stream"] is True: + # only log ttft for streaming request + time_to_first_token_response_time = ( + kwargs.get("completion_start_time", end_time) - start_time + ) + + final_value: Union[float, timedelta] = response_ms + time_to_first_token: Optional[float] = None + total_tokens = 0 + + if isinstance(response_obj, ModelResponse): + _usage = getattr(response_obj, "usage", None) + if _usage is not None: + completion_tokens = _usage.completion_tokens + total_tokens = _usage.total_tokens + final_value = float( + response_ms.total_seconds() / completion_tokens + ) + + if time_to_first_token_response_time is not None: + time_to_first_token = float( + time_to_first_token_response_time.total_seconds() + / completion_tokens + ) + + # ------------ + # Update usage + # ------------ + parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs) + request_count_dict = ( + self.router_cache.get_cache( + key=latency_key, parent_otel_span=parent_otel_span + ) + or {} + ) + + if id not in request_count_dict: + request_count_dict[id] = {} + + ## Latency + if ( + len(request_count_dict[id].get("latency", [])) + < self.routing_args.max_latency_list_size + ): + request_count_dict[id].setdefault("latency", []).append(final_value) + else: + request_count_dict[id]["latency"] = request_count_dict[id][ + "latency" + ][: self.routing_args.max_latency_list_size - 1] + [final_value] + + ## Time to first token + if time_to_first_token is not None: + if ( + len(request_count_dict[id].get("time_to_first_token", [])) + < self.routing_args.max_latency_list_size + ): + request_count_dict[id].setdefault( + "time_to_first_token", [] + ).append(time_to_first_token) + else: + request_count_dict[id][ + "time_to_first_token" + ] = request_count_dict[id]["time_to_first_token"][ + : self.routing_args.max_latency_list_size - 1 + ] + [ + time_to_first_token + ] + + if precise_minute not in request_count_dict[id]: + request_count_dict[id][precise_minute] = {} + + ## TPM + request_count_dict[id][precise_minute]["tpm"] = ( + request_count_dict[id][precise_minute].get("tpm", 0) + total_tokens + ) + + ## RPM + request_count_dict[id][precise_minute]["rpm"] = ( + request_count_dict[id][precise_minute].get("rpm", 0) + 1 + ) + + self.router_cache.set_cache( + key=latency_key, value=request_count_dict, ttl=self.routing_args.ttl + ) # reset map within window + + ### TESTING ### + if self.test_flag: + self.logged_success += 1 + except Exception as e: + verbose_logger.exception( + "litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format( + str(e) + ) + ) + pass + + async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): + """ + Check if Timeout Error, if timeout set deployment latency -> 100 + """ + try: + _exception = kwargs.get("exception", None) + if isinstance(_exception, litellm.Timeout): + if kwargs["litellm_params"].get("metadata") is None: + pass + else: + model_group = kwargs["litellm_params"]["metadata"].get( + "model_group", None + ) + + id = kwargs["litellm_params"].get("model_info", {}).get("id", None) + if model_group is None or id is None: + return + elif isinstance(id, int): + id = str(id) + + # ------------ + # Setup values + # ------------ + """ + { + {model_group}_map: { + id: { + "latency": [..] + f"{date:hour:minute}" : {"tpm": 34, "rpm": 3} + } + } + } + """ + latency_key = f"{model_group}_map" + request_count_dict = ( + await self.router_cache.async_get_cache(key=latency_key) or {} + ) + + if id not in request_count_dict: + request_count_dict[id] = {} + + ## Latency - give 1000s penalty for failing + if ( + len(request_count_dict[id].get("latency", [])) + < self.routing_args.max_latency_list_size + ): + request_count_dict[id].setdefault("latency", []).append(1000.0) + else: + request_count_dict[id]["latency"] = request_count_dict[id][ + "latency" + ][: self.routing_args.max_latency_list_size - 1] + [1000.0] + + await self.router_cache.async_set_cache( + key=latency_key, + value=request_count_dict, + ttl=self.routing_args.ttl, + ) # reset map within window + else: + # do nothing if it's not a timeout error + return + except Exception as e: + verbose_logger.exception( + "litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format( + str(e) + ) + ) + pass + + async def async_log_success_event( # noqa: PLR0915 + self, kwargs, response_obj, start_time, end_time + ): + try: + """ + Update latency usage on success + """ + if kwargs["litellm_params"].get("metadata") is None: + pass + else: + model_group = kwargs["litellm_params"]["metadata"].get( + "model_group", None + ) + + id = kwargs["litellm_params"].get("model_info", {}).get("id", None) + if model_group is None or id is None: + return + elif isinstance(id, int): + id = str(id) + + # ------------ + # Setup values + # ------------ + """ + { + {model_group}_map: { + id: { + "latency": [..] + "time_to_first_token": [..] + f"{date:hour:minute}" : {"tpm": 34, "rpm": 3} + } + } + } + """ + latency_key = f"{model_group}_map" + + current_date = datetime.now().strftime("%Y-%m-%d") + current_hour = datetime.now().strftime("%H") + current_minute = datetime.now().strftime("%M") + precise_minute = f"{current_date}-{current_hour}-{current_minute}" + + response_ms: timedelta = end_time - start_time + time_to_first_token_response_time: Optional[timedelta] = None + if kwargs.get("stream", None) is not None and kwargs["stream"] is True: + # only log ttft for streaming request + time_to_first_token_response_time = ( + kwargs.get("completion_start_time", end_time) - start_time + ) + + final_value: Union[float, timedelta] = response_ms + total_tokens = 0 + time_to_first_token: Optional[float] = None + + if isinstance(response_obj, ModelResponse): + _usage = getattr(response_obj, "usage", None) + if _usage is not None: + completion_tokens = _usage.completion_tokens + total_tokens = _usage.total_tokens + final_value = float( + response_ms.total_seconds() / completion_tokens + ) + + if time_to_first_token_response_time is not None: + time_to_first_token = float( + time_to_first_token_response_time.total_seconds() + / completion_tokens + ) + # ------------ + # Update usage + # ------------ + parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs) + request_count_dict = ( + await self.router_cache.async_get_cache( + key=latency_key, + parent_otel_span=parent_otel_span, + local_only=True, + ) + or {} + ) + + if id not in request_count_dict: + request_count_dict[id] = {} + + ## Latency + if ( + len(request_count_dict[id].get("latency", [])) + < self.routing_args.max_latency_list_size + ): + request_count_dict[id].setdefault("latency", []).append(final_value) + else: + request_count_dict[id]["latency"] = request_count_dict[id][ + "latency" + ][: self.routing_args.max_latency_list_size - 1] + [final_value] + + ## Time to first token + if time_to_first_token is not None: + if ( + len(request_count_dict[id].get("time_to_first_token", [])) + < self.routing_args.max_latency_list_size + ): + request_count_dict[id].setdefault( + "time_to_first_token", [] + ).append(time_to_first_token) + else: + request_count_dict[id][ + "time_to_first_token" + ] = request_count_dict[id]["time_to_first_token"][ + : self.routing_args.max_latency_list_size - 1 + ] + [ + time_to_first_token + ] + + if precise_minute not in request_count_dict[id]: + request_count_dict[id][precise_minute] = {} + + ## TPM + request_count_dict[id][precise_minute]["tpm"] = ( + request_count_dict[id][precise_minute].get("tpm", 0) + total_tokens + ) + + ## RPM + request_count_dict[id][precise_minute]["rpm"] = ( + request_count_dict[id][precise_minute].get("rpm", 0) + 1 + ) + + await self.router_cache.async_set_cache( + key=latency_key, value=request_count_dict, ttl=self.routing_args.ttl + ) # reset map within window + + ### TESTING ### + if self.test_flag: + self.logged_success += 1 + except Exception as e: + verbose_logger.exception( + "litellm.router_strategy.lowest_latency.py::async_log_success_event(): Exception occured - {}".format( + str(e) + ) + ) + pass + + def _get_available_deployments( # noqa: PLR0915 + self, + model_group: str, + healthy_deployments: list, + messages: Optional[List[Dict[str, str]]] = None, + input: Optional[Union[str, List]] = None, + request_kwargs: Optional[Dict] = None, + request_count_dict: Optional[Dict] = None, + ): + """Common logic for both sync and async get_available_deployments""" + + # ----------------------- + # Find lowest used model + # ---------------------- + _latency_per_deployment = {} + lowest_latency = float("inf") + + current_date = datetime.now().strftime("%Y-%m-%d") + current_hour = datetime.now().strftime("%H") + current_minute = datetime.now().strftime("%M") + precise_minute = f"{current_date}-{current_hour}-{current_minute}" + + deployment = None + + if request_count_dict is None: # base case + return + + all_deployments = request_count_dict + for d in healthy_deployments: + ## if healthy deployment not yet used + if d["model_info"]["id"] not in all_deployments: + all_deployments[d["model_info"]["id"]] = { + "latency": [0], + precise_minute: {"tpm": 0, "rpm": 0}, + } + + try: + input_tokens = token_counter(messages=messages, text=input) + except Exception: + input_tokens = 0 + + # randomly sample from all_deployments, incase all deployments have latency=0.0 + _items = all_deployments.items() + + _all_deployments = random.sample(list(_items), len(_items)) + all_deployments = dict(_all_deployments) + ### GET AVAILABLE DEPLOYMENTS ### filter out any deployments > tpm/rpm limits + + potential_deployments = [] + for item, item_map in all_deployments.items(): + ## get the item from model list + _deployment = None + for m in healthy_deployments: + if item == m["model_info"]["id"]: + _deployment = m + + if _deployment is None: + continue # skip to next one + + _deployment_tpm = ( + _deployment.get("tpm", None) + or _deployment.get("litellm_params", {}).get("tpm", None) + or _deployment.get("model_info", {}).get("tpm", None) + or float("inf") + ) + + _deployment_rpm = ( + _deployment.get("rpm", None) + or _deployment.get("litellm_params", {}).get("rpm", None) + or _deployment.get("model_info", {}).get("rpm", None) + or float("inf") + ) + item_latency = item_map.get("latency", []) + item_ttft_latency = item_map.get("time_to_first_token", []) + item_rpm = item_map.get(precise_minute, {}).get("rpm", 0) + item_tpm = item_map.get(precise_minute, {}).get("tpm", 0) + + # get average latency or average ttft (depending on streaming/non-streaming) + total: float = 0.0 + if ( + request_kwargs is not None + and request_kwargs.get("stream", None) is not None + and request_kwargs["stream"] is True + and len(item_ttft_latency) > 0 + ): + for _call_latency in item_ttft_latency: + if isinstance(_call_latency, float): + total += _call_latency + else: + for _call_latency in item_latency: + if isinstance(_call_latency, float): + total += _call_latency + item_latency = total / len(item_latency) + + # -------------- # + # Debugging Logic + # -------------- # + # We use _latency_per_deployment to log to langfuse, slack - this is not used to make a decision on routing + # this helps a user to debug why the router picked a specfic deployment # + _deployment_api_base = _deployment.get("litellm_params", {}).get( + "api_base", "" + ) + if _deployment_api_base is not None: + _latency_per_deployment[_deployment_api_base] = item_latency + # -------------- # + # End of Debugging Logic + # -------------- # + + if ( + item_tpm + input_tokens > _deployment_tpm + or item_rpm + 1 > _deployment_rpm + ): # if user passed in tpm / rpm in the model_list + continue + else: + potential_deployments.append((_deployment, item_latency)) + + if len(potential_deployments) == 0: + return None + + # Sort potential deployments by latency + sorted_deployments = sorted(potential_deployments, key=lambda x: x[1]) + + # Find lowest latency deployment + lowest_latency = sorted_deployments[0][1] + + # Find deployments within buffer of lowest latency + buffer = self.routing_args.lowest_latency_buffer * lowest_latency + + valid_deployments = [ + x for x in sorted_deployments if x[1] <= lowest_latency + buffer + ] + + # Pick a random deployment from valid deployments + random_valid_deployment = random.choice(valid_deployments) + deployment = random_valid_deployment[0] + + if request_kwargs is not None and "metadata" in request_kwargs: + request_kwargs["metadata"][ + "_latency_per_deployment" + ] = _latency_per_deployment + return deployment + + async def async_get_available_deployments( + self, + model_group: str, + healthy_deployments: list, + messages: Optional[List[Dict[str, str]]] = None, + input: Optional[Union[str, List]] = None, + request_kwargs: Optional[Dict] = None, + ): + # get list of potential deployments + latency_key = f"{model_group}_map" + + parent_otel_span: Optional[Span] = _get_parent_otel_span_from_kwargs( + request_kwargs + ) + request_count_dict = ( + await self.router_cache.async_get_cache( + key=latency_key, parent_otel_span=parent_otel_span + ) + or {} + ) + + return self._get_available_deployments( + model_group, + healthy_deployments, + messages, + input, + request_kwargs, + request_count_dict, + ) + + def get_available_deployments( + self, + model_group: str, + healthy_deployments: list, + messages: Optional[List[Dict[str, str]]] = None, + input: Optional[Union[str, List]] = None, + request_kwargs: Optional[Dict] = None, + ): + """ + Returns a deployment with the lowest latency + """ + # get list of potential deployments + latency_key = f"{model_group}_map" + + parent_otel_span: Optional[Span] = _get_parent_otel_span_from_kwargs( + request_kwargs + ) + request_count_dict = ( + self.router_cache.get_cache( + key=latency_key, parent_otel_span=parent_otel_span + ) + or {} + ) + + return self._get_available_deployments( + model_group, + healthy_deployments, + messages, + input, + request_kwargs, + request_count_dict, + ) diff --git a/.venv/lib/python3.12/site-packages/litellm/router_strategy/lowest_tpm_rpm.py b/.venv/lib/python3.12/site-packages/litellm/router_strategy/lowest_tpm_rpm.py new file mode 100644 index 00000000..86587939 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/router_strategy/lowest_tpm_rpm.py @@ -0,0 +1,243 @@ +#### What this does #### +# identifies lowest tpm deployment +import traceback +from datetime import datetime +from typing import Dict, List, Optional, Union + +from litellm import token_counter +from litellm._logging import verbose_router_logger +from litellm.caching.caching import DualCache +from litellm.integrations.custom_logger import CustomLogger +from litellm.types.utils import LiteLLMPydanticObjectBase +from litellm.utils import print_verbose + + +class RoutingArgs(LiteLLMPydanticObjectBase): + ttl: int = 1 * 60 # 1min (RPM/TPM expire key) + + +class LowestTPMLoggingHandler(CustomLogger): + test_flag: bool = False + logged_success: int = 0 + logged_failure: int = 0 + default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour + + def __init__( + self, router_cache: DualCache, model_list: list, routing_args: dict = {} + ): + self.router_cache = router_cache + self.model_list = model_list + self.routing_args = RoutingArgs(**routing_args) + + def log_success_event(self, kwargs, response_obj, start_time, end_time): + try: + """ + Update TPM/RPM usage on success + """ + if kwargs["litellm_params"].get("metadata") is None: + pass + else: + model_group = kwargs["litellm_params"]["metadata"].get( + "model_group", None + ) + + id = kwargs["litellm_params"].get("model_info", {}).get("id", None) + if model_group is None or id is None: + return + elif isinstance(id, int): + id = str(id) + + total_tokens = response_obj["usage"]["total_tokens"] + + # ------------ + # Setup values + # ------------ + current_minute = datetime.now().strftime("%H-%M") + tpm_key = f"{model_group}:tpm:{current_minute}" + rpm_key = f"{model_group}:rpm:{current_minute}" + + # ------------ + # Update usage + # ------------ + + ## TPM + request_count_dict = self.router_cache.get_cache(key=tpm_key) or {} + request_count_dict[id] = request_count_dict.get(id, 0) + total_tokens + + self.router_cache.set_cache( + key=tpm_key, value=request_count_dict, ttl=self.routing_args.ttl + ) + + ## RPM + request_count_dict = self.router_cache.get_cache(key=rpm_key) or {} + request_count_dict[id] = request_count_dict.get(id, 0) + 1 + + self.router_cache.set_cache( + key=rpm_key, value=request_count_dict, ttl=self.routing_args.ttl + ) + + ### TESTING ### + if self.test_flag: + self.logged_success += 1 + except Exception as e: + verbose_router_logger.error( + "litellm.router_strategy.lowest_tpm_rpm.py::async_log_success_event(): Exception occured - {}".format( + str(e) + ) + ) + verbose_router_logger.debug(traceback.format_exc()) + pass + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + try: + """ + Update TPM/RPM usage on success + """ + if kwargs["litellm_params"].get("metadata") is None: + pass + else: + model_group = kwargs["litellm_params"]["metadata"].get( + "model_group", None + ) + + id = kwargs["litellm_params"].get("model_info", {}).get("id", None) + if model_group is None or id is None: + return + elif isinstance(id, int): + id = str(id) + + total_tokens = response_obj["usage"]["total_tokens"] + + # ------------ + # Setup values + # ------------ + current_minute = datetime.now().strftime("%H-%M") + tpm_key = f"{model_group}:tpm:{current_minute}" + rpm_key = f"{model_group}:rpm:{current_minute}" + + # ------------ + # Update usage + # ------------ + # update cache + + ## TPM + request_count_dict = ( + await self.router_cache.async_get_cache(key=tpm_key) or {} + ) + request_count_dict[id] = request_count_dict.get(id, 0) + total_tokens + + await self.router_cache.async_set_cache( + key=tpm_key, value=request_count_dict, ttl=self.routing_args.ttl + ) + + ## RPM + request_count_dict = ( + await self.router_cache.async_get_cache(key=rpm_key) or {} + ) + request_count_dict[id] = request_count_dict.get(id, 0) + 1 + + await self.router_cache.async_set_cache( + key=rpm_key, value=request_count_dict, ttl=self.routing_args.ttl + ) + + ### TESTING ### + if self.test_flag: + self.logged_success += 1 + except Exception as e: + verbose_router_logger.error( + "litellm.router_strategy.lowest_tpm_rpm.py::async_log_success_event(): Exception occured - {}".format( + str(e) + ) + ) + verbose_router_logger.debug(traceback.format_exc()) + pass + + def get_available_deployments( # noqa: PLR0915 + self, + model_group: str, + healthy_deployments: list, + messages: Optional[List[Dict[str, str]]] = None, + input: Optional[Union[str, List]] = None, + ): + """ + Returns a deployment with the lowest TPM/RPM usage. + """ + # get list of potential deployments + verbose_router_logger.debug( + f"get_available_deployments - Usage Based. model_group: {model_group}, healthy_deployments: {healthy_deployments}" + ) + current_minute = datetime.now().strftime("%H-%M") + tpm_key = f"{model_group}:tpm:{current_minute}" + rpm_key = f"{model_group}:rpm:{current_minute}" + + tpm_dict = self.router_cache.get_cache(key=tpm_key) + rpm_dict = self.router_cache.get_cache(key=rpm_key) + + verbose_router_logger.debug( + f"tpm_key={tpm_key}, tpm_dict: {tpm_dict}, rpm_dict: {rpm_dict}" + ) + try: + input_tokens = token_counter(messages=messages, text=input) + except Exception: + input_tokens = 0 + verbose_router_logger.debug(f"input_tokens={input_tokens}") + # ----------------------- + # Find lowest used model + # ---------------------- + lowest_tpm = float("inf") + + if tpm_dict is None: # base case - none of the deployments have been used + # initialize a tpm dict with {model_id: 0} + tpm_dict = {} + for deployment in healthy_deployments: + tpm_dict[deployment["model_info"]["id"]] = 0 + else: + for d in healthy_deployments: + ## if healthy deployment not yet used + if d["model_info"]["id"] not in tpm_dict: + tpm_dict[d["model_info"]["id"]] = 0 + + all_deployments = tpm_dict + + deployment = None + for item, item_tpm in all_deployments.items(): + ## get the item from model list + _deployment = None + for m in healthy_deployments: + if item == m["model_info"]["id"]: + _deployment = m + + if _deployment is None: + continue # skip to next one + + _deployment_tpm = None + if _deployment_tpm is None: + _deployment_tpm = _deployment.get("tpm") + if _deployment_tpm is None: + _deployment_tpm = _deployment.get("litellm_params", {}).get("tpm") + if _deployment_tpm is None: + _deployment_tpm = _deployment.get("model_info", {}).get("tpm") + if _deployment_tpm is None: + _deployment_tpm = float("inf") + + _deployment_rpm = None + if _deployment_rpm is None: + _deployment_rpm = _deployment.get("rpm") + if _deployment_rpm is None: + _deployment_rpm = _deployment.get("litellm_params", {}).get("rpm") + if _deployment_rpm is None: + _deployment_rpm = _deployment.get("model_info", {}).get("rpm") + if _deployment_rpm is None: + _deployment_rpm = float("inf") + + if item_tpm + input_tokens > _deployment_tpm: + continue + elif (rpm_dict is not None and item in rpm_dict) and ( + rpm_dict[item] + 1 >= _deployment_rpm + ): + continue + elif item_tpm < lowest_tpm: + lowest_tpm = item_tpm + deployment = _deployment + print_verbose("returning picked lowest tpm/rpm deployment.") + return deployment diff --git a/.venv/lib/python3.12/site-packages/litellm/router_strategy/lowest_tpm_rpm_v2.py b/.venv/lib/python3.12/site-packages/litellm/router_strategy/lowest_tpm_rpm_v2.py new file mode 100644 index 00000000..d1a46b7e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/router_strategy/lowest_tpm_rpm_v2.py @@ -0,0 +1,671 @@ +#### What this does #### +# identifies lowest tpm deployment +import random +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +import httpx + +import litellm +from litellm import token_counter +from litellm._logging import verbose_logger, verbose_router_logger +from litellm.caching.caching import DualCache +from litellm.integrations.custom_logger import CustomLogger +from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs +from litellm.types.router import RouterErrors +from litellm.types.utils import LiteLLMPydanticObjectBase, StandardLoggingPayload +from litellm.utils import get_utc_datetime, print_verbose + +from .base_routing_strategy import BaseRoutingStrategy + +if TYPE_CHECKING: + from opentelemetry.trace import Span as _Span + + Span = _Span +else: + Span = Any + + +class RoutingArgs(LiteLLMPydanticObjectBase): + ttl: int = 1 * 60 # 1min (RPM/TPM expire key) + + +class LowestTPMLoggingHandler_v2(BaseRoutingStrategy, CustomLogger): + """ + Updated version of TPM/RPM Logging. + + Meant to work across instances. + + Caches individual models, not model_groups + + Uses batch get (redis.mget) + + Increments tpm/rpm limit using redis.incr + """ + + test_flag: bool = False + logged_success: int = 0 + logged_failure: int = 0 + default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour + + def __init__( + self, router_cache: DualCache, model_list: list, routing_args: dict = {} + ): + self.router_cache = router_cache + self.model_list = model_list + self.routing_args = RoutingArgs(**routing_args) + BaseRoutingStrategy.__init__( + self, + dual_cache=router_cache, + should_batch_redis_writes=True, + default_sync_interval=0.1, + ) + + def pre_call_check(self, deployment: Dict) -> Optional[Dict]: + """ + Pre-call check + update model rpm + + Returns - deployment + + Raises - RateLimitError if deployment over defined RPM limit + """ + try: + + # ------------ + # Setup values + # ------------ + + dt = get_utc_datetime() + current_minute = dt.strftime("%H-%M") + model_id = deployment.get("model_info", {}).get("id") + deployment_name = deployment.get("litellm_params", {}).get("model") + rpm_key = f"{model_id}:{deployment_name}:rpm:{current_minute}" + + local_result = self.router_cache.get_cache( + key=rpm_key, local_only=True + ) # check local result first + + deployment_rpm = None + if deployment_rpm is None: + deployment_rpm = deployment.get("rpm") + if deployment_rpm is None: + deployment_rpm = deployment.get("litellm_params", {}).get("rpm") + if deployment_rpm is None: + deployment_rpm = deployment.get("model_info", {}).get("rpm") + if deployment_rpm is None: + deployment_rpm = float("inf") + + if local_result is not None and local_result >= deployment_rpm: + raise litellm.RateLimitError( + message="Deployment over defined rpm limit={}. current usage={}".format( + deployment_rpm, local_result + ), + llm_provider="", + model=deployment.get("litellm_params", {}).get("model"), + response=httpx.Response( + status_code=429, + content="{} rpm limit={}. current usage={}. id={}, model_group={}. Get the model info by calling 'router.get_model_info(id)".format( + RouterErrors.user_defined_ratelimit_error.value, + deployment_rpm, + local_result, + model_id, + deployment.get("model_name", ""), + ), + request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore + ), + ) + else: + # if local result below limit, check redis ## prevent unnecessary redis checks + + result = self.router_cache.increment_cache( + key=rpm_key, value=1, ttl=self.routing_args.ttl + ) + if result is not None and result > deployment_rpm: + raise litellm.RateLimitError( + message="Deployment over defined rpm limit={}. current usage={}".format( + deployment_rpm, result + ), + llm_provider="", + model=deployment.get("litellm_params", {}).get("model"), + response=httpx.Response( + status_code=429, + content="{} rpm limit={}. current usage={}".format( + RouterErrors.user_defined_ratelimit_error.value, + deployment_rpm, + result, + ), + request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore + ), + ) + return deployment + except Exception as e: + if isinstance(e, litellm.RateLimitError): + raise e + return deployment # don't fail calls if eg. redis fails to connect + + async def async_pre_call_check( + self, deployment: Dict, parent_otel_span: Optional[Span] + ) -> Optional[Dict]: + """ + Pre-call check + update model rpm + - Used inside semaphore + - raise rate limit error if deployment over limit + + Why? solves concurrency issue - https://github.com/BerriAI/litellm/issues/2994 + + Returns - deployment + + Raises - RateLimitError if deployment over defined RPM limit + """ + try: + # ------------ + # Setup values + # ------------ + dt = get_utc_datetime() + current_minute = dt.strftime("%H-%M") + model_id = deployment.get("model_info", {}).get("id") + deployment_name = deployment.get("litellm_params", {}).get("model") + + rpm_key = f"{model_id}:{deployment_name}:rpm:{current_minute}" + local_result = await self.router_cache.async_get_cache( + key=rpm_key, local_only=True + ) # check local result first + + deployment_rpm = None + if deployment_rpm is None: + deployment_rpm = deployment.get("rpm") + if deployment_rpm is None: + deployment_rpm = deployment.get("litellm_params", {}).get("rpm") + if deployment_rpm is None: + deployment_rpm = deployment.get("model_info", {}).get("rpm") + if deployment_rpm is None: + deployment_rpm = float("inf") + if local_result is not None and local_result >= deployment_rpm: + raise litellm.RateLimitError( + message="Deployment over defined rpm limit={}. current usage={}".format( + deployment_rpm, local_result + ), + llm_provider="", + model=deployment.get("litellm_params", {}).get("model"), + response=httpx.Response( + status_code=429, + content="{} rpm limit={}. current usage={}".format( + RouterErrors.user_defined_ratelimit_error.value, + deployment_rpm, + local_result, + ), + headers={"retry-after": str(60)}, # type: ignore + request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore + ), + num_retries=deployment.get("num_retries"), + ) + else: + # if local result below limit, check redis ## prevent unnecessary redis checks + result = await self._increment_value_in_current_window( + key=rpm_key, value=1, ttl=self.routing_args.ttl + ) + if result is not None and result > deployment_rpm: + raise litellm.RateLimitError( + message="Deployment over defined rpm limit={}. current usage={}".format( + deployment_rpm, result + ), + llm_provider="", + model=deployment.get("litellm_params", {}).get("model"), + response=httpx.Response( + status_code=429, + content="{} rpm limit={}. current usage={}".format( + RouterErrors.user_defined_ratelimit_error.value, + deployment_rpm, + result, + ), + headers={"retry-after": str(60)}, # type: ignore + request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore + ), + num_retries=deployment.get("num_retries"), + ) + return deployment + except Exception as e: + if isinstance(e, litellm.RateLimitError): + raise e + return deployment # don't fail calls if eg. redis fails to connect + + def log_success_event(self, kwargs, response_obj, start_time, end_time): + try: + """ + Update TPM/RPM usage on success + """ + standard_logging_object: Optional[StandardLoggingPayload] = kwargs.get( + "standard_logging_object" + ) + if standard_logging_object is None: + raise ValueError("standard_logging_object not passed in.") + model_group = standard_logging_object.get("model_group") + model = standard_logging_object["hidden_params"].get("litellm_model_name") + id = standard_logging_object.get("model_id") + if model_group is None or id is None or model is None: + return + elif isinstance(id, int): + id = str(id) + + total_tokens = standard_logging_object.get("total_tokens") + + # ------------ + # Setup values + # ------------ + dt = get_utc_datetime() + current_minute = dt.strftime( + "%H-%M" + ) # use the same timezone regardless of system clock + + tpm_key = f"{id}:{model}:tpm:{current_minute}" + # ------------ + # Update usage + # ------------ + # update cache + + ## TPM + self.router_cache.increment_cache( + key=tpm_key, value=total_tokens, ttl=self.routing_args.ttl + ) + ### TESTING ### + if self.test_flag: + self.logged_success += 1 + except Exception as e: + verbose_logger.exception( + "litellm.proxy.hooks.lowest_tpm_rpm_v2.py::log_success_event(): Exception occured - {}".format( + str(e) + ) + ) + pass + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + try: + """ + Update TPM usage on success + """ + standard_logging_object: Optional[StandardLoggingPayload] = kwargs.get( + "standard_logging_object" + ) + if standard_logging_object is None: + raise ValueError("standard_logging_object not passed in.") + model_group = standard_logging_object.get("model_group") + model = standard_logging_object["hidden_params"]["litellm_model_name"] + id = standard_logging_object.get("model_id") + if model_group is None or id is None: + return + elif isinstance(id, int): + id = str(id) + total_tokens = standard_logging_object.get("total_tokens") + # ------------ + # Setup values + # ------------ + dt = get_utc_datetime() + current_minute = dt.strftime( + "%H-%M" + ) # use the same timezone regardless of system clock + + tpm_key = f"{id}:{model}:tpm:{current_minute}" + # ------------ + # Update usage + # ------------ + # update cache + parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs) + ## TPM + await self.router_cache.async_increment_cache( + key=tpm_key, + value=total_tokens, + ttl=self.routing_args.ttl, + parent_otel_span=parent_otel_span, + ) + + ### TESTING ### + if self.test_flag: + self.logged_success += 1 + except Exception as e: + verbose_logger.exception( + "litellm.proxy.hooks.lowest_tpm_rpm_v2.py::async_log_success_event(): Exception occured - {}".format( + str(e) + ) + ) + pass + + def _return_potential_deployments( + self, + healthy_deployments: List[Dict], + all_deployments: Dict, + input_tokens: int, + rpm_dict: Dict, + ): + lowest_tpm = float("inf") + potential_deployments = [] # if multiple deployments have the same low value + for item, item_tpm in all_deployments.items(): + ## get the item from model list + _deployment = None + item = item.split(":")[0] + for m in healthy_deployments: + if item == m["model_info"]["id"]: + _deployment = m + if _deployment is None: + continue # skip to next one + elif item_tpm is None: + continue # skip if unhealthy deployment + + _deployment_tpm = None + if _deployment_tpm is None: + _deployment_tpm = _deployment.get("tpm") + if _deployment_tpm is None: + _deployment_tpm = _deployment.get("litellm_params", {}).get("tpm") + if _deployment_tpm is None: + _deployment_tpm = _deployment.get("model_info", {}).get("tpm") + if _deployment_tpm is None: + _deployment_tpm = float("inf") + + _deployment_rpm = None + if _deployment_rpm is None: + _deployment_rpm = _deployment.get("rpm") + if _deployment_rpm is None: + _deployment_rpm = _deployment.get("litellm_params", {}).get("rpm") + if _deployment_rpm is None: + _deployment_rpm = _deployment.get("model_info", {}).get("rpm") + if _deployment_rpm is None: + _deployment_rpm = float("inf") + if item_tpm + input_tokens > _deployment_tpm: + continue + elif ( + (rpm_dict is not None and item in rpm_dict) + and rpm_dict[item] is not None + and (rpm_dict[item] + 1 >= _deployment_rpm) + ): + continue + elif item_tpm == lowest_tpm: + potential_deployments.append(_deployment) + elif item_tpm < lowest_tpm: + lowest_tpm = item_tpm + potential_deployments = [_deployment] + return potential_deployments + + def _common_checks_available_deployment( # noqa: PLR0915 + self, + model_group: str, + healthy_deployments: list, + tpm_keys: list, + tpm_values: Optional[list], + rpm_keys: list, + rpm_values: Optional[list], + messages: Optional[List[Dict[str, str]]] = None, + input: Optional[Union[str, List]] = None, + ) -> Optional[dict]: + """ + Common checks for get available deployment, across sync + async implementations + """ + + if tpm_values is None or rpm_values is None: + return None + + tpm_dict = {} # {model_id: 1, ..} + for idx, key in enumerate(tpm_keys): + tpm_dict[tpm_keys[idx].split(":")[0]] = tpm_values[idx] + + rpm_dict = {} # {model_id: 1, ..} + for idx, key in enumerate(rpm_keys): + rpm_dict[rpm_keys[idx].split(":")[0]] = rpm_values[idx] + + try: + input_tokens = token_counter(messages=messages, text=input) + except Exception: + input_tokens = 0 + verbose_router_logger.debug(f"input_tokens={input_tokens}") + # ----------------------- + # Find lowest used model + # ---------------------- + + if tpm_dict is None: # base case - none of the deployments have been used + # initialize a tpm dict with {model_id: 0} + tpm_dict = {} + for deployment in healthy_deployments: + tpm_dict[deployment["model_info"]["id"]] = 0 + else: + for d in healthy_deployments: + ## if healthy deployment not yet used + tpm_key = d["model_info"]["id"] + if tpm_key not in tpm_dict or tpm_dict[tpm_key] is None: + tpm_dict[tpm_key] = 0 + + all_deployments = tpm_dict + potential_deployments = self._return_potential_deployments( + healthy_deployments=healthy_deployments, + all_deployments=all_deployments, + input_tokens=input_tokens, + rpm_dict=rpm_dict, + ) + print_verbose("returning picked lowest tpm/rpm deployment.") + + if len(potential_deployments) > 0: + return random.choice(potential_deployments) + else: + return None + + async def async_get_available_deployments( + self, + model_group: str, + healthy_deployments: list, + messages: Optional[List[Dict[str, str]]] = None, + input: Optional[Union[str, List]] = None, + ): + """ + Async implementation of get deployments. + + Reduces time to retrieve the tpm/rpm values from cache + """ + # get list of potential deployments + verbose_router_logger.debug( + f"get_available_deployments - Usage Based. model_group: {model_group}, healthy_deployments: {healthy_deployments}" + ) + + dt = get_utc_datetime() + current_minute = dt.strftime("%H-%M") + + tpm_keys = [] + rpm_keys = [] + for m in healthy_deployments: + if isinstance(m, dict): + id = m.get("model_info", {}).get( + "id" + ) # a deployment should always have an 'id'. this is set in router.py + deployment_name = m.get("litellm_params", {}).get("model") + tpm_key = "{}:{}:tpm:{}".format(id, deployment_name, current_minute) + rpm_key = "{}:{}:rpm:{}".format(id, deployment_name, current_minute) + + tpm_keys.append(tpm_key) + rpm_keys.append(rpm_key) + + combined_tpm_rpm_keys = tpm_keys + rpm_keys + + combined_tpm_rpm_values = await self.router_cache.async_batch_get_cache( + keys=combined_tpm_rpm_keys + ) # [1, 2, None, ..] + + if combined_tpm_rpm_values is not None: + tpm_values = combined_tpm_rpm_values[: len(tpm_keys)] + rpm_values = combined_tpm_rpm_values[len(tpm_keys) :] + else: + tpm_values = None + rpm_values = None + + deployment = self._common_checks_available_deployment( + model_group=model_group, + healthy_deployments=healthy_deployments, + tpm_keys=tpm_keys, + tpm_values=tpm_values, + rpm_keys=rpm_keys, + rpm_values=rpm_values, + messages=messages, + input=input, + ) + + try: + assert deployment is not None + return deployment + except Exception: + ### GET THE DICT OF TPM / RPM + LIMITS PER DEPLOYMENT ### + deployment_dict = {} + for index, _deployment in enumerate(healthy_deployments): + if isinstance(_deployment, dict): + id = _deployment.get("model_info", {}).get("id") + ### GET DEPLOYMENT TPM LIMIT ### + _deployment_tpm = None + if _deployment_tpm is None: + _deployment_tpm = _deployment.get("tpm", None) + if _deployment_tpm is None: + _deployment_tpm = _deployment.get("litellm_params", {}).get( + "tpm", None + ) + if _deployment_tpm is None: + _deployment_tpm = _deployment.get("model_info", {}).get( + "tpm", None + ) + if _deployment_tpm is None: + _deployment_tpm = float("inf") + + ### GET CURRENT TPM ### + current_tpm = tpm_values[index] if tpm_values else 0 + + ### GET DEPLOYMENT TPM LIMIT ### + _deployment_rpm = None + if _deployment_rpm is None: + _deployment_rpm = _deployment.get("rpm", None) + if _deployment_rpm is None: + _deployment_rpm = _deployment.get("litellm_params", {}).get( + "rpm", None + ) + if _deployment_rpm is None: + _deployment_rpm = _deployment.get("model_info", {}).get( + "rpm", None + ) + if _deployment_rpm is None: + _deployment_rpm = float("inf") + + ### GET CURRENT RPM ### + current_rpm = rpm_values[index] if rpm_values else 0 + + deployment_dict[id] = { + "current_tpm": current_tpm, + "tpm_limit": _deployment_tpm, + "current_rpm": current_rpm, + "rpm_limit": _deployment_rpm, + } + raise litellm.RateLimitError( + message=f"{RouterErrors.no_deployments_available.value}. Passed model={model_group}. Deployments={deployment_dict}", + llm_provider="", + model=model_group, + response=httpx.Response( + status_code=429, + content="", + headers={"retry-after": str(60)}, # type: ignore + request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore + ), + ) + + def get_available_deployments( + self, + model_group: str, + healthy_deployments: list, + messages: Optional[List[Dict[str, str]]] = None, + input: Optional[Union[str, List]] = None, + parent_otel_span: Optional[Span] = None, + ): + """ + Returns a deployment with the lowest TPM/RPM usage. + """ + # get list of potential deployments + verbose_router_logger.debug( + f"get_available_deployments - Usage Based. model_group: {model_group}, healthy_deployments: {healthy_deployments}" + ) + + dt = get_utc_datetime() + current_minute = dt.strftime("%H-%M") + tpm_keys = [] + rpm_keys = [] + for m in healthy_deployments: + if isinstance(m, dict): + id = m.get("model_info", {}).get( + "id" + ) # a deployment should always have an 'id'. this is set in router.py + deployment_name = m.get("litellm_params", {}).get("model") + tpm_key = "{}:{}:tpm:{}".format(id, deployment_name, current_minute) + rpm_key = "{}:{}:rpm:{}".format(id, deployment_name, current_minute) + + tpm_keys.append(tpm_key) + rpm_keys.append(rpm_key) + + tpm_values = self.router_cache.batch_get_cache( + keys=tpm_keys, parent_otel_span=parent_otel_span + ) # [1, 2, None, ..] + rpm_values = self.router_cache.batch_get_cache( + keys=rpm_keys, parent_otel_span=parent_otel_span + ) # [1, 2, None, ..] + + deployment = self._common_checks_available_deployment( + model_group=model_group, + healthy_deployments=healthy_deployments, + tpm_keys=tpm_keys, + tpm_values=tpm_values, + rpm_keys=rpm_keys, + rpm_values=rpm_values, + messages=messages, + input=input, + ) + + try: + assert deployment is not None + return deployment + except Exception: + ### GET THE DICT OF TPM / RPM + LIMITS PER DEPLOYMENT ### + deployment_dict = {} + for index, _deployment in enumerate(healthy_deployments): + if isinstance(_deployment, dict): + id = _deployment.get("model_info", {}).get("id") + ### GET DEPLOYMENT TPM LIMIT ### + _deployment_tpm = None + if _deployment_tpm is None: + _deployment_tpm = _deployment.get("tpm", None) + if _deployment_tpm is None: + _deployment_tpm = _deployment.get("litellm_params", {}).get( + "tpm", None + ) + if _deployment_tpm is None: + _deployment_tpm = _deployment.get("model_info", {}).get( + "tpm", None + ) + if _deployment_tpm is None: + _deployment_tpm = float("inf") + + ### GET CURRENT TPM ### + current_tpm = tpm_values[index] if tpm_values else 0 + + ### GET DEPLOYMENT TPM LIMIT ### + _deployment_rpm = None + if _deployment_rpm is None: + _deployment_rpm = _deployment.get("rpm", None) + if _deployment_rpm is None: + _deployment_rpm = _deployment.get("litellm_params", {}).get( + "rpm", None + ) + if _deployment_rpm is None: + _deployment_rpm = _deployment.get("model_info", {}).get( + "rpm", None + ) + if _deployment_rpm is None: + _deployment_rpm = float("inf") + + ### GET CURRENT RPM ### + current_rpm = rpm_values[index] if rpm_values else 0 + + deployment_dict[id] = { + "current_tpm": current_tpm, + "tpm_limit": _deployment_tpm, + "current_rpm": current_rpm, + "rpm_limit": _deployment_rpm, + } + raise ValueError( + f"{RouterErrors.no_deployments_available.value}. Passed model={model_group}. Deployments={deployment_dict}" + ) diff --git a/.venv/lib/python3.12/site-packages/litellm/router_strategy/simple_shuffle.py b/.venv/lib/python3.12/site-packages/litellm/router_strategy/simple_shuffle.py new file mode 100644 index 00000000..da24c02f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/router_strategy/simple_shuffle.py @@ -0,0 +1,96 @@ +""" +Returns a random deployment from the list of healthy deployments. + +If weights are provided, it will return a deployment based on the weights. + +""" + +import random +from typing import TYPE_CHECKING, Any, Dict, List, Union + +from litellm._logging import verbose_router_logger + +if TYPE_CHECKING: + from litellm.router import Router as _Router + + LitellmRouter = _Router +else: + LitellmRouter = Any + + +def simple_shuffle( + llm_router_instance: LitellmRouter, + healthy_deployments: Union[List[Any], Dict[Any, Any]], + model: str, +) -> Dict: + """ + Returns a random deployment from the list of healthy deployments. + + If weights are provided, it will return a deployment based on the weights. + + If users pass `rpm` or `tpm`, we do a random weighted pick - based on `rpm`/`tpm`. + + Args: + llm_router_instance: LitellmRouter instance + healthy_deployments: List of healthy deployments + model: Model name + + Returns: + Dict: A single healthy deployment + """ + + ############## Check if 'weight' param set for a weighted pick ################# + weight = healthy_deployments[0].get("litellm_params").get("weight", None) + if weight is not None: + # use weight-random pick if rpms provided + weights = [m["litellm_params"].get("weight", 0) for m in healthy_deployments] + verbose_router_logger.debug(f"\nweight {weights}") + total_weight = sum(weights) + weights = [weight / total_weight for weight in weights] + verbose_router_logger.debug(f"\n weights {weights}") + # Perform weighted random pick + selected_index = random.choices(range(len(weights)), weights=weights)[0] + verbose_router_logger.debug(f"\n selected index, {selected_index}") + deployment = healthy_deployments[selected_index] + verbose_router_logger.info( + f"get_available_deployment for model: {model}, Selected deployment: {llm_router_instance.print_deployment(deployment) or deployment[0]} for model: {model}" + ) + return deployment or deployment[0] + ############## Check if we can do a RPM/TPM based weighted pick ################# + rpm = healthy_deployments[0].get("litellm_params").get("rpm", None) + if rpm is not None: + # use weight-random pick if rpms provided + rpms = [m["litellm_params"].get("rpm", 0) for m in healthy_deployments] + verbose_router_logger.debug(f"\nrpms {rpms}") + total_rpm = sum(rpms) + weights = [rpm / total_rpm for rpm in rpms] + verbose_router_logger.debug(f"\n weights {weights}") + # Perform weighted random pick + selected_index = random.choices(range(len(rpms)), weights=weights)[0] + verbose_router_logger.debug(f"\n selected index, {selected_index}") + deployment = healthy_deployments[selected_index] + verbose_router_logger.info( + f"get_available_deployment for model: {model}, Selected deployment: {llm_router_instance.print_deployment(deployment) or deployment[0]} for model: {model}" + ) + return deployment or deployment[0] + ############## Check if we can do a RPM/TPM based weighted pick ################# + tpm = healthy_deployments[0].get("litellm_params").get("tpm", None) + if tpm is not None: + # use weight-random pick if rpms provided + tpms = [m["litellm_params"].get("tpm", 0) for m in healthy_deployments] + verbose_router_logger.debug(f"\ntpms {tpms}") + total_tpm = sum(tpms) + weights = [tpm / total_tpm for tpm in tpms] + verbose_router_logger.debug(f"\n weights {weights}") + # Perform weighted random pick + selected_index = random.choices(range(len(tpms)), weights=weights)[0] + verbose_router_logger.debug(f"\n selected index, {selected_index}") + deployment = healthy_deployments[selected_index] + verbose_router_logger.info( + f"get_available_deployment for model: {model}, Selected deployment: {llm_router_instance.print_deployment(deployment) or deployment[0]} for model: {model}" + ) + return deployment or deployment[0] + + ############## No RPM/TPM passed, we do a random pick ################# + item = random.choice(healthy_deployments) + return item or item[0] diff --git a/.venv/lib/python3.12/site-packages/litellm/router_strategy/tag_based_routing.py b/.venv/lib/python3.12/site-packages/litellm/router_strategy/tag_based_routing.py new file mode 100644 index 00000000..e6a93614 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/router_strategy/tag_based_routing.py @@ -0,0 +1,146 @@ +""" +Use this to route requests between Teams + +- If tags in request is a subset of tags in deployment, return deployment +- if deployments are set with default tags, return all default deployment +- If no default_deployments are set, return all deployments +""" + +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +from litellm._logging import verbose_logger +from litellm.types.router import RouterErrors + +if TYPE_CHECKING: + from litellm.router import Router as _Router + + LitellmRouter = _Router +else: + LitellmRouter = Any + + +def is_valid_deployment_tag( + deployment_tags: List[str], request_tags: List[str] +) -> bool: + """ + Check if a tag is valid + """ + + if any(tag in deployment_tags for tag in request_tags): + verbose_logger.debug( + "adding deployment with tags: %s, request tags: %s", + deployment_tags, + request_tags, + ) + return True + elif "default" in deployment_tags: + verbose_logger.debug( + "adding default deployment with tags: %s, request tags: %s", + deployment_tags, + request_tags, + ) + return True + return False + + +async def get_deployments_for_tag( + llm_router_instance: LitellmRouter, + model: str, # used to raise the correct error + healthy_deployments: Union[List[Any], Dict[Any, Any]], + request_kwargs: Optional[Dict[Any, Any]] = None, +): + """ + Returns a list of deployments that match the requested model and tags in the request. + + Executes tag based filtering based on the tags in request metadata and the tags on the deployments + """ + if llm_router_instance.enable_tag_filtering is not True: + return healthy_deployments + + if request_kwargs is None: + verbose_logger.debug( + "get_deployments_for_tag: request_kwargs is None returning healthy_deployments: %s", + healthy_deployments, + ) + return healthy_deployments + + if healthy_deployments is None: + verbose_logger.debug( + "get_deployments_for_tag: healthy_deployments is None returning healthy_deployments" + ) + return healthy_deployments + + verbose_logger.debug("request metadata: %s", request_kwargs.get("metadata")) + if "metadata" in request_kwargs: + metadata = request_kwargs["metadata"] + request_tags = metadata.get("tags") + + new_healthy_deployments = [] + if request_tags: + verbose_logger.debug( + "get_deployments_for_tag routing: router_keys: %s", request_tags + ) + # example this can be router_keys=["free", "custom"] + # get all deployments that have a superset of these router keys + for deployment in healthy_deployments: + deployment_litellm_params = deployment.get("litellm_params") + deployment_tags = deployment_litellm_params.get("tags") + + verbose_logger.debug( + "deployment: %s, deployment_router_keys: %s", + deployment, + deployment_tags, + ) + + if deployment_tags is None: + continue + + if is_valid_deployment_tag(deployment_tags, request_tags): + new_healthy_deployments.append(deployment) + + if len(new_healthy_deployments) == 0: + raise ValueError( + f"{RouterErrors.no_deployments_with_tag_routing.value}. Passed model={model} and tags={request_tags}" + ) + + return new_healthy_deployments + + # for Untagged requests use default deployments if set + _default_deployments_with_tags = [] + for deployment in healthy_deployments: + if "default" in deployment.get("litellm_params", {}).get("tags", []): + _default_deployments_with_tags.append(deployment) + + if len(_default_deployments_with_tags) > 0: + return _default_deployments_with_tags + + # if no default deployment is found, return healthy_deployments + verbose_logger.debug( + "no tier found in metadata, returning healthy_deployments: %s", + healthy_deployments, + ) + return healthy_deployments + + +def _get_tags_from_request_kwargs( + request_kwargs: Optional[Dict[Any, Any]] = None +) -> List[str]: + """ + Helper to get tags from request kwargs + + Args: + request_kwargs: The request kwargs to get tags from + + Returns: + List[str]: The tags from the request kwargs + """ + if request_kwargs is None: + return [] + if "metadata" in request_kwargs: + metadata = request_kwargs["metadata"] + return metadata.get("tags", []) + elif "litellm_params" in request_kwargs: + litellm_params = request_kwargs["litellm_params"] + _metadata = litellm_params.get("metadata", {}) + return _metadata.get("tags", []) + return [] |