about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/litellm/router_utils/client_initalization_utils.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/router_utils/client_initalization_utils.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/router_utils/client_initalization_utils.py37
1 files changed, 37 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/router_utils/client_initalization_utils.py b/.venv/lib/python3.12/site-packages/litellm/router_utils/client_initalization_utils.py
new file mode 100644
index 00000000..e24d2378
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/router_utils/client_initalization_utils.py
@@ -0,0 +1,37 @@
+import asyncio
+from typing import TYPE_CHECKING, Any
+
+from litellm.utils import calculate_max_parallel_requests
+
+if TYPE_CHECKING:
+    from litellm.router import Router as _Router
+
+    LitellmRouter = _Router
+else:
+    LitellmRouter = Any
+
+
+class InitalizeCachedClient:
+    @staticmethod
+    def set_max_parallel_requests_client(
+        litellm_router_instance: LitellmRouter, model: dict
+    ):
+        litellm_params = model.get("litellm_params", {})
+        model_id = model["model_info"]["id"]
+        rpm = litellm_params.get("rpm", None)
+        tpm = litellm_params.get("tpm", None)
+        max_parallel_requests = litellm_params.get("max_parallel_requests", None)
+        calculated_max_parallel_requests = calculate_max_parallel_requests(
+            rpm=rpm,
+            max_parallel_requests=max_parallel_requests,
+            tpm=tpm,
+            default_max_parallel_requests=litellm_router_instance.default_max_parallel_requests,
+        )
+        if calculated_max_parallel_requests:
+            semaphore = asyncio.Semaphore(calculated_max_parallel_requests)
+            cache_key = f"{model_id}_max_parallel_requests_client"
+            litellm_router_instance.cache.set_cache(
+                key=cache_key,
+                value=semaphore,
+                local_only=True,
+            )