about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/litellm/router_strategy/least_busy.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/router_strategy/least_busy.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/router_strategy/least_busy.py252
1 files changed, 252 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/router_strategy/least_busy.py b/.venv/lib/python3.12/site-packages/litellm/router_strategy/least_busy.py
new file mode 100644
index 00000000..12f3f01c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/router_strategy/least_busy.py
@@ -0,0 +1,252 @@
+#### What this does ####
+#   identifies least busy deployment
+#   How is this achieved?
+#   - Before each call, have the router print the state of requests {"deployment": "requests_in_flight"}
+#   - use litellm.input_callbacks to log when a request is just about to be made to a model - {"deployment-id": traffic}
+#   - use litellm.success + failure callbacks to log when a request completed
+#   - in get_available_deployment, for a given model group name -> pick based on traffic
+
+import random
+from typing import Optional
+
+from litellm.caching.caching import DualCache
+from litellm.integrations.custom_logger import CustomLogger
+
+
+class LeastBusyLoggingHandler(CustomLogger):
+    test_flag: bool = False
+    logged_success: int = 0
+    logged_failure: int = 0
+
+    def __init__(self, router_cache: DualCache, model_list: list):
+        self.router_cache = router_cache
+        self.mapping_deployment_to_id: dict = {}
+        self.model_list = model_list
+
+    def log_pre_api_call(self, model, messages, kwargs):
+        """
+        Log when a model is being used.
+
+        Caching based on model group.
+        """
+        try:
+            if kwargs["litellm_params"].get("metadata") is None:
+                pass
+            else:
+                model_group = kwargs["litellm_params"]["metadata"].get(
+                    "model_group", None
+                )
+                id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
+                if model_group is None or id is None:
+                    return
+                elif isinstance(id, int):
+                    id = str(id)
+
+                request_count_api_key = f"{model_group}_request_count"
+                # update cache
+                request_count_dict = (
+                    self.router_cache.get_cache(key=request_count_api_key) or {}
+                )
+                request_count_dict[id] = request_count_dict.get(id, 0) + 1
+
+                self.router_cache.set_cache(
+                    key=request_count_api_key, value=request_count_dict
+                )
+        except Exception:
+            pass
+
+    def log_success_event(self, kwargs, response_obj, start_time, end_time):
+        try:
+            if kwargs["litellm_params"].get("metadata") is None:
+                pass
+            else:
+                model_group = kwargs["litellm_params"]["metadata"].get(
+                    "model_group", None
+                )
+
+                id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
+                if model_group is None or id is None:
+                    return
+                elif isinstance(id, int):
+                    id = str(id)
+
+                request_count_api_key = f"{model_group}_request_count"
+                # decrement count in cache
+                request_count_dict = (
+                    self.router_cache.get_cache(key=request_count_api_key) or {}
+                )
+                request_count_value: Optional[int] = request_count_dict.get(id, 0)
+                if request_count_value is None:
+                    return
+                request_count_dict[id] = request_count_value - 1
+                self.router_cache.set_cache(
+                    key=request_count_api_key, value=request_count_dict
+                )
+
+                ### TESTING ###
+                if self.test_flag:
+                    self.logged_success += 1
+        except Exception:
+            pass
+
+    def log_failure_event(self, kwargs, response_obj, start_time, end_time):
+        try:
+            if kwargs["litellm_params"].get("metadata") is None:
+                pass
+            else:
+                model_group = kwargs["litellm_params"]["metadata"].get(
+                    "model_group", None
+                )
+                id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
+                if model_group is None or id is None:
+                    return
+                elif isinstance(id, int):
+                    id = str(id)
+
+                request_count_api_key = f"{model_group}_request_count"
+                # decrement count in cache
+                request_count_dict = (
+                    self.router_cache.get_cache(key=request_count_api_key) or {}
+                )
+                request_count_value: Optional[int] = request_count_dict.get(id, 0)
+                if request_count_value is None:
+                    return
+                request_count_dict[id] = request_count_value - 1
+                self.router_cache.set_cache(
+                    key=request_count_api_key, value=request_count_dict
+                )
+
+                ### TESTING ###
+                if self.test_flag:
+                    self.logged_failure += 1
+        except Exception:
+            pass
+
+    async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
+        try:
+            if kwargs["litellm_params"].get("metadata") is None:
+                pass
+            else:
+                model_group = kwargs["litellm_params"]["metadata"].get(
+                    "model_group", None
+                )
+
+                id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
+                if model_group is None or id is None:
+                    return
+                elif isinstance(id, int):
+                    id = str(id)
+
+                request_count_api_key = f"{model_group}_request_count"
+                # decrement count in cache
+                request_count_dict = (
+                    await self.router_cache.async_get_cache(key=request_count_api_key)
+                    or {}
+                )
+                request_count_value: Optional[int] = request_count_dict.get(id, 0)
+                if request_count_value is None:
+                    return
+                request_count_dict[id] = request_count_value - 1
+                await self.router_cache.async_set_cache(
+                    key=request_count_api_key, value=request_count_dict
+                )
+
+                ### TESTING ###
+                if self.test_flag:
+                    self.logged_success += 1
+        except Exception:
+            pass
+
+    async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
+        try:
+            if kwargs["litellm_params"].get("metadata") is None:
+                pass
+            else:
+                model_group = kwargs["litellm_params"]["metadata"].get(
+                    "model_group", None
+                )
+                id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
+                if model_group is None or id is None:
+                    return
+                elif isinstance(id, int):
+                    id = str(id)
+
+                request_count_api_key = f"{model_group}_request_count"
+                # decrement count in cache
+                request_count_dict = (
+                    await self.router_cache.async_get_cache(key=request_count_api_key)
+                    or {}
+                )
+                request_count_value: Optional[int] = request_count_dict.get(id, 0)
+                if request_count_value is None:
+                    return
+                request_count_dict[id] = request_count_value - 1
+                await self.router_cache.async_set_cache(
+                    key=request_count_api_key, value=request_count_dict
+                )
+
+                ### TESTING ###
+                if self.test_flag:
+                    self.logged_failure += 1
+        except Exception:
+            pass
+
+    def _get_available_deployments(
+        self,
+        healthy_deployments: list,
+        all_deployments: dict,
+    ):
+        """
+        Helper to get deployments using least busy strategy
+        """
+        for d in healthy_deployments:
+            ## if healthy deployment not yet used
+            if d["model_info"]["id"] not in all_deployments:
+                all_deployments[d["model_info"]["id"]] = 0
+        # map deployment to id
+        # pick least busy deployment
+        min_traffic = float("inf")
+        min_deployment = None
+        for k, v in all_deployments.items():
+            if v < min_traffic:
+                min_traffic = v
+                min_deployment = k
+        if min_deployment is not None:
+            ## check if min deployment is a string, if so, cast it to int
+            for m in healthy_deployments:
+                if m["model_info"]["id"] == min_deployment:
+                    return m
+            min_deployment = random.choice(healthy_deployments)
+        else:
+            min_deployment = random.choice(healthy_deployments)
+        return min_deployment
+
+    def get_available_deployments(
+        self,
+        model_group: str,
+        healthy_deployments: list,
+    ):
+        """
+        Sync helper to get deployments using least busy strategy
+        """
+        request_count_api_key = f"{model_group}_request_count"
+        all_deployments = self.router_cache.get_cache(key=request_count_api_key) or {}
+        return self._get_available_deployments(
+            healthy_deployments=healthy_deployments,
+            all_deployments=all_deployments,
+        )
+
+    async def async_get_available_deployments(
+        self, model_group: str, healthy_deployments: list
+    ):
+        """
+        Async helper to get deployments using least busy strategy
+        """
+        request_count_api_key = f"{model_group}_request_count"
+        all_deployments = (
+            await self.router_cache.async_get_cache(key=request_count_api_key) or {}
+        )
+        return self._get_available_deployments(
+            healthy_deployments=healthy_deployments,
+            all_deployments=all_deployments,
+        )