aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/proxy/health_check.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/proxy/health_check.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/health_check.py183
1 files changed, 183 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/health_check.py b/.venv/lib/python3.12/site-packages/litellm/proxy/health_check.py
new file mode 100644
index 00000000..f9455387
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/health_check.py
@@ -0,0 +1,183 @@
+# This file runs a health check for the LLM, used on litellm/proxy
+
+import asyncio
+import logging
+import random
+from typing import List, Optional
+
+import litellm
+
+logger = logging.getLogger(__name__)
+from litellm.constants import HEALTH_CHECK_TIMEOUT_SECONDS
+
+ILLEGAL_DISPLAY_PARAMS = [
+ "messages",
+ "api_key",
+ "prompt",
+ "input",
+ "vertex_credentials",
+ "aws_access_key_id",
+ "aws_secret_access_key",
+]
+
+MINIMAL_DISPLAY_PARAMS = ["model", "mode_error"]
+
+
+def _get_random_llm_message():
+ """
+ Get a random message from the LLM.
+ """
+ messages = ["Hey how's it going?", "What's 1 + 1?"]
+
+ return [{"role": "user", "content": random.choice(messages)}]
+
+
+def _clean_endpoint_data(endpoint_data: dict, details: Optional[bool] = True):
+ """
+ Clean the endpoint data for display to users.
+ """
+ endpoint_data.pop("litellm_logging_obj", None)
+ return (
+ {k: v for k, v in endpoint_data.items() if k not in ILLEGAL_DISPLAY_PARAMS}
+ if details is not False
+ else {k: v for k, v in endpoint_data.items() if k in MINIMAL_DISPLAY_PARAMS}
+ )
+
+
+def filter_deployments_by_id(
+ model_list: List,
+) -> List:
+ seen_ids = set()
+ filtered_deployments = []
+
+ for deployment in model_list:
+ _model_info = deployment.get("model_info") or {}
+ _id = _model_info.get("id") or None
+ if _id is None:
+ continue
+
+ if _id not in seen_ids:
+ seen_ids.add(_id)
+ filtered_deployments.append(deployment)
+
+ return filtered_deployments
+
+
+async def run_with_timeout(task, timeout):
+ try:
+ return await asyncio.wait_for(task, timeout)
+ except asyncio.TimeoutError:
+ task.cancel()
+ # Only cancel child tasks of the current task
+ current_task = asyncio.current_task()
+ for t in asyncio.all_tasks():
+ if t != current_task:
+ t.cancel()
+ try:
+ await asyncio.wait_for(task, 0.1) # Give 100ms for cleanup
+ except (asyncio.TimeoutError, asyncio.CancelledError, Exception):
+ pass
+ return {"error": "Timeout exceeded"}
+
+
+async def _perform_health_check(model_list: list, details: Optional[bool] = True):
+ """
+ Perform a health check for each model in the list.
+ """
+
+ tasks = []
+ for model in model_list:
+ litellm_params = model["litellm_params"]
+ model_info = model.get("model_info", {})
+ mode = model_info.get("mode", None)
+ litellm_params = _update_litellm_params_for_health_check(
+ model_info, litellm_params
+ )
+ timeout = model_info.get("health_check_timeout") or HEALTH_CHECK_TIMEOUT_SECONDS
+
+ task = run_with_timeout(
+ litellm.ahealth_check(
+ model["litellm_params"],
+ mode=mode,
+ prompt="test from litellm",
+ input=["test from litellm"],
+ ),
+ timeout,
+ )
+
+ tasks.append(task)
+
+ results = await asyncio.gather(*tasks, return_exceptions=True)
+
+ healthy_endpoints = []
+ unhealthy_endpoints = []
+
+ for is_healthy, model in zip(results, model_list):
+ litellm_params = model["litellm_params"]
+
+ if isinstance(is_healthy, dict) and "error" not in is_healthy:
+ healthy_endpoints.append(
+ _clean_endpoint_data({**litellm_params, **is_healthy}, details)
+ )
+ elif isinstance(is_healthy, dict):
+ unhealthy_endpoints.append(
+ _clean_endpoint_data({**litellm_params, **is_healthy}, details)
+ )
+ else:
+ unhealthy_endpoints.append(_clean_endpoint_data(litellm_params, details))
+
+ return healthy_endpoints, unhealthy_endpoints
+
+
+def _update_litellm_params_for_health_check(
+ model_info: dict, litellm_params: dict
+) -> dict:
+ """
+ Update the litellm params for health check.
+
+ - gets a short `messages` param for health check
+ - updates the `model` param with the `health_check_model` if it exists Doc: https://docs.litellm.ai/docs/proxy/health#wildcard-routes
+ """
+ litellm_params["messages"] = _get_random_llm_message()
+ _health_check_model = model_info.get("health_check_model", None)
+ if _health_check_model is not None:
+ litellm_params["model"] = _health_check_model
+ return litellm_params
+
+
+async def perform_health_check(
+ model_list: list,
+ model: Optional[str] = None,
+ cli_model: Optional[str] = None,
+ details: Optional[bool] = True,
+):
+ """
+ Perform a health check on the system.
+
+ Returns:
+ (bool): True if the health check passes, False otherwise.
+ """
+ if not model_list:
+ if cli_model:
+ model_list = [
+ {"model_name": cli_model, "litellm_params": {"model": cli_model}}
+ ]
+ else:
+ return [], []
+
+ if model is not None:
+ _new_model_list = [
+ x for x in model_list if x["litellm_params"]["model"] == model
+ ]
+ if _new_model_list == []:
+ _new_model_list = [x for x in model_list if x["model_name"] == model]
+ model_list = _new_model_list
+
+ model_list = filter_deployments_by_id(
+ model_list=model_list
+ ) # filter duplicate deployments (e.g. when model alias'es are used)
+ healthy_endpoints, unhealthy_endpoints = await _perform_health_check(
+ model_list, details
+ )
+
+ return healthy_endpoints, unhealthy_endpoints