about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/litellm/proxy/management_endpoints/model_management_endpoints.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/proxy/management_endpoints/model_management_endpoints.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/management_endpoints/model_management_endpoints.py719
1 files changed, 719 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/management_endpoints/model_management_endpoints.py b/.venv/lib/python3.12/site-packages/litellm/proxy/management_endpoints/model_management_endpoints.py
new file mode 100644
index 00000000..2a2b7eae
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/management_endpoints/model_management_endpoints.py
@@ -0,0 +1,719 @@
+"""
+Allow proxy admin to add/update/delete models in the db
+
+Currently most endpoints are in `proxy_server.py`, but those should  be moved here over time.
+
+Endpoints here: 
+
+model/{model_id}/update - PATCH endpoint for model update.
+"""
+
+#### MODEL MANAGEMENT ####
+
+import asyncio
+import json
+import uuid
+from typing import Optional, cast
+
+from fastapi import APIRouter, Depends, HTTPException, Request, status
+from pydantic import BaseModel
+
+from litellm._logging import verbose_proxy_logger
+from litellm.constants import LITELLM_PROXY_ADMIN_NAME
+from litellm.proxy._types import (
+    CommonProxyErrors,
+    LiteLLM_ProxyModelTable,
+    LitellmTableNames,
+    LitellmUserRoles,
+    ModelInfoDelete,
+    PrismaCompatibleUpdateDBModel,
+    ProxyErrorTypes,
+    ProxyException,
+    TeamModelAddRequest,
+    UpdateTeamRequest,
+    UserAPIKeyAuth,
+)
+from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
+from litellm.proxy.common_utils.encrypt_decrypt_utils import encrypt_value_helper
+from litellm.proxy.management_endpoints.team_endpoints import (
+    team_model_add,
+    update_team,
+)
+from litellm.proxy.management_helpers.audit_logs import create_object_audit_log
+from litellm.proxy.utils import PrismaClient
+from litellm.types.router import (
+    Deployment,
+    DeploymentTypedDict,
+    LiteLLMParamsTypedDict,
+    updateDeployment,
+)
+from litellm.utils import get_utc_datetime
+
+router = APIRouter()
+
+
+async def get_db_model(
+    model_id: str, prisma_client: PrismaClient
+) -> Optional[Deployment]:
+    db_model = cast(
+        Optional[BaseModel],
+        await prisma_client.db.litellm_proxymodeltable.find_unique(
+            where={"model_id": model_id}
+        ),
+    )
+
+    if not db_model:
+        return None
+
+    deployment_pydantic_obj = Deployment(**db_model.model_dump(exclude_none=True))
+    return deployment_pydantic_obj
+
+
+def update_db_model(
+    db_model: Deployment, updated_patch: updateDeployment
+) -> PrismaCompatibleUpdateDBModel:
+    merged_deployment_dict = DeploymentTypedDict(
+        model_name=db_model.model_name,
+        litellm_params=LiteLLMParamsTypedDict(
+            **db_model.litellm_params.model_dump(exclude_none=True)  # type: ignore
+        ),
+    )
+    # update model name
+    if updated_patch.model_name:
+        merged_deployment_dict["model_name"] = updated_patch.model_name
+
+    # update litellm params
+    if updated_patch.litellm_params:
+        # Encrypt any sensitive values
+        encrypted_params = {
+            k: encrypt_value_helper(v)
+            for k, v in updated_patch.litellm_params.model_dump(
+                exclude_none=True
+            ).items()
+        }
+
+        merged_deployment_dict["litellm_params"].update(encrypted_params)  # type: ignore
+
+    # update model info
+    if updated_patch.model_info:
+        if "model_info" not in merged_deployment_dict:
+            merged_deployment_dict["model_info"] = {}
+        merged_deployment_dict["model_info"].update(
+            updated_patch.model_info.model_dump(exclude_none=True)
+        )
+
+    # convert to prisma compatible format
+
+    prisma_compatible_model_dict = PrismaCompatibleUpdateDBModel()
+    if "model_name" in merged_deployment_dict:
+        prisma_compatible_model_dict["model_name"] = merged_deployment_dict[
+            "model_name"
+        ]
+
+    if "litellm_params" in merged_deployment_dict:
+        prisma_compatible_model_dict["litellm_params"] = json.dumps(
+            merged_deployment_dict["litellm_params"]
+        )
+
+    if "model_info" in merged_deployment_dict:
+        prisma_compatible_model_dict["model_info"] = json.dumps(
+            merged_deployment_dict["model_info"]
+        )
+    return prisma_compatible_model_dict
+
+
+@router.patch(
+    "/model/{model_id}/update",
+    tags=["model management"],
+    dependencies=[Depends(user_api_key_auth)],
+)
+async def patch_model(
+    model_id: str,  # Get model_id from path parameter
+    patch_data: updateDeployment,  # Create a specific schema for PATCH operations
+    user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
+):
+    """
+    PATCH Endpoint for partial model updates.
+
+    Only updates the fields specified in the request while preserving other existing values.
+    Follows proper PATCH semantics by only modifying provided fields.
+
+    Args:
+        model_id: The ID of the model to update
+        patch_data: The fields to update and their new values
+        user_api_key_dict: User authentication information
+
+    Returns:
+        Updated model information
+
+    Raises:
+        ProxyException: For various error conditions including authentication and database errors
+    """
+    from litellm.proxy.proxy_server import (
+        litellm_proxy_admin_name,
+        llm_router,
+        prisma_client,
+        store_model_in_db,
+    )
+
+    try:
+        if prisma_client is None:
+            raise HTTPException(
+                status_code=500,
+                detail={"error": CommonProxyErrors.db_not_connected_error.value},
+            )
+
+        # Verify model exists and is stored in DB
+        if not store_model_in_db:
+            raise ProxyException(
+                message="Model updates only supported for DB-stored models",
+                type=ProxyErrorTypes.validation_error.value,
+                code=status.HTTP_400_BAD_REQUEST,
+                param=None,
+            )
+
+        # Fetch existing model
+        db_model = await get_db_model(model_id=model_id, prisma_client=prisma_client)
+
+        if db_model is None:
+            # Check if model exists in config but not DB
+            if llm_router and llm_router.get_deployment(model_id=model_id) is not None:
+                raise ProxyException(
+                    message="Cannot edit config-based model. Store model in DB via /model/new first.",
+                    type=ProxyErrorTypes.validation_error.value,
+                    code=status.HTTP_400_BAD_REQUEST,
+                    param=None,
+                )
+            raise ProxyException(
+                message=f"Model {model_id} not found on proxy.",
+                type=ProxyErrorTypes.not_found_error,
+                code=status.HTTP_404_NOT_FOUND,
+                param=None,
+            )
+
+        # Create update dictionary only for provided fields
+        update_data = update_db_model(db_model=db_model, updated_patch=patch_data)
+
+        # Add metadata about update
+        update_data["updated_by"] = (
+            user_api_key_dict.user_id or litellm_proxy_admin_name
+        )
+        update_data["updated_at"] = cast(str, get_utc_datetime())
+
+        # Perform partial update
+        updated_model = await prisma_client.db.litellm_proxymodeltable.update(
+            where={"model_id": model_id},
+            data=update_data,
+        )
+
+        return updated_model
+
+    except Exception as e:
+        verbose_proxy_logger.exception(f"Error in patch_model: {str(e)}")
+
+        if isinstance(e, (HTTPException, ProxyException)):
+            raise e
+
+        raise ProxyException(
+            message=f"Error updating model: {str(e)}",
+            type=ProxyErrorTypes.internal_server_error,
+            code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+            param=None,
+        )
+
+
+################################# Helper Functions #################################
+####################################################################################
+####################################################################################
+####################################################################################
+
+
+async def _add_model_to_db(
+    model_params: Deployment,
+    user_api_key_dict: UserAPIKeyAuth,
+    prisma_client: PrismaClient,
+    new_encryption_key: Optional[str] = None,
+    should_create_model_in_db: bool = True,
+) -> Optional[LiteLLM_ProxyModelTable]:
+    # encrypt litellm params #
+    _litellm_params_dict = model_params.litellm_params.dict(exclude_none=True)
+    _orignal_litellm_model_name = model_params.litellm_params.model
+    for k, v in _litellm_params_dict.items():
+        encrypted_value = encrypt_value_helper(
+            value=v, new_encryption_key=new_encryption_key
+        )
+        model_params.litellm_params[k] = encrypted_value
+    _data: dict = {
+        "model_id": model_params.model_info.id,
+        "model_name": model_params.model_name,
+        "litellm_params": model_params.litellm_params.model_dump_json(exclude_none=True),  # type: ignore
+        "model_info": model_params.model_info.model_dump_json(  # type: ignore
+            exclude_none=True
+        ),
+        "created_by": user_api_key_dict.user_id or LITELLM_PROXY_ADMIN_NAME,
+        "updated_by": user_api_key_dict.user_id or LITELLM_PROXY_ADMIN_NAME,
+    }
+    if model_params.model_info.id is not None:
+        _data["model_id"] = model_params.model_info.id
+    if should_create_model_in_db:
+        model_response = await prisma_client.db.litellm_proxymodeltable.create(
+            data=_data  # type: ignore
+        )
+    else:
+        model_response = LiteLLM_ProxyModelTable(**_data)
+    return model_response
+
+
+async def _add_team_model_to_db(
+    model_params: Deployment,
+    user_api_key_dict: UserAPIKeyAuth,
+    prisma_client: PrismaClient,
+) -> Optional[LiteLLM_ProxyModelTable]:
+    """
+    If 'team_id' is provided,
+
+    - generate a unique 'model_name' for the model (e.g. 'model_name_{team_id}_{uuid})
+    - store the model in the db with the unique 'model_name'
+    - store a team model alias mapping {"model_name": "model_name_{team_id}_{uuid}"}
+    """
+    _team_id = model_params.model_info.team_id
+    if _team_id is None:
+        return None
+    original_model_name = model_params.model_name
+    if original_model_name:
+        model_params.model_info.team_public_model_name = original_model_name
+
+    unique_model_name = f"model_name_{_team_id}_{uuid.uuid4()}"
+
+    model_params.model_name = unique_model_name
+
+    ## CREATE MODEL IN DB ##
+    model_response = await _add_model_to_db(
+        model_params=model_params,
+        user_api_key_dict=user_api_key_dict,
+        prisma_client=prisma_client,
+    )
+
+    ## CREATE MODEL ALIAS IN DB ##
+    await update_team(
+        data=UpdateTeamRequest(
+            team_id=_team_id,
+            model_aliases={original_model_name: unique_model_name},
+        ),
+        user_api_key_dict=user_api_key_dict,
+        http_request=Request(scope={"type": "http"}),
+    )
+
+    # add model to team object
+    await team_model_add(
+        data=TeamModelAddRequest(
+            team_id=_team_id,
+            models=[original_model_name],
+        ),
+        http_request=Request(scope={"type": "http"}),
+        user_api_key_dict=user_api_key_dict,
+    )
+
+    return model_response
+
+
+def check_if_team_id_matches_key(
+    team_id: Optional[str], user_api_key_dict: UserAPIKeyAuth
+) -> bool:
+    can_make_call = True
+    if (
+        user_api_key_dict.user_role
+        and user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN
+    ):
+        return True
+    if team_id is None:
+        if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
+            can_make_call = False
+    else:
+        if user_api_key_dict.team_id != team_id:
+            can_make_call = False
+    return can_make_call
+
+
+#### [BETA] - This is a beta endpoint, format might change based on user feedback. - https://github.com/BerriAI/litellm/issues/964
+@router.post(
+    "/model/delete",
+    description="Allows deleting models in the model list in the config.yaml",
+    tags=["model management"],
+    dependencies=[Depends(user_api_key_auth)],
+)
+async def delete_model(
+    model_info: ModelInfoDelete,
+    user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
+):
+    from litellm.proxy.proxy_server import llm_router
+
+    try:
+        """
+        [BETA] - This is a beta endpoint, format might change based on user feedback. - https://github.com/BerriAI/litellm/issues/964
+
+        - Check if id in db
+        - Delete
+        """
+
+        from litellm.proxy.proxy_server import (
+            llm_router,
+            prisma_client,
+            store_model_in_db,
+        )
+
+        if prisma_client is None:
+            raise HTTPException(
+                status_code=500,
+                detail={
+                    "error": "No DB Connected. Here's how to do it - https://docs.litellm.ai/docs/proxy/virtual_keys"
+                },
+            )
+
+        # update DB
+        if store_model_in_db is True:
+            """
+            - store model_list in db
+            - store keys separately
+            """
+            # encrypt litellm params #
+            result = await prisma_client.db.litellm_proxymodeltable.delete(
+                where={"model_id": model_info.id}
+            )
+
+            if result is None:
+                raise HTTPException(
+                    status_code=400,
+                    detail={"error": f"Model with id={model_info.id} not found in db"},
+                )
+
+            ## DELETE FROM ROUTER ##
+            if llm_router is not None:
+                llm_router.delete_deployment(id=model_info.id)
+
+            ## CREATE AUDIT LOG ##
+            asyncio.create_task(
+                create_object_audit_log(
+                    object_id=model_info.id,
+                    action="deleted",
+                    user_api_key_dict=user_api_key_dict,
+                    table_name=LitellmTableNames.PROXY_MODEL_TABLE_NAME,
+                    before_value=result.model_dump_json(exclude_none=True),
+                    after_value=None,
+                    litellm_changed_by=user_api_key_dict.user_id,
+                    litellm_proxy_admin_name=LITELLM_PROXY_ADMIN_NAME,
+                )
+            )
+            return {"message": f"Model: {result.model_id} deleted successfully"}
+        else:
+            raise HTTPException(
+                status_code=500,
+                detail={
+                    "error": "Set `'STORE_MODEL_IN_DB='True'` in your env to enable this feature."
+                },
+            )
+
+    except Exception as e:
+        verbose_proxy_logger.exception(
+            f"Failed to delete model. Due to error - {str(e)}"
+        )
+        if isinstance(e, HTTPException):
+            raise ProxyException(
+                message=getattr(e, "detail", f"Authentication Error({str(e)})"),
+                type=ProxyErrorTypes.auth_error,
+                param=getattr(e, "param", "None"),
+                code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
+            )
+        elif isinstance(e, ProxyException):
+            raise e
+        raise ProxyException(
+            message="Authentication Error, " + str(e),
+            type=ProxyErrorTypes.auth_error,
+            param=getattr(e, "param", "None"),
+            code=status.HTTP_400_BAD_REQUEST,
+        )
+
+
+#### [BETA] - This is a beta endpoint, format might change based on user feedback. - https://github.com/BerriAI/litellm/issues/964
+@router.post(
+    "/model/new",
+    description="Allows adding new models to the model list in the config.yaml",
+    tags=["model management"],
+    dependencies=[Depends(user_api_key_auth)],
+)
+async def add_new_model(
+    model_params: Deployment,
+    user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
+):
+    from litellm.proxy.proxy_server import (
+        general_settings,
+        premium_user,
+        prisma_client,
+        proxy_config,
+        proxy_logging_obj,
+        store_model_in_db,
+    )
+
+    try:
+
+        if prisma_client is None:
+            raise HTTPException(
+                status_code=500,
+                detail={
+                    "error": "No DB Connected. Here's how to do it - https://docs.litellm.ai/docs/proxy/virtual_keys"
+                },
+            )
+
+        if model_params.model_info.team_id is not None and premium_user is not True:
+            raise HTTPException(
+                status_code=403,
+                detail={"error": CommonProxyErrors.not_premium_user.value},
+            )
+
+        if not check_if_team_id_matches_key(
+            team_id=model_params.model_info.team_id, user_api_key_dict=user_api_key_dict
+        ):
+            raise HTTPException(
+                status_code=403,
+                detail={"error": "Team ID does not match the API key's team ID"},
+            )
+
+        model_response: Optional[LiteLLM_ProxyModelTable] = None
+        # update DB
+        if store_model_in_db is True:
+            """
+            - store model_list in db
+            - store keys separately
+            """
+
+            try:
+                _original_litellm_model_name = model_params.model_name
+                if model_params.model_info.team_id is None:
+                    model_response = await _add_model_to_db(
+                        model_params=model_params,
+                        user_api_key_dict=user_api_key_dict,
+                        prisma_client=prisma_client,
+                    )
+                else:
+                    model_response = await _add_team_model_to_db(
+                        model_params=model_params,
+                        user_api_key_dict=user_api_key_dict,
+                        prisma_client=prisma_client,
+                    )
+                await proxy_config.add_deployment(
+                    prisma_client=prisma_client, proxy_logging_obj=proxy_logging_obj
+                )
+                # don't let failed slack alert block the /model/new response
+                _alerting = general_settings.get("alerting", []) or []
+                if "slack" in _alerting:
+                    # send notification - new model added
+                    await proxy_logging_obj.slack_alerting_instance.model_added_alert(
+                        model_name=model_params.model_name,
+                        litellm_model_name=_original_litellm_model_name,
+                        passed_model_info=model_params.model_info,
+                    )
+            except Exception as e:
+                verbose_proxy_logger.exception(f"Exception in add_new_model: {e}")
+
+        else:
+            raise HTTPException(
+                status_code=500,
+                detail={
+                    "error": "Set `'STORE_MODEL_IN_DB='True'` in your env to enable this feature."
+                },
+            )
+
+        if model_response is None:
+            raise HTTPException(
+                status_code=500,
+                detail={
+                    "error": "Failed to add model to db. Check your server logs for more details."
+                },
+            )
+
+        ## CREATE AUDIT LOG ##
+        asyncio.create_task(
+            create_object_audit_log(
+                object_id=model_response.model_id,
+                action="created",
+                user_api_key_dict=user_api_key_dict,
+                table_name=LitellmTableNames.PROXY_MODEL_TABLE_NAME,
+                before_value=None,
+                after_value=(
+                    model_response.model_dump_json(exclude_none=True)
+                    if isinstance(model_response, BaseModel)
+                    else None
+                ),
+                litellm_changed_by=user_api_key_dict.user_id,
+                litellm_proxy_admin_name=LITELLM_PROXY_ADMIN_NAME,
+            )
+        )
+
+        return model_response
+
+    except Exception as e:
+        verbose_proxy_logger.exception(
+            "litellm.proxy.proxy_server.add_new_model(): Exception occured - {}".format(
+                str(e)
+            )
+        )
+        if isinstance(e, HTTPException):
+            raise ProxyException(
+                message=getattr(e, "detail", f"Authentication Error({str(e)})"),
+                type=ProxyErrorTypes.auth_error,
+                param=getattr(e, "param", "None"),
+                code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
+            )
+        elif isinstance(e, ProxyException):
+            raise e
+        raise ProxyException(
+            message="Authentication Error, " + str(e),
+            type=ProxyErrorTypes.auth_error,
+            param=getattr(e, "param", "None"),
+            code=status.HTTP_400_BAD_REQUEST,
+        )
+
+
+#### MODEL MANAGEMENT ####
+@router.post(
+    "/model/update",
+    description="Edit existing model params",
+    tags=["model management"],
+    dependencies=[Depends(user_api_key_auth)],
+)
+async def update_model(
+    model_params: updateDeployment,
+    user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
+):
+    """
+    Old endpoint for model update. Makes a PUT request.
+
+    Use `/model/{model_id}/update` to PATCH the stored model in db.
+    """
+    from litellm.proxy.proxy_server import (
+        LITELLM_PROXY_ADMIN_NAME,
+        llm_router,
+        prisma_client,
+        store_model_in_db,
+    )
+
+    try:
+
+        if prisma_client is None:
+            raise HTTPException(
+                status_code=500,
+                detail={
+                    "error": "No DB Connected. Here's how to do it - https://docs.litellm.ai/docs/proxy/virtual_keys"
+                },
+            )
+        # update DB
+        if store_model_in_db is True:
+            _model_id = None
+            _model_info = getattr(model_params, "model_info", None)
+            if _model_info is None:
+                raise Exception("model_info not provided")
+
+            _model_id = _model_info.id
+            if _model_id is None:
+                raise Exception("model_info.id not provided")
+            _existing_litellm_params = (
+                await prisma_client.db.litellm_proxymodeltable.find_unique(
+                    where={"model_id": _model_id}
+                )
+            )
+            if _existing_litellm_params is None:
+                if (
+                    llm_router is not None
+                    and llm_router.get_deployment(model_id=_model_id) is not None
+                ):
+                    raise HTTPException(
+                        status_code=400,
+                        detail={
+                            "error": "Can't edit model. Model in config. Store model in db via `/model/new`. to edit."
+                        },
+                    )
+                raise Exception("model not found")
+            _existing_litellm_params_dict = dict(
+                _existing_litellm_params.litellm_params
+            )
+
+            if model_params.litellm_params is None:
+                raise Exception("litellm_params not provided")
+
+            _new_litellm_params_dict = model_params.litellm_params.dict(
+                exclude_none=True
+            )
+
+            ### ENCRYPT PARAMS ###
+            for k, v in _new_litellm_params_dict.items():
+                encrypted_value = encrypt_value_helper(value=v)
+                model_params.litellm_params[k] = encrypted_value
+
+            ### MERGE WITH EXISTING DATA ###
+            merged_dictionary = {}
+            _mp = model_params.litellm_params.dict()
+
+            for key, value in _mp.items():
+                if value is not None:
+                    merged_dictionary[key] = value
+                elif (
+                    key in _existing_litellm_params_dict
+                    and _existing_litellm_params_dict[key] is not None
+                ):
+                    merged_dictionary[key] = _existing_litellm_params_dict[key]
+                else:
+                    pass
+
+            _data: dict = {
+                "litellm_params": json.dumps(merged_dictionary),  # type: ignore
+                "updated_by": user_api_key_dict.user_id or LITELLM_PROXY_ADMIN_NAME,
+            }
+            model_response = await prisma_client.db.litellm_proxymodeltable.update(
+                where={"model_id": _model_id},
+                data=_data,  # type: ignore
+            )
+
+            ## CREATE AUDIT LOG ##
+            asyncio.create_task(
+                create_object_audit_log(
+                    object_id=_model_id,
+                    action="updated",
+                    user_api_key_dict=user_api_key_dict,
+                    table_name=LitellmTableNames.PROXY_MODEL_TABLE_NAME,
+                    before_value=(
+                        _existing_litellm_params.model_dump_json(exclude_none=True)
+                        if isinstance(_existing_litellm_params, BaseModel)
+                        else None
+                    ),
+                    after_value=(
+                        model_response.model_dump_json(exclude_none=True)
+                        if isinstance(model_response, BaseModel)
+                        else None
+                    ),
+                    litellm_changed_by=user_api_key_dict.user_id,
+                    litellm_proxy_admin_name=LITELLM_PROXY_ADMIN_NAME,
+                )
+            )
+
+            return model_response
+    except Exception as e:
+        verbose_proxy_logger.exception(
+            "litellm.proxy.proxy_server.update_model(): Exception occured - {}".format(
+                str(e)
+            )
+        )
+        if isinstance(e, HTTPException):
+            raise ProxyException(
+                message=getattr(e, "detail", f"Authentication Error({str(e)})"),
+                type=ProxyErrorTypes.auth_error,
+                param=getattr(e, "param", "None"),
+                code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
+            )
+        elif isinstance(e, ProxyException):
+            raise e
+        raise ProxyException(
+            message="Authentication Error, " + str(e),
+            type=ProxyErrorTypes.auth_error,
+            param=getattr(e, "param", "None"),
+            code=status.HTTP_400_BAD_REQUEST,
+        )