about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/litellm/proxy/auth/model_checks.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/proxy/auth/model_checks.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/proxy/auth/model_checks.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/auth/model_checks.py197
1 files changed, 197 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/auth/model_checks.py b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/model_checks.py
new file mode 100644
index 00000000..a48ef6ae
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/model_checks.py
@@ -0,0 +1,197 @@
+# What is this?
+## Common checks for /v1/models and `/model/info`
+import copy
+from typing import Dict, List, Optional, Set
+
+import litellm
+from litellm._logging import verbose_proxy_logger
+from litellm.proxy._types import SpecialModelNames, UserAPIKeyAuth
+from litellm.utils import get_valid_models
+
+
+def _check_wildcard_routing(model: str) -> bool:
+    """
+    Returns True if a model is a provider wildcard.
+
+    eg:
+    - anthropic/*
+    - openai/*
+    - *
+    """
+    if "*" in model:
+        return True
+    return False
+
+
+def get_provider_models(provider: str) -> Optional[List[str]]:
+    """
+    Returns the list of known models by provider
+    """
+    if provider == "*":
+        return get_valid_models()
+
+    if provider in litellm.models_by_provider:
+        provider_models = copy.deepcopy(litellm.models_by_provider[provider])
+        for idx, _model in enumerate(provider_models):
+            if provider not in _model:
+                provider_models[idx] = f"{provider}/{_model}"
+        return provider_models
+    return None
+
+
+def _get_models_from_access_groups(
+    model_access_groups: Dict[str, List[str]],
+    all_models: List[str],
+) -> List[str]:
+    idx_to_remove = []
+    new_models = []
+    for idx, model in enumerate(all_models):
+        if model in model_access_groups:
+            idx_to_remove.append(idx)
+            new_models.extend(model_access_groups[model])
+
+    for idx in sorted(idx_to_remove, reverse=True):
+        all_models.pop(idx)
+
+    all_models.extend(new_models)
+    return all_models
+
+
+def get_key_models(
+    user_api_key_dict: UserAPIKeyAuth,
+    proxy_model_list: List[str],
+    model_access_groups: Dict[str, List[str]],
+) -> List[str]:
+    """
+    Returns:
+    - List of model name strings
+    - Empty list if no models set
+    - If model_access_groups is provided, only return models that are in the access groups
+    """
+    all_models: List[str] = []
+    if len(user_api_key_dict.models) > 0:
+        all_models = user_api_key_dict.models
+        if SpecialModelNames.all_team_models.value in all_models:
+            all_models = user_api_key_dict.team_models
+        if SpecialModelNames.all_proxy_models.value in all_models:
+            all_models = proxy_model_list
+
+    all_models = _get_models_from_access_groups(
+        model_access_groups=model_access_groups, all_models=all_models
+    )
+
+    verbose_proxy_logger.debug("ALL KEY MODELS - {}".format(len(all_models)))
+    return all_models
+
+
+def get_team_models(
+    team_models: List[str],
+    proxy_model_list: List[str],
+    model_access_groups: Dict[str, List[str]],
+) -> List[str]:
+    """
+    Returns:
+    - List of model name strings
+    - Empty list if no models set
+    - If model_access_groups is provided, only return models that are in the access groups
+    """
+    all_models = []
+    if len(team_models) > 0:
+        all_models = team_models
+        if SpecialModelNames.all_team_models.value in all_models:
+            all_models = team_models
+        if SpecialModelNames.all_proxy_models.value in all_models:
+            all_models = proxy_model_list
+
+    all_models = _get_models_from_access_groups(
+        model_access_groups=model_access_groups, all_models=all_models
+    )
+
+    verbose_proxy_logger.debug("ALL TEAM MODELS - {}".format(len(all_models)))
+    return all_models
+
+
+def get_complete_model_list(
+    key_models: List[str],
+    team_models: List[str],
+    proxy_model_list: List[str],
+    user_model: Optional[str],
+    infer_model_from_keys: Optional[bool],
+    return_wildcard_routes: Optional[bool] = False,
+) -> List[str]:
+    """Logic for returning complete model list for a given key + team pair"""
+
+    """
+    - If key list is empty -> defer to team list
+    - If team list is empty -> defer to proxy model list
+
+    If list contains wildcard -> return known provider models
+    """
+    unique_models: Set[str] = set()
+    if key_models:
+        unique_models.update(key_models)
+    elif team_models:
+        unique_models.update(team_models)
+    else:
+        unique_models.update(proxy_model_list)
+
+        if user_model:
+            unique_models.add(user_model)
+
+        if infer_model_from_keys:
+            valid_models = get_valid_models()
+            unique_models.update(valid_models)
+
+    all_wildcard_models = _get_wildcard_models(
+        unique_models=unique_models, return_wildcard_routes=return_wildcard_routes
+    )
+
+    return list(unique_models) + all_wildcard_models
+
+
+def get_known_models_from_wildcard(wildcard_model: str) -> List[str]:
+    try:
+        provider, model = wildcard_model.split("/", 1)
+    except ValueError:  # safely fail
+        return []
+    # get all known provider models
+    wildcard_models = get_provider_models(provider=provider)
+    if wildcard_models is None:
+        return []
+    if model == "*":
+        return wildcard_models or []
+    else:
+        model_prefix = model.replace("*", "")
+        filtered_wildcard_models = [
+            wc_model
+            for wc_model in wildcard_models
+            if wc_model.split("/")[1].startswith(model_prefix)
+        ]
+
+        return filtered_wildcard_models
+
+
+def _get_wildcard_models(
+    unique_models: Set[str], return_wildcard_routes: Optional[bool] = False
+) -> List[str]:
+    models_to_remove = set()
+    all_wildcard_models = []
+    for model in unique_models:
+        if _check_wildcard_routing(model=model):
+
+            if (
+                return_wildcard_routes
+            ):  # will add the wildcard route to the list eg: anthropic/*.
+                all_wildcard_models.append(model)
+
+            # get all known provider models
+            wildcard_models = get_known_models_from_wildcard(wildcard_model=model)
+
+            if wildcard_models is not None:
+                models_to_remove.add(model)
+                all_wildcard_models.extend(wildcard_models)
+
+    for model in models_to_remove:
+        unique_models.remove(model)
+
+    return all_wildcard_models