aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py1001
1 files changed, 1001 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/.venv/lib/python3.12/site-packages/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py
new file mode 100644
index 00000000..a13b0dc2
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py
@@ -0,0 +1,1001 @@
+import ast
+import asyncio
+import json
+import uuid
+from base64 import b64encode
+from datetime import datetime
+from typing import Dict, List, Optional, Union
+from urllib.parse import parse_qs, urlencode, urlparse
+
+import httpx
+from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
+from fastapi.responses import StreamingResponse
+
+import litellm
+from litellm._logging import verbose_proxy_logger
+from litellm.integrations.custom_logger import CustomLogger
+from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
+from litellm.proxy._types import (
+ ConfigFieldInfo,
+ ConfigFieldUpdate,
+ PassThroughEndpointResponse,
+ PassThroughGenericEndpoint,
+ ProxyException,
+ UserAPIKeyAuth,
+)
+from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
+from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing
+from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
+from litellm.secret_managers.main import get_secret_str
+from litellm.types.llms.custom_http import httpxSpecialProvider
+from litellm.types.utils import StandardLoggingUserAPIKeyMetadata
+
+from .streaming_handler import PassThroughStreamingHandler
+from .success_handler import PassThroughEndpointLogging
+from .types import EndpointType, PassthroughStandardLoggingPayload
+
+router = APIRouter()
+
+pass_through_endpoint_logging = PassThroughEndpointLogging()
+
+
+def get_response_body(response: httpx.Response) -> Optional[dict]:
+ try:
+ return response.json()
+ except Exception:
+ return None
+
+
+async def set_env_variables_in_header(custom_headers: Optional[dict]) -> Optional[dict]:
+ """
+ checks if any headers on config.yaml are defined as os.environ/COHERE_API_KEY etc
+
+ only runs for headers defined on config.yaml
+
+ example header can be
+
+ {"Authorization": "bearer os.environ/COHERE_API_KEY"}
+ """
+ if custom_headers is None:
+ return None
+ headers = {}
+ for key, value in custom_headers.items():
+ # langfuse Api requires base64 encoded headers - it's simpleer to just ask litellm users to set their langfuse public and secret keys
+ # we can then get the b64 encoded keys here
+ if key == "LANGFUSE_PUBLIC_KEY" or key == "LANGFUSE_SECRET_KEY":
+ # langfuse requires b64 encoded headers - we construct that here
+ _langfuse_public_key = custom_headers["LANGFUSE_PUBLIC_KEY"]
+ _langfuse_secret_key = custom_headers["LANGFUSE_SECRET_KEY"]
+ if isinstance(
+ _langfuse_public_key, str
+ ) and _langfuse_public_key.startswith("os.environ/"):
+ _langfuse_public_key = get_secret_str(_langfuse_public_key)
+ if isinstance(
+ _langfuse_secret_key, str
+ ) and _langfuse_secret_key.startswith("os.environ/"):
+ _langfuse_secret_key = get_secret_str(_langfuse_secret_key)
+ headers["Authorization"] = "Basic " + b64encode(
+ f"{_langfuse_public_key}:{_langfuse_secret_key}".encode("utf-8")
+ ).decode("ascii")
+ else:
+ # for all other headers
+ headers[key] = value
+ if isinstance(value, str) and "os.environ/" in value:
+ verbose_proxy_logger.debug(
+ "pass through endpoint - looking up 'os.environ/' variable"
+ )
+ # get string section that is os.environ/
+ start_index = value.find("os.environ/")
+ _variable_name = value[start_index:]
+
+ verbose_proxy_logger.debug(
+ "pass through endpoint - getting secret for variable name: %s",
+ _variable_name,
+ )
+ _secret_value = get_secret_str(_variable_name)
+ if _secret_value is not None:
+ new_value = value.replace(_variable_name, _secret_value)
+ headers[key] = new_value
+ return headers
+
+
+async def chat_completion_pass_through_endpoint( # noqa: PLR0915
+ fastapi_response: Response,
+ request: Request,
+ adapter_id: str,
+ user_api_key_dict: UserAPIKeyAuth,
+):
+ from litellm.proxy.proxy_server import (
+ add_litellm_data_to_request,
+ general_settings,
+ llm_router,
+ proxy_config,
+ proxy_logging_obj,
+ user_api_base,
+ user_max_tokens,
+ user_model,
+ user_request_timeout,
+ user_temperature,
+ version,
+ )
+
+ data = {}
+ try:
+ body = await request.body()
+ body_str = body.decode()
+ try:
+ data = ast.literal_eval(body_str)
+ except Exception:
+ data = json.loads(body_str)
+
+ data["adapter_id"] = adapter_id
+
+ verbose_proxy_logger.debug(
+ "Request received by LiteLLM:\n{}".format(json.dumps(data, indent=4)),
+ )
+ data["model"] = (
+ general_settings.get("completion_model", None) # server default
+ or user_model # model name passed via cli args
+ or data.get("model", None) # default passed in http request
+ )
+ if user_model:
+ data["model"] = user_model
+
+ data = await add_litellm_data_to_request(
+ data=data, # type: ignore
+ request=request,
+ general_settings=general_settings,
+ user_api_key_dict=user_api_key_dict,
+ version=version,
+ proxy_config=proxy_config,
+ )
+
+ # override with user settings, these are params passed via cli
+ if user_temperature:
+ data["temperature"] = user_temperature
+ if user_request_timeout:
+ data["request_timeout"] = user_request_timeout
+ if user_max_tokens:
+ data["max_tokens"] = user_max_tokens
+ if user_api_base:
+ data["api_base"] = user_api_base
+
+ ### MODEL ALIAS MAPPING ###
+ # check if model name in model alias map
+ # get the actual model name
+ if data["model"] in litellm.model_alias_map:
+ data["model"] = litellm.model_alias_map[data["model"]]
+
+ ### CALL HOOKS ### - modify incoming data before calling the model
+ data = await proxy_logging_obj.pre_call_hook( # type: ignore
+ user_api_key_dict=user_api_key_dict, data=data, call_type="text_completion"
+ )
+
+ ### ROUTE THE REQUESTs ###
+ router_model_names = llm_router.model_names if llm_router is not None else []
+ # skip router if user passed their key
+ if "api_key" in data:
+ llm_response = asyncio.create_task(litellm.aadapter_completion(**data))
+ elif (
+ llm_router is not None and data["model"] in router_model_names
+ ): # model in router model list
+ llm_response = asyncio.create_task(llm_router.aadapter_completion(**data))
+ elif (
+ llm_router is not None
+ and llm_router.model_group_alias is not None
+ and data["model"] in llm_router.model_group_alias
+ ): # model set in model_group_alias
+ llm_response = asyncio.create_task(llm_router.aadapter_completion(**data))
+ elif (
+ llm_router is not None and data["model"] in llm_router.deployment_names
+ ): # model in router deployments, calling a specific deployment on the router
+ llm_response = asyncio.create_task(
+ llm_router.aadapter_completion(**data, specific_deployment=True)
+ )
+ elif (
+ llm_router is not None and data["model"] in llm_router.get_model_ids()
+ ): # model in router model list
+ llm_response = asyncio.create_task(llm_router.aadapter_completion(**data))
+ elif (
+ llm_router is not None
+ and data["model"] not in router_model_names
+ and llm_router.default_deployment is not None
+ ): # model in router deployments, calling a specific deployment on the router
+ llm_response = asyncio.create_task(llm_router.aadapter_completion(**data))
+ elif user_model is not None: # `litellm --model <your-model-name>`
+ llm_response = asyncio.create_task(litellm.aadapter_completion(**data))
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail={
+ "error": "completion: Invalid model name passed in model="
+ + data.get("model", "")
+ },
+ )
+
+ # Await the llm_response task
+ response = await llm_response
+
+ hidden_params = getattr(response, "_hidden_params", {}) or {}
+ model_id = hidden_params.get("model_id", None) or ""
+ cache_key = hidden_params.get("cache_key", None) or ""
+ api_base = hidden_params.get("api_base", None) or ""
+ response_cost = hidden_params.get("response_cost", None) or ""
+
+ ### ALERTING ###
+ asyncio.create_task(
+ proxy_logging_obj.update_request_status(
+ litellm_call_id=data.get("litellm_call_id", ""), status="success"
+ )
+ )
+
+ verbose_proxy_logger.debug("final response: %s", response)
+
+ fastapi_response.headers.update(
+ ProxyBaseLLMRequestProcessing.get_custom_headers(
+ user_api_key_dict=user_api_key_dict,
+ model_id=model_id,
+ cache_key=cache_key,
+ api_base=api_base,
+ version=version,
+ response_cost=response_cost,
+ )
+ )
+
+ verbose_proxy_logger.info("\nResponse from Litellm:\n{}".format(response))
+ return response
+ except Exception as e:
+ await proxy_logging_obj.post_call_failure_hook(
+ user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
+ )
+ verbose_proxy_logger.exception(
+ "litellm.proxy.proxy_server.completion(): Exception occured - {}".format(
+ str(e)
+ )
+ )
+ error_msg = f"{str(e)}"
+ raise ProxyException(
+ message=getattr(e, "message", error_msg),
+ type=getattr(e, "type", "None"),
+ param=getattr(e, "param", "None"),
+ code=getattr(e, "status_code", 500),
+ )
+
+
+class HttpPassThroughEndpointHelpers:
+ @staticmethod
+ def forward_headers_from_request(
+ request: Request,
+ headers: dict,
+ forward_headers: Optional[bool] = False,
+ ):
+ """
+ Helper to forward headers from original request
+ """
+ if forward_headers is True:
+ request_headers = dict(request.headers)
+
+ # Header We Should NOT forward
+ request_headers.pop("content-length", None)
+ request_headers.pop("host", None)
+
+ # Combine request headers with custom headers
+ headers = {**request_headers, **headers}
+ return headers
+
+ @staticmethod
+ def get_response_headers(
+ headers: httpx.Headers,
+ litellm_call_id: Optional[str] = None,
+ custom_headers: Optional[dict] = None,
+ ) -> dict:
+ excluded_headers = {"transfer-encoding", "content-encoding"}
+
+ return_headers = {
+ key: value
+ for key, value in headers.items()
+ if key.lower() not in excluded_headers
+ }
+ if litellm_call_id:
+ return_headers["x-litellm-call-id"] = litellm_call_id
+ if custom_headers:
+ return_headers.update(custom_headers)
+
+ return return_headers
+
+ @staticmethod
+ def get_endpoint_type(url: str) -> EndpointType:
+ parsed_url = urlparse(url)
+ if ("generateContent") in url or ("streamGenerateContent") in url:
+ return EndpointType.VERTEX_AI
+ elif parsed_url.hostname == "api.anthropic.com":
+ return EndpointType.ANTHROPIC
+ return EndpointType.GENERIC
+
+ @staticmethod
+ def get_merged_query_parameters(
+ existing_url: httpx.URL, request_query_params: Dict[str, Union[str, list]]
+ ) -> Dict[str, Union[str, List[str]]]:
+ # Get the existing query params from the target URL
+ existing_query_string = existing_url.query.decode("utf-8")
+ existing_query_params = parse_qs(existing_query_string)
+
+ # parse_qs returns a dict where each value is a list, so let's flatten it
+ updated_existing_query_params = {
+ k: v[0] if len(v) == 1 else v for k, v in existing_query_params.items()
+ }
+ # Merge the query params, giving priority to the existing ones
+ return {**request_query_params, **updated_existing_query_params}
+
+ @staticmethod
+ async def _make_non_streaming_http_request(
+ request: Request,
+ async_client: httpx.AsyncClient,
+ url: str,
+ headers: dict,
+ requested_query_params: Optional[dict] = None,
+ custom_body: Optional[dict] = None,
+ ) -> httpx.Response:
+ """
+ Make a non-streaming HTTP request
+
+ If request is GET, don't include a JSON body
+ """
+ if request.method == "GET":
+ response = await async_client.request(
+ method=request.method,
+ url=url,
+ headers=headers,
+ params=requested_query_params,
+ )
+ else:
+ response = await async_client.request(
+ method=request.method,
+ url=url,
+ headers=headers,
+ params=requested_query_params,
+ json=custom_body,
+ )
+ return response
+
+
+async def pass_through_request( # noqa: PLR0915
+ request: Request,
+ target: str,
+ custom_headers: dict,
+ user_api_key_dict: UserAPIKeyAuth,
+ custom_body: Optional[dict] = None,
+ forward_headers: Optional[bool] = False,
+ merge_query_params: Optional[bool] = False,
+ query_params: Optional[dict] = None,
+ stream: Optional[bool] = None,
+):
+ litellm_call_id = str(uuid.uuid4())
+ url: Optional[httpx.URL] = None
+ try:
+
+ from litellm.litellm_core_utils.litellm_logging import Logging
+ from litellm.proxy.proxy_server import proxy_logging_obj
+
+ url = httpx.URL(target)
+ headers = custom_headers
+ headers = HttpPassThroughEndpointHelpers.forward_headers_from_request(
+ request=request, headers=headers, forward_headers=forward_headers
+ )
+
+ if merge_query_params:
+
+ # Create a new URL with the merged query params
+ url = url.copy_with(
+ query=urlencode(
+ HttpPassThroughEndpointHelpers.get_merged_query_parameters(
+ existing_url=url,
+ request_query_params=dict(request.query_params),
+ )
+ ).encode("ascii")
+ )
+
+ endpoint_type: EndpointType = HttpPassThroughEndpointHelpers.get_endpoint_type(
+ str(url)
+ )
+
+ _parsed_body = None
+ if custom_body:
+ _parsed_body = custom_body
+ else:
+ _parsed_body = await _read_request_body(request)
+ verbose_proxy_logger.debug(
+ "Pass through endpoint sending request to \nURL {}\nheaders: {}\nbody: {}\n".format(
+ url, headers, _parsed_body
+ )
+ )
+
+ ### CALL HOOKS ### - modify incoming data / reject request before calling the model
+ _parsed_body = await proxy_logging_obj.pre_call_hook(
+ user_api_key_dict=user_api_key_dict,
+ data=_parsed_body,
+ call_type="pass_through_endpoint",
+ )
+ async_client_obj = get_async_httpx_client(
+ llm_provider=httpxSpecialProvider.PassThroughEndpoint,
+ params={"timeout": 600},
+ )
+ async_client = async_client_obj.client
+
+ # create logging object
+ start_time = datetime.now()
+ logging_obj = Logging(
+ model="unknown",
+ messages=[{"role": "user", "content": json.dumps(_parsed_body)}],
+ stream=False,
+ call_type="pass_through_endpoint",
+ start_time=start_time,
+ litellm_call_id=litellm_call_id,
+ function_id="1245",
+ )
+ passthrough_logging_payload = PassthroughStandardLoggingPayload(
+ url=str(url),
+ request_body=_parsed_body,
+ )
+ kwargs = _init_kwargs_for_pass_through_endpoint(
+ user_api_key_dict=user_api_key_dict,
+ _parsed_body=_parsed_body,
+ passthrough_logging_payload=passthrough_logging_payload,
+ litellm_call_id=litellm_call_id,
+ request=request,
+ )
+ # done for supporting 'parallel_request_limiter.py' with pass-through endpoints
+ logging_obj.update_environment_variables(
+ model="unknown",
+ user="unknown",
+ optional_params={},
+ litellm_params=kwargs["litellm_params"],
+ call_type="pass_through_endpoint",
+ )
+ logging_obj.model_call_details["litellm_call_id"] = litellm_call_id
+
+ # combine url with query params for logging
+
+ requested_query_params: Optional[dict] = (
+ query_params or request.query_params.__dict__
+ )
+ if requested_query_params == request.query_params.__dict__:
+ requested_query_params = None
+
+ requested_query_params_str = None
+ if requested_query_params:
+ requested_query_params_str = "&".join(
+ f"{k}={v}" for k, v in requested_query_params.items()
+ )
+
+ logging_url = str(url)
+ if requested_query_params_str:
+ if "?" in str(url):
+ logging_url = str(url) + "&" + requested_query_params_str
+ else:
+ logging_url = str(url) + "?" + requested_query_params_str
+
+ logging_obj.pre_call(
+ input=[{"role": "user", "content": json.dumps(_parsed_body)}],
+ api_key="",
+ additional_args={
+ "complete_input_dict": _parsed_body,
+ "api_base": str(logging_url),
+ "headers": headers,
+ },
+ )
+ if stream:
+ req = async_client.build_request(
+ "POST",
+ url,
+ json=_parsed_body,
+ params=requested_query_params,
+ headers=headers,
+ )
+
+ response = await async_client.send(req, stream=stream)
+
+ try:
+ response.raise_for_status()
+ except httpx.HTTPStatusError as e:
+ raise HTTPException(
+ status_code=e.response.status_code, detail=await e.response.aread()
+ )
+
+ return StreamingResponse(
+ PassThroughStreamingHandler.chunk_processor(
+ response=response,
+ request_body=_parsed_body,
+ litellm_logging_obj=logging_obj,
+ endpoint_type=endpoint_type,
+ start_time=start_time,
+ passthrough_success_handler_obj=pass_through_endpoint_logging,
+ url_route=str(url),
+ ),
+ headers=HttpPassThroughEndpointHelpers.get_response_headers(
+ headers=response.headers,
+ litellm_call_id=litellm_call_id,
+ ),
+ status_code=response.status_code,
+ )
+
+ verbose_proxy_logger.debug("request method: {}".format(request.method))
+ verbose_proxy_logger.debug("request url: {}".format(url))
+ verbose_proxy_logger.debug("request headers: {}".format(headers))
+ verbose_proxy_logger.debug(
+ "requested_query_params={}".format(requested_query_params)
+ )
+ verbose_proxy_logger.debug("request body: {}".format(_parsed_body))
+
+ if request.method == "GET":
+ response = await async_client.request(
+ method=request.method,
+ url=url,
+ headers=headers,
+ params=requested_query_params,
+ )
+ else:
+ response = await async_client.request(
+ method=request.method,
+ url=url,
+ headers=headers,
+ params=requested_query_params,
+ json=_parsed_body,
+ )
+
+ verbose_proxy_logger.debug("response.headers= %s", response.headers)
+
+ if _is_streaming_response(response) is True:
+ try:
+ response.raise_for_status()
+ except httpx.HTTPStatusError as e:
+ raise HTTPException(
+ status_code=e.response.status_code, detail=await e.response.aread()
+ )
+
+ return StreamingResponse(
+ PassThroughStreamingHandler.chunk_processor(
+ response=response,
+ request_body=_parsed_body,
+ litellm_logging_obj=logging_obj,
+ endpoint_type=endpoint_type,
+ start_time=start_time,
+ passthrough_success_handler_obj=pass_through_endpoint_logging,
+ url_route=str(url),
+ ),
+ headers=HttpPassThroughEndpointHelpers.get_response_headers(
+ headers=response.headers,
+ litellm_call_id=litellm_call_id,
+ ),
+ status_code=response.status_code,
+ )
+
+ try:
+ response.raise_for_status()
+ except httpx.HTTPStatusError as e:
+ raise HTTPException(
+ status_code=e.response.status_code, detail=e.response.text
+ )
+
+ if response.status_code >= 300:
+ raise HTTPException(status_code=response.status_code, detail=response.text)
+
+ content = await response.aread()
+
+ ## LOG SUCCESS
+ response_body: Optional[dict] = get_response_body(response)
+ passthrough_logging_payload["response_body"] = response_body
+ end_time = datetime.now()
+ asyncio.create_task(
+ pass_through_endpoint_logging.pass_through_async_success_handler(
+ httpx_response=response,
+ response_body=response_body,
+ url_route=str(url),
+ result="",
+ start_time=start_time,
+ end_time=end_time,
+ logging_obj=logging_obj,
+ cache_hit=False,
+ **kwargs,
+ )
+ )
+
+ ## CUSTOM HEADERS - `x-litellm-*`
+ custom_headers = ProxyBaseLLMRequestProcessing.get_custom_headers(
+ user_api_key_dict=user_api_key_dict,
+ call_id=litellm_call_id,
+ model_id=None,
+ cache_key=None,
+ api_base=str(url._uri_reference),
+ )
+
+ return Response(
+ content=content,
+ status_code=response.status_code,
+ headers=HttpPassThroughEndpointHelpers.get_response_headers(
+ headers=response.headers,
+ custom_headers=custom_headers,
+ ),
+ )
+ except Exception as e:
+ custom_headers = ProxyBaseLLMRequestProcessing.get_custom_headers(
+ user_api_key_dict=user_api_key_dict,
+ call_id=litellm_call_id,
+ model_id=None,
+ cache_key=None,
+ api_base=str(url._uri_reference) if url else None,
+ )
+ verbose_proxy_logger.exception(
+ "litellm.proxy.proxy_server.pass_through_endpoint(): Exception occured - {}".format(
+ str(e)
+ )
+ )
+ if isinstance(e, HTTPException):
+ raise ProxyException(
+ message=getattr(e, "message", str(e.detail)),
+ type=getattr(e, "type", "None"),
+ param=getattr(e, "param", "None"),
+ code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
+ headers=custom_headers,
+ )
+ else:
+ error_msg = f"{str(e)}"
+ raise ProxyException(
+ message=getattr(e, "message", error_msg),
+ type=getattr(e, "type", "None"),
+ param=getattr(e, "param", "None"),
+ code=getattr(e, "status_code", 500),
+ headers=custom_headers,
+ )
+
+
+def _init_kwargs_for_pass_through_endpoint(
+ request: Request,
+ user_api_key_dict: UserAPIKeyAuth,
+ passthrough_logging_payload: PassthroughStandardLoggingPayload,
+ _parsed_body: Optional[dict] = None,
+ litellm_call_id: Optional[str] = None,
+) -> dict:
+ _parsed_body = _parsed_body or {}
+ _litellm_metadata: Optional[dict] = _parsed_body.pop("litellm_metadata", None)
+ _metadata = dict(
+ StandardLoggingUserAPIKeyMetadata(
+ user_api_key_hash=user_api_key_dict.api_key,
+ user_api_key_alias=user_api_key_dict.key_alias,
+ user_api_key_user_email=user_api_key_dict.user_email,
+ user_api_key_user_id=user_api_key_dict.user_id,
+ user_api_key_team_id=user_api_key_dict.team_id,
+ user_api_key_org_id=user_api_key_dict.org_id,
+ user_api_key_team_alias=user_api_key_dict.team_alias,
+ user_api_key_end_user_id=user_api_key_dict.end_user_id,
+ )
+ )
+ _metadata["user_api_key"] = user_api_key_dict.api_key
+ if _litellm_metadata:
+ _metadata.update(_litellm_metadata)
+
+ _metadata = _update_metadata_with_tags_in_header(
+ request=request,
+ metadata=_metadata,
+ )
+
+ kwargs = {
+ "litellm_params": {
+ "metadata": _metadata,
+ },
+ "call_type": "pass_through_endpoint",
+ "litellm_call_id": litellm_call_id,
+ "passthrough_logging_payload": passthrough_logging_payload,
+ }
+ return kwargs
+
+
+def _update_metadata_with_tags_in_header(request: Request, metadata: dict) -> dict:
+ """
+ If tags are in the request headers, add them to the metadata
+
+ Used for google and vertex JS SDKs
+ """
+ _tags = request.headers.get("tags")
+ if _tags:
+ metadata["tags"] = _tags.split(",")
+ return metadata
+
+
+def create_pass_through_route(
+ endpoint,
+ target: str,
+ custom_headers: Optional[dict] = None,
+ _forward_headers: Optional[bool] = False,
+ _merge_query_params: Optional[bool] = False,
+ dependencies: Optional[List] = None,
+):
+ # check if target is an adapter.py or a url
+ import uuid
+
+ from litellm.proxy.types_utils.utils import get_instance_fn
+
+ try:
+ if isinstance(target, CustomLogger):
+ adapter = target
+ else:
+ adapter = get_instance_fn(value=target)
+ adapter_id = str(uuid.uuid4())
+ litellm.adapters = [{"id": adapter_id, "adapter": adapter}]
+
+ async def endpoint_func( # type: ignore
+ request: Request,
+ fastapi_response: Response,
+ user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
+ ):
+ return await chat_completion_pass_through_endpoint(
+ fastapi_response=fastapi_response,
+ request=request,
+ adapter_id=adapter_id,
+ user_api_key_dict=user_api_key_dict,
+ )
+
+ except Exception:
+ verbose_proxy_logger.debug("Defaulting to target being a url.")
+
+ async def endpoint_func( # type: ignore
+ request: Request,
+ fastapi_response: Response,
+ user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
+ query_params: Optional[dict] = None,
+ custom_body: Optional[dict] = None,
+ stream: Optional[
+ bool
+ ] = None, # if pass-through endpoint is a streaming request
+ ):
+ return await pass_through_request( # type: ignore
+ request=request,
+ target=target,
+ custom_headers=custom_headers or {},
+ user_api_key_dict=user_api_key_dict,
+ forward_headers=_forward_headers,
+ merge_query_params=_merge_query_params,
+ query_params=query_params,
+ stream=stream,
+ custom_body=custom_body,
+ )
+
+ return endpoint_func
+
+
+def _is_streaming_response(response: httpx.Response) -> bool:
+ _content_type = response.headers.get("content-type")
+ if _content_type is not None and "text/event-stream" in _content_type:
+ return True
+ return False
+
+
+async def initialize_pass_through_endpoints(pass_through_endpoints: list):
+
+ verbose_proxy_logger.debug("initializing pass through endpoints")
+ from litellm.proxy._types import CommonProxyErrors, LiteLLMRoutes
+ from litellm.proxy.proxy_server import app, premium_user
+
+ for endpoint in pass_through_endpoints:
+ _target = endpoint.get("target", None)
+ _path = endpoint.get("path", None)
+ _custom_headers = endpoint.get("headers", None)
+ _custom_headers = await set_env_variables_in_header(
+ custom_headers=_custom_headers
+ )
+ _forward_headers = endpoint.get("forward_headers", None)
+ _merge_query_params = endpoint.get("merge_query_params", None)
+ _auth = endpoint.get("auth", None)
+ _dependencies = None
+ if _auth is not None and str(_auth).lower() == "true":
+ if premium_user is not True:
+ raise ValueError(
+ "Error Setting Authentication on Pass Through Endpoint: {}".format(
+ CommonProxyErrors.not_premium_user.value
+ )
+ )
+ _dependencies = [Depends(user_api_key_auth)]
+ LiteLLMRoutes.openai_routes.value.append(_path)
+
+ if _target is None:
+ continue
+
+ verbose_proxy_logger.debug(
+ "adding pass through endpoint: %s, dependencies: %s", _path, _dependencies
+ )
+ app.add_api_route( # type: ignore
+ path=_path,
+ endpoint=create_pass_through_route( # type: ignore
+ _path,
+ _target,
+ _custom_headers,
+ _forward_headers,
+ _merge_query_params,
+ _dependencies,
+ ),
+ methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
+ dependencies=_dependencies,
+ )
+
+ verbose_proxy_logger.debug("Added new pass through endpoint: %s", _path)
+
+
+@router.get(
+ "/config/pass_through_endpoint",
+ dependencies=[Depends(user_api_key_auth)],
+ response_model=PassThroughEndpointResponse,
+)
+async def get_pass_through_endpoints(
+ endpoint_id: Optional[str] = None,
+ user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
+):
+ """
+ GET configured pass through endpoint.
+
+ If no endpoint_id given, return all configured endpoints.
+ """
+ from litellm.proxy.proxy_server import get_config_general_settings
+
+ ## Get existing pass-through endpoint field value
+ try:
+ response: ConfigFieldInfo = await get_config_general_settings(
+ field_name="pass_through_endpoints", user_api_key_dict=user_api_key_dict
+ )
+ except Exception:
+ return PassThroughEndpointResponse(endpoints=[])
+
+ pass_through_endpoint_data: Optional[List] = response.field_value
+ if pass_through_endpoint_data is None:
+ return PassThroughEndpointResponse(endpoints=[])
+
+ returned_endpoints = []
+ if endpoint_id is None:
+ for endpoint in pass_through_endpoint_data:
+ if isinstance(endpoint, dict):
+ returned_endpoints.append(PassThroughGenericEndpoint(**endpoint))
+ elif isinstance(endpoint, PassThroughGenericEndpoint):
+ returned_endpoints.append(endpoint)
+ elif endpoint_id is not None:
+ for endpoint in pass_through_endpoint_data:
+ _endpoint: Optional[PassThroughGenericEndpoint] = None
+ if isinstance(endpoint, dict):
+ _endpoint = PassThroughGenericEndpoint(**endpoint)
+ elif isinstance(endpoint, PassThroughGenericEndpoint):
+ _endpoint = endpoint
+
+ if _endpoint is not None and _endpoint.path == endpoint_id:
+ returned_endpoints.append(_endpoint)
+
+ return PassThroughEndpointResponse(endpoints=returned_endpoints)
+
+
+@router.post(
+ "/config/pass_through_endpoint/{endpoint_id}",
+ dependencies=[Depends(user_api_key_auth)],
+)
+async def update_pass_through_endpoints(request: Request, endpoint_id: str):
+ """
+ Update a pass-through endpoint
+ """
+ pass
+
+
+@router.post(
+ "/config/pass_through_endpoint",
+ dependencies=[Depends(user_api_key_auth)],
+)
+async def create_pass_through_endpoints(
+ data: PassThroughGenericEndpoint,
+ user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
+):
+ """
+ Create new pass-through endpoint
+ """
+ from litellm.proxy.proxy_server import (
+ get_config_general_settings,
+ update_config_general_settings,
+ )
+
+ ## Get existing pass-through endpoint field value
+
+ try:
+ response: ConfigFieldInfo = await get_config_general_settings(
+ field_name="pass_through_endpoints", user_api_key_dict=user_api_key_dict
+ )
+ except Exception:
+ response = ConfigFieldInfo(
+ field_name="pass_through_endpoints", field_value=None
+ )
+
+ ## Update field with new endpoint
+ data_dict = data.model_dump()
+ if response.field_value is None:
+ response.field_value = [data_dict]
+ elif isinstance(response.field_value, List):
+ response.field_value.append(data_dict)
+
+ ## Update db
+ updated_data = ConfigFieldUpdate(
+ field_name="pass_through_endpoints",
+ field_value=response.field_value,
+ config_type="general_settings",
+ )
+ await update_config_general_settings(
+ data=updated_data, user_api_key_dict=user_api_key_dict
+ )
+
+
+@router.delete(
+ "/config/pass_through_endpoint",
+ dependencies=[Depends(user_api_key_auth)],
+ response_model=PassThroughEndpointResponse,
+)
+async def delete_pass_through_endpoints(
+ endpoint_id: str,
+ user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
+):
+ """
+ Delete a pass-through endpoint
+
+ Returns - the deleted endpoint
+ """
+ from litellm.proxy.proxy_server import (
+ get_config_general_settings,
+ update_config_general_settings,
+ )
+
+ ## Get existing pass-through endpoint field value
+
+ try:
+ response: ConfigFieldInfo = await get_config_general_settings(
+ field_name="pass_through_endpoints", user_api_key_dict=user_api_key_dict
+ )
+ except Exception:
+ response = ConfigFieldInfo(
+ field_name="pass_through_endpoints", field_value=None
+ )
+
+ ## Update field by removing endpoint
+ pass_through_endpoint_data: Optional[List] = response.field_value
+ response_obj: Optional[PassThroughGenericEndpoint] = None
+ if response.field_value is None or pass_through_endpoint_data is None:
+ raise HTTPException(
+ status_code=400,
+ detail={"error": "There are no pass-through endpoints setup."},
+ )
+ elif isinstance(response.field_value, List):
+ invalid_idx: Optional[int] = None
+ for idx, endpoint in enumerate(pass_through_endpoint_data):
+ _endpoint: Optional[PassThroughGenericEndpoint] = None
+ if isinstance(endpoint, dict):
+ _endpoint = PassThroughGenericEndpoint(**endpoint)
+ elif isinstance(endpoint, PassThroughGenericEndpoint):
+ _endpoint = endpoint
+
+ if _endpoint is not None and _endpoint.path == endpoint_id:
+ invalid_idx = idx
+ response_obj = _endpoint
+
+ if invalid_idx is not None:
+ pass_through_endpoint_data.pop(invalid_idx)
+
+ ## Update db
+ updated_data = ConfigFieldUpdate(
+ field_name="pass_through_endpoints",
+ field_value=pass_through_endpoint_data,
+ config_type="general_settings",
+ )
+ await update_config_general_settings(
+ data=updated_data, user_api_key_dict=user_api_key_dict
+ )
+
+ if response_obj is None:
+ raise HTTPException(
+ status_code=400,
+ detail={
+ "error": "Endpoint={} was not found in pass-through endpoint list.".format(
+ endpoint_id
+ )
+ },
+ )
+ return PassThroughEndpointResponse(endpoints=[response_obj])