diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/router_utils | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/router_utils')
15 files changed, 2000 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/router_utils/add_retry_fallback_headers.py b/.venv/lib/python3.12/site-packages/litellm/router_utils/add_retry_fallback_headers.py new file mode 100644 index 00000000..0984f61b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/router_utils/add_retry_fallback_headers.py @@ -0,0 +1,68 @@ +from typing import Any, Optional, Union + +from pydantic import BaseModel + +from litellm.types.utils import HiddenParams + + +def _add_headers_to_response(response: Any, headers: dict) -> Any: + """ + Helper function to add headers to a response's hidden params + """ + if response is None or not isinstance(response, BaseModel): + return response + + hidden_params: Optional[Union[dict, HiddenParams]] = getattr( + response, "_hidden_params", {} + ) + + if hidden_params is None: + hidden_params = {} + elif isinstance(hidden_params, HiddenParams): + hidden_params = hidden_params.model_dump() + + hidden_params.setdefault("additional_headers", {}) + hidden_params["additional_headers"].update(headers) + + setattr(response, "_hidden_params", hidden_params) + return response + + +def add_retry_headers_to_response( + response: Any, + attempted_retries: int, + max_retries: Optional[int] = None, +) -> Any: + """ + Add retry headers to the request + """ + retry_headers = { + "x-litellm-attempted-retries": attempted_retries, + } + if max_retries is not None: + retry_headers["x-litellm-max-retries"] = max_retries + + return _add_headers_to_response(response, retry_headers) + + +def add_fallback_headers_to_response( + response: Any, + attempted_fallbacks: int, +) -> Any: + """ + Add fallback headers to the response + + Args: + response: The response to add the headers to + attempted_fallbacks: The number of fallbacks attempted + + Returns: + The response with the headers added + + Note: It's intentional that we don't add max_fallbacks in response headers + Want to avoid bloat in the response headers for performance. + """ + fallback_headers = { + "x-litellm-attempted-fallbacks": attempted_fallbacks, + } + return _add_headers_to_response(response, fallback_headers) diff --git a/.venv/lib/python3.12/site-packages/litellm/router_utils/batch_utils.py b/.venv/lib/python3.12/site-packages/litellm/router_utils/batch_utils.py new file mode 100644 index 00000000..a41bae25 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/router_utils/batch_utils.py @@ -0,0 +1,63 @@ +import io +import json +from typing import Optional, Tuple, Union + + +class InMemoryFile(io.BytesIO): + def __init__(self, content: bytes, name: str): + super().__init__(content) + self.name = name + + +def replace_model_in_jsonl( + file_content: Union[bytes, Tuple[str, bytes, str]], new_model_name: str +) -> Optional[InMemoryFile]: + try: + # Decode the bytes to a string and split into lines + # If file_content is a file-like object, read the bytes + if hasattr(file_content, "read"): + file_content_bytes = file_content.read() # type: ignore + elif isinstance(file_content, tuple): + file_content_bytes = file_content[1] + else: + file_content_bytes = file_content + + # Decode the bytes to a string and split into lines + if isinstance(file_content_bytes, bytes): + file_content_str = file_content_bytes.decode("utf-8") + else: + file_content_str = file_content_bytes + lines = file_content_str.splitlines() + modified_lines = [] + for line in lines: + # Parse each line as a JSON object + json_object = json.loads(line.strip()) + + # Replace the model name if it exists + if "body" in json_object: + json_object["body"]["model"] = new_model_name + + # Convert the modified JSON object back to a string + modified_lines.append(json.dumps(json_object)) + + # Reassemble the modified lines and return as bytes + modified_file_content = "\n".join(modified_lines).encode("utf-8") + return InMemoryFile(modified_file_content, name="modified_file.jsonl") # type: ignore + + except (json.JSONDecodeError, UnicodeDecodeError, TypeError): + return None + + +def _get_router_metadata_variable_name(function_name) -> str: + """ + Helper to return what the "metadata" field should be called in the request data + + For all /thread or /assistant endpoints we need to call this "litellm_metadata" + + For ALL other endpoints we call this "metadata + """ + ROUTER_METHODS_USING_LITELLM_METADATA = set(["batch", "generic_api_call"]) + if function_name in ROUTER_METHODS_USING_LITELLM_METADATA: + return "litellm_metadata" + else: + return "metadata" diff --git a/.venv/lib/python3.12/site-packages/litellm/router_utils/client_initalization_utils.py b/.venv/lib/python3.12/site-packages/litellm/router_utils/client_initalization_utils.py new file mode 100644 index 00000000..e24d2378 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/router_utils/client_initalization_utils.py @@ -0,0 +1,37 @@ +import asyncio +from typing import TYPE_CHECKING, Any + +from litellm.utils import calculate_max_parallel_requests + +if TYPE_CHECKING: + from litellm.router import Router as _Router + + LitellmRouter = _Router +else: + LitellmRouter = Any + + +class InitalizeCachedClient: + @staticmethod + def set_max_parallel_requests_client( + litellm_router_instance: LitellmRouter, model: dict + ): + litellm_params = model.get("litellm_params", {}) + model_id = model["model_info"]["id"] + rpm = litellm_params.get("rpm", None) + tpm = litellm_params.get("tpm", None) + max_parallel_requests = litellm_params.get("max_parallel_requests", None) + calculated_max_parallel_requests = calculate_max_parallel_requests( + rpm=rpm, + max_parallel_requests=max_parallel_requests, + tpm=tpm, + default_max_parallel_requests=litellm_router_instance.default_max_parallel_requests, + ) + if calculated_max_parallel_requests: + semaphore = asyncio.Semaphore(calculated_max_parallel_requests) + cache_key = f"{model_id}_max_parallel_requests_client" + litellm_router_instance.cache.set_cache( + key=cache_key, + value=semaphore, + local_only=True, + ) diff --git a/.venv/lib/python3.12/site-packages/litellm/router_utils/clientside_credential_handler.py b/.venv/lib/python3.12/site-packages/litellm/router_utils/clientside_credential_handler.py new file mode 100644 index 00000000..c98f6143 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/router_utils/clientside_credential_handler.py @@ -0,0 +1,37 @@ +""" +Utils for handling clientside credentials + +Supported clientside credentials: +- api_key +- api_base +- base_url + +If given, generate a unique model_id for the deployment. + +Ensures cooldowns are applied correctly. +""" + +clientside_credential_keys = ["api_key", "api_base", "base_url"] + + +def is_clientside_credential(request_kwargs: dict) -> bool: + """ + Check if the credential is a clientside credential. + """ + return any(key in request_kwargs for key in clientside_credential_keys) + + +def get_dynamic_litellm_params(litellm_params: dict, request_kwargs: dict) -> dict: + """ + Generate a unique model_id for the deployment. + + Returns + - litellm_params: dict + + for generating a unique model_id. + """ + # update litellm_params with clientside credentials + for key in clientside_credential_keys: + if key in request_kwargs: + litellm_params[key] = request_kwargs[key] + return litellm_params diff --git a/.venv/lib/python3.12/site-packages/litellm/router_utils/cooldown_cache.py b/.venv/lib/python3.12/site-packages/litellm/router_utils/cooldown_cache.py new file mode 100644 index 00000000..f096b026 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/router_utils/cooldown_cache.py @@ -0,0 +1,170 @@ +""" +Wrapper around router cache. Meant to handle model cooldown logic +""" + +import time +from typing import TYPE_CHECKING, Any, List, Optional, Tuple, TypedDict + +from litellm import verbose_logger +from litellm.caching.caching import DualCache +from litellm.caching.in_memory_cache import InMemoryCache + +if TYPE_CHECKING: + from opentelemetry.trace import Span as _Span + + Span = _Span +else: + Span = Any + + +class CooldownCacheValue(TypedDict): + exception_received: str + status_code: str + timestamp: float + cooldown_time: float + + +class CooldownCache: + def __init__(self, cache: DualCache, default_cooldown_time: float): + self.cache = cache + self.default_cooldown_time = default_cooldown_time + self.in_memory_cache = InMemoryCache() + + def _common_add_cooldown_logic( + self, model_id: str, original_exception, exception_status, cooldown_time: float + ) -> Tuple[str, CooldownCacheValue]: + try: + current_time = time.time() + cooldown_key = f"deployment:{model_id}:cooldown" + + # Store the cooldown information for the deployment separately + cooldown_data = CooldownCacheValue( + exception_received=str(original_exception), + status_code=str(exception_status), + timestamp=current_time, + cooldown_time=cooldown_time, + ) + + return cooldown_key, cooldown_data + except Exception as e: + verbose_logger.error( + "CooldownCache::_common_add_cooldown_logic - Exception occurred - {}".format( + str(e) + ) + ) + raise e + + def add_deployment_to_cooldown( + self, + model_id: str, + original_exception: Exception, + exception_status: int, + cooldown_time: Optional[float], + ): + try: + _cooldown_time = cooldown_time or self.default_cooldown_time + cooldown_key, cooldown_data = self._common_add_cooldown_logic( + model_id=model_id, + original_exception=original_exception, + exception_status=exception_status, + cooldown_time=_cooldown_time, + ) + + # Set the cache with a TTL equal to the cooldown time + self.cache.set_cache( + value=cooldown_data, + key=cooldown_key, + ttl=_cooldown_time, + ) + except Exception as e: + verbose_logger.error( + "CooldownCache::add_deployment_to_cooldown - Exception occurred - {}".format( + str(e) + ) + ) + raise e + + @staticmethod + def get_cooldown_cache_key(model_id: str) -> str: + return f"deployment:{model_id}:cooldown" + + async def async_get_active_cooldowns( + self, model_ids: List[str], parent_otel_span: Optional[Span] + ) -> List[Tuple[str, CooldownCacheValue]]: + # Generate the keys for the deployments + keys = [ + CooldownCache.get_cooldown_cache_key(model_id) for model_id in model_ids + ] + + # Retrieve the values for the keys using mget + ## more likely to be none if no models ratelimited. So just check redis every 1s + ## each redis call adds ~100ms latency. + + ## check in memory cache first + results = await self.cache.async_batch_get_cache( + keys=keys, parent_otel_span=parent_otel_span + ) + active_cooldowns: List[Tuple[str, CooldownCacheValue]] = [] + + if results is None: + return active_cooldowns + + # Process the results + for model_id, result in zip(model_ids, results): + if result and isinstance(result, dict): + cooldown_cache_value = CooldownCacheValue(**result) # type: ignore + active_cooldowns.append((model_id, cooldown_cache_value)) + + return active_cooldowns + + def get_active_cooldowns( + self, model_ids: List[str], parent_otel_span: Optional[Span] + ) -> List[Tuple[str, CooldownCacheValue]]: + # Generate the keys for the deployments + keys = [f"deployment:{model_id}:cooldown" for model_id in model_ids] + # Retrieve the values for the keys using mget + results = ( + self.cache.batch_get_cache(keys=keys, parent_otel_span=parent_otel_span) + or [] + ) + + active_cooldowns = [] + # Process the results + for model_id, result in zip(model_ids, results): + if result and isinstance(result, dict): + cooldown_cache_value = CooldownCacheValue(**result) # type: ignore + active_cooldowns.append((model_id, cooldown_cache_value)) + + return active_cooldowns + + def get_min_cooldown( + self, model_ids: List[str], parent_otel_span: Optional[Span] + ) -> float: + """Return min cooldown time required for a group of model id's.""" + + # Generate the keys for the deployments + keys = [f"deployment:{model_id}:cooldown" for model_id in model_ids] + + # Retrieve the values for the keys using mget + results = ( + self.cache.batch_get_cache(keys=keys, parent_otel_span=parent_otel_span) + or [] + ) + + min_cooldown_time: Optional[float] = None + # Process the results + for model_id, result in zip(model_ids, results): + if result and isinstance(result, dict): + cooldown_cache_value = CooldownCacheValue(**result) # type: ignore + if min_cooldown_time is None: + min_cooldown_time = cooldown_cache_value["cooldown_time"] + elif cooldown_cache_value["cooldown_time"] < min_cooldown_time: + min_cooldown_time = cooldown_cache_value["cooldown_time"] + + return min_cooldown_time or self.default_cooldown_time + + +# Usage example: +# cooldown_cache = CooldownCache(cache=your_cache_instance, cooldown_time=your_cooldown_time) +# cooldown_cache.add_deployment_to_cooldown(deployment, original_exception, exception_status) +# active_cooldowns = cooldown_cache.get_active_cooldowns() diff --git a/.venv/lib/python3.12/site-packages/litellm/router_utils/cooldown_callbacks.py b/.venv/lib/python3.12/site-packages/litellm/router_utils/cooldown_callbacks.py new file mode 100644 index 00000000..54a016d3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/router_utils/cooldown_callbacks.py @@ -0,0 +1,98 @@ +""" +Callbacks triggered on cooling down deployments +""" + +import copy +from typing import TYPE_CHECKING, Any, Optional, Union + +import litellm +from litellm._logging import verbose_logger + +if TYPE_CHECKING: + from litellm.router import Router as _Router + + LitellmRouter = _Router + from litellm.integrations.prometheus import PrometheusLogger +else: + LitellmRouter = Any + PrometheusLogger = Any + + +async def router_cooldown_event_callback( + litellm_router_instance: LitellmRouter, + deployment_id: str, + exception_status: Union[str, int], + cooldown_time: float, +): + """ + Callback triggered when a deployment is put into cooldown by litellm + + - Updates deployment state on Prometheus + - Increments cooldown metric for deployment on Prometheus + """ + verbose_logger.debug("In router_cooldown_event_callback - updating prometheus") + _deployment = litellm_router_instance.get_deployment(model_id=deployment_id) + if _deployment is None: + verbose_logger.warning( + f"in router_cooldown_event_callback but _deployment is None for deployment_id={deployment_id}. Doing nothing" + ) + return + _litellm_params = _deployment["litellm_params"] + temp_litellm_params = copy.deepcopy(_litellm_params) + temp_litellm_params = dict(temp_litellm_params) + _model_name = _deployment.get("model_name", None) or "" + _api_base = ( + litellm.get_api_base(model=_model_name, optional_params=temp_litellm_params) + or "" + ) + model_info = _deployment["model_info"] + model_id = model_info.id + + litellm_model_name = temp_litellm_params.get("model") or "" + llm_provider = "" + try: + _, llm_provider, _, _ = litellm.get_llm_provider( + model=litellm_model_name, + custom_llm_provider=temp_litellm_params.get("custom_llm_provider"), + ) + except Exception: + pass + + # get the prometheus logger from in memory loggers + prometheusLogger: Optional[PrometheusLogger] = ( + _get_prometheus_logger_from_callbacks() + ) + + if prometheusLogger is not None: + prometheusLogger.set_deployment_complete_outage( + litellm_model_name=_model_name, + model_id=model_id, + api_base=_api_base, + api_provider=llm_provider, + ) + + prometheusLogger.increment_deployment_cooled_down( + litellm_model_name=_model_name, + model_id=model_id, + api_base=_api_base, + api_provider=llm_provider, + exception_status=str(exception_status), + ) + + return + + +def _get_prometheus_logger_from_callbacks() -> Optional[PrometheusLogger]: + """ + Checks if prometheus is a initalized callback, if yes returns it + """ + from litellm.integrations.prometheus import PrometheusLogger + + for _callback in litellm._async_success_callback: + if isinstance(_callback, PrometheusLogger): + return _callback + for global_callback in litellm.callbacks: + if isinstance(global_callback, PrometheusLogger): + return global_callback + + return None diff --git a/.venv/lib/python3.12/site-packages/litellm/router_utils/cooldown_handlers.py b/.venv/lib/python3.12/site-packages/litellm/router_utils/cooldown_handlers.py new file mode 100644 index 00000000..52babc27 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/router_utils/cooldown_handlers.py @@ -0,0 +1,438 @@ +""" +Router cooldown handlers +- _set_cooldown_deployments: puts a deployment in the cooldown list +- get_cooldown_deployments: returns the list of deployments in the cooldown list +- async_get_cooldown_deployments: ASYNC: returns the list of deployments in the cooldown list + +""" + +import asyncio +from typing import TYPE_CHECKING, Any, List, Optional, Union + +import litellm +from litellm._logging import verbose_router_logger +from litellm.constants import ( + DEFAULT_COOLDOWN_TIME_SECONDS, + DEFAULT_FAILURE_THRESHOLD_PERCENT, + SINGLE_DEPLOYMENT_TRAFFIC_FAILURE_THRESHOLD, +) +from litellm.router_utils.cooldown_callbacks import router_cooldown_event_callback + +from .router_callbacks.track_deployment_metrics import ( + get_deployment_failures_for_current_minute, + get_deployment_successes_for_current_minute, +) + +if TYPE_CHECKING: + from opentelemetry.trace import Span as _Span + + from litellm.router import Router as _Router + + LitellmRouter = _Router + Span = _Span +else: + LitellmRouter = Any + Span = Any + + +def _is_cooldown_required( + litellm_router_instance: LitellmRouter, + model_id: str, + exception_status: Union[str, int], + exception_str: Optional[str] = None, +) -> bool: + """ + A function to determine if a cooldown is required based on the exception status. + + Parameters: + model_id (str) The id of the model in the model list + exception_status (Union[str, int]): The status of the exception. + + Returns: + bool: True if a cooldown is required, False otherwise. + """ + try: + ignored_strings = ["APIConnectionError"] + if ( + exception_str is not None + ): # don't cooldown on litellm api connection errors errors + for ignored_string in ignored_strings: + if ignored_string in exception_str: + return False + + if isinstance(exception_status, str): + exception_status = int(exception_status) + + if exception_status >= 400 and exception_status < 500: + if exception_status == 429: + # Cool down 429 Rate Limit Errors + return True + + elif exception_status == 401: + # Cool down 401 Auth Errors + return True + + elif exception_status == 408: + return True + + elif exception_status == 404: + return True + + else: + # Do NOT cool down all other 4XX Errors + return False + + else: + # should cool down for all other errors + return True + + except Exception: + # Catch all - if any exceptions default to cooling down + return True + + +def _should_run_cooldown_logic( + litellm_router_instance: LitellmRouter, + deployment: Optional[str], + exception_status: Union[str, int], + original_exception: Any, +) -> bool: + """ + Helper that decides if cooldown logic should be run + Returns False if cooldown logic should not be run + + Does not run cooldown logic when: + - router.disable_cooldowns is True + - deployment is None + - _is_cooldown_required() returns False + - deployment is in litellm_router_instance.provider_default_deployment_ids + - exception_status is not one that should be immediately retried (e.g. 401) + """ + if ( + deployment is None + or litellm_router_instance.get_model_group(id=deployment) is None + ): + verbose_router_logger.debug( + "Should Not Run Cooldown Logic: deployment id is none or model group can't be found." + ) + return False + + if litellm_router_instance.disable_cooldowns: + verbose_router_logger.debug( + "Should Not Run Cooldown Logic: disable_cooldowns is True" + ) + return False + + if deployment is None: + verbose_router_logger.debug("Should Not Run Cooldown Logic: deployment is None") + return False + + if not _is_cooldown_required( + litellm_router_instance=litellm_router_instance, + model_id=deployment, + exception_status=exception_status, + exception_str=str(original_exception), + ): + verbose_router_logger.debug( + "Should Not Run Cooldown Logic: _is_cooldown_required returned False" + ) + return False + + if deployment in litellm_router_instance.provider_default_deployment_ids: + verbose_router_logger.debug( + "Should Not Run Cooldown Logic: deployment is in provider_default_deployment_ids" + ) + return False + + return True + + +def _should_cooldown_deployment( + litellm_router_instance: LitellmRouter, + deployment: str, + exception_status: Union[str, int], + original_exception: Any, +) -> bool: + """ + Helper that decides if a deployment should be put in cooldown + + Returns True if the deployment should be put in cooldown + Returns False if the deployment should not be put in cooldown + + + Deployment is put in cooldown when: + - v2 logic (Current): + cooldown if: + - got a 429 error from LLM API + - if %fails/%(successes + fails) > ALLOWED_FAILURE_RATE_PER_MINUTE + - got 401 Auth error, 404 NotFounder - checked by litellm._should_retry() + + + + - v1 logic (Legacy): if allowed fails or allowed fail policy set, coolsdown if num fails in this minute > allowed fails + """ + ## BASE CASE - single deployment + model_group = litellm_router_instance.get_model_group(id=deployment) + is_single_deployment_model_group = False + if model_group is not None and len(model_group) == 1: + is_single_deployment_model_group = True + if ( + litellm_router_instance.allowed_fails_policy is None + and _is_allowed_fails_set_on_router( + litellm_router_instance=litellm_router_instance + ) + is False + ): + num_successes_this_minute = get_deployment_successes_for_current_minute( + litellm_router_instance=litellm_router_instance, deployment_id=deployment + ) + num_fails_this_minute = get_deployment_failures_for_current_minute( + litellm_router_instance=litellm_router_instance, deployment_id=deployment + ) + + total_requests_this_minute = num_successes_this_minute + num_fails_this_minute + percent_fails = 0.0 + if total_requests_this_minute > 0: + percent_fails = num_fails_this_minute / ( + num_successes_this_minute + num_fails_this_minute + ) + verbose_router_logger.debug( + "percent fails for deployment = %s, percent fails = %s, num successes = %s, num fails = %s", + deployment, + percent_fails, + num_successes_this_minute, + num_fails_this_minute, + ) + + exception_status_int = cast_exception_status_to_int(exception_status) + if exception_status_int == 429 and not is_single_deployment_model_group: + return True + elif ( + percent_fails == 1.0 + and total_requests_this_minute + >= SINGLE_DEPLOYMENT_TRAFFIC_FAILURE_THRESHOLD + ): + # Cooldown if all requests failed and we have reasonable traffic + return True + elif ( + percent_fails > DEFAULT_FAILURE_THRESHOLD_PERCENT + and not is_single_deployment_model_group # by default we should avoid cooldowns on single deployment model groups + ): + return True + + elif ( + litellm._should_retry( + status_code=cast_exception_status_to_int(exception_status) + ) + is False + ): + return True + + return False + else: + return should_cooldown_based_on_allowed_fails_policy( + litellm_router_instance=litellm_router_instance, + deployment=deployment, + original_exception=original_exception, + ) + + return False + + +def _set_cooldown_deployments( + litellm_router_instance: LitellmRouter, + original_exception: Any, + exception_status: Union[str, int], + deployment: Optional[str] = None, + time_to_cooldown: Optional[float] = None, +) -> bool: + """ + Add a model to the list of models being cooled down for that minute, if it exceeds the allowed fails / minute + + or + + the exception is not one that should be immediately retried (e.g. 401) + + Returns: + - True if the deployment should be put in cooldown + - False if the deployment should not be put in cooldown + """ + verbose_router_logger.debug("checks 'should_run_cooldown_logic'") + + if ( + _should_run_cooldown_logic( + litellm_router_instance, deployment, exception_status, original_exception + ) + is False + or deployment is None + ): + verbose_router_logger.debug("should_run_cooldown_logic returned False") + return False + + exception_status_int = cast_exception_status_to_int(exception_status) + + verbose_router_logger.debug(f"Attempting to add {deployment} to cooldown list") + cooldown_time = litellm_router_instance.cooldown_time or 1 + if time_to_cooldown is not None: + cooldown_time = time_to_cooldown + + if _should_cooldown_deployment( + litellm_router_instance, deployment, exception_status, original_exception + ): + litellm_router_instance.cooldown_cache.add_deployment_to_cooldown( + model_id=deployment, + original_exception=original_exception, + exception_status=exception_status_int, + cooldown_time=cooldown_time, + ) + + # Trigger cooldown callback handler + asyncio.create_task( + router_cooldown_event_callback( + litellm_router_instance=litellm_router_instance, + deployment_id=deployment, + exception_status=exception_status, + cooldown_time=cooldown_time, + ) + ) + return True + return False + + +async def _async_get_cooldown_deployments( + litellm_router_instance: LitellmRouter, + parent_otel_span: Optional[Span], +) -> List[str]: + """ + Async implementation of '_get_cooldown_deployments' + """ + model_ids = litellm_router_instance.get_model_ids() + cooldown_models = ( + await litellm_router_instance.cooldown_cache.async_get_active_cooldowns( + model_ids=model_ids, + parent_otel_span=parent_otel_span, + ) + ) + + cached_value_deployment_ids = [] + if ( + cooldown_models is not None + and isinstance(cooldown_models, list) + and len(cooldown_models) > 0 + and isinstance(cooldown_models[0], tuple) + ): + cached_value_deployment_ids = [cv[0] for cv in cooldown_models] + + verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}") + return cached_value_deployment_ids + + +async def _async_get_cooldown_deployments_with_debug_info( + litellm_router_instance: LitellmRouter, + parent_otel_span: Optional[Span], +) -> List[tuple]: + """ + Async implementation of '_get_cooldown_deployments' + """ + model_ids = litellm_router_instance.get_model_ids() + cooldown_models = ( + await litellm_router_instance.cooldown_cache.async_get_active_cooldowns( + model_ids=model_ids, parent_otel_span=parent_otel_span + ) + ) + + verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}") + return cooldown_models + + +def _get_cooldown_deployments( + litellm_router_instance: LitellmRouter, parent_otel_span: Optional[Span] +) -> List[str]: + """ + Get the list of models being cooled down for this minute + """ + # get the current cooldown list for that minute + + # ---------------------- + # Return cooldown models + # ---------------------- + model_ids = litellm_router_instance.get_model_ids() + + cooldown_models = litellm_router_instance.cooldown_cache.get_active_cooldowns( + model_ids=model_ids, parent_otel_span=parent_otel_span + ) + + cached_value_deployment_ids = [] + if ( + cooldown_models is not None + and isinstance(cooldown_models, list) + and len(cooldown_models) > 0 + and isinstance(cooldown_models[0], tuple) + ): + cached_value_deployment_ids = [cv[0] for cv in cooldown_models] + + return cached_value_deployment_ids + + +def should_cooldown_based_on_allowed_fails_policy( + litellm_router_instance: LitellmRouter, + deployment: str, + original_exception: Any, +) -> bool: + """ + Check if fails are within the allowed limit and update the number of fails. + + Returns: + - True if fails exceed the allowed limit (should cooldown) + - False if fails are within the allowed limit (should not cooldown) + """ + allowed_fails = ( + litellm_router_instance.get_allowed_fails_from_policy( + exception=original_exception, + ) + or litellm_router_instance.allowed_fails + ) + cooldown_time = ( + litellm_router_instance.cooldown_time or DEFAULT_COOLDOWN_TIME_SECONDS + ) + + current_fails = litellm_router_instance.failed_calls.get_cache(key=deployment) or 0 + updated_fails = current_fails + 1 + + if updated_fails > allowed_fails: + return True + else: + litellm_router_instance.failed_calls.set_cache( + key=deployment, value=updated_fails, ttl=cooldown_time + ) + + return False + + +def _is_allowed_fails_set_on_router( + litellm_router_instance: LitellmRouter, +) -> bool: + """ + Check if Router.allowed_fails is set or is Non-default Value + + Returns: + - True if Router.allowed_fails is set or is Non-default Value + - False if Router.allowed_fails is None or is Default Value + """ + if litellm_router_instance.allowed_fails is None: + return False + if litellm_router_instance.allowed_fails != litellm.allowed_fails: + return True + return False + + +def cast_exception_status_to_int(exception_status: Union[str, int]) -> int: + if isinstance(exception_status, str): + try: + exception_status = int(exception_status) + except Exception: + verbose_router_logger.debug( + f"Unable to cast exception status to int {exception_status}. Defaulting to status=500." + ) + exception_status = 500 + return exception_status diff --git a/.venv/lib/python3.12/site-packages/litellm/router_utils/fallback_event_handlers.py b/.venv/lib/python3.12/site-packages/litellm/router_utils/fallback_event_handlers.py new file mode 100644 index 00000000..df805e49 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/router_utils/fallback_event_handlers.py @@ -0,0 +1,303 @@ +from enum import Enum +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union + +import litellm +from litellm._logging import verbose_router_logger +from litellm.integrations.custom_logger import CustomLogger +from litellm.router_utils.add_retry_fallback_headers import ( + add_fallback_headers_to_response, +) +from litellm.types.router import LiteLLMParamsTypedDict + +if TYPE_CHECKING: + from litellm.router import Router as _Router + + LitellmRouter = _Router +else: + LitellmRouter = Any + + +def _check_stripped_model_group(model_group: str, fallback_key: str) -> bool: + """ + Handles wildcard routing scenario + + where fallbacks set like: + [{"gpt-3.5-turbo": ["claude-3-haiku"]}] + + but model_group is like: + "openai/gpt-3.5-turbo" + + Returns: + - True if the stripped model group == fallback_key + """ + for provider in litellm.provider_list: + if isinstance(provider, Enum): + _provider = provider.value + else: + _provider = provider + if model_group.startswith(f"{_provider}/"): + stripped_model_group = model_group.replace(f"{_provider}/", "") + if stripped_model_group == fallback_key: + return True + return False + + +def get_fallback_model_group( + fallbacks: List[Any], model_group: str +) -> Tuple[Optional[List[str]], Optional[int]]: + """ + Returns: + - fallback_model_group: List[str] of fallback model groups. example: ["gpt-4", "gpt-3.5-turbo"] + - generic_fallback_idx: int of the index of the generic fallback in the fallbacks list. + + Checks: + - exact match + - stripped model group match + - generic fallback + """ + generic_fallback_idx: Optional[int] = None + stripped_model_fallback: Optional[List[str]] = None + fallback_model_group: Optional[List[str]] = None + ## check for specific model group-specific fallbacks + for idx, item in enumerate(fallbacks): + if isinstance(item, dict): + if list(item.keys())[0] == model_group: # check exact match + fallback_model_group = item[model_group] + break + elif _check_stripped_model_group( + model_group=model_group, fallback_key=list(item.keys())[0] + ): # check generic fallback + stripped_model_fallback = item[list(item.keys())[0]] + elif list(item.keys())[0] == "*": # check generic fallback + generic_fallback_idx = idx + elif isinstance(item, str): + fallback_model_group = [fallbacks.pop(idx)] # returns single-item list + ## if none, check for generic fallback + if fallback_model_group is None: + if stripped_model_fallback is not None: + fallback_model_group = stripped_model_fallback + elif generic_fallback_idx is not None: + fallback_model_group = fallbacks[generic_fallback_idx]["*"] + + return fallback_model_group, generic_fallback_idx + + +async def run_async_fallback( + *args: Tuple[Any], + litellm_router: LitellmRouter, + fallback_model_group: List[str], + original_model_group: str, + original_exception: Exception, + max_fallbacks: int, + fallback_depth: int, + **kwargs, +) -> Any: + """ + Loops through all the fallback model groups and calls kwargs["original_function"] with the arguments and keyword arguments provided. + + If the call is successful, it logs the success and returns the response. + If the call fails, it logs the failure and continues to the next fallback model group. + If all fallback model groups fail, it raises the most recent exception. + + Args: + litellm_router: The litellm router instance. + *args: Positional arguments. + fallback_model_group: List[str] of fallback model groups. example: ["gpt-4", "gpt-3.5-turbo"] + original_model_group: The original model group. example: "gpt-3.5-turbo" + original_exception: The original exception. + **kwargs: Keyword arguments. + + Returns: + The response from the successful fallback model group. + Raises: + The most recent exception if all fallback model groups fail. + """ + + ### BASE CASE ### MAX FALLBACK DEPTH REACHED + if fallback_depth >= max_fallbacks: + raise original_exception + + error_from_fallbacks = original_exception + + for mg in fallback_model_group: + if mg == original_model_group: + continue + try: + # LOGGING + kwargs = litellm_router.log_retry(kwargs=kwargs, e=original_exception) + verbose_router_logger.info(f"Falling back to model_group = {mg}") + if isinstance(mg, str): + kwargs["model"] = mg + elif isinstance(mg, dict): + kwargs.update(mg) + kwargs.setdefault("metadata", {}).update( + {"model_group": kwargs.get("model", None)} + ) # update model_group used, if fallbacks are done + fallback_depth = fallback_depth + 1 + kwargs["fallback_depth"] = fallback_depth + kwargs["max_fallbacks"] = max_fallbacks + response = await litellm_router.async_function_with_fallbacks( + *args, **kwargs + ) + verbose_router_logger.info("Successful fallback b/w models.") + response = add_fallback_headers_to_response( + response=response, + attempted_fallbacks=fallback_depth, + ) + # callback for successfull_fallback_event(): + await log_success_fallback_event( + original_model_group=original_model_group, + kwargs=kwargs, + original_exception=original_exception, + ) + return response + except Exception as e: + error_from_fallbacks = e + await log_failure_fallback_event( + original_model_group=original_model_group, + kwargs=kwargs, + original_exception=original_exception, + ) + raise error_from_fallbacks + + +async def log_success_fallback_event( + original_model_group: str, kwargs: dict, original_exception: Exception +): + """ + Log a successful fallback event to all registered callbacks. + + This function iterates through all callbacks, initializing _known_custom_logger_compatible_callbacks if needed, + and calls the log_success_fallback_event method on CustomLogger instances. + + Args: + original_model_group (str): The original model group before fallback. + kwargs (dict): kwargs for the request + + Note: + Errors during logging are caught and reported but do not interrupt the process. + """ + from litellm.litellm_core_utils.litellm_logging import ( + _init_custom_logger_compatible_class, + ) + + for _callback in litellm.callbacks: + if isinstance(_callback, CustomLogger) or ( + _callback in litellm._known_custom_logger_compatible_callbacks + ): + try: + _callback_custom_logger: Optional[CustomLogger] = None + if _callback in litellm._known_custom_logger_compatible_callbacks: + _callback_custom_logger = _init_custom_logger_compatible_class( + logging_integration=_callback, # type: ignore + llm_router=None, + internal_usage_cache=None, + ) + elif isinstance(_callback, CustomLogger): + _callback_custom_logger = _callback + else: + verbose_router_logger.exception( + f"{_callback} logger not found / initialized properly" + ) + continue + + if _callback_custom_logger is None: + verbose_router_logger.exception( + f"{_callback} logger not found / initialized properly, callback is None" + ) + continue + + await _callback_custom_logger.log_success_fallback_event( + original_model_group=original_model_group, + kwargs=kwargs, + original_exception=original_exception, + ) + except Exception as e: + verbose_router_logger.error( + f"Error in log_success_fallback_event: {str(e)}" + ) + + +async def log_failure_fallback_event( + original_model_group: str, kwargs: dict, original_exception: Exception +): + """ + Log a failed fallback event to all registered callbacks. + + This function iterates through all callbacks, initializing _known_custom_logger_compatible_callbacks if needed, + and calls the log_failure_fallback_event method on CustomLogger instances. + + Args: + original_model_group (str): The original model group before fallback. + kwargs (dict): kwargs for the request + + Note: + Errors during logging are caught and reported but do not interrupt the process. + """ + from litellm.litellm_core_utils.litellm_logging import ( + _init_custom_logger_compatible_class, + ) + + for _callback in litellm.callbacks: + if isinstance(_callback, CustomLogger) or ( + _callback in litellm._known_custom_logger_compatible_callbacks + ): + try: + _callback_custom_logger: Optional[CustomLogger] = None + if _callback in litellm._known_custom_logger_compatible_callbacks: + _callback_custom_logger = _init_custom_logger_compatible_class( + logging_integration=_callback, # type: ignore + llm_router=None, + internal_usage_cache=None, + ) + elif isinstance(_callback, CustomLogger): + _callback_custom_logger = _callback + else: + verbose_router_logger.exception( + f"{_callback} logger not found / initialized properly" + ) + continue + + if _callback_custom_logger is None: + verbose_router_logger.exception( + f"{_callback} logger not found / initialized properly" + ) + continue + + await _callback_custom_logger.log_failure_fallback_event( + original_model_group=original_model_group, + kwargs=kwargs, + original_exception=original_exception, + ) + except Exception as e: + verbose_router_logger.error( + f"Error in log_failure_fallback_event: {str(e)}" + ) + + +def _check_non_standard_fallback_format(fallbacks: Optional[List[Any]]) -> bool: + """ + Checks if the fallbacks list is a list of strings or a list of dictionaries. + + If + - List[str]: e.g. ["claude-3-haiku", "openai/o-1"] + - List[Dict[<LiteLLMParamsTypedDict>, Any]]: e.g. [{"model": "claude-3-haiku", "messages": [{"role": "user", "content": "Hey, how's it going?"}]}] + + If [{"gpt-3.5-turbo": ["claude-3-haiku"]}] then standard format. + """ + if fallbacks is None or not isinstance(fallbacks, list) or len(fallbacks) == 0: + return False + if all(isinstance(item, str) for item in fallbacks): + return True + elif all(isinstance(item, dict) for item in fallbacks): + for key in LiteLLMParamsTypedDict.__annotations__.keys(): + if key in fallbacks[0].keys(): + return True + + return False + + +def run_non_standard_fallback_format( + fallbacks: Union[List[str], List[Dict[str, Any]]], model_group: str +): + pass diff --git a/.venv/lib/python3.12/site-packages/litellm/router_utils/get_retry_from_policy.py b/.venv/lib/python3.12/site-packages/litellm/router_utils/get_retry_from_policy.py new file mode 100644 index 00000000..48df43ef --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/router_utils/get_retry_from_policy.py @@ -0,0 +1,71 @@ +""" +Get num retries for an exception. + +- Account for retry policy by exception type. +""" + +from typing import Dict, Optional, Union + +from litellm.exceptions import ( + AuthenticationError, + BadRequestError, + ContentPolicyViolationError, + RateLimitError, + Timeout, +) +from litellm.types.router import RetryPolicy + + +def get_num_retries_from_retry_policy( + exception: Exception, + retry_policy: Optional[Union[RetryPolicy, dict]] = None, + model_group: Optional[str] = None, + model_group_retry_policy: Optional[Dict[str, RetryPolicy]] = None, +): + """ + BadRequestErrorRetries: Optional[int] = None + AuthenticationErrorRetries: Optional[int] = None + TimeoutErrorRetries: Optional[int] = None + RateLimitErrorRetries: Optional[int] = None + ContentPolicyViolationErrorRetries: Optional[int] = None + """ + # if we can find the exception then in the retry policy -> return the number of retries + + if ( + model_group_retry_policy is not None + and model_group is not None + and model_group in model_group_retry_policy + ): + retry_policy = model_group_retry_policy.get(model_group, None) # type: ignore + + if retry_policy is None: + return None + if isinstance(retry_policy, dict): + retry_policy = RetryPolicy(**retry_policy) + + if ( + isinstance(exception, BadRequestError) + and retry_policy.BadRequestErrorRetries is not None + ): + return retry_policy.BadRequestErrorRetries + if ( + isinstance(exception, AuthenticationError) + and retry_policy.AuthenticationErrorRetries is not None + ): + return retry_policy.AuthenticationErrorRetries + if isinstance(exception, Timeout) and retry_policy.TimeoutErrorRetries is not None: + return retry_policy.TimeoutErrorRetries + if ( + isinstance(exception, RateLimitError) + and retry_policy.RateLimitErrorRetries is not None + ): + return retry_policy.RateLimitErrorRetries + if ( + isinstance(exception, ContentPolicyViolationError) + and retry_policy.ContentPolicyViolationErrorRetries is not None + ): + return retry_policy.ContentPolicyViolationErrorRetries + + +def reset_retry_policy() -> RetryPolicy: + return RetryPolicy() diff --git a/.venv/lib/python3.12/site-packages/litellm/router_utils/handle_error.py b/.venv/lib/python3.12/site-packages/litellm/router_utils/handle_error.py new file mode 100644 index 00000000..132440cb --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/router_utils/handle_error.py @@ -0,0 +1,89 @@ +from typing import TYPE_CHECKING, Any, Optional + +from litellm._logging import verbose_router_logger +from litellm.router_utils.cooldown_handlers import ( + _async_get_cooldown_deployments_with_debug_info, +) +from litellm.types.integrations.slack_alerting import AlertType +from litellm.types.router import RouterRateLimitError + +if TYPE_CHECKING: + from opentelemetry.trace import Span as _Span + + from litellm.router import Router as _Router + + LitellmRouter = _Router + Span = _Span +else: + LitellmRouter = Any + Span = Any + + +async def send_llm_exception_alert( + litellm_router_instance: LitellmRouter, + request_kwargs: dict, + error_traceback_str: str, + original_exception, +): + """ + Only runs if router.slack_alerting_logger is set + Sends a Slack / MS Teams alert for the LLM API call failure. Only if router.slack_alerting_logger is set. + + Parameters: + litellm_router_instance (_Router): The LitellmRouter instance. + original_exception (Any): The original exception that occurred. + + Returns: + None + """ + if litellm_router_instance is None: + return + + if not hasattr(litellm_router_instance, "slack_alerting_logger"): + return + + if litellm_router_instance.slack_alerting_logger is None: + return + + if "proxy_server_request" in request_kwargs: + # Do not send any alert if it's a request from litellm proxy server request + # the proxy is already instrumented to send LLM API call failures + return + + litellm_debug_info = getattr(original_exception, "litellm_debug_info", None) + exception_str = str(original_exception) + if litellm_debug_info is not None: + exception_str += litellm_debug_info + exception_str += f"\n\n{error_traceback_str[:2000]}" + + await litellm_router_instance.slack_alerting_logger.send_alert( + message=f"LLM API call failed: `{exception_str}`", + level="High", + alert_type=AlertType.llm_exceptions, + alerting_metadata={}, + ) + + +async def async_raise_no_deployment_exception( + litellm_router_instance: LitellmRouter, model: str, parent_otel_span: Optional[Span] +): + """ + Raises a RouterRateLimitError if no deployment is found for the given model. + """ + verbose_router_logger.info( + f"get_available_deployment for model: {model}, No deployment available" + ) + model_ids = litellm_router_instance.get_model_ids(model_name=model) + _cooldown_time = litellm_router_instance.cooldown_cache.get_min_cooldown( + model_ids=model_ids, parent_otel_span=parent_otel_span + ) + _cooldown_list = await _async_get_cooldown_deployments_with_debug_info( + litellm_router_instance=litellm_router_instance, + parent_otel_span=parent_otel_span, + ) + return RouterRateLimitError( + model=model, + cooldown_time=_cooldown_time, + enable_pre_call_checks=litellm_router_instance.enable_pre_call_checks, + cooldown_list=_cooldown_list, + ) diff --git a/.venv/lib/python3.12/site-packages/litellm/router_utils/pattern_match_deployments.py b/.venv/lib/python3.12/site-packages/litellm/router_utils/pattern_match_deployments.py new file mode 100644 index 00000000..72951057 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/router_utils/pattern_match_deployments.py @@ -0,0 +1,266 @@ +""" +Class to handle llm wildcard routing and regex pattern matching +""" + +import copy +import re +from re import Match +from typing import Dict, List, Optional, Tuple + +from litellm import get_llm_provider +from litellm._logging import verbose_router_logger + + +class PatternUtils: + @staticmethod + def calculate_pattern_specificity(pattern: str) -> Tuple[int, int]: + """ + Calculate pattern specificity based on length and complexity. + + Args: + pattern: Regex pattern to analyze + + Returns: + Tuple of (length, complexity) for sorting + """ + complexity_chars = ["*", "+", "?", "\\", "^", "$", "|", "(", ")"] + ret_val = ( + len(pattern), # Longer patterns more specific + sum( + pattern.count(char) for char in complexity_chars + ), # More regex complexity + ) + return ret_val + + @staticmethod + def sorted_patterns( + patterns: Dict[str, List[Dict]] + ) -> List[Tuple[str, List[Dict]]]: + """ + Cached property for patterns sorted by specificity. + + Returns: + Sorted list of pattern-deployment tuples + """ + return sorted( + patterns.items(), + key=lambda x: PatternUtils.calculate_pattern_specificity(x[0]), + reverse=True, + ) + + +class PatternMatchRouter: + """ + Class to handle llm wildcard routing and regex pattern matching + + doc: https://docs.litellm.ai/docs/proxy/configs#provider-specific-wildcard-routing + + This class will store a mapping for regex pattern: List[Deployments] + """ + + def __init__(self): + self.patterns: Dict[str, List] = {} + + def add_pattern(self, pattern: str, llm_deployment: Dict): + """ + Add a regex pattern and the corresponding llm deployments to the patterns + + Args: + pattern: str + llm_deployment: str or List[str] + """ + # Convert the pattern to a regex + regex = self._pattern_to_regex(pattern) + if regex not in self.patterns: + self.patterns[regex] = [] + self.patterns[regex].append(llm_deployment) + + def _pattern_to_regex(self, pattern: str) -> str: + """ + Convert a wildcard pattern to a regex pattern + + example: + pattern: openai/* + regex: openai/.* + + pattern: openai/fo::*::static::* + regex: openai/fo::.*::static::.* + + Args: + pattern: str + + Returns: + str: regex pattern + """ + # # Replace '*' with '.*' for regex matching + # regex = pattern.replace("*", ".*") + # # Escape other special characters + # regex = re.escape(regex).replace(r"\.\*", ".*") + # return f"^{regex}$" + return re.escape(pattern).replace(r"\*", "(.*)") + + def _return_pattern_matched_deployments( + self, matched_pattern: Match, deployments: List[Dict] + ) -> List[Dict]: + new_deployments = [] + for deployment in deployments: + new_deployment = copy.deepcopy(deployment) + new_deployment["litellm_params"]["model"] = ( + PatternMatchRouter.set_deployment_model_name( + matched_pattern=matched_pattern, + litellm_deployment_litellm_model=deployment["litellm_params"][ + "model" + ], + ) + ) + new_deployments.append(new_deployment) + + return new_deployments + + def route( + self, request: Optional[str], filtered_model_names: Optional[List[str]] = None + ) -> Optional[List[Dict]]: + """ + Route a requested model to the corresponding llm deployments based on the regex pattern + + loop through all the patterns and find the matching pattern + if a pattern is found, return the corresponding llm deployments + if no pattern is found, return None + + Args: + request: str - the received model name from the user (can be a wildcard route). If none, No deployments will be returned. + filtered_model_names: Optional[List[str]] - if provided, only return deployments that match the filtered_model_names + Returns: + Optional[List[Deployment]]: llm deployments + """ + try: + if request is None: + return None + + sorted_patterns = PatternUtils.sorted_patterns(self.patterns) + regex_filtered_model_names = ( + [self._pattern_to_regex(m) for m in filtered_model_names] + if filtered_model_names is not None + else [] + ) + for pattern, llm_deployments in sorted_patterns: + if ( + filtered_model_names is not None + and pattern not in regex_filtered_model_names + ): + continue + pattern_match = re.match(pattern, request) + if pattern_match: + return self._return_pattern_matched_deployments( + matched_pattern=pattern_match, deployments=llm_deployments + ) + except Exception as e: + verbose_router_logger.debug(f"Error in PatternMatchRouter.route: {str(e)}") + + return None # No matching pattern found + + @staticmethod + def set_deployment_model_name( + matched_pattern: Match, + litellm_deployment_litellm_model: str, + ) -> str: + """ + Set the model name for the matched pattern llm deployment + + E.g.: + + Case 1: + model_name: llmengine/* (can be any regex pattern or wildcard pattern) + litellm_params: + model: openai/* + + if model_name = "llmengine/foo" -> model = "openai/foo" + + Case 2: + model_name: llmengine/fo::*::static::* + litellm_params: + model: openai/fo::*::static::* + + if model_name = "llmengine/foo::bar::static::baz" -> model = "openai/foo::bar::static::baz" + + Case 3: + model_name: *meta.llama3* + litellm_params: + model: bedrock/meta.llama3* + + if model_name = "hello-world-meta.llama3-70b" -> model = "bedrock/meta.llama3-70b" + """ + + ## BASE CASE: if the deployment model name does not contain a wildcard, return the deployment model name + if "*" not in litellm_deployment_litellm_model: + return litellm_deployment_litellm_model + + wildcard_count = litellm_deployment_litellm_model.count("*") + + # Extract all dynamic segments from the request + dynamic_segments = matched_pattern.groups() + + if len(dynamic_segments) > wildcard_count: + return ( + matched_pattern.string + ) # default to the user input, if unable to map based on wildcards. + # Replace the corresponding wildcards in the litellm model pattern with extracted segments + for segment in dynamic_segments: + litellm_deployment_litellm_model = litellm_deployment_litellm_model.replace( + "*", segment, 1 + ) + + return litellm_deployment_litellm_model + + def get_pattern( + self, model: str, custom_llm_provider: Optional[str] = None + ) -> Optional[List[Dict]]: + """ + Check if a pattern exists for the given model and custom llm provider + + Args: + model: str + custom_llm_provider: Optional[str] + + Returns: + bool: True if pattern exists, False otherwise + """ + if custom_llm_provider is None: + try: + ( + _, + custom_llm_provider, + _, + _, + ) = get_llm_provider(model=model) + except Exception: + # get_llm_provider raises exception when provider is unknown + pass + return self.route(model) or self.route(f"{custom_llm_provider}/{model}") + + def get_deployments_by_pattern( + self, model: str, custom_llm_provider: Optional[str] = None + ) -> List[Dict]: + """ + Get the deployments by pattern + + Args: + model: str + custom_llm_provider: Optional[str] + + Returns: + List[Dict]: llm deployments matching the pattern + """ + pattern_match = self.get_pattern(model, custom_llm_provider) + if pattern_match: + return pattern_match + return [] + + +# Example usage: +# router = PatternRouter() +# router.add_pattern('openai/*', [Deployment(), Deployment()]) +# router.add_pattern('openai/fo::*::static::*', Deployment()) +# print(router.route('openai/gpt-4')) # Output: [Deployment(), Deployment()] +# print(router.route('openai/fo::hi::static::hi')) # Output: [Deployment()] +# print(router.route('something/else')) # Output: None diff --git a/.venv/lib/python3.12/site-packages/litellm/router_utils/pre_call_checks/prompt_caching_deployment_check.py b/.venv/lib/python3.12/site-packages/litellm/router_utils/pre_call_checks/prompt_caching_deployment_check.py new file mode 100644 index 00000000..d3d237d9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/router_utils/pre_call_checks/prompt_caching_deployment_check.py @@ -0,0 +1,99 @@ +""" +Check if prompt caching is valid for a given deployment + +Route to previously cached model id, if valid +""" + +from typing import List, Optional, cast + +from litellm import verbose_logger +from litellm.caching.dual_cache import DualCache +from litellm.integrations.custom_logger import CustomLogger, Span +from litellm.types.llms.openai import AllMessageValues +from litellm.types.utils import CallTypes, StandardLoggingPayload +from litellm.utils import is_prompt_caching_valid_prompt + +from ..prompt_caching_cache import PromptCachingCache + + +class PromptCachingDeploymentCheck(CustomLogger): + def __init__(self, cache: DualCache): + self.cache = cache + + async def async_filter_deployments( + self, + model: str, + healthy_deployments: List, + messages: Optional[List[AllMessageValues]], + request_kwargs: Optional[dict] = None, + parent_otel_span: Optional[Span] = None, + ) -> List[dict]: + if messages is not None and is_prompt_caching_valid_prompt( + messages=messages, + model=model, + ): # prompt > 1024 tokens + prompt_cache = PromptCachingCache( + cache=self.cache, + ) + + model_id_dict = await prompt_cache.async_get_model_id( + messages=cast(List[AllMessageValues], messages), + tools=None, + ) + if model_id_dict is not None: + model_id = model_id_dict["model_id"] + for deployment in healthy_deployments: + if deployment["model_info"]["id"] == model_id: + return [deployment] + + return healthy_deployments + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + standard_logging_object: Optional[StandardLoggingPayload] = kwargs.get( + "standard_logging_object", None + ) + + if standard_logging_object is None: + return + + call_type = standard_logging_object["call_type"] + + if ( + call_type != CallTypes.completion.value + and call_type != CallTypes.acompletion.value + ): # only use prompt caching for completion calls + verbose_logger.debug( + "litellm.router_utils.pre_call_checks.prompt_caching_deployment_check: skipping adding model id to prompt caching cache, CALL TYPE IS NOT COMPLETION" + ) + return + + model = standard_logging_object["model"] + messages = standard_logging_object["messages"] + model_id = standard_logging_object["model_id"] + + if messages is None or not isinstance(messages, list): + verbose_logger.debug( + "litellm.router_utils.pre_call_checks.prompt_caching_deployment_check: skipping adding model id to prompt caching cache, MESSAGES IS NOT A LIST" + ) + return + if model_id is None: + verbose_logger.debug( + "litellm.router_utils.pre_call_checks.prompt_caching_deployment_check: skipping adding model id to prompt caching cache, MODEL ID IS NONE" + ) + return + + ## PROMPT CACHING - cache model id, if prompt caching valid prompt + provider + if is_prompt_caching_valid_prompt( + model=model, + messages=cast(List[AllMessageValues], messages), + ): + cache = PromptCachingCache( + cache=self.cache, + ) + await cache.async_add_model_id( + model_id=model_id, + messages=messages, + tools=None, # [TODO]: add tools once standard_logging_object supports it + ) + + return diff --git a/.venv/lib/python3.12/site-packages/litellm/router_utils/prompt_caching_cache.py b/.venv/lib/python3.12/site-packages/litellm/router_utils/prompt_caching_cache.py new file mode 100644 index 00000000..1bf686d6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/router_utils/prompt_caching_cache.py @@ -0,0 +1,171 @@ +""" +Wrapper around router cache. Meant to store model id when prompt caching supported prompt is called. +""" + +import hashlib +import json +from typing import TYPE_CHECKING, Any, List, Optional, TypedDict + +from litellm.caching.caching import DualCache +from litellm.caching.in_memory_cache import InMemoryCache +from litellm.types.llms.openai import AllMessageValues, ChatCompletionToolParam + +if TYPE_CHECKING: + from opentelemetry.trace import Span as _Span + + from litellm.router import Router + + litellm_router = Router + Span = _Span +else: + Span = Any + litellm_router = Any + + +class PromptCachingCacheValue(TypedDict): + model_id: str + + +class PromptCachingCache: + def __init__(self, cache: DualCache): + self.cache = cache + self.in_memory_cache = InMemoryCache() + + @staticmethod + def serialize_object(obj: Any) -> Any: + """Helper function to serialize Pydantic objects, dictionaries, or fallback to string.""" + if hasattr(obj, "dict"): + # If the object is a Pydantic model, use its `dict()` method + return obj.dict() + elif isinstance(obj, dict): + # If the object is a dictionary, serialize it with sorted keys + return json.dumps( + obj, sort_keys=True, separators=(",", ":") + ) # Standardize serialization + + elif isinstance(obj, list): + # Serialize lists by ensuring each element is handled properly + return [PromptCachingCache.serialize_object(item) for item in obj] + elif isinstance(obj, (int, float, bool)): + return obj # Keep primitive types as-is + return str(obj) + + @staticmethod + def get_prompt_caching_cache_key( + messages: Optional[List[AllMessageValues]], + tools: Optional[List[ChatCompletionToolParam]], + ) -> Optional[str]: + if messages is None and tools is None: + return None + # Use serialize_object for consistent and stable serialization + data_to_hash = {} + if messages is not None: + serialized_messages = PromptCachingCache.serialize_object(messages) + data_to_hash["messages"] = serialized_messages + if tools is not None: + serialized_tools = PromptCachingCache.serialize_object(tools) + data_to_hash["tools"] = serialized_tools + + # Combine serialized data into a single string + data_to_hash_str = json.dumps( + data_to_hash, + sort_keys=True, + separators=(",", ":"), + ) + + # Create a hash of the serialized data for a stable cache key + hashed_data = hashlib.sha256(data_to_hash_str.encode()).hexdigest() + return f"deployment:{hashed_data}:prompt_caching" + + def add_model_id( + self, + model_id: str, + messages: Optional[List[AllMessageValues]], + tools: Optional[List[ChatCompletionToolParam]], + ) -> None: + if messages is None and tools is None: + return None + + cache_key = PromptCachingCache.get_prompt_caching_cache_key(messages, tools) + self.cache.set_cache( + cache_key, PromptCachingCacheValue(model_id=model_id), ttl=300 + ) + return None + + async def async_add_model_id( + self, + model_id: str, + messages: Optional[List[AllMessageValues]], + tools: Optional[List[ChatCompletionToolParam]], + ) -> None: + if messages is None and tools is None: + return None + + cache_key = PromptCachingCache.get_prompt_caching_cache_key(messages, tools) + await self.cache.async_set_cache( + cache_key, + PromptCachingCacheValue(model_id=model_id), + ttl=300, # store for 5 minutes + ) + return None + + async def async_get_model_id( + self, + messages: Optional[List[AllMessageValues]], + tools: Optional[List[ChatCompletionToolParam]], + ) -> Optional[PromptCachingCacheValue]: + """ + if messages is not none + - check full messages + - check messages[:-1] + - check messages[:-2] + - check messages[:-3] + + use self.cache.async_batch_get_cache(keys=potential_cache_keys]) + """ + if messages is None and tools is None: + return None + + # Generate potential cache keys by slicing messages + + potential_cache_keys = [] + + if messages is not None: + full_cache_key = PromptCachingCache.get_prompt_caching_cache_key( + messages, tools + ) + potential_cache_keys.append(full_cache_key) + + # Check progressively shorter message slices + for i in range(1, min(4, len(messages))): + partial_messages = messages[:-i] + partial_cache_key = PromptCachingCache.get_prompt_caching_cache_key( + partial_messages, tools + ) + potential_cache_keys.append(partial_cache_key) + + # Perform batch cache lookup + cache_results = await self.cache.async_batch_get_cache( + keys=potential_cache_keys + ) + + if cache_results is None: + return None + + # Return the first non-None cache result + for result in cache_results: + if result is not None: + return result + + return None + + def get_model_id( + self, + messages: Optional[List[AllMessageValues]], + tools: Optional[List[ChatCompletionToolParam]], + ) -> Optional[PromptCachingCacheValue]: + if messages is None and tools is None: + return None + + cache_key = PromptCachingCache.get_prompt_caching_cache_key(messages, tools) + return self.cache.get_cache(cache_key) diff --git a/.venv/lib/python3.12/site-packages/litellm/router_utils/response_headers.py b/.venv/lib/python3.12/site-packages/litellm/router_utils/response_headers.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/router_utils/response_headers.py diff --git a/.venv/lib/python3.12/site-packages/litellm/router_utils/router_callbacks/track_deployment_metrics.py b/.venv/lib/python3.12/site-packages/litellm/router_utils/router_callbacks/track_deployment_metrics.py new file mode 100644 index 00000000..1f226879 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/router_utils/router_callbacks/track_deployment_metrics.py @@ -0,0 +1,90 @@ +""" +Helper functions to get/set num success and num failures per deployment + + +set_deployment_failures_for_current_minute +set_deployment_successes_for_current_minute + +get_deployment_failures_for_current_minute +get_deployment_successes_for_current_minute +""" + +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from litellm.router import Router as _Router + + LitellmRouter = _Router +else: + LitellmRouter = Any + + +def increment_deployment_successes_for_current_minute( + litellm_router_instance: LitellmRouter, + deployment_id: str, +) -> str: + """ + In-Memory: Increments the number of successes for the current minute for a deployment_id + """ + key = f"{deployment_id}:successes" + litellm_router_instance.cache.increment_cache( + local_only=True, + key=key, + value=1, + ttl=60, + ) + return key + + +def increment_deployment_failures_for_current_minute( + litellm_router_instance: LitellmRouter, + deployment_id: str, +): + """ + In-Memory: Increments the number of failures for the current minute for a deployment_id + """ + key = f"{deployment_id}:fails" + litellm_router_instance.cache.increment_cache( + local_only=True, + key=key, + value=1, + ttl=60, + ) + + +def get_deployment_successes_for_current_minute( + litellm_router_instance: LitellmRouter, + deployment_id: str, +) -> int: + """ + Returns the number of successes for the current minute for a deployment_id + + Returns 0 if no value found + """ + key = f"{deployment_id}:successes" + return ( + litellm_router_instance.cache.get_cache( + local_only=True, + key=key, + ) + or 0 + ) + + +def get_deployment_failures_for_current_minute( + litellm_router_instance: LitellmRouter, + deployment_id: str, +) -> int: + """ + Returns the number of fails for the current minute for a deployment_id + + Returns 0 if no value found + """ + key = f"{deployment_id}:fails" + return ( + litellm_router_instance.cache.get_cache( + local_only=True, + key=key, + ) + or 0 + ) |