about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/litellm/proxy/health_check.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/proxy/health_check.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/health_check.py183
1 files changed, 183 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/health_check.py b/.venv/lib/python3.12/site-packages/litellm/proxy/health_check.py
new file mode 100644
index 00000000..f9455387
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/health_check.py
@@ -0,0 +1,183 @@
+# This file runs a health check for the LLM, used on litellm/proxy
+
+import asyncio
+import logging
+import random
+from typing import List, Optional
+
+import litellm
+
+logger = logging.getLogger(__name__)
+from litellm.constants import HEALTH_CHECK_TIMEOUT_SECONDS
+
+ILLEGAL_DISPLAY_PARAMS = [
+    "messages",
+    "api_key",
+    "prompt",
+    "input",
+    "vertex_credentials",
+    "aws_access_key_id",
+    "aws_secret_access_key",
+]
+
+MINIMAL_DISPLAY_PARAMS = ["model", "mode_error"]
+
+
+def _get_random_llm_message():
+    """
+    Get a random message from the LLM.
+    """
+    messages = ["Hey how's it going?", "What's 1 + 1?"]
+
+    return [{"role": "user", "content": random.choice(messages)}]
+
+
+def _clean_endpoint_data(endpoint_data: dict, details: Optional[bool] = True):
+    """
+    Clean the endpoint data for display to users.
+    """
+    endpoint_data.pop("litellm_logging_obj", None)
+    return (
+        {k: v for k, v in endpoint_data.items() if k not in ILLEGAL_DISPLAY_PARAMS}
+        if details is not False
+        else {k: v for k, v in endpoint_data.items() if k in MINIMAL_DISPLAY_PARAMS}
+    )
+
+
+def filter_deployments_by_id(
+    model_list: List,
+) -> List:
+    seen_ids = set()
+    filtered_deployments = []
+
+    for deployment in model_list:
+        _model_info = deployment.get("model_info") or {}
+        _id = _model_info.get("id") or None
+        if _id is None:
+            continue
+
+        if _id not in seen_ids:
+            seen_ids.add(_id)
+            filtered_deployments.append(deployment)
+
+    return filtered_deployments
+
+
+async def run_with_timeout(task, timeout):
+    try:
+        return await asyncio.wait_for(task, timeout)
+    except asyncio.TimeoutError:
+        task.cancel()
+        # Only cancel child tasks of the current task
+        current_task = asyncio.current_task()
+        for t in asyncio.all_tasks():
+            if t != current_task:
+                t.cancel()
+        try:
+            await asyncio.wait_for(task, 0.1)  # Give 100ms for cleanup
+        except (asyncio.TimeoutError, asyncio.CancelledError, Exception):
+            pass
+        return {"error": "Timeout exceeded"}
+
+
+async def _perform_health_check(model_list: list, details: Optional[bool] = True):
+    """
+    Perform a health check for each model in the list.
+    """
+
+    tasks = []
+    for model in model_list:
+        litellm_params = model["litellm_params"]
+        model_info = model.get("model_info", {})
+        mode = model_info.get("mode", None)
+        litellm_params = _update_litellm_params_for_health_check(
+            model_info, litellm_params
+        )
+        timeout = model_info.get("health_check_timeout") or HEALTH_CHECK_TIMEOUT_SECONDS
+
+        task = run_with_timeout(
+            litellm.ahealth_check(
+                model["litellm_params"],
+                mode=mode,
+                prompt="test from litellm",
+                input=["test from litellm"],
+            ),
+            timeout,
+        )
+
+        tasks.append(task)
+
+    results = await asyncio.gather(*tasks, return_exceptions=True)
+
+    healthy_endpoints = []
+    unhealthy_endpoints = []
+
+    for is_healthy, model in zip(results, model_list):
+        litellm_params = model["litellm_params"]
+
+        if isinstance(is_healthy, dict) and "error" not in is_healthy:
+            healthy_endpoints.append(
+                _clean_endpoint_data({**litellm_params, **is_healthy}, details)
+            )
+        elif isinstance(is_healthy, dict):
+            unhealthy_endpoints.append(
+                _clean_endpoint_data({**litellm_params, **is_healthy}, details)
+            )
+        else:
+            unhealthy_endpoints.append(_clean_endpoint_data(litellm_params, details))
+
+    return healthy_endpoints, unhealthy_endpoints
+
+
+def _update_litellm_params_for_health_check(
+    model_info: dict, litellm_params: dict
+) -> dict:
+    """
+    Update the litellm params for health check.
+
+    - gets a short `messages` param for health check
+    - updates the `model` param with the `health_check_model` if it exists Doc: https://docs.litellm.ai/docs/proxy/health#wildcard-routes
+    """
+    litellm_params["messages"] = _get_random_llm_message()
+    _health_check_model = model_info.get("health_check_model", None)
+    if _health_check_model is not None:
+        litellm_params["model"] = _health_check_model
+    return litellm_params
+
+
+async def perform_health_check(
+    model_list: list,
+    model: Optional[str] = None,
+    cli_model: Optional[str] = None,
+    details: Optional[bool] = True,
+):
+    """
+    Perform a health check on the system.
+
+    Returns:
+        (bool): True if the health check passes, False otherwise.
+    """
+    if not model_list:
+        if cli_model:
+            model_list = [
+                {"model_name": cli_model, "litellm_params": {"model": cli_model}}
+            ]
+        else:
+            return [], []
+
+    if model is not None:
+        _new_model_list = [
+            x for x in model_list if x["litellm_params"]["model"] == model
+        ]
+        if _new_model_list == []:
+            _new_model_list = [x for x in model_list if x["model_name"] == model]
+        model_list = _new_model_list
+
+    model_list = filter_deployments_by_id(
+        model_list=model_list
+    )  # filter duplicate deployments (e.g. when model alias'es are used)
+    healthy_endpoints, unhealthy_endpoints = await _perform_health_check(
+        model_list, details
+    )
+
+    return healthy_endpoints, unhealthy_endpoints