diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/router_strategy/least_busy.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/router_strategy/least_busy.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/router_strategy/least_busy.py | 252 |
1 files changed, 252 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/router_strategy/least_busy.py b/.venv/lib/python3.12/site-packages/litellm/router_strategy/least_busy.py new file mode 100644 index 00000000..12f3f01c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/router_strategy/least_busy.py @@ -0,0 +1,252 @@ +#### What this does #### +# identifies least busy deployment +# How is this achieved? +# - Before each call, have the router print the state of requests {"deployment": "requests_in_flight"} +# - use litellm.input_callbacks to log when a request is just about to be made to a model - {"deployment-id": traffic} +# - use litellm.success + failure callbacks to log when a request completed +# - in get_available_deployment, for a given model group name -> pick based on traffic + +import random +from typing import Optional + +from litellm.caching.caching import DualCache +from litellm.integrations.custom_logger import CustomLogger + + +class LeastBusyLoggingHandler(CustomLogger): + test_flag: bool = False + logged_success: int = 0 + logged_failure: int = 0 + + def __init__(self, router_cache: DualCache, model_list: list): + self.router_cache = router_cache + self.mapping_deployment_to_id: dict = {} + self.model_list = model_list + + def log_pre_api_call(self, model, messages, kwargs): + """ + Log when a model is being used. + + Caching based on model group. + """ + try: + if kwargs["litellm_params"].get("metadata") is None: + pass + else: + model_group = kwargs["litellm_params"]["metadata"].get( + "model_group", None + ) + id = kwargs["litellm_params"].get("model_info", {}).get("id", None) + if model_group is None or id is None: + return + elif isinstance(id, int): + id = str(id) + + request_count_api_key = f"{model_group}_request_count" + # update cache + request_count_dict = ( + self.router_cache.get_cache(key=request_count_api_key) or {} + ) + request_count_dict[id] = request_count_dict.get(id, 0) + 1 + + self.router_cache.set_cache( + key=request_count_api_key, value=request_count_dict + ) + except Exception: + pass + + def log_success_event(self, kwargs, response_obj, start_time, end_time): + try: + if kwargs["litellm_params"].get("metadata") is None: + pass + else: + model_group = kwargs["litellm_params"]["metadata"].get( + "model_group", None + ) + + id = kwargs["litellm_params"].get("model_info", {}).get("id", None) + if model_group is None or id is None: + return + elif isinstance(id, int): + id = str(id) + + request_count_api_key = f"{model_group}_request_count" + # decrement count in cache + request_count_dict = ( + self.router_cache.get_cache(key=request_count_api_key) or {} + ) + request_count_value: Optional[int] = request_count_dict.get(id, 0) + if request_count_value is None: + return + request_count_dict[id] = request_count_value - 1 + self.router_cache.set_cache( + key=request_count_api_key, value=request_count_dict + ) + + ### TESTING ### + if self.test_flag: + self.logged_success += 1 + except Exception: + pass + + def log_failure_event(self, kwargs, response_obj, start_time, end_time): + try: + if kwargs["litellm_params"].get("metadata") is None: + pass + else: + model_group = kwargs["litellm_params"]["metadata"].get( + "model_group", None + ) + id = kwargs["litellm_params"].get("model_info", {}).get("id", None) + if model_group is None or id is None: + return + elif isinstance(id, int): + id = str(id) + + request_count_api_key = f"{model_group}_request_count" + # decrement count in cache + request_count_dict = ( + self.router_cache.get_cache(key=request_count_api_key) or {} + ) + request_count_value: Optional[int] = request_count_dict.get(id, 0) + if request_count_value is None: + return + request_count_dict[id] = request_count_value - 1 + self.router_cache.set_cache( + key=request_count_api_key, value=request_count_dict + ) + + ### TESTING ### + if self.test_flag: + self.logged_failure += 1 + except Exception: + pass + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + try: + if kwargs["litellm_params"].get("metadata") is None: + pass + else: + model_group = kwargs["litellm_params"]["metadata"].get( + "model_group", None + ) + + id = kwargs["litellm_params"].get("model_info", {}).get("id", None) + if model_group is None or id is None: + return + elif isinstance(id, int): + id = str(id) + + request_count_api_key = f"{model_group}_request_count" + # decrement count in cache + request_count_dict = ( + await self.router_cache.async_get_cache(key=request_count_api_key) + or {} + ) + request_count_value: Optional[int] = request_count_dict.get(id, 0) + if request_count_value is None: + return + request_count_dict[id] = request_count_value - 1 + await self.router_cache.async_set_cache( + key=request_count_api_key, value=request_count_dict + ) + + ### TESTING ### + if self.test_flag: + self.logged_success += 1 + except Exception: + pass + + async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): + try: + if kwargs["litellm_params"].get("metadata") is None: + pass + else: + model_group = kwargs["litellm_params"]["metadata"].get( + "model_group", None + ) + id = kwargs["litellm_params"].get("model_info", {}).get("id", None) + if model_group is None or id is None: + return + elif isinstance(id, int): + id = str(id) + + request_count_api_key = f"{model_group}_request_count" + # decrement count in cache + request_count_dict = ( + await self.router_cache.async_get_cache(key=request_count_api_key) + or {} + ) + request_count_value: Optional[int] = request_count_dict.get(id, 0) + if request_count_value is None: + return + request_count_dict[id] = request_count_value - 1 + await self.router_cache.async_set_cache( + key=request_count_api_key, value=request_count_dict + ) + + ### TESTING ### + if self.test_flag: + self.logged_failure += 1 + except Exception: + pass + + def _get_available_deployments( + self, + healthy_deployments: list, + all_deployments: dict, + ): + """ + Helper to get deployments using least busy strategy + """ + for d in healthy_deployments: + ## if healthy deployment not yet used + if d["model_info"]["id"] not in all_deployments: + all_deployments[d["model_info"]["id"]] = 0 + # map deployment to id + # pick least busy deployment + min_traffic = float("inf") + min_deployment = None + for k, v in all_deployments.items(): + if v < min_traffic: + min_traffic = v + min_deployment = k + if min_deployment is not None: + ## check if min deployment is a string, if so, cast it to int + for m in healthy_deployments: + if m["model_info"]["id"] == min_deployment: + return m + min_deployment = random.choice(healthy_deployments) + else: + min_deployment = random.choice(healthy_deployments) + return min_deployment + + def get_available_deployments( + self, + model_group: str, + healthy_deployments: list, + ): + """ + Sync helper to get deployments using least busy strategy + """ + request_count_api_key = f"{model_group}_request_count" + all_deployments = self.router_cache.get_cache(key=request_count_api_key) or {} + return self._get_available_deployments( + healthy_deployments=healthy_deployments, + all_deployments=all_deployments, + ) + + async def async_get_available_deployments( + self, model_group: str, healthy_deployments: list + ): + """ + Async helper to get deployments using least busy strategy + """ + request_count_api_key = f"{model_group}_request_count" + all_deployments = ( + await self.router_cache.async_get_cache(key=request_count_api_key) or {} + ) + return self._get_available_deployments( + healthy_deployments=healthy_deployments, + all_deployments=all_deployments, + ) |