about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/litellm/router_utils/fallback_event_handlers.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/router_utils/fallback_event_handlers.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/router_utils/fallback_event_handlers.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/router_utils/fallback_event_handlers.py303
1 files changed, 303 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/router_utils/fallback_event_handlers.py b/.venv/lib/python3.12/site-packages/litellm/router_utils/fallback_event_handlers.py
new file mode 100644
index 00000000..df805e49
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/router_utils/fallback_event_handlers.py
@@ -0,0 +1,303 @@
+from enum import Enum
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
+
+import litellm
+from litellm._logging import verbose_router_logger
+from litellm.integrations.custom_logger import CustomLogger
+from litellm.router_utils.add_retry_fallback_headers import (
+    add_fallback_headers_to_response,
+)
+from litellm.types.router import LiteLLMParamsTypedDict
+
+if TYPE_CHECKING:
+    from litellm.router import Router as _Router
+
+    LitellmRouter = _Router
+else:
+    LitellmRouter = Any
+
+
+def _check_stripped_model_group(model_group: str, fallback_key: str) -> bool:
+    """
+    Handles wildcard routing scenario
+
+    where fallbacks set like:
+    [{"gpt-3.5-turbo": ["claude-3-haiku"]}]
+
+    but model_group is like:
+    "openai/gpt-3.5-turbo"
+
+    Returns:
+    - True if the stripped model group == fallback_key
+    """
+    for provider in litellm.provider_list:
+        if isinstance(provider, Enum):
+            _provider = provider.value
+        else:
+            _provider = provider
+        if model_group.startswith(f"{_provider}/"):
+            stripped_model_group = model_group.replace(f"{_provider}/", "")
+            if stripped_model_group == fallback_key:
+                return True
+    return False
+
+
+def get_fallback_model_group(
+    fallbacks: List[Any], model_group: str
+) -> Tuple[Optional[List[str]], Optional[int]]:
+    """
+    Returns:
+    - fallback_model_group: List[str] of fallback model groups. example: ["gpt-4", "gpt-3.5-turbo"]
+    - generic_fallback_idx: int of the index of the generic fallback in the fallbacks list.
+
+    Checks:
+    - exact match
+    - stripped model group match
+    - generic fallback
+    """
+    generic_fallback_idx: Optional[int] = None
+    stripped_model_fallback: Optional[List[str]] = None
+    fallback_model_group: Optional[List[str]] = None
+    ## check for specific model group-specific fallbacks
+    for idx, item in enumerate(fallbacks):
+        if isinstance(item, dict):
+            if list(item.keys())[0] == model_group:  # check exact match
+                fallback_model_group = item[model_group]
+                break
+            elif _check_stripped_model_group(
+                model_group=model_group, fallback_key=list(item.keys())[0]
+            ):  # check generic fallback
+                stripped_model_fallback = item[list(item.keys())[0]]
+            elif list(item.keys())[0] == "*":  # check generic fallback
+                generic_fallback_idx = idx
+        elif isinstance(item, str):
+            fallback_model_group = [fallbacks.pop(idx)]  # returns single-item list
+    ## if none, check for generic fallback
+    if fallback_model_group is None:
+        if stripped_model_fallback is not None:
+            fallback_model_group = stripped_model_fallback
+        elif generic_fallback_idx is not None:
+            fallback_model_group = fallbacks[generic_fallback_idx]["*"]
+
+    return fallback_model_group, generic_fallback_idx
+
+
+async def run_async_fallback(
+    *args: Tuple[Any],
+    litellm_router: LitellmRouter,
+    fallback_model_group: List[str],
+    original_model_group: str,
+    original_exception: Exception,
+    max_fallbacks: int,
+    fallback_depth: int,
+    **kwargs,
+) -> Any:
+    """
+    Loops through all the fallback model groups and calls kwargs["original_function"] with the arguments and keyword arguments provided.
+
+    If the call is successful, it logs the success and returns the response.
+    If the call fails, it logs the failure and continues to the next fallback model group.
+    If all fallback model groups fail, it raises the most recent exception.
+
+    Args:
+        litellm_router: The litellm router instance.
+        *args: Positional arguments.
+        fallback_model_group: List[str] of fallback model groups. example: ["gpt-4", "gpt-3.5-turbo"]
+        original_model_group: The original model group. example: "gpt-3.5-turbo"
+        original_exception: The original exception.
+        **kwargs: Keyword arguments.
+
+    Returns:
+        The response from the successful fallback model group.
+    Raises:
+        The most recent exception if all fallback model groups fail.
+    """
+
+    ### BASE CASE ### MAX FALLBACK DEPTH REACHED
+    if fallback_depth >= max_fallbacks:
+        raise original_exception
+
+    error_from_fallbacks = original_exception
+
+    for mg in fallback_model_group:
+        if mg == original_model_group:
+            continue
+        try:
+            # LOGGING
+            kwargs = litellm_router.log_retry(kwargs=kwargs, e=original_exception)
+            verbose_router_logger.info(f"Falling back to model_group = {mg}")
+            if isinstance(mg, str):
+                kwargs["model"] = mg
+            elif isinstance(mg, dict):
+                kwargs.update(mg)
+            kwargs.setdefault("metadata", {}).update(
+                {"model_group": kwargs.get("model", None)}
+            )  # update model_group used, if fallbacks are done
+            fallback_depth = fallback_depth + 1
+            kwargs["fallback_depth"] = fallback_depth
+            kwargs["max_fallbacks"] = max_fallbacks
+            response = await litellm_router.async_function_with_fallbacks(
+                *args, **kwargs
+            )
+            verbose_router_logger.info("Successful fallback b/w models.")
+            response = add_fallback_headers_to_response(
+                response=response,
+                attempted_fallbacks=fallback_depth,
+            )
+            # callback for successfull_fallback_event():
+            await log_success_fallback_event(
+                original_model_group=original_model_group,
+                kwargs=kwargs,
+                original_exception=original_exception,
+            )
+            return response
+        except Exception as e:
+            error_from_fallbacks = e
+            await log_failure_fallback_event(
+                original_model_group=original_model_group,
+                kwargs=kwargs,
+                original_exception=original_exception,
+            )
+    raise error_from_fallbacks
+
+
+async def log_success_fallback_event(
+    original_model_group: str, kwargs: dict, original_exception: Exception
+):
+    """
+    Log a successful fallback event to all registered callbacks.
+
+    This function iterates through all callbacks, initializing _known_custom_logger_compatible_callbacks  if needed,
+    and calls the log_success_fallback_event method on CustomLogger instances.
+
+    Args:
+        original_model_group (str): The original model group before fallback.
+        kwargs (dict): kwargs for the request
+
+    Note:
+        Errors during logging are caught and reported but do not interrupt the process.
+    """
+    from litellm.litellm_core_utils.litellm_logging import (
+        _init_custom_logger_compatible_class,
+    )
+
+    for _callback in litellm.callbacks:
+        if isinstance(_callback, CustomLogger) or (
+            _callback in litellm._known_custom_logger_compatible_callbacks
+        ):
+            try:
+                _callback_custom_logger: Optional[CustomLogger] = None
+                if _callback in litellm._known_custom_logger_compatible_callbacks:
+                    _callback_custom_logger = _init_custom_logger_compatible_class(
+                        logging_integration=_callback,  # type: ignore
+                        llm_router=None,
+                        internal_usage_cache=None,
+                    )
+                elif isinstance(_callback, CustomLogger):
+                    _callback_custom_logger = _callback
+                else:
+                    verbose_router_logger.exception(
+                        f"{_callback} logger not found / initialized properly"
+                    )
+                    continue
+
+                if _callback_custom_logger is None:
+                    verbose_router_logger.exception(
+                        f"{_callback} logger not found / initialized properly, callback is None"
+                    )
+                    continue
+
+                await _callback_custom_logger.log_success_fallback_event(
+                    original_model_group=original_model_group,
+                    kwargs=kwargs,
+                    original_exception=original_exception,
+                )
+            except Exception as e:
+                verbose_router_logger.error(
+                    f"Error in log_success_fallback_event: {str(e)}"
+                )
+
+
+async def log_failure_fallback_event(
+    original_model_group: str, kwargs: dict, original_exception: Exception
+):
+    """
+    Log a failed fallback event to all registered callbacks.
+
+    This function iterates through all callbacks, initializing _known_custom_logger_compatible_callbacks if needed,
+    and calls the log_failure_fallback_event method on CustomLogger instances.
+
+    Args:
+        original_model_group (str): The original model group before fallback.
+        kwargs (dict): kwargs for the request
+
+    Note:
+        Errors during logging are caught and reported but do not interrupt the process.
+    """
+    from litellm.litellm_core_utils.litellm_logging import (
+        _init_custom_logger_compatible_class,
+    )
+
+    for _callback in litellm.callbacks:
+        if isinstance(_callback, CustomLogger) or (
+            _callback in litellm._known_custom_logger_compatible_callbacks
+        ):
+            try:
+                _callback_custom_logger: Optional[CustomLogger] = None
+                if _callback in litellm._known_custom_logger_compatible_callbacks:
+                    _callback_custom_logger = _init_custom_logger_compatible_class(
+                        logging_integration=_callback,  # type: ignore
+                        llm_router=None,
+                        internal_usage_cache=None,
+                    )
+                elif isinstance(_callback, CustomLogger):
+                    _callback_custom_logger = _callback
+                else:
+                    verbose_router_logger.exception(
+                        f"{_callback} logger not found / initialized properly"
+                    )
+                    continue
+
+                if _callback_custom_logger is None:
+                    verbose_router_logger.exception(
+                        f"{_callback} logger not found / initialized properly"
+                    )
+                    continue
+
+                await _callback_custom_logger.log_failure_fallback_event(
+                    original_model_group=original_model_group,
+                    kwargs=kwargs,
+                    original_exception=original_exception,
+                )
+            except Exception as e:
+                verbose_router_logger.error(
+                    f"Error in log_failure_fallback_event: {str(e)}"
+                )
+
+
+def _check_non_standard_fallback_format(fallbacks: Optional[List[Any]]) -> bool:
+    """
+    Checks if the fallbacks list is a list of strings or a list of dictionaries.
+
+    If
+    - List[str]: e.g. ["claude-3-haiku", "openai/o-1"]
+    - List[Dict[<LiteLLMParamsTypedDict>, Any]]: e.g. [{"model": "claude-3-haiku", "messages": [{"role": "user", "content": "Hey, how's it going?"}]}]
+
+    If [{"gpt-3.5-turbo": ["claude-3-haiku"]}] then standard format.
+    """
+    if fallbacks is None or not isinstance(fallbacks, list) or len(fallbacks) == 0:
+        return False
+    if all(isinstance(item, str) for item in fallbacks):
+        return True
+    elif all(isinstance(item, dict) for item in fallbacks):
+        for key in LiteLLMParamsTypedDict.__annotations__.keys():
+            if key in fallbacks[0].keys():
+                return True
+
+    return False
+
+
+def run_non_standard_fallback_format(
+    fallbacks: Union[List[str], List[Dict[str, Any]]], model_group: str
+):
+    pass