aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/proxy/auth/model_checks.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/proxy/auth/model_checks.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/auth/model_checks.py197
1 files changed, 197 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/auth/model_checks.py b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/model_checks.py
new file mode 100644
index 00000000..a48ef6ae
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/model_checks.py
@@ -0,0 +1,197 @@
+# What is this?
+## Common checks for /v1/models and `/model/info`
+import copy
+from typing import Dict, List, Optional, Set
+
+import litellm
+from litellm._logging import verbose_proxy_logger
+from litellm.proxy._types import SpecialModelNames, UserAPIKeyAuth
+from litellm.utils import get_valid_models
+
+
+def _check_wildcard_routing(model: str) -> bool:
+ """
+ Returns True if a model is a provider wildcard.
+
+ eg:
+ - anthropic/*
+ - openai/*
+ - *
+ """
+ if "*" in model:
+ return True
+ return False
+
+
+def get_provider_models(provider: str) -> Optional[List[str]]:
+ """
+ Returns the list of known models by provider
+ """
+ if provider == "*":
+ return get_valid_models()
+
+ if provider in litellm.models_by_provider:
+ provider_models = copy.deepcopy(litellm.models_by_provider[provider])
+ for idx, _model in enumerate(provider_models):
+ if provider not in _model:
+ provider_models[idx] = f"{provider}/{_model}"
+ return provider_models
+ return None
+
+
+def _get_models_from_access_groups(
+ model_access_groups: Dict[str, List[str]],
+ all_models: List[str],
+) -> List[str]:
+ idx_to_remove = []
+ new_models = []
+ for idx, model in enumerate(all_models):
+ if model in model_access_groups:
+ idx_to_remove.append(idx)
+ new_models.extend(model_access_groups[model])
+
+ for idx in sorted(idx_to_remove, reverse=True):
+ all_models.pop(idx)
+
+ all_models.extend(new_models)
+ return all_models
+
+
+def get_key_models(
+ user_api_key_dict: UserAPIKeyAuth,
+ proxy_model_list: List[str],
+ model_access_groups: Dict[str, List[str]],
+) -> List[str]:
+ """
+ Returns:
+ - List of model name strings
+ - Empty list if no models set
+ - If model_access_groups is provided, only return models that are in the access groups
+ """
+ all_models: List[str] = []
+ if len(user_api_key_dict.models) > 0:
+ all_models = user_api_key_dict.models
+ if SpecialModelNames.all_team_models.value in all_models:
+ all_models = user_api_key_dict.team_models
+ if SpecialModelNames.all_proxy_models.value in all_models:
+ all_models = proxy_model_list
+
+ all_models = _get_models_from_access_groups(
+ model_access_groups=model_access_groups, all_models=all_models
+ )
+
+ verbose_proxy_logger.debug("ALL KEY MODELS - {}".format(len(all_models)))
+ return all_models
+
+
+def get_team_models(
+ team_models: List[str],
+ proxy_model_list: List[str],
+ model_access_groups: Dict[str, List[str]],
+) -> List[str]:
+ """
+ Returns:
+ - List of model name strings
+ - Empty list if no models set
+ - If model_access_groups is provided, only return models that are in the access groups
+ """
+ all_models = []
+ if len(team_models) > 0:
+ all_models = team_models
+ if SpecialModelNames.all_team_models.value in all_models:
+ all_models = team_models
+ if SpecialModelNames.all_proxy_models.value in all_models:
+ all_models = proxy_model_list
+
+ all_models = _get_models_from_access_groups(
+ model_access_groups=model_access_groups, all_models=all_models
+ )
+
+ verbose_proxy_logger.debug("ALL TEAM MODELS - {}".format(len(all_models)))
+ return all_models
+
+
+def get_complete_model_list(
+ key_models: List[str],
+ team_models: List[str],
+ proxy_model_list: List[str],
+ user_model: Optional[str],
+ infer_model_from_keys: Optional[bool],
+ return_wildcard_routes: Optional[bool] = False,
+) -> List[str]:
+ """Logic for returning complete model list for a given key + team pair"""
+
+ """
+ - If key list is empty -> defer to team list
+ - If team list is empty -> defer to proxy model list
+
+ If list contains wildcard -> return known provider models
+ """
+ unique_models: Set[str] = set()
+ if key_models:
+ unique_models.update(key_models)
+ elif team_models:
+ unique_models.update(team_models)
+ else:
+ unique_models.update(proxy_model_list)
+
+ if user_model:
+ unique_models.add(user_model)
+
+ if infer_model_from_keys:
+ valid_models = get_valid_models()
+ unique_models.update(valid_models)
+
+ all_wildcard_models = _get_wildcard_models(
+ unique_models=unique_models, return_wildcard_routes=return_wildcard_routes
+ )
+
+ return list(unique_models) + all_wildcard_models
+
+
+def get_known_models_from_wildcard(wildcard_model: str) -> List[str]:
+ try:
+ provider, model = wildcard_model.split("/", 1)
+ except ValueError: # safely fail
+ return []
+ # get all known provider models
+ wildcard_models = get_provider_models(provider=provider)
+ if wildcard_models is None:
+ return []
+ if model == "*":
+ return wildcard_models or []
+ else:
+ model_prefix = model.replace("*", "")
+ filtered_wildcard_models = [
+ wc_model
+ for wc_model in wildcard_models
+ if wc_model.split("/")[1].startswith(model_prefix)
+ ]
+
+ return filtered_wildcard_models
+
+
+def _get_wildcard_models(
+ unique_models: Set[str], return_wildcard_routes: Optional[bool] = False
+) -> List[str]:
+ models_to_remove = set()
+ all_wildcard_models = []
+ for model in unique_models:
+ if _check_wildcard_routing(model=model):
+
+ if (
+ return_wildcard_routes
+ ): # will add the wildcard route to the list eg: anthropic/*.
+ all_wildcard_models.append(model)
+
+ # get all known provider models
+ wildcard_models = get_known_models_from_wildcard(wildcard_model=model)
+
+ if wildcard_models is not None:
+ models_to_remove.add(model)
+ all_wildcard_models.extend(wildcard_models)
+
+ for model in models_to_remove:
+ unique_models.remove(model)
+
+ return all_wildcard_models