diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/router_utils/client_initalization_utils.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/router_utils/client_initalization_utils.py | 37 |
1 files changed, 37 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/router_utils/client_initalization_utils.py b/.venv/lib/python3.12/site-packages/litellm/router_utils/client_initalization_utils.py new file mode 100644 index 00000000..e24d2378 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/router_utils/client_initalization_utils.py @@ -0,0 +1,37 @@ +import asyncio +from typing import TYPE_CHECKING, Any + +from litellm.utils import calculate_max_parallel_requests + +if TYPE_CHECKING: + from litellm.router import Router as _Router + + LitellmRouter = _Router +else: + LitellmRouter = Any + + +class InitalizeCachedClient: + @staticmethod + def set_max_parallel_requests_client( + litellm_router_instance: LitellmRouter, model: dict + ): + litellm_params = model.get("litellm_params", {}) + model_id = model["model_info"]["id"] + rpm = litellm_params.get("rpm", None) + tpm = litellm_params.get("tpm", None) + max_parallel_requests = litellm_params.get("max_parallel_requests", None) + calculated_max_parallel_requests = calculate_max_parallel_requests( + rpm=rpm, + max_parallel_requests=max_parallel_requests, + tpm=tpm, + default_max_parallel_requests=litellm_router_instance.default_max_parallel_requests, + ) + if calculated_max_parallel_requests: + semaphore = asyncio.Semaphore(calculated_max_parallel_requests) + cache_key = f"{model_id}_max_parallel_requests_client" + litellm_router_instance.cache.set_cache( + key=cache_key, + value=semaphore, + local_only=True, + ) |