aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/router_utils/client_initalization_utils.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/router_utils/client_initalization_utils.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/router_utils/client_initalization_utils.py37
1 files changed, 37 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/router_utils/client_initalization_utils.py b/.venv/lib/python3.12/site-packages/litellm/router_utils/client_initalization_utils.py
new file mode 100644
index 00000000..e24d2378
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/router_utils/client_initalization_utils.py
@@ -0,0 +1,37 @@
+import asyncio
+from typing import TYPE_CHECKING, Any
+
+from litellm.utils import calculate_max_parallel_requests
+
+if TYPE_CHECKING:
+ from litellm.router import Router as _Router
+
+ LitellmRouter = _Router
+else:
+ LitellmRouter = Any
+
+
+class InitalizeCachedClient:
+ @staticmethod
+ def set_max_parallel_requests_client(
+ litellm_router_instance: LitellmRouter, model: dict
+ ):
+ litellm_params = model.get("litellm_params", {})
+ model_id = model["model_info"]["id"]
+ rpm = litellm_params.get("rpm", None)
+ tpm = litellm_params.get("tpm", None)
+ max_parallel_requests = litellm_params.get("max_parallel_requests", None)
+ calculated_max_parallel_requests = calculate_max_parallel_requests(
+ rpm=rpm,
+ max_parallel_requests=max_parallel_requests,
+ tpm=tpm,
+ default_max_parallel_requests=litellm_router_instance.default_max_parallel_requests,
+ )
+ if calculated_max_parallel_requests:
+ semaphore = asyncio.Semaphore(calculated_max_parallel_requests)
+ cache_key = f"{model_id}_max_parallel_requests_client"
+ litellm_router_instance.cache.set_cache(
+ key=cache_key,
+ value=semaphore,
+ local_only=True,
+ )