diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/proxy/pass_through_endpoints | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/proxy/pass_through_endpoints')
9 files changed, 2814 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py b/.venv/lib/python3.12/site-packages/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py new file mode 100644 index 00000000..4724c7f9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py @@ -0,0 +1,556 @@ +""" +What is this? + +Provider-specific Pass-Through Endpoints + +Use litellm with Anthropic SDK, Vertex AI SDK, Cohere SDK, etc. +""" + +from typing import Optional + +import httpx +from fastapi import APIRouter, Depends, HTTPException, Request, Response + +import litellm +from litellm.constants import BEDROCK_AGENT_RUNTIME_PASS_THROUGH_ROUTES +from litellm.proxy._types import * +from litellm.proxy.auth.route_checks import RouteChecks +from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.proxy.pass_through_endpoints.pass_through_endpoints import ( + create_pass_through_route, +) +from litellm.secret_managers.main import get_secret_str + +from .passthrough_endpoint_router import PassthroughEndpointRouter + +router = APIRouter() +default_vertex_config = None + +passthrough_endpoint_router = PassthroughEndpointRouter() + + +def create_request_copy(request: Request): + return { + "method": request.method, + "url": str(request.url), + "headers": dict(request.headers), + "cookies": request.cookies, + "query_params": dict(request.query_params), + } + + +@router.api_route( + "/gemini/{endpoint:path}", + methods=["GET", "POST", "PUT", "DELETE", "PATCH"], + tags=["Google AI Studio Pass-through", "pass-through"], +) +async def gemini_proxy_route( + endpoint: str, + request: Request, + fastapi_response: Response, +): + """ + [Docs](https://docs.litellm.ai/docs/pass_through/google_ai_studio) + """ + ## CHECK FOR LITELLM API KEY IN THE QUERY PARAMS - ?..key=LITELLM_API_KEY + google_ai_studio_api_key = request.query_params.get("key") or request.headers.get( + "x-goog-api-key" + ) + + user_api_key_dict = await user_api_key_auth( + request=request, api_key=f"Bearer {google_ai_studio_api_key}" + ) + + base_target_url = "https://generativelanguage.googleapis.com" + encoded_endpoint = httpx.URL(endpoint).path + + # Ensure endpoint starts with '/' for proper URL construction + if not encoded_endpoint.startswith("/"): + encoded_endpoint = "/" + encoded_endpoint + + # Construct the full target URL using httpx + base_url = httpx.URL(base_target_url) + updated_url = base_url.copy_with(path=encoded_endpoint) + + # Add or update query parameters + gemini_api_key: Optional[str] = passthrough_endpoint_router.get_credentials( + custom_llm_provider="gemini", + region_name=None, + ) + if gemini_api_key is None: + raise Exception( + "Required 'GEMINI_API_KEY' in environment to make pass-through calls to Google AI Studio." + ) + # Merge query parameters, giving precedence to those in updated_url + merged_params = dict(request.query_params) + merged_params.update({"key": gemini_api_key}) + + ## check for streaming + is_streaming_request = False + if "stream" in str(updated_url): + is_streaming_request = True + + ## CREATE PASS-THROUGH + endpoint_func = create_pass_through_route( + endpoint=endpoint, + target=str(updated_url), + ) # dynamically construct pass-through endpoint based on incoming path + received_value = await endpoint_func( + request, + fastapi_response, + user_api_key_dict, + query_params=merged_params, # type: ignore + stream=is_streaming_request, # type: ignore + ) + + return received_value + + +@router.api_route( + "/cohere/{endpoint:path}", + methods=["GET", "POST", "PUT", "DELETE", "PATCH"], + tags=["Cohere Pass-through", "pass-through"], +) +async def cohere_proxy_route( + endpoint: str, + request: Request, + fastapi_response: Response, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + [Docs](https://docs.litellm.ai/docs/pass_through/cohere) + """ + base_target_url = "https://api.cohere.com" + encoded_endpoint = httpx.URL(endpoint).path + + # Ensure endpoint starts with '/' for proper URL construction + if not encoded_endpoint.startswith("/"): + encoded_endpoint = "/" + encoded_endpoint + + # Construct the full target URL using httpx + base_url = httpx.URL(base_target_url) + updated_url = base_url.copy_with(path=encoded_endpoint) + + # Add or update query parameters + cohere_api_key = passthrough_endpoint_router.get_credentials( + custom_llm_provider="cohere", + region_name=None, + ) + + ## check for streaming + is_streaming_request = False + if "stream" in str(updated_url): + is_streaming_request = True + + ## CREATE PASS-THROUGH + endpoint_func = create_pass_through_route( + endpoint=endpoint, + target=str(updated_url), + custom_headers={"Authorization": "Bearer {}".format(cohere_api_key)}, + ) # dynamically construct pass-through endpoint based on incoming path + received_value = await endpoint_func( + request, + fastapi_response, + user_api_key_dict, + stream=is_streaming_request, # type: ignore + ) + + return received_value + + +@router.api_route( + "/anthropic/{endpoint:path}", + methods=["GET", "POST", "PUT", "DELETE", "PATCH"], + tags=["Anthropic Pass-through", "pass-through"], +) +async def anthropic_proxy_route( + endpoint: str, + request: Request, + fastapi_response: Response, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + [Docs](https://docs.litellm.ai/docs/anthropic_completion) + """ + base_target_url = "https://api.anthropic.com" + encoded_endpoint = httpx.URL(endpoint).path + + # Ensure endpoint starts with '/' for proper URL construction + if not encoded_endpoint.startswith("/"): + encoded_endpoint = "/" + encoded_endpoint + + # Construct the full target URL using httpx + base_url = httpx.URL(base_target_url) + updated_url = base_url.copy_with(path=encoded_endpoint) + + # Add or update query parameters + anthropic_api_key = passthrough_endpoint_router.get_credentials( + custom_llm_provider="anthropic", + region_name=None, + ) + + ## check for streaming + is_streaming_request = False + # anthropic is streaming when 'stream' = True is in the body + if request.method == "POST": + _request_body = await request.json() + if _request_body.get("stream"): + is_streaming_request = True + + ## CREATE PASS-THROUGH + endpoint_func = create_pass_through_route( + endpoint=endpoint, + target=str(updated_url), + custom_headers={"x-api-key": "{}".format(anthropic_api_key)}, + _forward_headers=True, + ) # dynamically construct pass-through endpoint based on incoming path + received_value = await endpoint_func( + request, + fastapi_response, + user_api_key_dict, + stream=is_streaming_request, # type: ignore + ) + + return received_value + + +@router.api_route( + "/bedrock/{endpoint:path}", + methods=["GET", "POST", "PUT", "DELETE", "PATCH"], + tags=["Bedrock Pass-through", "pass-through"], +) +async def bedrock_proxy_route( + endpoint: str, + request: Request, + fastapi_response: Response, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + [Docs](https://docs.litellm.ai/docs/pass_through/bedrock) + """ + create_request_copy(request) + + try: + from botocore.auth import SigV4Auth + from botocore.awsrequest import AWSRequest + from botocore.credentials import Credentials + except ImportError: + raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") + + aws_region_name = litellm.utils.get_secret(secret_name="AWS_REGION_NAME") + if _is_bedrock_agent_runtime_route(endpoint=endpoint): # handle bedrock agents + base_target_url = ( + f"https://bedrock-agent-runtime.{aws_region_name}.amazonaws.com" + ) + else: + base_target_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com" + encoded_endpoint = httpx.URL(endpoint).path + + # Ensure endpoint starts with '/' for proper URL construction + if not encoded_endpoint.startswith("/"): + encoded_endpoint = "/" + encoded_endpoint + + # Construct the full target URL using httpx + base_url = httpx.URL(base_target_url) + updated_url = base_url.copy_with(path=encoded_endpoint) + + # Add or update query parameters + from litellm.llms.bedrock.chat import BedrockConverseLLM + + credentials: Credentials = BedrockConverseLLM().get_credentials() + sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name) + headers = {"Content-Type": "application/json"} + # Assuming the body contains JSON data, parse it + try: + data = await request.json() + except Exception as e: + raise HTTPException(status_code=400, detail={"error": e}) + _request = AWSRequest( + method="POST", url=str(updated_url), data=json.dumps(data), headers=headers + ) + sigv4.add_auth(_request) + prepped = _request.prepare() + + ## check for streaming + is_streaming_request = False + if "stream" in str(updated_url): + is_streaming_request = True + + ## CREATE PASS-THROUGH + endpoint_func = create_pass_through_route( + endpoint=endpoint, + target=str(prepped.url), + custom_headers=prepped.headers, # type: ignore + ) # dynamically construct pass-through endpoint based on incoming path + received_value = await endpoint_func( + request, + fastapi_response, + user_api_key_dict, + stream=is_streaming_request, # type: ignore + custom_body=data, # type: ignore + query_params={}, # type: ignore + ) + + return received_value + + +def _is_bedrock_agent_runtime_route(endpoint: str) -> bool: + """ + Return True, if the endpoint should be routed to the `bedrock-agent-runtime` endpoint. + """ + for _route in BEDROCK_AGENT_RUNTIME_PASS_THROUGH_ROUTES: + if _route in endpoint: + return True + return False + + +@router.api_route( + "/assemblyai/{endpoint:path}", + methods=["GET", "POST", "PUT", "DELETE", "PATCH"], + tags=["AssemblyAI Pass-through", "pass-through"], +) +@router.api_route( + "/eu.assemblyai/{endpoint:path}", + methods=["GET", "POST", "PUT", "DELETE", "PATCH"], + tags=["AssemblyAI EU Pass-through", "pass-through"], +) +async def assemblyai_proxy_route( + endpoint: str, + request: Request, + fastapi_response: Response, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + from litellm.proxy.pass_through_endpoints.llm_provider_handlers.assembly_passthrough_logging_handler import ( + AssemblyAIPassthroughLoggingHandler, + ) + + """ + [Docs](https://api.assemblyai.com) + """ + # Set base URL based on the route + assembly_region = AssemblyAIPassthroughLoggingHandler._get_assembly_region_from_url( + url=str(request.url) + ) + base_target_url = ( + AssemblyAIPassthroughLoggingHandler._get_assembly_base_url_from_region( + region=assembly_region + ) + ) + encoded_endpoint = httpx.URL(endpoint).path + # Ensure endpoint starts with '/' for proper URL construction + if not encoded_endpoint.startswith("/"): + encoded_endpoint = "/" + encoded_endpoint + + # Construct the full target URL using httpx + base_url = httpx.URL(base_target_url) + updated_url = base_url.copy_with(path=encoded_endpoint) + + # Add or update query parameters + assemblyai_api_key = passthrough_endpoint_router.get_credentials( + custom_llm_provider="assemblyai", + region_name=assembly_region, + ) + + ## check for streaming + is_streaming_request = False + # assemblyai is streaming when 'stream' = True is in the body + if request.method == "POST": + _request_body = await request.json() + if _request_body.get("stream"): + is_streaming_request = True + + ## CREATE PASS-THROUGH + endpoint_func = create_pass_through_route( + endpoint=endpoint, + target=str(updated_url), + custom_headers={"Authorization": "{}".format(assemblyai_api_key)}, + ) # dynamically construct pass-through endpoint based on incoming path + received_value = await endpoint_func( + request=request, + fastapi_response=fastapi_response, + user_api_key_dict=user_api_key_dict, + stream=is_streaming_request, # type: ignore + ) + + return received_value + + +@router.api_route( + "/azure/{endpoint:path}", + methods=["GET", "POST", "PUT", "DELETE", "PATCH"], + tags=["Azure Pass-through", "pass-through"], +) +async def azure_proxy_route( + endpoint: str, + request: Request, + fastapi_response: Response, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Call any azure endpoint using the proxy. + + Just use `{PROXY_BASE_URL}/azure/{endpoint:path}` + """ + base_target_url = get_secret_str(secret_name="AZURE_API_BASE") + if base_target_url is None: + raise Exception( + "Required 'AZURE_API_BASE' in environment to make pass-through calls to Azure." + ) + # Add or update query parameters + azure_api_key = passthrough_endpoint_router.get_credentials( + custom_llm_provider=litellm.LlmProviders.AZURE.value, + region_name=None, + ) + if azure_api_key is None: + raise Exception( + "Required 'AZURE_API_KEY' in environment to make pass-through calls to Azure." + ) + + return await BaseOpenAIPassThroughHandler._base_openai_pass_through_handler( + endpoint=endpoint, + request=request, + fastapi_response=fastapi_response, + user_api_key_dict=user_api_key_dict, + base_target_url=base_target_url, + api_key=azure_api_key, + custom_llm_provider=litellm.LlmProviders.AZURE, + ) + + +@router.api_route( + "/openai/{endpoint:path}", + methods=["GET", "POST", "PUT", "DELETE", "PATCH"], + tags=["OpenAI Pass-through", "pass-through"], +) +async def openai_proxy_route( + endpoint: str, + request: Request, + fastapi_response: Response, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Simple pass-through for OpenAI. Use this if you want to directly send a request to OpenAI. + + + """ + base_target_url = "https://api.openai.com/" + # Add or update query parameters + openai_api_key = passthrough_endpoint_router.get_credentials( + custom_llm_provider=litellm.LlmProviders.OPENAI.value, + region_name=None, + ) + if openai_api_key is None: + raise Exception( + "Required 'OPENAI_API_KEY' in environment to make pass-through calls to OpenAI." + ) + + return await BaseOpenAIPassThroughHandler._base_openai_pass_through_handler( + endpoint=endpoint, + request=request, + fastapi_response=fastapi_response, + user_api_key_dict=user_api_key_dict, + base_target_url=base_target_url, + api_key=openai_api_key, + custom_llm_provider=litellm.LlmProviders.OPENAI, + ) + + +class BaseOpenAIPassThroughHandler: + @staticmethod + async def _base_openai_pass_through_handler( + endpoint: str, + request: Request, + fastapi_response: Response, + user_api_key_dict: UserAPIKeyAuth, + base_target_url: str, + api_key: str, + custom_llm_provider: litellm.LlmProviders, + ): + encoded_endpoint = httpx.URL(endpoint).path + # Ensure endpoint starts with '/' for proper URL construction + if not encoded_endpoint.startswith("/"): + encoded_endpoint = "/" + encoded_endpoint + + # Construct the full target URL by properly joining the base URL and endpoint path + base_url = httpx.URL(base_target_url) + updated_url = BaseOpenAIPassThroughHandler._join_url_paths( + base_url=base_url, + path=encoded_endpoint, + custom_llm_provider=custom_llm_provider, + ) + + ## check for streaming + is_streaming_request = False + if "stream" in str(updated_url): + is_streaming_request = True + + ## CREATE PASS-THROUGH + endpoint_func = create_pass_through_route( + endpoint=endpoint, + target=str(updated_url), + custom_headers=BaseOpenAIPassThroughHandler._assemble_headers( + api_key=api_key, request=request + ), + ) # dynamically construct pass-through endpoint based on incoming path + received_value = await endpoint_func( + request, + fastapi_response, + user_api_key_dict, + stream=is_streaming_request, # type: ignore + query_params=dict(request.query_params), # type: ignore + ) + + return received_value + + @staticmethod + def _append_openai_beta_header(headers: dict, request: Request) -> dict: + """ + Appends the OpenAI-Beta header to the headers if the request is an OpenAI Assistants API request + """ + if ( + RouteChecks._is_assistants_api_request(request) is True + and "OpenAI-Beta" not in headers + ): + headers["OpenAI-Beta"] = "assistants=v2" + return headers + + @staticmethod + def _assemble_headers(api_key: str, request: Request) -> dict: + base_headers = { + "authorization": "Bearer {}".format(api_key), + "api-key": "{}".format(api_key), + } + return BaseOpenAIPassThroughHandler._append_openai_beta_header( + headers=base_headers, + request=request, + ) + + @staticmethod + def _join_url_paths( + base_url: httpx.URL, path: str, custom_llm_provider: litellm.LlmProviders + ) -> str: + """ + Properly joins a base URL with a path, preserving any existing path in the base URL. + """ + # Join paths correctly by removing trailing/leading slashes as needed + if not base_url.path or base_url.path == "/": + # If base URL has no path, just use the new path + joined_path_str = str(base_url.copy_with(path=path)) + else: + # Otherwise, combine the paths + base_path = base_url.path.rstrip("/") + clean_path = path.lstrip("/") + full_path = f"{base_path}/{clean_path}" + joined_path_str = str(base_url.copy_with(path=full_path)) + + # Apply OpenAI-specific path handling for both branches + if ( + custom_llm_provider == litellm.LlmProviders.OPENAI + and "/v1/" not in joined_path_str + ): + # Insert v1 after api.openai.com for OpenAI requests + joined_path_str = joined_path_str.replace( + "api.openai.com/", "api.openai.com/v1/" + ) + + return joined_path_str diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py b/.venv/lib/python3.12/site-packages/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py new file mode 100644 index 00000000..51845956 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py @@ -0,0 +1,220 @@ +import json +from datetime import datetime +from typing import TYPE_CHECKING, Any, List, Optional, Union + +import httpx + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +from litellm.llms.anthropic.chat.handler import ( + ModelResponseIterator as AnthropicModelResponseIterator, +) +from litellm.llms.anthropic.chat.transformation import AnthropicConfig +from litellm.proxy._types import PassThroughEndpointLoggingTypedDict +from litellm.proxy.auth.auth_utils import get_end_user_id_from_request_body +from litellm.proxy.pass_through_endpoints.types import PassthroughStandardLoggingPayload +from litellm.types.utils import ModelResponse, TextCompletionResponse + +if TYPE_CHECKING: + from ..success_handler import PassThroughEndpointLogging + from ..types import EndpointType +else: + PassThroughEndpointLogging = Any + EndpointType = Any + + +class AnthropicPassthroughLoggingHandler: + + @staticmethod + def anthropic_passthrough_handler( + httpx_response: httpx.Response, + response_body: dict, + logging_obj: LiteLLMLoggingObj, + url_route: str, + result: str, + start_time: datetime, + end_time: datetime, + cache_hit: bool, + **kwargs, + ) -> PassThroughEndpointLoggingTypedDict: + """ + Transforms Anthropic response to OpenAI response, generates a standard logging object so downstream logging can be handled + """ + model = response_body.get("model", "") + litellm_model_response: ModelResponse = AnthropicConfig().transform_response( + raw_response=httpx_response, + model_response=litellm.ModelResponse(), + model=model, + messages=[], + logging_obj=logging_obj, + optional_params={}, + api_key="", + request_data={}, + encoding=litellm.encoding, + json_mode=False, + litellm_params={}, + ) + + kwargs = AnthropicPassthroughLoggingHandler._create_anthropic_response_logging_payload( + litellm_model_response=litellm_model_response, + model=model, + kwargs=kwargs, + start_time=start_time, + end_time=end_time, + logging_obj=logging_obj, + ) + + return { + "result": litellm_model_response, + "kwargs": kwargs, + } + + @staticmethod + def _get_user_from_metadata( + passthrough_logging_payload: PassthroughStandardLoggingPayload, + ) -> Optional[str]: + request_body = passthrough_logging_payload.get("request_body") + if request_body: + return get_end_user_id_from_request_body(request_body) + return None + + @staticmethod + def _create_anthropic_response_logging_payload( + litellm_model_response: Union[ModelResponse, TextCompletionResponse], + model: str, + kwargs: dict, + start_time: datetime, + end_time: datetime, + logging_obj: LiteLLMLoggingObj, + ): + """ + Create the standard logging object for Anthropic passthrough + + handles streaming and non-streaming responses + """ + try: + response_cost = litellm.completion_cost( + completion_response=litellm_model_response, + model=model, + ) + kwargs["response_cost"] = response_cost + kwargs["model"] = model + passthrough_logging_payload: Optional[PassthroughStandardLoggingPayload] = ( # type: ignore + kwargs.get("passthrough_logging_payload") + ) + if passthrough_logging_payload: + user = AnthropicPassthroughLoggingHandler._get_user_from_metadata( + passthrough_logging_payload=passthrough_logging_payload, + ) + if user: + kwargs.setdefault("litellm_params", {}) + kwargs["litellm_params"].update( + {"proxy_server_request": {"body": {"user": user}}} + ) + + # pretty print standard logging object + verbose_proxy_logger.debug( + "kwargs= %s", + json.dumps(kwargs, indent=4, default=str), + ) + + # set litellm_call_id to logging response object + litellm_model_response.id = logging_obj.litellm_call_id + litellm_model_response.model = model + logging_obj.model_call_details["model"] = model + logging_obj.model_call_details["custom_llm_provider"] = ( + litellm.LlmProviders.ANTHROPIC.value + ) + return kwargs + except Exception as e: + verbose_proxy_logger.exception( + "Error creating Anthropic response logging payload: %s", e + ) + return kwargs + + @staticmethod + def _handle_logging_anthropic_collected_chunks( + litellm_logging_obj: LiteLLMLoggingObj, + passthrough_success_handler_obj: PassThroughEndpointLogging, + url_route: str, + request_body: dict, + endpoint_type: EndpointType, + start_time: datetime, + all_chunks: List[str], + end_time: datetime, + ) -> PassThroughEndpointLoggingTypedDict: + """ + Takes raw chunks from Anthropic passthrough endpoint and logs them in litellm callbacks + + - Builds complete response from chunks + - Creates standard logging object + - Logs in litellm callbacks + """ + + model = request_body.get("model", "") + complete_streaming_response = ( + AnthropicPassthroughLoggingHandler._build_complete_streaming_response( + all_chunks=all_chunks, + litellm_logging_obj=litellm_logging_obj, + model=model, + ) + ) + if complete_streaming_response is None: + verbose_proxy_logger.error( + "Unable to build complete streaming response for Anthropic passthrough endpoint, not logging..." + ) + return { + "result": None, + "kwargs": {}, + } + kwargs = AnthropicPassthroughLoggingHandler._create_anthropic_response_logging_payload( + litellm_model_response=complete_streaming_response, + model=model, + kwargs={}, + start_time=start_time, + end_time=end_time, + logging_obj=litellm_logging_obj, + ) + + return { + "result": complete_streaming_response, + "kwargs": kwargs, + } + + @staticmethod + def _build_complete_streaming_response( + all_chunks: List[str], + litellm_logging_obj: LiteLLMLoggingObj, + model: str, + ) -> Optional[Union[ModelResponse, TextCompletionResponse]]: + """ + Builds complete response from raw Anthropic chunks + + - Converts str chunks to generic chunks + - Converts generic chunks to litellm chunks (OpenAI format) + - Builds complete response from litellm chunks + """ + anthropic_model_response_iterator = AnthropicModelResponseIterator( + streaming_response=None, + sync_stream=False, + ) + all_openai_chunks = [] + for _chunk_str in all_chunks: + try: + transformed_openai_chunk = anthropic_model_response_iterator.convert_str_chunk_to_generic_chunk( + chunk=_chunk_str + ) + if transformed_openai_chunk is not None: + all_openai_chunks.append(transformed_openai_chunk) + + verbose_proxy_logger.debug( + "all openai chunks= %s", + json.dumps(all_openai_chunks, indent=4, default=str), + ) + except (StopIteration, StopAsyncIteration): + break + complete_streaming_response = litellm.stream_chunk_builder( + chunks=all_openai_chunks + ) + return complete_streaming_response diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/pass_through_endpoints/llm_provider_handlers/assembly_passthrough_logging_handler.py b/.venv/lib/python3.12/site-packages/litellm/proxy/pass_through_endpoints/llm_provider_handlers/assembly_passthrough_logging_handler.py new file mode 100644 index 00000000..7cf3013d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/pass_through_endpoints/llm_provider_handlers/assembly_passthrough_logging_handler.py @@ -0,0 +1,326 @@ +import asyncio +import json +import time +from datetime import datetime +from typing import Literal, Optional, TypedDict +from urllib.parse import urlparse + +import httpx + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +from litellm.litellm_core_utils.litellm_logging import ( + get_standard_logging_object_payload, +) +from litellm.litellm_core_utils.thread_pool_executor import executor +from litellm.proxy.pass_through_endpoints.types import PassthroughStandardLoggingPayload + + +class AssemblyAITranscriptResponse(TypedDict, total=False): + id: str + speech_model: str + acoustic_model: str + language_code: str + status: str + audio_duration: float + + +class AssemblyAIPassthroughLoggingHandler: + def __init__(self): + self.assembly_ai_base_url = "https://api.assemblyai.com" + self.assembly_ai_eu_base_url = "https://eu.assemblyai.com" + """ + The base URL for the AssemblyAI API + """ + + self.polling_interval: float = 10 + """ + The polling interval for the AssemblyAI API. + litellm needs to poll the GET /transcript/{transcript_id} endpoint to get the status of the transcript. + """ + + self.max_polling_attempts = 180 + """ + The maximum number of polling attempts for the AssemblyAI API. + """ + + def assemblyai_passthrough_logging_handler( + self, + httpx_response: httpx.Response, + response_body: dict, + logging_obj: LiteLLMLoggingObj, + url_route: str, + result: str, + start_time: datetime, + end_time: datetime, + cache_hit: bool, + **kwargs, + ): + """ + Since cost tracking requires polling the AssemblyAI API, we need to handle this in a separate thread. Hence the executor.submit. + """ + executor.submit( + self._handle_assemblyai_passthrough_logging, + httpx_response, + response_body, + logging_obj, + url_route, + result, + start_time, + end_time, + cache_hit, + **kwargs, + ) + + def _handle_assemblyai_passthrough_logging( + self, + httpx_response: httpx.Response, + response_body: dict, + logging_obj: LiteLLMLoggingObj, + url_route: str, + result: str, + start_time: datetime, + end_time: datetime, + cache_hit: bool, + **kwargs, + ): + """ + Handles logging for AssemblyAI successful passthrough requests + """ + from ..pass_through_endpoints import pass_through_endpoint_logging + + model = response_body.get("speech_model", "") + verbose_proxy_logger.debug( + "response body %s", json.dumps(response_body, indent=4) + ) + kwargs["model"] = model + kwargs["custom_llm_provider"] = "assemblyai" + response_cost: Optional[float] = None + + transcript_id = response_body.get("id") + if transcript_id is None: + raise ValueError( + "Transcript ID is required to log the cost of the transcription" + ) + transcript_response = self._poll_assembly_for_transcript_response( + transcript_id=transcript_id, url_route=url_route + ) + verbose_proxy_logger.debug( + "finished polling assembly for transcript response- got transcript response %s", + json.dumps(transcript_response, indent=4), + ) + if transcript_response: + cost = self.get_cost_for_assembly_transcript( + speech_model=model, + transcript_response=transcript_response, + ) + response_cost = cost + + # Make standard logging object for Vertex AI + standard_logging_object = get_standard_logging_object_payload( + kwargs=kwargs, + init_response_obj=transcript_response, + start_time=start_time, + end_time=end_time, + logging_obj=logging_obj, + status="success", + ) + + passthrough_logging_payload: Optional[PassthroughStandardLoggingPayload] = ( # type: ignore + kwargs.get("passthrough_logging_payload") + ) + + verbose_proxy_logger.debug( + "standard_passthrough_logging_object %s", + json.dumps(passthrough_logging_payload, indent=4), + ) + + # pretty print standard logging object + verbose_proxy_logger.debug( + "standard_logging_object= %s", json.dumps(standard_logging_object, indent=4) + ) + logging_obj.model_call_details["model"] = model + logging_obj.model_call_details["custom_llm_provider"] = "assemblyai" + logging_obj.model_call_details["response_cost"] = response_cost + + asyncio.run( + pass_through_endpoint_logging._handle_logging( + logging_obj=logging_obj, + standard_logging_response_object=self._get_response_to_log( + transcript_response + ), + result=result, + start_time=start_time, + end_time=end_time, + cache_hit=cache_hit, + **kwargs, + ) + ) + + pass + + def _get_response_to_log( + self, transcript_response: Optional[AssemblyAITranscriptResponse] + ) -> dict: + if transcript_response is None: + return {} + return dict(transcript_response) + + def _get_assembly_transcript( + self, + transcript_id: str, + request_region: Optional[Literal["eu"]] = None, + ) -> Optional[dict]: + """ + Get the transcript details from AssemblyAI API + + Args: + response_body (dict): Response containing the transcript ID + + Returns: + Optional[dict]: Transcript details if successful, None otherwise + """ + from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import ( + passthrough_endpoint_router, + ) + + _base_url = ( + self.assembly_ai_eu_base_url + if request_region == "eu" + else self.assembly_ai_base_url + ) + _api_key = passthrough_endpoint_router.get_credentials( + custom_llm_provider="assemblyai", + region_name=request_region, + ) + if _api_key is None: + raise ValueError("AssemblyAI API key not found") + try: + url = f"{_base_url}/v2/transcript/{transcript_id}" + headers = { + "Authorization": f"Bearer {_api_key}", + "Content-Type": "application/json", + } + + response = httpx.get(url, headers=headers) + response.raise_for_status() + + return response.json() + except Exception as e: + verbose_proxy_logger.exception( + f"[Non blocking logging error] Error getting AssemblyAI transcript: {str(e)}" + ) + return None + + def _poll_assembly_for_transcript_response( + self, + transcript_id: str, + url_route: Optional[str] = None, + ) -> Optional[AssemblyAITranscriptResponse]: + """ + Poll the status of the transcript until it is completed or timeout (30 minutes) + """ + for _ in range( + self.max_polling_attempts + ): # 180 attempts * 10s = 30 minutes max + transcript = self._get_assembly_transcript( + request_region=AssemblyAIPassthroughLoggingHandler._get_assembly_region_from_url( + url=url_route + ), + transcript_id=transcript_id, + ) + if transcript is None: + return None + if ( + transcript.get("status") == "completed" + or transcript.get("status") == "error" + ): + return AssemblyAITranscriptResponse(**transcript) + time.sleep(self.polling_interval) + return None + + @staticmethod + def get_cost_for_assembly_transcript( + transcript_response: AssemblyAITranscriptResponse, + speech_model: str, + ) -> Optional[float]: + """ + Get the cost for the assembly transcript + """ + _audio_duration = transcript_response.get("audio_duration") + if _audio_duration is None: + return None + _cost_per_second = ( + AssemblyAIPassthroughLoggingHandler.get_cost_per_second_for_assembly_model( + speech_model=speech_model + ) + ) + if _cost_per_second is None: + return None + return _audio_duration * _cost_per_second + + @staticmethod + def get_cost_per_second_for_assembly_model(speech_model: str) -> Optional[float]: + """ + Get the cost per second for the assembly model. + Falls back to assemblyai/nano if the specific speech model info cannot be found. + """ + try: + # First try with the provided speech model + try: + model_info = litellm.get_model_info( + model=speech_model, + custom_llm_provider="assemblyai", + ) + if model_info and model_info.get("input_cost_per_second") is not None: + return model_info.get("input_cost_per_second") + except Exception: + pass # Continue to fallback if model not found + + # Fallback to assemblyai/nano if speech model info not found + try: + model_info = litellm.get_model_info( + model="assemblyai/nano", + custom_llm_provider="assemblyai", + ) + if model_info and model_info.get("input_cost_per_second") is not None: + return model_info.get("input_cost_per_second") + except Exception: + pass + + return None + except Exception as e: + verbose_proxy_logger.exception( + f"[Non blocking logging error] Error getting AssemblyAI model info: {str(e)}" + ) + return None + + @staticmethod + def _should_log_request(request_method: str) -> bool: + """ + only POST transcription jobs are logged. litellm will POLL assembly to wait for the transcription to complete to log the complete response / cost + """ + return request_method == "POST" + + @staticmethod + def _get_assembly_region_from_url(url: Optional[str]) -> Optional[Literal["eu"]]: + """ + Get the region from the URL + """ + if url is None: + return None + if urlparse(url).hostname == "eu.assemblyai.com": + return "eu" + return None + + @staticmethod + def _get_assembly_base_url_from_region(region: Optional[Literal["eu"]]) -> str: + """ + Get the base URL for the AssemblyAI API + if region == "eu", return "https://api.eu.assemblyai.com" + else return "https://api.assemblyai.com" + """ + if region == "eu": + return "https://api.eu.assemblyai.com" + return "https://api.assemblyai.com" diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/pass_through_endpoints/llm_provider_handlers/vertex_passthrough_logging_handler.py b/.venv/lib/python3.12/site-packages/litellm/proxy/pass_through_endpoints/llm_provider_handlers/vertex_passthrough_logging_handler.py new file mode 100644 index 00000000..94435637 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/pass_through_endpoints/llm_provider_handlers/vertex_passthrough_logging_handler.py @@ -0,0 +1,258 @@ +import json +import re +from datetime import datetime +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from urllib.parse import urlparse +import httpx + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import ( + ModelResponseIterator as VertexModelResponseIterator, +) +from litellm.proxy._types import PassThroughEndpointLoggingTypedDict +from litellm.types.utils import ( + EmbeddingResponse, + ImageResponse, + ModelResponse, + TextCompletionResponse, +) + +if TYPE_CHECKING: + from ..success_handler import PassThroughEndpointLogging + from ..types import EndpointType +else: + PassThroughEndpointLogging = Any + EndpointType = Any + + +class VertexPassthroughLoggingHandler: + @staticmethod + def vertex_passthrough_handler( + httpx_response: httpx.Response, + logging_obj: LiteLLMLoggingObj, + url_route: str, + result: str, + start_time: datetime, + end_time: datetime, + cache_hit: bool, + **kwargs, + ) -> PassThroughEndpointLoggingTypedDict: + if "generateContent" in url_route: + model = VertexPassthroughLoggingHandler.extract_model_from_url(url_route) + + instance_of_vertex_llm = litellm.VertexGeminiConfig() + litellm_model_response: ModelResponse = ( + instance_of_vertex_llm.transform_response( + model=model, + messages=[ + {"role": "user", "content": "no-message-pass-through-endpoint"} + ], + raw_response=httpx_response, + model_response=litellm.ModelResponse(), + logging_obj=logging_obj, + optional_params={}, + litellm_params={}, + api_key="", + request_data={}, + encoding=litellm.encoding, + ) + ) + kwargs = VertexPassthroughLoggingHandler._create_vertex_response_logging_payload_for_generate_content( + litellm_model_response=litellm_model_response, + model=model, + kwargs=kwargs, + start_time=start_time, + end_time=end_time, + logging_obj=logging_obj, + custom_llm_provider=VertexPassthroughLoggingHandler._get_custom_llm_provider_from_url( + url_route + ), + ) + + return { + "result": litellm_model_response, + "kwargs": kwargs, + } + + elif "predict" in url_route: + from litellm.llms.vertex_ai.image_generation.image_generation_handler import ( + VertexImageGeneration, + ) + from litellm.types.utils import PassthroughCallTypes + + vertex_image_generation_class = VertexImageGeneration() + + model = VertexPassthroughLoggingHandler.extract_model_from_url(url_route) + _json_response = httpx_response.json() + + litellm_prediction_response: Union[ + ModelResponse, EmbeddingResponse, ImageResponse + ] = ModelResponse() + if vertex_image_generation_class.is_image_generation_response( + _json_response + ): + litellm_prediction_response = ( + vertex_image_generation_class.process_image_generation_response( + _json_response, + model_response=litellm.ImageResponse(), + model=model, + ) + ) + + logging_obj.call_type = ( + PassthroughCallTypes.passthrough_image_generation.value + ) + else: + litellm_prediction_response = litellm.vertexAITextEmbeddingConfig.transform_vertex_response_to_openai( + response=_json_response, + model=model, + model_response=litellm.EmbeddingResponse(), + ) + if isinstance(litellm_prediction_response, litellm.EmbeddingResponse): + litellm_prediction_response.model = model + + logging_obj.model = model + logging_obj.model_call_details["model"] = logging_obj.model + + return { + "result": litellm_prediction_response, + "kwargs": kwargs, + } + else: + return { + "result": None, + "kwargs": kwargs, + } + + @staticmethod + def _handle_logging_vertex_collected_chunks( + litellm_logging_obj: LiteLLMLoggingObj, + passthrough_success_handler_obj: PassThroughEndpointLogging, + url_route: str, + request_body: dict, + endpoint_type: EndpointType, + start_time: datetime, + all_chunks: List[str], + end_time: datetime, + ) -> PassThroughEndpointLoggingTypedDict: + """ + Takes raw chunks from Vertex passthrough endpoint and logs them in litellm callbacks + + - Builds complete response from chunks + - Creates standard logging object + - Logs in litellm callbacks + """ + kwargs: Dict[str, Any] = {} + model = VertexPassthroughLoggingHandler.extract_model_from_url(url_route) + complete_streaming_response = ( + VertexPassthroughLoggingHandler._build_complete_streaming_response( + all_chunks=all_chunks, + litellm_logging_obj=litellm_logging_obj, + model=model, + ) + ) + + if complete_streaming_response is None: + verbose_proxy_logger.error( + "Unable to build complete streaming response for Vertex passthrough endpoint, not logging..." + ) + return { + "result": None, + "kwargs": kwargs, + } + + kwargs = VertexPassthroughLoggingHandler._create_vertex_response_logging_payload_for_generate_content( + litellm_model_response=complete_streaming_response, + model=model, + kwargs=kwargs, + start_time=start_time, + end_time=end_time, + logging_obj=litellm_logging_obj, + custom_llm_provider=VertexPassthroughLoggingHandler._get_custom_llm_provider_from_url( + url_route + ), + ) + + return { + "result": complete_streaming_response, + "kwargs": kwargs, + } + + @staticmethod + def _build_complete_streaming_response( + all_chunks: List[str], + litellm_logging_obj: LiteLLMLoggingObj, + model: str, + ) -> Optional[Union[ModelResponse, TextCompletionResponse]]: + vertex_iterator = VertexModelResponseIterator( + streaming_response=None, + sync_stream=False, + ) + litellm_custom_stream_wrapper = litellm.CustomStreamWrapper( + completion_stream=vertex_iterator, + model=model, + logging_obj=litellm_logging_obj, + custom_llm_provider="vertex_ai", + ) + all_openai_chunks = [] + for chunk in all_chunks: + generic_chunk = vertex_iterator._common_chunk_parsing_logic(chunk) + litellm_chunk = litellm_custom_stream_wrapper.chunk_creator( + chunk=generic_chunk + ) + if litellm_chunk is not None: + all_openai_chunks.append(litellm_chunk) + + complete_streaming_response = litellm.stream_chunk_builder( + chunks=all_openai_chunks + ) + + return complete_streaming_response + + @staticmethod + def extract_model_from_url(url: str) -> str: + pattern = r"/models/([^:]+)" + match = re.search(pattern, url) + if match: + return match.group(1) + return "unknown" + + @staticmethod + def _get_custom_llm_provider_from_url(url: str) -> str: + parsed_url = urlparse(url) + if parsed_url.hostname and parsed_url.hostname.endswith("generativelanguage.googleapis.com"): + return litellm.LlmProviders.GEMINI.value + return litellm.LlmProviders.VERTEX_AI.value + + @staticmethod + def _create_vertex_response_logging_payload_for_generate_content( + litellm_model_response: Union[ModelResponse, TextCompletionResponse], + model: str, + kwargs: dict, + start_time: datetime, + end_time: datetime, + logging_obj: LiteLLMLoggingObj, + custom_llm_provider: str, + ): + """ + Create the standard logging object for Vertex passthrough generateContent (streaming and non-streaming) + + """ + response_cost = litellm.completion_cost( + completion_response=litellm_model_response, + model=model, + ) + kwargs["response_cost"] = response_cost + kwargs["model"] = model + + # pretty print standard logging object + verbose_proxy_logger.debug("kwargs= %s", json.dumps(kwargs, indent=4)) + + # set litellm_call_id to logging response object + litellm_model_response.id = logging_obj.litellm_call_id + logging_obj.model = litellm_model_response.model or model + logging_obj.model_call_details["model"] = logging_obj.model + logging_obj.model_call_details["custom_llm_provider"] = custom_llm_provider + return kwargs diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/.venv/lib/python3.12/site-packages/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py new file mode 100644 index 00000000..a13b0dc2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -0,0 +1,1001 @@ +import ast +import asyncio +import json +import uuid +from base64 import b64encode +from datetime import datetime +from typing import Dict, List, Optional, Union +from urllib.parse import parse_qs, urlencode, urlparse + +import httpx +from fastapi import APIRouter, Depends, HTTPException, Request, Response, status +from fastapi.responses import StreamingResponse + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.integrations.custom_logger import CustomLogger +from litellm.llms.custom_httpx.http_handler import get_async_httpx_client +from litellm.proxy._types import ( + ConfigFieldInfo, + ConfigFieldUpdate, + PassThroughEndpointResponse, + PassThroughGenericEndpoint, + ProxyException, + UserAPIKeyAuth, +) +from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing +from litellm.proxy.common_utils.http_parsing_utils import _read_request_body +from litellm.secret_managers.main import get_secret_str +from litellm.types.llms.custom_http import httpxSpecialProvider +from litellm.types.utils import StandardLoggingUserAPIKeyMetadata + +from .streaming_handler import PassThroughStreamingHandler +from .success_handler import PassThroughEndpointLogging +from .types import EndpointType, PassthroughStandardLoggingPayload + +router = APIRouter() + +pass_through_endpoint_logging = PassThroughEndpointLogging() + + +def get_response_body(response: httpx.Response) -> Optional[dict]: + try: + return response.json() + except Exception: + return None + + +async def set_env_variables_in_header(custom_headers: Optional[dict]) -> Optional[dict]: + """ + checks if any headers on config.yaml are defined as os.environ/COHERE_API_KEY etc + + only runs for headers defined on config.yaml + + example header can be + + {"Authorization": "bearer os.environ/COHERE_API_KEY"} + """ + if custom_headers is None: + return None + headers = {} + for key, value in custom_headers.items(): + # langfuse Api requires base64 encoded headers - it's simpleer to just ask litellm users to set their langfuse public and secret keys + # we can then get the b64 encoded keys here + if key == "LANGFUSE_PUBLIC_KEY" or key == "LANGFUSE_SECRET_KEY": + # langfuse requires b64 encoded headers - we construct that here + _langfuse_public_key = custom_headers["LANGFUSE_PUBLIC_KEY"] + _langfuse_secret_key = custom_headers["LANGFUSE_SECRET_KEY"] + if isinstance( + _langfuse_public_key, str + ) and _langfuse_public_key.startswith("os.environ/"): + _langfuse_public_key = get_secret_str(_langfuse_public_key) + if isinstance( + _langfuse_secret_key, str + ) and _langfuse_secret_key.startswith("os.environ/"): + _langfuse_secret_key = get_secret_str(_langfuse_secret_key) + headers["Authorization"] = "Basic " + b64encode( + f"{_langfuse_public_key}:{_langfuse_secret_key}".encode("utf-8") + ).decode("ascii") + else: + # for all other headers + headers[key] = value + if isinstance(value, str) and "os.environ/" in value: + verbose_proxy_logger.debug( + "pass through endpoint - looking up 'os.environ/' variable" + ) + # get string section that is os.environ/ + start_index = value.find("os.environ/") + _variable_name = value[start_index:] + + verbose_proxy_logger.debug( + "pass through endpoint - getting secret for variable name: %s", + _variable_name, + ) + _secret_value = get_secret_str(_variable_name) + if _secret_value is not None: + new_value = value.replace(_variable_name, _secret_value) + headers[key] = new_value + return headers + + +async def chat_completion_pass_through_endpoint( # noqa: PLR0915 + fastapi_response: Response, + request: Request, + adapter_id: str, + user_api_key_dict: UserAPIKeyAuth, +): + from litellm.proxy.proxy_server import ( + add_litellm_data_to_request, + general_settings, + llm_router, + proxy_config, + proxy_logging_obj, + user_api_base, + user_max_tokens, + user_model, + user_request_timeout, + user_temperature, + version, + ) + + data = {} + try: + body = await request.body() + body_str = body.decode() + try: + data = ast.literal_eval(body_str) + except Exception: + data = json.loads(body_str) + + data["adapter_id"] = adapter_id + + verbose_proxy_logger.debug( + "Request received by LiteLLM:\n{}".format(json.dumps(data, indent=4)), + ) + data["model"] = ( + general_settings.get("completion_model", None) # server default + or user_model # model name passed via cli args + or data.get("model", None) # default passed in http request + ) + if user_model: + data["model"] = user_model + + data = await add_litellm_data_to_request( + data=data, # type: ignore + request=request, + general_settings=general_settings, + user_api_key_dict=user_api_key_dict, + version=version, + proxy_config=proxy_config, + ) + + # override with user settings, these are params passed via cli + if user_temperature: + data["temperature"] = user_temperature + if user_request_timeout: + data["request_timeout"] = user_request_timeout + if user_max_tokens: + data["max_tokens"] = user_max_tokens + if user_api_base: + data["api_base"] = user_api_base + + ### MODEL ALIAS MAPPING ### + # check if model name in model alias map + # get the actual model name + if data["model"] in litellm.model_alias_map: + data["model"] = litellm.model_alias_map[data["model"]] + + ### CALL HOOKS ### - modify incoming data before calling the model + data = await proxy_logging_obj.pre_call_hook( # type: ignore + user_api_key_dict=user_api_key_dict, data=data, call_type="text_completion" + ) + + ### ROUTE THE REQUESTs ### + router_model_names = llm_router.model_names if llm_router is not None else [] + # skip router if user passed their key + if "api_key" in data: + llm_response = asyncio.create_task(litellm.aadapter_completion(**data)) + elif ( + llm_router is not None and data["model"] in router_model_names + ): # model in router model list + llm_response = asyncio.create_task(llm_router.aadapter_completion(**data)) + elif ( + llm_router is not None + and llm_router.model_group_alias is not None + and data["model"] in llm_router.model_group_alias + ): # model set in model_group_alias + llm_response = asyncio.create_task(llm_router.aadapter_completion(**data)) + elif ( + llm_router is not None and data["model"] in llm_router.deployment_names + ): # model in router deployments, calling a specific deployment on the router + llm_response = asyncio.create_task( + llm_router.aadapter_completion(**data, specific_deployment=True) + ) + elif ( + llm_router is not None and data["model"] in llm_router.get_model_ids() + ): # model in router model list + llm_response = asyncio.create_task(llm_router.aadapter_completion(**data)) + elif ( + llm_router is not None + and data["model"] not in router_model_names + and llm_router.default_deployment is not None + ): # model in router deployments, calling a specific deployment on the router + llm_response = asyncio.create_task(llm_router.aadapter_completion(**data)) + elif user_model is not None: # `litellm --model <your-model-name>` + llm_response = asyncio.create_task(litellm.aadapter_completion(**data)) + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={ + "error": "completion: Invalid model name passed in model=" + + data.get("model", "") + }, + ) + + # Await the llm_response task + response = await llm_response + + hidden_params = getattr(response, "_hidden_params", {}) or {} + model_id = hidden_params.get("model_id", None) or "" + cache_key = hidden_params.get("cache_key", None) or "" + api_base = hidden_params.get("api_base", None) or "" + response_cost = hidden_params.get("response_cost", None) or "" + + ### ALERTING ### + asyncio.create_task( + proxy_logging_obj.update_request_status( + litellm_call_id=data.get("litellm_call_id", ""), status="success" + ) + ) + + verbose_proxy_logger.debug("final response: %s", response) + + fastapi_response.headers.update( + ProxyBaseLLMRequestProcessing.get_custom_headers( + user_api_key_dict=user_api_key_dict, + model_id=model_id, + cache_key=cache_key, + api_base=api_base, + version=version, + response_cost=response_cost, + ) + ) + + verbose_proxy_logger.info("\nResponse from Litellm:\n{}".format(response)) + return response + except Exception as e: + await proxy_logging_obj.post_call_failure_hook( + user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data + ) + verbose_proxy_logger.exception( + "litellm.proxy.proxy_server.completion(): Exception occured - {}".format( + str(e) + ) + ) + error_msg = f"{str(e)}" + raise ProxyException( + message=getattr(e, "message", error_msg), + type=getattr(e, "type", "None"), + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", 500), + ) + + +class HttpPassThroughEndpointHelpers: + @staticmethod + def forward_headers_from_request( + request: Request, + headers: dict, + forward_headers: Optional[bool] = False, + ): + """ + Helper to forward headers from original request + """ + if forward_headers is True: + request_headers = dict(request.headers) + + # Header We Should NOT forward + request_headers.pop("content-length", None) + request_headers.pop("host", None) + + # Combine request headers with custom headers + headers = {**request_headers, **headers} + return headers + + @staticmethod + def get_response_headers( + headers: httpx.Headers, + litellm_call_id: Optional[str] = None, + custom_headers: Optional[dict] = None, + ) -> dict: + excluded_headers = {"transfer-encoding", "content-encoding"} + + return_headers = { + key: value + for key, value in headers.items() + if key.lower() not in excluded_headers + } + if litellm_call_id: + return_headers["x-litellm-call-id"] = litellm_call_id + if custom_headers: + return_headers.update(custom_headers) + + return return_headers + + @staticmethod + def get_endpoint_type(url: str) -> EndpointType: + parsed_url = urlparse(url) + if ("generateContent") in url or ("streamGenerateContent") in url: + return EndpointType.VERTEX_AI + elif parsed_url.hostname == "api.anthropic.com": + return EndpointType.ANTHROPIC + return EndpointType.GENERIC + + @staticmethod + def get_merged_query_parameters( + existing_url: httpx.URL, request_query_params: Dict[str, Union[str, list]] + ) -> Dict[str, Union[str, List[str]]]: + # Get the existing query params from the target URL + existing_query_string = existing_url.query.decode("utf-8") + existing_query_params = parse_qs(existing_query_string) + + # parse_qs returns a dict where each value is a list, so let's flatten it + updated_existing_query_params = { + k: v[0] if len(v) == 1 else v for k, v in existing_query_params.items() + } + # Merge the query params, giving priority to the existing ones + return {**request_query_params, **updated_existing_query_params} + + @staticmethod + async def _make_non_streaming_http_request( + request: Request, + async_client: httpx.AsyncClient, + url: str, + headers: dict, + requested_query_params: Optional[dict] = None, + custom_body: Optional[dict] = None, + ) -> httpx.Response: + """ + Make a non-streaming HTTP request + + If request is GET, don't include a JSON body + """ + if request.method == "GET": + response = await async_client.request( + method=request.method, + url=url, + headers=headers, + params=requested_query_params, + ) + else: + response = await async_client.request( + method=request.method, + url=url, + headers=headers, + params=requested_query_params, + json=custom_body, + ) + return response + + +async def pass_through_request( # noqa: PLR0915 + request: Request, + target: str, + custom_headers: dict, + user_api_key_dict: UserAPIKeyAuth, + custom_body: Optional[dict] = None, + forward_headers: Optional[bool] = False, + merge_query_params: Optional[bool] = False, + query_params: Optional[dict] = None, + stream: Optional[bool] = None, +): + litellm_call_id = str(uuid.uuid4()) + url: Optional[httpx.URL] = None + try: + + from litellm.litellm_core_utils.litellm_logging import Logging + from litellm.proxy.proxy_server import proxy_logging_obj + + url = httpx.URL(target) + headers = custom_headers + headers = HttpPassThroughEndpointHelpers.forward_headers_from_request( + request=request, headers=headers, forward_headers=forward_headers + ) + + if merge_query_params: + + # Create a new URL with the merged query params + url = url.copy_with( + query=urlencode( + HttpPassThroughEndpointHelpers.get_merged_query_parameters( + existing_url=url, + request_query_params=dict(request.query_params), + ) + ).encode("ascii") + ) + + endpoint_type: EndpointType = HttpPassThroughEndpointHelpers.get_endpoint_type( + str(url) + ) + + _parsed_body = None + if custom_body: + _parsed_body = custom_body + else: + _parsed_body = await _read_request_body(request) + verbose_proxy_logger.debug( + "Pass through endpoint sending request to \nURL {}\nheaders: {}\nbody: {}\n".format( + url, headers, _parsed_body + ) + ) + + ### CALL HOOKS ### - modify incoming data / reject request before calling the model + _parsed_body = await proxy_logging_obj.pre_call_hook( + user_api_key_dict=user_api_key_dict, + data=_parsed_body, + call_type="pass_through_endpoint", + ) + async_client_obj = get_async_httpx_client( + llm_provider=httpxSpecialProvider.PassThroughEndpoint, + params={"timeout": 600}, + ) + async_client = async_client_obj.client + + # create logging object + start_time = datetime.now() + logging_obj = Logging( + model="unknown", + messages=[{"role": "user", "content": json.dumps(_parsed_body)}], + stream=False, + call_type="pass_through_endpoint", + start_time=start_time, + litellm_call_id=litellm_call_id, + function_id="1245", + ) + passthrough_logging_payload = PassthroughStandardLoggingPayload( + url=str(url), + request_body=_parsed_body, + ) + kwargs = _init_kwargs_for_pass_through_endpoint( + user_api_key_dict=user_api_key_dict, + _parsed_body=_parsed_body, + passthrough_logging_payload=passthrough_logging_payload, + litellm_call_id=litellm_call_id, + request=request, + ) + # done for supporting 'parallel_request_limiter.py' with pass-through endpoints + logging_obj.update_environment_variables( + model="unknown", + user="unknown", + optional_params={}, + litellm_params=kwargs["litellm_params"], + call_type="pass_through_endpoint", + ) + logging_obj.model_call_details["litellm_call_id"] = litellm_call_id + + # combine url with query params for logging + + requested_query_params: Optional[dict] = ( + query_params or request.query_params.__dict__ + ) + if requested_query_params == request.query_params.__dict__: + requested_query_params = None + + requested_query_params_str = None + if requested_query_params: + requested_query_params_str = "&".join( + f"{k}={v}" for k, v in requested_query_params.items() + ) + + logging_url = str(url) + if requested_query_params_str: + if "?" in str(url): + logging_url = str(url) + "&" + requested_query_params_str + else: + logging_url = str(url) + "?" + requested_query_params_str + + logging_obj.pre_call( + input=[{"role": "user", "content": json.dumps(_parsed_body)}], + api_key="", + additional_args={ + "complete_input_dict": _parsed_body, + "api_base": str(logging_url), + "headers": headers, + }, + ) + if stream: + req = async_client.build_request( + "POST", + url, + json=_parsed_body, + params=requested_query_params, + headers=headers, + ) + + response = await async_client.send(req, stream=stream) + + try: + response.raise_for_status() + except httpx.HTTPStatusError as e: + raise HTTPException( + status_code=e.response.status_code, detail=await e.response.aread() + ) + + return StreamingResponse( + PassThroughStreamingHandler.chunk_processor( + response=response, + request_body=_parsed_body, + litellm_logging_obj=logging_obj, + endpoint_type=endpoint_type, + start_time=start_time, + passthrough_success_handler_obj=pass_through_endpoint_logging, + url_route=str(url), + ), + headers=HttpPassThroughEndpointHelpers.get_response_headers( + headers=response.headers, + litellm_call_id=litellm_call_id, + ), + status_code=response.status_code, + ) + + verbose_proxy_logger.debug("request method: {}".format(request.method)) + verbose_proxy_logger.debug("request url: {}".format(url)) + verbose_proxy_logger.debug("request headers: {}".format(headers)) + verbose_proxy_logger.debug( + "requested_query_params={}".format(requested_query_params) + ) + verbose_proxy_logger.debug("request body: {}".format(_parsed_body)) + + if request.method == "GET": + response = await async_client.request( + method=request.method, + url=url, + headers=headers, + params=requested_query_params, + ) + else: + response = await async_client.request( + method=request.method, + url=url, + headers=headers, + params=requested_query_params, + json=_parsed_body, + ) + + verbose_proxy_logger.debug("response.headers= %s", response.headers) + + if _is_streaming_response(response) is True: + try: + response.raise_for_status() + except httpx.HTTPStatusError as e: + raise HTTPException( + status_code=e.response.status_code, detail=await e.response.aread() + ) + + return StreamingResponse( + PassThroughStreamingHandler.chunk_processor( + response=response, + request_body=_parsed_body, + litellm_logging_obj=logging_obj, + endpoint_type=endpoint_type, + start_time=start_time, + passthrough_success_handler_obj=pass_through_endpoint_logging, + url_route=str(url), + ), + headers=HttpPassThroughEndpointHelpers.get_response_headers( + headers=response.headers, + litellm_call_id=litellm_call_id, + ), + status_code=response.status_code, + ) + + try: + response.raise_for_status() + except httpx.HTTPStatusError as e: + raise HTTPException( + status_code=e.response.status_code, detail=e.response.text + ) + + if response.status_code >= 300: + raise HTTPException(status_code=response.status_code, detail=response.text) + + content = await response.aread() + + ## LOG SUCCESS + response_body: Optional[dict] = get_response_body(response) + passthrough_logging_payload["response_body"] = response_body + end_time = datetime.now() + asyncio.create_task( + pass_through_endpoint_logging.pass_through_async_success_handler( + httpx_response=response, + response_body=response_body, + url_route=str(url), + result="", + start_time=start_time, + end_time=end_time, + logging_obj=logging_obj, + cache_hit=False, + **kwargs, + ) + ) + + ## CUSTOM HEADERS - `x-litellm-*` + custom_headers = ProxyBaseLLMRequestProcessing.get_custom_headers( + user_api_key_dict=user_api_key_dict, + call_id=litellm_call_id, + model_id=None, + cache_key=None, + api_base=str(url._uri_reference), + ) + + return Response( + content=content, + status_code=response.status_code, + headers=HttpPassThroughEndpointHelpers.get_response_headers( + headers=response.headers, + custom_headers=custom_headers, + ), + ) + except Exception as e: + custom_headers = ProxyBaseLLMRequestProcessing.get_custom_headers( + user_api_key_dict=user_api_key_dict, + call_id=litellm_call_id, + model_id=None, + cache_key=None, + api_base=str(url._uri_reference) if url else None, + ) + verbose_proxy_logger.exception( + "litellm.proxy.proxy_server.pass_through_endpoint(): Exception occured - {}".format( + str(e) + ) + ) + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "message", str(e.detail)), + type=getattr(e, "type", "None"), + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), + headers=custom_headers, + ) + else: + error_msg = f"{str(e)}" + raise ProxyException( + message=getattr(e, "message", error_msg), + type=getattr(e, "type", "None"), + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", 500), + headers=custom_headers, + ) + + +def _init_kwargs_for_pass_through_endpoint( + request: Request, + user_api_key_dict: UserAPIKeyAuth, + passthrough_logging_payload: PassthroughStandardLoggingPayload, + _parsed_body: Optional[dict] = None, + litellm_call_id: Optional[str] = None, +) -> dict: + _parsed_body = _parsed_body or {} + _litellm_metadata: Optional[dict] = _parsed_body.pop("litellm_metadata", None) + _metadata = dict( + StandardLoggingUserAPIKeyMetadata( + user_api_key_hash=user_api_key_dict.api_key, + user_api_key_alias=user_api_key_dict.key_alias, + user_api_key_user_email=user_api_key_dict.user_email, + user_api_key_user_id=user_api_key_dict.user_id, + user_api_key_team_id=user_api_key_dict.team_id, + user_api_key_org_id=user_api_key_dict.org_id, + user_api_key_team_alias=user_api_key_dict.team_alias, + user_api_key_end_user_id=user_api_key_dict.end_user_id, + ) + ) + _metadata["user_api_key"] = user_api_key_dict.api_key + if _litellm_metadata: + _metadata.update(_litellm_metadata) + + _metadata = _update_metadata_with_tags_in_header( + request=request, + metadata=_metadata, + ) + + kwargs = { + "litellm_params": { + "metadata": _metadata, + }, + "call_type": "pass_through_endpoint", + "litellm_call_id": litellm_call_id, + "passthrough_logging_payload": passthrough_logging_payload, + } + return kwargs + + +def _update_metadata_with_tags_in_header(request: Request, metadata: dict) -> dict: + """ + If tags are in the request headers, add them to the metadata + + Used for google and vertex JS SDKs + """ + _tags = request.headers.get("tags") + if _tags: + metadata["tags"] = _tags.split(",") + return metadata + + +def create_pass_through_route( + endpoint, + target: str, + custom_headers: Optional[dict] = None, + _forward_headers: Optional[bool] = False, + _merge_query_params: Optional[bool] = False, + dependencies: Optional[List] = None, +): + # check if target is an adapter.py or a url + import uuid + + from litellm.proxy.types_utils.utils import get_instance_fn + + try: + if isinstance(target, CustomLogger): + adapter = target + else: + adapter = get_instance_fn(value=target) + adapter_id = str(uuid.uuid4()) + litellm.adapters = [{"id": adapter_id, "adapter": adapter}] + + async def endpoint_func( # type: ignore + request: Request, + fastapi_response: Response, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), + ): + return await chat_completion_pass_through_endpoint( + fastapi_response=fastapi_response, + request=request, + adapter_id=adapter_id, + user_api_key_dict=user_api_key_dict, + ) + + except Exception: + verbose_proxy_logger.debug("Defaulting to target being a url.") + + async def endpoint_func( # type: ignore + request: Request, + fastapi_response: Response, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), + query_params: Optional[dict] = None, + custom_body: Optional[dict] = None, + stream: Optional[ + bool + ] = None, # if pass-through endpoint is a streaming request + ): + return await pass_through_request( # type: ignore + request=request, + target=target, + custom_headers=custom_headers or {}, + user_api_key_dict=user_api_key_dict, + forward_headers=_forward_headers, + merge_query_params=_merge_query_params, + query_params=query_params, + stream=stream, + custom_body=custom_body, + ) + + return endpoint_func + + +def _is_streaming_response(response: httpx.Response) -> bool: + _content_type = response.headers.get("content-type") + if _content_type is not None and "text/event-stream" in _content_type: + return True + return False + + +async def initialize_pass_through_endpoints(pass_through_endpoints: list): + + verbose_proxy_logger.debug("initializing pass through endpoints") + from litellm.proxy._types import CommonProxyErrors, LiteLLMRoutes + from litellm.proxy.proxy_server import app, premium_user + + for endpoint in pass_through_endpoints: + _target = endpoint.get("target", None) + _path = endpoint.get("path", None) + _custom_headers = endpoint.get("headers", None) + _custom_headers = await set_env_variables_in_header( + custom_headers=_custom_headers + ) + _forward_headers = endpoint.get("forward_headers", None) + _merge_query_params = endpoint.get("merge_query_params", None) + _auth = endpoint.get("auth", None) + _dependencies = None + if _auth is not None and str(_auth).lower() == "true": + if premium_user is not True: + raise ValueError( + "Error Setting Authentication on Pass Through Endpoint: {}".format( + CommonProxyErrors.not_premium_user.value + ) + ) + _dependencies = [Depends(user_api_key_auth)] + LiteLLMRoutes.openai_routes.value.append(_path) + + if _target is None: + continue + + verbose_proxy_logger.debug( + "adding pass through endpoint: %s, dependencies: %s", _path, _dependencies + ) + app.add_api_route( # type: ignore + path=_path, + endpoint=create_pass_through_route( # type: ignore + _path, + _target, + _custom_headers, + _forward_headers, + _merge_query_params, + _dependencies, + ), + methods=["GET", "POST", "PUT", "DELETE", "PATCH"], + dependencies=_dependencies, + ) + + verbose_proxy_logger.debug("Added new pass through endpoint: %s", _path) + + +@router.get( + "/config/pass_through_endpoint", + dependencies=[Depends(user_api_key_auth)], + response_model=PassThroughEndpointResponse, +) +async def get_pass_through_endpoints( + endpoint_id: Optional[str] = None, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + GET configured pass through endpoint. + + If no endpoint_id given, return all configured endpoints. + """ + from litellm.proxy.proxy_server import get_config_general_settings + + ## Get existing pass-through endpoint field value + try: + response: ConfigFieldInfo = await get_config_general_settings( + field_name="pass_through_endpoints", user_api_key_dict=user_api_key_dict + ) + except Exception: + return PassThroughEndpointResponse(endpoints=[]) + + pass_through_endpoint_data: Optional[List] = response.field_value + if pass_through_endpoint_data is None: + return PassThroughEndpointResponse(endpoints=[]) + + returned_endpoints = [] + if endpoint_id is None: + for endpoint in pass_through_endpoint_data: + if isinstance(endpoint, dict): + returned_endpoints.append(PassThroughGenericEndpoint(**endpoint)) + elif isinstance(endpoint, PassThroughGenericEndpoint): + returned_endpoints.append(endpoint) + elif endpoint_id is not None: + for endpoint in pass_through_endpoint_data: + _endpoint: Optional[PassThroughGenericEndpoint] = None + if isinstance(endpoint, dict): + _endpoint = PassThroughGenericEndpoint(**endpoint) + elif isinstance(endpoint, PassThroughGenericEndpoint): + _endpoint = endpoint + + if _endpoint is not None and _endpoint.path == endpoint_id: + returned_endpoints.append(_endpoint) + + return PassThroughEndpointResponse(endpoints=returned_endpoints) + + +@router.post( + "/config/pass_through_endpoint/{endpoint_id}", + dependencies=[Depends(user_api_key_auth)], +) +async def update_pass_through_endpoints(request: Request, endpoint_id: str): + """ + Update a pass-through endpoint + """ + pass + + +@router.post( + "/config/pass_through_endpoint", + dependencies=[Depends(user_api_key_auth)], +) +async def create_pass_through_endpoints( + data: PassThroughGenericEndpoint, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Create new pass-through endpoint + """ + from litellm.proxy.proxy_server import ( + get_config_general_settings, + update_config_general_settings, + ) + + ## Get existing pass-through endpoint field value + + try: + response: ConfigFieldInfo = await get_config_general_settings( + field_name="pass_through_endpoints", user_api_key_dict=user_api_key_dict + ) + except Exception: + response = ConfigFieldInfo( + field_name="pass_through_endpoints", field_value=None + ) + + ## Update field with new endpoint + data_dict = data.model_dump() + if response.field_value is None: + response.field_value = [data_dict] + elif isinstance(response.field_value, List): + response.field_value.append(data_dict) + + ## Update db + updated_data = ConfigFieldUpdate( + field_name="pass_through_endpoints", + field_value=response.field_value, + config_type="general_settings", + ) + await update_config_general_settings( + data=updated_data, user_api_key_dict=user_api_key_dict + ) + + +@router.delete( + "/config/pass_through_endpoint", + dependencies=[Depends(user_api_key_auth)], + response_model=PassThroughEndpointResponse, +) +async def delete_pass_through_endpoints( + endpoint_id: str, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Delete a pass-through endpoint + + Returns - the deleted endpoint + """ + from litellm.proxy.proxy_server import ( + get_config_general_settings, + update_config_general_settings, + ) + + ## Get existing pass-through endpoint field value + + try: + response: ConfigFieldInfo = await get_config_general_settings( + field_name="pass_through_endpoints", user_api_key_dict=user_api_key_dict + ) + except Exception: + response = ConfigFieldInfo( + field_name="pass_through_endpoints", field_value=None + ) + + ## Update field by removing endpoint + pass_through_endpoint_data: Optional[List] = response.field_value + response_obj: Optional[PassThroughGenericEndpoint] = None + if response.field_value is None or pass_through_endpoint_data is None: + raise HTTPException( + status_code=400, + detail={"error": "There are no pass-through endpoints setup."}, + ) + elif isinstance(response.field_value, List): + invalid_idx: Optional[int] = None + for idx, endpoint in enumerate(pass_through_endpoint_data): + _endpoint: Optional[PassThroughGenericEndpoint] = None + if isinstance(endpoint, dict): + _endpoint = PassThroughGenericEndpoint(**endpoint) + elif isinstance(endpoint, PassThroughGenericEndpoint): + _endpoint = endpoint + + if _endpoint is not None and _endpoint.path == endpoint_id: + invalid_idx = idx + response_obj = _endpoint + + if invalid_idx is not None: + pass_through_endpoint_data.pop(invalid_idx) + + ## Update db + updated_data = ConfigFieldUpdate( + field_name="pass_through_endpoints", + field_value=pass_through_endpoint_data, + config_type="general_settings", + ) + await update_config_general_settings( + data=updated_data, user_api_key_dict=user_api_key_dict + ) + + if response_obj is None: + raise HTTPException( + status_code=400, + detail={ + "error": "Endpoint={} was not found in pass-through endpoint list.".format( + endpoint_id + ) + }, + ) + return PassThroughEndpointResponse(endpoints=[response_obj]) diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/pass_through_endpoints/passthrough_endpoint_router.py b/.venv/lib/python3.12/site-packages/litellm/proxy/pass_through_endpoints/passthrough_endpoint_router.py new file mode 100644 index 00000000..adf7d0f3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/pass_through_endpoints/passthrough_endpoint_router.py @@ -0,0 +1,93 @@ +from typing import Dict, Optional + +from litellm._logging import verbose_logger +from litellm.secret_managers.main import get_secret_str + + +class PassthroughEndpointRouter: + """ + Use this class to Set/Get credentials for pass-through endpoints + """ + + def __init__(self): + self.credentials: Dict[str, str] = {} + + def set_pass_through_credentials( + self, + custom_llm_provider: str, + api_base: Optional[str], + api_key: Optional[str], + ): + """ + Set credentials for a pass-through endpoint. Used when a user adds a pass-through LLM endpoint on the UI. + + Args: + custom_llm_provider: The provider of the pass-through endpoint + api_base: The base URL of the pass-through endpoint + api_key: The API key for the pass-through endpoint + """ + credential_name = self._get_credential_name_for_provider( + custom_llm_provider=custom_llm_provider, + region_name=self._get_region_name_from_api_base( + api_base=api_base, custom_llm_provider=custom_llm_provider + ), + ) + if api_key is None: + raise ValueError("api_key is required for setting pass-through credentials") + self.credentials[credential_name] = api_key + + def get_credentials( + self, + custom_llm_provider: str, + region_name: Optional[str], + ) -> Optional[str]: + credential_name = self._get_credential_name_for_provider( + custom_llm_provider=custom_llm_provider, + region_name=region_name, + ) + verbose_logger.debug( + f"Pass-through llm endpoints router, looking for credentials for {credential_name}" + ) + if credential_name in self.credentials: + verbose_logger.debug(f"Found credentials for {credential_name}") + return self.credentials[credential_name] + else: + verbose_logger.debug( + f"No credentials found for {credential_name}, looking for env variable" + ) + _env_variable_name = ( + self._get_default_env_variable_name_passthrough_endpoint( + custom_llm_provider=custom_llm_provider, + ) + ) + return get_secret_str(_env_variable_name) + + def _get_credential_name_for_provider( + self, + custom_llm_provider: str, + region_name: Optional[str], + ) -> str: + if region_name is None: + return f"{custom_llm_provider.upper()}_API_KEY" + return f"{custom_llm_provider.upper()}_{region_name.upper()}_API_KEY" + + def _get_region_name_from_api_base( + self, + custom_llm_provider: str, + api_base: Optional[str], + ) -> Optional[str]: + """ + Get the region name from the API base. + + Each provider might have a different way of specifying the region in the API base - this is where you can use conditional logic to handle that. + """ + if custom_llm_provider == "assemblyai": + if api_base and "eu" in api_base: + return "eu" + return None + + @staticmethod + def _get_default_env_variable_name_passthrough_endpoint( + custom_llm_provider: str, + ) -> str: + return f"{custom_llm_provider.upper()}_API_KEY" diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/pass_through_endpoints/streaming_handler.py b/.venv/lib/python3.12/site-packages/litellm/proxy/pass_through_endpoints/streaming_handler.py new file mode 100644 index 00000000..b022bf1d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/pass_through_endpoints/streaming_handler.py @@ -0,0 +1,160 @@ +import asyncio +import threading +from datetime import datetime +from typing import List, Optional + +import httpx + +from litellm._logging import verbose_proxy_logger +from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +from litellm.proxy._types import PassThroughEndpointLoggingResultValues +from litellm.types.utils import StandardPassThroughResponseObject + +from .llm_provider_handlers.anthropic_passthrough_logging_handler import ( + AnthropicPassthroughLoggingHandler, +) +from .llm_provider_handlers.vertex_passthrough_logging_handler import ( + VertexPassthroughLoggingHandler, +) +from .success_handler import PassThroughEndpointLogging +from .types import EndpointType + + +class PassThroughStreamingHandler: + + @staticmethod + async def chunk_processor( + response: httpx.Response, + request_body: Optional[dict], + litellm_logging_obj: LiteLLMLoggingObj, + endpoint_type: EndpointType, + start_time: datetime, + passthrough_success_handler_obj: PassThroughEndpointLogging, + url_route: str, + ): + """ + - Yields chunks from the response + - Collect non-empty chunks for post-processing (logging) + """ + try: + raw_bytes: List[bytes] = [] + async for chunk in response.aiter_bytes(): + raw_bytes.append(chunk) + yield chunk + + # After all chunks are processed, handle post-processing + end_time = datetime.now() + + asyncio.create_task( + PassThroughStreamingHandler._route_streaming_logging_to_handler( + litellm_logging_obj=litellm_logging_obj, + passthrough_success_handler_obj=passthrough_success_handler_obj, + url_route=url_route, + request_body=request_body or {}, + endpoint_type=endpoint_type, + start_time=start_time, + raw_bytes=raw_bytes, + end_time=end_time, + ) + ) + except Exception as e: + verbose_proxy_logger.error(f"Error in chunk_processor: {str(e)}") + raise + + @staticmethod + async def _route_streaming_logging_to_handler( + litellm_logging_obj: LiteLLMLoggingObj, + passthrough_success_handler_obj: PassThroughEndpointLogging, + url_route: str, + request_body: dict, + endpoint_type: EndpointType, + start_time: datetime, + raw_bytes: List[bytes], + end_time: datetime, + ): + """ + Route the logging for the collected chunks to the appropriate handler + + Supported endpoint types: + - Anthropic + - Vertex AI + """ + all_chunks = PassThroughStreamingHandler._convert_raw_bytes_to_str_lines( + raw_bytes + ) + standard_logging_response_object: Optional[ + PassThroughEndpointLoggingResultValues + ] = None + kwargs: dict = {} + if endpoint_type == EndpointType.ANTHROPIC: + anthropic_passthrough_logging_handler_result = AnthropicPassthroughLoggingHandler._handle_logging_anthropic_collected_chunks( + litellm_logging_obj=litellm_logging_obj, + passthrough_success_handler_obj=passthrough_success_handler_obj, + url_route=url_route, + request_body=request_body, + endpoint_type=endpoint_type, + start_time=start_time, + all_chunks=all_chunks, + end_time=end_time, + ) + standard_logging_response_object = ( + anthropic_passthrough_logging_handler_result["result"] + ) + kwargs = anthropic_passthrough_logging_handler_result["kwargs"] + elif endpoint_type == EndpointType.VERTEX_AI: + vertex_passthrough_logging_handler_result = ( + VertexPassthroughLoggingHandler._handle_logging_vertex_collected_chunks( + litellm_logging_obj=litellm_logging_obj, + passthrough_success_handler_obj=passthrough_success_handler_obj, + url_route=url_route, + request_body=request_body, + endpoint_type=endpoint_type, + start_time=start_time, + all_chunks=all_chunks, + end_time=end_time, + ) + ) + standard_logging_response_object = ( + vertex_passthrough_logging_handler_result["result"] + ) + kwargs = vertex_passthrough_logging_handler_result["kwargs"] + + if standard_logging_response_object is None: + standard_logging_response_object = StandardPassThroughResponseObject( + response=f"cannot parse chunks to standard response object. Chunks={all_chunks}" + ) + threading.Thread( + target=litellm_logging_obj.success_handler, + args=( + standard_logging_response_object, + start_time, + end_time, + False, + ), + ).start() + await litellm_logging_obj.async_success_handler( + result=standard_logging_response_object, + start_time=start_time, + end_time=end_time, + cache_hit=False, + **kwargs, + ) + + @staticmethod + def _convert_raw_bytes_to_str_lines(raw_bytes: List[bytes]) -> List[str]: + """ + Converts a list of raw bytes into a list of string lines, similar to aiter_lines() + + Args: + raw_bytes: List of bytes chunks from aiter.bytes() + + Returns: + List of string lines, with each line being a complete data: {} chunk + """ + # Combine all bytes and decode to string + combined_str = b"".join(raw_bytes).decode("utf-8") + + # Split by newlines and filter out empty lines + lines = [line.strip() for line in combined_str.split("\n") if line.strip()] + + return lines diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/pass_through_endpoints/success_handler.py b/.venv/lib/python3.12/site-packages/litellm/proxy/pass_through_endpoints/success_handler.py new file mode 100644 index 00000000..02e81566 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/pass_through_endpoints/success_handler.py @@ -0,0 +1,182 @@ +import json +from datetime import datetime +from typing import Optional, Union +from urllib.parse import urlparse + +import httpx + +from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +from litellm.proxy._types import PassThroughEndpointLoggingResultValues +from litellm.types.utils import StandardPassThroughResponseObject +from litellm.utils import executor as thread_pool_executor + +from .llm_provider_handlers.anthropic_passthrough_logging_handler import ( + AnthropicPassthroughLoggingHandler, +) +from .llm_provider_handlers.assembly_passthrough_logging_handler import ( + AssemblyAIPassthroughLoggingHandler, +) +from .llm_provider_handlers.vertex_passthrough_logging_handler import ( + VertexPassthroughLoggingHandler, +) + + +class PassThroughEndpointLogging: + def __init__(self): + self.TRACKED_VERTEX_ROUTES = [ + "generateContent", + "streamGenerateContent", + "predict", + ] + + # Anthropic + self.TRACKED_ANTHROPIC_ROUTES = ["/messages"] + + self.assemblyai_passthrough_logging_handler = ( + AssemblyAIPassthroughLoggingHandler() + ) + + async def _handle_logging( + self, + logging_obj: LiteLLMLoggingObj, + standard_logging_response_object: Union[ + StandardPassThroughResponseObject, + PassThroughEndpointLoggingResultValues, + dict, + ], + result: str, + start_time: datetime, + end_time: datetime, + cache_hit: bool, + **kwargs, + ): + """Helper function to handle both sync and async logging operations""" + # Submit to thread pool for sync logging + thread_pool_executor.submit( + logging_obj.success_handler, + standard_logging_response_object, + start_time, + end_time, + cache_hit, + **kwargs, + ) + + # Handle async logging + await logging_obj.async_success_handler( + result=( + json.dumps(result) + if isinstance(result, dict) + else standard_logging_response_object + ), + start_time=start_time, + end_time=end_time, + cache_hit=False, + **kwargs, + ) + + async def pass_through_async_success_handler( + self, + httpx_response: httpx.Response, + response_body: Optional[dict], + logging_obj: LiteLLMLoggingObj, + url_route: str, + result: str, + start_time: datetime, + end_time: datetime, + cache_hit: bool, + **kwargs, + ): + standard_logging_response_object: Optional[ + PassThroughEndpointLoggingResultValues + ] = None + if self.is_vertex_route(url_route): + vertex_passthrough_logging_handler_result = ( + VertexPassthroughLoggingHandler.vertex_passthrough_handler( + httpx_response=httpx_response, + logging_obj=logging_obj, + url_route=url_route, + result=result, + start_time=start_time, + end_time=end_time, + cache_hit=cache_hit, + **kwargs, + ) + ) + standard_logging_response_object = ( + vertex_passthrough_logging_handler_result["result"] + ) + kwargs = vertex_passthrough_logging_handler_result["kwargs"] + elif self.is_anthropic_route(url_route): + anthropic_passthrough_logging_handler_result = ( + AnthropicPassthroughLoggingHandler.anthropic_passthrough_handler( + httpx_response=httpx_response, + response_body=response_body or {}, + logging_obj=logging_obj, + url_route=url_route, + result=result, + start_time=start_time, + end_time=end_time, + cache_hit=cache_hit, + **kwargs, + ) + ) + + standard_logging_response_object = ( + anthropic_passthrough_logging_handler_result["result"] + ) + kwargs = anthropic_passthrough_logging_handler_result["kwargs"] + elif self.is_assemblyai_route(url_route): + if ( + AssemblyAIPassthroughLoggingHandler._should_log_request( + httpx_response.request.method + ) + is not True + ): + return + self.assemblyai_passthrough_logging_handler.assemblyai_passthrough_logging_handler( + httpx_response=httpx_response, + response_body=response_body or {}, + logging_obj=logging_obj, + url_route=url_route, + result=result, + start_time=start_time, + end_time=end_time, + cache_hit=cache_hit, + **kwargs, + ) + return + + if standard_logging_response_object is None: + standard_logging_response_object = StandardPassThroughResponseObject( + response=httpx_response.text + ) + + await self._handle_logging( + logging_obj=logging_obj, + standard_logging_response_object=standard_logging_response_object, + result=result, + start_time=start_time, + end_time=end_time, + cache_hit=cache_hit, + **kwargs, + ) + + def is_vertex_route(self, url_route: str): + for route in self.TRACKED_VERTEX_ROUTES: + if route in url_route: + return True + return False + + def is_anthropic_route(self, url_route: str): + for route in self.TRACKED_ANTHROPIC_ROUTES: + if route in url_route: + return True + return False + + def is_assemblyai_route(self, url_route: str): + parsed_url = urlparse(url_route) + if parsed_url.hostname == "api.assemblyai.com": + return True + elif "/transcript" in parsed_url.path: + return True + return False diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/pass_through_endpoints/types.py b/.venv/lib/python3.12/site-packages/litellm/proxy/pass_through_endpoints/types.py new file mode 100644 index 00000000..59047a63 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/pass_through_endpoints/types.py @@ -0,0 +1,18 @@ +from enum import Enum +from typing import Optional, TypedDict + + +class EndpointType(str, Enum): + VERTEX_AI = "vertex-ai" + ANTHROPIC = "anthropic" + GENERIC = "generic" + + +class PassthroughStandardLoggingPayload(TypedDict, total=False): + """ + Standard logging payload for all pass through endpoints + """ + + url: str + request_body: Optional[dict] + response_body: Optional[dict] # only tracked for non-streaming responses |