diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/integrations/custom_logger.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/integrations/custom_logger.py | 388 |
1 files changed, 388 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/custom_logger.py b/.venv/lib/python3.12/site-packages/litellm/integrations/custom_logger.py new file mode 100644 index 00000000..6f1ec88d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/integrations/custom_logger.py @@ -0,0 +1,388 @@ +#### What this does #### +# On success, logs events to Promptlayer +import traceback +from typing import ( + TYPE_CHECKING, + Any, + AsyncGenerator, + List, + Literal, + Optional, + Tuple, + Union, +) + +from pydantic import BaseModel + +from litellm.caching.caching import DualCache +from litellm.proxy._types import UserAPIKeyAuth +from litellm.types.integrations.argilla import ArgillaItem +from litellm.types.llms.openai import AllMessageValues, ChatCompletionRequest +from litellm.types.utils import ( + AdapterCompletionStreamWrapper, + EmbeddingResponse, + ImageResponse, + ModelResponse, + ModelResponseStream, + StandardCallbackDynamicParams, + StandardLoggingPayload, +) + +if TYPE_CHECKING: + from opentelemetry.trace import Span as _Span + + Span = _Span +else: + Span = Any + + +class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback#callback-class + # Class variables or attributes + def __init__(self, message_logging: bool = True) -> None: + self.message_logging = message_logging + pass + + def log_pre_api_call(self, model, messages, kwargs): + pass + + def log_post_api_call(self, kwargs, response_obj, start_time, end_time): + pass + + def log_stream_event(self, kwargs, response_obj, start_time, end_time): + pass + + def log_success_event(self, kwargs, response_obj, start_time, end_time): + pass + + def log_failure_event(self, kwargs, response_obj, start_time, end_time): + pass + + #### ASYNC #### + + async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time): + pass + + async def async_log_pre_api_call(self, model, messages, kwargs): + pass + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + pass + + async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): + pass + + #### PROMPT MANAGEMENT HOOKS #### + + async def async_get_chat_completion_prompt( + self, + model: str, + messages: List[AllMessageValues], + non_default_params: dict, + prompt_id: str, + prompt_variables: Optional[dict], + dynamic_callback_params: StandardCallbackDynamicParams, + ) -> Tuple[str, List[AllMessageValues], dict]: + """ + Returns: + - model: str - the model to use (can be pulled from prompt management tool) + - messages: List[AllMessageValues] - the messages to use (can be pulled from prompt management tool) + - non_default_params: dict - update with any optional params (e.g. temperature, max_tokens, etc.) to use (can be pulled from prompt management tool) + """ + return model, messages, non_default_params + + def get_chat_completion_prompt( + self, + model: str, + messages: List[AllMessageValues], + non_default_params: dict, + prompt_id: str, + prompt_variables: Optional[dict], + dynamic_callback_params: StandardCallbackDynamicParams, + ) -> Tuple[str, List[AllMessageValues], dict]: + """ + Returns: + - model: str - the model to use (can be pulled from prompt management tool) + - messages: List[AllMessageValues] - the messages to use (can be pulled from prompt management tool) + - non_default_params: dict - update with any optional params (e.g. temperature, max_tokens, etc.) to use (can be pulled from prompt management tool) + """ + return model, messages, non_default_params + + #### PRE-CALL CHECKS - router/proxy only #### + """ + Allows usage-based-routing-v2 to run pre-call rpm checks within the picked deployment's semaphore (concurrency-safe tpm/rpm checks). + """ + + async def async_filter_deployments( + self, + model: str, + healthy_deployments: List, + messages: Optional[List[AllMessageValues]], + request_kwargs: Optional[dict] = None, + parent_otel_span: Optional[Span] = None, + ) -> List[dict]: + return healthy_deployments + + async def async_pre_call_check( + self, deployment: dict, parent_otel_span: Optional[Span] + ) -> Optional[dict]: + pass + + def pre_call_check(self, deployment: dict) -> Optional[dict]: + pass + + #### Fallback Events - router/proxy only #### + async def log_model_group_rate_limit_error( + self, exception: Exception, original_model_group: Optional[str], kwargs: dict + ): + pass + + async def log_success_fallback_event( + self, original_model_group: str, kwargs: dict, original_exception: Exception + ): + pass + + async def log_failure_fallback_event( + self, original_model_group: str, kwargs: dict, original_exception: Exception + ): + pass + + #### ADAPTERS #### Allow calling 100+ LLMs in custom format - https://github.com/BerriAI/litellm/pulls + + def translate_completion_input_params( + self, kwargs + ) -> Optional[ChatCompletionRequest]: + """ + Translates the input params, from the provider's native format to the litellm.completion() format. + """ + pass + + def translate_completion_output_params( + self, response: ModelResponse + ) -> Optional[BaseModel]: + """ + Translates the output params, from the OpenAI format to the custom format. + """ + pass + + def translate_completion_output_params_streaming( + self, completion_stream: Any + ) -> Optional[AdapterCompletionStreamWrapper]: + """ + Translates the streaming chunk, from the OpenAI format to the custom format. + """ + pass + + ### DATASET HOOKS #### - currently only used for Argilla + + async def async_dataset_hook( + self, + logged_item: ArgillaItem, + standard_logging_payload: Optional[StandardLoggingPayload], + ) -> Optional[ArgillaItem]: + """ + - Decide if the result should be logged to Argilla. + - Modify the result before logging to Argilla. + - Return None if the result should not be logged to Argilla. + """ + raise NotImplementedError("async_dataset_hook not implemented") + + #### CALL HOOKS - proxy only #### + """ + Control the modify incoming / outgoung data before calling the model + """ + + async def async_pre_call_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + cache: DualCache, + data: dict, + call_type: Literal[ + "completion", + "text_completion", + "embeddings", + "image_generation", + "moderation", + "audio_transcription", + "pass_through_endpoint", + "rerank", + ], + ) -> Optional[ + Union[Exception, str, dict] + ]: # raise exception if invalid, return a str for the user to receive - if rejected, or return a modified dictionary for passing into litellm + pass + + async def async_post_call_failure_hook( + self, + request_data: dict, + original_exception: Exception, + user_api_key_dict: UserAPIKeyAuth, + ): + pass + + async def async_post_call_success_hook( + self, + data: dict, + user_api_key_dict: UserAPIKeyAuth, + response: Union[Any, ModelResponse, EmbeddingResponse, ImageResponse], + ) -> Any: + pass + + async def async_logging_hook( + self, kwargs: dict, result: Any, call_type: str + ) -> Tuple[dict, Any]: + """For masking logged request/response. Return a modified version of the request/result.""" + return kwargs, result + + def logging_hook( + self, kwargs: dict, result: Any, call_type: str + ) -> Tuple[dict, Any]: + """For masking logged request/response. Return a modified version of the request/result.""" + return kwargs, result + + async def async_moderation_hook( + self, + data: dict, + user_api_key_dict: UserAPIKeyAuth, + call_type: Literal[ + "completion", + "embeddings", + "image_generation", + "moderation", + "audio_transcription", + "responses", + ], + ) -> Any: + pass + + async def async_post_call_streaming_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + response: str, + ) -> Any: + pass + + async def async_post_call_streaming_iterator_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + response: Any, + request_data: dict, + ) -> AsyncGenerator[ModelResponseStream, None]: + async for item in response: + yield item + + #### SINGLE-USE #### - https://docs.litellm.ai/docs/observability/custom_callback#using-your-custom-callback-function + + def log_input_event(self, model, messages, kwargs, print_verbose, callback_func): + try: + kwargs["model"] = model + kwargs["messages"] = messages + kwargs["log_event_type"] = "pre_api_call" + callback_func( + kwargs, + ) + print_verbose(f"Custom Logger - model call details: {kwargs}") + except Exception: + print_verbose(f"Custom Logger Error - {traceback.format_exc()}") + + async def async_log_input_event( + self, model, messages, kwargs, print_verbose, callback_func + ): + try: + kwargs["model"] = model + kwargs["messages"] = messages + kwargs["log_event_type"] = "pre_api_call" + await callback_func( + kwargs, + ) + print_verbose(f"Custom Logger - model call details: {kwargs}") + except Exception: + print_verbose(f"Custom Logger Error - {traceback.format_exc()}") + + def log_event( + self, kwargs, response_obj, start_time, end_time, print_verbose, callback_func + ): + # Method definition + try: + kwargs["log_event_type"] = "post_api_call" + callback_func( + kwargs, # kwargs to func + response_obj, + start_time, + end_time, + ) + except Exception: + print_verbose(f"Custom Logger Error - {traceback.format_exc()}") + pass + + async def async_log_event( + self, kwargs, response_obj, start_time, end_time, print_verbose, callback_func + ): + # Method definition + try: + kwargs["log_event_type"] = "post_api_call" + await callback_func( + kwargs, # kwargs to func + response_obj, + start_time, + end_time, + ) + except Exception: + print_verbose(f"Custom Logger Error - {traceback.format_exc()}") + pass + + # Useful helpers for custom logger classes + + def truncate_standard_logging_payload_content( + self, + standard_logging_object: StandardLoggingPayload, + ): + """ + Truncate error strings and message content in logging payload + + Some loggers like DataDog/ GCS Bucket have a limit on the size of the payload. (1MB) + + This function truncates the error string and the message content if they exceed a certain length. + """ + MAX_STR_LENGTH = 10_000 + + # Truncate fields that might exceed max length + fields_to_truncate = ["error_str", "messages", "response"] + for field in fields_to_truncate: + self._truncate_field( + standard_logging_object=standard_logging_object, + field_name=field, + max_length=MAX_STR_LENGTH, + ) + + def _truncate_field( + self, + standard_logging_object: StandardLoggingPayload, + field_name: str, + max_length: int, + ) -> None: + """ + Helper function to truncate a field in the logging payload + + This converts the field to a string and then truncates it if it exceeds the max length. + + Why convert to string ? + 1. User was sending a poorly formatted list for `messages` field, we could not predict where they would send content + - Converting to string and then truncating the logged content catches this + 2. We want to avoid modifying the original `messages`, `response`, and `error_str` in the logging payload since these are in kwargs and could be returned to the user + """ + field_value = standard_logging_object.get(field_name) # type: ignore + if field_value: + str_value = str(field_value) + if len(str_value) > max_length: + standard_logging_object[field_name] = self._truncate_text( # type: ignore + text=str_value, max_length=max_length + ) + + def _truncate_text(self, text: str, max_length: int) -> str: + """Truncate text if it exceeds max_length""" + return ( + text[:max_length] + + "...truncated by litellm, this logger does not support large content" + if len(text) > max_length + else text + ) |