diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/proxy/pass_through_endpoints/llm_provider_handlers')
3 files changed, 804 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py b/.venv/lib/python3.12/site-packages/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py new file mode 100644 index 00000000..51845956 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py @@ -0,0 +1,220 @@ +import json +from datetime import datetime +from typing import TYPE_CHECKING, Any, List, Optional, Union + +import httpx + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +from litellm.llms.anthropic.chat.handler import ( + ModelResponseIterator as AnthropicModelResponseIterator, +) +from litellm.llms.anthropic.chat.transformation import AnthropicConfig +from litellm.proxy._types import PassThroughEndpointLoggingTypedDict +from litellm.proxy.auth.auth_utils import get_end_user_id_from_request_body +from litellm.proxy.pass_through_endpoints.types import PassthroughStandardLoggingPayload +from litellm.types.utils import ModelResponse, TextCompletionResponse + +if TYPE_CHECKING: + from ..success_handler import PassThroughEndpointLogging + from ..types import EndpointType +else: + PassThroughEndpointLogging = Any + EndpointType = Any + + +class AnthropicPassthroughLoggingHandler: + + @staticmethod + def anthropic_passthrough_handler( + httpx_response: httpx.Response, + response_body: dict, + logging_obj: LiteLLMLoggingObj, + url_route: str, + result: str, + start_time: datetime, + end_time: datetime, + cache_hit: bool, + **kwargs, + ) -> PassThroughEndpointLoggingTypedDict: + """ + Transforms Anthropic response to OpenAI response, generates a standard logging object so downstream logging can be handled + """ + model = response_body.get("model", "") + litellm_model_response: ModelResponse = AnthropicConfig().transform_response( + raw_response=httpx_response, + model_response=litellm.ModelResponse(), + model=model, + messages=[], + logging_obj=logging_obj, + optional_params={}, + api_key="", + request_data={}, + encoding=litellm.encoding, + json_mode=False, + litellm_params={}, + ) + + kwargs = AnthropicPassthroughLoggingHandler._create_anthropic_response_logging_payload( + litellm_model_response=litellm_model_response, + model=model, + kwargs=kwargs, + start_time=start_time, + end_time=end_time, + logging_obj=logging_obj, + ) + + return { + "result": litellm_model_response, + "kwargs": kwargs, + } + + @staticmethod + def _get_user_from_metadata( + passthrough_logging_payload: PassthroughStandardLoggingPayload, + ) -> Optional[str]: + request_body = passthrough_logging_payload.get("request_body") + if request_body: + return get_end_user_id_from_request_body(request_body) + return None + + @staticmethod + def _create_anthropic_response_logging_payload( + litellm_model_response: Union[ModelResponse, TextCompletionResponse], + model: str, + kwargs: dict, + start_time: datetime, + end_time: datetime, + logging_obj: LiteLLMLoggingObj, + ): + """ + Create the standard logging object for Anthropic passthrough + + handles streaming and non-streaming responses + """ + try: + response_cost = litellm.completion_cost( + completion_response=litellm_model_response, + model=model, + ) + kwargs["response_cost"] = response_cost + kwargs["model"] = model + passthrough_logging_payload: Optional[PassthroughStandardLoggingPayload] = ( # type: ignore + kwargs.get("passthrough_logging_payload") + ) + if passthrough_logging_payload: + user = AnthropicPassthroughLoggingHandler._get_user_from_metadata( + passthrough_logging_payload=passthrough_logging_payload, + ) + if user: + kwargs.setdefault("litellm_params", {}) + kwargs["litellm_params"].update( + {"proxy_server_request": {"body": {"user": user}}} + ) + + # pretty print standard logging object + verbose_proxy_logger.debug( + "kwargs= %s", + json.dumps(kwargs, indent=4, default=str), + ) + + # set litellm_call_id to logging response object + litellm_model_response.id = logging_obj.litellm_call_id + litellm_model_response.model = model + logging_obj.model_call_details["model"] = model + logging_obj.model_call_details["custom_llm_provider"] = ( + litellm.LlmProviders.ANTHROPIC.value + ) + return kwargs + except Exception as e: + verbose_proxy_logger.exception( + "Error creating Anthropic response logging payload: %s", e + ) + return kwargs + + @staticmethod + def _handle_logging_anthropic_collected_chunks( + litellm_logging_obj: LiteLLMLoggingObj, + passthrough_success_handler_obj: PassThroughEndpointLogging, + url_route: str, + request_body: dict, + endpoint_type: EndpointType, + start_time: datetime, + all_chunks: List[str], + end_time: datetime, + ) -> PassThroughEndpointLoggingTypedDict: + """ + Takes raw chunks from Anthropic passthrough endpoint and logs them in litellm callbacks + + - Builds complete response from chunks + - Creates standard logging object + - Logs in litellm callbacks + """ + + model = request_body.get("model", "") + complete_streaming_response = ( + AnthropicPassthroughLoggingHandler._build_complete_streaming_response( + all_chunks=all_chunks, + litellm_logging_obj=litellm_logging_obj, + model=model, + ) + ) + if complete_streaming_response is None: + verbose_proxy_logger.error( + "Unable to build complete streaming response for Anthropic passthrough endpoint, not logging..." + ) + return { + "result": None, + "kwargs": {}, + } + kwargs = AnthropicPassthroughLoggingHandler._create_anthropic_response_logging_payload( + litellm_model_response=complete_streaming_response, + model=model, + kwargs={}, + start_time=start_time, + end_time=end_time, + logging_obj=litellm_logging_obj, + ) + + return { + "result": complete_streaming_response, + "kwargs": kwargs, + } + + @staticmethod + def _build_complete_streaming_response( + all_chunks: List[str], + litellm_logging_obj: LiteLLMLoggingObj, + model: str, + ) -> Optional[Union[ModelResponse, TextCompletionResponse]]: + """ + Builds complete response from raw Anthropic chunks + + - Converts str chunks to generic chunks + - Converts generic chunks to litellm chunks (OpenAI format) + - Builds complete response from litellm chunks + """ + anthropic_model_response_iterator = AnthropicModelResponseIterator( + streaming_response=None, + sync_stream=False, + ) + all_openai_chunks = [] + for _chunk_str in all_chunks: + try: + transformed_openai_chunk = anthropic_model_response_iterator.convert_str_chunk_to_generic_chunk( + chunk=_chunk_str + ) + if transformed_openai_chunk is not None: + all_openai_chunks.append(transformed_openai_chunk) + + verbose_proxy_logger.debug( + "all openai chunks= %s", + json.dumps(all_openai_chunks, indent=4, default=str), + ) + except (StopIteration, StopAsyncIteration): + break + complete_streaming_response = litellm.stream_chunk_builder( + chunks=all_openai_chunks + ) + return complete_streaming_response diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/pass_through_endpoints/llm_provider_handlers/assembly_passthrough_logging_handler.py b/.venv/lib/python3.12/site-packages/litellm/proxy/pass_through_endpoints/llm_provider_handlers/assembly_passthrough_logging_handler.py new file mode 100644 index 00000000..7cf3013d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/pass_through_endpoints/llm_provider_handlers/assembly_passthrough_logging_handler.py @@ -0,0 +1,326 @@ +import asyncio +import json +import time +from datetime import datetime +from typing import Literal, Optional, TypedDict +from urllib.parse import urlparse + +import httpx + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +from litellm.litellm_core_utils.litellm_logging import ( + get_standard_logging_object_payload, +) +from litellm.litellm_core_utils.thread_pool_executor import executor +from litellm.proxy.pass_through_endpoints.types import PassthroughStandardLoggingPayload + + +class AssemblyAITranscriptResponse(TypedDict, total=False): + id: str + speech_model: str + acoustic_model: str + language_code: str + status: str + audio_duration: float + + +class AssemblyAIPassthroughLoggingHandler: + def __init__(self): + self.assembly_ai_base_url = "https://api.assemblyai.com" + self.assembly_ai_eu_base_url = "https://eu.assemblyai.com" + """ + The base URL for the AssemblyAI API + """ + + self.polling_interval: float = 10 + """ + The polling interval for the AssemblyAI API. + litellm needs to poll the GET /transcript/{transcript_id} endpoint to get the status of the transcript. + """ + + self.max_polling_attempts = 180 + """ + The maximum number of polling attempts for the AssemblyAI API. + """ + + def assemblyai_passthrough_logging_handler( + self, + httpx_response: httpx.Response, + response_body: dict, + logging_obj: LiteLLMLoggingObj, + url_route: str, + result: str, + start_time: datetime, + end_time: datetime, + cache_hit: bool, + **kwargs, + ): + """ + Since cost tracking requires polling the AssemblyAI API, we need to handle this in a separate thread. Hence the executor.submit. + """ + executor.submit( + self._handle_assemblyai_passthrough_logging, + httpx_response, + response_body, + logging_obj, + url_route, + result, + start_time, + end_time, + cache_hit, + **kwargs, + ) + + def _handle_assemblyai_passthrough_logging( + self, + httpx_response: httpx.Response, + response_body: dict, + logging_obj: LiteLLMLoggingObj, + url_route: str, + result: str, + start_time: datetime, + end_time: datetime, + cache_hit: bool, + **kwargs, + ): + """ + Handles logging for AssemblyAI successful passthrough requests + """ + from ..pass_through_endpoints import pass_through_endpoint_logging + + model = response_body.get("speech_model", "") + verbose_proxy_logger.debug( + "response body %s", json.dumps(response_body, indent=4) + ) + kwargs["model"] = model + kwargs["custom_llm_provider"] = "assemblyai" + response_cost: Optional[float] = None + + transcript_id = response_body.get("id") + if transcript_id is None: + raise ValueError( + "Transcript ID is required to log the cost of the transcription" + ) + transcript_response = self._poll_assembly_for_transcript_response( + transcript_id=transcript_id, url_route=url_route + ) + verbose_proxy_logger.debug( + "finished polling assembly for transcript response- got transcript response %s", + json.dumps(transcript_response, indent=4), + ) + if transcript_response: + cost = self.get_cost_for_assembly_transcript( + speech_model=model, + transcript_response=transcript_response, + ) + response_cost = cost + + # Make standard logging object for Vertex AI + standard_logging_object = get_standard_logging_object_payload( + kwargs=kwargs, + init_response_obj=transcript_response, + start_time=start_time, + end_time=end_time, + logging_obj=logging_obj, + status="success", + ) + + passthrough_logging_payload: Optional[PassthroughStandardLoggingPayload] = ( # type: ignore + kwargs.get("passthrough_logging_payload") + ) + + verbose_proxy_logger.debug( + "standard_passthrough_logging_object %s", + json.dumps(passthrough_logging_payload, indent=4), + ) + + # pretty print standard logging object + verbose_proxy_logger.debug( + "standard_logging_object= %s", json.dumps(standard_logging_object, indent=4) + ) + logging_obj.model_call_details["model"] = model + logging_obj.model_call_details["custom_llm_provider"] = "assemblyai" + logging_obj.model_call_details["response_cost"] = response_cost + + asyncio.run( + pass_through_endpoint_logging._handle_logging( + logging_obj=logging_obj, + standard_logging_response_object=self._get_response_to_log( + transcript_response + ), + result=result, + start_time=start_time, + end_time=end_time, + cache_hit=cache_hit, + **kwargs, + ) + ) + + pass + + def _get_response_to_log( + self, transcript_response: Optional[AssemblyAITranscriptResponse] + ) -> dict: + if transcript_response is None: + return {} + return dict(transcript_response) + + def _get_assembly_transcript( + self, + transcript_id: str, + request_region: Optional[Literal["eu"]] = None, + ) -> Optional[dict]: + """ + Get the transcript details from AssemblyAI API + + Args: + response_body (dict): Response containing the transcript ID + + Returns: + Optional[dict]: Transcript details if successful, None otherwise + """ + from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import ( + passthrough_endpoint_router, + ) + + _base_url = ( + self.assembly_ai_eu_base_url + if request_region == "eu" + else self.assembly_ai_base_url + ) + _api_key = passthrough_endpoint_router.get_credentials( + custom_llm_provider="assemblyai", + region_name=request_region, + ) + if _api_key is None: + raise ValueError("AssemblyAI API key not found") + try: + url = f"{_base_url}/v2/transcript/{transcript_id}" + headers = { + "Authorization": f"Bearer {_api_key}", + "Content-Type": "application/json", + } + + response = httpx.get(url, headers=headers) + response.raise_for_status() + + return response.json() + except Exception as e: + verbose_proxy_logger.exception( + f"[Non blocking logging error] Error getting AssemblyAI transcript: {str(e)}" + ) + return None + + def _poll_assembly_for_transcript_response( + self, + transcript_id: str, + url_route: Optional[str] = None, + ) -> Optional[AssemblyAITranscriptResponse]: + """ + Poll the status of the transcript until it is completed or timeout (30 minutes) + """ + for _ in range( + self.max_polling_attempts + ): # 180 attempts * 10s = 30 minutes max + transcript = self._get_assembly_transcript( + request_region=AssemblyAIPassthroughLoggingHandler._get_assembly_region_from_url( + url=url_route + ), + transcript_id=transcript_id, + ) + if transcript is None: + return None + if ( + transcript.get("status") == "completed" + or transcript.get("status") == "error" + ): + return AssemblyAITranscriptResponse(**transcript) + time.sleep(self.polling_interval) + return None + + @staticmethod + def get_cost_for_assembly_transcript( + transcript_response: AssemblyAITranscriptResponse, + speech_model: str, + ) -> Optional[float]: + """ + Get the cost for the assembly transcript + """ + _audio_duration = transcript_response.get("audio_duration") + if _audio_duration is None: + return None + _cost_per_second = ( + AssemblyAIPassthroughLoggingHandler.get_cost_per_second_for_assembly_model( + speech_model=speech_model + ) + ) + if _cost_per_second is None: + return None + return _audio_duration * _cost_per_second + + @staticmethod + def get_cost_per_second_for_assembly_model(speech_model: str) -> Optional[float]: + """ + Get the cost per second for the assembly model. + Falls back to assemblyai/nano if the specific speech model info cannot be found. + """ + try: + # First try with the provided speech model + try: + model_info = litellm.get_model_info( + model=speech_model, + custom_llm_provider="assemblyai", + ) + if model_info and model_info.get("input_cost_per_second") is not None: + return model_info.get("input_cost_per_second") + except Exception: + pass # Continue to fallback if model not found + + # Fallback to assemblyai/nano if speech model info not found + try: + model_info = litellm.get_model_info( + model="assemblyai/nano", + custom_llm_provider="assemblyai", + ) + if model_info and model_info.get("input_cost_per_second") is not None: + return model_info.get("input_cost_per_second") + except Exception: + pass + + return None + except Exception as e: + verbose_proxy_logger.exception( + f"[Non blocking logging error] Error getting AssemblyAI model info: {str(e)}" + ) + return None + + @staticmethod + def _should_log_request(request_method: str) -> bool: + """ + only POST transcription jobs are logged. litellm will POLL assembly to wait for the transcription to complete to log the complete response / cost + """ + return request_method == "POST" + + @staticmethod + def _get_assembly_region_from_url(url: Optional[str]) -> Optional[Literal["eu"]]: + """ + Get the region from the URL + """ + if url is None: + return None + if urlparse(url).hostname == "eu.assemblyai.com": + return "eu" + return None + + @staticmethod + def _get_assembly_base_url_from_region(region: Optional[Literal["eu"]]) -> str: + """ + Get the base URL for the AssemblyAI API + if region == "eu", return "https://api.eu.assemblyai.com" + else return "https://api.assemblyai.com" + """ + if region == "eu": + return "https://api.eu.assemblyai.com" + return "https://api.assemblyai.com" diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/pass_through_endpoints/llm_provider_handlers/vertex_passthrough_logging_handler.py b/.venv/lib/python3.12/site-packages/litellm/proxy/pass_through_endpoints/llm_provider_handlers/vertex_passthrough_logging_handler.py new file mode 100644 index 00000000..94435637 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/pass_through_endpoints/llm_provider_handlers/vertex_passthrough_logging_handler.py @@ -0,0 +1,258 @@ +import json +import re +from datetime import datetime +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from urllib.parse import urlparse +import httpx + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import ( + ModelResponseIterator as VertexModelResponseIterator, +) +from litellm.proxy._types import PassThroughEndpointLoggingTypedDict +from litellm.types.utils import ( + EmbeddingResponse, + ImageResponse, + ModelResponse, + TextCompletionResponse, +) + +if TYPE_CHECKING: + from ..success_handler import PassThroughEndpointLogging + from ..types import EndpointType +else: + PassThroughEndpointLogging = Any + EndpointType = Any + + +class VertexPassthroughLoggingHandler: + @staticmethod + def vertex_passthrough_handler( + httpx_response: httpx.Response, + logging_obj: LiteLLMLoggingObj, + url_route: str, + result: str, + start_time: datetime, + end_time: datetime, + cache_hit: bool, + **kwargs, + ) -> PassThroughEndpointLoggingTypedDict: + if "generateContent" in url_route: + model = VertexPassthroughLoggingHandler.extract_model_from_url(url_route) + + instance_of_vertex_llm = litellm.VertexGeminiConfig() + litellm_model_response: ModelResponse = ( + instance_of_vertex_llm.transform_response( + model=model, + messages=[ + {"role": "user", "content": "no-message-pass-through-endpoint"} + ], + raw_response=httpx_response, + model_response=litellm.ModelResponse(), + logging_obj=logging_obj, + optional_params={}, + litellm_params={}, + api_key="", + request_data={}, + encoding=litellm.encoding, + ) + ) + kwargs = VertexPassthroughLoggingHandler._create_vertex_response_logging_payload_for_generate_content( + litellm_model_response=litellm_model_response, + model=model, + kwargs=kwargs, + start_time=start_time, + end_time=end_time, + logging_obj=logging_obj, + custom_llm_provider=VertexPassthroughLoggingHandler._get_custom_llm_provider_from_url( + url_route + ), + ) + + return { + "result": litellm_model_response, + "kwargs": kwargs, + } + + elif "predict" in url_route: + from litellm.llms.vertex_ai.image_generation.image_generation_handler import ( + VertexImageGeneration, + ) + from litellm.types.utils import PassthroughCallTypes + + vertex_image_generation_class = VertexImageGeneration() + + model = VertexPassthroughLoggingHandler.extract_model_from_url(url_route) + _json_response = httpx_response.json() + + litellm_prediction_response: Union[ + ModelResponse, EmbeddingResponse, ImageResponse + ] = ModelResponse() + if vertex_image_generation_class.is_image_generation_response( + _json_response + ): + litellm_prediction_response = ( + vertex_image_generation_class.process_image_generation_response( + _json_response, + model_response=litellm.ImageResponse(), + model=model, + ) + ) + + logging_obj.call_type = ( + PassthroughCallTypes.passthrough_image_generation.value + ) + else: + litellm_prediction_response = litellm.vertexAITextEmbeddingConfig.transform_vertex_response_to_openai( + response=_json_response, + model=model, + model_response=litellm.EmbeddingResponse(), + ) + if isinstance(litellm_prediction_response, litellm.EmbeddingResponse): + litellm_prediction_response.model = model + + logging_obj.model = model + logging_obj.model_call_details["model"] = logging_obj.model + + return { + "result": litellm_prediction_response, + "kwargs": kwargs, + } + else: + return { + "result": None, + "kwargs": kwargs, + } + + @staticmethod + def _handle_logging_vertex_collected_chunks( + litellm_logging_obj: LiteLLMLoggingObj, + passthrough_success_handler_obj: PassThroughEndpointLogging, + url_route: str, + request_body: dict, + endpoint_type: EndpointType, + start_time: datetime, + all_chunks: List[str], + end_time: datetime, + ) -> PassThroughEndpointLoggingTypedDict: + """ + Takes raw chunks from Vertex passthrough endpoint and logs them in litellm callbacks + + - Builds complete response from chunks + - Creates standard logging object + - Logs in litellm callbacks + """ + kwargs: Dict[str, Any] = {} + model = VertexPassthroughLoggingHandler.extract_model_from_url(url_route) + complete_streaming_response = ( + VertexPassthroughLoggingHandler._build_complete_streaming_response( + all_chunks=all_chunks, + litellm_logging_obj=litellm_logging_obj, + model=model, + ) + ) + + if complete_streaming_response is None: + verbose_proxy_logger.error( + "Unable to build complete streaming response for Vertex passthrough endpoint, not logging..." + ) + return { + "result": None, + "kwargs": kwargs, + } + + kwargs = VertexPassthroughLoggingHandler._create_vertex_response_logging_payload_for_generate_content( + litellm_model_response=complete_streaming_response, + model=model, + kwargs=kwargs, + start_time=start_time, + end_time=end_time, + logging_obj=litellm_logging_obj, + custom_llm_provider=VertexPassthroughLoggingHandler._get_custom_llm_provider_from_url( + url_route + ), + ) + + return { + "result": complete_streaming_response, + "kwargs": kwargs, + } + + @staticmethod + def _build_complete_streaming_response( + all_chunks: List[str], + litellm_logging_obj: LiteLLMLoggingObj, + model: str, + ) -> Optional[Union[ModelResponse, TextCompletionResponse]]: + vertex_iterator = VertexModelResponseIterator( + streaming_response=None, + sync_stream=False, + ) + litellm_custom_stream_wrapper = litellm.CustomStreamWrapper( + completion_stream=vertex_iterator, + model=model, + logging_obj=litellm_logging_obj, + custom_llm_provider="vertex_ai", + ) + all_openai_chunks = [] + for chunk in all_chunks: + generic_chunk = vertex_iterator._common_chunk_parsing_logic(chunk) + litellm_chunk = litellm_custom_stream_wrapper.chunk_creator( + chunk=generic_chunk + ) + if litellm_chunk is not None: + all_openai_chunks.append(litellm_chunk) + + complete_streaming_response = litellm.stream_chunk_builder( + chunks=all_openai_chunks + ) + + return complete_streaming_response + + @staticmethod + def extract_model_from_url(url: str) -> str: + pattern = r"/models/([^:]+)" + match = re.search(pattern, url) + if match: + return match.group(1) + return "unknown" + + @staticmethod + def _get_custom_llm_provider_from_url(url: str) -> str: + parsed_url = urlparse(url) + if parsed_url.hostname and parsed_url.hostname.endswith("generativelanguage.googleapis.com"): + return litellm.LlmProviders.GEMINI.value + return litellm.LlmProviders.VERTEX_AI.value + + @staticmethod + def _create_vertex_response_logging_payload_for_generate_content( + litellm_model_response: Union[ModelResponse, TextCompletionResponse], + model: str, + kwargs: dict, + start_time: datetime, + end_time: datetime, + logging_obj: LiteLLMLoggingObj, + custom_llm_provider: str, + ): + """ + Create the standard logging object for Vertex passthrough generateContent (streaming and non-streaming) + + """ + response_cost = litellm.completion_cost( + completion_response=litellm_model_response, + model=model, + ) + kwargs["response_cost"] = response_cost + kwargs["model"] = model + + # pretty print standard logging object + verbose_proxy_logger.debug("kwargs= %s", json.dumps(kwargs, indent=4)) + + # set litellm_call_id to logging response object + litellm_model_response.id = logging_obj.litellm_call_id + logging_obj.model = litellm_model_response.model or model + logging_obj.model_call_details["model"] = logging_obj.model + logging_obj.model_call_details["custom_llm_provider"] = custom_llm_provider + return kwargs |