diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py | 1001 |
1 files changed, 1001 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/.venv/lib/python3.12/site-packages/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py new file mode 100644 index 00000000..a13b0dc2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -0,0 +1,1001 @@ +import ast +import asyncio +import json +import uuid +from base64 import b64encode +from datetime import datetime +from typing import Dict, List, Optional, Union +from urllib.parse import parse_qs, urlencode, urlparse + +import httpx +from fastapi import APIRouter, Depends, HTTPException, Request, Response, status +from fastapi.responses import StreamingResponse + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.integrations.custom_logger import CustomLogger +from litellm.llms.custom_httpx.http_handler import get_async_httpx_client +from litellm.proxy._types import ( + ConfigFieldInfo, + ConfigFieldUpdate, + PassThroughEndpointResponse, + PassThroughGenericEndpoint, + ProxyException, + UserAPIKeyAuth, +) +from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing +from litellm.proxy.common_utils.http_parsing_utils import _read_request_body +from litellm.secret_managers.main import get_secret_str +from litellm.types.llms.custom_http import httpxSpecialProvider +from litellm.types.utils import StandardLoggingUserAPIKeyMetadata + +from .streaming_handler import PassThroughStreamingHandler +from .success_handler import PassThroughEndpointLogging +from .types import EndpointType, PassthroughStandardLoggingPayload + +router = APIRouter() + +pass_through_endpoint_logging = PassThroughEndpointLogging() + + +def get_response_body(response: httpx.Response) -> Optional[dict]: + try: + return response.json() + except Exception: + return None + + +async def set_env_variables_in_header(custom_headers: Optional[dict]) -> Optional[dict]: + """ + checks if any headers on config.yaml are defined as os.environ/COHERE_API_KEY etc + + only runs for headers defined on config.yaml + + example header can be + + {"Authorization": "bearer os.environ/COHERE_API_KEY"} + """ + if custom_headers is None: + return None + headers = {} + for key, value in custom_headers.items(): + # langfuse Api requires base64 encoded headers - it's simpleer to just ask litellm users to set their langfuse public and secret keys + # we can then get the b64 encoded keys here + if key == "LANGFUSE_PUBLIC_KEY" or key == "LANGFUSE_SECRET_KEY": + # langfuse requires b64 encoded headers - we construct that here + _langfuse_public_key = custom_headers["LANGFUSE_PUBLIC_KEY"] + _langfuse_secret_key = custom_headers["LANGFUSE_SECRET_KEY"] + if isinstance( + _langfuse_public_key, str + ) and _langfuse_public_key.startswith("os.environ/"): + _langfuse_public_key = get_secret_str(_langfuse_public_key) + if isinstance( + _langfuse_secret_key, str + ) and _langfuse_secret_key.startswith("os.environ/"): + _langfuse_secret_key = get_secret_str(_langfuse_secret_key) + headers["Authorization"] = "Basic " + b64encode( + f"{_langfuse_public_key}:{_langfuse_secret_key}".encode("utf-8") + ).decode("ascii") + else: + # for all other headers + headers[key] = value + if isinstance(value, str) and "os.environ/" in value: + verbose_proxy_logger.debug( + "pass through endpoint - looking up 'os.environ/' variable" + ) + # get string section that is os.environ/ + start_index = value.find("os.environ/") + _variable_name = value[start_index:] + + verbose_proxy_logger.debug( + "pass through endpoint - getting secret for variable name: %s", + _variable_name, + ) + _secret_value = get_secret_str(_variable_name) + if _secret_value is not None: + new_value = value.replace(_variable_name, _secret_value) + headers[key] = new_value + return headers + + +async def chat_completion_pass_through_endpoint( # noqa: PLR0915 + fastapi_response: Response, + request: Request, + adapter_id: str, + user_api_key_dict: UserAPIKeyAuth, +): + from litellm.proxy.proxy_server import ( + add_litellm_data_to_request, + general_settings, + llm_router, + proxy_config, + proxy_logging_obj, + user_api_base, + user_max_tokens, + user_model, + user_request_timeout, + user_temperature, + version, + ) + + data = {} + try: + body = await request.body() + body_str = body.decode() + try: + data = ast.literal_eval(body_str) + except Exception: + data = json.loads(body_str) + + data["adapter_id"] = adapter_id + + verbose_proxy_logger.debug( + "Request received by LiteLLM:\n{}".format(json.dumps(data, indent=4)), + ) + data["model"] = ( + general_settings.get("completion_model", None) # server default + or user_model # model name passed via cli args + or data.get("model", None) # default passed in http request + ) + if user_model: + data["model"] = user_model + + data = await add_litellm_data_to_request( + data=data, # type: ignore + request=request, + general_settings=general_settings, + user_api_key_dict=user_api_key_dict, + version=version, + proxy_config=proxy_config, + ) + + # override with user settings, these are params passed via cli + if user_temperature: + data["temperature"] = user_temperature + if user_request_timeout: + data["request_timeout"] = user_request_timeout + if user_max_tokens: + data["max_tokens"] = user_max_tokens + if user_api_base: + data["api_base"] = user_api_base + + ### MODEL ALIAS MAPPING ### + # check if model name in model alias map + # get the actual model name + if data["model"] in litellm.model_alias_map: + data["model"] = litellm.model_alias_map[data["model"]] + + ### CALL HOOKS ### - modify incoming data before calling the model + data = await proxy_logging_obj.pre_call_hook( # type: ignore + user_api_key_dict=user_api_key_dict, data=data, call_type="text_completion" + ) + + ### ROUTE THE REQUESTs ### + router_model_names = llm_router.model_names if llm_router is not None else [] + # skip router if user passed their key + if "api_key" in data: + llm_response = asyncio.create_task(litellm.aadapter_completion(**data)) + elif ( + llm_router is not None and data["model"] in router_model_names + ): # model in router model list + llm_response = asyncio.create_task(llm_router.aadapter_completion(**data)) + elif ( + llm_router is not None + and llm_router.model_group_alias is not None + and data["model"] in llm_router.model_group_alias + ): # model set in model_group_alias + llm_response = asyncio.create_task(llm_router.aadapter_completion(**data)) + elif ( + llm_router is not None and data["model"] in llm_router.deployment_names + ): # model in router deployments, calling a specific deployment on the router + llm_response = asyncio.create_task( + llm_router.aadapter_completion(**data, specific_deployment=True) + ) + elif ( + llm_router is not None and data["model"] in llm_router.get_model_ids() + ): # model in router model list + llm_response = asyncio.create_task(llm_router.aadapter_completion(**data)) + elif ( + llm_router is not None + and data["model"] not in router_model_names + and llm_router.default_deployment is not None + ): # model in router deployments, calling a specific deployment on the router + llm_response = asyncio.create_task(llm_router.aadapter_completion(**data)) + elif user_model is not None: # `litellm --model <your-model-name>` + llm_response = asyncio.create_task(litellm.aadapter_completion(**data)) + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={ + "error": "completion: Invalid model name passed in model=" + + data.get("model", "") + }, + ) + + # Await the llm_response task + response = await llm_response + + hidden_params = getattr(response, "_hidden_params", {}) or {} + model_id = hidden_params.get("model_id", None) or "" + cache_key = hidden_params.get("cache_key", None) or "" + api_base = hidden_params.get("api_base", None) or "" + response_cost = hidden_params.get("response_cost", None) or "" + + ### ALERTING ### + asyncio.create_task( + proxy_logging_obj.update_request_status( + litellm_call_id=data.get("litellm_call_id", ""), status="success" + ) + ) + + verbose_proxy_logger.debug("final response: %s", response) + + fastapi_response.headers.update( + ProxyBaseLLMRequestProcessing.get_custom_headers( + user_api_key_dict=user_api_key_dict, + model_id=model_id, + cache_key=cache_key, + api_base=api_base, + version=version, + response_cost=response_cost, + ) + ) + + verbose_proxy_logger.info("\nResponse from Litellm:\n{}".format(response)) + return response + except Exception as e: + await proxy_logging_obj.post_call_failure_hook( + user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data + ) + verbose_proxy_logger.exception( + "litellm.proxy.proxy_server.completion(): Exception occured - {}".format( + str(e) + ) + ) + error_msg = f"{str(e)}" + raise ProxyException( + message=getattr(e, "message", error_msg), + type=getattr(e, "type", "None"), + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", 500), + ) + + +class HttpPassThroughEndpointHelpers: + @staticmethod + def forward_headers_from_request( + request: Request, + headers: dict, + forward_headers: Optional[bool] = False, + ): + """ + Helper to forward headers from original request + """ + if forward_headers is True: + request_headers = dict(request.headers) + + # Header We Should NOT forward + request_headers.pop("content-length", None) + request_headers.pop("host", None) + + # Combine request headers with custom headers + headers = {**request_headers, **headers} + return headers + + @staticmethod + def get_response_headers( + headers: httpx.Headers, + litellm_call_id: Optional[str] = None, + custom_headers: Optional[dict] = None, + ) -> dict: + excluded_headers = {"transfer-encoding", "content-encoding"} + + return_headers = { + key: value + for key, value in headers.items() + if key.lower() not in excluded_headers + } + if litellm_call_id: + return_headers["x-litellm-call-id"] = litellm_call_id + if custom_headers: + return_headers.update(custom_headers) + + return return_headers + + @staticmethod + def get_endpoint_type(url: str) -> EndpointType: + parsed_url = urlparse(url) + if ("generateContent") in url or ("streamGenerateContent") in url: + return EndpointType.VERTEX_AI + elif parsed_url.hostname == "api.anthropic.com": + return EndpointType.ANTHROPIC + return EndpointType.GENERIC + + @staticmethod + def get_merged_query_parameters( + existing_url: httpx.URL, request_query_params: Dict[str, Union[str, list]] + ) -> Dict[str, Union[str, List[str]]]: + # Get the existing query params from the target URL + existing_query_string = existing_url.query.decode("utf-8") + existing_query_params = parse_qs(existing_query_string) + + # parse_qs returns a dict where each value is a list, so let's flatten it + updated_existing_query_params = { + k: v[0] if len(v) == 1 else v for k, v in existing_query_params.items() + } + # Merge the query params, giving priority to the existing ones + return {**request_query_params, **updated_existing_query_params} + + @staticmethod + async def _make_non_streaming_http_request( + request: Request, + async_client: httpx.AsyncClient, + url: str, + headers: dict, + requested_query_params: Optional[dict] = None, + custom_body: Optional[dict] = None, + ) -> httpx.Response: + """ + Make a non-streaming HTTP request + + If request is GET, don't include a JSON body + """ + if request.method == "GET": + response = await async_client.request( + method=request.method, + url=url, + headers=headers, + params=requested_query_params, + ) + else: + response = await async_client.request( + method=request.method, + url=url, + headers=headers, + params=requested_query_params, + json=custom_body, + ) + return response + + +async def pass_through_request( # noqa: PLR0915 + request: Request, + target: str, + custom_headers: dict, + user_api_key_dict: UserAPIKeyAuth, + custom_body: Optional[dict] = None, + forward_headers: Optional[bool] = False, + merge_query_params: Optional[bool] = False, + query_params: Optional[dict] = None, + stream: Optional[bool] = None, +): + litellm_call_id = str(uuid.uuid4()) + url: Optional[httpx.URL] = None + try: + + from litellm.litellm_core_utils.litellm_logging import Logging + from litellm.proxy.proxy_server import proxy_logging_obj + + url = httpx.URL(target) + headers = custom_headers + headers = HttpPassThroughEndpointHelpers.forward_headers_from_request( + request=request, headers=headers, forward_headers=forward_headers + ) + + if merge_query_params: + + # Create a new URL with the merged query params + url = url.copy_with( + query=urlencode( + HttpPassThroughEndpointHelpers.get_merged_query_parameters( + existing_url=url, + request_query_params=dict(request.query_params), + ) + ).encode("ascii") + ) + + endpoint_type: EndpointType = HttpPassThroughEndpointHelpers.get_endpoint_type( + str(url) + ) + + _parsed_body = None + if custom_body: + _parsed_body = custom_body + else: + _parsed_body = await _read_request_body(request) + verbose_proxy_logger.debug( + "Pass through endpoint sending request to \nURL {}\nheaders: {}\nbody: {}\n".format( + url, headers, _parsed_body + ) + ) + + ### CALL HOOKS ### - modify incoming data / reject request before calling the model + _parsed_body = await proxy_logging_obj.pre_call_hook( + user_api_key_dict=user_api_key_dict, + data=_parsed_body, + call_type="pass_through_endpoint", + ) + async_client_obj = get_async_httpx_client( + llm_provider=httpxSpecialProvider.PassThroughEndpoint, + params={"timeout": 600}, + ) + async_client = async_client_obj.client + + # create logging object + start_time = datetime.now() + logging_obj = Logging( + model="unknown", + messages=[{"role": "user", "content": json.dumps(_parsed_body)}], + stream=False, + call_type="pass_through_endpoint", + start_time=start_time, + litellm_call_id=litellm_call_id, + function_id="1245", + ) + passthrough_logging_payload = PassthroughStandardLoggingPayload( + url=str(url), + request_body=_parsed_body, + ) + kwargs = _init_kwargs_for_pass_through_endpoint( + user_api_key_dict=user_api_key_dict, + _parsed_body=_parsed_body, + passthrough_logging_payload=passthrough_logging_payload, + litellm_call_id=litellm_call_id, + request=request, + ) + # done for supporting 'parallel_request_limiter.py' with pass-through endpoints + logging_obj.update_environment_variables( + model="unknown", + user="unknown", + optional_params={}, + litellm_params=kwargs["litellm_params"], + call_type="pass_through_endpoint", + ) + logging_obj.model_call_details["litellm_call_id"] = litellm_call_id + + # combine url with query params for logging + + requested_query_params: Optional[dict] = ( + query_params or request.query_params.__dict__ + ) + if requested_query_params == request.query_params.__dict__: + requested_query_params = None + + requested_query_params_str = None + if requested_query_params: + requested_query_params_str = "&".join( + f"{k}={v}" for k, v in requested_query_params.items() + ) + + logging_url = str(url) + if requested_query_params_str: + if "?" in str(url): + logging_url = str(url) + "&" + requested_query_params_str + else: + logging_url = str(url) + "?" + requested_query_params_str + + logging_obj.pre_call( + input=[{"role": "user", "content": json.dumps(_parsed_body)}], + api_key="", + additional_args={ + "complete_input_dict": _parsed_body, + "api_base": str(logging_url), + "headers": headers, + }, + ) + if stream: + req = async_client.build_request( + "POST", + url, + json=_parsed_body, + params=requested_query_params, + headers=headers, + ) + + response = await async_client.send(req, stream=stream) + + try: + response.raise_for_status() + except httpx.HTTPStatusError as e: + raise HTTPException( + status_code=e.response.status_code, detail=await e.response.aread() + ) + + return StreamingResponse( + PassThroughStreamingHandler.chunk_processor( + response=response, + request_body=_parsed_body, + litellm_logging_obj=logging_obj, + endpoint_type=endpoint_type, + start_time=start_time, + passthrough_success_handler_obj=pass_through_endpoint_logging, + url_route=str(url), + ), + headers=HttpPassThroughEndpointHelpers.get_response_headers( + headers=response.headers, + litellm_call_id=litellm_call_id, + ), + status_code=response.status_code, + ) + + verbose_proxy_logger.debug("request method: {}".format(request.method)) + verbose_proxy_logger.debug("request url: {}".format(url)) + verbose_proxy_logger.debug("request headers: {}".format(headers)) + verbose_proxy_logger.debug( + "requested_query_params={}".format(requested_query_params) + ) + verbose_proxy_logger.debug("request body: {}".format(_parsed_body)) + + if request.method == "GET": + response = await async_client.request( + method=request.method, + url=url, + headers=headers, + params=requested_query_params, + ) + else: + response = await async_client.request( + method=request.method, + url=url, + headers=headers, + params=requested_query_params, + json=_parsed_body, + ) + + verbose_proxy_logger.debug("response.headers= %s", response.headers) + + if _is_streaming_response(response) is True: + try: + response.raise_for_status() + except httpx.HTTPStatusError as e: + raise HTTPException( + status_code=e.response.status_code, detail=await e.response.aread() + ) + + return StreamingResponse( + PassThroughStreamingHandler.chunk_processor( + response=response, + request_body=_parsed_body, + litellm_logging_obj=logging_obj, + endpoint_type=endpoint_type, + start_time=start_time, + passthrough_success_handler_obj=pass_through_endpoint_logging, + url_route=str(url), + ), + headers=HttpPassThroughEndpointHelpers.get_response_headers( + headers=response.headers, + litellm_call_id=litellm_call_id, + ), + status_code=response.status_code, + ) + + try: + response.raise_for_status() + except httpx.HTTPStatusError as e: + raise HTTPException( + status_code=e.response.status_code, detail=e.response.text + ) + + if response.status_code >= 300: + raise HTTPException(status_code=response.status_code, detail=response.text) + + content = await response.aread() + + ## LOG SUCCESS + response_body: Optional[dict] = get_response_body(response) + passthrough_logging_payload["response_body"] = response_body + end_time = datetime.now() + asyncio.create_task( + pass_through_endpoint_logging.pass_through_async_success_handler( + httpx_response=response, + response_body=response_body, + url_route=str(url), + result="", + start_time=start_time, + end_time=end_time, + logging_obj=logging_obj, + cache_hit=False, + **kwargs, + ) + ) + + ## CUSTOM HEADERS - `x-litellm-*` + custom_headers = ProxyBaseLLMRequestProcessing.get_custom_headers( + user_api_key_dict=user_api_key_dict, + call_id=litellm_call_id, + model_id=None, + cache_key=None, + api_base=str(url._uri_reference), + ) + + return Response( + content=content, + status_code=response.status_code, + headers=HttpPassThroughEndpointHelpers.get_response_headers( + headers=response.headers, + custom_headers=custom_headers, + ), + ) + except Exception as e: + custom_headers = ProxyBaseLLMRequestProcessing.get_custom_headers( + user_api_key_dict=user_api_key_dict, + call_id=litellm_call_id, + model_id=None, + cache_key=None, + api_base=str(url._uri_reference) if url else None, + ) + verbose_proxy_logger.exception( + "litellm.proxy.proxy_server.pass_through_endpoint(): Exception occured - {}".format( + str(e) + ) + ) + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "message", str(e.detail)), + type=getattr(e, "type", "None"), + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), + headers=custom_headers, + ) + else: + error_msg = f"{str(e)}" + raise ProxyException( + message=getattr(e, "message", error_msg), + type=getattr(e, "type", "None"), + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", 500), + headers=custom_headers, + ) + + +def _init_kwargs_for_pass_through_endpoint( + request: Request, + user_api_key_dict: UserAPIKeyAuth, + passthrough_logging_payload: PassthroughStandardLoggingPayload, + _parsed_body: Optional[dict] = None, + litellm_call_id: Optional[str] = None, +) -> dict: + _parsed_body = _parsed_body or {} + _litellm_metadata: Optional[dict] = _parsed_body.pop("litellm_metadata", None) + _metadata = dict( + StandardLoggingUserAPIKeyMetadata( + user_api_key_hash=user_api_key_dict.api_key, + user_api_key_alias=user_api_key_dict.key_alias, + user_api_key_user_email=user_api_key_dict.user_email, + user_api_key_user_id=user_api_key_dict.user_id, + user_api_key_team_id=user_api_key_dict.team_id, + user_api_key_org_id=user_api_key_dict.org_id, + user_api_key_team_alias=user_api_key_dict.team_alias, + user_api_key_end_user_id=user_api_key_dict.end_user_id, + ) + ) + _metadata["user_api_key"] = user_api_key_dict.api_key + if _litellm_metadata: + _metadata.update(_litellm_metadata) + + _metadata = _update_metadata_with_tags_in_header( + request=request, + metadata=_metadata, + ) + + kwargs = { + "litellm_params": { + "metadata": _metadata, + }, + "call_type": "pass_through_endpoint", + "litellm_call_id": litellm_call_id, + "passthrough_logging_payload": passthrough_logging_payload, + } + return kwargs + + +def _update_metadata_with_tags_in_header(request: Request, metadata: dict) -> dict: + """ + If tags are in the request headers, add them to the metadata + + Used for google and vertex JS SDKs + """ + _tags = request.headers.get("tags") + if _tags: + metadata["tags"] = _tags.split(",") + return metadata + + +def create_pass_through_route( + endpoint, + target: str, + custom_headers: Optional[dict] = None, + _forward_headers: Optional[bool] = False, + _merge_query_params: Optional[bool] = False, + dependencies: Optional[List] = None, +): + # check if target is an adapter.py or a url + import uuid + + from litellm.proxy.types_utils.utils import get_instance_fn + + try: + if isinstance(target, CustomLogger): + adapter = target + else: + adapter = get_instance_fn(value=target) + adapter_id = str(uuid.uuid4()) + litellm.adapters = [{"id": adapter_id, "adapter": adapter}] + + async def endpoint_func( # type: ignore + request: Request, + fastapi_response: Response, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), + ): + return await chat_completion_pass_through_endpoint( + fastapi_response=fastapi_response, + request=request, + adapter_id=adapter_id, + user_api_key_dict=user_api_key_dict, + ) + + except Exception: + verbose_proxy_logger.debug("Defaulting to target being a url.") + + async def endpoint_func( # type: ignore + request: Request, + fastapi_response: Response, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), + query_params: Optional[dict] = None, + custom_body: Optional[dict] = None, + stream: Optional[ + bool + ] = None, # if pass-through endpoint is a streaming request + ): + return await pass_through_request( # type: ignore + request=request, + target=target, + custom_headers=custom_headers or {}, + user_api_key_dict=user_api_key_dict, + forward_headers=_forward_headers, + merge_query_params=_merge_query_params, + query_params=query_params, + stream=stream, + custom_body=custom_body, + ) + + return endpoint_func + + +def _is_streaming_response(response: httpx.Response) -> bool: + _content_type = response.headers.get("content-type") + if _content_type is not None and "text/event-stream" in _content_type: + return True + return False + + +async def initialize_pass_through_endpoints(pass_through_endpoints: list): + + verbose_proxy_logger.debug("initializing pass through endpoints") + from litellm.proxy._types import CommonProxyErrors, LiteLLMRoutes + from litellm.proxy.proxy_server import app, premium_user + + for endpoint in pass_through_endpoints: + _target = endpoint.get("target", None) + _path = endpoint.get("path", None) + _custom_headers = endpoint.get("headers", None) + _custom_headers = await set_env_variables_in_header( + custom_headers=_custom_headers + ) + _forward_headers = endpoint.get("forward_headers", None) + _merge_query_params = endpoint.get("merge_query_params", None) + _auth = endpoint.get("auth", None) + _dependencies = None + if _auth is not None and str(_auth).lower() == "true": + if premium_user is not True: + raise ValueError( + "Error Setting Authentication on Pass Through Endpoint: {}".format( + CommonProxyErrors.not_premium_user.value + ) + ) + _dependencies = [Depends(user_api_key_auth)] + LiteLLMRoutes.openai_routes.value.append(_path) + + if _target is None: + continue + + verbose_proxy_logger.debug( + "adding pass through endpoint: %s, dependencies: %s", _path, _dependencies + ) + app.add_api_route( # type: ignore + path=_path, + endpoint=create_pass_through_route( # type: ignore + _path, + _target, + _custom_headers, + _forward_headers, + _merge_query_params, + _dependencies, + ), + methods=["GET", "POST", "PUT", "DELETE", "PATCH"], + dependencies=_dependencies, + ) + + verbose_proxy_logger.debug("Added new pass through endpoint: %s", _path) + + +@router.get( + "/config/pass_through_endpoint", + dependencies=[Depends(user_api_key_auth)], + response_model=PassThroughEndpointResponse, +) +async def get_pass_through_endpoints( + endpoint_id: Optional[str] = None, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + GET configured pass through endpoint. + + If no endpoint_id given, return all configured endpoints. + """ + from litellm.proxy.proxy_server import get_config_general_settings + + ## Get existing pass-through endpoint field value + try: + response: ConfigFieldInfo = await get_config_general_settings( + field_name="pass_through_endpoints", user_api_key_dict=user_api_key_dict + ) + except Exception: + return PassThroughEndpointResponse(endpoints=[]) + + pass_through_endpoint_data: Optional[List] = response.field_value + if pass_through_endpoint_data is None: + return PassThroughEndpointResponse(endpoints=[]) + + returned_endpoints = [] + if endpoint_id is None: + for endpoint in pass_through_endpoint_data: + if isinstance(endpoint, dict): + returned_endpoints.append(PassThroughGenericEndpoint(**endpoint)) + elif isinstance(endpoint, PassThroughGenericEndpoint): + returned_endpoints.append(endpoint) + elif endpoint_id is not None: + for endpoint in pass_through_endpoint_data: + _endpoint: Optional[PassThroughGenericEndpoint] = None + if isinstance(endpoint, dict): + _endpoint = PassThroughGenericEndpoint(**endpoint) + elif isinstance(endpoint, PassThroughGenericEndpoint): + _endpoint = endpoint + + if _endpoint is not None and _endpoint.path == endpoint_id: + returned_endpoints.append(_endpoint) + + return PassThroughEndpointResponse(endpoints=returned_endpoints) + + +@router.post( + "/config/pass_through_endpoint/{endpoint_id}", + dependencies=[Depends(user_api_key_auth)], +) +async def update_pass_through_endpoints(request: Request, endpoint_id: str): + """ + Update a pass-through endpoint + """ + pass + + +@router.post( + "/config/pass_through_endpoint", + dependencies=[Depends(user_api_key_auth)], +) +async def create_pass_through_endpoints( + data: PassThroughGenericEndpoint, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Create new pass-through endpoint + """ + from litellm.proxy.proxy_server import ( + get_config_general_settings, + update_config_general_settings, + ) + + ## Get existing pass-through endpoint field value + + try: + response: ConfigFieldInfo = await get_config_general_settings( + field_name="pass_through_endpoints", user_api_key_dict=user_api_key_dict + ) + except Exception: + response = ConfigFieldInfo( + field_name="pass_through_endpoints", field_value=None + ) + + ## Update field with new endpoint + data_dict = data.model_dump() + if response.field_value is None: + response.field_value = [data_dict] + elif isinstance(response.field_value, List): + response.field_value.append(data_dict) + + ## Update db + updated_data = ConfigFieldUpdate( + field_name="pass_through_endpoints", + field_value=response.field_value, + config_type="general_settings", + ) + await update_config_general_settings( + data=updated_data, user_api_key_dict=user_api_key_dict + ) + + +@router.delete( + "/config/pass_through_endpoint", + dependencies=[Depends(user_api_key_auth)], + response_model=PassThroughEndpointResponse, +) +async def delete_pass_through_endpoints( + endpoint_id: str, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Delete a pass-through endpoint + + Returns - the deleted endpoint + """ + from litellm.proxy.proxy_server import ( + get_config_general_settings, + update_config_general_settings, + ) + + ## Get existing pass-through endpoint field value + + try: + response: ConfigFieldInfo = await get_config_general_settings( + field_name="pass_through_endpoints", user_api_key_dict=user_api_key_dict + ) + except Exception: + response = ConfigFieldInfo( + field_name="pass_through_endpoints", field_value=None + ) + + ## Update field by removing endpoint + pass_through_endpoint_data: Optional[List] = response.field_value + response_obj: Optional[PassThroughGenericEndpoint] = None + if response.field_value is None or pass_through_endpoint_data is None: + raise HTTPException( + status_code=400, + detail={"error": "There are no pass-through endpoints setup."}, + ) + elif isinstance(response.field_value, List): + invalid_idx: Optional[int] = None + for idx, endpoint in enumerate(pass_through_endpoint_data): + _endpoint: Optional[PassThroughGenericEndpoint] = None + if isinstance(endpoint, dict): + _endpoint = PassThroughGenericEndpoint(**endpoint) + elif isinstance(endpoint, PassThroughGenericEndpoint): + _endpoint = endpoint + + if _endpoint is not None and _endpoint.path == endpoint_id: + invalid_idx = idx + response_obj = _endpoint + + if invalid_idx is not None: + pass_through_endpoint_data.pop(invalid_idx) + + ## Update db + updated_data = ConfigFieldUpdate( + field_name="pass_through_endpoints", + field_value=pass_through_endpoint_data, + config_type="general_settings", + ) + await update_config_general_settings( + data=updated_data, user_api_key_dict=user_api_key_dict + ) + + if response_obj is None: + raise HTTPException( + status_code=400, + detail={ + "error": "Endpoint={} was not found in pass-through endpoint list.".format( + endpoint_id + ) + }, + ) + return PassThroughEndpointResponse(endpoints=[response_obj]) |