diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/integrations')
57 files changed, 15250 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/Readme.md b/.venv/lib/python3.12/site-packages/litellm/integrations/Readme.md new file mode 100644 index 00000000..2b0b530a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/integrations/Readme.md @@ -0,0 +1,5 @@ +# Integrations + +This folder contains logging integrations for litellm + +eg. logging to Datadog, Langfuse, Prometheus, s3, GCS Bucket, etc.
\ No newline at end of file diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/SlackAlerting/Readme.md b/.venv/lib/python3.12/site-packages/litellm/integrations/SlackAlerting/Readme.md new file mode 100644 index 00000000..f28f7150 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/integrations/SlackAlerting/Readme.md @@ -0,0 +1,13 @@ +# Slack Alerting on LiteLLM Gateway + +This folder contains the Slack Alerting integration for LiteLLM Gateway. + +## Folder Structure + +- `slack_alerting.py`: This is the main file that handles sending different types of alerts +- `batching_handler.py`: Handles Batching + sending Httpx Post requests to slack. Slack alerts are sent every 10s or when events are greater than X events. Done to ensure litellm has good performance under high traffic +- `types.py`: This file contains the AlertType enum which is used to define the different types of alerts that can be sent to Slack. +- `utils.py`: This file contains common utils used specifically for slack alerting + +## Further Reading +- [Doc setting up Alerting on LiteLLM Proxy (Gateway)](https://docs.litellm.ai/docs/proxy/alerting)
\ No newline at end of file diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/SlackAlerting/batching_handler.py b/.venv/lib/python3.12/site-packages/litellm/integrations/SlackAlerting/batching_handler.py new file mode 100644 index 00000000..e35cf61d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/integrations/SlackAlerting/batching_handler.py @@ -0,0 +1,82 @@ +""" +Handles Batching + sending Httpx Post requests to slack + +Slack alerts are sent every 10s or when events are greater than X events + +see custom_batch_logger.py for more details / defaults +""" + +from typing import TYPE_CHECKING, Any + +from litellm._logging import verbose_proxy_logger + +if TYPE_CHECKING: + from .slack_alerting import SlackAlerting as _SlackAlerting + + SlackAlertingType = _SlackAlerting +else: + SlackAlertingType = Any + + +def squash_payloads(queue): + + squashed = {} + if len(queue) == 0: + return squashed + if len(queue) == 1: + return {"key": {"item": queue[0], "count": 1}} + + for item in queue: + url = item["url"] + alert_type = item["alert_type"] + _key = (url, alert_type) + + if _key in squashed: + squashed[_key]["count"] += 1 + # Merge the payloads + + else: + squashed[_key] = {"item": item, "count": 1} + + return squashed + + +def _print_alerting_payload_warning( + payload: dict, slackAlertingInstance: SlackAlertingType +): + """ + Print the payload to the console when + slackAlertingInstance.alerting_args.log_to_console is True + + Relevant issue: https://github.com/BerriAI/litellm/issues/7372 + """ + if slackAlertingInstance.alerting_args.log_to_console is True: + verbose_proxy_logger.warning(payload) + + +async def send_to_webhook(slackAlertingInstance: SlackAlertingType, item, count): + """ + Send a single slack alert to the webhook + """ + import json + + payload = item.get("payload", {}) + try: + if count > 1: + payload["text"] = f"[Num Alerts: {count}]\n\n{payload['text']}" + + response = await slackAlertingInstance.async_http_handler.post( + url=item["url"], + headers=item["headers"], + data=json.dumps(payload), + ) + if response.status_code != 200: + verbose_proxy_logger.debug( + f"Error sending slack alert to url={item['url']}. Error={response.text}" + ) + except Exception as e: + verbose_proxy_logger.debug(f"Error sending slack alert: {str(e)}") + finally: + _print_alerting_payload_warning( + payload, slackAlertingInstance=slackAlertingInstance + ) diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/SlackAlerting/slack_alerting.py b/.venv/lib/python3.12/site-packages/litellm/integrations/SlackAlerting/slack_alerting.py new file mode 100644 index 00000000..a2e62647 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/integrations/SlackAlerting/slack_alerting.py @@ -0,0 +1,1822 @@ +#### What this does #### +# Class for sending Slack Alerts # +import asyncio +import datetime +import os +import random +import time +from datetime import timedelta +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union + +from openai import APIError + +import litellm +import litellm.litellm_core_utils +import litellm.litellm_core_utils.litellm_logging +import litellm.types +from litellm._logging import verbose_logger, verbose_proxy_logger +from litellm.caching.caching import DualCache +from litellm.integrations.custom_batch_logger import CustomBatchLogger +from litellm.litellm_core_utils.duration_parser import duration_in_seconds +from litellm.litellm_core_utils.exception_mapping_utils import ( + _add_key_name_and_team_to_alert, +) +from litellm.llms.custom_httpx.http_handler import ( + get_async_httpx_client, + httpxSpecialProvider, +) +from litellm.proxy._types import AlertType, CallInfo, VirtualKeyEvent, WebhookEvent +from litellm.types.integrations.slack_alerting import * + +from ..email_templates.templates import * +from .batching_handler import send_to_webhook, squash_payloads +from .utils import _add_langfuse_trace_id_to_alert, process_slack_alerting_variables + +if TYPE_CHECKING: + from litellm.router import Router as _Router + + Router = _Router +else: + Router = Any + + +class SlackAlerting(CustomBatchLogger): + """ + Class for sending Slack Alerts + """ + + # Class variables or attributes + def __init__( + self, + internal_usage_cache: Optional[DualCache] = None, + alerting_threshold: Optional[ + float + ] = None, # threshold for slow / hanging llm responses (in seconds) + alerting: Optional[List] = [], + alert_types: List[AlertType] = DEFAULT_ALERT_TYPES, + alert_to_webhook_url: Optional[ + Dict[AlertType, Union[List[str], str]] + ] = None, # if user wants to separate alerts to diff channels + alerting_args={}, + default_webhook_url: Optional[str] = None, + **kwargs, + ): + if alerting_threshold is None: + alerting_threshold = 300 + self.alerting_threshold = alerting_threshold + self.alerting = alerting + self.alert_types = alert_types + self.internal_usage_cache = internal_usage_cache or DualCache() + self.async_http_handler = get_async_httpx_client( + llm_provider=httpxSpecialProvider.LoggingCallback + ) + self.alert_to_webhook_url = process_slack_alerting_variables( + alert_to_webhook_url=alert_to_webhook_url + ) + self.is_running = False + self.alerting_args = SlackAlertingArgs(**alerting_args) + self.default_webhook_url = default_webhook_url + self.flush_lock = asyncio.Lock() + super().__init__(**kwargs, flush_lock=self.flush_lock) + + def update_values( + self, + alerting: Optional[List] = None, + alerting_threshold: Optional[float] = None, + alert_types: Optional[List[AlertType]] = None, + alert_to_webhook_url: Optional[Dict[AlertType, Union[List[str], str]]] = None, + alerting_args: Optional[Dict] = None, + llm_router: Optional[Router] = None, + ): + if alerting is not None: + self.alerting = alerting + asyncio.create_task(self.periodic_flush()) + if alerting_threshold is not None: + self.alerting_threshold = alerting_threshold + if alert_types is not None: + self.alert_types = alert_types + if alerting_args is not None: + self.alerting_args = SlackAlertingArgs(**alerting_args) + if alert_to_webhook_url is not None: + # update the dict + if self.alert_to_webhook_url is None: + self.alert_to_webhook_url = process_slack_alerting_variables( + alert_to_webhook_url=alert_to_webhook_url + ) + else: + _new_values = ( + process_slack_alerting_variables( + alert_to_webhook_url=alert_to_webhook_url + ) + or {} + ) + self.alert_to_webhook_url.update(_new_values) + if llm_router is not None: + self.llm_router = llm_router + + async def deployment_in_cooldown(self): + pass + + async def deployment_removed_from_cooldown(self): + pass + + def _all_possible_alert_types(self): + # used by the UI to show all supported alert types + # Note: This is not the alerts the user has configured, instead it's all possible alert types a user can select + # return list of all values AlertType enum + return list(AlertType) + + def _response_taking_too_long_callback_helper( + self, + kwargs, # kwargs to completion + start_time, + end_time, # start/end time + ): + try: + time_difference = end_time - start_time + # Convert the timedelta to float (in seconds) + time_difference_float = time_difference.total_seconds() + litellm_params = kwargs.get("litellm_params", {}) + model = kwargs.get("model", "") + api_base = litellm.get_api_base(model=model, optional_params=litellm_params) + messages = kwargs.get("messages", None) + # if messages does not exist fallback to "input" + if messages is None: + messages = kwargs.get("input", None) + + # only use first 100 chars for alerting + _messages = str(messages)[:100] + + return time_difference_float, model, api_base, _messages + except Exception as e: + raise e + + def _get_deployment_latencies_to_alert(self, metadata=None): + if metadata is None: + return None + + if "_latency_per_deployment" in metadata: + # Translate model_id to -> api_base + # _latency_per_deployment is a dictionary that looks like this: + """ + _latency_per_deployment: { + api_base: 0.01336697916666667 + } + """ + _message_to_send = "" + _deployment_latencies = metadata["_latency_per_deployment"] + if len(_deployment_latencies) == 0: + return None + _deployment_latency_map: Optional[dict] = None + try: + # try sorting deployments by latency + _deployment_latencies = sorted( + _deployment_latencies.items(), key=lambda x: x[1] + ) + _deployment_latency_map = dict(_deployment_latencies) + except Exception: + pass + + if _deployment_latency_map is None: + return + + for api_base, latency in _deployment_latency_map.items(): + _message_to_send += f"\n{api_base}: {round(latency,2)}s" + _message_to_send = "```" + _message_to_send + "```" + return _message_to_send + + async def response_taking_too_long_callback( + self, + kwargs, # kwargs to completion + completion_response, # response from completion + start_time, + end_time, # start/end time + ): + if self.alerting is None or self.alert_types is None: + return + + time_difference_float, model, api_base, messages = ( + self._response_taking_too_long_callback_helper( + kwargs=kwargs, + start_time=start_time, + end_time=end_time, + ) + ) + if litellm.turn_off_message_logging or litellm.redact_messages_in_exceptions: + messages = "Message not logged. litellm.redact_messages_in_exceptions=True" + request_info = f"\nRequest Model: `{model}`\nAPI Base: `{api_base}`\nMessages: `{messages}`" + slow_message = f"`Responses are slow - {round(time_difference_float,2)}s response time > Alerting threshold: {self.alerting_threshold}s`" + alerting_metadata: dict = {} + if time_difference_float > self.alerting_threshold: + # add deployment latencies to alert + if ( + kwargs is not None + and "litellm_params" in kwargs + and "metadata" in kwargs["litellm_params"] + ): + _metadata: dict = kwargs["litellm_params"]["metadata"] + request_info = _add_key_name_and_team_to_alert( + request_info=request_info, metadata=_metadata + ) + + _deployment_latency_map = self._get_deployment_latencies_to_alert( + metadata=_metadata + ) + if _deployment_latency_map is not None: + request_info += ( + f"\nAvailable Deployment Latencies\n{_deployment_latency_map}" + ) + + if "alerting_metadata" in _metadata: + alerting_metadata = _metadata["alerting_metadata"] + await self.send_alert( + message=slow_message + request_info, + level="Low", + alert_type=AlertType.llm_too_slow, + alerting_metadata=alerting_metadata, + ) + + async def async_update_daily_reports( + self, deployment_metrics: DeploymentMetrics + ) -> int: + """ + Store the perf by deployment in cache + - Number of failed requests per deployment + - Latency / output tokens per deployment + + 'deployment_id:daily_metrics:failed_requests' + 'deployment_id:daily_metrics:latency_per_output_token' + + Returns + int - count of metrics set (1 - if just latency, 2 - if failed + latency) + """ + + return_val = 0 + try: + ## FAILED REQUESTS ## + if deployment_metrics.failed_request: + await self.internal_usage_cache.async_increment_cache( + key="{}:{}".format( + deployment_metrics.id, + SlackAlertingCacheKeys.failed_requests_key.value, + ), + value=1, + parent_otel_span=None, # no attached request, this is a background operation + ) + + return_val += 1 + + ## LATENCY ## + if deployment_metrics.latency_per_output_token is not None: + await self.internal_usage_cache.async_increment_cache( + key="{}:{}".format( + deployment_metrics.id, SlackAlertingCacheKeys.latency_key.value + ), + value=deployment_metrics.latency_per_output_token, + parent_otel_span=None, # no attached request, this is a background operation + ) + + return_val += 1 + + return return_val + except Exception: + return 0 + + async def send_daily_reports(self, router) -> bool: # noqa: PLR0915 + """ + Send a daily report on: + - Top 5 deployments with most failed requests + - Top 5 slowest deployments (normalized by latency/output tokens) + + Get the value from redis cache (if available) or in-memory and send it + + Cleanup: + - reset values in cache -> prevent memory leak + + Returns: + True -> if successfuly sent + False -> if not sent + """ + + ids = router.get_model_ids() + + # get keys + failed_request_keys = [ + "{}:{}".format(id, SlackAlertingCacheKeys.failed_requests_key.value) + for id in ids + ] + latency_keys = [ + "{}:{}".format(id, SlackAlertingCacheKeys.latency_key.value) for id in ids + ] + + combined_metrics_keys = failed_request_keys + latency_keys # reduce cache calls + + combined_metrics_values = await self.internal_usage_cache.async_batch_get_cache( + keys=combined_metrics_keys + ) # [1, 2, None, ..] + + if combined_metrics_values is None: + return False + + all_none = True + for val in combined_metrics_values: + if val is not None and val > 0: + all_none = False + break + + if all_none: + return False + + failed_request_values = combined_metrics_values[ + : len(failed_request_keys) + ] # # [1, 2, None, ..] + latency_values = combined_metrics_values[len(failed_request_keys) :] + + # find top 5 failed + ## Replace None values with a placeholder value (-1 in this case) + placeholder_value = 0 + replaced_failed_values = [ + value if value is not None else placeholder_value + for value in failed_request_values + ] + + ## Get the indices of top 5 keys with the highest numerical values (ignoring None and 0 values) + top_5_failed = sorted( + range(len(replaced_failed_values)), + key=lambda i: replaced_failed_values[i], + reverse=True, + )[:5] + top_5_failed = [ + index for index in top_5_failed if replaced_failed_values[index] > 0 + ] + + # find top 5 slowest + # Replace None values with a placeholder value (-1 in this case) + placeholder_value = 0 + replaced_slowest_values = [ + value if value is not None else placeholder_value + for value in latency_values + ] + + # Get the indices of top 5 values with the highest numerical values (ignoring None and 0 values) + top_5_slowest = sorted( + range(len(replaced_slowest_values)), + key=lambda i: replaced_slowest_values[i], + reverse=True, + )[:5] + top_5_slowest = [ + index for index in top_5_slowest if replaced_slowest_values[index] > 0 + ] + + # format alert -> return the litellm model name + api base + message = f"\n\nTime: `{time.time()}`s\nHere are today's key metrics 📈: \n\n" + + message += "\n\n*❗️ Top Deployments with Most Failed Requests:*\n\n" + if not top_5_failed: + message += "\tNone\n" + for i in range(len(top_5_failed)): + key = failed_request_keys[top_5_failed[i]].split(":")[0] + _deployment = router.get_model_info(key) + if isinstance(_deployment, dict): + deployment_name = _deployment["litellm_params"].get("model", "") + else: + return False + + api_base = litellm.get_api_base( + model=deployment_name, + optional_params=( + _deployment["litellm_params"] if _deployment is not None else {} + ), + ) + if api_base is None: + api_base = "" + value = replaced_failed_values[top_5_failed[i]] + message += f"\t{i+1}. Deployment: `{deployment_name}`, Failed Requests: `{value}`, API Base: `{api_base}`\n" + + message += "\n\n*😅 Top Slowest Deployments:*\n\n" + if not top_5_slowest: + message += "\tNone\n" + for i in range(len(top_5_slowest)): + key = latency_keys[top_5_slowest[i]].split(":")[0] + _deployment = router.get_model_info(key) + if _deployment is not None: + deployment_name = _deployment["litellm_params"].get("model", "") + else: + deployment_name = "" + api_base = litellm.get_api_base( + model=deployment_name, + optional_params=( + _deployment["litellm_params"] if _deployment is not None else {} + ), + ) + value = round(replaced_slowest_values[top_5_slowest[i]], 3) + message += f"\t{i+1}. Deployment: `{deployment_name}`, Latency per output token: `{value}s/token`, API Base: `{api_base}`\n\n" + + # cache cleanup -> reset values to 0 + latency_cache_keys = [(key, 0) for key in latency_keys] + failed_request_cache_keys = [(key, 0) for key in failed_request_keys] + combined_metrics_cache_keys = latency_cache_keys + failed_request_cache_keys + await self.internal_usage_cache.async_set_cache_pipeline( + cache_list=combined_metrics_cache_keys + ) + + message += f"\n\nNext Run is at: `{time.time() + self.alerting_args.daily_report_frequency}`s" + + # send alert + await self.send_alert( + message=message, + level="Low", + alert_type=AlertType.daily_reports, + alerting_metadata={}, + ) + + return True + + async def response_taking_too_long( + self, + start_time: Optional[datetime.datetime] = None, + end_time: Optional[datetime.datetime] = None, + type: Literal["hanging_request", "slow_response"] = "hanging_request", + request_data: Optional[dict] = None, + ): + if self.alerting is None or self.alert_types is None: + return + model: str = "" + if request_data is not None: + model = request_data.get("model", "") + messages = request_data.get("messages", None) + if messages is None: + # if messages does not exist fallback to "input" + messages = request_data.get("input", None) + + # try casting messages to str and get the first 100 characters, else mark as None + try: + messages = str(messages) + messages = messages[:100] + except Exception: + messages = "" + + if ( + litellm.turn_off_message_logging + or litellm.redact_messages_in_exceptions + ): + messages = ( + "Message not logged. litellm.redact_messages_in_exceptions=True" + ) + request_info = f"\nRequest Model: `{model}`\nMessages: `{messages}`" + else: + request_info = "" + + if type == "hanging_request": + await asyncio.sleep( + self.alerting_threshold + ) # Set it to 5 minutes - i'd imagine this might be different for streaming, non-streaming, non-completion (embedding + img) requests + alerting_metadata: dict = {} + if await self._request_is_completed(request_data=request_data) is True: + return + + if request_data is not None: + if request_data.get("deployment", None) is not None and isinstance( + request_data["deployment"], dict + ): + _api_base = litellm.get_api_base( + model=model, + optional_params=request_data["deployment"].get( + "litellm_params", {} + ), + ) + + if _api_base is None: + _api_base = "" + + request_info += f"\nAPI Base: {_api_base}" + elif request_data.get("metadata", None) is not None and isinstance( + request_data["metadata"], dict + ): + # In hanging requests sometime it has not made it to the point where the deployment is passed to the `request_data`` + # in that case we fallback to the api base set in the request metadata + _metadata: dict = request_data["metadata"] + _api_base = _metadata.get("api_base", "") + + request_info = _add_key_name_and_team_to_alert( + request_info=request_info, metadata=_metadata + ) + + if _api_base is None: + _api_base = "" + + if "alerting_metadata" in _metadata: + alerting_metadata = _metadata["alerting_metadata"] + request_info += f"\nAPI Base: `{_api_base}`" + # only alert hanging responses if they have not been marked as success + alerting_message = ( + f"`Requests are hanging - {self.alerting_threshold}s+ request time`" + ) + + if "langfuse" in litellm.success_callback: + langfuse_url = await _add_langfuse_trace_id_to_alert( + request_data=request_data, + ) + + if langfuse_url is not None: + request_info += "\n🪢 Langfuse Trace: {}".format(langfuse_url) + + # add deployment latencies to alert + _deployment_latency_map = self._get_deployment_latencies_to_alert( + metadata=request_data.get("metadata", {}) + ) + if _deployment_latency_map is not None: + request_info += f"\nDeployment Latencies\n{_deployment_latency_map}" + + await self.send_alert( + message=alerting_message + request_info, + level="Medium", + alert_type=AlertType.llm_requests_hanging, + alerting_metadata=alerting_metadata, + ) + + async def failed_tracking_alert(self, error_message: str, failing_model: str): + """ + Raise alert when tracking failed for specific model + + Args: + error_message (str): Error message + failing_model (str): Model that failed tracking + """ + if self.alerting is None or self.alert_types is None: + # do nothing if alerting is not switched on + return + if "failed_tracking_spend" not in self.alert_types: + return + + _cache: DualCache = self.internal_usage_cache + message = "Failed Tracking Cost for " + error_message + _cache_key = "budget_alerts:failed_tracking:{}".format(failing_model) + result = await _cache.async_get_cache(key=_cache_key) + if result is None: + await self.send_alert( + message=message, + level="High", + alert_type=AlertType.failed_tracking_spend, + alerting_metadata={}, + ) + await _cache.async_set_cache( + key=_cache_key, + value="SENT", + ttl=self.alerting_args.budget_alert_ttl, + ) + + async def budget_alerts( # noqa: PLR0915 + self, + type: Literal[ + "token_budget", + "soft_budget", + "user_budget", + "team_budget", + "proxy_budget", + "projected_limit_exceeded", + ], + user_info: CallInfo, + ): + ## PREVENTITIVE ALERTING ## - https://github.com/BerriAI/litellm/issues/2727 + # - Alert once within 24hr period + # - Cache this information + # - Don't re-alert, if alert already sent + _cache: DualCache = self.internal_usage_cache + + if self.alerting is None or self.alert_types is None: + # do nothing if alerting is not switched on + return + if "budget_alerts" not in self.alert_types: + return + _id: Optional[str] = "default_id" # used for caching + user_info_json = user_info.model_dump(exclude_none=True) + user_info_str = self._get_user_info_str(user_info) + event: Optional[ + Literal[ + "budget_crossed", + "threshold_crossed", + "projected_limit_exceeded", + "soft_budget_crossed", + ] + ] = None + event_group: Optional[ + Literal["internal_user", "team", "key", "proxy", "customer"] + ] = None + event_message: str = "" + webhook_event: Optional[WebhookEvent] = None + if type == "proxy_budget": + event_group = "proxy" + event_message += "Proxy Budget: " + elif type == "soft_budget": + event_group = "proxy" + event_message += "Soft Budget Crossed: " + elif type == "user_budget": + event_group = "internal_user" + event_message += "User Budget: " + _id = user_info.user_id or _id + elif type == "team_budget": + event_group = "team" + event_message += "Team Budget: " + _id = user_info.team_id or _id + elif type == "token_budget": + event_group = "key" + event_message += "Key Budget: " + _id = user_info.token + elif type == "projected_limit_exceeded": + event_group = "key" + event_message += "Key Budget: Projected Limit Exceeded" + event = "projected_limit_exceeded" + _id = user_info.token + + # percent of max_budget left to spend + if user_info.max_budget is None and user_info.soft_budget is None: + return + percent_left: float = 0 + if user_info.max_budget is not None: + if user_info.max_budget > 0: + percent_left = ( + user_info.max_budget - user_info.spend + ) / user_info.max_budget + + # check if crossed budget + if user_info.max_budget is not None: + if user_info.spend >= user_info.max_budget: + event = "budget_crossed" + event_message += ( + f"Budget Crossed\n Total Budget:`{user_info.max_budget}`" + ) + elif percent_left <= 0.05: + event = "threshold_crossed" + event_message += "5% Threshold Crossed " + elif percent_left <= 0.15: + event = "threshold_crossed" + event_message += "15% Threshold Crossed" + elif user_info.soft_budget is not None: + if user_info.spend >= user_info.soft_budget: + event = "soft_budget_crossed" + if event is not None and event_group is not None: + _cache_key = "budget_alerts:{}:{}".format(event, _id) + result = await _cache.async_get_cache(key=_cache_key) + if result is None: + webhook_event = WebhookEvent( + event=event, + event_group=event_group, + event_message=event_message, + **user_info_json, + ) + await self.send_alert( + message=event_message + "\n\n" + user_info_str, + level="High", + alert_type=AlertType.budget_alerts, + user_info=webhook_event, + alerting_metadata={}, + ) + await _cache.async_set_cache( + key=_cache_key, + value="SENT", + ttl=self.alerting_args.budget_alert_ttl, + ) + + return + return + + def _get_user_info_str(self, user_info: CallInfo) -> str: + """ + Create a standard message for a budget alert + """ + _all_fields_as_dict = user_info.model_dump(exclude_none=True) + _all_fields_as_dict.pop("token") + msg = "" + for k, v in _all_fields_as_dict.items(): + msg += f"*{k}:* `{v}`\n" + + return msg + + async def customer_spend_alert( + self, + token: Optional[str], + key_alias: Optional[str], + end_user_id: Optional[str], + response_cost: Optional[float], + max_budget: Optional[float], + ): + if ( + self.alerting is not None + and "webhook" in self.alerting + and end_user_id is not None + and token is not None + and response_cost is not None + ): + # log customer spend + event = WebhookEvent( + spend=response_cost, + max_budget=max_budget, + token=token, + customer_id=end_user_id, + user_id=None, + team_id=None, + user_email=None, + key_alias=key_alias, + projected_exceeded_date=None, + projected_spend=None, + event="spend_tracked", + event_group="customer", + event_message="Customer spend tracked. Customer={}, spend={}".format( + end_user_id, response_cost + ), + ) + + await self.send_webhook_alert(webhook_event=event) + + def _count_outage_alerts(self, alerts: List[int]) -> str: + """ + Parameters: + - alerts: List[int] -> list of error codes (either 408 or 500+) + + Returns: + - str -> formatted string. This is an alert message, giving a human-friendly description of the errors. + """ + error_breakdown = {"Timeout Errors": 0, "API Errors": 0, "Unknown Errors": 0} + for alert in alerts: + if alert == 408: + error_breakdown["Timeout Errors"] += 1 + elif alert >= 500: + error_breakdown["API Errors"] += 1 + else: + error_breakdown["Unknown Errors"] += 1 + + error_msg = "" + for key, value in error_breakdown.items(): + if value > 0: + error_msg += "\n{}: {}\n".format(key, value) + + return error_msg + + def _outage_alert_msg_factory( + self, + alert_type: Literal["Major", "Minor"], + key: Literal["Model", "Region"], + key_val: str, + provider: str, + api_base: Optional[str], + outage_value: BaseOutageModel, + ) -> str: + """Format an alert message for slack""" + headers = {f"{key} Name": key_val, "Provider": provider} + if api_base is not None: + headers["API Base"] = api_base # type: ignore + + headers_str = "\n" + for k, v in headers.items(): + headers_str += f"*{k}:* `{v}`\n" + return f"""\n\n +*⚠️ {alert_type} Service Outage* + +{headers_str} + +*Errors:* +{self._count_outage_alerts(alerts=outage_value["alerts"])} + +*Last Check:* `{round(time.time() - outage_value["last_updated_at"], 4)}s ago`\n\n +""" + + async def region_outage_alerts( + self, + exception: APIError, + deployment_id: str, + ) -> None: + """ + Send slack alert if specific provider region is having an outage. + + Track for 408 (Timeout) and >=500 Error codes + """ + ## CREATE (PROVIDER+REGION) ID ## + if self.llm_router is None: + return + + deployment = self.llm_router.get_deployment(model_id=deployment_id) + + if deployment is None: + return + + model = deployment.litellm_params.model + ### GET PROVIDER ### + provider = deployment.litellm_params.custom_llm_provider + if provider is None: + model, provider, _, _ = litellm.get_llm_provider(model=model) + + ### GET REGION ### + region_name = deployment.litellm_params.region_name + if region_name is None: + region_name = litellm.utils._get_model_region( + custom_llm_provider=provider, litellm_params=deployment.litellm_params + ) + + if region_name is None: + return + + ### UNIQUE CACHE KEY ### + cache_key = provider + region_name + + outage_value: Optional[ProviderRegionOutageModel] = ( + await self.internal_usage_cache.async_get_cache(key=cache_key) + ) + + if ( + getattr(exception, "status_code", None) is None + or ( + exception.status_code != 408 # type: ignore + and exception.status_code < 500 # type: ignore + ) + or self.llm_router is None + ): + return + + if outage_value is None: + _deployment_set = set() + _deployment_set.add(deployment_id) + outage_value = ProviderRegionOutageModel( + provider_region_id=cache_key, + alerts=[exception.status_code], # type: ignore + minor_alert_sent=False, + major_alert_sent=False, + last_updated_at=time.time(), + deployment_ids=_deployment_set, + ) + + ## add to cache ## + await self.internal_usage_cache.async_set_cache( + key=cache_key, + value=outage_value, + ttl=self.alerting_args.region_outage_alert_ttl, + ) + return + + if len(outage_value["alerts"]) < self.alerting_args.max_outage_alert_list_size: + outage_value["alerts"].append(exception.status_code) # type: ignore + else: # prevent memory leaks + pass + _deployment_set = outage_value["deployment_ids"] + _deployment_set.add(deployment_id) + outage_value["deployment_ids"] = _deployment_set + outage_value["last_updated_at"] = time.time() + + ## MINOR OUTAGE ALERT SENT ## + if ( + outage_value["minor_alert_sent"] is False + and len(outage_value["alerts"]) + >= self.alerting_args.minor_outage_alert_threshold + and len(_deployment_set) > 1 # make sure it's not just 1 bad deployment + ): + msg = self._outage_alert_msg_factory( + alert_type="Minor", + key="Region", + key_val=region_name, + api_base=None, + outage_value=outage_value, + provider=provider, + ) + # send minor alert + await self.send_alert( + message=msg, + level="Medium", + alert_type=AlertType.outage_alerts, + alerting_metadata={}, + ) + # set to true + outage_value["minor_alert_sent"] = True + + ## MAJOR OUTAGE ALERT SENT ## + elif ( + outage_value["major_alert_sent"] is False + and len(outage_value["alerts"]) + >= self.alerting_args.major_outage_alert_threshold + and len(_deployment_set) > 1 # make sure it's not just 1 bad deployment + ): + msg = self._outage_alert_msg_factory( + alert_type="Major", + key="Region", + key_val=region_name, + api_base=None, + outage_value=outage_value, + provider=provider, + ) + + # send minor alert + await self.send_alert( + message=msg, + level="High", + alert_type=AlertType.outage_alerts, + alerting_metadata={}, + ) + # set to true + outage_value["major_alert_sent"] = True + + ## update cache ## + await self.internal_usage_cache.async_set_cache( + key=cache_key, value=outage_value + ) + + async def outage_alerts( + self, + exception: APIError, + deployment_id: str, + ) -> None: + """ + Send slack alert if model is badly configured / having an outage (408, 401, 429, >=500). + + key = model_id + + value = { + - model_id + - threshold + - alerts [] + } + + ttl = 1hr + max_alerts_size = 10 + """ + try: + outage_value: Optional[OutageModel] = await self.internal_usage_cache.async_get_cache(key=deployment_id) # type: ignore + if ( + getattr(exception, "status_code", None) is None + or ( + exception.status_code != 408 # type: ignore + and exception.status_code < 500 # type: ignore + ) + or self.llm_router is None + ): + return + + ### EXTRACT MODEL DETAILS ### + deployment = self.llm_router.get_deployment(model_id=deployment_id) + if deployment is None: + return + + model = deployment.litellm_params.model + provider = deployment.litellm_params.custom_llm_provider + if provider is None: + try: + model, provider, _, _ = litellm.get_llm_provider(model=model) + except Exception: + provider = "" + api_base = litellm.get_api_base( + model=model, optional_params=deployment.litellm_params + ) + + if outage_value is None: + outage_value = OutageModel( + model_id=deployment_id, + alerts=[exception.status_code], # type: ignore + minor_alert_sent=False, + major_alert_sent=False, + last_updated_at=time.time(), + ) + + ## add to cache ## + await self.internal_usage_cache.async_set_cache( + key=deployment_id, + value=outage_value, + ttl=self.alerting_args.outage_alert_ttl, + ) + return + + if ( + len(outage_value["alerts"]) + < self.alerting_args.max_outage_alert_list_size + ): + outage_value["alerts"].append(exception.status_code) # type: ignore + else: # prevent memory leaks + pass + + outage_value["last_updated_at"] = time.time() + + ## MINOR OUTAGE ALERT SENT ## + if ( + outage_value["minor_alert_sent"] is False + and len(outage_value["alerts"]) + >= self.alerting_args.minor_outage_alert_threshold + ): + msg = self._outage_alert_msg_factory( + alert_type="Minor", + key="Model", + key_val=model, + api_base=api_base, + outage_value=outage_value, + provider=provider, + ) + # send minor alert + await self.send_alert( + message=msg, + level="Medium", + alert_type=AlertType.outage_alerts, + alerting_metadata={}, + ) + # set to true + outage_value["minor_alert_sent"] = True + elif ( + outage_value["major_alert_sent"] is False + and len(outage_value["alerts"]) + >= self.alerting_args.major_outage_alert_threshold + ): + msg = self._outage_alert_msg_factory( + alert_type="Major", + key="Model", + key_val=model, + api_base=api_base, + outage_value=outage_value, + provider=provider, + ) + # send minor alert + await self.send_alert( + message=msg, + level="High", + alert_type=AlertType.outage_alerts, + alerting_metadata={}, + ) + # set to true + outage_value["major_alert_sent"] = True + + ## update cache ## + await self.internal_usage_cache.async_set_cache( + key=deployment_id, value=outage_value + ) + except Exception: + pass + + async def model_added_alert( + self, model_name: str, litellm_model_name: str, passed_model_info: Any + ): + base_model_from_user = getattr(passed_model_info, "base_model", None) + model_info = {} + base_model = "" + if base_model_from_user is not None: + model_info = litellm.model_cost.get(base_model_from_user, {}) + base_model = f"Base Model: `{base_model_from_user}`\n" + else: + model_info = litellm.model_cost.get(litellm_model_name, {}) + model_info_str = "" + for k, v in model_info.items(): + if k == "input_cost_per_token" or k == "output_cost_per_token": + # when converting to string it should not be 1.63e-06 + v = "{:.8f}".format(v) + + model_info_str += f"{k}: {v}\n" + + message = f""" +*🚅 New Model Added* +Model Name: `{model_name}` +{base_model} + +Usage OpenAI Python SDK: +``` +import openai +client = openai.OpenAI( + api_key="your_api_key", + base_url={os.getenv("PROXY_BASE_URL", "http://0.0.0.0:4000")} +) + +response = client.chat.completions.create( + model="{model_name}", # model to send to the proxy + messages = [ + {{ + "role": "user", + "content": "this is a test request, write a short poem" + }} + ] +) +``` + +Model Info: +``` +{model_info_str} +``` +""" + + alert_val = self.send_alert( + message=message, + level="Low", + alert_type=AlertType.new_model_added, + alerting_metadata={}, + ) + + if alert_val is not None and asyncio.iscoroutine(alert_val): + await alert_val + + async def model_removed_alert(self, model_name: str): + pass + + async def send_webhook_alert(self, webhook_event: WebhookEvent) -> bool: + """ + Sends structured alert to webhook, if set. + + Currently only implemented for budget alerts + + Returns -> True if sent, False if not. + + Raises Exception + - if WEBHOOK_URL is not set + """ + + webhook_url = os.getenv("WEBHOOK_URL", None) + if webhook_url is None: + raise Exception("Missing webhook_url from environment") + + payload = webhook_event.model_dump_json() + headers = {"Content-type": "application/json"} + + response = await self.async_http_handler.post( + url=webhook_url, + headers=headers, + data=payload, + ) + if response.status_code == 200: + return True + else: + print("Error sending webhook alert. Error=", response.text) # noqa + + return False + + async def _check_if_using_premium_email_feature( + self, + premium_user: bool, + email_logo_url: Optional[str] = None, + email_support_contact: Optional[str] = None, + ): + from litellm.proxy.proxy_server import CommonProxyErrors, premium_user + + if premium_user is not True: + if email_logo_url is not None or email_support_contact is not None: + raise ValueError( + f"Trying to Customize Email Alerting\n {CommonProxyErrors.not_premium_user.value}" + ) + return + + async def send_key_created_or_user_invited_email( + self, webhook_event: WebhookEvent + ) -> bool: + try: + from litellm.proxy.utils import send_email + + if self.alerting is None or "email" not in self.alerting: + # do nothing if user does not want email alerts + verbose_proxy_logger.error( + "Error sending email alert - 'email' not in self.alerting %s", + self.alerting, + ) + return False + from litellm.proxy.proxy_server import premium_user, prisma_client + + email_logo_url = os.getenv( + "SMTP_SENDER_LOGO", os.getenv("EMAIL_LOGO_URL", None) + ) + email_support_contact = os.getenv("EMAIL_SUPPORT_CONTACT", None) + await self._check_if_using_premium_email_feature( + premium_user, email_logo_url, email_support_contact + ) + if email_logo_url is None: + email_logo_url = LITELLM_LOGO_URL + if email_support_contact is None: + email_support_contact = LITELLM_SUPPORT_CONTACT + + event_name = webhook_event.event_message + recipient_email = webhook_event.user_email + recipient_user_id = webhook_event.user_id + if ( + recipient_email is None + and recipient_user_id is not None + and prisma_client is not None + ): + user_row = await prisma_client.db.litellm_usertable.find_unique( + where={"user_id": recipient_user_id} + ) + + if user_row is not None: + recipient_email = user_row.user_email + + key_token = webhook_event.token + key_budget = webhook_event.max_budget + base_url = os.getenv("PROXY_BASE_URL", "http://0.0.0.0:4000") + + email_html_content = "Alert from LiteLLM Server" + if recipient_email is None: + verbose_proxy_logger.error( + "Trying to send email alert to no recipient", + extra=webhook_event.dict(), + ) + + if webhook_event.event == "key_created": + email_html_content = KEY_CREATED_EMAIL_TEMPLATE.format( + email_logo_url=email_logo_url, + recipient_email=recipient_email, + key_budget=key_budget, + key_token=key_token, + base_url=base_url, + email_support_contact=email_support_contact, + ) + elif webhook_event.event == "internal_user_created": + # GET TEAM NAME + team_id = webhook_event.team_id + team_name = "Default Team" + if team_id is not None and prisma_client is not None: + team_row = await prisma_client.db.litellm_teamtable.find_unique( + where={"team_id": team_id} + ) + if team_row is not None: + team_name = team_row.team_alias or "-" + email_html_content = USER_INVITED_EMAIL_TEMPLATE.format( + email_logo_url=email_logo_url, + recipient_email=recipient_email, + team_name=team_name, + base_url=base_url, + email_support_contact=email_support_contact, + ) + else: + verbose_proxy_logger.error( + "Trying to send email alert on unknown webhook event", + extra=webhook_event.model_dump(), + ) + + webhook_event.model_dump_json() + email_event = { + "to": recipient_email, + "subject": f"LiteLLM: {event_name}", + "html": email_html_content, + } + + await send_email( + receiver_email=email_event["to"], + subject=email_event["subject"], + html=email_event["html"], + ) + + return True + + except Exception as e: + verbose_proxy_logger.error("Error sending email alert %s", str(e)) + return False + + async def send_email_alert_using_smtp( + self, webhook_event: WebhookEvent, alert_type: str + ) -> bool: + """ + Sends structured Email alert to an SMTP server + + Currently only implemented for budget alerts + + Returns -> True if sent, False if not. + """ + from litellm.proxy.proxy_server import premium_user + from litellm.proxy.utils import send_email + + email_logo_url = os.getenv( + "SMTP_SENDER_LOGO", os.getenv("EMAIL_LOGO_URL", None) + ) + email_support_contact = os.getenv("EMAIL_SUPPORT_CONTACT", None) + await self._check_if_using_premium_email_feature( + premium_user, email_logo_url, email_support_contact + ) + + if email_logo_url is None: + email_logo_url = LITELLM_LOGO_URL + if email_support_contact is None: + email_support_contact = LITELLM_SUPPORT_CONTACT + + event_name = webhook_event.event_message + recipient_email = webhook_event.user_email + user_name = webhook_event.user_id + max_budget = webhook_event.max_budget + email_html_content = "Alert from LiteLLM Server" + if recipient_email is None: + verbose_proxy_logger.error( + "Trying to send email alert to no recipient", extra=webhook_event.dict() + ) + + if webhook_event.event == "budget_crossed": + email_html_content = f""" + <img src="{email_logo_url}" alt="LiteLLM Logo" width="150" height="50" /> + + <p> Hi {user_name}, <br/> + + Your LLM API usage this month has reached your account's <b> monthly budget of ${max_budget} </b> <br /> <br /> + + API requests will be rejected until either (a) you increase your monthly budget or (b) your monthly usage resets at the beginning of the next calendar month. <br /> <br /> + + If you have any questions, please send an email to {email_support_contact} <br /> <br /> + + Best, <br /> + The LiteLLM team <br /> + """ + + webhook_event.model_dump_json() + email_event = { + "to": recipient_email, + "subject": f"LiteLLM: {event_name}", + "html": email_html_content, + } + + await send_email( + receiver_email=email_event["to"], + subject=email_event["subject"], + html=email_event["html"], + ) + if webhook_event.event_group == "team": + from litellm.integrations.email_alerting import send_team_budget_alert + + await send_team_budget_alert(webhook_event=webhook_event) + + return False + + async def send_alert( + self, + message: str, + level: Literal["Low", "Medium", "High"], + alert_type: AlertType, + alerting_metadata: dict, + user_info: Optional[WebhookEvent] = None, + **kwargs, + ): + """ + Alerting based on thresholds: - https://github.com/BerriAI/litellm/issues/1298 + + - Responses taking too long + - Requests are hanging + - Calls are failing + - DB Read/Writes are failing + - Proxy Close to max budget + - Key Close to max budget + + Parameters: + level: str - Low|Medium|High - if calls might fail (Medium) or are failing (High); Currently, no alerts would be 'Low'. + message: str - what is the alert about + """ + if self.alerting is None: + return + + if ( + "webhook" in self.alerting + and alert_type == "budget_alerts" + and user_info is not None + ): + await self.send_webhook_alert(webhook_event=user_info) + + if ( + "email" in self.alerting + and alert_type == "budget_alerts" + and user_info is not None + ): + # only send budget alerts over Email + await self.send_email_alert_using_smtp( + webhook_event=user_info, alert_type=alert_type + ) + + if "slack" not in self.alerting: + return + if alert_type not in self.alert_types: + return + + from datetime import datetime + + # Get the current timestamp + current_time = datetime.now().strftime("%H:%M:%S") + _proxy_base_url = os.getenv("PROXY_BASE_URL", None) + if alert_type == "daily_reports" or alert_type == "new_model_added": + formatted_message = message + else: + formatted_message = ( + f"Level: `{level}`\nTimestamp: `{current_time}`\n\nMessage: {message}" + ) + + if kwargs: + for key, value in kwargs.items(): + formatted_message += f"\n\n{key}: `{value}`\n\n" + if alerting_metadata: + for key, value in alerting_metadata.items(): + formatted_message += f"\n\n*Alerting Metadata*: \n{key}: `{value}`\n\n" + if _proxy_base_url is not None: + formatted_message += f"\n\nProxy URL: `{_proxy_base_url}`" + + # check if we find the slack webhook url in self.alert_to_webhook_url + if ( + self.alert_to_webhook_url is not None + and alert_type in self.alert_to_webhook_url + ): + slack_webhook_url: Optional[Union[str, List[str]]] = ( + self.alert_to_webhook_url[alert_type] + ) + elif self.default_webhook_url is not None: + slack_webhook_url = self.default_webhook_url + else: + slack_webhook_url = os.getenv("SLACK_WEBHOOK_URL", None) + + if slack_webhook_url is None: + raise ValueError("Missing SLACK_WEBHOOK_URL from environment") + payload = {"text": formatted_message} + headers = {"Content-type": "application/json"} + + if isinstance(slack_webhook_url, list): + for url in slack_webhook_url: + self.log_queue.append( + { + "url": url, + "headers": headers, + "payload": payload, + "alert_type": alert_type, + } + ) + else: + self.log_queue.append( + { + "url": slack_webhook_url, + "headers": headers, + "payload": payload, + "alert_type": alert_type, + } + ) + + if len(self.log_queue) >= self.batch_size: + await self.flush_queue() + + async def async_send_batch(self): + if not self.log_queue: + return + + squashed_queue = squash_payloads(self.log_queue) + tasks = [ + send_to_webhook( + slackAlertingInstance=self, item=item["item"], count=item["count"] + ) + for item in squashed_queue.values() + ] + await asyncio.gather(*tasks) + self.log_queue.clear() + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + """Log deployment latency""" + try: + if "daily_reports" in self.alert_types: + litellm_params = kwargs.get("litellm_params", {}) or {} + model_info = litellm_params.get("model_info", {}) or {} + model_id = model_info.get("id", "") or "" + response_s: timedelta = end_time - start_time + + final_value = response_s + + if isinstance(response_obj, litellm.ModelResponse) and ( + hasattr(response_obj, "usage") + and response_obj.usage is not None # type: ignore + and hasattr(response_obj.usage, "completion_tokens") # type: ignore + ): + completion_tokens = response_obj.usage.completion_tokens # type: ignore + if completion_tokens is not None and completion_tokens > 0: + final_value = float( + response_s.total_seconds() / completion_tokens + ) + if isinstance(final_value, timedelta): + final_value = final_value.total_seconds() + + await self.async_update_daily_reports( + DeploymentMetrics( + id=model_id, + failed_request=False, + latency_per_output_token=final_value, + updated_at=litellm.utils.get_utc_datetime(), + ) + ) + except Exception as e: + verbose_proxy_logger.error( + f"[Non-Blocking Error] Slack Alerting: Got error in logging LLM deployment latency: {str(e)}" + ) + pass + + async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): + """Log failure + deployment latency""" + _litellm_params = kwargs.get("litellm_params", {}) + _model_info = _litellm_params.get("model_info", {}) or {} + model_id = _model_info.get("id", "") + try: + if "daily_reports" in self.alert_types: + try: + await self.async_update_daily_reports( + DeploymentMetrics( + id=model_id, + failed_request=True, + latency_per_output_token=None, + updated_at=litellm.utils.get_utc_datetime(), + ) + ) + except Exception as e: + verbose_logger.debug(f"Exception raises -{str(e)}") + + if isinstance(kwargs.get("exception", ""), APIError): + if "outage_alerts" in self.alert_types: + await self.outage_alerts( + exception=kwargs["exception"], + deployment_id=model_id, + ) + + if "region_outage_alerts" in self.alert_types: + await self.region_outage_alerts( + exception=kwargs["exception"], deployment_id=model_id + ) + except Exception: + pass + + async def _run_scheduler_helper(self, llm_router) -> bool: + """ + Returns: + - True -> report sent + - False -> report not sent + """ + report_sent_bool = False + + report_sent = await self.internal_usage_cache.async_get_cache( + key=SlackAlertingCacheKeys.report_sent_key.value, + parent_otel_span=None, + ) # None | float + + current_time = time.time() + + if report_sent is None: + await self.internal_usage_cache.async_set_cache( + key=SlackAlertingCacheKeys.report_sent_key.value, + value=current_time, + ) + elif isinstance(report_sent, float): + # Check if current time - interval >= time last sent + interval_seconds = self.alerting_args.daily_report_frequency + + if current_time - report_sent >= interval_seconds: + # Sneak in the reporting logic here + await self.send_daily_reports(router=llm_router) + # Also, don't forget to update the report_sent time after sending the report! + await self.internal_usage_cache.async_set_cache( + key=SlackAlertingCacheKeys.report_sent_key.value, + value=current_time, + ) + report_sent_bool = True + + return report_sent_bool + + async def _run_scheduled_daily_report(self, llm_router: Optional[Any] = None): + """ + If 'daily_reports' enabled + + Ping redis cache every 5 minutes to check if we should send the report + + If yes -> call send_daily_report() + """ + if llm_router is None or self.alert_types is None: + return + + if "daily_reports" in self.alert_types: + while True: + await self._run_scheduler_helper(llm_router=llm_router) + interval = random.randint( + self.alerting_args.report_check_interval - 3, + self.alerting_args.report_check_interval + 3, + ) # shuffle to prevent collisions + await asyncio.sleep(interval) + return + + async def send_weekly_spend_report( + self, + time_range: str = "7d", + ): + """ + Send a spend report for a configurable time range. + + Args: + time_range: A string specifying the time range for the report, e.g., "1d", "7d", "30d" + """ + if self.alerting is None or "spend_reports" not in self.alert_types: + return + + try: + from litellm.proxy.spend_tracking.spend_management_endpoints import ( + _get_spend_report_for_time_range, + ) + + # Parse the time range + days = int(time_range[:-1]) + if time_range[-1].lower() != "d": + raise ValueError("Time range must be specified in days, e.g., '7d'") + + todays_date = datetime.datetime.now().date() + start_date = todays_date - datetime.timedelta(days=days) + + _event_cache_key = f"weekly_spend_report_sent_{start_date.strftime('%Y-%m-%d')}_{todays_date.strftime('%Y-%m-%d')}" + if await self.internal_usage_cache.async_get_cache(key=_event_cache_key): + return + + _resp = await _get_spend_report_for_time_range( + start_date=start_date.strftime("%Y-%m-%d"), + end_date=todays_date.strftime("%Y-%m-%d"), + ) + if _resp is None or _resp == ([], []): + return + + spend_per_team, spend_per_tag = _resp + + _spend_message = f"*💸 Spend Report for `{start_date.strftime('%m-%d-%Y')} - {todays_date.strftime('%m-%d-%Y')}` ({days} days)*\n" + + if spend_per_team is not None: + _spend_message += "\n*Team Spend Report:*\n" + for spend in spend_per_team: + _team_spend = round(float(spend["total_spend"]), 4) + _spend_message += ( + f"Team: `{spend['team_alias']}` | Spend: `${_team_spend}`\n" + ) + + if spend_per_tag is not None: + _spend_message += "\n*Tag Spend Report:*\n" + for spend in spend_per_tag: + _tag_spend = round(float(spend["total_spend"]), 4) + _spend_message += f"Tag: `{spend['individual_request_tag']}` | Spend: `${_tag_spend}`\n" + + await self.send_alert( + message=_spend_message, + level="Low", + alert_type=AlertType.spend_reports, + alerting_metadata={}, + ) + + await self.internal_usage_cache.async_set_cache( + key=_event_cache_key, + value="SENT", + ttl=duration_in_seconds(time_range), + ) + + except ValueError as ve: + verbose_proxy_logger.error(f"Invalid time range format: {ve}") + except Exception as e: + verbose_proxy_logger.error(f"Error sending spend report: {e}") + + async def send_monthly_spend_report(self): + """ """ + try: + from calendar import monthrange + + from litellm.proxy.spend_tracking.spend_management_endpoints import ( + _get_spend_report_for_time_range, + ) + + todays_date = datetime.datetime.now().date() + first_day_of_month = todays_date.replace(day=1) + _, last_day_of_month = monthrange(todays_date.year, todays_date.month) + last_day_of_month = first_day_of_month + datetime.timedelta( + days=last_day_of_month - 1 + ) + + _event_cache_key = f"monthly_spend_report_sent_{first_day_of_month.strftime('%Y-%m-%d')}_{last_day_of_month.strftime('%Y-%m-%d')}" + if await self.internal_usage_cache.async_get_cache(key=_event_cache_key): + return + + _resp = await _get_spend_report_for_time_range( + start_date=first_day_of_month.strftime("%Y-%m-%d"), + end_date=last_day_of_month.strftime("%Y-%m-%d"), + ) + + if _resp is None or _resp == ([], []): + return + + monthly_spend_per_team, monthly_spend_per_tag = _resp + + _spend_message = f"*💸 Monthly Spend Report for `{first_day_of_month.strftime('%m-%d-%Y')} - {last_day_of_month.strftime('%m-%d-%Y')}` *\n" + + if monthly_spend_per_team is not None: + _spend_message += "\n*Team Spend Report:*\n" + for spend in monthly_spend_per_team: + _team_spend = spend["total_spend"] + _team_spend = float(_team_spend) + # round to 4 decimal places + _team_spend = round(_team_spend, 4) + _spend_message += ( + f"Team: `{spend['team_alias']}` | Spend: `${_team_spend}`\n" + ) + + if monthly_spend_per_tag is not None: + _spend_message += "\n*Tag Spend Report:*\n" + for spend in monthly_spend_per_tag: + _tag_spend = spend["total_spend"] + _tag_spend = float(_tag_spend) + # round to 4 decimal places + _tag_spend = round(_tag_spend, 4) + _spend_message += f"Tag: `{spend['individual_request_tag']}` | Spend: `${_tag_spend}`\n" + + await self.send_alert( + message=_spend_message, + level="Low", + alert_type=AlertType.spend_reports, + alerting_metadata={}, + ) + + await self.internal_usage_cache.async_set_cache( + key=_event_cache_key, + value="SENT", + ttl=(30 * 24 * 60 * 60), # 1 month + ) + + except Exception as e: + verbose_proxy_logger.exception("Error sending weekly spend report %s", e) + + async def send_fallback_stats_from_prometheus(self): + """ + Helper to send fallback statistics from prometheus server -> to slack + + This runs once per day and sends an overview of all the fallback statistics + """ + try: + from litellm.integrations.prometheus_helpers.prometheus_api import ( + get_fallback_metric_from_prometheus, + ) + + # call prometheuslogger. + falllback_success_info_prometheus = ( + await get_fallback_metric_from_prometheus() + ) + + fallback_message = ( + f"*Fallback Statistics:*\n{falllback_success_info_prometheus}" + ) + + await self.send_alert( + message=fallback_message, + level="Low", + alert_type=AlertType.fallback_reports, + alerting_metadata={}, + ) + + except Exception as e: + verbose_proxy_logger.error("Error sending weekly spend report %s", e) + + pass + + async def send_virtual_key_event_slack( + self, + key_event: VirtualKeyEvent, + alert_type: AlertType, + event_name: str, + ): + """ + Handles sending Virtual Key related alerts + + Example: + - New Virtual Key Created + - Internal User Updated + - Team Created, Updated, Deleted + """ + try: + + message = f"`{event_name}`\n" + + key_event_dict = key_event.model_dump() + + # Add Created by information first + message += "*Action Done by:*\n" + for key, value in key_event_dict.items(): + if "created_by" in key: + message += f"{key}: `{value}`\n" + + # Add args sent to function in the alert + message += "\n*Arguments passed:*\n" + request_kwargs = key_event.request_kwargs + for key, value in request_kwargs.items(): + if key == "user_api_key_dict": + continue + message += f"{key}: `{value}`\n" + + await self.send_alert( + message=message, + level="High", + alert_type=alert_type, + alerting_metadata={}, + ) + + except Exception as e: + verbose_proxy_logger.error( + "Error sending send_virtual_key_event_slack %s", e + ) + + return + + async def _request_is_completed(self, request_data: Optional[dict]) -> bool: + """ + Returns True if the request is completed - either as a success or failure + """ + if request_data is None: + return False + + if ( + request_data.get("litellm_status", "") != "success" + and request_data.get("litellm_status", "") != "fail" + ): + ## CHECK IF CACHE IS UPDATED + litellm_call_id = request_data.get("litellm_call_id", "") + status: Optional[str] = await self.internal_usage_cache.async_get_cache( + key="request_status:{}".format(litellm_call_id), local_only=True + ) + if status is not None and (status == "success" or status == "fail"): + return True + return False diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/SlackAlerting/utils.py b/.venv/lib/python3.12/site-packages/litellm/integrations/SlackAlerting/utils.py new file mode 100644 index 00000000..0dc8bae5 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/integrations/SlackAlerting/utils.py @@ -0,0 +1,92 @@ +""" +Utils used for slack alerting +""" + +import asyncio +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +from litellm.proxy._types import AlertType +from litellm.secret_managers.main import get_secret + +if TYPE_CHECKING: + from litellm.litellm_core_utils.litellm_logging import Logging as _Logging + + Logging = _Logging +else: + Logging = Any + + +def process_slack_alerting_variables( + alert_to_webhook_url: Optional[Dict[AlertType, Union[List[str], str]]] +) -> Optional[Dict[AlertType, Union[List[str], str]]]: + """ + process alert_to_webhook_url + - check if any urls are set as os.environ/SLACK_WEBHOOK_URL_1 read env var and set the correct value + """ + if alert_to_webhook_url is None: + return None + + for alert_type, webhook_urls in alert_to_webhook_url.items(): + if isinstance(webhook_urls, list): + _webhook_values: List[str] = [] + for webhook_url in webhook_urls: + if "os.environ/" in webhook_url: + _env_value = get_secret(secret_name=webhook_url) + if not isinstance(_env_value, str): + raise ValueError( + f"Invalid webhook url value for: {webhook_url}. Got type={type(_env_value)}" + ) + _webhook_values.append(_env_value) + else: + _webhook_values.append(webhook_url) + + alert_to_webhook_url[alert_type] = _webhook_values + else: + _webhook_value_str: str = webhook_urls + if "os.environ/" in webhook_urls: + _env_value = get_secret(secret_name=webhook_urls) + if not isinstance(_env_value, str): + raise ValueError( + f"Invalid webhook url value for: {webhook_urls}. Got type={type(_env_value)}" + ) + _webhook_value_str = _env_value + else: + _webhook_value_str = webhook_urls + + alert_to_webhook_url[alert_type] = _webhook_value_str + + return alert_to_webhook_url + + +async def _add_langfuse_trace_id_to_alert( + request_data: Optional[dict] = None, +) -> Optional[str]: + """ + Returns langfuse trace url + + - check: + -> existing_trace_id + -> trace_id + -> litellm_call_id + """ + # do nothing for now + if ( + request_data is not None + and request_data.get("litellm_logging_obj", None) is not None + ): + trace_id: Optional[str] = None + litellm_logging_obj: Logging = request_data["litellm_logging_obj"] + + for _ in range(3): + trace_id = litellm_logging_obj._get_trace_id(service_name="langfuse") + if trace_id is not None: + break + await asyncio.sleep(3) # wait 3s before retrying for trace id + + _langfuse_object = litellm_logging_obj._get_callback_object( + service_name="langfuse" + ) + if _langfuse_object is not None: + base_url = _langfuse_object.Langfuse.base_url + return f"{base_url}/trace/{trace_id}" + return None diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/__init__.py b/.venv/lib/python3.12/site-packages/litellm/integrations/__init__.py new file mode 100644 index 00000000..b6e690fd --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/integrations/__init__.py @@ -0,0 +1 @@ +from . import * diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/_types/open_inference.py b/.venv/lib/python3.12/site-packages/litellm/integrations/_types/open_inference.py new file mode 100644 index 00000000..b5076c0e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/integrations/_types/open_inference.py @@ -0,0 +1,286 @@ +from enum import Enum + + +class SpanAttributes: + OUTPUT_VALUE = "output.value" + OUTPUT_MIME_TYPE = "output.mime_type" + """ + The type of output.value. If unspecified, the type is plain text by default. + If type is JSON, the value is a string representing a JSON object. + """ + INPUT_VALUE = "input.value" + INPUT_MIME_TYPE = "input.mime_type" + """ + The type of input.value. If unspecified, the type is plain text by default. + If type is JSON, the value is a string representing a JSON object. + """ + + EMBEDDING_EMBEDDINGS = "embedding.embeddings" + """ + A list of objects containing embedding data, including the vector and represented piece of text. + """ + EMBEDDING_MODEL_NAME = "embedding.model_name" + """ + The name of the embedding model. + """ + + LLM_FUNCTION_CALL = "llm.function_call" + """ + For models and APIs that support function calling. Records attributes such as the function + name and arguments to the called function. + """ + LLM_INVOCATION_PARAMETERS = "llm.invocation_parameters" + """ + Invocation parameters passed to the LLM or API, such as the model name, temperature, etc. + """ + LLM_INPUT_MESSAGES = "llm.input_messages" + """ + Messages provided to a chat API. + """ + LLM_OUTPUT_MESSAGES = "llm.output_messages" + """ + Messages received from a chat API. + """ + LLM_MODEL_NAME = "llm.model_name" + """ + The name of the model being used. + """ + LLM_PROMPTS = "llm.prompts" + """ + Prompts provided to a completions API. + """ + LLM_PROMPT_TEMPLATE = "llm.prompt_template.template" + """ + The prompt template as a Python f-string. + """ + LLM_PROMPT_TEMPLATE_VARIABLES = "llm.prompt_template.variables" + """ + A list of input variables to the prompt template. + """ + LLM_PROMPT_TEMPLATE_VERSION = "llm.prompt_template.version" + """ + The version of the prompt template being used. + """ + LLM_TOKEN_COUNT_PROMPT = "llm.token_count.prompt" + """ + Number of tokens in the prompt. + """ + LLM_TOKEN_COUNT_COMPLETION = "llm.token_count.completion" + """ + Number of tokens in the completion. + """ + LLM_TOKEN_COUNT_TOTAL = "llm.token_count.total" + """ + Total number of tokens, including both prompt and completion. + """ + + TOOL_NAME = "tool.name" + """ + Name of the tool being used. + """ + TOOL_DESCRIPTION = "tool.description" + """ + Description of the tool's purpose, typically used to select the tool. + """ + TOOL_PARAMETERS = "tool.parameters" + """ + Parameters of the tool represented a dictionary JSON string, e.g. + see https://platform.openai.com/docs/guides/gpt/function-calling + """ + + RETRIEVAL_DOCUMENTS = "retrieval.documents" + + METADATA = "metadata" + """ + Metadata attributes are used to store user-defined key-value pairs. + For example, LangChain uses metadata to store user-defined attributes for a chain. + """ + + TAG_TAGS = "tag.tags" + """ + Custom categorical tags for the span. + """ + + OPENINFERENCE_SPAN_KIND = "openinference.span.kind" + + SESSION_ID = "session.id" + """ + The id of the session + """ + USER_ID = "user.id" + """ + The id of the user + """ + + +class MessageAttributes: + """ + Attributes for a message sent to or from an LLM + """ + + MESSAGE_ROLE = "message.role" + """ + The role of the message, such as "user", "agent", "function". + """ + MESSAGE_CONTENT = "message.content" + """ + The content of the message to or from the llm, must be a string. + """ + MESSAGE_CONTENTS = "message.contents" + """ + The message contents to the llm, it is an array of + `message_content` prefixed attributes. + """ + MESSAGE_NAME = "message.name" + """ + The name of the message, often used to identify the function + that was used to generate the message. + """ + MESSAGE_TOOL_CALLS = "message.tool_calls" + """ + The tool calls generated by the model, such as function calls. + """ + MESSAGE_FUNCTION_CALL_NAME = "message.function_call_name" + """ + The function name that is a part of the message list. + This is populated for role 'function' or 'agent' as a mechanism to identify + the function that was called during the execution of a tool. + """ + MESSAGE_FUNCTION_CALL_ARGUMENTS_JSON = "message.function_call_arguments_json" + """ + The JSON string representing the arguments passed to the function + during a function call. + """ + + +class MessageContentAttributes: + """ + Attributes for the contents of user messages sent to an LLM. + """ + + MESSAGE_CONTENT_TYPE = "message_content.type" + """ + The type of the content, such as "text" or "image". + """ + MESSAGE_CONTENT_TEXT = "message_content.text" + """ + The text content of the message, if the type is "text". + """ + MESSAGE_CONTENT_IMAGE = "message_content.image" + """ + The image content of the message, if the type is "image". + An image can be made available to the model by passing a link to + the image or by passing the base64 encoded image directly in the + request. + """ + + +class ImageAttributes: + """ + Attributes for images + """ + + IMAGE_URL = "image.url" + """ + An http or base64 image url + """ + + +class DocumentAttributes: + """ + Attributes for a document. + """ + + DOCUMENT_ID = "document.id" + """ + The id of the document. + """ + DOCUMENT_SCORE = "document.score" + """ + The score of the document + """ + DOCUMENT_CONTENT = "document.content" + """ + The content of the document. + """ + DOCUMENT_METADATA = "document.metadata" + """ + The metadata of the document represented as a dictionary + JSON string, e.g. `"{ 'title': 'foo' }"` + """ + + +class RerankerAttributes: + """ + Attributes for a reranker + """ + + RERANKER_INPUT_DOCUMENTS = "reranker.input_documents" + """ + List of documents as input to the reranker + """ + RERANKER_OUTPUT_DOCUMENTS = "reranker.output_documents" + """ + List of documents as output from the reranker + """ + RERANKER_QUERY = "reranker.query" + """ + Query string for the reranker + """ + RERANKER_MODEL_NAME = "reranker.model_name" + """ + Model name of the reranker + """ + RERANKER_TOP_K = "reranker.top_k" + """ + Top K parameter of the reranker + """ + + +class EmbeddingAttributes: + """ + Attributes for an embedding + """ + + EMBEDDING_TEXT = "embedding.text" + """ + The text represented by the embedding. + """ + EMBEDDING_VECTOR = "embedding.vector" + """ + The embedding vector. + """ + + +class ToolCallAttributes: + """ + Attributes for a tool call + """ + + TOOL_CALL_FUNCTION_NAME = "tool_call.function.name" + """ + The name of function that is being called during a tool call. + """ + TOOL_CALL_FUNCTION_ARGUMENTS_JSON = "tool_call.function.arguments" + """ + The JSON string representing the arguments passed to the function + during a tool call. + """ + + +class OpenInferenceSpanKindValues(Enum): + TOOL = "TOOL" + CHAIN = "CHAIN" + LLM = "LLM" + RETRIEVER = "RETRIEVER" + EMBEDDING = "EMBEDDING" + AGENT = "AGENT" + RERANKER = "RERANKER" + UNKNOWN = "UNKNOWN" + GUARDRAIL = "GUARDRAIL" + EVALUATOR = "EVALUATOR" + + +class OpenInferenceMimeTypeValues(Enum): + TEXT = "text/plain" + JSON = "application/json"
\ No newline at end of file diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/additional_logging_utils.py b/.venv/lib/python3.12/site-packages/litellm/integrations/additional_logging_utils.py new file mode 100644 index 00000000..795afd81 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/integrations/additional_logging_utils.py @@ -0,0 +1,36 @@ +""" +Base class for Additional Logging Utils for CustomLoggers + +- Health Check for the logging util +- Get Request / Response Payload for the logging util +""" + +from abc import ABC, abstractmethod +from datetime import datetime +from typing import Optional + +from litellm.types.integrations.base_health_check import IntegrationHealthCheckStatus + + +class AdditionalLoggingUtils(ABC): + def __init__(self): + super().__init__() + + @abstractmethod + async def async_health_check(self) -> IntegrationHealthCheckStatus: + """ + Check if the service is healthy + """ + pass + + @abstractmethod + async def get_request_response_payload( + self, + request_id: str, + start_time_utc: Optional[datetime], + end_time_utc: Optional[datetime], + ) -> Optional[dict]: + """ + Get the request and response payload for a given `request_id` + """ + return None diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/argilla.py b/.venv/lib/python3.12/site-packages/litellm/integrations/argilla.py new file mode 100644 index 00000000..055ad902 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/integrations/argilla.py @@ -0,0 +1,392 @@ +""" +Send logs to Argilla for annotation +""" + +import asyncio +import json +import os +import random +import types +from typing import Any, Dict, List, Optional + +import httpx +from pydantic import BaseModel # type: ignore + +import litellm +from litellm._logging import verbose_logger +from litellm.integrations.custom_batch_logger import CustomBatchLogger +from litellm.integrations.custom_logger import CustomLogger +from litellm.llms.custom_httpx.http_handler import ( + get_async_httpx_client, + httpxSpecialProvider, +) +from litellm.types.integrations.argilla import ( + SUPPORTED_PAYLOAD_FIELDS, + ArgillaCredentialsObject, + ArgillaItem, +) +from litellm.types.utils import StandardLoggingPayload + + +def is_serializable(value): + non_serializable_types = ( + types.CoroutineType, + types.FunctionType, + types.GeneratorType, + BaseModel, + ) + return not isinstance(value, non_serializable_types) + + +class ArgillaLogger(CustomBatchLogger): + def __init__( + self, + argilla_api_key: Optional[str] = None, + argilla_dataset_name: Optional[str] = None, + argilla_base_url: Optional[str] = None, + **kwargs, + ): + if litellm.argilla_transformation_object is None: + raise Exception( + "'litellm.argilla_transformation_object' is required, to log your payload to Argilla." + ) + self.validate_argilla_transformation_object( + litellm.argilla_transformation_object + ) + self.argilla_transformation_object = litellm.argilla_transformation_object + self.default_credentials = self.get_credentials_from_env( + argilla_api_key=argilla_api_key, + argilla_dataset_name=argilla_dataset_name, + argilla_base_url=argilla_base_url, + ) + self.sampling_rate: float = ( + float(os.getenv("ARGILLA_SAMPLING_RATE")) # type: ignore + if os.getenv("ARGILLA_SAMPLING_RATE") is not None + and os.getenv("ARGILLA_SAMPLING_RATE").strip().isdigit() # type: ignore + else 1.0 + ) + + self.async_httpx_client = get_async_httpx_client( + llm_provider=httpxSpecialProvider.LoggingCallback + ) + _batch_size = ( + os.getenv("ARGILLA_BATCH_SIZE", None) or litellm.argilla_batch_size + ) + if _batch_size: + self.batch_size = int(_batch_size) + asyncio.create_task(self.periodic_flush()) + self.flush_lock = asyncio.Lock() + super().__init__(**kwargs, flush_lock=self.flush_lock) + + def validate_argilla_transformation_object( + self, argilla_transformation_object: Dict[str, Any] + ): + if not isinstance(argilla_transformation_object, dict): + raise Exception( + "'argilla_transformation_object' must be a dictionary, to log your payload to Argilla." + ) + + for v in argilla_transformation_object.values(): + if v not in SUPPORTED_PAYLOAD_FIELDS: + raise Exception( + f"All values in argilla_transformation_object must be a key in SUPPORTED_PAYLOAD_FIELDS, {v} is not a valid key." + ) + + def get_credentials_from_env( + self, + argilla_api_key: Optional[str], + argilla_dataset_name: Optional[str], + argilla_base_url: Optional[str], + ) -> ArgillaCredentialsObject: + + _credentials_api_key = argilla_api_key or os.getenv("ARGILLA_API_KEY") + if _credentials_api_key is None: + raise Exception("Invalid Argilla API Key given. _credentials_api_key=None.") + + _credentials_base_url = ( + argilla_base_url + or os.getenv("ARGILLA_BASE_URL") + or "http://localhost:6900/" + ) + if _credentials_base_url is None: + raise Exception( + "Invalid Argilla Base URL given. _credentials_base_url=None." + ) + + _credentials_dataset_name = ( + argilla_dataset_name + or os.getenv("ARGILLA_DATASET_NAME") + or "litellm-completion" + ) + if _credentials_dataset_name is None: + raise Exception("Invalid Argilla Dataset give. Value=None.") + else: + dataset_response = litellm.module_level_client.get( + url=f"{_credentials_base_url}/api/v1/me/datasets?name={_credentials_dataset_name}", + headers={"X-Argilla-Api-Key": _credentials_api_key}, + ) + json_response = dataset_response.json() + if ( + "items" in json_response + and isinstance(json_response["items"], list) + and len(json_response["items"]) > 0 + ): + _credentials_dataset_name = json_response["items"][0]["id"] + + return ArgillaCredentialsObject( + ARGILLA_API_KEY=_credentials_api_key, + ARGILLA_BASE_URL=_credentials_base_url, + ARGILLA_DATASET_NAME=_credentials_dataset_name, + ) + + def get_chat_messages( + self, payload: StandardLoggingPayload + ) -> List[Dict[str, Any]]: + payload_messages = payload.get("messages", None) + + if payload_messages is None: + raise Exception("No chat messages found in payload.") + + if ( + isinstance(payload_messages, list) + and len(payload_messages) > 0 + and isinstance(payload_messages[0], dict) + ): + return payload_messages + elif isinstance(payload_messages, dict): + return [payload_messages] + else: + raise Exception(f"Invalid chat messages format: {payload_messages}") + + def get_str_response(self, payload: StandardLoggingPayload) -> str: + response = payload["response"] + + if response is None: + raise Exception("No response found in payload.") + + if isinstance(response, str): + return response + elif isinstance(response, dict): + return ( + response.get("choices", [{}])[0].get("message", {}).get("content", "") + ) + else: + raise Exception(f"Invalid response format: {response}") + + def _prepare_log_data( + self, kwargs, response_obj, start_time, end_time + ) -> Optional[ArgillaItem]: + try: + # Ensure everything in the payload is converted to str + payload: Optional[StandardLoggingPayload] = kwargs.get( + "standard_logging_object", None + ) + + if payload is None: + raise Exception("Error logging request payload. Payload=none.") + + argilla_message = self.get_chat_messages(payload) + argilla_response = self.get_str_response(payload) + argilla_item: ArgillaItem = {"fields": {}} + for k, v in self.argilla_transformation_object.items(): + if v == "messages": + argilla_item["fields"][k] = argilla_message + elif v == "response": + argilla_item["fields"][k] = argilla_response + else: + argilla_item["fields"][k] = payload.get(v, None) + + return argilla_item + except Exception: + raise + + def _send_batch(self): + if not self.log_queue: + return + + argilla_api_base = self.default_credentials["ARGILLA_BASE_URL"] + argilla_dataset_name = self.default_credentials["ARGILLA_DATASET_NAME"] + + url = f"{argilla_api_base}/api/v1/datasets/{argilla_dataset_name}/records/bulk" + + argilla_api_key = self.default_credentials["ARGILLA_API_KEY"] + + headers = {"X-Argilla-Api-Key": argilla_api_key} + + try: + response = litellm.module_level_client.post( + url=url, + json=self.log_queue, + headers=headers, + ) + + if response.status_code >= 300: + verbose_logger.error( + f"Argilla Error: {response.status_code} - {response.text}" + ) + else: + verbose_logger.debug( + f"Batch of {len(self.log_queue)} runs successfully created" + ) + + self.log_queue.clear() + except Exception: + verbose_logger.exception("Argilla Layer Error - Error sending batch.") + + def log_success_event(self, kwargs, response_obj, start_time, end_time): + try: + sampling_rate = ( + float(os.getenv("LANGSMITH_SAMPLING_RATE")) # type: ignore + if os.getenv("LANGSMITH_SAMPLING_RATE") is not None + and os.getenv("LANGSMITH_SAMPLING_RATE").strip().isdigit() # type: ignore + else 1.0 + ) + random_sample = random.random() + if random_sample > sampling_rate: + verbose_logger.info( + "Skipping Langsmith logging. Sampling rate={}, random_sample={}".format( + sampling_rate, random_sample + ) + ) + return # Skip logging + verbose_logger.debug( + "Langsmith Sync Layer Logging - kwargs: %s, response_obj: %s", + kwargs, + response_obj, + ) + data = self._prepare_log_data(kwargs, response_obj, start_time, end_time) + if data is None: + return + + self.log_queue.append(data) + verbose_logger.debug( + f"Langsmith, event added to queue. Will flush in {self.flush_interval} seconds..." + ) + + if len(self.log_queue) >= self.batch_size: + self._send_batch() + + except Exception: + verbose_logger.exception("Langsmith Layer Error - log_success_event error") + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + try: + sampling_rate = self.sampling_rate + random_sample = random.random() + if random_sample > sampling_rate: + verbose_logger.info( + "Skipping Langsmith logging. Sampling rate={}, random_sample={}".format( + sampling_rate, random_sample + ) + ) + return # Skip logging + verbose_logger.debug( + "Langsmith Async Layer Logging - kwargs: %s, response_obj: %s", + kwargs, + response_obj, + ) + payload: Optional[StandardLoggingPayload] = kwargs.get( + "standard_logging_object", None + ) + + data = self._prepare_log_data(kwargs, response_obj, start_time, end_time) + + ## ALLOW CUSTOM LOGGERS TO MODIFY / FILTER DATA BEFORE LOGGING + for callback in litellm.callbacks: + if isinstance(callback, CustomLogger): + try: + if data is None: + break + data = await callback.async_dataset_hook(data, payload) + except NotImplementedError: + pass + + if data is None: + return + + self.log_queue.append(data) + verbose_logger.debug( + "Langsmith logging: queue length %s, batch size %s", + len(self.log_queue), + self.batch_size, + ) + if len(self.log_queue) >= self.batch_size: + await self.flush_queue() + except Exception: + verbose_logger.exception( + "Argilla Layer Error - error logging async success event." + ) + + async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): + sampling_rate = self.sampling_rate + random_sample = random.random() + if random_sample > sampling_rate: + verbose_logger.info( + "Skipping Langsmith logging. Sampling rate={}, random_sample={}".format( + sampling_rate, random_sample + ) + ) + return # Skip logging + verbose_logger.info("Langsmith Failure Event Logging!") + try: + data = self._prepare_log_data(kwargs, response_obj, start_time, end_time) + self.log_queue.append(data) + verbose_logger.debug( + "Langsmith logging: queue length %s, batch size %s", + len(self.log_queue), + self.batch_size, + ) + if len(self.log_queue) >= self.batch_size: + await self.flush_queue() + except Exception: + verbose_logger.exception( + "Langsmith Layer Error - error logging async failure event." + ) + + async def async_send_batch(self): + """ + sends runs to /batch endpoint + + Sends runs from self.log_queue + + Returns: None + + Raises: Does not raise an exception, will only verbose_logger.exception() + """ + if not self.log_queue: + return + + argilla_api_base = self.default_credentials["ARGILLA_BASE_URL"] + argilla_dataset_name = self.default_credentials["ARGILLA_DATASET_NAME"] + + url = f"{argilla_api_base}/api/v1/datasets/{argilla_dataset_name}/records/bulk" + + argilla_api_key = self.default_credentials["ARGILLA_API_KEY"] + + headers = {"X-Argilla-Api-Key": argilla_api_key} + + try: + response = await self.async_httpx_client.put( + url=url, + data=json.dumps( + { + "items": self.log_queue, + } + ), + headers=headers, + timeout=60000, + ) + response.raise_for_status() + + if response.status_code >= 300: + verbose_logger.error( + f"Argilla Error: {response.status_code} - {response.text}" + ) + else: + verbose_logger.debug( + "Batch of %s runs successfully created", len(self.log_queue) + ) + except httpx.HTTPStatusError: + verbose_logger.exception("Argilla HTTP Error") + except Exception: + verbose_logger.exception("Argilla Layer Error") diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/arize/_utils.py b/.venv/lib/python3.12/site-packages/litellm/integrations/arize/_utils.py new file mode 100644 index 00000000..487304cc --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/integrations/arize/_utils.py @@ -0,0 +1,126 @@ +from typing import TYPE_CHECKING, Any, Optional + +from litellm._logging import verbose_logger +from litellm.litellm_core_utils.safe_json_dumps import safe_dumps +from litellm.types.utils import StandardLoggingPayload + +if TYPE_CHECKING: + from opentelemetry.trace import Span as _Span + + Span = _Span +else: + Span = Any + + +def set_attributes(span: Span, kwargs, response_obj): + from litellm.integrations._types.open_inference import ( + MessageAttributes, + OpenInferenceSpanKindValues, + SpanAttributes, + ) + + try: + standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get( + "standard_logging_object" + ) + + ############################################# + ############ LLM CALL METADATA ############## + ############################################# + + if standard_logging_payload and ( + metadata := standard_logging_payload["metadata"] + ): + span.set_attribute(SpanAttributes.METADATA, safe_dumps(metadata)) + + ############################################# + ########## LLM Request Attributes ########### + ############################################# + + # The name of the LLM a request is being made to + if kwargs.get("model"): + span.set_attribute(SpanAttributes.LLM_MODEL_NAME, kwargs.get("model")) + + span.set_attribute( + SpanAttributes.OPENINFERENCE_SPAN_KIND, + OpenInferenceSpanKindValues.LLM.value, + ) + messages = kwargs.get("messages") + + # for /chat/completions + # https://docs.arize.com/arize/large-language-models/tracing/semantic-conventions + if messages: + span.set_attribute( + SpanAttributes.INPUT_VALUE, + messages[-1].get("content", ""), # get the last message for input + ) + + # LLM_INPUT_MESSAGES shows up under `input_messages` tab on the span page + for idx, msg in enumerate(messages): + # Set the role per message + span.set_attribute( + f"{SpanAttributes.LLM_INPUT_MESSAGES}.{idx}.{MessageAttributes.MESSAGE_ROLE}", + msg["role"], + ) + # Set the content per message + span.set_attribute( + f"{SpanAttributes.LLM_INPUT_MESSAGES}.{idx}.{MessageAttributes.MESSAGE_CONTENT}", + msg.get("content", ""), + ) + + if standard_logging_payload and ( + model_params := standard_logging_payload["model_parameters"] + ): + # The Generative AI Provider: Azure, OpenAI, etc. + span.set_attribute( + SpanAttributes.LLM_INVOCATION_PARAMETERS, safe_dumps(model_params) + ) + + if model_params.get("user"): + user_id = model_params.get("user") + if user_id is not None: + span.set_attribute(SpanAttributes.USER_ID, user_id) + + ############################################# + ########## LLM Response Attributes ########## + # https://docs.arize.com/arize/large-language-models/tracing/semantic-conventions + ############################################# + if hasattr(response_obj, "get"): + for choice in response_obj.get("choices", []): + response_message = choice.get("message", {}) + span.set_attribute( + SpanAttributes.OUTPUT_VALUE, response_message.get("content", "") + ) + + # This shows up under `output_messages` tab on the span page + # This code assumes a single response + span.set_attribute( + f"{SpanAttributes.LLM_OUTPUT_MESSAGES}.0.{MessageAttributes.MESSAGE_ROLE}", + response_message.get("role"), + ) + span.set_attribute( + f"{SpanAttributes.LLM_OUTPUT_MESSAGES}.0.{MessageAttributes.MESSAGE_CONTENT}", + response_message.get("content", ""), + ) + + usage = response_obj.get("usage") + if usage: + span.set_attribute( + SpanAttributes.LLM_TOKEN_COUNT_TOTAL, + usage.get("total_tokens"), + ) + + # The number of tokens used in the LLM response (completion). + span.set_attribute( + SpanAttributes.LLM_TOKEN_COUNT_COMPLETION, + usage.get("completion_tokens"), + ) + + # The number of tokens used in the LLM prompt. + span.set_attribute( + SpanAttributes.LLM_TOKEN_COUNT_PROMPT, + usage.get("prompt_tokens"), + ) + pass + except Exception as e: + verbose_logger.error(f"Error setting arize attributes: {e}") diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/arize/arize.py b/.venv/lib/python3.12/site-packages/litellm/integrations/arize/arize.py new file mode 100644 index 00000000..7a0fb785 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/integrations/arize/arize.py @@ -0,0 +1,105 @@ +""" +arize AI is OTEL compatible + +this file has Arize ai specific helper functions +""" + +import os +from datetime import datetime +from typing import TYPE_CHECKING, Any, Optional, Union + +from litellm.integrations.arize import _utils +from litellm.integrations.opentelemetry import OpenTelemetry +from litellm.types.integrations.arize import ArizeConfig +from litellm.types.services import ServiceLoggerPayload + +if TYPE_CHECKING: + from opentelemetry.trace import Span as _Span + + from litellm.types.integrations.arize import Protocol as _Protocol + + Protocol = _Protocol + Span = _Span +else: + Protocol = Any + Span = Any + + +class ArizeLogger(OpenTelemetry): + + def set_attributes(self, span: Span, kwargs, response_obj: Optional[Any]): + ArizeLogger.set_arize_attributes(span, kwargs, response_obj) + return + + @staticmethod + def set_arize_attributes(span: Span, kwargs, response_obj): + _utils.set_attributes(span, kwargs, response_obj) + return + + @staticmethod + def get_arize_config() -> ArizeConfig: + """ + Helper function to get Arize configuration. + + Returns: + ArizeConfig: A Pydantic model containing Arize configuration. + + Raises: + ValueError: If required environment variables are not set. + """ + space_key = os.environ.get("ARIZE_SPACE_KEY") + api_key = os.environ.get("ARIZE_API_KEY") + + grpc_endpoint = os.environ.get("ARIZE_ENDPOINT") + http_endpoint = os.environ.get("ARIZE_HTTP_ENDPOINT") + + endpoint = None + protocol: Protocol = "otlp_grpc" + + if grpc_endpoint: + protocol = "otlp_grpc" + endpoint = grpc_endpoint + elif http_endpoint: + protocol = "otlp_http" + endpoint = http_endpoint + else: + protocol = "otlp_grpc" + endpoint = "https://otlp.arize.com/v1" + + return ArizeConfig( + space_key=space_key, + api_key=api_key, + protocol=protocol, + endpoint=endpoint, + ) + + async def async_service_success_hook( + self, + payload: ServiceLoggerPayload, + parent_otel_span: Optional[Span] = None, + start_time: Optional[Union[datetime, float]] = None, + end_time: Optional[Union[datetime, float]] = None, + event_metadata: Optional[dict] = None, + ): + """Arize is used mainly for LLM I/O tracing, sending router+caching metrics adds bloat to arize logs""" + pass + + async def async_service_failure_hook( + self, + payload: ServiceLoggerPayload, + error: Optional[str] = "", + parent_otel_span: Optional[Span] = None, + start_time: Optional[Union[datetime, float]] = None, + end_time: Optional[Union[float, datetime]] = None, + event_metadata: Optional[dict] = None, + ): + """Arize is used mainly for LLM I/O tracing, sending router+caching metrics adds bloat to arize logs""" + pass + + def create_litellm_proxy_request_started_span( + self, + start_time: datetime, + headers: dict, + ): + """Arize is used mainly for LLM I/O tracing, sending Proxy Server Request adds bloat to arize logs""" + pass diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/arize/arize_phoenix.py b/.venv/lib/python3.12/site-packages/litellm/integrations/arize/arize_phoenix.py new file mode 100644 index 00000000..d7b7d581 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/integrations/arize/arize_phoenix.py @@ -0,0 +1,73 @@ +import os +from typing import TYPE_CHECKING, Any +from litellm.integrations.arize import _utils +from litellm._logging import verbose_logger +from litellm.types.integrations.arize_phoenix import ArizePhoenixConfig + +if TYPE_CHECKING: + from .opentelemetry import OpenTelemetryConfig as _OpenTelemetryConfig + from litellm.types.integrations.arize import Protocol as _Protocol + from opentelemetry.trace import Span as _Span + + Protocol = _Protocol + OpenTelemetryConfig = _OpenTelemetryConfig + Span = _Span +else: + Protocol = Any + OpenTelemetryConfig = Any + Span = Any + + +ARIZE_HOSTED_PHOENIX_ENDPOINT = "https://app.phoenix.arize.com/v1/traces" + +class ArizePhoenixLogger: + @staticmethod + def set_arize_phoenix_attributes(span: Span, kwargs, response_obj): + _utils.set_attributes(span, kwargs, response_obj) + return + + @staticmethod + def get_arize_phoenix_config() -> ArizePhoenixConfig: + """ + Retrieves the Arize Phoenix configuration based on environment variables. + + Returns: + ArizePhoenixConfig: A Pydantic model containing Arize Phoenix configuration. + """ + api_key = os.environ.get("PHOENIX_API_KEY", None) + grpc_endpoint = os.environ.get("PHOENIX_COLLECTOR_ENDPOINT", None) + http_endpoint = os.environ.get("PHOENIX_COLLECTOR_HTTP_ENDPOINT", None) + + endpoint = None + protocol: Protocol = "otlp_http" + + if http_endpoint: + endpoint = http_endpoint + protocol = "otlp_http" + elif grpc_endpoint: + endpoint = grpc_endpoint + protocol = "otlp_grpc" + else: + endpoint = ARIZE_HOSTED_PHOENIX_ENDPOINT + protocol = "otlp_http" + verbose_logger.debug( + f"No PHOENIX_COLLECTOR_ENDPOINT or PHOENIX_COLLECTOR_HTTP_ENDPOINT found, using default endpoint with http: {ARIZE_HOSTED_PHOENIX_ENDPOINT}" + ) + + otlp_auth_headers = None + # If the endpoint is the Arize hosted Phoenix endpoint, use the api_key as the auth header as currently it is uses + # a slightly different auth header format than self hosted phoenix + if endpoint == ARIZE_HOSTED_PHOENIX_ENDPOINT: + if api_key is None: + raise ValueError("PHOENIX_API_KEY must be set when the Arize hosted Phoenix endpoint is used.") + otlp_auth_headers = f"api_key={api_key}" + elif api_key is not None: + # api_key/auth is optional for self hosted phoenix + otlp_auth_headers = f"Authorization=Bearer {api_key}" + + return ArizePhoenixConfig( + otlp_auth_headers=otlp_auth_headers, + protocol=protocol, + endpoint=endpoint + ) + diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/athina.py b/.venv/lib/python3.12/site-packages/litellm/integrations/athina.py new file mode 100644 index 00000000..705dc11f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/integrations/athina.py @@ -0,0 +1,102 @@ +import datetime + +import litellm + + +class AthinaLogger: + def __init__(self): + import os + + self.athina_api_key = os.getenv("ATHINA_API_KEY") + self.headers = { + "athina-api-key": self.athina_api_key, + "Content-Type": "application/json", + } + self.athina_logging_url = os.getenv("ATHINA_BASE_URL", "https://log.athina.ai") + "/api/v1/log/inference" + self.additional_keys = [ + "environment", + "prompt_slug", + "customer_id", + "customer_user_id", + "session_id", + "external_reference_id", + "context", + "expected_response", + "user_query", + "tags", + "user_feedback", + "model_options", + "custom_attributes", + ] + + def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose): + import json + import traceback + + try: + is_stream = kwargs.get("stream", False) + if is_stream: + if "complete_streaming_response" in kwargs: + # Log the completion response in streaming mode + completion_response = kwargs["complete_streaming_response"] + response_json = ( + completion_response.model_dump() if completion_response else {} + ) + else: + # Skip logging if the completion response is not available + return + else: + # Log the completion response in non streaming mode + response_json = response_obj.model_dump() if response_obj else {} + data = { + "language_model_id": kwargs.get("model"), + "request": kwargs, + "response": response_json, + "prompt_tokens": response_json.get("usage", {}).get("prompt_tokens"), + "completion_tokens": response_json.get("usage", {}).get( + "completion_tokens" + ), + "total_tokens": response_json.get("usage", {}).get("total_tokens"), + } + + if ( + type(end_time) is datetime.datetime + and type(start_time) is datetime.datetime + ): + data["response_time"] = int( + (end_time - start_time).total_seconds() * 1000 + ) + + if "messages" in kwargs: + data["prompt"] = kwargs.get("messages", None) + + # Directly add tools or functions if present + optional_params = kwargs.get("optional_params", {}) + data.update( + (k, v) + for k, v in optional_params.items() + if k in ["tools", "functions"] + ) + + # Add additional metadata keys + metadata = kwargs.get("litellm_params", {}).get("metadata", {}) + if metadata: + for key in self.additional_keys: + if key in metadata: + data[key] = metadata[key] + response = litellm.module_level_client.post( + self.athina_logging_url, + headers=self.headers, + data=json.dumps(data, default=str), + ) + if response.status_code != 200: + print_verbose( + f"Athina Logger Error - {response.text}, {response.status_code}" + ) + else: + print_verbose(f"Athina Logger Succeeded - {response.text}") + except Exception as e: + print_verbose( + f"Athina Logger Error - {e}, Stack trace: {traceback.format_exc()}" + ) + pass diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/azure_storage/azure_storage.py b/.venv/lib/python3.12/site-packages/litellm/integrations/azure_storage/azure_storage.py new file mode 100644 index 00000000..ddc46b11 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/integrations/azure_storage/azure_storage.py @@ -0,0 +1,381 @@ +import asyncio +import json +import os +import uuid +from datetime import datetime, timedelta +from typing import List, Optional + +from litellm._logging import verbose_logger +from litellm.constants import AZURE_STORAGE_MSFT_VERSION +from litellm.integrations.custom_batch_logger import CustomBatchLogger +from litellm.llms.azure.common_utils import get_azure_ad_token_from_entrata_id +from litellm.llms.custom_httpx.http_handler import ( + AsyncHTTPHandler, + get_async_httpx_client, + httpxSpecialProvider, +) +from litellm.types.utils import StandardLoggingPayload + + +class AzureBlobStorageLogger(CustomBatchLogger): + def __init__( + self, + **kwargs, + ): + try: + verbose_logger.debug( + "AzureBlobStorageLogger: in init azure blob storage logger" + ) + + # Env Variables used for Azure Storage Authentication + self.tenant_id = os.getenv("AZURE_STORAGE_TENANT_ID") + self.client_id = os.getenv("AZURE_STORAGE_CLIENT_ID") + self.client_secret = os.getenv("AZURE_STORAGE_CLIENT_SECRET") + self.azure_storage_account_key: Optional[str] = os.getenv( + "AZURE_STORAGE_ACCOUNT_KEY" + ) + + # Required Env Variables for Azure Storage + _azure_storage_account_name = os.getenv("AZURE_STORAGE_ACCOUNT_NAME") + if not _azure_storage_account_name: + raise ValueError( + "Missing required environment variable: AZURE_STORAGE_ACCOUNT_NAME" + ) + self.azure_storage_account_name: str = _azure_storage_account_name + _azure_storage_file_system = os.getenv("AZURE_STORAGE_FILE_SYSTEM") + if not _azure_storage_file_system: + raise ValueError( + "Missing required environment variable: AZURE_STORAGE_FILE_SYSTEM" + ) + self.azure_storage_file_system: str = _azure_storage_file_system + + # Internal variables used for Token based authentication + self.azure_auth_token: Optional[str] = ( + None # the Azure AD token to use for Azure Storage API requests + ) + self.token_expiry: Optional[datetime] = ( + None # the expiry time of the currentAzure AD token + ) + + asyncio.create_task(self.periodic_flush()) + self.flush_lock = asyncio.Lock() + self.log_queue: List[StandardLoggingPayload] = [] + super().__init__(**kwargs, flush_lock=self.flush_lock) + except Exception as e: + verbose_logger.exception( + f"AzureBlobStorageLogger: Got exception on init AzureBlobStorageLogger client {str(e)}" + ) + raise e + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + """ + Async Log success events to Azure Blob Storage + + Raises: + Raises a NON Blocking verbose_logger.exception if an error occurs + """ + try: + self._premium_user_check() + verbose_logger.debug( + "AzureBlobStorageLogger: Logging - Enters logging function for model %s", + kwargs, + ) + standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get( + "standard_logging_object" + ) + + if standard_logging_payload is None: + raise ValueError("standard_logging_payload is not set") + + self.log_queue.append(standard_logging_payload) + + except Exception as e: + verbose_logger.exception(f"AzureBlobStorageLogger Layer Error - {str(e)}") + pass + + async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): + """ + Async Log failure events to Azure Blob Storage + + Raises: + Raises a NON Blocking verbose_logger.exception if an error occurs + """ + try: + self._premium_user_check() + verbose_logger.debug( + "AzureBlobStorageLogger: Logging - Enters logging function for model %s", + kwargs, + ) + standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get( + "standard_logging_object" + ) + + if standard_logging_payload is None: + raise ValueError("standard_logging_payload is not set") + + self.log_queue.append(standard_logging_payload) + except Exception as e: + verbose_logger.exception(f"AzureBlobStorageLogger Layer Error - {str(e)}") + pass + + async def async_send_batch(self): + """ + Sends the in memory logs queue to Azure Blob Storage + + Raises: + Raises a NON Blocking verbose_logger.exception if an error occurs + """ + try: + if not self.log_queue: + verbose_logger.exception("Datadog: log_queue does not exist") + return + + verbose_logger.debug( + "AzureBlobStorageLogger - about to flush %s events", + len(self.log_queue), + ) + + for payload in self.log_queue: + await self.async_upload_payload_to_azure_blob_storage(payload=payload) + + except Exception as e: + verbose_logger.exception( + f"AzureBlobStorageLogger Error sending batch API - {str(e)}" + ) + + async def async_upload_payload_to_azure_blob_storage( + self, payload: StandardLoggingPayload + ): + """ + Uploads the payload to Azure Blob Storage using a 3-step process: + 1. Create file resource + 2. Append data + 3. Flush the data + """ + try: + + if self.azure_storage_account_key: + await self.upload_to_azure_data_lake_with_azure_account_key( + payload=payload + ) + else: + # Get a valid token instead of always requesting a new one + await self.set_valid_azure_ad_token() + async_client = get_async_httpx_client( + llm_provider=httpxSpecialProvider.LoggingCallback + ) + json_payload = ( + json.dumps(payload) + "\n" + ) # Add newline for each log entry + payload_bytes = json_payload.encode("utf-8") + filename = f"{payload.get('id') or str(uuid.uuid4())}.json" + base_url = f"https://{self.azure_storage_account_name}.dfs.core.windows.net/{self.azure_storage_file_system}/{filename}" + + # Execute the 3-step upload process + await self._create_file(async_client, base_url) + await self._append_data(async_client, base_url, json_payload) + await self._flush_data(async_client, base_url, len(payload_bytes)) + + verbose_logger.debug( + f"Successfully uploaded log to Azure Blob Storage: {filename}" + ) + + except Exception as e: + verbose_logger.exception(f"Error uploading to Azure Blob Storage: {str(e)}") + raise e + + async def _create_file(self, client: AsyncHTTPHandler, base_url: str): + """Helper method to create the file resource""" + try: + verbose_logger.debug(f"Creating file resource at: {base_url}") + headers = { + "x-ms-version": AZURE_STORAGE_MSFT_VERSION, + "Content-Length": "0", + "Authorization": f"Bearer {self.azure_auth_token}", + } + response = await client.put(f"{base_url}?resource=file", headers=headers) + response.raise_for_status() + verbose_logger.debug("Successfully created file resource") + except Exception as e: + verbose_logger.exception(f"Error creating file resource: {str(e)}") + raise + + async def _append_data( + self, client: AsyncHTTPHandler, base_url: str, json_payload: str + ): + """Helper method to append data to the file""" + try: + verbose_logger.debug(f"Appending data to file: {base_url}") + headers = { + "x-ms-version": AZURE_STORAGE_MSFT_VERSION, + "Content-Type": "application/json", + "Authorization": f"Bearer {self.azure_auth_token}", + } + response = await client.patch( + f"{base_url}?action=append&position=0", + headers=headers, + data=json_payload, + ) + response.raise_for_status() + verbose_logger.debug("Successfully appended data") + except Exception as e: + verbose_logger.exception(f"Error appending data: {str(e)}") + raise + + async def _flush_data(self, client: AsyncHTTPHandler, base_url: str, position: int): + """Helper method to flush the data""" + try: + verbose_logger.debug(f"Flushing data at position {position}") + headers = { + "x-ms-version": AZURE_STORAGE_MSFT_VERSION, + "Content-Length": "0", + "Authorization": f"Bearer {self.azure_auth_token}", + } + response = await client.patch( + f"{base_url}?action=flush&position={position}", headers=headers + ) + response.raise_for_status() + verbose_logger.debug("Successfully flushed data") + except Exception as e: + verbose_logger.exception(f"Error flushing data: {str(e)}") + raise + + ####### Helper methods to managing Authentication to Azure Storage ####### + ########################################################################## + + async def set_valid_azure_ad_token(self): + """ + Wrapper to set self.azure_auth_token to a valid Azure AD token, refreshing if necessary + + Refreshes the token when: + - Token is expired + - Token is not set + """ + # Check if token needs refresh + if self._azure_ad_token_is_expired() or self.azure_auth_token is None: + verbose_logger.debug("Azure AD token needs refresh") + self.azure_auth_token = self.get_azure_ad_token_from_azure_storage( + tenant_id=self.tenant_id, + client_id=self.client_id, + client_secret=self.client_secret, + ) + # Token typically expires in 1 hour + self.token_expiry = datetime.now() + timedelta(hours=1) + verbose_logger.debug(f"New token will expire at {self.token_expiry}") + + def get_azure_ad_token_from_azure_storage( + self, + tenant_id: Optional[str], + client_id: Optional[str], + client_secret: Optional[str], + ) -> str: + """ + Gets Azure AD token to use for Azure Storage API requests + """ + verbose_logger.debug("Getting Azure AD Token from Azure Storage") + verbose_logger.debug( + "tenant_id %s, client_id %s, client_secret %s", + tenant_id, + client_id, + client_secret, + ) + if tenant_id is None: + raise ValueError( + "Missing required environment variable: AZURE_STORAGE_TENANT_ID" + ) + if client_id is None: + raise ValueError( + "Missing required environment variable: AZURE_STORAGE_CLIENT_ID" + ) + if client_secret is None: + raise ValueError( + "Missing required environment variable: AZURE_STORAGE_CLIENT_SECRET" + ) + + token_provider = get_azure_ad_token_from_entrata_id( + tenant_id=tenant_id, + client_id=client_id, + client_secret=client_secret, + scope="https://storage.azure.com/.default", + ) + token = token_provider() + + verbose_logger.debug("azure auth token %s", token) + + return token + + def _azure_ad_token_is_expired(self): + """ + Returns True if Azure AD token is expired, False otherwise + """ + if self.azure_auth_token and self.token_expiry: + if datetime.now() + timedelta(minutes=5) >= self.token_expiry: + verbose_logger.debug("Azure AD token is expired. Requesting new token") + return True + return False + + def _premium_user_check(self): + """ + Checks if the user is a premium user, raises an error if not + """ + from litellm.proxy.proxy_server import CommonProxyErrors, premium_user + + if premium_user is not True: + raise ValueError( + f"AzureBlobStorageLogger is only available for premium users. {CommonProxyErrors.not_premium_user}" + ) + + async def upload_to_azure_data_lake_with_azure_account_key( + self, payload: StandardLoggingPayload + ): + """ + Uploads the payload to Azure Data Lake using the Azure SDK + + This is used when Azure Storage Account Key is set - Azure Storage Account Key does not work directly with Azure Rest API + """ + from azure.storage.filedatalake.aio import DataLakeServiceClient + + # Create an async service client + service_client = DataLakeServiceClient( + account_url=f"https://{self.azure_storage_account_name}.dfs.core.windows.net", + credential=self.azure_storage_account_key, + ) + # Get file system client + file_system_client = service_client.get_file_system_client( + file_system=self.azure_storage_file_system + ) + + try: + # Create directory with today's date + from datetime import datetime + + today = datetime.now().strftime("%Y-%m-%d") + directory_client = file_system_client.get_directory_client(today) + + # check if the directory exists + if not await directory_client.exists(): + await directory_client.create_directory() + verbose_logger.debug(f"Created directory: {today}") + + # Create a file client + file_name = f"{payload.get('id') or str(uuid.uuid4())}.json" + file_client = directory_client.get_file_client(file_name) + + # Create the file + await file_client.create_file() + + # Content to append + content = json.dumps(payload).encode("utf-8") + + # Append content to the file + await file_client.append_data(data=content, offset=0, length=len(content)) + + # Flush the content to finalize the file + await file_client.flush_data(position=len(content), offset=0) + + verbose_logger.debug( + f"Successfully uploaded and wrote to {today}/{file_name}" + ) + + except Exception as e: + verbose_logger.exception(f"Error occurred: {str(e)}") diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/braintrust_logging.py b/.venv/lib/python3.12/site-packages/litellm/integrations/braintrust_logging.py new file mode 100644 index 00000000..281fbda0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/integrations/braintrust_logging.py @@ -0,0 +1,399 @@ +# What is this? +## Log success + failure events to Braintrust + +import copy +import os +from datetime import datetime +from typing import Optional, Dict + +import httpx +from pydantic import BaseModel + +import litellm +from litellm import verbose_logger +from litellm.integrations.custom_logger import CustomLogger +from litellm.llms.custom_httpx.http_handler import ( + HTTPHandler, + get_async_httpx_client, + httpxSpecialProvider, +) +from litellm.utils import print_verbose + +global_braintrust_http_handler = get_async_httpx_client(llm_provider=httpxSpecialProvider.LoggingCallback) +global_braintrust_sync_http_handler = HTTPHandler() +API_BASE = "https://api.braintrustdata.com/v1" + + +def get_utc_datetime(): + import datetime as dt + from datetime import datetime + + if hasattr(dt, "UTC"): + return datetime.now(dt.UTC) # type: ignore + else: + return datetime.utcnow() # type: ignore + + +class BraintrustLogger(CustomLogger): + def __init__(self, api_key: Optional[str] = None, api_base: Optional[str] = None) -> None: + super().__init__() + self.validate_environment(api_key=api_key) + self.api_base = api_base or API_BASE + self.default_project_id = None + self.api_key: str = api_key or os.getenv("BRAINTRUST_API_KEY") # type: ignore + self.headers = { + "Authorization": "Bearer " + self.api_key, + "Content-Type": "application/json", + } + self._project_id_cache: Dict[str, str] = {} # Cache mapping project names to IDs + + def validate_environment(self, api_key: Optional[str]): + """ + Expects + BRAINTRUST_API_KEY + + in the environment + """ + missing_keys = [] + if api_key is None and os.getenv("BRAINTRUST_API_KEY", None) is None: + missing_keys.append("BRAINTRUST_API_KEY") + + if len(missing_keys) > 0: + raise Exception("Missing keys={} in environment.".format(missing_keys)) + + def get_project_id_sync(self, project_name: str) -> str: + """ + Get project ID from name, using cache if available. + If project doesn't exist, creates it. + """ + if project_name in self._project_id_cache: + return self._project_id_cache[project_name] + + try: + response = global_braintrust_sync_http_handler.post( + f"{self.api_base}/project", headers=self.headers, json={"name": project_name} + ) + project_dict = response.json() + project_id = project_dict["id"] + self._project_id_cache[project_name] = project_id + return project_id + except httpx.HTTPStatusError as e: + raise Exception(f"Failed to register project: {e.response.text}") + + async def get_project_id_async(self, project_name: str) -> str: + """ + Async version of get_project_id_sync + """ + if project_name in self._project_id_cache: + return self._project_id_cache[project_name] + + try: + response = await global_braintrust_http_handler.post( + f"{self.api_base}/project/register", headers=self.headers, json={"name": project_name} + ) + project_dict = response.json() + project_id = project_dict["id"] + self._project_id_cache[project_name] = project_id + return project_id + except httpx.HTTPStatusError as e: + raise Exception(f"Failed to register project: {e.response.text}") + + @staticmethod + def add_metadata_from_header(litellm_params: dict, metadata: dict) -> dict: + """ + Adds metadata from proxy request headers to Langfuse logging if keys start with "langfuse_" + and overwrites litellm_params.metadata if already included. + + For example if you want to append your trace to an existing `trace_id` via header, send + `headers: { ..., langfuse_existing_trace_id: your-existing-trace-id }` via proxy request. + """ + if litellm_params is None: + return metadata + + if litellm_params.get("proxy_server_request") is None: + return metadata + + if metadata is None: + metadata = {} + + proxy_headers = litellm_params.get("proxy_server_request", {}).get("headers", {}) or {} + + for metadata_param_key in proxy_headers: + if metadata_param_key.startswith("braintrust"): + trace_param_key = metadata_param_key.replace("braintrust", "", 1) + if trace_param_key in metadata: + verbose_logger.warning(f"Overwriting Braintrust `{trace_param_key}` from request header") + else: + verbose_logger.debug(f"Found Braintrust `{trace_param_key}` in request header") + metadata[trace_param_key] = proxy_headers.get(metadata_param_key) + + return metadata + + async def create_default_project_and_experiment(self): + project = await global_braintrust_http_handler.post( + f"{self.api_base}/project", headers=self.headers, json={"name": "litellm"} + ) + + project_dict = project.json() + + self.default_project_id = project_dict["id"] + + def create_sync_default_project_and_experiment(self): + project = global_braintrust_sync_http_handler.post( + f"{self.api_base}/project", headers=self.headers, json={"name": "litellm"} + ) + + project_dict = project.json() + + self.default_project_id = project_dict["id"] + + def log_success_event( # noqa: PLR0915 + self, kwargs, response_obj, start_time, end_time + ): + verbose_logger.debug("REACHES BRAINTRUST SUCCESS") + try: + litellm_call_id = kwargs.get("litellm_call_id") + prompt = {"messages": kwargs.get("messages")} + output = None + choices = [] + if response_obj is not None and ( + kwargs.get("call_type", None) == "embedding" or isinstance(response_obj, litellm.EmbeddingResponse) + ): + output = None + elif response_obj is not None and isinstance(response_obj, litellm.ModelResponse): + output = response_obj["choices"][0]["message"].json() + choices = response_obj["choices"] + elif response_obj is not None and isinstance(response_obj, litellm.TextCompletionResponse): + output = response_obj.choices[0].text + choices = response_obj.choices + elif response_obj is not None and isinstance(response_obj, litellm.ImageResponse): + output = response_obj["data"] + + litellm_params = kwargs.get("litellm_params", {}) + metadata = litellm_params.get("metadata", {}) or {} # if litellm_params['metadata'] == None + metadata = self.add_metadata_from_header(litellm_params, metadata) + clean_metadata = {} + try: + metadata = copy.deepcopy(metadata) # Avoid modifying the original metadata + except Exception: + new_metadata = {} + for key, value in metadata.items(): + if ( + isinstance(value, list) + or isinstance(value, dict) + or isinstance(value, str) + or isinstance(value, int) + or isinstance(value, float) + ): + new_metadata[key] = copy.deepcopy(value) + metadata = new_metadata + + # Get project_id from metadata or create default if needed + project_id = metadata.get("project_id") + if project_id is None: + project_name = metadata.get("project_name") + project_id = self.get_project_id_sync(project_name) if project_name else None + + if project_id is None: + if self.default_project_id is None: + self.create_sync_default_project_and_experiment() + project_id = self.default_project_id + + tags = [] + if isinstance(metadata, dict): + for key, value in metadata.items(): + # generate langfuse tags - Default Tags sent to Langfuse from LiteLLM Proxy + if ( + litellm.langfuse_default_tags is not None + and isinstance(litellm.langfuse_default_tags, list) + and key in litellm.langfuse_default_tags + ): + tags.append(f"{key}:{value}") + + # clean litellm metadata before logging + if key in [ + "headers", + "endpoint", + "caching_groups", + "previous_models", + ]: + continue + else: + clean_metadata[key] = value + + cost = kwargs.get("response_cost", None) + if cost is not None: + clean_metadata["litellm_response_cost"] = cost + + metrics: Optional[dict] = None + usage_obj = getattr(response_obj, "usage", None) + if usage_obj and isinstance(usage_obj, litellm.Usage): + litellm.utils.get_logging_id(start_time, response_obj) + metrics = { + "prompt_tokens": usage_obj.prompt_tokens, + "completion_tokens": usage_obj.completion_tokens, + "total_tokens": usage_obj.total_tokens, + "total_cost": cost, + "time_to_first_token": end_time.timestamp() - start_time.timestamp(), + "start": start_time.timestamp(), + "end": end_time.timestamp(), + } + + request_data = { + "id": litellm_call_id, + "input": prompt["messages"], + "metadata": clean_metadata, + "tags": tags, + "span_attributes": {"name": "Chat Completion", "type": "llm"}, + } + if choices is not None: + request_data["output"] = [choice.dict() for choice in choices] + else: + request_data["output"] = output + + if metrics is not None: + request_data["metrics"] = metrics + + try: + print_verbose(f"global_braintrust_sync_http_handler.post: {global_braintrust_sync_http_handler.post}") + global_braintrust_sync_http_handler.post( + url=f"{self.api_base}/project_logs/{project_id}/insert", + json={"events": [request_data]}, + headers=self.headers, + ) + except httpx.HTTPStatusError as e: + raise Exception(e.response.text) + except Exception as e: + raise e # don't use verbose_logger.exception, if exception is raised + + async def async_log_success_event( # noqa: PLR0915 + self, kwargs, response_obj, start_time, end_time + ): + verbose_logger.debug("REACHES BRAINTRUST SUCCESS") + try: + litellm_call_id = kwargs.get("litellm_call_id") + prompt = {"messages": kwargs.get("messages")} + output = None + choices = [] + if response_obj is not None and ( + kwargs.get("call_type", None) == "embedding" or isinstance(response_obj, litellm.EmbeddingResponse) + ): + output = None + elif response_obj is not None and isinstance(response_obj, litellm.ModelResponse): + output = response_obj["choices"][0]["message"].json() + choices = response_obj["choices"] + elif response_obj is not None and isinstance(response_obj, litellm.TextCompletionResponse): + output = response_obj.choices[0].text + choices = response_obj.choices + elif response_obj is not None and isinstance(response_obj, litellm.ImageResponse): + output = response_obj["data"] + + litellm_params = kwargs.get("litellm_params", {}) + metadata = litellm_params.get("metadata", {}) or {} # if litellm_params['metadata'] == None + metadata = self.add_metadata_from_header(litellm_params, metadata) + clean_metadata = {} + new_metadata = {} + for key, value in metadata.items(): + if ( + isinstance(value, list) + or isinstance(value, str) + or isinstance(value, int) + or isinstance(value, float) + ): + new_metadata[key] = value + elif isinstance(value, BaseModel): + new_metadata[key] = value.model_dump_json() + elif isinstance(value, dict): + for k, v in value.items(): + if isinstance(v, datetime): + value[k] = v.isoformat() + new_metadata[key] = value + + # Get project_id from metadata or create default if needed + project_id = metadata.get("project_id") + if project_id is None: + project_name = metadata.get("project_name") + project_id = await self.get_project_id_async(project_name) if project_name else None + + if project_id is None: + if self.default_project_id is None: + await self.create_default_project_and_experiment() + project_id = self.default_project_id + + tags = [] + if isinstance(metadata, dict): + for key, value in metadata.items(): + # generate langfuse tags - Default Tags sent to Langfuse from LiteLLM Proxy + if ( + litellm.langfuse_default_tags is not None + and isinstance(litellm.langfuse_default_tags, list) + and key in litellm.langfuse_default_tags + ): + tags.append(f"{key}:{value}") + + # clean litellm metadata before logging + if key in [ + "headers", + "endpoint", + "caching_groups", + "previous_models", + ]: + continue + else: + clean_metadata[key] = value + + cost = kwargs.get("response_cost", None) + if cost is not None: + clean_metadata["litellm_response_cost"] = cost + + metrics: Optional[dict] = None + usage_obj = getattr(response_obj, "usage", None) + if usage_obj and isinstance(usage_obj, litellm.Usage): + litellm.utils.get_logging_id(start_time, response_obj) + metrics = { + "prompt_tokens": usage_obj.prompt_tokens, + "completion_tokens": usage_obj.completion_tokens, + "total_tokens": usage_obj.total_tokens, + "total_cost": cost, + "start": start_time.timestamp(), + "end": end_time.timestamp(), + } + + api_call_start_time = kwargs.get("api_call_start_time") + completion_start_time = kwargs.get("completion_start_time") + + if api_call_start_time is not None and completion_start_time is not None: + metrics["time_to_first_token"] = completion_start_time.timestamp() - api_call_start_time.timestamp() + + request_data = { + "id": litellm_call_id, + "input": prompt["messages"], + "output": output, + "metadata": clean_metadata, + "tags": tags, + "span_attributes": {"name": "Chat Completion", "type": "llm"}, + } + if choices is not None: + request_data["output"] = [choice.dict() for choice in choices] + else: + request_data["output"] = output + + if metrics is not None: + request_data["metrics"] = metrics + + if metrics is not None: + request_data["metrics"] = metrics + + try: + await global_braintrust_http_handler.post( + url=f"{self.api_base}/project_logs/{project_id}/insert", + json={"events": [request_data]}, + headers=self.headers, + ) + except httpx.HTTPStatusError as e: + raise Exception(e.response.text) + except Exception as e: + raise e # don't use verbose_logger.exception, if exception is raised + + def log_failure_event(self, kwargs, response_obj, start_time, end_time): + return super().log_failure_event(kwargs, response_obj, start_time, end_time) diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/custom_batch_logger.py b/.venv/lib/python3.12/site-packages/litellm/integrations/custom_batch_logger.py new file mode 100644 index 00000000..3cfdf82c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/integrations/custom_batch_logger.py @@ -0,0 +1,59 @@ +""" +Custom Logger that handles batching logic + +Use this if you want your logs to be stored in memory and flushed periodically. +""" + +import asyncio +import time +from typing import List, Optional + +import litellm +from litellm._logging import verbose_logger +from litellm.integrations.custom_logger import CustomLogger + + +class CustomBatchLogger(CustomLogger): + + def __init__( + self, + flush_lock: Optional[asyncio.Lock] = None, + batch_size: Optional[int] = None, + flush_interval: Optional[int] = None, + **kwargs, + ) -> None: + """ + Args: + flush_lock (Optional[asyncio.Lock], optional): Lock to use when flushing the queue. Defaults to None. Only used for custom loggers that do batching + """ + self.log_queue: List = [] + self.flush_interval = flush_interval or litellm.DEFAULT_FLUSH_INTERVAL_SECONDS + self.batch_size: int = batch_size or litellm.DEFAULT_BATCH_SIZE + self.last_flush_time = time.time() + self.flush_lock = flush_lock + + super().__init__(**kwargs) + + async def periodic_flush(self): + while True: + await asyncio.sleep(self.flush_interval) + verbose_logger.debug( + f"CustomLogger periodic flush after {self.flush_interval} seconds" + ) + await self.flush_queue() + + async def flush_queue(self): + if self.flush_lock is None: + return + + async with self.flush_lock: + if self.log_queue: + verbose_logger.debug( + "CustomLogger: Flushing batch of %s events", len(self.log_queue) + ) + await self.async_send_batch() + self.log_queue.clear() + self.last_flush_time = time.time() + + async def async_send_batch(self, *args, **kwargs): + pass diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/custom_guardrail.py b/.venv/lib/python3.12/site-packages/litellm/integrations/custom_guardrail.py new file mode 100644 index 00000000..4421664b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/integrations/custom_guardrail.py @@ -0,0 +1,274 @@ +from typing import Dict, List, Literal, Optional, Union + +from litellm._logging import verbose_logger +from litellm.integrations.custom_logger import CustomLogger +from litellm.types.guardrails import DynamicGuardrailParams, GuardrailEventHooks +from litellm.types.utils import StandardLoggingGuardrailInformation + + +class CustomGuardrail(CustomLogger): + + def __init__( + self, + guardrail_name: Optional[str] = None, + supported_event_hooks: Optional[List[GuardrailEventHooks]] = None, + event_hook: Optional[ + Union[GuardrailEventHooks, List[GuardrailEventHooks]] + ] = None, + default_on: bool = False, + **kwargs, + ): + """ + Initialize the CustomGuardrail class + + Args: + guardrail_name: The name of the guardrail. This is the name used in your requests. + supported_event_hooks: The event hooks that the guardrail supports + event_hook: The event hook to run the guardrail on + default_on: If True, the guardrail will be run by default on all requests + """ + self.guardrail_name = guardrail_name + self.supported_event_hooks = supported_event_hooks + self.event_hook: Optional[ + Union[GuardrailEventHooks, List[GuardrailEventHooks]] + ] = event_hook + self.default_on: bool = default_on + + if supported_event_hooks: + ## validate event_hook is in supported_event_hooks + self._validate_event_hook(event_hook, supported_event_hooks) + super().__init__(**kwargs) + + def _validate_event_hook( + self, + event_hook: Optional[Union[GuardrailEventHooks, List[GuardrailEventHooks]]], + supported_event_hooks: List[GuardrailEventHooks], + ) -> None: + if event_hook is None: + return + if isinstance(event_hook, list): + for hook in event_hook: + if hook not in supported_event_hooks: + raise ValueError( + f"Event hook {hook} is not in the supported event hooks {supported_event_hooks}" + ) + elif isinstance(event_hook, GuardrailEventHooks): + if event_hook not in supported_event_hooks: + raise ValueError( + f"Event hook {event_hook} is not in the supported event hooks {supported_event_hooks}" + ) + + def get_guardrail_from_metadata( + self, data: dict + ) -> Union[List[str], List[Dict[str, DynamicGuardrailParams]]]: + """ + Returns the guardrail(s) to be run from the metadata + """ + metadata = data.get("metadata") or {} + requested_guardrails = metadata.get("guardrails") or [] + return requested_guardrails + + def _guardrail_is_in_requested_guardrails( + self, + requested_guardrails: Union[List[str], List[Dict[str, DynamicGuardrailParams]]], + ) -> bool: + for _guardrail in requested_guardrails: + if isinstance(_guardrail, dict): + if self.guardrail_name in _guardrail: + return True + elif isinstance(_guardrail, str): + if self.guardrail_name == _guardrail: + return True + return False + + def should_run_guardrail(self, data, event_type: GuardrailEventHooks) -> bool: + """ + Returns True if the guardrail should be run on the event_type + """ + requested_guardrails = self.get_guardrail_from_metadata(data) + + verbose_logger.debug( + "inside should_run_guardrail for guardrail=%s event_type= %s guardrail_supported_event_hooks= %s requested_guardrails= %s self.default_on= %s", + self.guardrail_name, + event_type, + self.event_hook, + requested_guardrails, + self.default_on, + ) + + if self.default_on is True: + if self._event_hook_is_event_type(event_type): + return True + return False + + if ( + self.event_hook + and not self._guardrail_is_in_requested_guardrails(requested_guardrails) + and event_type.value != "logging_only" + ): + return False + + if not self._event_hook_is_event_type(event_type): + return False + + return True + + def _event_hook_is_event_type(self, event_type: GuardrailEventHooks) -> bool: + """ + Returns True if the event_hook is the same as the event_type + + eg. if `self.event_hook == "pre_call" and event_type == "pre_call"` -> then True + eg. if `self.event_hook == "pre_call" and event_type == "post_call"` -> then False + """ + + if self.event_hook is None: + return True + if isinstance(self.event_hook, list): + return event_type.value in self.event_hook + return self.event_hook == event_type.value + + def get_guardrail_dynamic_request_body_params(self, request_data: dict) -> dict: + """ + Returns `extra_body` to be added to the request body for the Guardrail API call + + Use this to pass dynamic params to the guardrail API call - eg. success_threshold, failure_threshold, etc. + + ``` + [{"lakera_guard": {"extra_body": {"foo": "bar"}}}] + ``` + + Will return: for guardrail=`lakera-guard`: + { + "foo": "bar" + } + + Args: + request_data: The original `request_data` passed to LiteLLM Proxy + """ + requested_guardrails = self.get_guardrail_from_metadata(request_data) + + # Look for the guardrail configuration matching self.guardrail_name + for guardrail in requested_guardrails: + if isinstance(guardrail, dict) and self.guardrail_name in guardrail: + # Get the configuration for this guardrail + guardrail_config: DynamicGuardrailParams = DynamicGuardrailParams( + **guardrail[self.guardrail_name] + ) + if self._validate_premium_user() is not True: + return {} + + # Return the extra_body if it exists, otherwise empty dict + return guardrail_config.get("extra_body", {}) + + return {} + + def _validate_premium_user(self) -> bool: + """ + Returns True if the user is a premium user + """ + from litellm.proxy.proxy_server import CommonProxyErrors, premium_user + + if premium_user is not True: + verbose_logger.warning( + f"Trying to use premium guardrail without premium user {CommonProxyErrors.not_premium_user.value}" + ) + return False + return True + + def add_standard_logging_guardrail_information_to_request_data( + self, + guardrail_json_response: Union[Exception, str, dict], + request_data: dict, + guardrail_status: Literal["success", "failure"], + ) -> None: + """ + Builds `StandardLoggingGuardrailInformation` and adds it to the request metadata so it can be used for logging to DataDog, Langfuse, etc. + """ + from litellm.proxy.proxy_server import premium_user + + if premium_user is not True: + verbose_logger.warning( + f"Guardrail Tracing is only available for premium users. Skipping guardrail logging for guardrail={self.guardrail_name} event_hook={self.event_hook}" + ) + return + if isinstance(guardrail_json_response, Exception): + guardrail_json_response = str(guardrail_json_response) + slg = StandardLoggingGuardrailInformation( + guardrail_name=self.guardrail_name, + guardrail_mode=self.event_hook, + guardrail_response=guardrail_json_response, + guardrail_status=guardrail_status, + ) + if "metadata" in request_data: + request_data["metadata"]["standard_logging_guardrail_information"] = slg + elif "litellm_metadata" in request_data: + request_data["litellm_metadata"][ + "standard_logging_guardrail_information" + ] = slg + else: + verbose_logger.warning( + "unable to log guardrail information. No metadata found in request_data" + ) + + +def log_guardrail_information(func): + """ + Decorator to add standard logging guardrail information to any function + + Add this decorator to ensure your guardrail response is logged to DataDog, OTEL, s3, GCS etc. + + Logs for: + - pre_call + - during_call + - TODO: log post_call. This is more involved since the logs are sent to DD, s3 before the guardrail is even run + """ + import asyncio + import functools + + def process_response(self, response, request_data): + self.add_standard_logging_guardrail_information_to_request_data( + guardrail_json_response=response, + request_data=request_data, + guardrail_status="success", + ) + return response + + def process_error(self, e, request_data): + self.add_standard_logging_guardrail_information_to_request_data( + guardrail_json_response=e, + request_data=request_data, + guardrail_status="failure", + ) + raise e + + @functools.wraps(func) + async def async_wrapper(*args, **kwargs): + self: CustomGuardrail = args[0] + request_data: Optional[dict] = ( + kwargs.get("data") or kwargs.get("request_data") or {} + ) + try: + response = await func(*args, **kwargs) + return process_response(self, response, request_data) + except Exception as e: + return process_error(self, e, request_data) + + @functools.wraps(func) + def sync_wrapper(*args, **kwargs): + self: CustomGuardrail = args[0] + request_data: Optional[dict] = ( + kwargs.get("data") or kwargs.get("request_data") or {} + ) + try: + response = func(*args, **kwargs) + return process_response(self, response, request_data) + except Exception as e: + return process_error(self, e, request_data) + + @functools.wraps(func) + def wrapper(*args, **kwargs): + if asyncio.iscoroutinefunction(func): + return async_wrapper(*args, **kwargs) + return sync_wrapper(*args, **kwargs) + + return wrapper diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/custom_logger.py b/.venv/lib/python3.12/site-packages/litellm/integrations/custom_logger.py new file mode 100644 index 00000000..6f1ec88d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/integrations/custom_logger.py @@ -0,0 +1,388 @@ +#### What this does #### +# On success, logs events to Promptlayer +import traceback +from typing import ( + TYPE_CHECKING, + Any, + AsyncGenerator, + List, + Literal, + Optional, + Tuple, + Union, +) + +from pydantic import BaseModel + +from litellm.caching.caching import DualCache +from litellm.proxy._types import UserAPIKeyAuth +from litellm.types.integrations.argilla import ArgillaItem +from litellm.types.llms.openai import AllMessageValues, ChatCompletionRequest +from litellm.types.utils import ( + AdapterCompletionStreamWrapper, + EmbeddingResponse, + ImageResponse, + ModelResponse, + ModelResponseStream, + StandardCallbackDynamicParams, + StandardLoggingPayload, +) + +if TYPE_CHECKING: + from opentelemetry.trace import Span as _Span + + Span = _Span +else: + Span = Any + + +class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback#callback-class + # Class variables or attributes + def __init__(self, message_logging: bool = True) -> None: + self.message_logging = message_logging + pass + + def log_pre_api_call(self, model, messages, kwargs): + pass + + def log_post_api_call(self, kwargs, response_obj, start_time, end_time): + pass + + def log_stream_event(self, kwargs, response_obj, start_time, end_time): + pass + + def log_success_event(self, kwargs, response_obj, start_time, end_time): + pass + + def log_failure_event(self, kwargs, response_obj, start_time, end_time): + pass + + #### ASYNC #### + + async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time): + pass + + async def async_log_pre_api_call(self, model, messages, kwargs): + pass + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + pass + + async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): + pass + + #### PROMPT MANAGEMENT HOOKS #### + + async def async_get_chat_completion_prompt( + self, + model: str, + messages: List[AllMessageValues], + non_default_params: dict, + prompt_id: str, + prompt_variables: Optional[dict], + dynamic_callback_params: StandardCallbackDynamicParams, + ) -> Tuple[str, List[AllMessageValues], dict]: + """ + Returns: + - model: str - the model to use (can be pulled from prompt management tool) + - messages: List[AllMessageValues] - the messages to use (can be pulled from prompt management tool) + - non_default_params: dict - update with any optional params (e.g. temperature, max_tokens, etc.) to use (can be pulled from prompt management tool) + """ + return model, messages, non_default_params + + def get_chat_completion_prompt( + self, + model: str, + messages: List[AllMessageValues], + non_default_params: dict, + prompt_id: str, + prompt_variables: Optional[dict], + dynamic_callback_params: StandardCallbackDynamicParams, + ) -> Tuple[str, List[AllMessageValues], dict]: + """ + Returns: + - model: str - the model to use (can be pulled from prompt management tool) + - messages: List[AllMessageValues] - the messages to use (can be pulled from prompt management tool) + - non_default_params: dict - update with any optional params (e.g. temperature, max_tokens, etc.) to use (can be pulled from prompt management tool) + """ + return model, messages, non_default_params + + #### PRE-CALL CHECKS - router/proxy only #### + """ + Allows usage-based-routing-v2 to run pre-call rpm checks within the picked deployment's semaphore (concurrency-safe tpm/rpm checks). + """ + + async def async_filter_deployments( + self, + model: str, + healthy_deployments: List, + messages: Optional[List[AllMessageValues]], + request_kwargs: Optional[dict] = None, + parent_otel_span: Optional[Span] = None, + ) -> List[dict]: + return healthy_deployments + + async def async_pre_call_check( + self, deployment: dict, parent_otel_span: Optional[Span] + ) -> Optional[dict]: + pass + + def pre_call_check(self, deployment: dict) -> Optional[dict]: + pass + + #### Fallback Events - router/proxy only #### + async def log_model_group_rate_limit_error( + self, exception: Exception, original_model_group: Optional[str], kwargs: dict + ): + pass + + async def log_success_fallback_event( + self, original_model_group: str, kwargs: dict, original_exception: Exception + ): + pass + + async def log_failure_fallback_event( + self, original_model_group: str, kwargs: dict, original_exception: Exception + ): + pass + + #### ADAPTERS #### Allow calling 100+ LLMs in custom format - https://github.com/BerriAI/litellm/pulls + + def translate_completion_input_params( + self, kwargs + ) -> Optional[ChatCompletionRequest]: + """ + Translates the input params, from the provider's native format to the litellm.completion() format. + """ + pass + + def translate_completion_output_params( + self, response: ModelResponse + ) -> Optional[BaseModel]: + """ + Translates the output params, from the OpenAI format to the custom format. + """ + pass + + def translate_completion_output_params_streaming( + self, completion_stream: Any + ) -> Optional[AdapterCompletionStreamWrapper]: + """ + Translates the streaming chunk, from the OpenAI format to the custom format. + """ + pass + + ### DATASET HOOKS #### - currently only used for Argilla + + async def async_dataset_hook( + self, + logged_item: ArgillaItem, + standard_logging_payload: Optional[StandardLoggingPayload], + ) -> Optional[ArgillaItem]: + """ + - Decide if the result should be logged to Argilla. + - Modify the result before logging to Argilla. + - Return None if the result should not be logged to Argilla. + """ + raise NotImplementedError("async_dataset_hook not implemented") + + #### CALL HOOKS - proxy only #### + """ + Control the modify incoming / outgoung data before calling the model + """ + + async def async_pre_call_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + cache: DualCache, + data: dict, + call_type: Literal[ + "completion", + "text_completion", + "embeddings", + "image_generation", + "moderation", + "audio_transcription", + "pass_through_endpoint", + "rerank", + ], + ) -> Optional[ + Union[Exception, str, dict] + ]: # raise exception if invalid, return a str for the user to receive - if rejected, or return a modified dictionary for passing into litellm + pass + + async def async_post_call_failure_hook( + self, + request_data: dict, + original_exception: Exception, + user_api_key_dict: UserAPIKeyAuth, + ): + pass + + async def async_post_call_success_hook( + self, + data: dict, + user_api_key_dict: UserAPIKeyAuth, + response: Union[Any, ModelResponse, EmbeddingResponse, ImageResponse], + ) -> Any: + pass + + async def async_logging_hook( + self, kwargs: dict, result: Any, call_type: str + ) -> Tuple[dict, Any]: + """For masking logged request/response. Return a modified version of the request/result.""" + return kwargs, result + + def logging_hook( + self, kwargs: dict, result: Any, call_type: str + ) -> Tuple[dict, Any]: + """For masking logged request/response. Return a modified version of the request/result.""" + return kwargs, result + + async def async_moderation_hook( + self, + data: dict, + user_api_key_dict: UserAPIKeyAuth, + call_type: Literal[ + "completion", + "embeddings", + "image_generation", + "moderation", + "audio_transcription", + "responses", + ], + ) -> Any: + pass + + async def async_post_call_streaming_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + response: str, + ) -> Any: + pass + + async def async_post_call_streaming_iterator_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + response: Any, + request_data: dict, + ) -> AsyncGenerator[ModelResponseStream, None]: + async for item in response: + yield item + + #### SINGLE-USE #### - https://docs.litellm.ai/docs/observability/custom_callback#using-your-custom-callback-function + + def log_input_event(self, model, messages, kwargs, print_verbose, callback_func): + try: + kwargs["model"] = model + kwargs["messages"] = messages + kwargs["log_event_type"] = "pre_api_call" + callback_func( + kwargs, + ) + print_verbose(f"Custom Logger - model call details: {kwargs}") + except Exception: + print_verbose(f"Custom Logger Error - {traceback.format_exc()}") + + async def async_log_input_event( + self, model, messages, kwargs, print_verbose, callback_func + ): + try: + kwargs["model"] = model + kwargs["messages"] = messages + kwargs["log_event_type"] = "pre_api_call" + await callback_func( + kwargs, + ) + print_verbose(f"Custom Logger - model call details: {kwargs}") + except Exception: + print_verbose(f"Custom Logger Error - {traceback.format_exc()}") + + def log_event( + self, kwargs, response_obj, start_time, end_time, print_verbose, callback_func + ): + # Method definition + try: + kwargs["log_event_type"] = "post_api_call" + callback_func( + kwargs, # kwargs to func + response_obj, + start_time, + end_time, + ) + except Exception: + print_verbose(f"Custom Logger Error - {traceback.format_exc()}") + pass + + async def async_log_event( + self, kwargs, response_obj, start_time, end_time, print_verbose, callback_func + ): + # Method definition + try: + kwargs["log_event_type"] = "post_api_call" + await callback_func( + kwargs, # kwargs to func + response_obj, + start_time, + end_time, + ) + except Exception: + print_verbose(f"Custom Logger Error - {traceback.format_exc()}") + pass + + # Useful helpers for custom logger classes + + def truncate_standard_logging_payload_content( + self, + standard_logging_object: StandardLoggingPayload, + ): + """ + Truncate error strings and message content in logging payload + + Some loggers like DataDog/ GCS Bucket have a limit on the size of the payload. (1MB) + + This function truncates the error string and the message content if they exceed a certain length. + """ + MAX_STR_LENGTH = 10_000 + + # Truncate fields that might exceed max length + fields_to_truncate = ["error_str", "messages", "response"] + for field in fields_to_truncate: + self._truncate_field( + standard_logging_object=standard_logging_object, + field_name=field, + max_length=MAX_STR_LENGTH, + ) + + def _truncate_field( + self, + standard_logging_object: StandardLoggingPayload, + field_name: str, + max_length: int, + ) -> None: + """ + Helper function to truncate a field in the logging payload + + This converts the field to a string and then truncates it if it exceeds the max length. + + Why convert to string ? + 1. User was sending a poorly formatted list for `messages` field, we could not predict where they would send content + - Converting to string and then truncating the logged content catches this + 2. We want to avoid modifying the original `messages`, `response`, and `error_str` in the logging payload since these are in kwargs and could be returned to the user + """ + field_value = standard_logging_object.get(field_name) # type: ignore + if field_value: + str_value = str(field_value) + if len(str_value) > max_length: + standard_logging_object[field_name] = self._truncate_text( # type: ignore + text=str_value, max_length=max_length + ) + + def _truncate_text(self, text: str, max_length: int) -> str: + """Truncate text if it exceeds max_length""" + return ( + text[:max_length] + + "...truncated by litellm, this logger does not support large content" + if len(text) > max_length + else text + ) diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/custom_prompt_management.py b/.venv/lib/python3.12/site-packages/litellm/integrations/custom_prompt_management.py new file mode 100644 index 00000000..5b34ef0c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/integrations/custom_prompt_management.py @@ -0,0 +1,49 @@ +from typing import List, Optional, Tuple + +from litellm.integrations.custom_logger import CustomLogger +from litellm.integrations.prompt_management_base import ( + PromptManagementBase, + PromptManagementClient, +) +from litellm.types.llms.openai import AllMessageValues +from litellm.types.utils import StandardCallbackDynamicParams + + +class CustomPromptManagement(CustomLogger, PromptManagementBase): + def get_chat_completion_prompt( + self, + model: str, + messages: List[AllMessageValues], + non_default_params: dict, + prompt_id: str, + prompt_variables: Optional[dict], + dynamic_callback_params: StandardCallbackDynamicParams, + ) -> Tuple[str, List[AllMessageValues], dict]: + """ + Returns: + - model: str - the model to use (can be pulled from prompt management tool) + - messages: List[AllMessageValues] - the messages to use (can be pulled from prompt management tool) + - non_default_params: dict - update with any optional params (e.g. temperature, max_tokens, etc.) to use (can be pulled from prompt management tool) + """ + return model, messages, non_default_params + + @property + def integration_name(self) -> str: + return "custom-prompt-management" + + def should_run_prompt_management( + self, + prompt_id: str, + dynamic_callback_params: StandardCallbackDynamicParams, + ) -> bool: + return True + + def _compile_prompt_helper( + self, + prompt_id: str, + prompt_variables: Optional[dict], + dynamic_callback_params: StandardCallbackDynamicParams, + ) -> PromptManagementClient: + raise NotImplementedError( + "Custom prompt management does not support compile prompt helper" + ) diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/datadog/datadog.py b/.venv/lib/python3.12/site-packages/litellm/integrations/datadog/datadog.py new file mode 100644 index 00000000..4f4b05c8 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/integrations/datadog/datadog.py @@ -0,0 +1,580 @@ +""" +DataDog Integration - sends logs to /api/v2/log + +DD Reference API: https://docs.datadoghq.com/api/latest/logs + +`async_log_success_event` - used by litellm proxy to send logs to datadog +`log_success_event` - sync version of logging to DataDog, only used on litellm Python SDK, if user opts in to using sync functions + +async_log_success_event: will store batch of DD_MAX_BATCH_SIZE in memory and flush to Datadog once it reaches DD_MAX_BATCH_SIZE or every 5 seconds + +async_service_failure_hook: Logs failures from Redis, Postgres (Adjacent systems), as 'WARNING' on DataDog + +For batching specific details see CustomBatchLogger class +""" + +import asyncio +import datetime +import json +import os +import traceback +import uuid +from datetime import datetime as datetimeObj +from typing import Any, List, Optional, Union + +import httpx +from httpx import Response + +import litellm +from litellm._logging import verbose_logger +from litellm.integrations.custom_batch_logger import CustomBatchLogger +from litellm.llms.custom_httpx.http_handler import ( + _get_httpx_client, + get_async_httpx_client, + httpxSpecialProvider, +) +from litellm.types.integrations.base_health_check import IntegrationHealthCheckStatus +from litellm.types.integrations.datadog import * +from litellm.types.services import ServiceLoggerPayload, ServiceTypes +from litellm.types.utils import StandardLoggingPayload + +from ..additional_logging_utils import AdditionalLoggingUtils + +# max number of logs DD API can accept +DD_MAX_BATCH_SIZE = 1000 + +# specify what ServiceTypes are logged as success events to DD. (We don't want to spam DD traces with large number of service types) +DD_LOGGED_SUCCESS_SERVICE_TYPES = [ + ServiceTypes.RESET_BUDGET_JOB, +] + + +class DataDogLogger( + CustomBatchLogger, + AdditionalLoggingUtils, +): + # Class variables or attributes + def __init__( + self, + **kwargs, + ): + """ + Initializes the datadog logger, checks if the correct env variables are set + + Required environment variables: + `DD_API_KEY` - your datadog api key + `DD_SITE` - your datadog site, example = `"us5.datadoghq.com"` + """ + try: + verbose_logger.debug("Datadog: in init datadog logger") + # check if the correct env variables are set + if os.getenv("DD_API_KEY", None) is None: + raise Exception("DD_API_KEY is not set, set 'DD_API_KEY=<>") + if os.getenv("DD_SITE", None) is None: + raise Exception("DD_SITE is not set in .env, set 'DD_SITE=<>") + self.async_client = get_async_httpx_client( + llm_provider=httpxSpecialProvider.LoggingCallback + ) + self.DD_API_KEY = os.getenv("DD_API_KEY") + self.intake_url = ( + f"https://http-intake.logs.{os.getenv('DD_SITE')}/api/v2/logs" + ) + + ################################### + # OPTIONAL -only used for testing + dd_base_url: Optional[str] = ( + os.getenv("_DATADOG_BASE_URL") + or os.getenv("DATADOG_BASE_URL") + or os.getenv("DD_BASE_URL") + ) + if dd_base_url is not None: + self.intake_url = f"{dd_base_url}/api/v2/logs" + ################################### + self.sync_client = _get_httpx_client() + asyncio.create_task(self.periodic_flush()) + self.flush_lock = asyncio.Lock() + super().__init__( + **kwargs, flush_lock=self.flush_lock, batch_size=DD_MAX_BATCH_SIZE + ) + except Exception as e: + verbose_logger.exception( + f"Datadog: Got exception on init Datadog client {str(e)}" + ) + raise e + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + """ + Async Log success events to Datadog + + - Creates a Datadog payload + - Adds the Payload to the in memory logs queue + - Payload is flushed every 10 seconds or when batch size is greater than 100 + + + Raises: + Raises a NON Blocking verbose_logger.exception if an error occurs + """ + try: + verbose_logger.debug( + "Datadog: Logging - Enters logging function for model %s", kwargs + ) + await self._log_async_event(kwargs, response_obj, start_time, end_time) + + except Exception as e: + verbose_logger.exception( + f"Datadog Layer Error - {str(e)}\n{traceback.format_exc()}" + ) + pass + + async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): + try: + verbose_logger.debug( + "Datadog: Logging - Enters logging function for model %s", kwargs + ) + await self._log_async_event(kwargs, response_obj, start_time, end_time) + + except Exception as e: + verbose_logger.exception( + f"Datadog Layer Error - {str(e)}\n{traceback.format_exc()}" + ) + pass + + async def async_send_batch(self): + """ + Sends the in memory logs queue to datadog api + + Logs sent to /api/v2/logs + + DD Ref: https://docs.datadoghq.com/api/latest/logs/ + + Raises: + Raises a NON Blocking verbose_logger.exception if an error occurs + """ + try: + if not self.log_queue: + verbose_logger.exception("Datadog: log_queue does not exist") + return + + verbose_logger.debug( + "Datadog - about to flush %s events on %s", + len(self.log_queue), + self.intake_url, + ) + + response = await self.async_send_compressed_data(self.log_queue) + if response.status_code == 413: + verbose_logger.exception(DD_ERRORS.DATADOG_413_ERROR.value) + return + + response.raise_for_status() + if response.status_code != 202: + raise Exception( + f"Response from datadog API status_code: {response.status_code}, text: {response.text}" + ) + + verbose_logger.debug( + "Datadog: Response from datadog API status_code: %s, text: %s", + response.status_code, + response.text, + ) + except Exception as e: + verbose_logger.exception( + f"Datadog Error sending batch API - {str(e)}\n{traceback.format_exc()}" + ) + + def log_success_event(self, kwargs, response_obj, start_time, end_time): + """ + Sync Log success events to Datadog + + - Creates a Datadog payload + - instantly logs it on DD API + """ + try: + if litellm.datadog_use_v1 is True: + dd_payload = self._create_v0_logging_payload( + kwargs=kwargs, + response_obj=response_obj, + start_time=start_time, + end_time=end_time, + ) + else: + dd_payload = self.create_datadog_logging_payload( + kwargs=kwargs, + response_obj=response_obj, + start_time=start_time, + end_time=end_time, + ) + + response = self.sync_client.post( + url=self.intake_url, + json=dd_payload, # type: ignore + headers={ + "DD-API-KEY": self.DD_API_KEY, + }, + ) + + response.raise_for_status() + if response.status_code != 202: + raise Exception( + f"Response from datadog API status_code: {response.status_code}, text: {response.text}" + ) + + verbose_logger.debug( + "Datadog: Response from datadog API status_code: %s, text: %s", + response.status_code, + response.text, + ) + + except Exception as e: + verbose_logger.exception( + f"Datadog Layer Error - {str(e)}\n{traceback.format_exc()}" + ) + pass + pass + + async def _log_async_event(self, kwargs, response_obj, start_time, end_time): + + dd_payload = self.create_datadog_logging_payload( + kwargs=kwargs, + response_obj=response_obj, + start_time=start_time, + end_time=end_time, + ) + + self.log_queue.append(dd_payload) + verbose_logger.debug( + f"Datadog, event added to queue. Will flush in {self.flush_interval} seconds..." + ) + + if len(self.log_queue) >= self.batch_size: + await self.async_send_batch() + + def _create_datadog_logging_payload_helper( + self, + standard_logging_object: StandardLoggingPayload, + status: DataDogStatus, + ) -> DatadogPayload: + json_payload = json.dumps(standard_logging_object, default=str) + verbose_logger.debug("Datadog: Logger - Logging payload = %s", json_payload) + dd_payload = DatadogPayload( + ddsource=self._get_datadog_source(), + ddtags=self._get_datadog_tags( + standard_logging_object=standard_logging_object + ), + hostname=self._get_datadog_hostname(), + message=json_payload, + service=self._get_datadog_service(), + status=status, + ) + return dd_payload + + def create_datadog_logging_payload( + self, + kwargs: Union[dict, Any], + response_obj: Any, + start_time: datetime.datetime, + end_time: datetime.datetime, + ) -> DatadogPayload: + """ + Helper function to create a datadog payload for logging + + Args: + kwargs (Union[dict, Any]): request kwargs + response_obj (Any): llm api response + start_time (datetime.datetime): start time of request + end_time (datetime.datetime): end time of request + + Returns: + DatadogPayload: defined in types.py + """ + + standard_logging_object: Optional[StandardLoggingPayload] = kwargs.get( + "standard_logging_object", None + ) + if standard_logging_object is None: + raise ValueError("standard_logging_object not found in kwargs") + + status = DataDogStatus.INFO + if standard_logging_object.get("status") == "failure": + status = DataDogStatus.ERROR + + # Build the initial payload + self.truncate_standard_logging_payload_content(standard_logging_object) + + dd_payload = self._create_datadog_logging_payload_helper( + standard_logging_object=standard_logging_object, + status=status, + ) + return dd_payload + + async def async_send_compressed_data(self, data: List) -> Response: + """ + Async helper to send compressed data to datadog self.intake_url + + Datadog recommends using gzip to compress data + https://docs.datadoghq.com/api/latest/logs/ + + "Datadog recommends sending your logs compressed. Add the Content-Encoding: gzip header to the request when sending" + """ + + import gzip + import json + + compressed_data = gzip.compress(json.dumps(data, default=str).encode("utf-8")) + response = await self.async_client.post( + url=self.intake_url, + data=compressed_data, # type: ignore + headers={ + "DD-API-KEY": self.DD_API_KEY, + "Content-Encoding": "gzip", + "Content-Type": "application/json", + }, + ) + return response + + async def async_service_failure_hook( + self, + payload: ServiceLoggerPayload, + error: Optional[str] = "", + parent_otel_span: Optional[Any] = None, + start_time: Optional[Union[datetimeObj, float]] = None, + end_time: Optional[Union[float, datetimeObj]] = None, + event_metadata: Optional[dict] = None, + ): + """ + Logs failures from Redis, Postgres (Adjacent systems), as 'WARNING' on DataDog + + - example - Redis is failing / erroring, will be logged on DataDog + """ + try: + _payload_dict = payload.model_dump() + _payload_dict.update(event_metadata or {}) + _dd_message_str = json.dumps(_payload_dict, default=str) + _dd_payload = DatadogPayload( + ddsource=self._get_datadog_source(), + ddtags=self._get_datadog_tags(), + hostname=self._get_datadog_hostname(), + message=_dd_message_str, + service=self._get_datadog_service(), + status=DataDogStatus.WARN, + ) + + self.log_queue.append(_dd_payload) + + except Exception as e: + verbose_logger.exception( + f"Datadog: Logger - Exception in async_service_failure_hook: {e}" + ) + pass + + async def async_service_success_hook( + self, + payload: ServiceLoggerPayload, + error: Optional[str] = "", + parent_otel_span: Optional[Any] = None, + start_time: Optional[Union[datetimeObj, float]] = None, + end_time: Optional[Union[float, datetimeObj]] = None, + event_metadata: Optional[dict] = None, + ): + """ + Logs success from Redis, Postgres (Adjacent systems), as 'INFO' on DataDog + + No user has asked for this so far, this might be spammy on datatdog. If need arises we can implement this + """ + try: + # intentionally done. Don't want to log all service types to DD + if payload.service not in DD_LOGGED_SUCCESS_SERVICE_TYPES: + return + + _payload_dict = payload.model_dump() + _payload_dict.update(event_metadata or {}) + + _dd_message_str = json.dumps(_payload_dict, default=str) + _dd_payload = DatadogPayload( + ddsource=self._get_datadog_source(), + ddtags=self._get_datadog_tags(), + hostname=self._get_datadog_hostname(), + message=_dd_message_str, + service=self._get_datadog_service(), + status=DataDogStatus.INFO, + ) + + self.log_queue.append(_dd_payload) + + except Exception as e: + verbose_logger.exception( + f"Datadog: Logger - Exception in async_service_failure_hook: {e}" + ) + + def _create_v0_logging_payload( + self, + kwargs: Union[dict, Any], + response_obj: Any, + start_time: datetime.datetime, + end_time: datetime.datetime, + ) -> DatadogPayload: + """ + Note: This is our V1 Version of DataDog Logging Payload + + + (Not Recommended) If you want this to get logged set `litellm.datadog_use_v1 = True` + """ + import json + + litellm_params = kwargs.get("litellm_params", {}) + metadata = ( + litellm_params.get("metadata", {}) or {} + ) # if litellm_params['metadata'] == None + messages = kwargs.get("messages") + optional_params = kwargs.get("optional_params", {}) + call_type = kwargs.get("call_type", "litellm.completion") + cache_hit = kwargs.get("cache_hit", False) + usage = response_obj["usage"] + id = response_obj.get("id", str(uuid.uuid4())) + usage = dict(usage) + try: + response_time = (end_time - start_time).total_seconds() * 1000 + except Exception: + response_time = None + + try: + response_obj = dict(response_obj) + except Exception: + response_obj = response_obj + + # Clean Metadata before logging - never log raw metadata + # the raw metadata can contain circular references which leads to infinite recursion + # we clean out all extra litellm metadata params before logging + clean_metadata = {} + if isinstance(metadata, dict): + for key, value in metadata.items(): + # clean litellm metadata before logging + if key in [ + "endpoint", + "caching_groups", + "previous_models", + ]: + continue + else: + clean_metadata[key] = value + + # Build the initial payload + payload = { + "id": id, + "call_type": call_type, + "cache_hit": cache_hit, + "start_time": start_time, + "end_time": end_time, + "response_time": response_time, + "model": kwargs.get("model", ""), + "user": kwargs.get("user", ""), + "model_parameters": optional_params, + "spend": kwargs.get("response_cost", 0), + "messages": messages, + "response": response_obj, + "usage": usage, + "metadata": clean_metadata, + } + + json_payload = json.dumps(payload, default=str) + + verbose_logger.debug("Datadog: Logger - Logging payload = %s", json_payload) + + dd_payload = DatadogPayload( + ddsource=self._get_datadog_source(), + ddtags=self._get_datadog_tags(), + hostname=self._get_datadog_hostname(), + message=json_payload, + service=self._get_datadog_service(), + status=DataDogStatus.INFO, + ) + return dd_payload + + @staticmethod + def _get_datadog_tags( + standard_logging_object: Optional[StandardLoggingPayload] = None, + ) -> str: + """ + Get the datadog tags for the request + + DD tags need to be as follows: + - tags: ["user_handle:dog@gmail.com", "app_version:1.0.0"] + """ + base_tags = { + "env": os.getenv("DD_ENV", "unknown"), + "service": os.getenv("DD_SERVICE", "litellm"), + "version": os.getenv("DD_VERSION", "unknown"), + "HOSTNAME": DataDogLogger._get_datadog_hostname(), + "POD_NAME": os.getenv("POD_NAME", "unknown"), + } + + tags = [f"{k}:{v}" for k, v in base_tags.items()] + + if standard_logging_object: + _request_tags: List[str] = ( + standard_logging_object.get("request_tags", []) or [] + ) + request_tags = [f"request_tag:{tag}" for tag in _request_tags] + tags.extend(request_tags) + + return ",".join(tags) + + @staticmethod + def _get_datadog_source(): + return os.getenv("DD_SOURCE", "litellm") + + @staticmethod + def _get_datadog_service(): + return os.getenv("DD_SERVICE", "litellm-server") + + @staticmethod + def _get_datadog_hostname(): + return os.getenv("HOSTNAME", "") + + @staticmethod + def _get_datadog_env(): + return os.getenv("DD_ENV", "unknown") + + @staticmethod + def _get_datadog_pod_name(): + return os.getenv("POD_NAME", "unknown") + + async def async_health_check(self) -> IntegrationHealthCheckStatus: + """ + Check if the service is healthy + """ + from litellm.litellm_core_utils.litellm_logging import ( + create_dummy_standard_logging_payload, + ) + + standard_logging_object = create_dummy_standard_logging_payload() + dd_payload = self._create_datadog_logging_payload_helper( + standard_logging_object=standard_logging_object, + status=DataDogStatus.INFO, + ) + log_queue = [dd_payload] + response = await self.async_send_compressed_data(log_queue) + try: + response.raise_for_status() + return IntegrationHealthCheckStatus( + status="healthy", + error_message=None, + ) + except httpx.HTTPStatusError as e: + return IntegrationHealthCheckStatus( + status="unhealthy", + error_message=e.response.text, + ) + except Exception as e: + return IntegrationHealthCheckStatus( + status="unhealthy", + error_message=str(e), + ) + + async def get_request_response_payload( + self, + request_id: str, + start_time_utc: Optional[datetimeObj], + end_time_utc: Optional[datetimeObj], + ) -> Optional[dict]: + pass diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/datadog/datadog_llm_obs.py b/.venv/lib/python3.12/site-packages/litellm/integrations/datadog/datadog_llm_obs.py new file mode 100644 index 00000000..e4e074ba --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/integrations/datadog/datadog_llm_obs.py @@ -0,0 +1,203 @@ +""" +Implements logging integration with Datadog's LLM Observability Service + + +API Reference: https://docs.datadoghq.com/llm_observability/setup/api/?tab=example#api-standards + +""" + +import asyncio +import json +import os +import uuid +from datetime import datetime +from typing import Any, Dict, List, Optional, Union + +import litellm +from litellm._logging import verbose_logger +from litellm.integrations.custom_batch_logger import CustomBatchLogger +from litellm.integrations.datadog.datadog import DataDogLogger +from litellm.llms.custom_httpx.http_handler import ( + get_async_httpx_client, + httpxSpecialProvider, +) +from litellm.types.integrations.datadog_llm_obs import * +from litellm.types.utils import StandardLoggingPayload + + +class DataDogLLMObsLogger(DataDogLogger, CustomBatchLogger): + def __init__(self, **kwargs): + try: + verbose_logger.debug("DataDogLLMObs: Initializing logger") + if os.getenv("DD_API_KEY", None) is None: + raise Exception("DD_API_KEY is not set, set 'DD_API_KEY=<>'") + if os.getenv("DD_SITE", None) is None: + raise Exception( + "DD_SITE is not set, set 'DD_SITE=<>', example sit = `us5.datadoghq.com`" + ) + + self.async_client = get_async_httpx_client( + llm_provider=httpxSpecialProvider.LoggingCallback + ) + self.DD_API_KEY = os.getenv("DD_API_KEY") + self.DD_SITE = os.getenv("DD_SITE") + self.intake_url = ( + f"https://api.{self.DD_SITE}/api/intake/llm-obs/v1/trace/spans" + ) + + # testing base url + dd_base_url = os.getenv("DD_BASE_URL") + if dd_base_url: + self.intake_url = f"{dd_base_url}/api/intake/llm-obs/v1/trace/spans" + + asyncio.create_task(self.periodic_flush()) + self.flush_lock = asyncio.Lock() + self.log_queue: List[LLMObsPayload] = [] + CustomBatchLogger.__init__(self, **kwargs, flush_lock=self.flush_lock) + except Exception as e: + verbose_logger.exception(f"DataDogLLMObs: Error initializing - {str(e)}") + raise e + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + try: + verbose_logger.debug( + f"DataDogLLMObs: Logging success event for model {kwargs.get('model', 'unknown')}" + ) + payload = self.create_llm_obs_payload( + kwargs, response_obj, start_time, end_time + ) + verbose_logger.debug(f"DataDogLLMObs: Payload: {payload}") + self.log_queue.append(payload) + + if len(self.log_queue) >= self.batch_size: + await self.async_send_batch() + except Exception as e: + verbose_logger.exception( + f"DataDogLLMObs: Error logging success event - {str(e)}" + ) + + async def async_send_batch(self): + try: + if not self.log_queue: + return + + verbose_logger.debug( + f"DataDogLLMObs: Flushing {len(self.log_queue)} events" + ) + + # Prepare the payload + payload = { + "data": DDIntakePayload( + type="span", + attributes=DDSpanAttributes( + ml_app=self._get_datadog_service(), + tags=[self._get_datadog_tags()], + spans=self.log_queue, + ), + ), + } + verbose_logger.debug("payload %s", json.dumps(payload, indent=4)) + response = await self.async_client.post( + url=self.intake_url, + json=payload, + headers={ + "DD-API-KEY": self.DD_API_KEY, + "Content-Type": "application/json", + }, + ) + + response.raise_for_status() + if response.status_code != 202: + raise Exception( + f"DataDogLLMObs: Unexpected response - status_code: {response.status_code}, text: {response.text}" + ) + + verbose_logger.debug( + f"DataDogLLMObs: Successfully sent batch - status_code: {response.status_code}" + ) + self.log_queue.clear() + except Exception as e: + verbose_logger.exception(f"DataDogLLMObs: Error sending batch - {str(e)}") + + def create_llm_obs_payload( + self, kwargs: Dict, response_obj: Any, start_time: datetime, end_time: datetime + ) -> LLMObsPayload: + standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get( + "standard_logging_object" + ) + if standard_logging_payload is None: + raise Exception("DataDogLLMObs: standard_logging_object is not set") + + messages = standard_logging_payload["messages"] + messages = self._ensure_string_content(messages=messages) + + metadata = kwargs.get("litellm_params", {}).get("metadata", {}) + + input_meta = InputMeta(messages=messages) # type: ignore + output_meta = OutputMeta(messages=self._get_response_messages(response_obj)) + + meta = Meta( + kind="llm", + input=input_meta, + output=output_meta, + metadata=self._get_dd_llm_obs_payload_metadata(standard_logging_payload), + ) + + # Calculate metrics (you may need to adjust these based on available data) + metrics = LLMMetrics( + input_tokens=float(standard_logging_payload.get("prompt_tokens", 0)), + output_tokens=float(standard_logging_payload.get("completion_tokens", 0)), + total_tokens=float(standard_logging_payload.get("total_tokens", 0)), + ) + + return LLMObsPayload( + parent_id=metadata.get("parent_id", "undefined"), + trace_id=metadata.get("trace_id", str(uuid.uuid4())), + span_id=metadata.get("span_id", str(uuid.uuid4())), + name=metadata.get("name", "litellm_llm_call"), + meta=meta, + start_ns=int(start_time.timestamp() * 1e9), + duration=int((end_time - start_time).total_seconds() * 1e9), + metrics=metrics, + tags=[ + self._get_datadog_tags(standard_logging_object=standard_logging_payload) + ], + ) + + def _get_response_messages(self, response_obj: Any) -> List[Any]: + """ + Get the messages from the response object + + for now this handles logging /chat/completions responses + """ + if isinstance(response_obj, litellm.ModelResponse): + return [response_obj["choices"][0]["message"].json()] + return [] + + def _ensure_string_content( + self, messages: Optional[Union[str, List[Any], Dict[Any, Any]]] + ) -> List[Any]: + if messages is None: + return [] + if isinstance(messages, str): + return [messages] + elif isinstance(messages, list): + return [message for message in messages] + elif isinstance(messages, dict): + return [str(messages.get("content", ""))] + return [] + + def _get_dd_llm_obs_payload_metadata( + self, standard_logging_payload: StandardLoggingPayload + ) -> Dict: + _metadata = { + "model_name": standard_logging_payload.get("model", "unknown"), + "model_provider": standard_logging_payload.get( + "custom_llm_provider", "unknown" + ), + } + _standard_logging_metadata: dict = ( + dict(standard_logging_payload.get("metadata", {})) or {} + ) + _metadata.update(_standard_logging_metadata) + return _metadata diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/dynamodb.py b/.venv/lib/python3.12/site-packages/litellm/integrations/dynamodb.py new file mode 100644 index 00000000..2c527ea8 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/integrations/dynamodb.py @@ -0,0 +1,89 @@ +#### What this does #### +# On success + failure, log events to Supabase + +import os +import traceback +import uuid +from typing import Any + +import litellm + + +class DyanmoDBLogger: + # Class variables or attributes + + def __init__(self): + # Instance variables + import boto3 + + self.dynamodb: Any = boto3.resource( + "dynamodb", region_name=os.environ["AWS_REGION_NAME"] + ) + if litellm.dynamodb_table_name is None: + raise ValueError( + "LiteLLM Error, trying to use DynamoDB but not table name passed. Create a table and set `litellm.dynamodb_table_name=<your-table>`" + ) + self.table_name = litellm.dynamodb_table_name + + async def _async_log_event( + self, kwargs, response_obj, start_time, end_time, print_verbose + ): + self.log_event(kwargs, response_obj, start_time, end_time, print_verbose) + + def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose): + try: + print_verbose( + f"DynamoDB Logging - Enters logging function for model {kwargs}" + ) + + # construct payload to send to DynamoDB + # follows the same params as langfuse.py + litellm_params = kwargs.get("litellm_params", {}) + metadata = ( + litellm_params.get("metadata", {}) or {} + ) # if litellm_params['metadata'] == None + messages = kwargs.get("messages") + optional_params = kwargs.get("optional_params", {}) + call_type = kwargs.get("call_type", "litellm.completion") + usage = response_obj["usage"] + id = response_obj.get("id", str(uuid.uuid4())) + + # Build the initial payload + payload = { + "id": id, + "call_type": call_type, + "startTime": start_time, + "endTime": end_time, + "model": kwargs.get("model", ""), + "user": kwargs.get("user", ""), + "modelParameters": optional_params, + "messages": messages, + "response": response_obj, + "usage": usage, + "metadata": metadata, + } + + # Ensure everything in the payload is converted to str + for key, value in payload.items(): + try: + payload[key] = str(value) + except Exception: + # non blocking if it can't cast to a str + pass + + print_verbose(f"\nDynamoDB Logger - Logging payload = {payload}") + + # put data in dyanmo DB + table = self.dynamodb.Table(self.table_name) + # Assuming log_data is a dictionary with log information + response = table.put_item(Item=payload) + + print_verbose(f"Response from DynamoDB:{str(response)}") + + print_verbose( + f"DynamoDB Layer Logging - final response object: {response_obj}" + ) + return response + except Exception: + print_verbose(f"DynamoDB Layer Error - {traceback.format_exc()}") + pass diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/email_alerting.py b/.venv/lib/python3.12/site-packages/litellm/integrations/email_alerting.py new file mode 100644 index 00000000..b45b9aa7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/integrations/email_alerting.py @@ -0,0 +1,136 @@ +""" +Functions for sending Email Alerts +""" + +import os +from typing import List, Optional + +from litellm._logging import verbose_logger, verbose_proxy_logger +from litellm.proxy._types import WebhookEvent + +# we use this for the email header, please send a test email if you change this. verify it looks good on email +LITELLM_LOGO_URL = "https://litellm-listing.s3.amazonaws.com/litellm_logo.png" +LITELLM_SUPPORT_CONTACT = "support@berri.ai" + + +async def get_all_team_member_emails(team_id: Optional[str] = None) -> list: + verbose_logger.debug( + "Email Alerting: Getting all team members for team_id=%s", team_id + ) + if team_id is None: + return [] + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise Exception("Not connected to DB!") + + team_row = await prisma_client.db.litellm_teamtable.find_unique( + where={ + "team_id": team_id, + } + ) + + if team_row is None: + return [] + + _team_members = team_row.members_with_roles + verbose_logger.debug( + "Email Alerting: Got team members for team_id=%s Team Members: %s", + team_id, + _team_members, + ) + _team_member_user_ids: List[str] = [] + for member in _team_members: + if member and isinstance(member, dict): + _user_id = member.get("user_id") + if _user_id and isinstance(_user_id, str): + _team_member_user_ids.append(_user_id) + + sql_query = """ + SELECT user_email + FROM "LiteLLM_UserTable" + WHERE user_id = ANY($1::TEXT[]); + """ + + _result = await prisma_client.db.query_raw(sql_query, _team_member_user_ids) + + verbose_logger.debug("Email Alerting: Got all Emails for team, emails=%s", _result) + + if _result is None: + return [] + + emails = [] + for user in _result: + if user and isinstance(user, dict) and user.get("user_email", None) is not None: + emails.append(user.get("user_email")) + return emails + + +async def send_team_budget_alert(webhook_event: WebhookEvent) -> bool: + """ + Send an Email Alert to All Team Members when the Team Budget is crossed + Returns -> True if sent, False if not. + """ + from litellm.proxy.utils import send_email + + _team_id = webhook_event.team_id + team_alias = webhook_event.team_alias + verbose_logger.debug( + "Email Alerting: Sending Team Budget Alert for team=%s", team_alias + ) + + email_logo_url = os.getenv("SMTP_SENDER_LOGO", os.getenv("EMAIL_LOGO_URL", None)) + email_support_contact = os.getenv("EMAIL_SUPPORT_CONTACT", None) + + # await self._check_if_using_premium_email_feature( + # premium_user, email_logo_url, email_support_contact + # ) + + if email_logo_url is None: + email_logo_url = LITELLM_LOGO_URL + if email_support_contact is None: + email_support_contact = LITELLM_SUPPORT_CONTACT + recipient_emails = await get_all_team_member_emails(_team_id) + recipient_emails_str: str = ",".join(recipient_emails) + verbose_logger.debug( + "Email Alerting: Sending team budget alert to %s", recipient_emails_str + ) + + event_name = webhook_event.event_message + max_budget = webhook_event.max_budget + email_html_content = "Alert from LiteLLM Server" + + if recipient_emails_str is None: + verbose_proxy_logger.warning( + "Email Alerting: Trying to send email alert to no recipient, got recipient_emails=%s", + recipient_emails_str, + ) + + email_html_content = f""" + <img src="{email_logo_url}" alt="LiteLLM Logo" width="150" height="50" /> <br/><br/><br/> + + Budget Crossed for Team <b> {team_alias} </b> <br/> <br/> + + Your Teams LLM API usage has crossed it's <b> budget of ${max_budget} </b>, current spend is <b>${webhook_event.spend}</b><br /> <br /> + + API requests will be rejected until either (a) you increase your budget or (b) your budget gets reset <br /> <br /> + + If you have any questions, please send an email to {email_support_contact} <br /> <br /> + + Best, <br /> + The LiteLLM team <br /> + """ + + email_event = { + "to": recipient_emails_str, + "subject": f"LiteLLM {event_name} for Team {team_alias}", + "html": email_html_content, + } + + await send_email( + receiver_email=email_event["to"], + subject=email_event["subject"], + html=email_event["html"], + ) + + return False diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/email_templates/templates.py b/.venv/lib/python3.12/site-packages/litellm/integrations/email_templates/templates.py new file mode 100644 index 00000000..7029e8ce --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/integrations/email_templates/templates.py @@ -0,0 +1,62 @@ +""" +Email Templates used by the LiteLLM Email Service in slack_alerting.py +""" + +KEY_CREATED_EMAIL_TEMPLATE = """ + <img src="{email_logo_url}" alt="LiteLLM Logo" width="150" height="50" /> + + <p> Hi {recipient_email}, <br/> + + I'm happy to provide you with an OpenAI Proxy API Key, loaded with ${key_budget} per month. <br /> <br /> + + <b> + Key: <pre>{key_token}</pre> <br> + </b> + + <h2>Usage Example</h2> + + Detailed Documentation on <a href="https://docs.litellm.ai/docs/proxy/user_keys">Usage with OpenAI Python SDK, Langchain, LlamaIndex, Curl</a> + + <pre> + + import openai + client = openai.OpenAI( + api_key="{key_token}", + base_url={{base_url}} + ) + + response = client.chat.completions.create( + model="gpt-3.5-turbo", # model to send to the proxy + messages = [ + {{ + "role": "user", + "content": "this is a test request, write a short poem" + }} + ] + ) + + </pre> + + + If you have any questions, please send an email to {email_support_contact} <br /> <br /> + + Best, <br /> + The LiteLLM team <br /> +""" + + +USER_INVITED_EMAIL_TEMPLATE = """ + <img src="{email_logo_url}" alt="LiteLLM Logo" width="150" height="50" /> + + <p> Hi {recipient_email}, <br/> + + You were invited to use OpenAI Proxy API for team {team_name} <br /> <br /> + + <a href="{base_url}" style="display: inline-block; padding: 10px 20px; background-color: #87ceeb; color: #fff; text-decoration: none; border-radius: 20px;">Get Started here</a> <br /> <br /> + + + If you have any questions, please send an email to {email_support_contact} <br /> <br /> + + Best, <br /> + The LiteLLM team <br /> +""" diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/galileo.py b/.venv/lib/python3.12/site-packages/litellm/integrations/galileo.py new file mode 100644 index 00000000..e99d5f23 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/integrations/galileo.py @@ -0,0 +1,157 @@ +import os +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, Field + +import litellm +from litellm._logging import verbose_logger +from litellm.integrations.custom_logger import CustomLogger +from litellm.llms.custom_httpx.http_handler import ( + get_async_httpx_client, + httpxSpecialProvider, +) + + +# from here: https://docs.rungalileo.io/galileo/gen-ai-studio-products/galileo-observe/how-to/logging-data-via-restful-apis#structuring-your-records +class LLMResponse(BaseModel): + latency_ms: int + status_code: int + input_text: str + output_text: str + node_type: str + model: str + num_input_tokens: int + num_output_tokens: int + output_logprobs: Optional[Dict[str, Any]] = Field( + default=None, + description="Optional. When available, logprobs are used to compute Uncertainty.", + ) + created_at: str = Field( + ..., description='timestamp constructed in "%Y-%m-%dT%H:%M:%S" format' + ) + tags: Optional[List[str]] = None + user_metadata: Optional[Dict[str, Any]] = None + + +class GalileoObserve(CustomLogger): + def __init__(self) -> None: + self.in_memory_records: List[dict] = [] + self.batch_size = 1 + self.base_url = os.getenv("GALILEO_BASE_URL", None) + self.project_id = os.getenv("GALILEO_PROJECT_ID", None) + self.headers: Optional[Dict[str, str]] = None + self.async_httpx_handler = get_async_httpx_client( + llm_provider=httpxSpecialProvider.LoggingCallback + ) + pass + + def set_galileo_headers(self): + # following https://docs.rungalileo.io/galileo/gen-ai-studio-products/galileo-observe/how-to/logging-data-via-restful-apis#logging-your-records + + headers = { + "accept": "application/json", + "Content-Type": "application/x-www-form-urlencoded", + } + galileo_login_response = litellm.module_level_client.post( + url=f"{self.base_url}/login", + headers=headers, + data={ + "username": os.getenv("GALILEO_USERNAME"), + "password": os.getenv("GALILEO_PASSWORD"), + }, + ) + + access_token = galileo_login_response.json()["access_token"] + + self.headers = { + "accept": "application/json", + "Content-Type": "application/json", + "Authorization": f"Bearer {access_token}", + } + + def get_output_str_from_response(self, response_obj, kwargs): + output = None + if response_obj is not None and ( + kwargs.get("call_type", None) == "embedding" + or isinstance(response_obj, litellm.EmbeddingResponse) + ): + output = None + elif response_obj is not None and isinstance( + response_obj, litellm.ModelResponse + ): + output = response_obj["choices"][0]["message"].json() + elif response_obj is not None and isinstance( + response_obj, litellm.TextCompletionResponse + ): + output = response_obj.choices[0].text + elif response_obj is not None and isinstance( + response_obj, litellm.ImageResponse + ): + output = response_obj["data"] + + return output + + async def async_log_success_event( + self, kwargs: Any, response_obj: Any, start_time: Any, end_time: Any + ): + verbose_logger.debug("On Async Success") + + _latency_ms = int((end_time - start_time).total_seconds() * 1000) + _call_type = kwargs.get("call_type", "litellm") + input_text = litellm.utils.get_formatted_prompt( + data=kwargs, call_type=_call_type + ) + + _usage = response_obj.get("usage", {}) or {} + num_input_tokens = _usage.get("prompt_tokens", 0) + num_output_tokens = _usage.get("completion_tokens", 0) + + output_text = self.get_output_str_from_response( + response_obj=response_obj, kwargs=kwargs + ) + + if output_text is not None: + request_record = LLMResponse( + latency_ms=_latency_ms, + status_code=200, + input_text=input_text, + output_text=output_text, + node_type=_call_type, + model=kwargs.get("model", "-"), + num_input_tokens=num_input_tokens, + num_output_tokens=num_output_tokens, + created_at=start_time.strftime( + "%Y-%m-%dT%H:%M:%S" + ), # timestamp str constructed in "%Y-%m-%dT%H:%M:%S" format + ) + + # dump to dict + request_dict = request_record.model_dump() + self.in_memory_records.append(request_dict) + + if len(self.in_memory_records) >= self.batch_size: + await self.flush_in_memory_records() + + async def flush_in_memory_records(self): + verbose_logger.debug("flushing in memory records") + response = await self.async_httpx_handler.post( + url=f"{self.base_url}/projects/{self.project_id}/observe/ingest", + headers=self.headers, + json={"records": self.in_memory_records}, + ) + + if response.status_code == 200: + verbose_logger.debug( + "Galileo Logger:successfully flushed in memory records" + ) + self.in_memory_records = [] + else: + verbose_logger.debug("Galileo Logger: failed to flush in memory records") + verbose_logger.debug( + "Galileo Logger error=%s, status code=%s", + response.text, + response.status_code, + ) + + async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): + verbose_logger.debug("On Async Failure") diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/gcs_bucket/Readme.md b/.venv/lib/python3.12/site-packages/litellm/integrations/gcs_bucket/Readme.md new file mode 100644 index 00000000..2ab0b233 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/integrations/gcs_bucket/Readme.md @@ -0,0 +1,12 @@ +# GCS (Google Cloud Storage) Bucket Logging on LiteLLM Gateway + +This folder contains the GCS Bucket Logging integration for LiteLLM Gateway. + +## Folder Structure + +- `gcs_bucket.py`: This is the main file that handles failure/success logging to GCS Bucket +- `gcs_bucket_base.py`: This file contains the GCSBucketBase class which handles Authentication for GCS Buckets + +## Further Reading +- [Doc setting up GCS Bucket Logging on LiteLLM Proxy (Gateway)](https://docs.litellm.ai/docs/proxy/bucket) +- [Doc on Key / Team Based logging with GCS](https://docs.litellm.ai/docs/proxy/team_logging)
\ No newline at end of file diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/gcs_bucket/gcs_bucket.py b/.venv/lib/python3.12/site-packages/litellm/integrations/gcs_bucket/gcs_bucket.py new file mode 100644 index 00000000..187ab779 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/integrations/gcs_bucket/gcs_bucket.py @@ -0,0 +1,237 @@ +import asyncio +import json +import os +import uuid +from datetime import datetime, timedelta, timezone +from typing import TYPE_CHECKING, Any, Dict, List, Optional +from urllib.parse import quote + +from litellm._logging import verbose_logger +from litellm.integrations.additional_logging_utils import AdditionalLoggingUtils +from litellm.integrations.gcs_bucket.gcs_bucket_base import GCSBucketBase +from litellm.proxy._types import CommonProxyErrors +from litellm.types.integrations.base_health_check import IntegrationHealthCheckStatus +from litellm.types.integrations.gcs_bucket import * +from litellm.types.utils import StandardLoggingPayload + +if TYPE_CHECKING: + from litellm.llms.vertex_ai.vertex_llm_base import VertexBase +else: + VertexBase = Any + + +GCS_DEFAULT_BATCH_SIZE = 2048 +GCS_DEFAULT_FLUSH_INTERVAL_SECONDS = 20 + + +class GCSBucketLogger(GCSBucketBase, AdditionalLoggingUtils): + def __init__(self, bucket_name: Optional[str] = None) -> None: + from litellm.proxy.proxy_server import premium_user + + super().__init__(bucket_name=bucket_name) + + # Init Batch logging settings + self.log_queue: List[GCSLogQueueItem] = [] + self.batch_size = int(os.getenv("GCS_BATCH_SIZE", GCS_DEFAULT_BATCH_SIZE)) + self.flush_interval = int( + os.getenv("GCS_FLUSH_INTERVAL", GCS_DEFAULT_FLUSH_INTERVAL_SECONDS) + ) + asyncio.create_task(self.periodic_flush()) + self.flush_lock = asyncio.Lock() + super().__init__( + flush_lock=self.flush_lock, + batch_size=self.batch_size, + flush_interval=self.flush_interval, + ) + AdditionalLoggingUtils.__init__(self) + + if premium_user is not True: + raise ValueError( + f"GCS Bucket logging is a premium feature. Please upgrade to use it. {CommonProxyErrors.not_premium_user.value}" + ) + + #### ASYNC #### + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + from litellm.proxy.proxy_server import premium_user + + if premium_user is not True: + raise ValueError( + f"GCS Bucket logging is a premium feature. Please upgrade to use it. {CommonProxyErrors.not_premium_user.value}" + ) + try: + verbose_logger.debug( + "GCS Logger: async_log_success_event logging kwargs: %s, response_obj: %s", + kwargs, + response_obj, + ) + logging_payload: Optional[StandardLoggingPayload] = kwargs.get( + "standard_logging_object", None + ) + if logging_payload is None: + raise ValueError("standard_logging_object not found in kwargs") + # Add to logging queue - this will be flushed periodically + self.log_queue.append( + GCSLogQueueItem( + payload=logging_payload, kwargs=kwargs, response_obj=response_obj + ) + ) + + except Exception as e: + verbose_logger.exception(f"GCS Bucket logging error: {str(e)}") + + async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): + try: + verbose_logger.debug( + "GCS Logger: async_log_failure_event logging kwargs: %s, response_obj: %s", + kwargs, + response_obj, + ) + + logging_payload: Optional[StandardLoggingPayload] = kwargs.get( + "standard_logging_object", None + ) + if logging_payload is None: + raise ValueError("standard_logging_object not found in kwargs") + # Add to logging queue - this will be flushed periodically + self.log_queue.append( + GCSLogQueueItem( + payload=logging_payload, kwargs=kwargs, response_obj=response_obj + ) + ) + + except Exception as e: + verbose_logger.exception(f"GCS Bucket logging error: {str(e)}") + + async def async_send_batch(self): + """ + Process queued logs in batch - sends logs to GCS Bucket + + + GCS Bucket does not have a Batch endpoint to batch upload logs + + Instead, we + - collect the logs to flush every `GCS_FLUSH_INTERVAL` seconds + - during async_send_batch, we make 1 POST request per log to GCS Bucket + + """ + if not self.log_queue: + return + + for log_item in self.log_queue: + logging_payload = log_item["payload"] + kwargs = log_item["kwargs"] + response_obj = log_item.get("response_obj", None) or {} + + gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config( + kwargs + ) + headers = await self.construct_request_headers( + vertex_instance=gcs_logging_config["vertex_instance"], + service_account_json=gcs_logging_config["path_service_account"], + ) + bucket_name = gcs_logging_config["bucket_name"] + object_name = self._get_object_name(kwargs, logging_payload, response_obj) + + try: + await self._log_json_data_on_gcs( + headers=headers, + bucket_name=bucket_name, + object_name=object_name, + logging_payload=logging_payload, + ) + except Exception as e: + # don't let one log item fail the entire batch + verbose_logger.exception( + f"GCS Bucket error logging payload to GCS bucket: {str(e)}" + ) + pass + + # Clear the queue after processing + self.log_queue.clear() + + def _get_object_name( + self, kwargs: Dict, logging_payload: StandardLoggingPayload, response_obj: Any + ) -> str: + """ + Get the object name to use for the current payload + """ + current_date = self._get_object_date_from_datetime(datetime.now(timezone.utc)) + if logging_payload.get("error_str", None) is not None: + object_name = self._generate_failure_object_name( + request_date_str=current_date, + ) + else: + object_name = self._generate_success_object_name( + request_date_str=current_date, + response_id=response_obj.get("id", ""), + ) + + # used for testing + _litellm_params = kwargs.get("litellm_params", None) or {} + _metadata = _litellm_params.get("metadata", None) or {} + if "gcs_log_id" in _metadata: + object_name = _metadata["gcs_log_id"] + + return object_name + + async def get_request_response_payload( + self, + request_id: str, + start_time_utc: Optional[datetime], + end_time_utc: Optional[datetime], + ) -> Optional[dict]: + """ + Get the request and response payload for a given `request_id` + Tries current day, next day, and previous day until it finds the payload + """ + if start_time_utc is None: + raise ValueError( + "start_time_utc is required for getting a payload from GCS Bucket" + ) + + # Try current day, next day, and previous day + dates_to_try = [ + start_time_utc, + start_time_utc + timedelta(days=1), + start_time_utc - timedelta(days=1), + ] + date_str = None + for date in dates_to_try: + try: + date_str = self._get_object_date_from_datetime(datetime_obj=date) + object_name = self._generate_success_object_name( + request_date_str=date_str, + response_id=request_id, + ) + encoded_object_name = quote(object_name, safe="") + response = await self.download_gcs_object(encoded_object_name) + + if response is not None: + loaded_response = json.loads(response) + return loaded_response + except Exception as e: + verbose_logger.debug( + f"Failed to fetch payload for date {date_str}: {str(e)}" + ) + continue + + return None + + def _generate_success_object_name( + self, + request_date_str: str, + response_id: str, + ) -> str: + return f"{request_date_str}/{response_id}" + + def _generate_failure_object_name( + self, + request_date_str: str, + ) -> str: + return f"{request_date_str}/failure-{uuid.uuid4().hex}" + + def _get_object_date_from_datetime(self, datetime_obj: datetime) -> str: + return datetime_obj.strftime("%Y-%m-%d") + + async def async_health_check(self) -> IntegrationHealthCheckStatus: + raise NotImplementedError("GCS Bucket does not support health check") diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/gcs_bucket/gcs_bucket_base.py b/.venv/lib/python3.12/site-packages/litellm/integrations/gcs_bucket/gcs_bucket_base.py new file mode 100644 index 00000000..66995d84 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/integrations/gcs_bucket/gcs_bucket_base.py @@ -0,0 +1,326 @@ +import json +import os +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union + +from litellm._logging import verbose_logger +from litellm.integrations.custom_batch_logger import CustomBatchLogger +from litellm.llms.custom_httpx.http_handler import ( + get_async_httpx_client, + httpxSpecialProvider, +) +from litellm.types.integrations.gcs_bucket import * +from litellm.types.utils import StandardCallbackDynamicParams, StandardLoggingPayload + +if TYPE_CHECKING: + from litellm.llms.vertex_ai.vertex_llm_base import VertexBase +else: + VertexBase = Any +IAM_AUTH_KEY = "IAM_AUTH" + + +class GCSBucketBase(CustomBatchLogger): + def __init__(self, bucket_name: Optional[str] = None, **kwargs) -> None: + self.async_httpx_client = get_async_httpx_client( + llm_provider=httpxSpecialProvider.LoggingCallback + ) + _path_service_account = os.getenv("GCS_PATH_SERVICE_ACCOUNT") + _bucket_name = bucket_name or os.getenv("GCS_BUCKET_NAME") + self.path_service_account_json: Optional[str] = _path_service_account + self.BUCKET_NAME: Optional[str] = _bucket_name + self.vertex_instances: Dict[str, VertexBase] = {} + super().__init__(**kwargs) + + async def construct_request_headers( + self, + service_account_json: Optional[str], + vertex_instance: Optional[VertexBase] = None, + ) -> Dict[str, str]: + from litellm import vertex_chat_completion + + if vertex_instance is None: + vertex_instance = vertex_chat_completion + + _auth_header, vertex_project = await vertex_instance._ensure_access_token_async( + credentials=service_account_json, + project_id=None, + custom_llm_provider="vertex_ai", + ) + + auth_header, _ = vertex_instance._get_token_and_url( + model="gcs-bucket", + auth_header=_auth_header, + vertex_credentials=service_account_json, + vertex_project=vertex_project, + vertex_location=None, + gemini_api_key=None, + stream=None, + custom_llm_provider="vertex_ai", + api_base=None, + ) + verbose_logger.debug("constructed auth_header %s", auth_header) + headers = { + "Authorization": f"Bearer {auth_header}", # auth_header + "Content-Type": "application/json", + } + + return headers + + def sync_construct_request_headers(self) -> Dict[str, str]: + from litellm import vertex_chat_completion + + _auth_header, vertex_project = vertex_chat_completion._ensure_access_token( + credentials=self.path_service_account_json, + project_id=None, + custom_llm_provider="vertex_ai", + ) + + auth_header, _ = vertex_chat_completion._get_token_and_url( + model="gcs-bucket", + auth_header=_auth_header, + vertex_credentials=self.path_service_account_json, + vertex_project=vertex_project, + vertex_location=None, + gemini_api_key=None, + stream=None, + custom_llm_provider="vertex_ai", + api_base=None, + ) + verbose_logger.debug("constructed auth_header %s", auth_header) + headers = { + "Authorization": f"Bearer {auth_header}", # auth_header + "Content-Type": "application/json", + } + + return headers + + def _handle_folders_in_bucket_name( + self, + bucket_name: str, + object_name: str, + ) -> Tuple[str, str]: + """ + Handles when the user passes a bucket name with a folder postfix + + + Example: + - Bucket name: "my-bucket/my-folder/dev" + - Object name: "my-object" + - Returns: bucket_name="my-bucket", object_name="my-folder/dev/my-object" + + """ + if "/" in bucket_name: + bucket_name, prefix = bucket_name.split("/", 1) + object_name = f"{prefix}/{object_name}" + return bucket_name, object_name + return bucket_name, object_name + + async def get_gcs_logging_config( + self, kwargs: Optional[Dict[str, Any]] = {} + ) -> GCSLoggingConfig: + """ + This function is used to get the GCS logging config for the GCS Bucket Logger. + It checks if the dynamic parameters are provided in the kwargs and uses them to get the GCS logging config. + If no dynamic parameters are provided, it uses the default values. + """ + if kwargs is None: + kwargs = {} + + standard_callback_dynamic_params: Optional[StandardCallbackDynamicParams] = ( + kwargs.get("standard_callback_dynamic_params", None) + ) + + bucket_name: str + path_service_account: Optional[str] + if standard_callback_dynamic_params is not None: + verbose_logger.debug("Using dynamic GCS logging") + verbose_logger.debug( + "standard_callback_dynamic_params: %s", standard_callback_dynamic_params + ) + + _bucket_name: Optional[str] = ( + standard_callback_dynamic_params.get("gcs_bucket_name", None) + or self.BUCKET_NAME + ) + _path_service_account: Optional[str] = ( + standard_callback_dynamic_params.get("gcs_path_service_account", None) + or self.path_service_account_json + ) + + if _bucket_name is None: + raise ValueError( + "GCS_BUCKET_NAME is not set in the environment, but GCS Bucket is being used as a logging callback. Please set 'GCS_BUCKET_NAME' in the environment." + ) + bucket_name = _bucket_name + path_service_account = _path_service_account + vertex_instance = await self.get_or_create_vertex_instance( + credentials=path_service_account + ) + else: + # If no dynamic parameters, use the default instance + if self.BUCKET_NAME is None: + raise ValueError( + "GCS_BUCKET_NAME is not set in the environment, but GCS Bucket is being used as a logging callback. Please set 'GCS_BUCKET_NAME' in the environment." + ) + bucket_name = self.BUCKET_NAME + path_service_account = self.path_service_account_json + vertex_instance = await self.get_or_create_vertex_instance( + credentials=path_service_account + ) + + return GCSLoggingConfig( + bucket_name=bucket_name, + vertex_instance=vertex_instance, + path_service_account=path_service_account, + ) + + async def get_or_create_vertex_instance( + self, credentials: Optional[str] + ) -> VertexBase: + """ + This function is used to get the Vertex instance for the GCS Bucket Logger. + It checks if the Vertex instance is already created and cached, if not it creates a new instance and caches it. + """ + from litellm.llms.vertex_ai.vertex_llm_base import VertexBase + + _in_memory_key = self._get_in_memory_key_for_vertex_instance(credentials) + if _in_memory_key not in self.vertex_instances: + vertex_instance = VertexBase() + await vertex_instance._ensure_access_token_async( + credentials=credentials, + project_id=None, + custom_llm_provider="vertex_ai", + ) + self.vertex_instances[_in_memory_key] = vertex_instance + return self.vertex_instances[_in_memory_key] + + def _get_in_memory_key_for_vertex_instance(self, credentials: Optional[str]) -> str: + """ + Returns key to use for caching the Vertex instance in-memory. + + When using Vertex with Key based logging, we need to cache the Vertex instance in-memory. + + - If a credentials string is provided, it is used as the key. + - If no credentials string is provided, "IAM_AUTH" is used as the key. + """ + return credentials or IAM_AUTH_KEY + + async def download_gcs_object(self, object_name: str, **kwargs): + """ + Download an object from GCS. + + https://cloud.google.com/storage/docs/downloading-objects#download-object-json + """ + try: + gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config( + kwargs=kwargs + ) + headers = await self.construct_request_headers( + vertex_instance=gcs_logging_config["vertex_instance"], + service_account_json=gcs_logging_config["path_service_account"], + ) + bucket_name = gcs_logging_config["bucket_name"] + bucket_name, object_name = self._handle_folders_in_bucket_name( + bucket_name=bucket_name, + object_name=object_name, + ) + + url = f"https://storage.googleapis.com/storage/v1/b/{bucket_name}/o/{object_name}?alt=media" + + # Send the GET request to download the object + response = await self.async_httpx_client.get(url=url, headers=headers) + + if response.status_code != 200: + verbose_logger.error( + "GCS object download error: %s", str(response.text) + ) + return None + + verbose_logger.debug( + "GCS object download response status code: %s", response.status_code + ) + + # Return the content of the downloaded object + return response.content + + except Exception as e: + verbose_logger.error("GCS object download error: %s", str(e)) + return None + + async def delete_gcs_object(self, object_name: str, **kwargs): + """ + Delete an object from GCS. + """ + try: + gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config( + kwargs=kwargs + ) + headers = await self.construct_request_headers( + vertex_instance=gcs_logging_config["vertex_instance"], + service_account_json=gcs_logging_config["path_service_account"], + ) + bucket_name = gcs_logging_config["bucket_name"] + bucket_name, object_name = self._handle_folders_in_bucket_name( + bucket_name=bucket_name, + object_name=object_name, + ) + + url = f"https://storage.googleapis.com/storage/v1/b/{bucket_name}/o/{object_name}" + + # Send the DELETE request to delete the object + response = await self.async_httpx_client.delete(url=url, headers=headers) + + if (response.status_code != 200) or (response.status_code != 204): + verbose_logger.error( + "GCS object delete error: %s, status code: %s", + str(response.text), + response.status_code, + ) + return None + + verbose_logger.debug( + "GCS object delete response status code: %s, response: %s", + response.status_code, + response.text, + ) + + # Return the content of the downloaded object + return response.text + + except Exception as e: + verbose_logger.error("GCS object download error: %s", str(e)) + return None + + async def _log_json_data_on_gcs( + self, + headers: Dict[str, str], + bucket_name: str, + object_name: str, + logging_payload: Union[StandardLoggingPayload, str], + ): + """ + Helper function to make POST request to GCS Bucket in the specified bucket. + """ + if isinstance(logging_payload, str): + json_logged_payload = logging_payload + else: + json_logged_payload = json.dumps(logging_payload, default=str) + + bucket_name, object_name = self._handle_folders_in_bucket_name( + bucket_name=bucket_name, + object_name=object_name, + ) + + response = await self.async_httpx_client.post( + headers=headers, + url=f"https://storage.googleapis.com/upload/storage/v1/b/{bucket_name}/o?uploadType=media&name={object_name}", + data=json_logged_payload, + ) + + if response.status_code != 200: + verbose_logger.error("GCS Bucket logging error: %s", str(response.text)) + + verbose_logger.debug("GCS Bucket response %s", response) + verbose_logger.debug("GCS Bucket status code %s", response.status_code) + verbose_logger.debug("GCS Bucket response.text %s", response.text) + + return response.json() diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/gcs_pubsub/pub_sub.py b/.venv/lib/python3.12/site-packages/litellm/integrations/gcs_pubsub/pub_sub.py new file mode 100644 index 00000000..e94c853f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/integrations/gcs_pubsub/pub_sub.py @@ -0,0 +1,203 @@ +""" +BETA + +This is the PubSub logger for GCS PubSub, this sends LiteLLM SpendLogs Payloads to GCS PubSub. + +Users can use this instead of sending their SpendLogs to their Postgres database. +""" + +import asyncio +import json +import os +import traceback +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +if TYPE_CHECKING: + from litellm.proxy._types import SpendLogsPayload +else: + SpendLogsPayload = Any + +from litellm._logging import verbose_logger +from litellm.integrations.custom_batch_logger import CustomBatchLogger +from litellm.llms.custom_httpx.http_handler import ( + get_async_httpx_client, + httpxSpecialProvider, +) + + +class GcsPubSubLogger(CustomBatchLogger): + def __init__( + self, + project_id: Optional[str] = None, + topic_id: Optional[str] = None, + credentials_path: Optional[str] = None, + **kwargs, + ): + """ + Initialize Google Cloud Pub/Sub publisher + + Args: + project_id (str): Google Cloud project ID + topic_id (str): Pub/Sub topic ID + credentials_path (str, optional): Path to Google Cloud credentials JSON file + """ + from litellm.proxy.utils import _premium_user_check + + _premium_user_check() + + self.async_httpx_client = get_async_httpx_client( + llm_provider=httpxSpecialProvider.LoggingCallback + ) + + self.project_id = project_id or os.getenv("GCS_PUBSUB_PROJECT_ID") + self.topic_id = topic_id or os.getenv("GCS_PUBSUB_TOPIC_ID") + self.path_service_account_json = credentials_path or os.getenv( + "GCS_PATH_SERVICE_ACCOUNT" + ) + + if not self.project_id or not self.topic_id: + raise ValueError("Both project_id and topic_id must be provided") + + self.flush_lock = asyncio.Lock() + super().__init__(**kwargs, flush_lock=self.flush_lock) + asyncio.create_task(self.periodic_flush()) + self.log_queue: List[SpendLogsPayload] = [] + + async def construct_request_headers(self) -> Dict[str, str]: + """Construct authorization headers using Vertex AI auth""" + from litellm import vertex_chat_completion + + _auth_header, vertex_project = ( + await vertex_chat_completion._ensure_access_token_async( + credentials=self.path_service_account_json, + project_id=None, + custom_llm_provider="vertex_ai", + ) + ) + + auth_header, _ = vertex_chat_completion._get_token_and_url( + model="pub-sub", + auth_header=_auth_header, + vertex_credentials=self.path_service_account_json, + vertex_project=vertex_project, + vertex_location=None, + gemini_api_key=None, + stream=None, + custom_llm_provider="vertex_ai", + api_base=None, + ) + + headers = { + "Authorization": f"Bearer {auth_header}", + "Content-Type": "application/json", + } + return headers + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + """ + Async Log success events to GCS PubSub Topic + + - Creates a SpendLogsPayload + - Adds to batch queue + - Flushes based on CustomBatchLogger settings + + Raises: + Raises a NON Blocking verbose_logger.exception if an error occurs + """ + from litellm.proxy.spend_tracking.spend_tracking_utils import ( + get_logging_payload, + ) + from litellm.proxy.utils import _premium_user_check + + _premium_user_check() + + try: + verbose_logger.debug( + "PubSub: Logging - Enters logging function for model %s", kwargs + ) + spend_logs_payload = get_logging_payload( + kwargs=kwargs, + response_obj=response_obj, + start_time=start_time, + end_time=end_time, + ) + self.log_queue.append(spend_logs_payload) + + if len(self.log_queue) >= self.batch_size: + await self.async_send_batch() + + except Exception as e: + verbose_logger.exception( + f"PubSub Layer Error - {str(e)}\n{traceback.format_exc()}" + ) + pass + + async def async_send_batch(self): + """ + Sends the batch of messages to Pub/Sub + """ + try: + if not self.log_queue: + return + + verbose_logger.debug( + f"PubSub - about to flush {len(self.log_queue)} events" + ) + + for message in self.log_queue: + await self.publish_message(message) + + except Exception as e: + verbose_logger.exception( + f"PubSub Error sending batch - {str(e)}\n{traceback.format_exc()}" + ) + finally: + self.log_queue.clear() + + async def publish_message( + self, message: SpendLogsPayload + ) -> Optional[Dict[str, Any]]: + """ + Publish message to Google Cloud Pub/Sub using REST API + + Args: + message: Message to publish (dict or string) + + Returns: + dict: Published message response + """ + try: + headers = await self.construct_request_headers() + + # Prepare message data + if isinstance(message, str): + message_data = message + else: + message_data = json.dumps(message, default=str) + + # Base64 encode the message + import base64 + + encoded_message = base64.b64encode(message_data.encode("utf-8")).decode( + "utf-8" + ) + + # Construct request body + request_body = {"messages": [{"data": encoded_message}]} + + url = f"https://pubsub.googleapis.com/v1/projects/{self.project_id}/topics/{self.topic_id}:publish" + + response = await self.async_httpx_client.post( + url=url, headers=headers, json=request_body + ) + + if response.status_code not in [200, 202]: + verbose_logger.error("Pub/Sub publish error: %s", str(response.text)) + raise Exception(f"Failed to publish message: {response.text}") + + verbose_logger.debug("Pub/Sub response: %s", response.text) + return response.json() + + except Exception as e: + verbose_logger.error("Pub/Sub publish error: %s", str(e)) + return None diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/greenscale.py b/.venv/lib/python3.12/site-packages/litellm/integrations/greenscale.py new file mode 100644 index 00000000..430c3d0a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/integrations/greenscale.py @@ -0,0 +1,72 @@ +import json +import traceback +from datetime import datetime, timezone + +import litellm + + +class GreenscaleLogger: + def __init__(self): + import os + + self.greenscale_api_key = os.getenv("GREENSCALE_API_KEY") + self.headers = { + "api-key": self.greenscale_api_key, + "Content-Type": "application/json", + } + self.greenscale_logging_url = os.getenv("GREENSCALE_ENDPOINT") + + def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose): + try: + response_json = response_obj.model_dump() if response_obj else {} + data = { + "modelId": kwargs.get("model"), + "inputTokenCount": response_json.get("usage", {}).get("prompt_tokens"), + "outputTokenCount": response_json.get("usage", {}).get( + "completion_tokens" + ), + } + data["timestamp"] = datetime.now(timezone.utc).strftime( + "%Y-%m-%dT%H:%M:%SZ" + ) + + if type(end_time) is datetime and type(start_time) is datetime: + data["invocationLatency"] = int( + (end_time - start_time).total_seconds() * 1000 + ) + + # Add additional metadata keys to tags + tags = [] + metadata = kwargs.get("litellm_params", {}).get("metadata", {}) + for key, value in metadata.items(): + if key.startswith("greenscale"): + if key == "greenscale_project": + data["project"] = value + elif key == "greenscale_application": + data["application"] = value + else: + tags.append( + {"key": key.replace("greenscale_", ""), "value": str(value)} + ) + + data["tags"] = tags + + if self.greenscale_logging_url is None: + raise Exception("Greenscale Logger Error - No logging URL found") + + response = litellm.module_level_client.post( + self.greenscale_logging_url, + headers=self.headers, + data=json.dumps(data, default=str), + ) + if response.status_code != 200: + print_verbose( + f"Greenscale Logger Error - {response.text}, {response.status_code}" + ) + else: + print_verbose(f"Greenscale Logger Succeeded - {response.text}") + except Exception as e: + print_verbose( + f"Greenscale Logger Error - {e}, Stack trace: {traceback.format_exc()}" + ) + pass diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/helicone.py b/.venv/lib/python3.12/site-packages/litellm/integrations/helicone.py new file mode 100644 index 00000000..a526a74f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/integrations/helicone.py @@ -0,0 +1,188 @@ +#### What this does #### +# On success, logs events to Helicone +import os +import traceback + +import litellm + + +class HeliconeLogger: + # Class variables or attributes + helicone_model_list = [ + "gpt", + "claude", + "command-r", + "command-r-plus", + "command-light", + "command-medium", + "command-medium-beta", + "command-xlarge-nightly", + "command-nightly", + ] + + def __init__(self): + # Instance variables + self.provider_url = "https://api.openai.com/v1" + self.key = os.getenv("HELICONE_API_KEY") + + def claude_mapping(self, model, messages, response_obj): + from anthropic import AI_PROMPT, HUMAN_PROMPT + + prompt = f"{HUMAN_PROMPT}" + for message in messages: + if "role" in message: + if message["role"] == "user": + prompt += f"{HUMAN_PROMPT}{message['content']}" + else: + prompt += f"{AI_PROMPT}{message['content']}" + else: + prompt += f"{HUMAN_PROMPT}{message['content']}" + prompt += f"{AI_PROMPT}" + + choice = response_obj["choices"][0] + message = choice["message"] + + content = [] + if "tool_calls" in message and message["tool_calls"]: + for tool_call in message["tool_calls"]: + content.append( + { + "type": "tool_use", + "id": tool_call["id"], + "name": tool_call["function"]["name"], + "input": tool_call["function"]["arguments"], + } + ) + elif "content" in message and message["content"]: + content = [{"type": "text", "text": message["content"]}] + + claude_response_obj = { + "id": response_obj["id"], + "type": "message", + "role": "assistant", + "model": model, + "content": content, + "stop_reason": choice["finish_reason"], + "stop_sequence": None, + "usage": { + "input_tokens": response_obj["usage"]["prompt_tokens"], + "output_tokens": response_obj["usage"]["completion_tokens"], + }, + } + + return claude_response_obj + + @staticmethod + def add_metadata_from_header(litellm_params: dict, metadata: dict) -> dict: + """ + Adds metadata from proxy request headers to Helicone logging if keys start with "helicone_" + and overwrites litellm_params.metadata if already included. + + For example if you want to add custom property to your request, send + `headers: { ..., helicone-property-something: 1234 }` via proxy request. + """ + if litellm_params is None: + return metadata + + if litellm_params.get("proxy_server_request") is None: + return metadata + + if metadata is None: + metadata = {} + + proxy_headers = ( + litellm_params.get("proxy_server_request", {}).get("headers", {}) or {} + ) + + for header_key in proxy_headers: + if header_key.startswith("helicone_"): + metadata[header_key] = proxy_headers.get(header_key) + + return metadata + + def log_success( + self, model, messages, response_obj, start_time, end_time, print_verbose, kwargs + ): + # Method definition + try: + print_verbose( + f"Helicone Logging - Enters logging function for model {model}" + ) + litellm_params = kwargs.get("litellm_params", {}) + kwargs.get("litellm_call_id", None) + metadata = litellm_params.get("metadata", {}) or {} + metadata = self.add_metadata_from_header(litellm_params, metadata) + model = ( + model + if any( + accepted_model in model + for accepted_model in self.helicone_model_list + ) + else "gpt-3.5-turbo" + ) + provider_request = {"model": model, "messages": messages} + if isinstance(response_obj, litellm.EmbeddingResponse) or isinstance( + response_obj, litellm.ModelResponse + ): + response_obj = response_obj.json() + + if "claude" in model: + response_obj = self.claude_mapping( + model=model, messages=messages, response_obj=response_obj + ) + + providerResponse = { + "json": response_obj, + "headers": {"openai-version": "2020-10-01"}, + "status": 200, + } + + # Code to be executed + provider_url = self.provider_url + url = "https://api.hconeai.com/oai/v1/log" + if "claude" in model: + url = "https://api.hconeai.com/anthropic/v1/log" + provider_url = "https://api.anthropic.com/v1/messages" + headers = { + "Authorization": f"Bearer {self.key}", + "Content-Type": "application/json", + } + start_time_seconds = int(start_time.timestamp()) + start_time_milliseconds = int( + (start_time.timestamp() - start_time_seconds) * 1000 + ) + end_time_seconds = int(end_time.timestamp()) + end_time_milliseconds = int( + (end_time.timestamp() - end_time_seconds) * 1000 + ) + meta = {"Helicone-Auth": f"Bearer {self.key}"} + meta.update(metadata) + data = { + "providerRequest": { + "url": provider_url, + "json": provider_request, + "meta": meta, + }, + "providerResponse": providerResponse, + "timing": { + "startTime": { + "seconds": start_time_seconds, + "milliseconds": start_time_milliseconds, + }, + "endTime": { + "seconds": end_time_seconds, + "milliseconds": end_time_milliseconds, + }, + }, # {"seconds": .., "milliseconds": ..} + } + response = litellm.module_level_client.post(url, headers=headers, json=data) + if response.status_code == 200: + print_verbose("Helicone Logging - Success!") + else: + print_verbose( + f"Helicone Logging - Error Request was not successful. Status Code: {response.status_code}" + ) + print_verbose(f"Helicone Logging - Error {response.text}") + except Exception: + print_verbose(f"Helicone Logging Error - {traceback.format_exc()}") + pass diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/humanloop.py b/.venv/lib/python3.12/site-packages/litellm/integrations/humanloop.py new file mode 100644 index 00000000..fd3463f9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/integrations/humanloop.py @@ -0,0 +1,197 @@ +""" +Humanloop integration + +https://humanloop.com/ +""" + +from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union, cast + +import httpx + +import litellm +from litellm.caching import DualCache +from litellm.llms.custom_httpx.http_handler import _get_httpx_client +from litellm.secret_managers.main import get_secret_str +from litellm.types.llms.openai import AllMessageValues +from litellm.types.utils import StandardCallbackDynamicParams + +from .custom_logger import CustomLogger + + +class PromptManagementClient(TypedDict): + prompt_id: str + prompt_template: List[AllMessageValues] + model: Optional[str] + optional_params: Optional[Dict[str, Any]] + + +class HumanLoopPromptManager(DualCache): + @property + def integration_name(self): + return "humanloop" + + def _get_prompt_from_id_cache( + self, humanloop_prompt_id: str + ) -> Optional[PromptManagementClient]: + return cast( + Optional[PromptManagementClient], self.get_cache(key=humanloop_prompt_id) + ) + + def _compile_prompt_helper( + self, prompt_template: List[AllMessageValues], prompt_variables: Dict[str, Any] + ) -> List[AllMessageValues]: + """ + Helper function to compile the prompt by substituting variables in the template. + + Args: + prompt_template: List[AllMessageValues] + prompt_variables (dict): A dictionary of variables to substitute into the prompt template. + + Returns: + list: A list of dictionaries with variables substituted. + """ + compiled_prompts: List[AllMessageValues] = [] + + for template in prompt_template: + tc = template.get("content") + if tc and isinstance(tc, str): + formatted_template = tc.replace("{{", "{").replace("}}", "}") + compiled_content = formatted_template.format(**prompt_variables) + template["content"] = compiled_content + compiled_prompts.append(template) + + return compiled_prompts + + def _get_prompt_from_id_api( + self, humanloop_prompt_id: str, humanloop_api_key: str + ) -> PromptManagementClient: + client = _get_httpx_client() + + base_url = "https://api.humanloop.com/v5/prompts/{}".format(humanloop_prompt_id) + + response = client.get( + url=base_url, + headers={ + "X-Api-Key": humanloop_api_key, + "Content-Type": "application/json", + }, + ) + + try: + response.raise_for_status() + except httpx.HTTPStatusError as e: + raise Exception(f"Error getting prompt from Humanloop: {e.response.text}") + + json_response = response.json() + template_message = json_response["template"] + if isinstance(template_message, dict): + template_messages = [template_message] + elif isinstance(template_message, list): + template_messages = template_message + else: + raise ValueError(f"Invalid template message type: {type(template_message)}") + template_model = json_response["model"] + optional_params = {} + for k, v in json_response.items(): + if k in litellm.OPENAI_CHAT_COMPLETION_PARAMS: + optional_params[k] = v + return PromptManagementClient( + prompt_id=humanloop_prompt_id, + prompt_template=cast(List[AllMessageValues], template_messages), + model=template_model, + optional_params=optional_params, + ) + + def _get_prompt_from_id( + self, humanloop_prompt_id: str, humanloop_api_key: str + ) -> PromptManagementClient: + prompt = self._get_prompt_from_id_cache(humanloop_prompt_id) + if prompt is None: + prompt = self._get_prompt_from_id_api( + humanloop_prompt_id, humanloop_api_key + ) + self.set_cache( + key=humanloop_prompt_id, + value=prompt, + ttl=litellm.HUMANLOOP_PROMPT_CACHE_TTL_SECONDS, + ) + return prompt + + def compile_prompt( + self, + prompt_template: List[AllMessageValues], + prompt_variables: Optional[dict], + ) -> List[AllMessageValues]: + compiled_prompt: Optional[Union[str, list]] = None + + if prompt_variables is None: + prompt_variables = {} + + compiled_prompt = self._compile_prompt_helper( + prompt_template=prompt_template, + prompt_variables=prompt_variables, + ) + + return compiled_prompt + + def _get_model_from_prompt( + self, prompt_management_client: PromptManagementClient, model: str + ) -> str: + if prompt_management_client["model"] is not None: + return prompt_management_client["model"] + else: + return model.replace("{}/".format(self.integration_name), "") + + +prompt_manager = HumanLoopPromptManager() + + +class HumanloopLogger(CustomLogger): + def get_chat_completion_prompt( + self, + model: str, + messages: List[AllMessageValues], + non_default_params: dict, + prompt_id: str, + prompt_variables: Optional[dict], + dynamic_callback_params: StandardCallbackDynamicParams, + ) -> Tuple[ + str, + List[AllMessageValues], + dict, + ]: + humanloop_api_key = dynamic_callback_params.get( + "humanloop_api_key" + ) or get_secret_str("HUMANLOOP_API_KEY") + + if humanloop_api_key is None: + return super().get_chat_completion_prompt( + model=model, + messages=messages, + non_default_params=non_default_params, + prompt_id=prompt_id, + prompt_variables=prompt_variables, + dynamic_callback_params=dynamic_callback_params, + ) + + prompt_template = prompt_manager._get_prompt_from_id( + humanloop_prompt_id=prompt_id, humanloop_api_key=humanloop_api_key + ) + + updated_messages = prompt_manager.compile_prompt( + prompt_template=prompt_template["prompt_template"], + prompt_variables=prompt_variables, + ) + + prompt_template_optional_params = prompt_template["optional_params"] or {} + + updated_non_default_params = { + **non_default_params, + **prompt_template_optional_params, + } + + model = prompt_manager._get_model_from_prompt( + prompt_management_client=prompt_template, model=model + ) + + return model, updated_messages, updated_non_default_params diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/lago.py b/.venv/lib/python3.12/site-packages/litellm/integrations/lago.py new file mode 100644 index 00000000..5dfb1ce0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/integrations/lago.py @@ -0,0 +1,202 @@ +# What is this? +## On Success events log cost to Lago - https://github.com/BerriAI/litellm/issues/3639 + +import json +import os +import uuid +from typing import Literal, Optional + +import httpx + +import litellm +from litellm._logging import verbose_logger +from litellm.integrations.custom_logger import CustomLogger +from litellm.llms.custom_httpx.http_handler import ( + HTTPHandler, + get_async_httpx_client, + httpxSpecialProvider, +) + + +def get_utc_datetime(): + import datetime as dt + from datetime import datetime + + if hasattr(dt, "UTC"): + return datetime.now(dt.UTC) # type: ignore + else: + return datetime.utcnow() # type: ignore + + +class LagoLogger(CustomLogger): + def __init__(self) -> None: + super().__init__() + self.validate_environment() + self.async_http_handler = get_async_httpx_client( + llm_provider=httpxSpecialProvider.LoggingCallback + ) + self.sync_http_handler = HTTPHandler() + + def validate_environment(self): + """ + Expects + LAGO_API_BASE, + LAGO_API_KEY, + LAGO_API_EVENT_CODE, + + Optional: + LAGO_API_CHARGE_BY + + in the environment + """ + missing_keys = [] + if os.getenv("LAGO_API_KEY", None) is None: + missing_keys.append("LAGO_API_KEY") + + if os.getenv("LAGO_API_BASE", None) is None: + missing_keys.append("LAGO_API_BASE") + + if os.getenv("LAGO_API_EVENT_CODE", None) is None: + missing_keys.append("LAGO_API_EVENT_CODE") + + if len(missing_keys) > 0: + raise Exception("Missing keys={} in environment.".format(missing_keys)) + + def _common_logic(self, kwargs: dict, response_obj) -> dict: + response_obj.get("id", kwargs.get("litellm_call_id")) + get_utc_datetime().isoformat() + cost = kwargs.get("response_cost", None) + model = kwargs.get("model") + usage = {} + + if ( + isinstance(response_obj, litellm.ModelResponse) + or isinstance(response_obj, litellm.EmbeddingResponse) + ) and hasattr(response_obj, "usage"): + usage = { + "prompt_tokens": response_obj["usage"].get("prompt_tokens", 0), + "completion_tokens": response_obj["usage"].get("completion_tokens", 0), + "total_tokens": response_obj["usage"].get("total_tokens"), + } + + litellm_params = kwargs.get("litellm_params", {}) or {} + proxy_server_request = litellm_params.get("proxy_server_request") or {} + end_user_id = proxy_server_request.get("body", {}).get("user", None) + user_id = litellm_params["metadata"].get("user_api_key_user_id", None) + team_id = litellm_params["metadata"].get("user_api_key_team_id", None) + litellm_params["metadata"].get("user_api_key_org_id", None) + + charge_by: Literal["end_user_id", "team_id", "user_id"] = "end_user_id" + external_customer_id: Optional[str] = None + + if os.getenv("LAGO_API_CHARGE_BY", None) is not None and isinstance( + os.environ["LAGO_API_CHARGE_BY"], str + ): + if os.environ["LAGO_API_CHARGE_BY"] in [ + "end_user_id", + "user_id", + "team_id", + ]: + charge_by = os.environ["LAGO_API_CHARGE_BY"] # type: ignore + else: + raise Exception("invalid LAGO_API_CHARGE_BY set") + + if charge_by == "end_user_id": + external_customer_id = end_user_id + elif charge_by == "team_id": + external_customer_id = team_id + elif charge_by == "user_id": + external_customer_id = user_id + + if external_customer_id is None: + raise Exception( + "External Customer ID is not set. Charge_by={}. User_id={}. End_user_id={}. Team_id={}".format( + charge_by, user_id, end_user_id, team_id + ) + ) + + returned_val = { + "event": { + "transaction_id": str(uuid.uuid4()), + "external_subscription_id": external_customer_id, + "code": os.getenv("LAGO_API_EVENT_CODE"), + "properties": {"model": model, "response_cost": cost, **usage}, + } + } + + verbose_logger.debug( + "\033[91mLogged Lago Object:\n{}\033[0m\n".format(returned_val) + ) + return returned_val + + def log_success_event(self, kwargs, response_obj, start_time, end_time): + _url = os.getenv("LAGO_API_BASE") + assert _url is not None and isinstance( + _url, str + ), "LAGO_API_BASE missing or not set correctly. LAGO_API_BASE={}".format(_url) + if _url.endswith("/"): + _url += "api/v1/events" + else: + _url += "/api/v1/events" + + api_key = os.getenv("LAGO_API_KEY") + + _data = self._common_logic(kwargs=kwargs, response_obj=response_obj) + _headers = { + "Content-Type": "application/json", + "Authorization": "Bearer {}".format(api_key), + } + + try: + response = self.sync_http_handler.post( + url=_url, + data=json.dumps(_data), + headers=_headers, + ) + + response.raise_for_status() + except Exception as e: + error_response = getattr(e, "response", None) + if error_response is not None and hasattr(error_response, "text"): + verbose_logger.debug(f"\nError Message: {error_response.text}") + raise e + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + try: + verbose_logger.debug("ENTERS LAGO CALLBACK") + _url = os.getenv("LAGO_API_BASE") + assert _url is not None and isinstance( + _url, str + ), "LAGO_API_BASE missing or not set correctly. LAGO_API_BASE={}".format( + _url + ) + if _url.endswith("/"): + _url += "api/v1/events" + else: + _url += "/api/v1/events" + + api_key = os.getenv("LAGO_API_KEY") + + _data = self._common_logic(kwargs=kwargs, response_obj=response_obj) + _headers = { + "Content-Type": "application/json", + "Authorization": "Bearer {}".format(api_key), + } + except Exception as e: + raise e + + response: Optional[httpx.Response] = None + try: + response = await self.async_http_handler.post( + url=_url, + data=json.dumps(_data), + headers=_headers, + ) + + response.raise_for_status() + + verbose_logger.debug(f"Logged Lago Object: {response.text}") + except Exception as e: + if response is not None and hasattr(response, "text"): + verbose_logger.debug(f"\nError Message: {response.text}") + raise e diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/langfuse/langfuse.py b/.venv/lib/python3.12/site-packages/litellm/integrations/langfuse/langfuse.py new file mode 100644 index 00000000..f990a316 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/integrations/langfuse/langfuse.py @@ -0,0 +1,955 @@ +#### What this does #### +# On success, logs events to Langfuse +import copy +import os +import traceback +from datetime import datetime +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast + +from packaging.version import Version + +import litellm +from litellm._logging import verbose_logger +from litellm.litellm_core_utils.redact_messages import redact_user_api_key_info +from litellm.llms.custom_httpx.http_handler import _get_httpx_client +from litellm.secret_managers.main import str_to_bool +from litellm.types.integrations.langfuse import * +from litellm.types.llms.openai import HttpxBinaryResponseContent +from litellm.types.utils import ( + EmbeddingResponse, + ImageResponse, + ModelResponse, + RerankResponse, + StandardLoggingPayload, + StandardLoggingPromptManagementMetadata, + TextCompletionResponse, + TranscriptionResponse, +) + +if TYPE_CHECKING: + from litellm.litellm_core_utils.litellm_logging import DynamicLoggingCache +else: + DynamicLoggingCache = Any + + +class LangFuseLogger: + # Class variables or attributes + def __init__( + self, + langfuse_public_key=None, + langfuse_secret=None, + langfuse_host=None, + flush_interval=1, + ): + try: + import langfuse + from langfuse import Langfuse + except Exception as e: + raise Exception( + f"\033[91mLangfuse not installed, try running 'pip install langfuse' to fix this error: {e}\n{traceback.format_exc()}\033[0m" + ) + # Instance variables + self.secret_key = langfuse_secret or os.getenv("LANGFUSE_SECRET_KEY") + self.public_key = langfuse_public_key or os.getenv("LANGFUSE_PUBLIC_KEY") + self.langfuse_host = langfuse_host or os.getenv( + "LANGFUSE_HOST", "https://cloud.langfuse.com" + ) + if not ( + self.langfuse_host.startswith("http://") + or self.langfuse_host.startswith("https://") + ): + # add http:// if unset, assume communicating over private network - e.g. render + self.langfuse_host = "http://" + self.langfuse_host + self.langfuse_release = os.getenv("LANGFUSE_RELEASE") + self.langfuse_debug = os.getenv("LANGFUSE_DEBUG") + self.langfuse_flush_interval = LangFuseLogger._get_langfuse_flush_interval( + flush_interval + ) + http_client = _get_httpx_client() + self.langfuse_client = http_client.client + + parameters = { + "public_key": self.public_key, + "secret_key": self.secret_key, + "host": self.langfuse_host, + "release": self.langfuse_release, + "debug": self.langfuse_debug, + "flush_interval": self.langfuse_flush_interval, # flush interval in seconds + "httpx_client": self.langfuse_client, + } + self.langfuse_sdk_version: str = langfuse.version.__version__ + + if Version(self.langfuse_sdk_version) >= Version("2.6.0"): + parameters["sdk_integration"] = "litellm" + + self.Langfuse = Langfuse(**parameters) + + # set the current langfuse project id in the environ + # this is used by Alerting to link to the correct project + try: + project_id = self.Langfuse.client.projects.get().data[0].id + os.environ["LANGFUSE_PROJECT_ID"] = project_id + except Exception: + project_id = None + + if os.getenv("UPSTREAM_LANGFUSE_SECRET_KEY") is not None: + upstream_langfuse_debug = ( + str_to_bool(self.upstream_langfuse_debug) + if self.upstream_langfuse_debug is not None + else None + ) + self.upstream_langfuse_secret_key = os.getenv( + "UPSTREAM_LANGFUSE_SECRET_KEY" + ) + self.upstream_langfuse_public_key = os.getenv( + "UPSTREAM_LANGFUSE_PUBLIC_KEY" + ) + self.upstream_langfuse_host = os.getenv("UPSTREAM_LANGFUSE_HOST") + self.upstream_langfuse_release = os.getenv("UPSTREAM_LANGFUSE_RELEASE") + self.upstream_langfuse_debug = os.getenv("UPSTREAM_LANGFUSE_DEBUG") + self.upstream_langfuse = Langfuse( + public_key=self.upstream_langfuse_public_key, + secret_key=self.upstream_langfuse_secret_key, + host=self.upstream_langfuse_host, + release=self.upstream_langfuse_release, + debug=( + upstream_langfuse_debug + if upstream_langfuse_debug is not None + else False + ), + ) + else: + self.upstream_langfuse = None + + @staticmethod + def add_metadata_from_header(litellm_params: dict, metadata: dict) -> dict: + """ + Adds metadata from proxy request headers to Langfuse logging if keys start with "langfuse_" + and overwrites litellm_params.metadata if already included. + + For example if you want to append your trace to an existing `trace_id` via header, send + `headers: { ..., langfuse_existing_trace_id: your-existing-trace-id }` via proxy request. + """ + if litellm_params is None: + return metadata + + if litellm_params.get("proxy_server_request") is None: + return metadata + + if metadata is None: + metadata = {} + + proxy_headers = ( + litellm_params.get("proxy_server_request", {}).get("headers", {}) or {} + ) + + for metadata_param_key in proxy_headers: + if metadata_param_key.startswith("langfuse_"): + trace_param_key = metadata_param_key.replace("langfuse_", "", 1) + if trace_param_key in metadata: + verbose_logger.warning( + f"Overwriting Langfuse `{trace_param_key}` from request header" + ) + else: + verbose_logger.debug( + f"Found Langfuse `{trace_param_key}` in request header" + ) + metadata[trace_param_key] = proxy_headers.get(metadata_param_key) + + return metadata + + def log_event_on_langfuse( + self, + kwargs: dict, + response_obj: Union[ + None, + dict, + EmbeddingResponse, + ModelResponse, + TextCompletionResponse, + ImageResponse, + TranscriptionResponse, + RerankResponse, + HttpxBinaryResponseContent, + ], + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None, + user_id: Optional[str] = None, + level: str = "DEFAULT", + status_message: Optional[str] = None, + ) -> dict: + """ + Logs a success or error event on Langfuse + """ + try: + verbose_logger.debug( + f"Langfuse Logging - Enters logging function for model {kwargs}" + ) + + # set default values for input/output for langfuse logging + input = None + output = None + + litellm_params = kwargs.get("litellm_params", {}) + litellm_call_id = kwargs.get("litellm_call_id", None) + metadata = ( + litellm_params.get("metadata", {}) or {} + ) # if litellm_params['metadata'] == None + metadata = self.add_metadata_from_header(litellm_params, metadata) + optional_params = copy.deepcopy(kwargs.get("optional_params", {})) + + prompt = {"messages": kwargs.get("messages")} + + functions = optional_params.pop("functions", None) + tools = optional_params.pop("tools", None) + if functions is not None: + prompt["functions"] = functions + if tools is not None: + prompt["tools"] = tools + + # langfuse only accepts str, int, bool, float for logging + for param, value in optional_params.items(): + if not isinstance(value, (str, int, bool, float)): + try: + optional_params[param] = str(value) + except Exception: + # if casting value to str fails don't block logging + pass + + input, output = self._get_langfuse_input_output_content( + kwargs=kwargs, + response_obj=response_obj, + prompt=prompt, + level=level, + status_message=status_message, + ) + verbose_logger.debug( + f"OUTPUT IN LANGFUSE: {output}; original: {response_obj}" + ) + trace_id = None + generation_id = None + if self._is_langfuse_v2(): + trace_id, generation_id = self._log_langfuse_v2( + user_id=user_id, + metadata=metadata, + litellm_params=litellm_params, + output=output, + start_time=start_time, + end_time=end_time, + kwargs=kwargs, + optional_params=optional_params, + input=input, + response_obj=response_obj, + level=level, + litellm_call_id=litellm_call_id, + ) + elif response_obj is not None: + self._log_langfuse_v1( + user_id=user_id, + metadata=metadata, + output=output, + start_time=start_time, + end_time=end_time, + kwargs=kwargs, + optional_params=optional_params, + input=input, + response_obj=response_obj, + ) + verbose_logger.debug( + f"Langfuse Layer Logging - final response object: {response_obj}" + ) + verbose_logger.info("Langfuse Layer Logging - logging success") + + return {"trace_id": trace_id, "generation_id": generation_id} + except Exception as e: + verbose_logger.exception( + "Langfuse Layer Error(): Exception occured - {}".format(str(e)) + ) + return {"trace_id": None, "generation_id": None} + + def _get_langfuse_input_output_content( + self, + kwargs: dict, + response_obj: Union[ + None, + dict, + EmbeddingResponse, + ModelResponse, + TextCompletionResponse, + ImageResponse, + TranscriptionResponse, + RerankResponse, + HttpxBinaryResponseContent, + ], + prompt: dict, + level: str, + status_message: Optional[str], + ) -> Tuple[Optional[dict], Optional[Union[str, dict, list]]]: + """ + Get the input and output content for Langfuse logging + + Args: + kwargs: The keyword arguments passed to the function + response_obj: The response object returned by the function + prompt: The prompt used to generate the response + level: The level of the log message + status_message: The status message of the log message + + Returns: + input: The input content for Langfuse logging + output: The output content for Langfuse logging + """ + input = None + output: Optional[Union[str, dict, List[Any]]] = None + if ( + level == "ERROR" + and status_message is not None + and isinstance(status_message, str) + ): + input = prompt + output = status_message + elif response_obj is not None and ( + kwargs.get("call_type", None) == "embedding" + or isinstance(response_obj, litellm.EmbeddingResponse) + ): + input = prompt + output = None + elif response_obj is not None and isinstance( + response_obj, litellm.ModelResponse + ): + input = prompt + output = self._get_chat_content_for_langfuse(response_obj) + elif response_obj is not None and isinstance( + response_obj, litellm.HttpxBinaryResponseContent + ): + input = prompt + output = "speech-output" + elif response_obj is not None and isinstance( + response_obj, litellm.TextCompletionResponse + ): + input = prompt + output = self._get_text_completion_content_for_langfuse(response_obj) + elif response_obj is not None and isinstance( + response_obj, litellm.ImageResponse + ): + input = prompt + output = response_obj.get("data", None) + elif response_obj is not None and isinstance( + response_obj, litellm.TranscriptionResponse + ): + input = prompt + output = response_obj.get("text", None) + elif response_obj is not None and isinstance( + response_obj, litellm.RerankResponse + ): + input = prompt + output = response_obj.results + elif ( + kwargs.get("call_type") is not None + and kwargs.get("call_type") == "_arealtime" + and response_obj is not None + and isinstance(response_obj, list) + ): + input = kwargs.get("input") + output = response_obj + elif ( + kwargs.get("call_type") is not None + and kwargs.get("call_type") == "pass_through_endpoint" + and response_obj is not None + and isinstance(response_obj, dict) + ): + input = prompt + output = response_obj.get("response", "") + return input, output + + async def _async_log_event( + self, kwargs, response_obj, start_time, end_time, user_id + ): + """ + Langfuse SDK uses a background thread to log events + + This approach does not impact latency and runs in the background + """ + + def _is_langfuse_v2(self): + import langfuse + + return Version(langfuse.version.__version__) >= Version("2.0.0") + + def _log_langfuse_v1( + self, + user_id, + metadata, + output, + start_time, + end_time, + kwargs, + optional_params, + input, + response_obj, + ): + from langfuse.model import CreateGeneration, CreateTrace # type: ignore + + verbose_logger.warning( + "Please upgrade langfuse to v2.0.0 or higher: https://github.com/langfuse/langfuse-python/releases/tag/v2.0.1" + ) + + trace = self.Langfuse.trace( # type: ignore + CreateTrace( # type: ignore + name=metadata.get("generation_name", "litellm-completion"), + input=input, + output=output, + userId=user_id, + ) + ) + + trace.generation( + CreateGeneration( + name=metadata.get("generation_name", "litellm-completion"), + startTime=start_time, + endTime=end_time, + model=kwargs["model"], + modelParameters=optional_params, + prompt=input, + completion=output, + usage={ + "prompt_tokens": response_obj.usage.prompt_tokens, + "completion_tokens": response_obj.usage.completion_tokens, + }, + metadata=metadata, + ) + ) + + def _log_langfuse_v2( # noqa: PLR0915 + self, + user_id: Optional[str], + metadata: dict, + litellm_params: dict, + output: Optional[Union[str, dict, list]], + start_time: Optional[datetime], + end_time: Optional[datetime], + kwargs: dict, + optional_params: dict, + input: Optional[dict], + response_obj, + level: str, + litellm_call_id: Optional[str], + ) -> tuple: + verbose_logger.debug("Langfuse Layer Logging - logging to langfuse v2") + + try: + metadata = metadata or {} + standard_logging_object: Optional[StandardLoggingPayload] = cast( + Optional[StandardLoggingPayload], + kwargs.get("standard_logging_object", None), + ) + tags = ( + self._get_langfuse_tags(standard_logging_object=standard_logging_object) + if self._supports_tags() + else [] + ) + + if standard_logging_object is None: + end_user_id = None + prompt_management_metadata: Optional[ + StandardLoggingPromptManagementMetadata + ] = None + else: + end_user_id = standard_logging_object["metadata"].get( + "user_api_key_end_user_id", None + ) + + prompt_management_metadata = cast( + Optional[StandardLoggingPromptManagementMetadata], + standard_logging_object["metadata"].get( + "prompt_management_metadata", None + ), + ) + + # Clean Metadata before logging - never log raw metadata + # the raw metadata can contain circular references which leads to infinite recursion + # we clean out all extra litellm metadata params before logging + clean_metadata: Dict[str, Any] = {} + if prompt_management_metadata is not None: + clean_metadata["prompt_management_metadata"] = ( + prompt_management_metadata + ) + if isinstance(metadata, dict): + for key, value in metadata.items(): + # generate langfuse tags - Default Tags sent to Langfuse from LiteLLM Proxy + if ( + litellm.langfuse_default_tags is not None + and isinstance(litellm.langfuse_default_tags, list) + and key in litellm.langfuse_default_tags + ): + tags.append(f"{key}:{value}") + + # clean litellm metadata before logging + if key in [ + "headers", + "endpoint", + "caching_groups", + "previous_models", + ]: + continue + else: + clean_metadata[key] = value + + # Add default langfuse tags + tags = self.add_default_langfuse_tags( + tags=tags, kwargs=kwargs, metadata=metadata + ) + + session_id = clean_metadata.pop("session_id", None) + trace_name = cast(Optional[str], clean_metadata.pop("trace_name", None)) + trace_id = clean_metadata.pop("trace_id", litellm_call_id) + existing_trace_id = clean_metadata.pop("existing_trace_id", None) + update_trace_keys = cast(list, clean_metadata.pop("update_trace_keys", [])) + debug = clean_metadata.pop("debug_langfuse", None) + mask_input = clean_metadata.pop("mask_input", False) + mask_output = clean_metadata.pop("mask_output", False) + + clean_metadata = redact_user_api_key_info(metadata=clean_metadata) + + if trace_name is None and existing_trace_id is None: + # just log `litellm-{call_type}` as the trace name + ## DO NOT SET TRACE_NAME if trace-id set. this can lead to overwriting of past traces. + trace_name = f"litellm-{kwargs.get('call_type', 'completion')}" + + if existing_trace_id is not None: + trace_params: Dict[str, Any] = {"id": existing_trace_id} + + # Update the following keys for this trace + for metadata_param_key in update_trace_keys: + trace_param_key = metadata_param_key.replace("trace_", "") + if trace_param_key not in trace_params: + updated_trace_value = clean_metadata.pop( + metadata_param_key, None + ) + if updated_trace_value is not None: + trace_params[trace_param_key] = updated_trace_value + + # Pop the trace specific keys that would have been popped if there were a new trace + for key in list( + filter(lambda key: key.startswith("trace_"), clean_metadata.keys()) + ): + clean_metadata.pop(key, None) + + # Special keys that are found in the function arguments and not the metadata + if "input" in update_trace_keys: + trace_params["input"] = ( + input if not mask_input else "redacted-by-litellm" + ) + if "output" in update_trace_keys: + trace_params["output"] = ( + output if not mask_output else "redacted-by-litellm" + ) + else: # don't overwrite an existing trace + trace_params = { + "id": trace_id, + "name": trace_name, + "session_id": session_id, + "input": input if not mask_input else "redacted-by-litellm", + "version": clean_metadata.pop( + "trace_version", clean_metadata.get("version", None) + ), # If provided just version, it will applied to the trace as well, if applied a trace version it will take precedence + "user_id": end_user_id, + } + for key in list( + filter(lambda key: key.startswith("trace_"), clean_metadata.keys()) + ): + trace_params[key.replace("trace_", "")] = clean_metadata.pop( + key, None + ) + + if level == "ERROR": + trace_params["status_message"] = output + else: + trace_params["output"] = ( + output if not mask_output else "redacted-by-litellm" + ) + + if debug is True or (isinstance(debug, str) and debug.lower() == "true"): + if "metadata" in trace_params: + # log the raw_metadata in the trace + trace_params["metadata"]["metadata_passed_to_litellm"] = metadata + else: + trace_params["metadata"] = {"metadata_passed_to_litellm": metadata} + + cost = kwargs.get("response_cost", None) + verbose_logger.debug(f"trace: {cost}") + + clean_metadata["litellm_response_cost"] = cost + if standard_logging_object is not None: + clean_metadata["hidden_params"] = standard_logging_object[ + "hidden_params" + ] + + if ( + litellm.langfuse_default_tags is not None + and isinstance(litellm.langfuse_default_tags, list) + and "proxy_base_url" in litellm.langfuse_default_tags + ): + proxy_base_url = os.environ.get("PROXY_BASE_URL", None) + if proxy_base_url is not None: + tags.append(f"proxy_base_url:{proxy_base_url}") + + api_base = litellm_params.get("api_base", None) + if api_base: + clean_metadata["api_base"] = api_base + + vertex_location = kwargs.get("vertex_location", None) + if vertex_location: + clean_metadata["vertex_location"] = vertex_location + + aws_region_name = kwargs.get("aws_region_name", None) + if aws_region_name: + clean_metadata["aws_region_name"] = aws_region_name + + if self._supports_tags(): + if "cache_hit" in kwargs: + if kwargs["cache_hit"] is None: + kwargs["cache_hit"] = False + clean_metadata["cache_hit"] = kwargs["cache_hit"] + if existing_trace_id is None: + trace_params.update({"tags": tags}) + + proxy_server_request = litellm_params.get("proxy_server_request", None) + if proxy_server_request: + proxy_server_request.get("method", None) + proxy_server_request.get("url", None) + headers = proxy_server_request.get("headers", None) + clean_headers = {} + if headers: + for key, value in headers.items(): + # these headers can leak our API keys and/or JWT tokens + if key.lower() not in ["authorization", "cookie", "referer"]: + clean_headers[key] = value + + # clean_metadata["request"] = { + # "method": method, + # "url": url, + # "headers": clean_headers, + # } + trace = self.Langfuse.trace(**trace_params) + + # Log provider specific information as a span + log_provider_specific_information_as_span(trace, clean_metadata) + + generation_id = None + usage = None + if response_obj is not None: + if ( + hasattr(response_obj, "id") + and response_obj.get("id", None) is not None + ): + generation_id = litellm.utils.get_logging_id( + start_time, response_obj + ) + _usage_obj = getattr(response_obj, "usage", None) + + if _usage_obj: + usage = { + "prompt_tokens": _usage_obj.prompt_tokens, + "completion_tokens": _usage_obj.completion_tokens, + "total_cost": cost if self._supports_costs() else None, + } + generation_name = clean_metadata.pop("generation_name", None) + if generation_name is None: + # if `generation_name` is None, use sensible default values + # If using litellm proxy user `key_alias` if not None + # If `key_alias` is None, just log `litellm-{call_type}` as the generation name + _user_api_key_alias = cast( + Optional[str], clean_metadata.get("user_api_key_alias", None) + ) + generation_name = ( + f"litellm-{cast(str, kwargs.get('call_type', 'completion'))}" + ) + if _user_api_key_alias is not None: + generation_name = f"litellm:{_user_api_key_alias}" + + if response_obj is not None: + system_fingerprint = getattr(response_obj, "system_fingerprint", None) + else: + system_fingerprint = None + + if system_fingerprint is not None: + optional_params["system_fingerprint"] = system_fingerprint + + generation_params = { + "name": generation_name, + "id": clean_metadata.pop("generation_id", generation_id), + "start_time": start_time, + "end_time": end_time, + "model": kwargs["model"], + "model_parameters": optional_params, + "input": input if not mask_input else "redacted-by-litellm", + "output": output if not mask_output else "redacted-by-litellm", + "usage": usage, + "metadata": log_requester_metadata(clean_metadata), + "level": level, + "version": clean_metadata.pop("version", None), + } + + parent_observation_id = metadata.get("parent_observation_id", None) + if parent_observation_id is not None: + generation_params["parent_observation_id"] = parent_observation_id + + if self._supports_prompt(): + generation_params = _add_prompt_to_generation_params( + generation_params=generation_params, + clean_metadata=clean_metadata, + prompt_management_metadata=prompt_management_metadata, + langfuse_client=self.Langfuse, + ) + if output is not None and isinstance(output, str) and level == "ERROR": + generation_params["status_message"] = output + + if self._supports_completion_start_time(): + generation_params["completion_start_time"] = kwargs.get( + "completion_start_time", None + ) + + generation_client = trace.generation(**generation_params) + + return generation_client.trace_id, generation_id + except Exception: + verbose_logger.error(f"Langfuse Layer Error - {traceback.format_exc()}") + return None, None + + @staticmethod + def _get_chat_content_for_langfuse( + response_obj: ModelResponse, + ): + """ + Get the chat content for Langfuse logging + """ + if response_obj.choices and len(response_obj.choices) > 0: + output = response_obj["choices"][0]["message"].json() + return output + else: + return None + + @staticmethod + def _get_text_completion_content_for_langfuse( + response_obj: TextCompletionResponse, + ): + """ + Get the text completion content for Langfuse logging + """ + if response_obj.choices and len(response_obj.choices) > 0: + return response_obj.choices[0].text + else: + return None + + @staticmethod + def _get_langfuse_tags( + standard_logging_object: Optional[StandardLoggingPayload], + ) -> List[str]: + if standard_logging_object is None: + return [] + return standard_logging_object.get("request_tags", []) or [] + + def add_default_langfuse_tags(self, tags, kwargs, metadata): + """ + Helper function to add litellm default langfuse tags + + - Special LiteLLM tags: + - cache_hit + - cache_key + + """ + if litellm.langfuse_default_tags is not None and isinstance( + litellm.langfuse_default_tags, list + ): + if "cache_hit" in litellm.langfuse_default_tags: + _cache_hit_value = kwargs.get("cache_hit", False) + tags.append(f"cache_hit:{_cache_hit_value}") + if "cache_key" in litellm.langfuse_default_tags: + _hidden_params = metadata.get("hidden_params", {}) or {} + _cache_key = _hidden_params.get("cache_key", None) + if _cache_key is None and litellm.cache is not None: + # fallback to using "preset_cache_key" + _preset_cache_key = litellm.cache._get_preset_cache_key_from_kwargs( + **kwargs + ) + _cache_key = _preset_cache_key + tags.append(f"cache_key:{_cache_key}") + return tags + + def _supports_tags(self): + """Check if current langfuse version supports tags""" + return Version(self.langfuse_sdk_version) >= Version("2.6.3") + + def _supports_prompt(self): + """Check if current langfuse version supports prompt""" + return Version(self.langfuse_sdk_version) >= Version("2.7.3") + + def _supports_costs(self): + """Check if current langfuse version supports costs""" + return Version(self.langfuse_sdk_version) >= Version("2.7.3") + + def _supports_completion_start_time(self): + """Check if current langfuse version supports completion start time""" + return Version(self.langfuse_sdk_version) >= Version("2.7.3") + + @staticmethod + def _get_langfuse_flush_interval(flush_interval: int) -> int: + """ + Get the langfuse flush interval to initialize the Langfuse client + + Reads `LANGFUSE_FLUSH_INTERVAL` from the environment variable. + If not set, uses the flush interval passed in as an argument. + + Args: + flush_interval: The flush interval to use if LANGFUSE_FLUSH_INTERVAL is not set + + Returns: + [int] The flush interval to use to initialize the Langfuse client + """ + return int(os.getenv("LANGFUSE_FLUSH_INTERVAL") or flush_interval) + + +def _add_prompt_to_generation_params( + generation_params: dict, + clean_metadata: dict, + prompt_management_metadata: Optional[StandardLoggingPromptManagementMetadata], + langfuse_client: Any, +) -> dict: + from langfuse import Langfuse + from langfuse.model import ( + ChatPromptClient, + Prompt_Chat, + Prompt_Text, + TextPromptClient, + ) + + langfuse_client = cast(Langfuse, langfuse_client) + + user_prompt = clean_metadata.pop("prompt", None) + if user_prompt is None and prompt_management_metadata is None: + pass + elif isinstance(user_prompt, dict): + if user_prompt.get("type", "") == "chat": + _prompt_chat = Prompt_Chat(**user_prompt) + generation_params["prompt"] = ChatPromptClient(prompt=_prompt_chat) + elif user_prompt.get("type", "") == "text": + _prompt_text = Prompt_Text(**user_prompt) + generation_params["prompt"] = TextPromptClient(prompt=_prompt_text) + elif "version" in user_prompt and "prompt" in user_prompt: + # prompts + if isinstance(user_prompt["prompt"], str): + prompt_text_params = getattr( + Prompt_Text, "model_fields", Prompt_Text.__fields__ + ) + _data = { + "name": user_prompt["name"], + "prompt": user_prompt["prompt"], + "version": user_prompt["version"], + "config": user_prompt.get("config", None), + } + if "labels" in prompt_text_params and "tags" in prompt_text_params: + _data["labels"] = user_prompt.get("labels", []) or [] + _data["tags"] = user_prompt.get("tags", []) or [] + _prompt_obj = Prompt_Text(**_data) # type: ignore + generation_params["prompt"] = TextPromptClient(prompt=_prompt_obj) + + elif isinstance(user_prompt["prompt"], list): + prompt_chat_params = getattr( + Prompt_Chat, "model_fields", Prompt_Chat.__fields__ + ) + _data = { + "name": user_prompt["name"], + "prompt": user_prompt["prompt"], + "version": user_prompt["version"], + "config": user_prompt.get("config", None), + } + if "labels" in prompt_chat_params and "tags" in prompt_chat_params: + _data["labels"] = user_prompt.get("labels", []) or [] + _data["tags"] = user_prompt.get("tags", []) or [] + + _prompt_obj = Prompt_Chat(**_data) # type: ignore + + generation_params["prompt"] = ChatPromptClient(prompt=_prompt_obj) + else: + verbose_logger.error( + "[Non-blocking] Langfuse Logger: Invalid prompt format" + ) + else: + verbose_logger.error( + "[Non-blocking] Langfuse Logger: Invalid prompt format. No prompt logged to Langfuse" + ) + elif ( + prompt_management_metadata is not None + and prompt_management_metadata["prompt_integration"] == "langfuse" + ): + try: + generation_params["prompt"] = langfuse_client.get_prompt( + prompt_management_metadata["prompt_id"] + ) + except Exception as e: + verbose_logger.debug( + f"[Non-blocking] Langfuse Logger: Error getting prompt client for logging: {e}" + ) + pass + + else: + generation_params["prompt"] = user_prompt + + return generation_params + + +def log_provider_specific_information_as_span( + trace, + clean_metadata, +): + """ + Logs provider-specific information as spans. + + Parameters: + trace: The tracing object used to log spans. + clean_metadata: A dictionary containing metadata to be logged. + + Returns: + None + """ + + _hidden_params = clean_metadata.get("hidden_params", None) + if _hidden_params is None: + return + + vertex_ai_grounding_metadata = _hidden_params.get( + "vertex_ai_grounding_metadata", None + ) + + if vertex_ai_grounding_metadata is not None: + if isinstance(vertex_ai_grounding_metadata, list): + for elem in vertex_ai_grounding_metadata: + if isinstance(elem, dict): + for key, value in elem.items(): + trace.span( + name=key, + input=value, + ) + else: + trace.span( + name="vertex_ai_grounding_metadata", + input=elem, + ) + else: + trace.span( + name="vertex_ai_grounding_metadata", + input=vertex_ai_grounding_metadata, + ) + + +def log_requester_metadata(clean_metadata: dict): + returned_metadata = {} + requester_metadata = clean_metadata.get("requester_metadata") or {} + for k, v in clean_metadata.items(): + if k not in requester_metadata: + returned_metadata[k] = v + + returned_metadata.update({"requester_metadata": requester_metadata}) + + return returned_metadata diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/langfuse/langfuse_handler.py b/.venv/lib/python3.12/site-packages/litellm/integrations/langfuse/langfuse_handler.py new file mode 100644 index 00000000..aebe1461 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/integrations/langfuse/langfuse_handler.py @@ -0,0 +1,169 @@ +""" +This file contains the LangFuseHandler class + +Used to get the LangFuseLogger for a given request + +Handles Key/Team Based Langfuse Logging +""" + +from typing import TYPE_CHECKING, Any, Dict, Optional + +from litellm.litellm_core_utils.litellm_logging import StandardCallbackDynamicParams + +from .langfuse import LangFuseLogger, LangfuseLoggingConfig + +if TYPE_CHECKING: + from litellm.litellm_core_utils.litellm_logging import DynamicLoggingCache +else: + DynamicLoggingCache = Any + + +class LangFuseHandler: + + @staticmethod + def get_langfuse_logger_for_request( + standard_callback_dynamic_params: StandardCallbackDynamicParams, + in_memory_dynamic_logger_cache: DynamicLoggingCache, + globalLangfuseLogger: Optional[LangFuseLogger] = None, + ) -> LangFuseLogger: + """ + This function is used to get the LangFuseLogger for a given request + + 1. If dynamic credentials are passed + - check if a LangFuseLogger is cached for the dynamic credentials + - if cached LangFuseLogger is not found, create a new LangFuseLogger and cache it + + 2. If dynamic credentials are not passed return the globalLangfuseLogger + + """ + temp_langfuse_logger: Optional[LangFuseLogger] = globalLangfuseLogger + if ( + LangFuseHandler._dynamic_langfuse_credentials_are_passed( + standard_callback_dynamic_params + ) + is False + ): + return LangFuseHandler._return_global_langfuse_logger( + globalLangfuseLogger=globalLangfuseLogger, + in_memory_dynamic_logger_cache=in_memory_dynamic_logger_cache, + ) + + # get langfuse logging config to use for this request, based on standard_callback_dynamic_params + _credentials = LangFuseHandler.get_dynamic_langfuse_logging_config( + globalLangfuseLogger=globalLangfuseLogger, + standard_callback_dynamic_params=standard_callback_dynamic_params, + ) + credentials_dict = dict(_credentials) + + # check if langfuse logger is already cached + temp_langfuse_logger = in_memory_dynamic_logger_cache.get_cache( + credentials=credentials_dict, service_name="langfuse" + ) + + # if not cached, create a new langfuse logger and cache it + if temp_langfuse_logger is None: + temp_langfuse_logger = ( + LangFuseHandler._create_langfuse_logger_from_credentials( + credentials=credentials_dict, + in_memory_dynamic_logger_cache=in_memory_dynamic_logger_cache, + ) + ) + + return temp_langfuse_logger + + @staticmethod + def _return_global_langfuse_logger( + globalLangfuseLogger: Optional[LangFuseLogger], + in_memory_dynamic_logger_cache: DynamicLoggingCache, + ) -> LangFuseLogger: + """ + Returns the Global LangfuseLogger set on litellm + + (this is the default langfuse logger - used when no dynamic credentials are passed) + + If no Global LangfuseLogger is set, it will check in_memory_dynamic_logger_cache for a cached LangFuseLogger + This function is used to return the globalLangfuseLogger if it exists, otherwise it will check in_memory_dynamic_logger_cache for a cached LangFuseLogger + """ + if globalLangfuseLogger is not None: + return globalLangfuseLogger + + credentials_dict: Dict[str, Any] = ( + {} + ) # the global langfuse logger uses Environment Variables, there are no dynamic credentials + globalLangfuseLogger = in_memory_dynamic_logger_cache.get_cache( + credentials=credentials_dict, + service_name="langfuse", + ) + if globalLangfuseLogger is None: + globalLangfuseLogger = ( + LangFuseHandler._create_langfuse_logger_from_credentials( + credentials=credentials_dict, + in_memory_dynamic_logger_cache=in_memory_dynamic_logger_cache, + ) + ) + return globalLangfuseLogger + + @staticmethod + def _create_langfuse_logger_from_credentials( + credentials: Dict, + in_memory_dynamic_logger_cache: DynamicLoggingCache, + ) -> LangFuseLogger: + """ + This function is used to + 1. create a LangFuseLogger from the credentials + 2. cache the LangFuseLogger to prevent re-creating it for the same credentials + """ + + langfuse_logger = LangFuseLogger( + langfuse_public_key=credentials.get("langfuse_public_key"), + langfuse_secret=credentials.get("langfuse_secret"), + langfuse_host=credentials.get("langfuse_host"), + ) + in_memory_dynamic_logger_cache.set_cache( + credentials=credentials, + service_name="langfuse", + logging_obj=langfuse_logger, + ) + return langfuse_logger + + @staticmethod + def get_dynamic_langfuse_logging_config( + standard_callback_dynamic_params: StandardCallbackDynamicParams, + globalLangfuseLogger: Optional[LangFuseLogger] = None, + ) -> LangfuseLoggingConfig: + """ + This function is used to get the Langfuse logging config to use for a given request. + + It checks if the dynamic parameters are provided in the standard_callback_dynamic_params and uses them to get the Langfuse logging config. + + If no dynamic parameters are provided, it uses the `globalLangfuseLogger` values + """ + # only use dynamic params if langfuse credentials are passed dynamically + return LangfuseLoggingConfig( + langfuse_secret=standard_callback_dynamic_params.get("langfuse_secret") + or standard_callback_dynamic_params.get("langfuse_secret_key"), + langfuse_public_key=standard_callback_dynamic_params.get( + "langfuse_public_key" + ), + langfuse_host=standard_callback_dynamic_params.get("langfuse_host"), + ) + + @staticmethod + def _dynamic_langfuse_credentials_are_passed( + standard_callback_dynamic_params: StandardCallbackDynamicParams, + ) -> bool: + """ + This function is used to check if the dynamic langfuse credentials are passed in standard_callback_dynamic_params + + Returns: + bool: True if the dynamic langfuse credentials are passed, False otherwise + """ + + if ( + standard_callback_dynamic_params.get("langfuse_host") is not None + or standard_callback_dynamic_params.get("langfuse_public_key") is not None + or standard_callback_dynamic_params.get("langfuse_secret") is not None + or standard_callback_dynamic_params.get("langfuse_secret_key") is not None + ): + return True + return False diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/langfuse/langfuse_prompt_management.py b/.venv/lib/python3.12/site-packages/litellm/integrations/langfuse/langfuse_prompt_management.py new file mode 100644 index 00000000..1f4ca84d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/integrations/langfuse/langfuse_prompt_management.py @@ -0,0 +1,287 @@ +""" +Call Hook for LiteLLM Proxy which allows Langfuse prompt management. +""" + +import os +from functools import lru_cache +from typing import TYPE_CHECKING, Any, List, Literal, Optional, Tuple, Union, cast + +from packaging.version import Version +from typing_extensions import TypeAlias + +from litellm.integrations.custom_logger import CustomLogger +from litellm.integrations.prompt_management_base import PromptManagementClient +from litellm.litellm_core_utils.asyncify import run_async_function +from litellm.types.llms.openai import AllMessageValues, ChatCompletionSystemMessage +from litellm.types.utils import StandardCallbackDynamicParams, StandardLoggingPayload + +from ...litellm_core_utils.specialty_caches.dynamic_logging_cache import ( + DynamicLoggingCache, +) +from ..prompt_management_base import PromptManagementBase +from .langfuse import LangFuseLogger +from .langfuse_handler import LangFuseHandler + +if TYPE_CHECKING: + from langfuse import Langfuse + from langfuse.client import ChatPromptClient, TextPromptClient + + LangfuseClass: TypeAlias = Langfuse + + PROMPT_CLIENT = Union[TextPromptClient, ChatPromptClient] +else: + PROMPT_CLIENT = Any + LangfuseClass = Any + +in_memory_dynamic_logger_cache = DynamicLoggingCache() + + +@lru_cache(maxsize=10) +def langfuse_client_init( + langfuse_public_key=None, + langfuse_secret=None, + langfuse_secret_key=None, + langfuse_host=None, + flush_interval=1, +) -> LangfuseClass: + """ + Initialize Langfuse client with caching to prevent multiple initializations. + + Args: + langfuse_public_key (str, optional): Public key for Langfuse. Defaults to None. + langfuse_secret (str, optional): Secret key for Langfuse. Defaults to None. + langfuse_host (str, optional): Host URL for Langfuse. Defaults to None. + flush_interval (int, optional): Flush interval in seconds. Defaults to 1. + + Returns: + Langfuse: Initialized Langfuse client instance + + Raises: + Exception: If langfuse package is not installed + """ + try: + import langfuse + from langfuse import Langfuse + except Exception as e: + raise Exception( + f"\033[91mLangfuse not installed, try running 'pip install langfuse' to fix this error: {e}\n\033[0m" + ) + + # Instance variables + + secret_key = ( + langfuse_secret or langfuse_secret_key or os.getenv("LANGFUSE_SECRET_KEY") + ) + public_key = langfuse_public_key or os.getenv("LANGFUSE_PUBLIC_KEY") + langfuse_host = langfuse_host or os.getenv( + "LANGFUSE_HOST", "https://cloud.langfuse.com" + ) + + if not ( + langfuse_host.startswith("http://") or langfuse_host.startswith("https://") + ): + # add http:// if unset, assume communicating over private network - e.g. render + langfuse_host = "http://" + langfuse_host + + langfuse_release = os.getenv("LANGFUSE_RELEASE") + langfuse_debug = os.getenv("LANGFUSE_DEBUG") + + parameters = { + "public_key": public_key, + "secret_key": secret_key, + "host": langfuse_host, + "release": langfuse_release, + "debug": langfuse_debug, + "flush_interval": LangFuseLogger._get_langfuse_flush_interval( + flush_interval + ), # flush interval in seconds + } + + if Version(langfuse.version.__version__) >= Version("2.6.0"): + parameters["sdk_integration"] = "litellm" + + client = Langfuse(**parameters) + + return client + + +class LangfusePromptManagement(LangFuseLogger, PromptManagementBase, CustomLogger): + def __init__( + self, + langfuse_public_key=None, + langfuse_secret=None, + langfuse_host=None, + flush_interval=1, + ): + import langfuse + + self.langfuse_sdk_version = langfuse.version.__version__ + self.Langfuse = langfuse_client_init( + langfuse_public_key=langfuse_public_key, + langfuse_secret=langfuse_secret, + langfuse_host=langfuse_host, + flush_interval=flush_interval, + ) + + @property + def integration_name(self): + return "langfuse" + + def _get_prompt_from_id( + self, langfuse_prompt_id: str, langfuse_client: LangfuseClass + ) -> PROMPT_CLIENT: + return langfuse_client.get_prompt(langfuse_prompt_id) + + def _compile_prompt( + self, + langfuse_prompt_client: PROMPT_CLIENT, + langfuse_prompt_variables: Optional[dict], + call_type: Union[Literal["completion"], Literal["text_completion"]], + ) -> List[AllMessageValues]: + compiled_prompt: Optional[Union[str, list]] = None + + if langfuse_prompt_variables is None: + langfuse_prompt_variables = {} + + compiled_prompt = langfuse_prompt_client.compile(**langfuse_prompt_variables) + + if isinstance(compiled_prompt, str): + compiled_prompt = [ + ChatCompletionSystemMessage(role="system", content=compiled_prompt) + ] + else: + compiled_prompt = cast(List[AllMessageValues], compiled_prompt) + + return compiled_prompt + + def _get_optional_params_from_langfuse( + self, langfuse_prompt_client: PROMPT_CLIENT + ) -> dict: + config = langfuse_prompt_client.config + optional_params = {} + for k, v in config.items(): + if k != "model": + optional_params[k] = v + return optional_params + + async def async_get_chat_completion_prompt( + self, + model: str, + messages: List[AllMessageValues], + non_default_params: dict, + prompt_id: str, + prompt_variables: Optional[dict], + dynamic_callback_params: StandardCallbackDynamicParams, + ) -> Tuple[ + str, + List[AllMessageValues], + dict, + ]: + return self.get_chat_completion_prompt( + model, + messages, + non_default_params, + prompt_id, + prompt_variables, + dynamic_callback_params, + ) + + def should_run_prompt_management( + self, + prompt_id: str, + dynamic_callback_params: StandardCallbackDynamicParams, + ) -> bool: + langfuse_client = langfuse_client_init( + langfuse_public_key=dynamic_callback_params.get("langfuse_public_key"), + langfuse_secret=dynamic_callback_params.get("langfuse_secret"), + langfuse_secret_key=dynamic_callback_params.get("langfuse_secret_key"), + langfuse_host=dynamic_callback_params.get("langfuse_host"), + ) + langfuse_prompt_client = self._get_prompt_from_id( + langfuse_prompt_id=prompt_id, langfuse_client=langfuse_client + ) + return langfuse_prompt_client is not None + + def _compile_prompt_helper( + self, + prompt_id: str, + prompt_variables: Optional[dict], + dynamic_callback_params: StandardCallbackDynamicParams, + ) -> PromptManagementClient: + langfuse_client = langfuse_client_init( + langfuse_public_key=dynamic_callback_params.get("langfuse_public_key"), + langfuse_secret=dynamic_callback_params.get("langfuse_secret"), + langfuse_secret_key=dynamic_callback_params.get("langfuse_secret_key"), + langfuse_host=dynamic_callback_params.get("langfuse_host"), + ) + langfuse_prompt_client = self._get_prompt_from_id( + langfuse_prompt_id=prompt_id, langfuse_client=langfuse_client + ) + + ## SET PROMPT + compiled_prompt = self._compile_prompt( + langfuse_prompt_client=langfuse_prompt_client, + langfuse_prompt_variables=prompt_variables, + call_type="completion", + ) + + template_model = langfuse_prompt_client.config.get("model") + + template_optional_params = self._get_optional_params_from_langfuse( + langfuse_prompt_client + ) + + return PromptManagementClient( + prompt_id=prompt_id, + prompt_template=compiled_prompt, + prompt_template_model=template_model, + prompt_template_optional_params=template_optional_params, + completed_messages=None, + ) + + def log_success_event(self, kwargs, response_obj, start_time, end_time): + return run_async_function( + self.async_log_success_event, kwargs, response_obj, start_time, end_time + ) + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + standard_callback_dynamic_params = kwargs.get( + "standard_callback_dynamic_params" + ) + langfuse_logger_to_use = LangFuseHandler.get_langfuse_logger_for_request( + globalLangfuseLogger=self, + standard_callback_dynamic_params=standard_callback_dynamic_params, + in_memory_dynamic_logger_cache=in_memory_dynamic_logger_cache, + ) + langfuse_logger_to_use.log_event_on_langfuse( + kwargs=kwargs, + response_obj=response_obj, + start_time=start_time, + end_time=end_time, + user_id=kwargs.get("user", None), + ) + + async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): + standard_callback_dynamic_params = kwargs.get( + "standard_callback_dynamic_params" + ) + langfuse_logger_to_use = LangFuseHandler.get_langfuse_logger_for_request( + globalLangfuseLogger=self, + standard_callback_dynamic_params=standard_callback_dynamic_params, + in_memory_dynamic_logger_cache=in_memory_dynamic_logger_cache, + ) + standard_logging_object = cast( + Optional[StandardLoggingPayload], + kwargs.get("standard_logging_object", None), + ) + if standard_logging_object is None: + return + langfuse_logger_to_use.log_event_on_langfuse( + start_time=start_time, + end_time=end_time, + response_obj=None, + user_id=kwargs.get("user", None), + status_message=standard_logging_object["error_str"], + level="ERROR", + kwargs=kwargs, + ) diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/langsmith.py b/.venv/lib/python3.12/site-packages/litellm/integrations/langsmith.py new file mode 100644 index 00000000..1ef90c18 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/integrations/langsmith.py @@ -0,0 +1,500 @@ +#### What this does #### +# On success, logs events to Langsmith +import asyncio +import os +import random +import traceback +import types +import uuid +from datetime import datetime, timezone +from typing import Any, Dict, List, Optional + +import httpx +from pydantic import BaseModel # type: ignore + +import litellm +from litellm._logging import verbose_logger +from litellm.integrations.custom_batch_logger import CustomBatchLogger +from litellm.llms.custom_httpx.http_handler import ( + get_async_httpx_client, + httpxSpecialProvider, +) +from litellm.types.integrations.langsmith import * +from litellm.types.utils import StandardCallbackDynamicParams, StandardLoggingPayload + + +def is_serializable(value): + non_serializable_types = ( + types.CoroutineType, + types.FunctionType, + types.GeneratorType, + BaseModel, + ) + return not isinstance(value, non_serializable_types) + + +class LangsmithLogger(CustomBatchLogger): + def __init__( + self, + langsmith_api_key: Optional[str] = None, + langsmith_project: Optional[str] = None, + langsmith_base_url: Optional[str] = None, + **kwargs, + ): + self.default_credentials = self.get_credentials_from_env( + langsmith_api_key=langsmith_api_key, + langsmith_project=langsmith_project, + langsmith_base_url=langsmith_base_url, + ) + self.sampling_rate: float = ( + float(os.getenv("LANGSMITH_SAMPLING_RATE")) # type: ignore + if os.getenv("LANGSMITH_SAMPLING_RATE") is not None + and os.getenv("LANGSMITH_SAMPLING_RATE").strip().isdigit() # type: ignore + else 1.0 + ) + self.langsmith_default_run_name = os.getenv( + "LANGSMITH_DEFAULT_RUN_NAME", "LLMRun" + ) + self.async_httpx_client = get_async_httpx_client( + llm_provider=httpxSpecialProvider.LoggingCallback + ) + _batch_size = ( + os.getenv("LANGSMITH_BATCH_SIZE", None) or litellm.langsmith_batch_size + ) + if _batch_size: + self.batch_size = int(_batch_size) + self.log_queue: List[LangsmithQueueObject] = [] + asyncio.create_task(self.periodic_flush()) + self.flush_lock = asyncio.Lock() + + super().__init__(**kwargs, flush_lock=self.flush_lock) + + def get_credentials_from_env( + self, + langsmith_api_key: Optional[str] = None, + langsmith_project: Optional[str] = None, + langsmith_base_url: Optional[str] = None, + ) -> LangsmithCredentialsObject: + + _credentials_api_key = langsmith_api_key or os.getenv("LANGSMITH_API_KEY") + if _credentials_api_key is None: + raise Exception( + "Invalid Langsmith API Key given. _credentials_api_key=None." + ) + _credentials_project = ( + langsmith_project or os.getenv("LANGSMITH_PROJECT") or "litellm-completion" + ) + if _credentials_project is None: + raise Exception( + "Invalid Langsmith API Key given. _credentials_project=None." + ) + _credentials_base_url = ( + langsmith_base_url + or os.getenv("LANGSMITH_BASE_URL") + or "https://api.smith.langchain.com" + ) + if _credentials_base_url is None: + raise Exception( + "Invalid Langsmith API Key given. _credentials_base_url=None." + ) + + return LangsmithCredentialsObject( + LANGSMITH_API_KEY=_credentials_api_key, + LANGSMITH_BASE_URL=_credentials_base_url, + LANGSMITH_PROJECT=_credentials_project, + ) + + def _prepare_log_data( + self, + kwargs, + response_obj, + start_time, + end_time, + credentials: LangsmithCredentialsObject, + ): + try: + _litellm_params = kwargs.get("litellm_params", {}) or {} + metadata = _litellm_params.get("metadata", {}) or {} + project_name = metadata.get( + "project_name", credentials["LANGSMITH_PROJECT"] + ) + run_name = metadata.get("run_name", self.langsmith_default_run_name) + run_id = metadata.get("id", metadata.get("run_id", None)) + parent_run_id = metadata.get("parent_run_id", None) + trace_id = metadata.get("trace_id", None) + session_id = metadata.get("session_id", None) + dotted_order = metadata.get("dotted_order", None) + verbose_logger.debug( + f"Langsmith Logging - project_name: {project_name}, run_name {run_name}" + ) + + # Ensure everything in the payload is converted to str + payload: Optional[StandardLoggingPayload] = kwargs.get( + "standard_logging_object", None + ) + + if payload is None: + raise Exception("Error logging request payload. Payload=none.") + + metadata = payload[ + "metadata" + ] # ensure logged metadata is json serializable + + data = { + "name": run_name, + "run_type": "llm", # this should always be llm, since litellm always logs llm calls. Langsmith allow us to log "chain" + "inputs": payload, + "outputs": payload["response"], + "session_name": project_name, + "start_time": payload["startTime"], + "end_time": payload["endTime"], + "tags": payload["request_tags"], + "extra": metadata, + } + + if payload["error_str"] is not None and payload["status"] == "failure": + data["error"] = payload["error_str"] + + if run_id: + data["id"] = run_id + + if parent_run_id: + data["parent_run_id"] = parent_run_id + + if trace_id: + data["trace_id"] = trace_id + + if session_id: + data["session_id"] = session_id + + if dotted_order: + data["dotted_order"] = dotted_order + + run_id: Optional[str] = data.get("id") # type: ignore + if "id" not in data or data["id"] is None: + """ + for /batch langsmith requires id, trace_id and dotted_order passed as params + """ + run_id = str(uuid.uuid4()) + + data["id"] = run_id + + if ( + "trace_id" not in data + or data["trace_id"] is None + and (run_id is not None and isinstance(run_id, str)) + ): + data["trace_id"] = run_id + + if ( + "dotted_order" not in data + or data["dotted_order"] is None + and (run_id is not None and isinstance(run_id, str)) + ): + data["dotted_order"] = self.make_dot_order(run_id=run_id) # type: ignore + + verbose_logger.debug("Langsmith Logging data on langsmith: %s", data) + + return data + except Exception: + raise + + def log_success_event(self, kwargs, response_obj, start_time, end_time): + try: + sampling_rate = ( + float(os.getenv("LANGSMITH_SAMPLING_RATE")) # type: ignore + if os.getenv("LANGSMITH_SAMPLING_RATE") is not None + and os.getenv("LANGSMITH_SAMPLING_RATE").strip().isdigit() # type: ignore + else 1.0 + ) + random_sample = random.random() + if random_sample > sampling_rate: + verbose_logger.info( + "Skipping Langsmith logging. Sampling rate={}, random_sample={}".format( + sampling_rate, random_sample + ) + ) + return # Skip logging + verbose_logger.debug( + "Langsmith Sync Layer Logging - kwargs: %s, response_obj: %s", + kwargs, + response_obj, + ) + credentials = self._get_credentials_to_use_for_request(kwargs=kwargs) + data = self._prepare_log_data( + kwargs=kwargs, + response_obj=response_obj, + start_time=start_time, + end_time=end_time, + credentials=credentials, + ) + self.log_queue.append( + LangsmithQueueObject( + data=data, + credentials=credentials, + ) + ) + verbose_logger.debug( + f"Langsmith, event added to queue. Will flush in {self.flush_interval} seconds..." + ) + + if len(self.log_queue) >= self.batch_size: + self._send_batch() + + except Exception: + verbose_logger.exception("Langsmith Layer Error - log_success_event error") + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + try: + sampling_rate = self.sampling_rate + random_sample = random.random() + if random_sample > sampling_rate: + verbose_logger.info( + "Skipping Langsmith logging. Sampling rate={}, random_sample={}".format( + sampling_rate, random_sample + ) + ) + return # Skip logging + verbose_logger.debug( + "Langsmith Async Layer Logging - kwargs: %s, response_obj: %s", + kwargs, + response_obj, + ) + credentials = self._get_credentials_to_use_for_request(kwargs=kwargs) + data = self._prepare_log_data( + kwargs=kwargs, + response_obj=response_obj, + start_time=start_time, + end_time=end_time, + credentials=credentials, + ) + self.log_queue.append( + LangsmithQueueObject( + data=data, + credentials=credentials, + ) + ) + verbose_logger.debug( + "Langsmith logging: queue length %s, batch size %s", + len(self.log_queue), + self.batch_size, + ) + if len(self.log_queue) >= self.batch_size: + await self.flush_queue() + except Exception: + verbose_logger.exception( + "Langsmith Layer Error - error logging async success event." + ) + + async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): + sampling_rate = self.sampling_rate + random_sample = random.random() + if random_sample > sampling_rate: + verbose_logger.info( + "Skipping Langsmith logging. Sampling rate={}, random_sample={}".format( + sampling_rate, random_sample + ) + ) + return # Skip logging + verbose_logger.info("Langsmith Failure Event Logging!") + try: + credentials = self._get_credentials_to_use_for_request(kwargs=kwargs) + data = self._prepare_log_data( + kwargs=kwargs, + response_obj=response_obj, + start_time=start_time, + end_time=end_time, + credentials=credentials, + ) + self.log_queue.append( + LangsmithQueueObject( + data=data, + credentials=credentials, + ) + ) + verbose_logger.debug( + "Langsmith logging: queue length %s, batch size %s", + len(self.log_queue), + self.batch_size, + ) + if len(self.log_queue) >= self.batch_size: + await self.flush_queue() + except Exception: + verbose_logger.exception( + "Langsmith Layer Error - error logging async failure event." + ) + + async def async_send_batch(self): + """ + Handles sending batches of runs to Langsmith + + self.log_queue contains LangsmithQueueObjects + Each LangsmithQueueObject has the following: + - "credentials" - credentials to use for the request (langsmith_api_key, langsmith_project, langsmith_base_url) + - "data" - data to log on to langsmith for the request + + + This function + - groups the queue objects by credentials + - loops through each unique credentials and sends batches to Langsmith + + + This was added to support key/team based logging on langsmith + """ + if not self.log_queue: + return + + batch_groups = self._group_batches_by_credentials() + for batch_group in batch_groups.values(): + await self._log_batch_on_langsmith( + credentials=batch_group.credentials, + queue_objects=batch_group.queue_objects, + ) + + def _add_endpoint_to_url( + self, url: str, endpoint: str, api_version: str = "/api/v1" + ) -> str: + if api_version not in url: + url = f"{url.rstrip('/')}{api_version}" + + if url.endswith("/"): + return f"{url}{endpoint}" + return f"{url}/{endpoint}" + + async def _log_batch_on_langsmith( + self, + credentials: LangsmithCredentialsObject, + queue_objects: List[LangsmithQueueObject], + ): + """ + Logs a batch of runs to Langsmith + sends runs to /batch endpoint for the given credentials + + Args: + credentials: LangsmithCredentialsObject + queue_objects: List[LangsmithQueueObject] + + Returns: None + + Raises: Does not raise an exception, will only verbose_logger.exception() + """ + langsmith_api_base = credentials["LANGSMITH_BASE_URL"] + langsmith_api_key = credentials["LANGSMITH_API_KEY"] + url = self._add_endpoint_to_url(langsmith_api_base, "runs/batch") + headers = {"x-api-key": langsmith_api_key} + elements_to_log = [queue_object["data"] for queue_object in queue_objects] + + try: + verbose_logger.debug( + "Sending batch of %s runs to Langsmith", len(elements_to_log) + ) + response = await self.async_httpx_client.post( + url=url, + json={"post": elements_to_log}, + headers=headers, + ) + response.raise_for_status() + + if response.status_code >= 300: + verbose_logger.error( + f"Langsmith Error: {response.status_code} - {response.text}" + ) + else: + verbose_logger.debug( + f"Batch of {len(self.log_queue)} runs successfully created" + ) + except httpx.HTTPStatusError as e: + verbose_logger.exception( + f"Langsmith HTTP Error: {e.response.status_code} - {e.response.text}" + ) + except Exception: + verbose_logger.exception( + f"Langsmith Layer Error - {traceback.format_exc()}" + ) + + def _group_batches_by_credentials(self) -> Dict[CredentialsKey, BatchGroup]: + """Groups queue objects by credentials using a proper key structure""" + log_queue_by_credentials: Dict[CredentialsKey, BatchGroup] = {} + + for queue_object in self.log_queue: + credentials = queue_object["credentials"] + key = CredentialsKey( + api_key=credentials["LANGSMITH_API_KEY"], + project=credentials["LANGSMITH_PROJECT"], + base_url=credentials["LANGSMITH_BASE_URL"], + ) + + if key not in log_queue_by_credentials: + log_queue_by_credentials[key] = BatchGroup( + credentials=credentials, queue_objects=[] + ) + + log_queue_by_credentials[key].queue_objects.append(queue_object) + + return log_queue_by_credentials + + def _get_credentials_to_use_for_request( + self, kwargs: Dict[str, Any] + ) -> LangsmithCredentialsObject: + """ + Handles key/team based logging + + If standard_callback_dynamic_params are provided, use those credentials. + + Otherwise, use the default credentials. + """ + standard_callback_dynamic_params: Optional[StandardCallbackDynamicParams] = ( + kwargs.get("standard_callback_dynamic_params", None) + ) + if standard_callback_dynamic_params is not None: + credentials = self.get_credentials_from_env( + langsmith_api_key=standard_callback_dynamic_params.get( + "langsmith_api_key", None + ), + langsmith_project=standard_callback_dynamic_params.get( + "langsmith_project", None + ), + langsmith_base_url=standard_callback_dynamic_params.get( + "langsmith_base_url", None + ), + ) + else: + credentials = self.default_credentials + return credentials + + def _send_batch(self): + """Calls async_send_batch in an event loop""" + if not self.log_queue: + return + + try: + # Try to get the existing event loop + loop = asyncio.get_event_loop() + if loop.is_running(): + # If we're already in an event loop, create a task + asyncio.create_task(self.async_send_batch()) + else: + # If no event loop is running, run the coroutine directly + loop.run_until_complete(self.async_send_batch()) + except RuntimeError: + # If we can't get an event loop, create a new one + asyncio.run(self.async_send_batch()) + + def get_run_by_id(self, run_id): + + langsmith_api_key = self.default_credentials["LANGSMITH_API_KEY"] + + langsmith_api_base = self.default_credentials["LANGSMITH_BASE_URL"] + + url = f"{langsmith_api_base}/runs/{run_id}" + response = litellm.module_level_client.get( + url=url, + headers={"x-api-key": langsmith_api_key}, + ) + + return response.json() + + def make_dot_order(self, run_id: str): + st = datetime.now(timezone.utc) + id_ = run_id + return st.strftime("%Y%m%dT%H%M%S%fZ") + str(id_) diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/langtrace.py b/.venv/lib/python3.12/site-packages/litellm/integrations/langtrace.py new file mode 100644 index 00000000..51cd272f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/integrations/langtrace.py @@ -0,0 +1,106 @@ +import json +from typing import TYPE_CHECKING, Any + +from litellm.proxy._types import SpanAttributes + +if TYPE_CHECKING: + from opentelemetry.trace import Span as _Span + + Span = _Span +else: + Span = Any + + +class LangtraceAttributes: + """ + This class is used to save trace attributes to Langtrace's spans + """ + + def set_langtrace_attributes(self, span: Span, kwargs, response_obj): + """ + This function is used to log the event to Langtrace + """ + + vendor = kwargs.get("litellm_params").get("custom_llm_provider") + optional_params = kwargs.get("optional_params", {}) + options = {**kwargs, **optional_params} + self.set_request_attributes(span, options, vendor) + self.set_response_attributes(span, response_obj) + self.set_usage_attributes(span, response_obj) + + def set_request_attributes(self, span: Span, kwargs, vendor): + """ + This function is used to get span attributes for the LLM request + """ + span_attributes = { + "gen_ai.operation.name": "chat", + "langtrace.service.name": vendor, + SpanAttributes.LLM_REQUEST_MODEL.value: kwargs.get("model"), + SpanAttributes.LLM_IS_STREAMING.value: kwargs.get("stream"), + SpanAttributes.LLM_REQUEST_TEMPERATURE.value: kwargs.get("temperature"), + SpanAttributes.LLM_TOP_K.value: kwargs.get("top_k"), + SpanAttributes.LLM_REQUEST_TOP_P.value: kwargs.get("top_p"), + SpanAttributes.LLM_USER.value: kwargs.get("user"), + SpanAttributes.LLM_REQUEST_MAX_TOKENS.value: kwargs.get("max_tokens"), + SpanAttributes.LLM_RESPONSE_STOP_REASON.value: kwargs.get("stop"), + SpanAttributes.LLM_FREQUENCY_PENALTY.value: kwargs.get("frequency_penalty"), + SpanAttributes.LLM_PRESENCE_PENALTY.value: kwargs.get("presence_penalty"), + } + + prompts = kwargs.get("messages") + + if prompts: + span.add_event( + name="gen_ai.content.prompt", + attributes={SpanAttributes.LLM_PROMPTS.value: json.dumps(prompts)}, + ) + + self.set_span_attributes(span, span_attributes) + + def set_response_attributes(self, span: Span, response_obj): + """ + This function is used to get span attributes for the LLM response + """ + response_attributes = { + "gen_ai.response_id": response_obj.get("id"), + "gen_ai.system_fingerprint": response_obj.get("system_fingerprint"), + SpanAttributes.LLM_RESPONSE_MODEL.value: response_obj.get("model"), + } + completions = [] + for choice in response_obj.get("choices", []): + role = choice.get("message").get("role") + content = choice.get("message").get("content") + completions.append({"role": role, "content": content}) + + span.add_event( + name="gen_ai.content.completion", + attributes={SpanAttributes.LLM_COMPLETIONS: json.dumps(completions)}, + ) + + self.set_span_attributes(span, response_attributes) + + def set_usage_attributes(self, span: Span, response_obj): + """ + This function is used to get span attributes for the LLM usage + """ + usage = response_obj.get("usage") + if usage: + usage_attributes = { + SpanAttributes.LLM_USAGE_PROMPT_TOKENS.value: usage.get( + "prompt_tokens" + ), + SpanAttributes.LLM_USAGE_COMPLETION_TOKENS.value: usage.get( + "completion_tokens" + ), + SpanAttributes.LLM_USAGE_TOTAL_TOKENS.value: usage.get("total_tokens"), + } + self.set_span_attributes(span, usage_attributes) + + def set_span_attributes(self, span: Span, attributes): + """ + This function is used to set span attributes + """ + for key, value in attributes.items(): + if not value: + continue + span.set_attribute(key, value) diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/literal_ai.py b/.venv/lib/python3.12/site-packages/litellm/integrations/literal_ai.py new file mode 100644 index 00000000..5bf9afd7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/integrations/literal_ai.py @@ -0,0 +1,317 @@ +#### What this does #### +# This file contains the LiteralAILogger class which is used to log steps to the LiteralAI observability platform. +import asyncio +import os +import uuid +from typing import List, Optional + +import httpx + +from litellm._logging import verbose_logger +from litellm.integrations.custom_batch_logger import CustomBatchLogger +from litellm.llms.custom_httpx.http_handler import ( + HTTPHandler, + get_async_httpx_client, + httpxSpecialProvider, +) +from litellm.types.utils import StandardLoggingPayload + + +class LiteralAILogger(CustomBatchLogger): + def __init__( + self, + literalai_api_key=None, + literalai_api_url="https://cloud.getliteral.ai", + env=None, + **kwargs, + ): + self.literalai_api_url = os.getenv("LITERAL_API_URL") or literalai_api_url + self.headers = { + "Content-Type": "application/json", + "x-api-key": literalai_api_key or os.getenv("LITERAL_API_KEY"), + "x-client-name": "litellm", + } + if env: + self.headers["x-env"] = env + self.async_httpx_client = get_async_httpx_client( + llm_provider=httpxSpecialProvider.LoggingCallback + ) + self.sync_http_handler = HTTPHandler() + batch_size = os.getenv("LITERAL_BATCH_SIZE", None) + self.flush_lock = asyncio.Lock() + super().__init__( + **kwargs, + flush_lock=self.flush_lock, + batch_size=int(batch_size) if batch_size else None, + ) + + def log_success_event(self, kwargs, response_obj, start_time, end_time): + try: + verbose_logger.debug( + "Literal AI Layer Logging - kwargs: %s, response_obj: %s", + kwargs, + response_obj, + ) + data = self._prepare_log_data(kwargs, response_obj, start_time, end_time) + self.log_queue.append(data) + verbose_logger.debug( + "Literal AI logging: queue length %s, batch size %s", + len(self.log_queue), + self.batch_size, + ) + if len(self.log_queue) >= self.batch_size: + self._send_batch() + except Exception: + verbose_logger.exception( + "Literal AI Layer Error - error logging success event." + ) + + def log_failure_event(self, kwargs, response_obj, start_time, end_time): + verbose_logger.info("Literal AI Failure Event Logging!") + try: + data = self._prepare_log_data(kwargs, response_obj, start_time, end_time) + self.log_queue.append(data) + verbose_logger.debug( + "Literal AI logging: queue length %s, batch size %s", + len(self.log_queue), + self.batch_size, + ) + if len(self.log_queue) >= self.batch_size: + self._send_batch() + except Exception: + verbose_logger.exception( + "Literal AI Layer Error - error logging failure event." + ) + + def _send_batch(self): + if not self.log_queue: + return + + url = f"{self.literalai_api_url}/api/graphql" + query = self._steps_query_builder(self.log_queue) + variables = self._steps_variables_builder(self.log_queue) + try: + response = self.sync_http_handler.post( + url=url, + json={ + "query": query, + "variables": variables, + }, + headers=self.headers, + ) + + if response.status_code >= 300: + verbose_logger.error( + f"Literal AI Error: {response.status_code} - {response.text}" + ) + else: + verbose_logger.debug( + f"Batch of {len(self.log_queue)} runs successfully created" + ) + except Exception: + verbose_logger.exception("Literal AI Layer Error") + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + try: + verbose_logger.debug( + "Literal AI Async Layer Logging - kwargs: %s, response_obj: %s", + kwargs, + response_obj, + ) + data = self._prepare_log_data(kwargs, response_obj, start_time, end_time) + self.log_queue.append(data) + verbose_logger.debug( + "Literal AI logging: queue length %s, batch size %s", + len(self.log_queue), + self.batch_size, + ) + if len(self.log_queue) >= self.batch_size: + await self.flush_queue() + except Exception: + verbose_logger.exception( + "Literal AI Layer Error - error logging async success event." + ) + + async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): + verbose_logger.info("Literal AI Failure Event Logging!") + try: + data = self._prepare_log_data(kwargs, response_obj, start_time, end_time) + self.log_queue.append(data) + verbose_logger.debug( + "Literal AI logging: queue length %s, batch size %s", + len(self.log_queue), + self.batch_size, + ) + if len(self.log_queue) >= self.batch_size: + await self.flush_queue() + except Exception: + verbose_logger.exception( + "Literal AI Layer Error - error logging async failure event." + ) + + async def async_send_batch(self): + if not self.log_queue: + return + + url = f"{self.literalai_api_url}/api/graphql" + query = self._steps_query_builder(self.log_queue) + variables = self._steps_variables_builder(self.log_queue) + + try: + response = await self.async_httpx_client.post( + url=url, + json={ + "query": query, + "variables": variables, + }, + headers=self.headers, + ) + if response.status_code >= 300: + verbose_logger.error( + f"Literal AI Error: {response.status_code} - {response.text}" + ) + else: + verbose_logger.debug( + f"Batch of {len(self.log_queue)} runs successfully created" + ) + except httpx.HTTPStatusError as e: + verbose_logger.exception( + f"Literal AI HTTP Error: {e.response.status_code} - {e.response.text}" + ) + except Exception: + verbose_logger.exception("Literal AI Layer Error") + + def _prepare_log_data(self, kwargs, response_obj, start_time, end_time) -> dict: + logging_payload: Optional[StandardLoggingPayload] = kwargs.get( + "standard_logging_object", None + ) + + if logging_payload is None: + raise ValueError("standard_logging_object not found in kwargs") + clean_metadata = logging_payload["metadata"] + metadata = kwargs.get("litellm_params", {}).get("metadata", {}) + + settings = logging_payload["model_parameters"] + messages = logging_payload["messages"] + response = logging_payload["response"] + choices: List = [] + if isinstance(response, dict) and "choices" in response: + choices = response["choices"] + message_completion = choices[0]["message"] if choices else None + prompt_id = None + variables = None + + if messages and isinstance(messages, list) and isinstance(messages[0], dict): + for message in messages: + if literal_prompt := getattr(message, "__literal_prompt__", None): + prompt_id = literal_prompt.get("prompt_id") + variables = literal_prompt.get("variables") + message["uuid"] = literal_prompt.get("uuid") + message["templated"] = True + + tools = settings.pop("tools", None) + + step = { + "id": metadata.get("step_id", str(uuid.uuid4())), + "error": logging_payload["error_str"], + "name": kwargs.get("model", ""), + "threadId": metadata.get("literalai_thread_id", None), + "parentId": metadata.get("literalai_parent_id", None), + "rootRunId": metadata.get("literalai_root_run_id", None), + "input": None, + "output": None, + "type": "llm", + "tags": metadata.get("tags", metadata.get("literalai_tags", None)), + "startTime": str(start_time), + "endTime": str(end_time), + "metadata": clean_metadata, + "generation": { + "inputTokenCount": logging_payload["prompt_tokens"], + "outputTokenCount": logging_payload["completion_tokens"], + "tokenCount": logging_payload["total_tokens"], + "promptId": prompt_id, + "variables": variables, + "provider": kwargs.get("custom_llm_provider", "litellm"), + "model": kwargs.get("model", ""), + "duration": (end_time - start_time).total_seconds(), + "settings": settings, + "messages": messages, + "messageCompletion": message_completion, + "tools": tools, + }, + } + return step + + def _steps_query_variables_builder(self, steps): + generated = "" + for id in range(len(steps)): + generated += f"""$id_{id}: String! + $threadId_{id}: String + $rootRunId_{id}: String + $type_{id}: StepType + $startTime_{id}: DateTime + $endTime_{id}: DateTime + $error_{id}: String + $input_{id}: Json + $output_{id}: Json + $metadata_{id}: Json + $parentId_{id}: String + $name_{id}: String + $tags_{id}: [String!] + $generation_{id}: GenerationPayloadInput + $scores_{id}: [ScorePayloadInput!] + $attachments_{id}: [AttachmentPayloadInput!] + """ + return generated + + def _steps_ingest_steps_builder(self, steps): + generated = "" + for id in range(len(steps)): + generated += f""" + step{id}: ingestStep( + id: $id_{id} + threadId: $threadId_{id} + rootRunId: $rootRunId_{id} + startTime: $startTime_{id} + endTime: $endTime_{id} + type: $type_{id} + error: $error_{id} + input: $input_{id} + output: $output_{id} + metadata: $metadata_{id} + parentId: $parentId_{id} + name: $name_{id} + tags: $tags_{id} + generation: $generation_{id} + scores: $scores_{id} + attachments: $attachments_{id} + ) {{ + ok + message + }} + """ + return generated + + def _steps_query_builder(self, steps): + return f""" + mutation AddStep({self._steps_query_variables_builder(steps)}) {{ + {self._steps_ingest_steps_builder(steps)} + }} + """ + + def _steps_variables_builder(self, steps): + def serialize_step(event, id): + result = {} + + for key, value in event.items(): + # Only keep the keys that are not None to avoid overriding existing values + if value is not None: + result[f"{key}_{id}"] = value + + return result + + variables = {} + for i in range(len(steps)): + step = steps[i] + variables.update(serialize_step(step, i)) + return variables diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/logfire_logger.py b/.venv/lib/python3.12/site-packages/litellm/integrations/logfire_logger.py new file mode 100644 index 00000000..516bd4a8 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/integrations/logfire_logger.py @@ -0,0 +1,179 @@ +#### What this does #### +# On success + failure, log events to Logfire + +import os +import traceback +import uuid +from enum import Enum +from typing import Any, Dict, NamedTuple + +from typing_extensions import LiteralString + +from litellm._logging import print_verbose, verbose_logger +from litellm.litellm_core_utils.redact_messages import redact_user_api_key_info + + +class SpanConfig(NamedTuple): + message_template: LiteralString + span_data: Dict[str, Any] + + +class LogfireLevel(str, Enum): + INFO = "info" + ERROR = "error" + + +class LogfireLogger: + # Class variables or attributes + def __init__(self): + try: + verbose_logger.debug("in init logfire logger") + import logfire + + # only setting up logfire if we are sending to logfire + # in testing, we don't want to send to logfire + if logfire.DEFAULT_LOGFIRE_INSTANCE.config.send_to_logfire: + logfire.configure(token=os.getenv("LOGFIRE_TOKEN")) + except Exception as e: + print_verbose(f"Got exception on init logfire client {str(e)}") + raise e + + def _get_span_config(self, payload) -> SpanConfig: + if ( + payload["call_type"] == "completion" + or payload["call_type"] == "acompletion" + ): + return SpanConfig( + message_template="Chat Completion with {request_data[model]!r}", + span_data={"request_data": payload}, + ) + elif ( + payload["call_type"] == "embedding" or payload["call_type"] == "aembedding" + ): + return SpanConfig( + message_template="Embedding Creation with {request_data[model]!r}", + span_data={"request_data": payload}, + ) + elif ( + payload["call_type"] == "image_generation" + or payload["call_type"] == "aimage_generation" + ): + return SpanConfig( + message_template="Image Generation with {request_data[model]!r}", + span_data={"request_data": payload}, + ) + else: + return SpanConfig( + message_template="Litellm Call with {request_data[model]!r}", + span_data={"request_data": payload}, + ) + + async def _async_log_event( + self, + kwargs, + response_obj, + start_time, + end_time, + print_verbose, + level: LogfireLevel, + ): + self.log_event( + kwargs=kwargs, + response_obj=response_obj, + start_time=start_time, + end_time=end_time, + print_verbose=print_verbose, + level=level, + ) + + def log_event( + self, + kwargs, + start_time, + end_time, + print_verbose, + level: LogfireLevel, + response_obj, + ): + try: + import logfire + + verbose_logger.debug( + f"logfire Logging - Enters logging function for model {kwargs}" + ) + + if not response_obj: + response_obj = {} + litellm_params = kwargs.get("litellm_params", {}) + metadata = ( + litellm_params.get("metadata", {}) or {} + ) # if litellm_params['metadata'] == None + messages = kwargs.get("messages") + optional_params = kwargs.get("optional_params", {}) + call_type = kwargs.get("call_type", "completion") + cache_hit = kwargs.get("cache_hit", False) + usage = response_obj.get("usage", {}) + id = response_obj.get("id", str(uuid.uuid4())) + try: + response_time = (end_time - start_time).total_seconds() + except Exception: + response_time = None + + # Clean Metadata before logging - never log raw metadata + # the raw metadata can contain circular references which leads to infinite recursion + # we clean out all extra litellm metadata params before logging + clean_metadata = {} + if isinstance(metadata, dict): + for key, value in metadata.items(): + # clean litellm metadata before logging + if key in [ + "endpoint", + "caching_groups", + "previous_models", + ]: + continue + else: + clean_metadata[key] = value + + clean_metadata = redact_user_api_key_info(metadata=clean_metadata) + + # Build the initial payload + payload = { + "id": id, + "call_type": call_type, + "cache_hit": cache_hit, + "startTime": start_time, + "endTime": end_time, + "responseTime (seconds)": response_time, + "model": kwargs.get("model", ""), + "user": kwargs.get("user", ""), + "modelParameters": optional_params, + "spend": kwargs.get("response_cost", 0), + "messages": messages, + "response": response_obj, + "usage": usage, + "metadata": clean_metadata, + } + logfire_openai = logfire.with_settings(custom_scope_suffix="openai") + message_template, span_data = self._get_span_config(payload) + if level == LogfireLevel.INFO: + logfire_openai.info( + message_template, + **span_data, + ) + elif level == LogfireLevel.ERROR: + logfire_openai.error( + message_template, + **span_data, + _exc_info=True, + ) + print_verbose(f"\ndd Logger - Logging payload = {payload}") + + print_verbose( + f"Logfire Layer Logging - final response object: {response_obj}" + ) + except Exception as e: + verbose_logger.debug( + f"Logfire Layer Error - {str(e)}\n{traceback.format_exc()}" + ) + pass diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/lunary.py b/.venv/lib/python3.12/site-packages/litellm/integrations/lunary.py new file mode 100644 index 00000000..fcd781e4 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/integrations/lunary.py @@ -0,0 +1,181 @@ +#### What this does #### +# On success + failure, log events to lunary.ai +import importlib +import traceback +from datetime import datetime, timezone + +import packaging + + +# convert to {completion: xx, tokens: xx} +def parse_usage(usage): + return { + "completion": usage["completion_tokens"] if "completion_tokens" in usage else 0, + "prompt": usage["prompt_tokens"] if "prompt_tokens" in usage else 0, + } + + +def parse_tool_calls(tool_calls): + if tool_calls is None: + return None + + def clean_tool_call(tool_call): + + serialized = { + "type": tool_call.type, + "id": tool_call.id, + "function": { + "name": tool_call.function.name, + "arguments": tool_call.function.arguments, + }, + } + + return serialized + + return [clean_tool_call(tool_call) for tool_call in tool_calls] + + +def parse_messages(input): + + if input is None: + return None + + def clean_message(message): + # if is string, return as is + if isinstance(message, str): + return message + + if "message" in message: + return clean_message(message["message"]) + + serialized = { + "role": message.get("role"), + "content": message.get("content"), + } + + # Only add tool_calls and function_call to res if they are set + if message.get("tool_calls"): + serialized["tool_calls"] = parse_tool_calls(message.get("tool_calls")) + + return serialized + + if isinstance(input, list): + if len(input) == 1: + return clean_message(input[0]) + else: + return [clean_message(msg) for msg in input] + else: + return clean_message(input) + + +class LunaryLogger: + # Class variables or attributes + def __init__(self): + try: + import lunary + + version = importlib.metadata.version("lunary") # type: ignore + # if version < 0.1.43 then raise ImportError + if packaging.version.Version(version) < packaging.version.Version("0.1.43"): # type: ignore + print( # noqa + "Lunary version outdated. Required: >= 0.1.43. Upgrade via 'pip install lunary --upgrade'" + ) + raise ImportError + + self.lunary_client = lunary + except ImportError: + print( # noqa + "Lunary not installed. Please install it using 'pip install lunary'" + ) # noqa + raise ImportError + + def log_event( + self, + kwargs, + type, + event, + run_id, + model, + print_verbose, + extra={}, + input=None, + user_id=None, + response_obj=None, + start_time=datetime.now(timezone.utc), + end_time=datetime.now(timezone.utc), + error=None, + ): + try: + print_verbose(f"Lunary Logging - Logging request for model {model}") + + template_id = None + litellm_params = kwargs.get("litellm_params", {}) + optional_params = kwargs.get("optional_params", {}) + metadata = litellm_params.get("metadata", {}) or {} + + if optional_params: + extra = {**extra, **optional_params} + + tags = metadata.get("tags", None) + + if extra: + extra.pop("extra_body", None) + extra.pop("user", None) + template_id = extra.pop("extra_headers", {}).get("Template-Id", None) + + # keep only serializable types + for param, value in extra.items(): + if not isinstance(value, (str, int, bool, float)) and param != "tools": + try: + extra[param] = str(value) + except Exception: + pass + + if response_obj: + usage = ( + parse_usage(response_obj["usage"]) + if "usage" in response_obj + else None + ) + + output = response_obj["choices"] if "choices" in response_obj else None + + else: + usage = None + output = None + + if error: + error_obj = {"stack": error} + else: + error_obj = None + + self.lunary_client.track_event( # type: ignore + type, + "start", + run_id, + parent_run_id=metadata.get("parent_run_id", None), + user_id=user_id, + name=model, + input=parse_messages(input), + timestamp=start_time.astimezone(timezone.utc).isoformat(), + template_id=template_id, + metadata=metadata, + runtime="litellm", + tags=tags, + params=extra, + ) + + self.lunary_client.track_event( # type: ignore + type, + event, + run_id, + timestamp=end_time.astimezone(timezone.utc).isoformat(), + runtime="litellm", + error=error_obj, + output=parse_messages(output), + token_usage=usage, + ) + + except Exception: + print_verbose(f"Lunary Logging Error - {traceback.format_exc()}") + pass diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/mlflow.py b/.venv/lib/python3.12/site-packages/litellm/integrations/mlflow.py new file mode 100644 index 00000000..193d1c4e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/integrations/mlflow.py @@ -0,0 +1,269 @@ +import json +import threading +from typing import Optional + +from litellm._logging import verbose_logger +from litellm.integrations.custom_logger import CustomLogger + + +class MlflowLogger(CustomLogger): + def __init__(self): + from mlflow.tracking import MlflowClient + + self._client = MlflowClient() + + self._stream_id_to_span = {} + self._lock = threading.Lock() # lock for _stream_id_to_span + + def log_success_event(self, kwargs, response_obj, start_time, end_time): + self._handle_success(kwargs, response_obj, start_time, end_time) + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + self._handle_success(kwargs, response_obj, start_time, end_time) + + def _handle_success(self, kwargs, response_obj, start_time, end_time): + """ + Log the success event as an MLflow span. + Note that this method is called asynchronously in the background thread. + """ + from mlflow.entities import SpanStatusCode + + try: + verbose_logger.debug("MLflow logging start for success event") + + if kwargs.get("stream"): + self._handle_stream_event(kwargs, response_obj, start_time, end_time) + else: + span = self._start_span_or_trace(kwargs, start_time) + end_time_ns = int(end_time.timestamp() * 1e9) + self._extract_and_set_chat_attributes(span, kwargs, response_obj) + self._end_span_or_trace( + span=span, + outputs=response_obj, + status=SpanStatusCode.OK, + end_time_ns=end_time_ns, + ) + except Exception: + verbose_logger.debug("MLflow Logging Error", stack_info=True) + + def _extract_and_set_chat_attributes(self, span, kwargs, response_obj): + try: + from mlflow.tracing.utils import set_span_chat_messages, set_span_chat_tools + except ImportError: + return + + inputs = self._construct_input(kwargs) + input_messages = inputs.get("messages", []) + output_messages = [c.message.model_dump(exclude_none=True) + for c in getattr(response_obj, "choices", [])] + if messages := [*input_messages, *output_messages]: + set_span_chat_messages(span, messages) + if tools := inputs.get("tools"): + set_span_chat_tools(span, tools) + + def log_failure_event(self, kwargs, response_obj, start_time, end_time): + self._handle_failure(kwargs, response_obj, start_time, end_time) + + async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): + self._handle_failure(kwargs, response_obj, start_time, end_time) + + def _handle_failure(self, kwargs, response_obj, start_time, end_time): + """ + Log the failure event as an MLflow span. + Note that this method is called *synchronously* unlike the success handler. + """ + from mlflow.entities import SpanEvent, SpanStatusCode + + try: + span = self._start_span_or_trace(kwargs, start_time) + + end_time_ns = int(end_time.timestamp() * 1e9) + + # Record exception info as event + if exception := kwargs.get("exception"): + span.add_event(SpanEvent.from_exception(exception)) # type: ignore + + self._extract_and_set_chat_attributes(span, kwargs, response_obj) + self._end_span_or_trace( + span=span, + outputs=response_obj, + status=SpanStatusCode.ERROR, + end_time_ns=end_time_ns, + ) + + except Exception as e: + verbose_logger.debug(f"MLflow Logging Error - {e}", stack_info=True) + + def _handle_stream_event(self, kwargs, response_obj, start_time, end_time): + """ + Handle the success event for a streaming response. For streaming calls, + log_success_event handle is triggered for every chunk of the stream. + We create a single span for the entire stream request as follows: + + 1. For the first chunk, start a new span and store it in the map. + 2. For subsequent chunks, add the chunk as an event to the span. + 3. For the final chunk, end the span and remove the span from the map. + """ + from mlflow.entities import SpanStatusCode + + litellm_call_id = kwargs.get("litellm_call_id") + + if litellm_call_id not in self._stream_id_to_span: + with self._lock: + # Check again after acquiring lock + if litellm_call_id not in self._stream_id_to_span: + # Start a new span for the first chunk of the stream + span = self._start_span_or_trace(kwargs, start_time) + self._stream_id_to_span[litellm_call_id] = span + + # Add chunk as event to the span + span = self._stream_id_to_span[litellm_call_id] + self._add_chunk_events(span, response_obj) + + # If this is the final chunk, end the span. The final chunk + # has complete_streaming_response that gathers the full response. + if final_response := kwargs.get("complete_streaming_response"): + end_time_ns = int(end_time.timestamp() * 1e9) + + self._extract_and_set_chat_attributes(span, kwargs, final_response) + self._end_span_or_trace( + span=span, + outputs=final_response, + status=SpanStatusCode.OK, + end_time_ns=end_time_ns, + ) + + # Remove the stream_id from the map + with self._lock: + self._stream_id_to_span.pop(litellm_call_id) + + def _add_chunk_events(self, span, response_obj): + from mlflow.entities import SpanEvent + + try: + for choice in response_obj.choices: + span.add_event( + SpanEvent( + name="streaming_chunk", + attributes={"delta": json.dumps(choice.delta.model_dump())}, + ) + ) + except Exception: + verbose_logger.debug("Error adding chunk events to span", stack_info=True) + + def _construct_input(self, kwargs): + """Construct span inputs with optional parameters""" + inputs = {"messages": kwargs.get("messages")} + if tools := kwargs.get("tools"): + inputs["tools"] = tools + + for key in ["functions", "tools", "stream", "tool_choice", "user"]: + if value := kwargs.get("optional_params", {}).pop(key, None): + inputs[key] = value + return inputs + + def _extract_attributes(self, kwargs): + """ + Extract span attributes from kwargs. + + With the latest version of litellm, the standard_logging_object contains + canonical information for logging. If it is not present, we extract + subset of attributes from other kwargs. + """ + attributes = { + "litellm_call_id": kwargs.get("litellm_call_id"), + "call_type": kwargs.get("call_type"), + "model": kwargs.get("model"), + } + standard_obj = kwargs.get("standard_logging_object") + if standard_obj: + attributes.update( + { + "api_base": standard_obj.get("api_base"), + "cache_hit": standard_obj.get("cache_hit"), + "usage": { + "completion_tokens": standard_obj.get("completion_tokens"), + "prompt_tokens": standard_obj.get("prompt_tokens"), + "total_tokens": standard_obj.get("total_tokens"), + }, + "raw_llm_response": standard_obj.get("response"), + "response_cost": standard_obj.get("response_cost"), + "saved_cache_cost": standard_obj.get("saved_cache_cost"), + } + ) + else: + litellm_params = kwargs.get("litellm_params", {}) + attributes.update( + { + "model": kwargs.get("model"), + "cache_hit": kwargs.get("cache_hit"), + "custom_llm_provider": kwargs.get("custom_llm_provider"), + "api_base": litellm_params.get("api_base"), + "response_cost": kwargs.get("response_cost"), + } + ) + return attributes + + def _get_span_type(self, call_type: Optional[str]) -> str: + from mlflow.entities import SpanType + + if call_type in ["completion", "acompletion"]: + return SpanType.LLM + elif call_type == "embeddings": + return SpanType.EMBEDDING + else: + return SpanType.LLM + + def _start_span_or_trace(self, kwargs, start_time): + """ + Start an MLflow span or a trace. + + If there is an active span, we start a new span as a child of + that span. Otherwise, we start a new trace. + """ + import mlflow + + call_type = kwargs.get("call_type", "completion") + span_name = f"litellm-{call_type}" + span_type = self._get_span_type(call_type) + start_time_ns = int(start_time.timestamp() * 1e9) + + inputs = self._construct_input(kwargs) + attributes = self._extract_attributes(kwargs) + + if active_span := mlflow.get_current_active_span(): # type: ignore + return self._client.start_span( + name=span_name, + request_id=active_span.request_id, + parent_id=active_span.span_id, + span_type=span_type, + inputs=inputs, + attributes=attributes, + start_time_ns=start_time_ns, + ) + else: + return self._client.start_trace( + name=span_name, + span_type=span_type, + inputs=inputs, + attributes=attributes, + start_time_ns=start_time_ns, + ) + + def _end_span_or_trace(self, span, outputs, end_time_ns, status): + """End an MLflow span or a trace.""" + if span.parent_id is None: + self._client.end_trace( + request_id=span.request_id, + outputs=outputs, + status=status, + end_time_ns=end_time_ns, + ) + else: + self._client.end_span( + request_id=span.request_id, + span_id=span.span_id, + outputs=outputs, + status=status, + end_time_ns=end_time_ns, + ) diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/openmeter.py b/.venv/lib/python3.12/site-packages/litellm/integrations/openmeter.py new file mode 100644 index 00000000..ebfed532 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/integrations/openmeter.py @@ -0,0 +1,132 @@ +# What is this? +## On Success events log cost to OpenMeter - https://github.com/BerriAI/litellm/issues/1268 + +import json +import os + +import httpx + +import litellm +from litellm.integrations.custom_logger import CustomLogger +from litellm.llms.custom_httpx.http_handler import ( + HTTPHandler, + get_async_httpx_client, + httpxSpecialProvider, +) + + +def get_utc_datetime(): + import datetime as dt + from datetime import datetime + + if hasattr(dt, "UTC"): + return datetime.now(dt.UTC) # type: ignore + else: + return datetime.utcnow() # type: ignore + + +class OpenMeterLogger(CustomLogger): + def __init__(self) -> None: + super().__init__() + self.validate_environment() + self.async_http_handler = get_async_httpx_client( + llm_provider=httpxSpecialProvider.LoggingCallback + ) + self.sync_http_handler = HTTPHandler() + + def validate_environment(self): + """ + Expects + OPENMETER_API_ENDPOINT, + OPENMETER_API_KEY, + + in the environment + """ + missing_keys = [] + if os.getenv("OPENMETER_API_KEY", None) is None: + missing_keys.append("OPENMETER_API_KEY") + + if len(missing_keys) > 0: + raise Exception("Missing keys={} in environment.".format(missing_keys)) + + def _common_logic(self, kwargs: dict, response_obj): + call_id = response_obj.get("id", kwargs.get("litellm_call_id")) + dt = get_utc_datetime().isoformat() + cost = kwargs.get("response_cost", None) + model = kwargs.get("model") + usage = {} + if ( + isinstance(response_obj, litellm.ModelResponse) + or isinstance(response_obj, litellm.EmbeddingResponse) + ) and hasattr(response_obj, "usage"): + usage = { + "prompt_tokens": response_obj["usage"].get("prompt_tokens", 0), + "completion_tokens": response_obj["usage"].get("completion_tokens", 0), + "total_tokens": response_obj["usage"].get("total_tokens"), + } + + subject = (kwargs.get("user", None),) # end-user passed in via 'user' param + if not subject: + raise Exception("OpenMeter: user is required") + + return { + "specversion": "1.0", + "type": os.getenv("OPENMETER_EVENT_TYPE", "litellm_tokens"), + "id": call_id, + "time": dt, + "subject": subject, + "source": "litellm-proxy", + "data": {"model": model, "cost": cost, **usage}, + } + + def log_success_event(self, kwargs, response_obj, start_time, end_time): + _url = os.getenv("OPENMETER_API_ENDPOINT", "https://openmeter.cloud") + if _url.endswith("/"): + _url += "api/v1/events" + else: + _url += "/api/v1/events" + + api_key = os.getenv("OPENMETER_API_KEY") + + _data = self._common_logic(kwargs=kwargs, response_obj=response_obj) + _headers = { + "Content-Type": "application/cloudevents+json", + "Authorization": "Bearer {}".format(api_key), + } + + try: + self.sync_http_handler.post( + url=_url, + data=json.dumps(_data), + headers=_headers, + ) + except httpx.HTTPStatusError as e: + raise Exception(f"OpenMeter logging error: {e.response.text}") + except Exception as e: + raise e + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + _url = os.getenv("OPENMETER_API_ENDPOINT", "https://openmeter.cloud") + if _url.endswith("/"): + _url += "api/v1/events" + else: + _url += "/api/v1/events" + + api_key = os.getenv("OPENMETER_API_KEY") + + _data = self._common_logic(kwargs=kwargs, response_obj=response_obj) + _headers = { + "Content-Type": "application/cloudevents+json", + "Authorization": "Bearer {}".format(api_key), + } + + try: + await self.async_http_handler.post( + url=_url, + data=json.dumps(_data), + headers=_headers, + ) + except httpx.HTTPStatusError as e: + raise Exception(f"OpenMeter logging error: {e.response.text}") + except Exception as e: + raise e diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/opentelemetry.py b/.venv/lib/python3.12/site-packages/litellm/integrations/opentelemetry.py new file mode 100644 index 00000000..1572eb81 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/integrations/opentelemetry.py @@ -0,0 +1,1023 @@ +import os +from dataclasses import dataclass +from datetime import datetime +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +import litellm +from litellm._logging import verbose_logger +from litellm.integrations.custom_logger import CustomLogger +from litellm.types.services import ServiceLoggerPayload +from litellm.types.utils import ( + ChatCompletionMessageToolCall, + Function, + StandardCallbackDynamicParams, + StandardLoggingPayload, +) + +if TYPE_CHECKING: + from opentelemetry.sdk.trace.export import SpanExporter as _SpanExporter + from opentelemetry.trace import Span as _Span + + from litellm.proxy._types import ( + ManagementEndpointLoggingPayload as _ManagementEndpointLoggingPayload, + ) + from litellm.proxy.proxy_server import UserAPIKeyAuth as _UserAPIKeyAuth + + Span = _Span + SpanExporter = _SpanExporter + UserAPIKeyAuth = _UserAPIKeyAuth + ManagementEndpointLoggingPayload = _ManagementEndpointLoggingPayload +else: + Span = Any + SpanExporter = Any + UserAPIKeyAuth = Any + ManagementEndpointLoggingPayload = Any + + +LITELLM_TRACER_NAME = os.getenv("OTEL_TRACER_NAME", "litellm") +LITELLM_RESOURCE: Dict[Any, Any] = { + "service.name": os.getenv("OTEL_SERVICE_NAME", "litellm"), + "deployment.environment": os.getenv("OTEL_ENVIRONMENT_NAME", "production"), + "model_id": os.getenv("OTEL_SERVICE_NAME", "litellm"), +} +RAW_REQUEST_SPAN_NAME = "raw_gen_ai_request" +LITELLM_REQUEST_SPAN_NAME = "litellm_request" + + +@dataclass +class OpenTelemetryConfig: + + exporter: Union[str, SpanExporter] = "console" + endpoint: Optional[str] = None + headers: Optional[str] = None + + @classmethod + def from_env(cls): + """ + OTEL_HEADERS=x-honeycomb-team=B85YgLm9**** + OTEL_EXPORTER="otlp_http" + OTEL_ENDPOINT="https://api.honeycomb.io/v1/traces" + + OTEL_HEADERS gets sent as headers = {"x-honeycomb-team": "B85YgLm96******"} + """ + from opentelemetry.sdk.trace.export.in_memory_span_exporter import ( + InMemorySpanExporter, + ) + + if os.getenv("OTEL_EXPORTER") == "in_memory": + return cls(exporter=InMemorySpanExporter()) + return cls( + exporter=os.getenv("OTEL_EXPORTER", "console"), + endpoint=os.getenv("OTEL_ENDPOINT"), + headers=os.getenv( + "OTEL_HEADERS" + ), # example: OTEL_HEADERS=x-honeycomb-team=B85YgLm96***" + ) + + +class OpenTelemetry(CustomLogger): + def __init__( + self, + config: Optional[OpenTelemetryConfig] = None, + callback_name: Optional[str] = None, + **kwargs, + ): + from opentelemetry import trace + from opentelemetry.sdk.resources import Resource + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.trace import SpanKind + + if config is None: + config = OpenTelemetryConfig.from_env() + + self.config = config + self.OTEL_EXPORTER = self.config.exporter + self.OTEL_ENDPOINT = self.config.endpoint + self.OTEL_HEADERS = self.config.headers + provider = TracerProvider(resource=Resource(attributes=LITELLM_RESOURCE)) + provider.add_span_processor(self._get_span_processor()) + self.callback_name = callback_name + + trace.set_tracer_provider(provider) + self.tracer = trace.get_tracer(LITELLM_TRACER_NAME) + + self.span_kind = SpanKind + + _debug_otel = str(os.getenv("DEBUG_OTEL", "False")).lower() + + if _debug_otel == "true": + # Set up logging + import logging + + logging.basicConfig(level=logging.DEBUG) + logging.getLogger(__name__) + + # Enable OpenTelemetry logging + otel_exporter_logger = logging.getLogger("opentelemetry.sdk.trace.export") + otel_exporter_logger.setLevel(logging.DEBUG) + + # init CustomLogger params + super().__init__(**kwargs) + self._init_otel_logger_on_litellm_proxy() + + def _init_otel_logger_on_litellm_proxy(self): + """ + Initializes OpenTelemetry for litellm proxy server + + - Adds Otel as a service callback + - Sets `proxy_server.open_telemetry_logger` to self + """ + from litellm.proxy import proxy_server + + # Add Otel as a service callback + if "otel" not in litellm.service_callback: + litellm.service_callback.append("otel") + setattr(proxy_server, "open_telemetry_logger", self) + + def log_success_event(self, kwargs, response_obj, start_time, end_time): + self._handle_sucess(kwargs, response_obj, start_time, end_time) + + def log_failure_event(self, kwargs, response_obj, start_time, end_time): + self._handle_failure(kwargs, response_obj, start_time, end_time) + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + self._handle_sucess(kwargs, response_obj, start_time, end_time) + + async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): + self._handle_failure(kwargs, response_obj, start_time, end_time) + + async def async_service_success_hook( + self, + payload: ServiceLoggerPayload, + parent_otel_span: Optional[Span] = None, + start_time: Optional[Union[datetime, float]] = None, + end_time: Optional[Union[datetime, float]] = None, + event_metadata: Optional[dict] = None, + ): + + from opentelemetry import trace + from opentelemetry.trace import Status, StatusCode + + _start_time_ns = 0 + _end_time_ns = 0 + + if isinstance(start_time, float): + _start_time_ns = int(start_time * 1e9) + else: + _start_time_ns = self._to_ns(start_time) + + if isinstance(end_time, float): + _end_time_ns = int(end_time * 1e9) + else: + _end_time_ns = self._to_ns(end_time) + + if parent_otel_span is not None: + _span_name = payload.service + service_logging_span = self.tracer.start_span( + name=_span_name, + context=trace.set_span_in_context(parent_otel_span), + start_time=_start_time_ns, + ) + self.safe_set_attribute( + span=service_logging_span, + key="call_type", + value=payload.call_type, + ) + self.safe_set_attribute( + span=service_logging_span, + key="service", + value=payload.service.value, + ) + + if event_metadata: + for key, value in event_metadata.items(): + if value is None: + value = "None" + if isinstance(value, dict): + try: + value = str(value) + except Exception: + value = "litellm logging error - could_not_json_serialize" + self.safe_set_attribute( + span=service_logging_span, + key=key, + value=value, + ) + service_logging_span.set_status(Status(StatusCode.OK)) + service_logging_span.end(end_time=_end_time_ns) + + async def async_service_failure_hook( + self, + payload: ServiceLoggerPayload, + error: Optional[str] = "", + parent_otel_span: Optional[Span] = None, + start_time: Optional[Union[datetime, float]] = None, + end_time: Optional[Union[float, datetime]] = None, + event_metadata: Optional[dict] = None, + ): + + from opentelemetry import trace + from opentelemetry.trace import Status, StatusCode + + _start_time_ns = 0 + _end_time_ns = 0 + + if isinstance(start_time, float): + _start_time_ns = int(int(start_time) * 1e9) + else: + _start_time_ns = self._to_ns(start_time) + + if isinstance(end_time, float): + _end_time_ns = int(int(end_time) * 1e9) + else: + _end_time_ns = self._to_ns(end_time) + + if parent_otel_span is not None: + _span_name = payload.service + service_logging_span = self.tracer.start_span( + name=_span_name, + context=trace.set_span_in_context(parent_otel_span), + start_time=_start_time_ns, + ) + self.safe_set_attribute( + span=service_logging_span, + key="call_type", + value=payload.call_type, + ) + self.safe_set_attribute( + span=service_logging_span, + key="service", + value=payload.service.value, + ) + if error: + self.safe_set_attribute( + span=service_logging_span, + key="error", + value=error, + ) + if event_metadata: + for key, value in event_metadata.items(): + if isinstance(value, dict): + try: + value = str(value) + except Exception: + value = "litllm logging error - could_not_json_serialize" + self.safe_set_attribute( + span=service_logging_span, + key=key, + value=value, + ) + + service_logging_span.set_status(Status(StatusCode.ERROR)) + service_logging_span.end(end_time=_end_time_ns) + + async def async_post_call_failure_hook( + self, + request_data: dict, + original_exception: Exception, + user_api_key_dict: UserAPIKeyAuth, + ): + from opentelemetry import trace + from opentelemetry.trace import Status, StatusCode + + parent_otel_span = user_api_key_dict.parent_otel_span + if parent_otel_span is not None: + parent_otel_span.set_status(Status(StatusCode.ERROR)) + _span_name = "Failed Proxy Server Request" + + # Exception Logging Child Span + exception_logging_span = self.tracer.start_span( + name=_span_name, + context=trace.set_span_in_context(parent_otel_span), + ) + self.safe_set_attribute( + span=exception_logging_span, + key="exception", + value=str(original_exception), + ) + exception_logging_span.set_status(Status(StatusCode.ERROR)) + exception_logging_span.end(end_time=self._to_ns(datetime.now())) + + # End Parent OTEL Sspan + parent_otel_span.end(end_time=self._to_ns(datetime.now())) + + def _handle_sucess(self, kwargs, response_obj, start_time, end_time): + from opentelemetry import trace + from opentelemetry.trace import Status, StatusCode + + verbose_logger.debug( + "OpenTelemetry Logger: Logging kwargs: %s, OTEL config settings=%s", + kwargs, + self.config, + ) + _parent_context, parent_otel_span = self._get_span_context(kwargs) + + self._add_dynamic_span_processor_if_needed(kwargs) + + # Span 1: Requst sent to litellm SDK + span = self.tracer.start_span( + name=self._get_span_name(kwargs), + start_time=self._to_ns(start_time), + context=_parent_context, + ) + span.set_status(Status(StatusCode.OK)) + self.set_attributes(span, kwargs, response_obj) + + if litellm.turn_off_message_logging is True: + pass + elif self.message_logging is not True: + pass + else: + # Span 2: Raw Request / Response to LLM + raw_request_span = self.tracer.start_span( + name=RAW_REQUEST_SPAN_NAME, + start_time=self._to_ns(start_time), + context=trace.set_span_in_context(span), + ) + + raw_request_span.set_status(Status(StatusCode.OK)) + self.set_raw_request_attributes(raw_request_span, kwargs, response_obj) + raw_request_span.end(end_time=self._to_ns(end_time)) + + span.end(end_time=self._to_ns(end_time)) + + if parent_otel_span is not None: + parent_otel_span.end(end_time=self._to_ns(datetime.now())) + + def _add_dynamic_span_processor_if_needed(self, kwargs): + """ + Helper method to add a span processor with dynamic headers if needed. + + This allows for per-request configuration of telemetry exporters by + extracting headers from standard_callback_dynamic_params. + """ + from opentelemetry import trace + + standard_callback_dynamic_params: Optional[StandardCallbackDynamicParams] = ( + kwargs.get("standard_callback_dynamic_params") + ) + if not standard_callback_dynamic_params: + return + + # Extract headers from dynamic params + dynamic_headers = {} + + # Handle Arize headers + if standard_callback_dynamic_params.get("arize_space_key"): + dynamic_headers["space_key"] = standard_callback_dynamic_params.get( + "arize_space_key" + ) + if standard_callback_dynamic_params.get("arize_api_key"): + dynamic_headers["api_key"] = standard_callback_dynamic_params.get( + "arize_api_key" + ) + + # Only create a span processor if we have headers to use + if len(dynamic_headers) > 0: + from opentelemetry.sdk.trace import TracerProvider + + provider = trace.get_tracer_provider() + if isinstance(provider, TracerProvider): + span_processor = self._get_span_processor( + dynamic_headers=dynamic_headers + ) + provider.add_span_processor(span_processor) + + def _handle_failure(self, kwargs, response_obj, start_time, end_time): + from opentelemetry.trace import Status, StatusCode + + verbose_logger.debug( + "OpenTelemetry Logger: Failure HandlerLogging kwargs: %s, OTEL config settings=%s", + kwargs, + self.config, + ) + _parent_context, parent_otel_span = self._get_span_context(kwargs) + + # Span 1: Requst sent to litellm SDK + span = self.tracer.start_span( + name=self._get_span_name(kwargs), + start_time=self._to_ns(start_time), + context=_parent_context, + ) + span.set_status(Status(StatusCode.ERROR)) + self.set_attributes(span, kwargs, response_obj) + span.end(end_time=self._to_ns(end_time)) + + if parent_otel_span is not None: + parent_otel_span.end(end_time=self._to_ns(datetime.now())) + + def set_tools_attributes(self, span: Span, tools): + import json + + from litellm.proxy._types import SpanAttributes + + if not tools: + return + + try: + for i, tool in enumerate(tools): + function = tool.get("function") + if not function: + continue + + prefix = f"{SpanAttributes.LLM_REQUEST_FUNCTIONS}.{i}" + self.safe_set_attribute( + span=span, + key=f"{prefix}.name", + value=function.get("name"), + ) + self.safe_set_attribute( + span=span, + key=f"{prefix}.description", + value=function.get("description"), + ) + self.safe_set_attribute( + span=span, + key=f"{prefix}.parameters", + value=json.dumps(function.get("parameters")), + ) + except Exception as e: + verbose_logger.error( + "OpenTelemetry: Error setting tools attributes: %s", str(e) + ) + pass + + def cast_as_primitive_value_type(self, value) -> Union[str, bool, int, float]: + """ + Casts the value to a primitive OTEL type if it is not already a primitive type. + + OTEL supports - str, bool, int, float + + If it's not a primitive type, then it's converted to a string + """ + if value is None: + return "" + if isinstance(value, (str, bool, int, float)): + return value + try: + return str(value) + except Exception: + return "" + + @staticmethod + def _tool_calls_kv_pair( + tool_calls: List[ChatCompletionMessageToolCall], + ) -> Dict[str, Any]: + from litellm.proxy._types import SpanAttributes + + kv_pairs: Dict[str, Any] = {} + for idx, tool_call in enumerate(tool_calls): + _function = tool_call.get("function") + if not _function: + continue + + keys = Function.__annotations__.keys() + for key in keys: + _value = _function.get(key) + if _value: + kv_pairs[ + f"{SpanAttributes.LLM_COMPLETIONS}.{idx}.function_call.{key}" + ] = _value + + return kv_pairs + + def set_attributes( # noqa: PLR0915 + self, span: Span, kwargs, response_obj: Optional[Any] + ): + try: + if self.callback_name == "arize_phoenix": + from litellm.integrations.arize.arize_phoenix import ArizePhoenixLogger + + ArizePhoenixLogger.set_arize_phoenix_attributes( + span, kwargs, response_obj + ) + return + elif self.callback_name == "langtrace": + from litellm.integrations.langtrace import LangtraceAttributes + + LangtraceAttributes().set_langtrace_attributes( + span, kwargs, response_obj + ) + return + from litellm.proxy._types import SpanAttributes + + optional_params = kwargs.get("optional_params", {}) + litellm_params = kwargs.get("litellm_params", {}) or {} + standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get( + "standard_logging_object" + ) + if standard_logging_payload is None: + raise ValueError("standard_logging_object not found in kwargs") + + # https://github.com/open-telemetry/semantic-conventions/blob/main/model/registry/gen-ai.yaml + # Following Conventions here: https://github.com/open-telemetry/semantic-conventions/blob/main/docs/gen-ai/llm-spans.md + ############################################# + ############ LLM CALL METADATA ############## + ############################################# + metadata = standard_logging_payload["metadata"] + for key, value in metadata.items(): + self.safe_set_attribute( + span=span, key="metadata.{}".format(key), value=value + ) + + ############################################# + ########## LLM Request Attributes ########### + ############################################# + + # The name of the LLM a request is being made to + if kwargs.get("model"): + self.safe_set_attribute( + span=span, + key=SpanAttributes.LLM_REQUEST_MODEL, + value=kwargs.get("model"), + ) + + # The LLM request type + self.safe_set_attribute( + span=span, + key=SpanAttributes.LLM_REQUEST_TYPE, + value=standard_logging_payload["call_type"], + ) + + # The Generative AI Provider: Azure, OpenAI, etc. + self.safe_set_attribute( + span=span, + key=SpanAttributes.LLM_SYSTEM, + value=litellm_params.get("custom_llm_provider", "Unknown"), + ) + + # The maximum number of tokens the LLM generates for a request. + if optional_params.get("max_tokens"): + self.safe_set_attribute( + span=span, + key=SpanAttributes.LLM_REQUEST_MAX_TOKENS, + value=optional_params.get("max_tokens"), + ) + + # The temperature setting for the LLM request. + if optional_params.get("temperature"): + self.safe_set_attribute( + span=span, + key=SpanAttributes.LLM_REQUEST_TEMPERATURE, + value=optional_params.get("temperature"), + ) + + # The top_p sampling setting for the LLM request. + if optional_params.get("top_p"): + self.safe_set_attribute( + span=span, + key=SpanAttributes.LLM_REQUEST_TOP_P, + value=optional_params.get("top_p"), + ) + + self.safe_set_attribute( + span=span, + key=SpanAttributes.LLM_IS_STREAMING, + value=str(optional_params.get("stream", False)), + ) + + if optional_params.get("user"): + self.safe_set_attribute( + span=span, + key=SpanAttributes.LLM_USER, + value=optional_params.get("user"), + ) + + # The unique identifier for the completion. + if response_obj and response_obj.get("id"): + self.safe_set_attribute( + span=span, key="gen_ai.response.id", value=response_obj.get("id") + ) + + # The model used to generate the response. + if response_obj and response_obj.get("model"): + self.safe_set_attribute( + span=span, + key=SpanAttributes.LLM_RESPONSE_MODEL, + value=response_obj.get("model"), + ) + + usage = response_obj and response_obj.get("usage") + if usage: + self.safe_set_attribute( + span=span, + key=SpanAttributes.LLM_USAGE_TOTAL_TOKENS, + value=usage.get("total_tokens"), + ) + + # The number of tokens used in the LLM response (completion). + self.safe_set_attribute( + span=span, + key=SpanAttributes.LLM_USAGE_COMPLETION_TOKENS, + value=usage.get("completion_tokens"), + ) + + # The number of tokens used in the LLM prompt. + self.safe_set_attribute( + span=span, + key=SpanAttributes.LLM_USAGE_PROMPT_TOKENS, + value=usage.get("prompt_tokens"), + ) + + ######################################################################## + ########## LLM Request Medssages / tools / content Attributes ########### + ######################################################################### + + if litellm.turn_off_message_logging is True: + return + if self.message_logging is not True: + return + + if optional_params.get("tools"): + tools = optional_params["tools"] + self.set_tools_attributes(span, tools) + + if kwargs.get("messages"): + for idx, prompt in enumerate(kwargs.get("messages")): + if prompt.get("role"): + self.safe_set_attribute( + span=span, + key=f"{SpanAttributes.LLM_PROMPTS}.{idx}.role", + value=prompt.get("role"), + ) + + if prompt.get("content"): + if not isinstance(prompt.get("content"), str): + prompt["content"] = str(prompt.get("content")) + self.safe_set_attribute( + span=span, + key=f"{SpanAttributes.LLM_PROMPTS}.{idx}.content", + value=prompt.get("content"), + ) + ############################################# + ########## LLM Response Attributes ########## + ############################################# + if response_obj is not None: + if response_obj.get("choices"): + for idx, choice in enumerate(response_obj.get("choices")): + if choice.get("finish_reason"): + self.safe_set_attribute( + span=span, + key=f"{SpanAttributes.LLM_COMPLETIONS}.{idx}.finish_reason", + value=choice.get("finish_reason"), + ) + if choice.get("message"): + if choice.get("message").get("role"): + self.safe_set_attribute( + span=span, + key=f"{SpanAttributes.LLM_COMPLETIONS}.{idx}.role", + value=choice.get("message").get("role"), + ) + if choice.get("message").get("content"): + if not isinstance( + choice.get("message").get("content"), str + ): + choice["message"]["content"] = str( + choice.get("message").get("content") + ) + self.safe_set_attribute( + span=span, + key=f"{SpanAttributes.LLM_COMPLETIONS}.{idx}.content", + value=choice.get("message").get("content"), + ) + + message = choice.get("message") + tool_calls = message.get("tool_calls") + if tool_calls: + kv_pairs = OpenTelemetry._tool_calls_kv_pair(tool_calls) # type: ignore + for key, value in kv_pairs.items(): + self.safe_set_attribute( + span=span, + key=key, + value=value, + ) + + except Exception as e: + verbose_logger.exception( + "OpenTelemetry logging error in set_attributes %s", str(e) + ) + + def _cast_as_primitive_value_type(self, value) -> Union[str, bool, int, float]: + """ + Casts the value to a primitive OTEL type if it is not already a primitive type. + + OTEL supports - str, bool, int, float + + If it's not a primitive type, then it's converted to a string + """ + if value is None: + return "" + if isinstance(value, (str, bool, int, float)): + return value + try: + return str(value) + except Exception: + return "" + + def safe_set_attribute(self, span: Span, key: str, value: Any): + """ + Safely sets an attribute on the span, ensuring the value is a primitive type. + """ + primitive_value = self._cast_as_primitive_value_type(value) + span.set_attribute(key, primitive_value) + + def set_raw_request_attributes(self, span: Span, kwargs, response_obj): + + kwargs.get("optional_params", {}) + litellm_params = kwargs.get("litellm_params", {}) or {} + custom_llm_provider = litellm_params.get("custom_llm_provider", "Unknown") + + _raw_response = kwargs.get("original_response") + _additional_args = kwargs.get("additional_args", {}) or {} + complete_input_dict = _additional_args.get("complete_input_dict") + ############################################# + ########## LLM Request Attributes ########### + ############################################# + + # OTEL Attributes for the RAW Request to https://docs.anthropic.com/en/api/messages + if complete_input_dict and isinstance(complete_input_dict, dict): + for param, val in complete_input_dict.items(): + self.safe_set_attribute( + span=span, key=f"llm.{custom_llm_provider}.{param}", value=val + ) + + ############################################# + ########## LLM Response Attributes ########## + ############################################# + if _raw_response and isinstance(_raw_response, str): + # cast sr -> dict + import json + + try: + _raw_response = json.loads(_raw_response) + for param, val in _raw_response.items(): + self.safe_set_attribute( + span=span, + key=f"llm.{custom_llm_provider}.{param}", + value=val, + ) + except json.JSONDecodeError: + verbose_logger.debug( + "litellm.integrations.opentelemetry.py::set_raw_request_attributes() - raw_response not json string - {}".format( + _raw_response + ) + ) + + self.safe_set_attribute( + span=span, + key=f"llm.{custom_llm_provider}.stringified_raw_response", + value=_raw_response, + ) + + def _to_ns(self, dt): + return int(dt.timestamp() * 1e9) + + def _get_span_name(self, kwargs): + return LITELLM_REQUEST_SPAN_NAME + + def get_traceparent_from_header(self, headers): + if headers is None: + return None + _traceparent = headers.get("traceparent", None) + if _traceparent is None: + return None + + from opentelemetry.trace.propagation.tracecontext import ( + TraceContextTextMapPropagator, + ) + + propagator = TraceContextTextMapPropagator() + carrier = {"traceparent": _traceparent} + _parent_context = propagator.extract(carrier=carrier) + + return _parent_context + + def _get_span_context(self, kwargs): + from opentelemetry import trace + from opentelemetry.trace.propagation.tracecontext import ( + TraceContextTextMapPropagator, + ) + + litellm_params = kwargs.get("litellm_params", {}) or {} + proxy_server_request = litellm_params.get("proxy_server_request", {}) or {} + headers = proxy_server_request.get("headers", {}) or {} + traceparent = headers.get("traceparent", None) + _metadata = litellm_params.get("metadata", {}) or {} + parent_otel_span = _metadata.get("litellm_parent_otel_span", None) + + """ + Two way to use parents in opentelemetry + - using the traceparent header + - using the parent_otel_span in the [metadata][parent_otel_span] + """ + if parent_otel_span is not None: + return trace.set_span_in_context(parent_otel_span), parent_otel_span + + if traceparent is None: + return None, None + else: + carrier = {"traceparent": traceparent} + return TraceContextTextMapPropagator().extract(carrier=carrier), None + + def _get_span_processor(self, dynamic_headers: Optional[dict] = None): + from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import ( + OTLPSpanExporter as OTLPSpanExporterGRPC, + ) + from opentelemetry.exporter.otlp.proto.http.trace_exporter import ( + OTLPSpanExporter as OTLPSpanExporterHTTP, + ) + from opentelemetry.sdk.trace.export import ( + BatchSpanProcessor, + ConsoleSpanExporter, + SimpleSpanProcessor, + SpanExporter, + ) + + verbose_logger.debug( + "OpenTelemetry Logger, initializing span processor \nself.OTEL_EXPORTER: %s\nself.OTEL_ENDPOINT: %s\nself.OTEL_HEADERS: %s", + self.OTEL_EXPORTER, + self.OTEL_ENDPOINT, + self.OTEL_HEADERS, + ) + _split_otel_headers = OpenTelemetry._get_headers_dictionary( + headers=dynamic_headers or self.OTEL_HEADERS + ) + + if isinstance(self.OTEL_EXPORTER, SpanExporter): + verbose_logger.debug( + "OpenTelemetry: intiializing SpanExporter. Value of OTEL_EXPORTER: %s", + self.OTEL_EXPORTER, + ) + return SimpleSpanProcessor(self.OTEL_EXPORTER) + + if self.OTEL_EXPORTER == "console": + verbose_logger.debug( + "OpenTelemetry: intiializing console exporter. Value of OTEL_EXPORTER: %s", + self.OTEL_EXPORTER, + ) + return BatchSpanProcessor(ConsoleSpanExporter()) + elif self.OTEL_EXPORTER == "otlp_http": + verbose_logger.debug( + "OpenTelemetry: intiializing http exporter. Value of OTEL_EXPORTER: %s", + self.OTEL_EXPORTER, + ) + return BatchSpanProcessor( + OTLPSpanExporterHTTP( + endpoint=self.OTEL_ENDPOINT, headers=_split_otel_headers + ), + ) + elif self.OTEL_EXPORTER == "otlp_grpc": + verbose_logger.debug( + "OpenTelemetry: intiializing grpc exporter. Value of OTEL_EXPORTER: %s", + self.OTEL_EXPORTER, + ) + return BatchSpanProcessor( + OTLPSpanExporterGRPC( + endpoint=self.OTEL_ENDPOINT, headers=_split_otel_headers + ), + ) + else: + verbose_logger.debug( + "OpenTelemetry: intiializing console exporter. Value of OTEL_EXPORTER: %s", + self.OTEL_EXPORTER, + ) + return BatchSpanProcessor(ConsoleSpanExporter()) + + @staticmethod + def _get_headers_dictionary(headers: Optional[Union[str, dict]]) -> Dict[str, str]: + """ + Convert a string or dictionary of headers into a dictionary of headers. + """ + _split_otel_headers: Dict[str, str] = {} + if headers: + if isinstance(headers, str): + # when passed HEADERS="x-honeycomb-team=B85YgLm96******" + # Split only on first '=' occurrence + parts = headers.split("=", 1) + if len(parts) == 2: + _split_otel_headers = {parts[0]: parts[1]} + else: + _split_otel_headers = {} + elif isinstance(headers, dict): + _split_otel_headers = headers + return _split_otel_headers + + async def async_management_endpoint_success_hook( + self, + logging_payload: ManagementEndpointLoggingPayload, + parent_otel_span: Optional[Span] = None, + ): + + from opentelemetry import trace + from opentelemetry.trace import Status, StatusCode + + _start_time_ns = 0 + _end_time_ns = 0 + + start_time = logging_payload.start_time + end_time = logging_payload.end_time + + if isinstance(start_time, float): + _start_time_ns = int(start_time * 1e9) + else: + _start_time_ns = self._to_ns(start_time) + + if isinstance(end_time, float): + _end_time_ns = int(end_time * 1e9) + else: + _end_time_ns = self._to_ns(end_time) + + if parent_otel_span is not None: + _span_name = logging_payload.route + management_endpoint_span = self.tracer.start_span( + name=_span_name, + context=trace.set_span_in_context(parent_otel_span), + start_time=_start_time_ns, + ) + + _request_data = logging_payload.request_data + if _request_data is not None: + for key, value in _request_data.items(): + self.safe_set_attribute( + span=management_endpoint_span, + key=f"request.{key}", + value=value, + ) + + _response = logging_payload.response + if _response is not None: + for key, value in _response.items(): + self.safe_set_attribute( + span=management_endpoint_span, + key=f"response.{key}", + value=value, + ) + + management_endpoint_span.set_status(Status(StatusCode.OK)) + management_endpoint_span.end(end_time=_end_time_ns) + + async def async_management_endpoint_failure_hook( + self, + logging_payload: ManagementEndpointLoggingPayload, + parent_otel_span: Optional[Span] = None, + ): + + from opentelemetry import trace + from opentelemetry.trace import Status, StatusCode + + _start_time_ns = 0 + _end_time_ns = 0 + + start_time = logging_payload.start_time + end_time = logging_payload.end_time + + if isinstance(start_time, float): + _start_time_ns = int(int(start_time) * 1e9) + else: + _start_time_ns = self._to_ns(start_time) + + if isinstance(end_time, float): + _end_time_ns = int(int(end_time) * 1e9) + else: + _end_time_ns = self._to_ns(end_time) + + if parent_otel_span is not None: + _span_name = logging_payload.route + management_endpoint_span = self.tracer.start_span( + name=_span_name, + context=trace.set_span_in_context(parent_otel_span), + start_time=_start_time_ns, + ) + + _request_data = logging_payload.request_data + if _request_data is not None: + for key, value in _request_data.items(): + self.safe_set_attribute( + span=management_endpoint_span, + key=f"request.{key}", + value=value, + ) + + _exception = logging_payload.exception + self.safe_set_attribute( + span=management_endpoint_span, + key="exception", + value=str(_exception), + ) + management_endpoint_span.set_status(Status(StatusCode.ERROR)) + management_endpoint_span.end(end_time=_end_time_ns) + + def create_litellm_proxy_request_started_span( + self, + start_time: datetime, + headers: dict, + ) -> Optional[Span]: + """ + Create a span for the received proxy server request. + """ + return self.tracer.start_span( + name="Received Proxy Server Request", + start_time=self._to_ns(start_time), + context=self.get_traceparent_from_header(headers=headers), + kind=self.span_kind.SERVER, + ) diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/opik/opik.py b/.venv/lib/python3.12/site-packages/litellm/integrations/opik/opik.py new file mode 100644 index 00000000..1f7f18f3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/integrations/opik/opik.py @@ -0,0 +1,326 @@ +""" +Opik Logger that logs LLM events to an Opik server +""" + +import asyncio +import json +import traceback +from typing import Dict, List + +from litellm._logging import verbose_logger +from litellm.integrations.custom_batch_logger import CustomBatchLogger +from litellm.llms.custom_httpx.http_handler import ( + _get_httpx_client, + get_async_httpx_client, + httpxSpecialProvider, +) + +from .utils import ( + create_usage_object, + create_uuid7, + get_opik_config_variable, + get_traces_and_spans_from_payload, +) + + +class OpikLogger(CustomBatchLogger): + """ + Opik Logger for logging events to an Opik Server + """ + + def __init__(self, **kwargs): + self.async_httpx_client = get_async_httpx_client( + llm_provider=httpxSpecialProvider.LoggingCallback + ) + self.sync_httpx_client = _get_httpx_client() + + self.opik_project_name = get_opik_config_variable( + "project_name", + user_value=kwargs.get("project_name", None), + default_value="Default Project", + ) + + opik_base_url = get_opik_config_variable( + "url_override", + user_value=kwargs.get("url", None), + default_value="https://www.comet.com/opik/api", + ) + opik_api_key = get_opik_config_variable( + "api_key", user_value=kwargs.get("api_key", None), default_value=None + ) + opik_workspace = get_opik_config_variable( + "workspace", user_value=kwargs.get("workspace", None), default_value=None + ) + + self.trace_url = f"{opik_base_url}/v1/private/traces/batch" + self.span_url = f"{opik_base_url}/v1/private/spans/batch" + + self.headers = {} + if opik_workspace: + self.headers["Comet-Workspace"] = opik_workspace + + if opik_api_key: + self.headers["authorization"] = opik_api_key + + self.opik_workspace = opik_workspace + self.opik_api_key = opik_api_key + try: + asyncio.create_task(self.periodic_flush()) + self.flush_lock = asyncio.Lock() + except Exception as e: + verbose_logger.exception( + f"OpikLogger - Asynchronous processing not initialized as we are not running in an async context {str(e)}" + ) + self.flush_lock = None + + super().__init__(**kwargs, flush_lock=self.flush_lock) + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + try: + opik_payload = self._create_opik_payload( + kwargs=kwargs, + response_obj=response_obj, + start_time=start_time, + end_time=end_time, + ) + + self.log_queue.extend(opik_payload) + verbose_logger.debug( + f"OpikLogger added event to log_queue - Will flush in {self.flush_interval} seconds..." + ) + + if len(self.log_queue) >= self.batch_size: + verbose_logger.debug("OpikLogger - Flushing batch") + await self.flush_queue() + except Exception as e: + verbose_logger.exception( + f"OpikLogger failed to log success event - {str(e)}\n{traceback.format_exc()}" + ) + + def _sync_send(self, url: str, headers: Dict[str, str], batch: Dict): + try: + response = self.sync_httpx_client.post( + url=url, headers=headers, json=batch # type: ignore + ) + response.raise_for_status() + if response.status_code != 204: + raise Exception( + f"Response from opik API status_code: {response.status_code}, text: {response.text}" + ) + except Exception as e: + verbose_logger.exception( + f"OpikLogger failed to send batch - {str(e)}\n{traceback.format_exc()}" + ) + + def log_success_event(self, kwargs, response_obj, start_time, end_time): + try: + opik_payload = self._create_opik_payload( + kwargs=kwargs, + response_obj=response_obj, + start_time=start_time, + end_time=end_time, + ) + + traces, spans = get_traces_and_spans_from_payload(opik_payload) + if len(traces) > 0: + self._sync_send( + url=self.trace_url, headers=self.headers, batch={"traces": traces} + ) + if len(spans) > 0: + self._sync_send( + url=self.span_url, headers=self.headers, batch={"spans": spans} + ) + except Exception as e: + verbose_logger.exception( + f"OpikLogger failed to log success event - {str(e)}\n{traceback.format_exc()}" + ) + + async def _submit_batch(self, url: str, headers: Dict[str, str], batch: Dict): + try: + response = await self.async_httpx_client.post( + url=url, headers=headers, json=batch # type: ignore + ) + response.raise_for_status() + + if response.status_code >= 300: + verbose_logger.error( + f"OpikLogger - Error: {response.status_code} - {response.text}" + ) + else: + verbose_logger.info( + f"OpikLogger - {len(self.log_queue)} Opik events submitted" + ) + except Exception as e: + verbose_logger.exception(f"OpikLogger failed to send batch - {str(e)}") + + def _create_opik_headers(self): + headers = {} + if self.opik_workspace: + headers["Comet-Workspace"] = self.opik_workspace + + if self.opik_api_key: + headers["authorization"] = self.opik_api_key + return headers + + async def async_send_batch(self): + verbose_logger.info("Calling async_send_batch") + if not self.log_queue: + return + + # Split the log_queue into traces and spans + traces, spans = get_traces_and_spans_from_payload(self.log_queue) + + # Send trace batch + if len(traces) > 0: + await self._submit_batch( + url=self.trace_url, headers=self.headers, batch={"traces": traces} + ) + verbose_logger.info(f"Sent {len(traces)} traces") + if len(spans) > 0: + await self._submit_batch( + url=self.span_url, headers=self.headers, batch={"spans": spans} + ) + verbose_logger.info(f"Sent {len(spans)} spans") + + def _create_opik_payload( # noqa: PLR0915 + self, kwargs, response_obj, start_time, end_time + ) -> List[Dict]: + + # Get metadata + _litellm_params = kwargs.get("litellm_params", {}) or {} + litellm_params_metadata = _litellm_params.get("metadata", {}) or {} + + # Extract opik metadata + litellm_opik_metadata = litellm_params_metadata.get("opik", {}) + verbose_logger.debug( + f"litellm_opik_metadata - {json.dumps(litellm_opik_metadata, default=str)}" + ) + project_name = litellm_opik_metadata.get("project_name", self.opik_project_name) + + # Extract trace_id and parent_span_id + current_span_data = litellm_opik_metadata.get("current_span_data", None) + if isinstance(current_span_data, dict): + trace_id = current_span_data.get("trace_id", None) + parent_span_id = current_span_data.get("id", None) + elif current_span_data: + trace_id = current_span_data.trace_id + parent_span_id = current_span_data.id + else: + trace_id = None + parent_span_id = None + # Create Opik tags + opik_tags = litellm_opik_metadata.get("tags", []) + if kwargs.get("custom_llm_provider"): + opik_tags.append(kwargs["custom_llm_provider"]) + + # Use standard_logging_object to create metadata and input/output data + standard_logging_object = kwargs.get("standard_logging_object", None) + if standard_logging_object is None: + verbose_logger.debug( + "OpikLogger skipping event; no standard_logging_object found" + ) + return [] + + # Create input and output data + input_data = standard_logging_object.get("messages", {}) + output_data = standard_logging_object.get("response", {}) + + # Create usage object + usage = create_usage_object(response_obj["usage"]) + + # Define span and trace names + span_name = "%s_%s_%s" % ( + response_obj.get("model", "unknown-model"), + response_obj.get("object", "unknown-object"), + response_obj.get("created", 0), + ) + trace_name = response_obj.get("object", "unknown type") + + # Create metadata object, we add the opik metadata first and then + # update it with the standard_logging_object metadata + metadata = litellm_opik_metadata + if "current_span_data" in metadata: + del metadata["current_span_data"] + metadata["created_from"] = "litellm" + + metadata.update(standard_logging_object.get("metadata", {})) + if "call_type" in standard_logging_object: + metadata["type"] = standard_logging_object["call_type"] + if "status" in standard_logging_object: + metadata["status"] = standard_logging_object["status"] + if "response_cost" in kwargs: + metadata["cost"] = { + "total_tokens": kwargs["response_cost"], + "currency": "USD", + } + if "response_cost_failure_debug_info" in kwargs: + metadata["response_cost_failure_debug_info"] = kwargs[ + "response_cost_failure_debug_info" + ] + if "model_map_information" in standard_logging_object: + metadata["model_map_information"] = standard_logging_object[ + "model_map_information" + ] + if "model" in standard_logging_object: + metadata["model"] = standard_logging_object["model"] + if "model_id" in standard_logging_object: + metadata["model_id"] = standard_logging_object["model_id"] + if "model_group" in standard_logging_object: + metadata["model_group"] = standard_logging_object["model_group"] + if "api_base" in standard_logging_object: + metadata["api_base"] = standard_logging_object["api_base"] + if "cache_hit" in standard_logging_object: + metadata["cache_hit"] = standard_logging_object["cache_hit"] + if "saved_cache_cost" in standard_logging_object: + metadata["saved_cache_cost"] = standard_logging_object["saved_cache_cost"] + if "error_str" in standard_logging_object: + metadata["error_str"] = standard_logging_object["error_str"] + if "model_parameters" in standard_logging_object: + metadata["model_parameters"] = standard_logging_object["model_parameters"] + if "hidden_params" in standard_logging_object: + metadata["hidden_params"] = standard_logging_object["hidden_params"] + + payload = [] + if trace_id is None: + trace_id = create_uuid7() + verbose_logger.debug( + f"OpikLogger creating payload for trace with id {trace_id}" + ) + + payload.append( + { + "project_name": project_name, + "id": trace_id, + "name": trace_name, + "start_time": start_time.isoformat() + "Z", + "end_time": end_time.isoformat() + "Z", + "input": input_data, + "output": output_data, + "metadata": metadata, + "tags": opik_tags, + } + ) + + span_id = create_uuid7() + verbose_logger.debug( + f"OpikLogger creating payload for trace with id {trace_id} and span with id {span_id}" + ) + payload.append( + { + "id": span_id, + "project_name": project_name, + "trace_id": trace_id, + "parent_span_id": parent_span_id, + "name": span_name, + "type": "llm", + "start_time": start_time.isoformat() + "Z", + "end_time": end_time.isoformat() + "Z", + "input": input_data, + "output": output_data, + "metadata": metadata, + "tags": opik_tags, + "usage": usage, + } + ) + verbose_logger.debug(f"Payload: {payload}") + return payload diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/opik/utils.py b/.venv/lib/python3.12/site-packages/litellm/integrations/opik/utils.py new file mode 100644 index 00000000..7b3b64dc --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/integrations/opik/utils.py @@ -0,0 +1,110 @@ +import configparser +import os +import time +from typing import Dict, Final, List, Optional + +CONFIG_FILE_PATH_DEFAULT: Final[str] = "~/.opik.config" + + +def create_uuid7(): + ns = time.time_ns() + last = [0, 0, 0, 0] + + # Simple uuid7 implementation + sixteen_secs = 16_000_000_000 + t1, rest1 = divmod(ns, sixteen_secs) + t2, rest2 = divmod(rest1 << 16, sixteen_secs) + t3, _ = divmod(rest2 << 12, sixteen_secs) + t3 |= 7 << 12 # Put uuid version in top 4 bits, which are 0 in t3 + + # The next two bytes are an int (t4) with two bits for + # the variant 2 and a 14 bit sequence counter which increments + # if the time is unchanged. + if t1 == last[0] and t2 == last[1] and t3 == last[2]: + # Stop the seq counter wrapping past 0x3FFF. + # This won't happen in practice, but if it does, + # uuids after the 16383rd with that same timestamp + # will not longer be correctly ordered but + # are still unique due to the 6 random bytes. + if last[3] < 0x3FFF: + last[3] += 1 + else: + last[:] = (t1, t2, t3, 0) + t4 = (2 << 14) | last[3] # Put variant 0b10 in top two bits + + # Six random bytes for the lower part of the uuid + rand = os.urandom(6) + return f"{t1:>08x}-{t2:>04x}-{t3:>04x}-{t4:>04x}-{rand.hex()}" + + +def _read_opik_config_file() -> Dict[str, str]: + config_path = os.path.expanduser(CONFIG_FILE_PATH_DEFAULT) + + config = configparser.ConfigParser() + config.read(config_path) + + config_values = { + section: dict(config.items(section)) for section in config.sections() + } + + if "opik" in config_values: + return config_values["opik"] + + return {} + + +def _get_env_variable(key: str) -> Optional[str]: + env_prefix = "opik_" + return os.getenv((env_prefix + key).upper(), None) + + +def get_opik_config_variable( + key: str, user_value: Optional[str] = None, default_value: Optional[str] = None +) -> Optional[str]: + """ + Get the configuration value of a variable, order priority is: + 1. user provided value + 2. environment variable + 3. Opik configuration file + 4. default value + """ + # Return user provided value if it is not None + if user_value is not None: + return user_value + + # Return environment variable if it is not None + env_value = _get_env_variable(key) + if env_value is not None: + return env_value + + # Return value from Opik configuration file if it is not None + config_values = _read_opik_config_file() + + if key in config_values: + return config_values[key] + + # Return default value if it is not None + return default_value + + +def create_usage_object(usage): + usage_dict = {} + + if usage.completion_tokens is not None: + usage_dict["completion_tokens"] = usage.completion_tokens + if usage.prompt_tokens is not None: + usage_dict["prompt_tokens"] = usage.prompt_tokens + if usage.total_tokens is not None: + usage_dict["total_tokens"] = usage.total_tokens + return usage_dict + + +def _remove_nulls(x): + x_ = {k: v for k, v in x.items() if v is not None} + return x_ + + +def get_traces_and_spans_from_payload(payload: List): + traces = [_remove_nulls(x) for x in payload if "type" not in x] + spans = [_remove_nulls(x) for x in payload if "type" in x] + return traces, spans diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/pagerduty/pagerduty.py b/.venv/lib/python3.12/site-packages/litellm/integrations/pagerduty/pagerduty.py new file mode 100644 index 00000000..6085bc23 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/integrations/pagerduty/pagerduty.py @@ -0,0 +1,305 @@ +""" +PagerDuty Alerting Integration + +Handles two types of alerts: +- High LLM API Failure Rate. Configure X fails in Y seconds to trigger an alert. +- High Number of Hanging LLM Requests. Configure X hangs in Y seconds to trigger an alert. +""" + +import asyncio +import os +from datetime import datetime, timedelta, timezone +from typing import List, Literal, Optional, Union + +from litellm._logging import verbose_logger +from litellm.caching import DualCache +from litellm.integrations.SlackAlerting.slack_alerting import SlackAlerting +from litellm.llms.custom_httpx.http_handler import ( + AsyncHTTPHandler, + get_async_httpx_client, + httpxSpecialProvider, +) +from litellm.proxy._types import UserAPIKeyAuth +from litellm.types.integrations.pagerduty import ( + AlertingConfig, + PagerDutyInternalEvent, + PagerDutyPayload, + PagerDutyRequestBody, +) +from litellm.types.utils import ( + StandardLoggingPayload, + StandardLoggingPayloadErrorInformation, +) + +PAGERDUTY_DEFAULT_FAILURE_THRESHOLD = 60 +PAGERDUTY_DEFAULT_FAILURE_THRESHOLD_WINDOW_SECONDS = 60 +PAGERDUTY_DEFAULT_HANGING_THRESHOLD_SECONDS = 60 +PAGERDUTY_DEFAULT_HANGING_THRESHOLD_WINDOW_SECONDS = 600 + + +class PagerDutyAlerting(SlackAlerting): + """ + Tracks failed requests and hanging requests separately. + If threshold is crossed for either type, triggers a PagerDuty alert. + """ + + def __init__( + self, alerting_args: Optional[Union[AlertingConfig, dict]] = None, **kwargs + ): + from litellm.proxy.proxy_server import CommonProxyErrors, premium_user + + super().__init__() + _api_key = os.getenv("PAGERDUTY_API_KEY") + if not _api_key: + raise ValueError("PAGERDUTY_API_KEY is not set") + + self.api_key: str = _api_key + alerting_args = alerting_args or {} + self.alerting_args: AlertingConfig = AlertingConfig( + failure_threshold=alerting_args.get( + "failure_threshold", PAGERDUTY_DEFAULT_FAILURE_THRESHOLD + ), + failure_threshold_window_seconds=alerting_args.get( + "failure_threshold_window_seconds", + PAGERDUTY_DEFAULT_FAILURE_THRESHOLD_WINDOW_SECONDS, + ), + hanging_threshold_seconds=alerting_args.get( + "hanging_threshold_seconds", PAGERDUTY_DEFAULT_HANGING_THRESHOLD_SECONDS + ), + hanging_threshold_window_seconds=alerting_args.get( + "hanging_threshold_window_seconds", + PAGERDUTY_DEFAULT_HANGING_THRESHOLD_WINDOW_SECONDS, + ), + ) + + # Separate storage for failures vs. hangs + self._failure_events: List[PagerDutyInternalEvent] = [] + self._hanging_events: List[PagerDutyInternalEvent] = [] + + # premium user check + if premium_user is not True: + raise ValueError( + f"PagerDutyAlerting is only available for LiteLLM Enterprise users. {CommonProxyErrors.not_premium_user.value}" + ) + + # ------------------ MAIN LOGIC ------------------ # + + async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): + """ + Record a failure event. Only send an alert to PagerDuty if the + configured *failure* threshold is exceeded in the specified window. + """ + now = datetime.now(timezone.utc) + standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get( + "standard_logging_object" + ) + if not standard_logging_payload: + raise ValueError( + "standard_logging_object is required for PagerDutyAlerting" + ) + + # Extract error details + error_info: Optional[StandardLoggingPayloadErrorInformation] = ( + standard_logging_payload.get("error_information") or {} + ) + _meta = standard_logging_payload.get("metadata") or {} + + self._failure_events.append( + PagerDutyInternalEvent( + failure_event_type="failed_response", + timestamp=now, + error_class=error_info.get("error_class"), + error_code=error_info.get("error_code"), + error_llm_provider=error_info.get("llm_provider"), + user_api_key_hash=_meta.get("user_api_key_hash"), + user_api_key_alias=_meta.get("user_api_key_alias"), + user_api_key_org_id=_meta.get("user_api_key_org_id"), + user_api_key_team_id=_meta.get("user_api_key_team_id"), + user_api_key_user_id=_meta.get("user_api_key_user_id"), + user_api_key_team_alias=_meta.get("user_api_key_team_alias"), + user_api_key_end_user_id=_meta.get("user_api_key_end_user_id"), + user_api_key_user_email=_meta.get("user_api_key_user_email"), + ) + ) + + # Prune + Possibly alert + window_seconds = self.alerting_args.get("failure_threshold_window_seconds", 60) + threshold = self.alerting_args.get("failure_threshold", 1) + + # If threshold is crossed, send PD alert for failures + await self._send_alert_if_thresholds_crossed( + events=self._failure_events, + window_seconds=window_seconds, + threshold=threshold, + alert_prefix="High LLM API Failure Rate", + ) + + async def async_pre_call_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + cache: DualCache, + data: dict, + call_type: Literal[ + "completion", + "text_completion", + "embeddings", + "image_generation", + "moderation", + "audio_transcription", + "pass_through_endpoint", + "rerank", + ], + ) -> Optional[Union[Exception, str, dict]]: + """ + Example of detecting hanging requests by waiting a given threshold. + If the request didn't finish by then, we treat it as 'hanging'. + """ + verbose_logger.info("Inside Proxy Logging Pre-call hook!") + asyncio.create_task( + self.hanging_response_handler( + request_data=data, user_api_key_dict=user_api_key_dict + ) + ) + return None + + async def hanging_response_handler( + self, request_data: Optional[dict], user_api_key_dict: UserAPIKeyAuth + ): + """ + Checks if request completed by the time 'hanging_threshold_seconds' elapses. + If not, we classify it as a hanging request. + """ + verbose_logger.debug( + f"Inside Hanging Response Handler!..sleeping for {self.alerting_args.get('hanging_threshold_seconds', PAGERDUTY_DEFAULT_HANGING_THRESHOLD_SECONDS)} seconds" + ) + await asyncio.sleep( + self.alerting_args.get( + "hanging_threshold_seconds", PAGERDUTY_DEFAULT_HANGING_THRESHOLD_SECONDS + ) + ) + + if await self._request_is_completed(request_data=request_data): + return # It's not hanging if completed + + # Otherwise, record it as hanging + self._hanging_events.append( + PagerDutyInternalEvent( + failure_event_type="hanging_response", + timestamp=datetime.now(timezone.utc), + error_class="HangingRequest", + error_code="HangingRequest", + error_llm_provider="HangingRequest", + user_api_key_hash=user_api_key_dict.api_key, + user_api_key_alias=user_api_key_dict.key_alias, + user_api_key_org_id=user_api_key_dict.org_id, + user_api_key_team_id=user_api_key_dict.team_id, + user_api_key_user_id=user_api_key_dict.user_id, + user_api_key_team_alias=user_api_key_dict.team_alias, + user_api_key_end_user_id=user_api_key_dict.end_user_id, + user_api_key_user_email=user_api_key_dict.user_email, + ) + ) + + # Prune + Possibly alert + window_seconds = self.alerting_args.get( + "hanging_threshold_window_seconds", + PAGERDUTY_DEFAULT_HANGING_THRESHOLD_WINDOW_SECONDS, + ) + threshold: int = self.alerting_args.get( + "hanging_threshold_fails", PAGERDUTY_DEFAULT_HANGING_THRESHOLD_SECONDS + ) + + # If threshold is crossed, send PD alert for hangs + await self._send_alert_if_thresholds_crossed( + events=self._hanging_events, + window_seconds=window_seconds, + threshold=threshold, + alert_prefix="High Number of Hanging LLM Requests", + ) + + # ------------------ HELPERS ------------------ # + + async def _send_alert_if_thresholds_crossed( + self, + events: List[PagerDutyInternalEvent], + window_seconds: int, + threshold: int, + alert_prefix: str, + ): + """ + 1. Prune old events + 2. If threshold is reached, build alert, send to PagerDuty + 3. Clear those events + """ + cutoff = datetime.now(timezone.utc) - timedelta(seconds=window_seconds) + pruned = [e for e in events if e.get("timestamp", datetime.min) > cutoff] + + # Update the reference list + events.clear() + events.extend(pruned) + + # Check threshold + verbose_logger.debug( + f"Have {len(events)} events in the last {window_seconds} seconds. Threshold is {threshold}" + ) + if len(events) >= threshold: + # Build short summary of last N events + error_summaries = self._build_error_summaries(events, max_errors=5) + alert_message = ( + f"{alert_prefix}: {len(events)} in the last {window_seconds} seconds." + ) + custom_details = {"recent_errors": error_summaries} + + await self.send_alert_to_pagerduty( + alert_message=alert_message, + custom_details=custom_details, + ) + + # Clear them after sending an alert, so we don't spam + events.clear() + + def _build_error_summaries( + self, events: List[PagerDutyInternalEvent], max_errors: int = 5 + ) -> List[PagerDutyInternalEvent]: + """ + Build short text summaries for the last `max_errors`. + Example: "ValueError (code: 500, provider: openai)" + """ + recent = events[-max_errors:] + summaries = [] + for fe in recent: + # If any of these is None, show "N/A" to avoid messing up the summary string + fe.pop("timestamp") + summaries.append(fe) + return summaries + + async def send_alert_to_pagerduty(self, alert_message: str, custom_details: dict): + """ + Send [critical] Alert to PagerDuty + + https://developer.pagerduty.com/api-reference/YXBpOjI3NDgyNjU-pager-duty-v2-events-api + """ + try: + verbose_logger.debug(f"Sending alert to PagerDuty: {alert_message}") + async_client: AsyncHTTPHandler = get_async_httpx_client( + llm_provider=httpxSpecialProvider.LoggingCallback + ) + payload: PagerDutyRequestBody = PagerDutyRequestBody( + payload=PagerDutyPayload( + summary=alert_message, + severity="critical", + source="LiteLLM Alert", + component="LiteLLM", + custom_details=custom_details, + ), + routing_key=self.api_key, + event_action="trigger", + ) + + return await async_client.post( + url="https://events.pagerduty.com/v2/enqueue", + json=dict(payload), + headers={"Content-Type": "application/json"}, + ) + except Exception as e: + verbose_logger.exception(f"Error sending alert to PagerDuty: {e}") diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/prometheus.py b/.venv/lib/python3.12/site-packages/litellm/integrations/prometheus.py new file mode 100644 index 00000000..d6e47b87 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/integrations/prometheus.py @@ -0,0 +1,1789 @@ +# used for /metrics endpoint on LiteLLM Proxy +#### What this does #### +# On success, log events to Prometheus +import asyncio +import sys +from datetime import datetime, timedelta +from typing import Any, Awaitable, Callable, List, Literal, Optional, Tuple, cast + +import litellm +from litellm._logging import print_verbose, verbose_logger +from litellm.integrations.custom_logger import CustomLogger +from litellm.proxy._types import LiteLLM_TeamTable, UserAPIKeyAuth +from litellm.types.integrations.prometheus import * +from litellm.types.utils import StandardLoggingPayload +from litellm.utils import get_end_user_id_for_cost_tracking + + +class PrometheusLogger(CustomLogger): + # Class variables or attributes + def __init__( + self, + **kwargs, + ): + try: + from prometheus_client import Counter, Gauge, Histogram + + from litellm.proxy.proxy_server import CommonProxyErrors, premium_user + + if premium_user is not True: + verbose_logger.warning( + f"🚨🚨🚨 Prometheus Metrics is on LiteLLM Enterprise\n🚨 {CommonProxyErrors.not_premium_user.value}" + ) + self.litellm_not_a_premium_user_metric = Counter( + name="litellm_not_a_premium_user_metric", + documentation=f"🚨🚨🚨 Prometheus Metrics is on LiteLLM Enterprise. 🚨 {CommonProxyErrors.not_premium_user.value}", + ) + return + + self.litellm_proxy_failed_requests_metric = Counter( + name="litellm_proxy_failed_requests_metric", + documentation="Total number of failed responses from proxy - the client did not get a success response from litellm proxy", + labelnames=PrometheusMetricLabels.get_labels( + label_name="litellm_proxy_failed_requests_metric" + ), + ) + + self.litellm_proxy_total_requests_metric = Counter( + name="litellm_proxy_total_requests_metric", + documentation="Total number of requests made to the proxy server - track number of client side requests", + labelnames=PrometheusMetricLabels.get_labels( + label_name="litellm_proxy_total_requests_metric" + ), + ) + + # request latency metrics + self.litellm_request_total_latency_metric = Histogram( + "litellm_request_total_latency_metric", + "Total latency (seconds) for a request to LiteLLM", + labelnames=PrometheusMetricLabels.get_labels( + label_name="litellm_request_total_latency_metric" + ), + buckets=LATENCY_BUCKETS, + ) + + self.litellm_llm_api_latency_metric = Histogram( + "litellm_llm_api_latency_metric", + "Total latency (seconds) for a models LLM API call", + labelnames=PrometheusMetricLabels.get_labels( + label_name="litellm_llm_api_latency_metric" + ), + buckets=LATENCY_BUCKETS, + ) + + self.litellm_llm_api_time_to_first_token_metric = Histogram( + "litellm_llm_api_time_to_first_token_metric", + "Time to first token for a models LLM API call", + labelnames=[ + "model", + "hashed_api_key", + "api_key_alias", + "team", + "team_alias", + ], + buckets=LATENCY_BUCKETS, + ) + + # Counter for spend + self.litellm_spend_metric = Counter( + "litellm_spend_metric", + "Total spend on LLM requests", + labelnames=[ + "end_user", + "hashed_api_key", + "api_key_alias", + "model", + "team", + "team_alias", + "user", + ], + ) + + # Counter for total_output_tokens + self.litellm_tokens_metric = Counter( + "litellm_total_tokens", + "Total number of input + output tokens from LLM requests", + labelnames=[ + "end_user", + "hashed_api_key", + "api_key_alias", + "model", + "team", + "team_alias", + "user", + ], + ) + + self.litellm_input_tokens_metric = Counter( + "litellm_input_tokens", + "Total number of input tokens from LLM requests", + labelnames=PrometheusMetricLabels.get_labels( + label_name="litellm_input_tokens_metric" + ), + ) + + self.litellm_output_tokens_metric = Counter( + "litellm_output_tokens", + "Total number of output tokens from LLM requests", + labelnames=PrometheusMetricLabels.get_labels( + label_name="litellm_output_tokens_metric" + ), + ) + + # Remaining Budget for Team + self.litellm_remaining_team_budget_metric = Gauge( + "litellm_remaining_team_budget_metric", + "Remaining budget for team", + labelnames=PrometheusMetricLabels.get_labels( + label_name="litellm_remaining_team_budget_metric" + ), + ) + + # Max Budget for Team + self.litellm_team_max_budget_metric = Gauge( + "litellm_team_max_budget_metric", + "Maximum budget set for team", + labelnames=PrometheusMetricLabels.get_labels( + label_name="litellm_team_max_budget_metric" + ), + ) + + # Team Budget Reset At + self.litellm_team_budget_remaining_hours_metric = Gauge( + "litellm_team_budget_remaining_hours_metric", + "Remaining days for team budget to be reset", + labelnames=PrometheusMetricLabels.get_labels( + label_name="litellm_team_budget_remaining_hours_metric" + ), + ) + + # Remaining Budget for API Key + self.litellm_remaining_api_key_budget_metric = Gauge( + "litellm_remaining_api_key_budget_metric", + "Remaining budget for api key", + labelnames=PrometheusMetricLabels.get_labels( + label_name="litellm_remaining_api_key_budget_metric" + ), + ) + + # Max Budget for API Key + self.litellm_api_key_max_budget_metric = Gauge( + "litellm_api_key_max_budget_metric", + "Maximum budget set for api key", + labelnames=PrometheusMetricLabels.get_labels( + label_name="litellm_api_key_max_budget_metric" + ), + ) + + self.litellm_api_key_budget_remaining_hours_metric = Gauge( + "litellm_api_key_budget_remaining_hours_metric", + "Remaining hours for api key budget to be reset", + labelnames=PrometheusMetricLabels.get_labels( + label_name="litellm_api_key_budget_remaining_hours_metric" + ), + ) + + ######################################## + # LiteLLM Virtual API KEY metrics + ######################################## + # Remaining MODEL RPM limit for API Key + self.litellm_remaining_api_key_requests_for_model = Gauge( + "litellm_remaining_api_key_requests_for_model", + "Remaining Requests API Key can make for model (model based rpm limit on key)", + labelnames=["hashed_api_key", "api_key_alias", "model"], + ) + + # Remaining MODEL TPM limit for API Key + self.litellm_remaining_api_key_tokens_for_model = Gauge( + "litellm_remaining_api_key_tokens_for_model", + "Remaining Tokens API Key can make for model (model based tpm limit on key)", + labelnames=["hashed_api_key", "api_key_alias", "model"], + ) + + ######################################## + # LLM API Deployment Metrics / analytics + ######################################## + + # Remaining Rate Limit for model + self.litellm_remaining_requests_metric = Gauge( + "litellm_remaining_requests", + "LLM Deployment Analytics - remaining requests for model, returned from LLM API Provider", + labelnames=[ + "model_group", + "api_provider", + "api_base", + "litellm_model_name", + "hashed_api_key", + "api_key_alias", + ], + ) + + self.litellm_remaining_tokens_metric = Gauge( + "litellm_remaining_tokens", + "remaining tokens for model, returned from LLM API Provider", + labelnames=[ + "model_group", + "api_provider", + "api_base", + "litellm_model_name", + "hashed_api_key", + "api_key_alias", + ], + ) + + self.litellm_overhead_latency_metric = Histogram( + "litellm_overhead_latency_metric", + "Latency overhead (milliseconds) added by LiteLLM processing", + labelnames=[ + "model_group", + "api_provider", + "api_base", + "litellm_model_name", + "hashed_api_key", + "api_key_alias", + ], + buckets=LATENCY_BUCKETS, + ) + # llm api provider budget metrics + self.litellm_provider_remaining_budget_metric = Gauge( + "litellm_provider_remaining_budget_metric", + "Remaining budget for provider - used when you set provider budget limits", + labelnames=["api_provider"], + ) + + # Get all keys + _logged_llm_labels = [ + UserAPIKeyLabelNames.v2_LITELLM_MODEL_NAME.value, + UserAPIKeyLabelNames.MODEL_ID.value, + UserAPIKeyLabelNames.API_BASE.value, + UserAPIKeyLabelNames.API_PROVIDER.value, + ] + team_and_key_labels = [ + "hashed_api_key", + "api_key_alias", + "team", + "team_alias", + ] + + # Metric for deployment state + self.litellm_deployment_state = Gauge( + "litellm_deployment_state", + "LLM Deployment Analytics - The state of the deployment: 0 = healthy, 1 = partial outage, 2 = complete outage", + labelnames=_logged_llm_labels, + ) + + self.litellm_deployment_cooled_down = Counter( + "litellm_deployment_cooled_down", + "LLM Deployment Analytics - Number of times a deployment has been cooled down by LiteLLM load balancing logic. exception_status is the status of the exception that caused the deployment to be cooled down", + labelnames=_logged_llm_labels + [EXCEPTION_STATUS], + ) + + self.litellm_deployment_success_responses = Counter( + name="litellm_deployment_success_responses", + documentation="LLM Deployment Analytics - Total number of successful LLM API calls via litellm", + labelnames=[REQUESTED_MODEL] + _logged_llm_labels + team_and_key_labels, + ) + self.litellm_deployment_failure_responses = Counter( + name="litellm_deployment_failure_responses", + documentation="LLM Deployment Analytics - Total number of failed LLM API calls for a specific LLM deploymeny. exception_status is the status of the exception from the llm api", + labelnames=[REQUESTED_MODEL] + + _logged_llm_labels + + EXCEPTION_LABELS + + team_and_key_labels, + ) + self.litellm_deployment_failure_by_tag_responses = Counter( + "litellm_deployment_failure_by_tag_responses", + "Total number of failed LLM API calls for a specific LLM deploymeny by custom metadata tags", + labelnames=[ + UserAPIKeyLabelNames.REQUESTED_MODEL.value, + UserAPIKeyLabelNames.TAG.value, + ] + + _logged_llm_labels + + EXCEPTION_LABELS, + ) + self.litellm_deployment_total_requests = Counter( + name="litellm_deployment_total_requests", + documentation="LLM Deployment Analytics - Total number of LLM API calls via litellm - success + failure", + labelnames=[REQUESTED_MODEL] + _logged_llm_labels + team_and_key_labels, + ) + + # Deployment Latency tracking + team_and_key_labels = [ + "hashed_api_key", + "api_key_alias", + "team", + "team_alias", + ] + self.litellm_deployment_latency_per_output_token = Histogram( + name="litellm_deployment_latency_per_output_token", + documentation="LLM Deployment Analytics - Latency per output token", + labelnames=PrometheusMetricLabels.get_labels( + label_name="litellm_deployment_latency_per_output_token" + ), + ) + + self.litellm_deployment_successful_fallbacks = Counter( + "litellm_deployment_successful_fallbacks", + "LLM Deployment Analytics - Number of successful fallback requests from primary model -> fallback model", + PrometheusMetricLabels.get_labels( + "litellm_deployment_successful_fallbacks" + ), + ) + + self.litellm_deployment_failed_fallbacks = Counter( + "litellm_deployment_failed_fallbacks", + "LLM Deployment Analytics - Number of failed fallback requests from primary model -> fallback model", + PrometheusMetricLabels.get_labels( + "litellm_deployment_failed_fallbacks" + ), + ) + + self.litellm_llm_api_failed_requests_metric = Counter( + name="litellm_llm_api_failed_requests_metric", + documentation="deprecated - use litellm_proxy_failed_requests_metric", + labelnames=[ + "end_user", + "hashed_api_key", + "api_key_alias", + "model", + "team", + "team_alias", + "user", + ], + ) + + self.litellm_requests_metric = Counter( + name="litellm_requests_metric", + documentation="deprecated - use litellm_proxy_total_requests_metric. Total number of LLM calls to litellm - track total per API Key, team, user", + labelnames=PrometheusMetricLabels.get_labels( + label_name="litellm_requests_metric" + ), + ) + self._initialize_prometheus_startup_metrics() + + except Exception as e: + print_verbose(f"Got exception on init prometheus client {str(e)}") + raise e + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + # Define prometheus client + from litellm.types.utils import StandardLoggingPayload + + verbose_logger.debug( + f"prometheus Logging - Enters success logging function for kwargs {kwargs}" + ) + + # unpack kwargs + standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get( + "standard_logging_object" + ) + + if standard_logging_payload is None or not isinstance( + standard_logging_payload, dict + ): + raise ValueError( + f"standard_logging_object is required, got={standard_logging_payload}" + ) + + model = kwargs.get("model", "") + litellm_params = kwargs.get("litellm_params", {}) or {} + _metadata = litellm_params.get("metadata", {}) + end_user_id = get_end_user_id_for_cost_tracking( + litellm_params, service_type="prometheus" + ) + user_id = standard_logging_payload["metadata"]["user_api_key_user_id"] + user_api_key = standard_logging_payload["metadata"]["user_api_key_hash"] + user_api_key_alias = standard_logging_payload["metadata"]["user_api_key_alias"] + user_api_team = standard_logging_payload["metadata"]["user_api_key_team_id"] + user_api_team_alias = standard_logging_payload["metadata"][ + "user_api_key_team_alias" + ] + output_tokens = standard_logging_payload["completion_tokens"] + tokens_used = standard_logging_payload["total_tokens"] + response_cost = standard_logging_payload["response_cost"] + _requester_metadata = standard_logging_payload["metadata"].get( + "requester_metadata" + ) + if standard_logging_payload is not None and isinstance( + standard_logging_payload, dict + ): + _tags = standard_logging_payload["request_tags"] + else: + _tags = [] + + print_verbose( + f"inside track_prometheus_metrics, model {model}, response_cost {response_cost}, tokens_used {tokens_used}, end_user_id {end_user_id}, user_api_key {user_api_key}" + ) + + enum_values = UserAPIKeyLabelValues( + end_user=end_user_id, + hashed_api_key=user_api_key, + api_key_alias=user_api_key_alias, + requested_model=standard_logging_payload["model_group"], + team=user_api_team, + team_alias=user_api_team_alias, + user=user_id, + user_email=standard_logging_payload["metadata"]["user_api_key_user_email"], + status_code="200", + model=model, + litellm_model_name=model, + tags=_tags, + model_id=standard_logging_payload["model_id"], + api_base=standard_logging_payload["api_base"], + api_provider=standard_logging_payload["custom_llm_provider"], + exception_status=None, + exception_class=None, + custom_metadata_labels=get_custom_labels_from_metadata( + metadata=standard_logging_payload["metadata"].get("requester_metadata") + or {} + ), + ) + + if ( + user_api_key is not None + and isinstance(user_api_key, str) + and user_api_key.startswith("sk-") + ): + from litellm.proxy.utils import hash_token + + user_api_key = hash_token(user_api_key) + + # increment total LLM requests and spend metric + self._increment_top_level_request_and_spend_metrics( + end_user_id=end_user_id, + user_api_key=user_api_key, + user_api_key_alias=user_api_key_alias, + model=model, + user_api_team=user_api_team, + user_api_team_alias=user_api_team_alias, + user_id=user_id, + response_cost=response_cost, + enum_values=enum_values, + ) + + # input, output, total token metrics + self._increment_token_metrics( + # why type ignore below? + # 1. We just checked if isinstance(standard_logging_payload, dict). Pyright complains. + # 2. Pyright does not allow us to run isinstance(standard_logging_payload, StandardLoggingPayload) <- this would be ideal + standard_logging_payload=standard_logging_payload, # type: ignore + end_user_id=end_user_id, + user_api_key=user_api_key, + user_api_key_alias=user_api_key_alias, + model=model, + user_api_team=user_api_team, + user_api_team_alias=user_api_team_alias, + user_id=user_id, + enum_values=enum_values, + ) + + # remaining budget metrics + await self._increment_remaining_budget_metrics( + user_api_team=user_api_team, + user_api_team_alias=user_api_team_alias, + user_api_key=user_api_key, + user_api_key_alias=user_api_key_alias, + litellm_params=litellm_params, + response_cost=response_cost, + ) + + # set proxy virtual key rpm/tpm metrics + self._set_virtual_key_rate_limit_metrics( + user_api_key=user_api_key, + user_api_key_alias=user_api_key_alias, + kwargs=kwargs, + metadata=_metadata, + ) + + # set latency metrics + self._set_latency_metrics( + kwargs=kwargs, + model=model, + user_api_key=user_api_key, + user_api_key_alias=user_api_key_alias, + user_api_team=user_api_team, + user_api_team_alias=user_api_team_alias, + # why type ignore below? + # 1. We just checked if isinstance(standard_logging_payload, dict). Pyright complains. + # 2. Pyright does not allow us to run isinstance(standard_logging_payload, StandardLoggingPayload) <- this would be ideal + enum_values=enum_values, + ) + + # set x-ratelimit headers + self.set_llm_deployment_success_metrics( + kwargs, start_time, end_time, enum_values, output_tokens + ) + + if ( + standard_logging_payload["stream"] is True + ): # log successful streaming requests from logging event hook. + _labels = prometheus_label_factory( + supported_enum_labels=PrometheusMetricLabels.get_labels( + label_name="litellm_proxy_total_requests_metric" + ), + enum_values=enum_values, + ) + self.litellm_proxy_total_requests_metric.labels(**_labels).inc() + + def _increment_token_metrics( + self, + standard_logging_payload: StandardLoggingPayload, + end_user_id: Optional[str], + user_api_key: Optional[str], + user_api_key_alias: Optional[str], + model: Optional[str], + user_api_team: Optional[str], + user_api_team_alias: Optional[str], + user_id: Optional[str], + enum_values: UserAPIKeyLabelValues, + ): + # token metrics + self.litellm_tokens_metric.labels( + end_user_id, + user_api_key, + user_api_key_alias, + model, + user_api_team, + user_api_team_alias, + user_id, + ).inc(standard_logging_payload["total_tokens"]) + + if standard_logging_payload is not None and isinstance( + standard_logging_payload, dict + ): + _tags = standard_logging_payload["request_tags"] + + _labels = prometheus_label_factory( + supported_enum_labels=PrometheusMetricLabels.get_labels( + label_name="litellm_input_tokens_metric" + ), + enum_values=enum_values, + ) + self.litellm_input_tokens_metric.labels(**_labels).inc( + standard_logging_payload["prompt_tokens"] + ) + + _labels = prometheus_label_factory( + supported_enum_labels=PrometheusMetricLabels.get_labels( + label_name="litellm_output_tokens_metric" + ), + enum_values=enum_values, + ) + + self.litellm_output_tokens_metric.labels(**_labels).inc( + standard_logging_payload["completion_tokens"] + ) + + async def _increment_remaining_budget_metrics( + self, + user_api_team: Optional[str], + user_api_team_alias: Optional[str], + user_api_key: Optional[str], + user_api_key_alias: Optional[str], + litellm_params: dict, + response_cost: float, + ): + _team_spend = litellm_params.get("metadata", {}).get( + "user_api_key_team_spend", None + ) + _team_max_budget = litellm_params.get("metadata", {}).get( + "user_api_key_team_max_budget", None + ) + + _api_key_spend = litellm_params.get("metadata", {}).get( + "user_api_key_spend", None + ) + _api_key_max_budget = litellm_params.get("metadata", {}).get( + "user_api_key_max_budget", None + ) + await self._set_api_key_budget_metrics_after_api_request( + user_api_key=user_api_key, + user_api_key_alias=user_api_key_alias, + response_cost=response_cost, + key_max_budget=_api_key_max_budget, + key_spend=_api_key_spend, + ) + + await self._set_team_budget_metrics_after_api_request( + user_api_team=user_api_team, + user_api_team_alias=user_api_team_alias, + team_spend=_team_spend, + team_max_budget=_team_max_budget, + response_cost=response_cost, + ) + + def _increment_top_level_request_and_spend_metrics( + self, + end_user_id: Optional[str], + user_api_key: Optional[str], + user_api_key_alias: Optional[str], + model: Optional[str], + user_api_team: Optional[str], + user_api_team_alias: Optional[str], + user_id: Optional[str], + response_cost: float, + enum_values: UserAPIKeyLabelValues, + ): + _labels = prometheus_label_factory( + supported_enum_labels=PrometheusMetricLabels.get_labels( + label_name="litellm_requests_metric" + ), + enum_values=enum_values, + ) + self.litellm_requests_metric.labels(**_labels).inc() + + self.litellm_spend_metric.labels( + end_user_id, + user_api_key, + user_api_key_alias, + model, + user_api_team, + user_api_team_alias, + user_id, + ).inc(response_cost) + + def _set_virtual_key_rate_limit_metrics( + self, + user_api_key: Optional[str], + user_api_key_alias: Optional[str], + kwargs: dict, + metadata: dict, + ): + from litellm.proxy.common_utils.callback_utils import ( + get_model_group_from_litellm_kwargs, + ) + + # Set remaining rpm/tpm for API Key + model + # see parallel_request_limiter.py - variables are set there + model_group = get_model_group_from_litellm_kwargs(kwargs) + remaining_requests_variable_name = ( + f"litellm-key-remaining-requests-{model_group}" + ) + remaining_tokens_variable_name = f"litellm-key-remaining-tokens-{model_group}" + + remaining_requests = ( + metadata.get(remaining_requests_variable_name, sys.maxsize) or sys.maxsize + ) + remaining_tokens = ( + metadata.get(remaining_tokens_variable_name, sys.maxsize) or sys.maxsize + ) + + self.litellm_remaining_api_key_requests_for_model.labels( + user_api_key, user_api_key_alias, model_group + ).set(remaining_requests) + + self.litellm_remaining_api_key_tokens_for_model.labels( + user_api_key, user_api_key_alias, model_group + ).set(remaining_tokens) + + def _set_latency_metrics( + self, + kwargs: dict, + model: Optional[str], + user_api_key: Optional[str], + user_api_key_alias: Optional[str], + user_api_team: Optional[str], + user_api_team_alias: Optional[str], + enum_values: UserAPIKeyLabelValues, + ): + # latency metrics + end_time: datetime = kwargs.get("end_time") or datetime.now() + start_time: Optional[datetime] = kwargs.get("start_time") + api_call_start_time = kwargs.get("api_call_start_time", None) + completion_start_time = kwargs.get("completion_start_time", None) + time_to_first_token_seconds = self._safe_duration_seconds( + start_time=api_call_start_time, + end_time=completion_start_time, + ) + if ( + time_to_first_token_seconds is not None + and kwargs.get("stream", False) is True # only emit for streaming requests + ): + self.litellm_llm_api_time_to_first_token_metric.labels( + model, + user_api_key, + user_api_key_alias, + user_api_team, + user_api_team_alias, + ).observe(time_to_first_token_seconds) + else: + verbose_logger.debug( + "Time to first token metric not emitted, stream option in model_parameters is not True" + ) + + api_call_total_time_seconds = self._safe_duration_seconds( + start_time=api_call_start_time, + end_time=end_time, + ) + if api_call_total_time_seconds is not None: + _labels = prometheus_label_factory( + supported_enum_labels=PrometheusMetricLabels.get_labels( + label_name="litellm_llm_api_latency_metric" + ), + enum_values=enum_values, + ) + self.litellm_llm_api_latency_metric.labels(**_labels).observe( + api_call_total_time_seconds + ) + + # total request latency + total_time_seconds = self._safe_duration_seconds( + start_time=start_time, + end_time=end_time, + ) + if total_time_seconds is not None: + _labels = prometheus_label_factory( + supported_enum_labels=PrometheusMetricLabels.get_labels( + label_name="litellm_request_total_latency_metric" + ), + enum_values=enum_values, + ) + self.litellm_request_total_latency_metric.labels(**_labels).observe( + total_time_seconds + ) + + async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): + from litellm.types.utils import StandardLoggingPayload + + verbose_logger.debug( + f"prometheus Logging - Enters failure logging function for kwargs {kwargs}" + ) + + # unpack kwargs + model = kwargs.get("model", "") + standard_logging_payload: StandardLoggingPayload = kwargs.get( + "standard_logging_object", {} + ) + litellm_params = kwargs.get("litellm_params", {}) or {} + end_user_id = get_end_user_id_for_cost_tracking( + litellm_params, service_type="prometheus" + ) + user_id = standard_logging_payload["metadata"]["user_api_key_user_id"] + user_api_key = standard_logging_payload["metadata"]["user_api_key_hash"] + user_api_key_alias = standard_logging_payload["metadata"]["user_api_key_alias"] + user_api_team = standard_logging_payload["metadata"]["user_api_key_team_id"] + user_api_team_alias = standard_logging_payload["metadata"][ + "user_api_key_team_alias" + ] + kwargs.get("exception", None) + + try: + self.litellm_llm_api_failed_requests_metric.labels( + end_user_id, + user_api_key, + user_api_key_alias, + model, + user_api_team, + user_api_team_alias, + user_id, + ).inc() + self.set_llm_deployment_failure_metrics(kwargs) + except Exception as e: + verbose_logger.exception( + "prometheus Layer Error(): Exception occured - {}".format(str(e)) + ) + pass + pass + + async def async_post_call_failure_hook( + self, + request_data: dict, + original_exception: Exception, + user_api_key_dict: UserAPIKeyAuth, + ): + """ + Track client side failures + + Proxy level tracking - failed client side requests + + labelnames=[ + "end_user", + "hashed_api_key", + "api_key_alias", + REQUESTED_MODEL, + "team", + "team_alias", + ] + EXCEPTION_LABELS, + """ + try: + _tags = cast(List[str], request_data.get("tags") or []) + enum_values = UserAPIKeyLabelValues( + end_user=user_api_key_dict.end_user_id, + user=user_api_key_dict.user_id, + user_email=user_api_key_dict.user_email, + hashed_api_key=user_api_key_dict.api_key, + api_key_alias=user_api_key_dict.key_alias, + team=user_api_key_dict.team_id, + team_alias=user_api_key_dict.team_alias, + requested_model=request_data.get("model", ""), + status_code=str(getattr(original_exception, "status_code", None)), + exception_status=str(getattr(original_exception, "status_code", None)), + exception_class=str(original_exception.__class__.__name__), + tags=_tags, + ) + _labels = prometheus_label_factory( + supported_enum_labels=PrometheusMetricLabels.get_labels( + label_name="litellm_proxy_failed_requests_metric" + ), + enum_values=enum_values, + ) + self.litellm_proxy_failed_requests_metric.labels(**_labels).inc() + + _labels = prometheus_label_factory( + supported_enum_labels=PrometheusMetricLabels.get_labels( + label_name="litellm_proxy_total_requests_metric" + ), + enum_values=enum_values, + ) + self.litellm_proxy_total_requests_metric.labels(**_labels).inc() + + except Exception as e: + verbose_logger.exception( + "prometheus Layer Error(): Exception occured - {}".format(str(e)) + ) + pass + + async def async_post_call_success_hook( + self, data: dict, user_api_key_dict: UserAPIKeyAuth, response + ): + """ + Proxy level tracking - triggered when the proxy responds with a success response to the client + """ + try: + enum_values = UserAPIKeyLabelValues( + end_user=user_api_key_dict.end_user_id, + hashed_api_key=user_api_key_dict.api_key, + api_key_alias=user_api_key_dict.key_alias, + requested_model=data.get("model", ""), + team=user_api_key_dict.team_id, + team_alias=user_api_key_dict.team_alias, + user=user_api_key_dict.user_id, + user_email=user_api_key_dict.user_email, + status_code="200", + ) + _labels = prometheus_label_factory( + supported_enum_labels=PrometheusMetricLabels.get_labels( + label_name="litellm_proxy_total_requests_metric" + ), + enum_values=enum_values, + ) + self.litellm_proxy_total_requests_metric.labels(**_labels).inc() + + except Exception as e: + verbose_logger.exception( + "prometheus Layer Error(): Exception occured - {}".format(str(e)) + ) + pass + + def set_llm_deployment_failure_metrics(self, request_kwargs: dict): + """ + Sets Failure metrics when an LLM API call fails + + - mark the deployment as partial outage + - increment deployment failure responses metric + - increment deployment total requests metric + + Args: + request_kwargs: dict + + """ + try: + verbose_logger.debug("setting remaining tokens requests metric") + standard_logging_payload: StandardLoggingPayload = request_kwargs.get( + "standard_logging_object", {} + ) + _litellm_params = request_kwargs.get("litellm_params", {}) or {} + litellm_model_name = request_kwargs.get("model", None) + model_group = standard_logging_payload.get("model_group", None) + api_base = standard_logging_payload.get("api_base", None) + model_id = standard_logging_payload.get("model_id", None) + exception: Exception = request_kwargs.get("exception", None) + + llm_provider = _litellm_params.get("custom_llm_provider", None) + + """ + log these labels + ["litellm_model_name", "model_id", "api_base", "api_provider"] + """ + self.set_deployment_partial_outage( + litellm_model_name=litellm_model_name, + model_id=model_id, + api_base=api_base, + api_provider=llm_provider, + ) + self.litellm_deployment_failure_responses.labels( + litellm_model_name=litellm_model_name, + model_id=model_id, + api_base=api_base, + api_provider=llm_provider, + exception_status=str(getattr(exception, "status_code", None)), + exception_class=exception.__class__.__name__, + requested_model=model_group, + hashed_api_key=standard_logging_payload["metadata"][ + "user_api_key_hash" + ], + api_key_alias=standard_logging_payload["metadata"][ + "user_api_key_alias" + ], + team=standard_logging_payload["metadata"]["user_api_key_team_id"], + team_alias=standard_logging_payload["metadata"][ + "user_api_key_team_alias" + ], + ).inc() + + # tag based tracking + if standard_logging_payload is not None and isinstance( + standard_logging_payload, dict + ): + _tags = standard_logging_payload["request_tags"] + for tag in _tags: + self.litellm_deployment_failure_by_tag_responses.labels( + **{ + UserAPIKeyLabelNames.REQUESTED_MODEL.value: model_group, + UserAPIKeyLabelNames.TAG.value: tag, + UserAPIKeyLabelNames.v2_LITELLM_MODEL_NAME.value: litellm_model_name, + UserAPIKeyLabelNames.MODEL_ID.value: model_id, + UserAPIKeyLabelNames.API_BASE.value: api_base, + UserAPIKeyLabelNames.API_PROVIDER.value: llm_provider, + UserAPIKeyLabelNames.EXCEPTION_CLASS.value: exception.__class__.__name__, + UserAPIKeyLabelNames.EXCEPTION_STATUS.value: str( + getattr(exception, "status_code", None) + ), + } + ).inc() + + self.litellm_deployment_total_requests.labels( + litellm_model_name=litellm_model_name, + model_id=model_id, + api_base=api_base, + api_provider=llm_provider, + requested_model=model_group, + hashed_api_key=standard_logging_payload["metadata"][ + "user_api_key_hash" + ], + api_key_alias=standard_logging_payload["metadata"][ + "user_api_key_alias" + ], + team=standard_logging_payload["metadata"]["user_api_key_team_id"], + team_alias=standard_logging_payload["metadata"][ + "user_api_key_team_alias" + ], + ).inc() + + pass + except Exception as e: + verbose_logger.debug( + "Prometheus Error: set_llm_deployment_failure_metrics. Exception occured - {}".format( + str(e) + ) + ) + + def set_llm_deployment_success_metrics( + self, + request_kwargs: dict, + start_time, + end_time, + enum_values: UserAPIKeyLabelValues, + output_tokens: float = 1.0, + ): + try: + verbose_logger.debug("setting remaining tokens requests metric") + standard_logging_payload: Optional[StandardLoggingPayload] = ( + request_kwargs.get("standard_logging_object") + ) + + if standard_logging_payload is None: + return + + model_group = standard_logging_payload["model_group"] + api_base = standard_logging_payload["api_base"] + _response_headers = request_kwargs.get("response_headers") + _litellm_params = request_kwargs.get("litellm_params", {}) or {} + _metadata = _litellm_params.get("metadata", {}) + litellm_model_name = request_kwargs.get("model", None) + llm_provider = _litellm_params.get("custom_llm_provider", None) + _model_info = _metadata.get("model_info") or {} + model_id = _model_info.get("id", None) + + remaining_requests: Optional[int] = None + remaining_tokens: Optional[int] = None + if additional_headers := standard_logging_payload["hidden_params"][ + "additional_headers" + ]: + # OpenAI / OpenAI Compatible headers + remaining_requests = additional_headers.get( + "x_ratelimit_remaining_requests", None + ) + remaining_tokens = additional_headers.get( + "x_ratelimit_remaining_tokens", None + ) + + if litellm_overhead_time_ms := standard_logging_payload[ + "hidden_params" + ].get("litellm_overhead_time_ms"): + self.litellm_overhead_latency_metric.labels( + model_group, + llm_provider, + api_base, + litellm_model_name, + standard_logging_payload["metadata"]["user_api_key_hash"], + standard_logging_payload["metadata"]["user_api_key_alias"], + ).observe( + litellm_overhead_time_ms / 1000 + ) # set as seconds + + if remaining_requests: + """ + "model_group", + "api_provider", + "api_base", + "litellm_model_name" + """ + self.litellm_remaining_requests_metric.labels( + model_group, + llm_provider, + api_base, + litellm_model_name, + standard_logging_payload["metadata"]["user_api_key_hash"], + standard_logging_payload["metadata"]["user_api_key_alias"], + ).set(remaining_requests) + + if remaining_tokens: + self.litellm_remaining_tokens_metric.labels( + model_group, + llm_provider, + api_base, + litellm_model_name, + standard_logging_payload["metadata"]["user_api_key_hash"], + standard_logging_payload["metadata"]["user_api_key_alias"], + ).set(remaining_tokens) + + """ + log these labels + ["litellm_model_name", "requested_model", model_id", "api_base", "api_provider"] + """ + self.set_deployment_healthy( + litellm_model_name=litellm_model_name, + model_id=model_id, + api_base=api_base, + api_provider=llm_provider, + ) + + self.litellm_deployment_success_responses.labels( + litellm_model_name=litellm_model_name, + model_id=model_id, + api_base=api_base, + api_provider=llm_provider, + requested_model=model_group, + hashed_api_key=standard_logging_payload["metadata"][ + "user_api_key_hash" + ], + api_key_alias=standard_logging_payload["metadata"][ + "user_api_key_alias" + ], + team=standard_logging_payload["metadata"]["user_api_key_team_id"], + team_alias=standard_logging_payload["metadata"][ + "user_api_key_team_alias" + ], + ).inc() + + self.litellm_deployment_total_requests.labels( + litellm_model_name=litellm_model_name, + model_id=model_id, + api_base=api_base, + api_provider=llm_provider, + requested_model=model_group, + hashed_api_key=standard_logging_payload["metadata"][ + "user_api_key_hash" + ], + api_key_alias=standard_logging_payload["metadata"][ + "user_api_key_alias" + ], + team=standard_logging_payload["metadata"]["user_api_key_team_id"], + team_alias=standard_logging_payload["metadata"][ + "user_api_key_team_alias" + ], + ).inc() + + # Track deployment Latency + response_ms: timedelta = end_time - start_time + time_to_first_token_response_time: Optional[timedelta] = None + + if ( + request_kwargs.get("stream", None) is not None + and request_kwargs["stream"] is True + ): + # only log ttft for streaming request + time_to_first_token_response_time = ( + request_kwargs.get("completion_start_time", end_time) - start_time + ) + + # use the metric that is not None + # if streaming - use time_to_first_token_response + # if not streaming - use response_ms + _latency: timedelta = time_to_first_token_response_time or response_ms + _latency_seconds = _latency.total_seconds() + + # latency per output token + latency_per_token = None + if output_tokens is not None and output_tokens > 0: + latency_per_token = _latency_seconds / output_tokens + _labels = prometheus_label_factory( + supported_enum_labels=PrometheusMetricLabels.get_labels( + label_name="litellm_deployment_latency_per_output_token" + ), + enum_values=enum_values, + ) + self.litellm_deployment_latency_per_output_token.labels( + **_labels + ).observe(latency_per_token) + + except Exception as e: + verbose_logger.error( + "Prometheus Error: set_llm_deployment_success_metrics. Exception occured - {}".format( + str(e) + ) + ) + return + + async def log_success_fallback_event( + self, original_model_group: str, kwargs: dict, original_exception: Exception + ): + """ + + Logs a successful LLM fallback event on prometheus + + """ + from litellm.litellm_core_utils.litellm_logging import ( + StandardLoggingMetadata, + StandardLoggingPayloadSetup, + ) + + verbose_logger.debug( + "Prometheus: log_success_fallback_event, original_model_group: %s, kwargs: %s", + original_model_group, + kwargs, + ) + _metadata = kwargs.get("metadata", {}) + standard_metadata: StandardLoggingMetadata = ( + StandardLoggingPayloadSetup.get_standard_logging_metadata( + metadata=_metadata + ) + ) + _new_model = kwargs.get("model") + _tags = cast(List[str], kwargs.get("tags") or []) + + enum_values = UserAPIKeyLabelValues( + requested_model=original_model_group, + fallback_model=_new_model, + hashed_api_key=standard_metadata["user_api_key_hash"], + api_key_alias=standard_metadata["user_api_key_alias"], + team=standard_metadata["user_api_key_team_id"], + team_alias=standard_metadata["user_api_key_team_alias"], + exception_status=str(getattr(original_exception, "status_code", None)), + exception_class=str(original_exception.__class__.__name__), + tags=_tags, + ) + _labels = prometheus_label_factory( + supported_enum_labels=PrometheusMetricLabels.get_labels( + label_name="litellm_deployment_successful_fallbacks" + ), + enum_values=enum_values, + ) + self.litellm_deployment_successful_fallbacks.labels(**_labels).inc() + + async def log_failure_fallback_event( + self, original_model_group: str, kwargs: dict, original_exception: Exception + ): + """ + Logs a failed LLM fallback event on prometheus + """ + from litellm.litellm_core_utils.litellm_logging import ( + StandardLoggingMetadata, + StandardLoggingPayloadSetup, + ) + + verbose_logger.debug( + "Prometheus: log_failure_fallback_event, original_model_group: %s, kwargs: %s", + original_model_group, + kwargs, + ) + _new_model = kwargs.get("model") + _metadata = kwargs.get("metadata", {}) + _tags = cast(List[str], kwargs.get("tags") or []) + standard_metadata: StandardLoggingMetadata = ( + StandardLoggingPayloadSetup.get_standard_logging_metadata( + metadata=_metadata + ) + ) + + enum_values = UserAPIKeyLabelValues( + requested_model=original_model_group, + fallback_model=_new_model, + hashed_api_key=standard_metadata["user_api_key_hash"], + api_key_alias=standard_metadata["user_api_key_alias"], + team=standard_metadata["user_api_key_team_id"], + team_alias=standard_metadata["user_api_key_team_alias"], + exception_status=str(getattr(original_exception, "status_code", None)), + exception_class=str(original_exception.__class__.__name__), + tags=_tags, + ) + + _labels = prometheus_label_factory( + supported_enum_labels=PrometheusMetricLabels.get_labels( + label_name="litellm_deployment_failed_fallbacks" + ), + enum_values=enum_values, + ) + self.litellm_deployment_failed_fallbacks.labels(**_labels).inc() + + def set_litellm_deployment_state( + self, + state: int, + litellm_model_name: str, + model_id: Optional[str], + api_base: Optional[str], + api_provider: str, + ): + self.litellm_deployment_state.labels( + litellm_model_name, model_id, api_base, api_provider + ).set(state) + + def set_deployment_healthy( + self, + litellm_model_name: str, + model_id: str, + api_base: str, + api_provider: str, + ): + self.set_litellm_deployment_state( + 0, litellm_model_name, model_id, api_base, api_provider + ) + + def set_deployment_partial_outage( + self, + litellm_model_name: str, + model_id: Optional[str], + api_base: Optional[str], + api_provider: str, + ): + self.set_litellm_deployment_state( + 1, litellm_model_name, model_id, api_base, api_provider + ) + + def set_deployment_complete_outage( + self, + litellm_model_name: str, + model_id: Optional[str], + api_base: Optional[str], + api_provider: str, + ): + self.set_litellm_deployment_state( + 2, litellm_model_name, model_id, api_base, api_provider + ) + + def increment_deployment_cooled_down( + self, + litellm_model_name: str, + model_id: str, + api_base: str, + api_provider: str, + exception_status: str, + ): + """ + increment metric when litellm.Router / load balancing logic places a deployment in cool down + """ + self.litellm_deployment_cooled_down.labels( + litellm_model_name, model_id, api_base, api_provider, exception_status + ).inc() + + def track_provider_remaining_budget( + self, provider: str, spend: float, budget_limit: float + ): + """ + Track provider remaining budget in Prometheus + """ + self.litellm_provider_remaining_budget_metric.labels(provider).set( + self._safe_get_remaining_budget( + max_budget=budget_limit, + spend=spend, + ) + ) + + def _safe_get_remaining_budget( + self, max_budget: Optional[float], spend: Optional[float] + ) -> float: + if max_budget is None: + return float("inf") + + if spend is None: + return max_budget + + return max_budget - spend + + def _initialize_prometheus_startup_metrics(self): + """ + Initialize prometheus startup metrics + + Helper to create tasks for initializing metrics that are required on startup - eg. remaining budget metrics + """ + if litellm.prometheus_initialize_budget_metrics is not True: + verbose_logger.debug("Prometheus: skipping budget metrics initialization") + return + + try: + if asyncio.get_running_loop(): + asyncio.create_task(self._initialize_remaining_budget_metrics()) + except RuntimeError as e: # no running event loop + verbose_logger.exception( + f"No running event loop - skipping budget metrics initialization: {str(e)}" + ) + + async def _initialize_budget_metrics( + self, + data_fetch_function: Callable[..., Awaitable[Tuple[List[Any], Optional[int]]]], + set_metrics_function: Callable[[List[Any]], Awaitable[None]], + data_type: Literal["teams", "keys"], + ): + """ + Generic method to initialize budget metrics for teams or API keys. + + Args: + data_fetch_function: Function to fetch data with pagination. + set_metrics_function: Function to set metrics for the fetched data. + data_type: String representing the type of data ("teams" or "keys") for logging purposes. + """ + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + return + + try: + page = 1 + page_size = 50 + data, total_count = await data_fetch_function( + page_size=page_size, page=page + ) + + if total_count is None: + total_count = len(data) + + # Calculate total pages needed + total_pages = (total_count + page_size - 1) // page_size + + # Set metrics for first page of data + await set_metrics_function(data) + + # Get and set metrics for remaining pages + for page in range(2, total_pages + 1): + data, _ = await data_fetch_function(page_size=page_size, page=page) + await set_metrics_function(data) + + except Exception as e: + verbose_logger.exception( + f"Error initializing {data_type} budget metrics: {str(e)}" + ) + + async def _initialize_team_budget_metrics(self): + """ + Initialize team budget metrics by reusing the generic pagination logic. + """ + from litellm.proxy.management_endpoints.team_endpoints import ( + get_paginated_teams, + ) + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + verbose_logger.debug( + "Prometheus: skipping team metrics initialization, DB not initialized" + ) + return + + async def fetch_teams( + page_size: int, page: int + ) -> Tuple[List[LiteLLM_TeamTable], Optional[int]]: + teams, total_count = await get_paginated_teams( + prisma_client=prisma_client, page_size=page_size, page=page + ) + if total_count is None: + total_count = len(teams) + return teams, total_count + + await self._initialize_budget_metrics( + data_fetch_function=fetch_teams, + set_metrics_function=self._set_team_list_budget_metrics, + data_type="teams", + ) + + async def _initialize_api_key_budget_metrics(self): + """ + Initialize API key budget metrics by reusing the generic pagination logic. + """ + from typing import Union + + from litellm.constants import UI_SESSION_TOKEN_TEAM_ID + from litellm.proxy.management_endpoints.key_management_endpoints import ( + _list_key_helper, + ) + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + verbose_logger.debug( + "Prometheus: skipping key metrics initialization, DB not initialized" + ) + return + + async def fetch_keys( + page_size: int, page: int + ) -> Tuple[List[Union[str, UserAPIKeyAuth]], Optional[int]]: + key_list_response = await _list_key_helper( + prisma_client=prisma_client, + page=page, + size=page_size, + user_id=None, + team_id=None, + key_alias=None, + exclude_team_id=UI_SESSION_TOKEN_TEAM_ID, + return_full_object=True, + organization_id=None, + ) + keys = key_list_response.get("keys", []) + total_count = key_list_response.get("total_count") + if total_count is None: + total_count = len(keys) + return keys, total_count + + await self._initialize_budget_metrics( + data_fetch_function=fetch_keys, + set_metrics_function=self._set_key_list_budget_metrics, + data_type="keys", + ) + + async def _initialize_remaining_budget_metrics(self): + """ + Initialize remaining budget metrics for all teams to avoid metric discrepancies. + + Runs when prometheus logger starts up. + """ + await self._initialize_team_budget_metrics() + await self._initialize_api_key_budget_metrics() + + async def _set_key_list_budget_metrics( + self, keys: List[Union[str, UserAPIKeyAuth]] + ): + """Helper function to set budget metrics for a list of keys""" + for key in keys: + if isinstance(key, UserAPIKeyAuth): + self._set_key_budget_metrics(key) + + async def _set_team_list_budget_metrics(self, teams: List[LiteLLM_TeamTable]): + """Helper function to set budget metrics for a list of teams""" + for team in teams: + self._set_team_budget_metrics(team) + + async def _set_team_budget_metrics_after_api_request( + self, + user_api_team: Optional[str], + user_api_team_alias: Optional[str], + team_spend: float, + team_max_budget: float, + response_cost: float, + ): + """ + Set team budget metrics after an LLM API request + + - Assemble a LiteLLM_TeamTable object + - looks up team info from db if not available in metadata + - Set team budget metrics + """ + if user_api_team: + team_object = await self._assemble_team_object( + team_id=user_api_team, + team_alias=user_api_team_alias or "", + spend=team_spend, + max_budget=team_max_budget, + response_cost=response_cost, + ) + + self._set_team_budget_metrics(team_object) + + async def _assemble_team_object( + self, + team_id: str, + team_alias: str, + spend: Optional[float], + max_budget: Optional[float], + response_cost: float, + ) -> LiteLLM_TeamTable: + """ + Assemble a LiteLLM_TeamTable object + + for fields not available in metadata, we fetch from db + Fields not available in metadata: + - `budget_reset_at` + """ + from litellm.proxy.auth.auth_checks import get_team_object + from litellm.proxy.proxy_server import prisma_client, user_api_key_cache + + _total_team_spend = (spend or 0) + response_cost + team_object = LiteLLM_TeamTable( + team_id=team_id, + team_alias=team_alias, + spend=_total_team_spend, + max_budget=max_budget, + ) + try: + team_info = await get_team_object( + team_id=team_id, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + ) + except Exception as e: + verbose_logger.debug( + f"[Non-Blocking] Prometheus: Error getting team info: {str(e)}" + ) + return team_object + + if team_info: + team_object.budget_reset_at = team_info.budget_reset_at + + return team_object + + def _set_team_budget_metrics( + self, + team: LiteLLM_TeamTable, + ): + """ + Set team budget metrics for a single team + + - Remaining Budget + - Max Budget + - Budget Reset At + """ + enum_values = UserAPIKeyLabelValues( + team=team.team_id, + team_alias=team.team_alias or "", + ) + + _labels = prometheus_label_factory( + supported_enum_labels=PrometheusMetricLabels.get_labels( + label_name="litellm_remaining_team_budget_metric" + ), + enum_values=enum_values, + ) + self.litellm_remaining_team_budget_metric.labels(**_labels).set( + self._safe_get_remaining_budget( + max_budget=team.max_budget, + spend=team.spend, + ) + ) + + if team.max_budget is not None: + _labels = prometheus_label_factory( + supported_enum_labels=PrometheusMetricLabels.get_labels( + label_name="litellm_team_max_budget_metric" + ), + enum_values=enum_values, + ) + self.litellm_team_max_budget_metric.labels(**_labels).set(team.max_budget) + + if team.budget_reset_at is not None: + _labels = prometheus_label_factory( + supported_enum_labels=PrometheusMetricLabels.get_labels( + label_name="litellm_team_budget_remaining_hours_metric" + ), + enum_values=enum_values, + ) + self.litellm_team_budget_remaining_hours_metric.labels(**_labels).set( + self._get_remaining_hours_for_budget_reset( + budget_reset_at=team.budget_reset_at + ) + ) + + def _set_key_budget_metrics(self, user_api_key_dict: UserAPIKeyAuth): + """ + Set virtual key budget metrics + + - Remaining Budget + - Max Budget + - Budget Reset At + """ + enum_values = UserAPIKeyLabelValues( + hashed_api_key=user_api_key_dict.token, + api_key_alias=user_api_key_dict.key_alias or "", + ) + _labels = prometheus_label_factory( + supported_enum_labels=PrometheusMetricLabels.get_labels( + label_name="litellm_remaining_api_key_budget_metric" + ), + enum_values=enum_values, + ) + self.litellm_remaining_api_key_budget_metric.labels(**_labels).set( + self._safe_get_remaining_budget( + max_budget=user_api_key_dict.max_budget, + spend=user_api_key_dict.spend, + ) + ) + + if user_api_key_dict.max_budget is not None: + _labels = prometheus_label_factory( + supported_enum_labels=PrometheusMetricLabels.get_labels( + label_name="litellm_api_key_max_budget_metric" + ), + enum_values=enum_values, + ) + self.litellm_api_key_max_budget_metric.labels(**_labels).set( + user_api_key_dict.max_budget + ) + + if user_api_key_dict.budget_reset_at is not None: + self.litellm_api_key_budget_remaining_hours_metric.labels(**_labels).set( + self._get_remaining_hours_for_budget_reset( + budget_reset_at=user_api_key_dict.budget_reset_at + ) + ) + + async def _set_api_key_budget_metrics_after_api_request( + self, + user_api_key: Optional[str], + user_api_key_alias: Optional[str], + response_cost: float, + key_max_budget: float, + key_spend: Optional[float], + ): + if user_api_key: + user_api_key_dict = await self._assemble_key_object( + user_api_key=user_api_key, + user_api_key_alias=user_api_key_alias or "", + key_max_budget=key_max_budget, + key_spend=key_spend, + response_cost=response_cost, + ) + self._set_key_budget_metrics(user_api_key_dict) + + async def _assemble_key_object( + self, + user_api_key: str, + user_api_key_alias: str, + key_max_budget: float, + key_spend: Optional[float], + response_cost: float, + ) -> UserAPIKeyAuth: + """ + Assemble a UserAPIKeyAuth object + """ + from litellm.proxy.auth.auth_checks import get_key_object + from litellm.proxy.proxy_server import prisma_client, user_api_key_cache + + _total_key_spend = (key_spend or 0) + response_cost + user_api_key_dict = UserAPIKeyAuth( + token=user_api_key, + key_alias=user_api_key_alias, + max_budget=key_max_budget, + spend=_total_key_spend, + ) + try: + if user_api_key_dict.token: + key_object = await get_key_object( + hashed_token=user_api_key_dict.token, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + ) + if key_object: + user_api_key_dict.budget_reset_at = key_object.budget_reset_at + except Exception as e: + verbose_logger.debug( + f"[Non-Blocking] Prometheus: Error getting key info: {str(e)}" + ) + + return user_api_key_dict + + def _get_remaining_hours_for_budget_reset(self, budget_reset_at: datetime) -> float: + """ + Get remaining hours for budget reset + """ + return ( + budget_reset_at - datetime.now(budget_reset_at.tzinfo) + ).total_seconds() / 3600 + + def _safe_duration_seconds( + self, + start_time: Any, + end_time: Any, + ) -> Optional[float]: + """ + Compute the duration in seconds between two objects. + + Returns the duration as a float if both start and end are instances of datetime, + otherwise returns None. + """ + if isinstance(start_time, datetime) and isinstance(end_time, datetime): + return (end_time - start_time).total_seconds() + return None + + +def prometheus_label_factory( + supported_enum_labels: List[str], + enum_values: UserAPIKeyLabelValues, + tag: Optional[str] = None, +) -> dict: + """ + Returns a dictionary of label + values for prometheus. + + Ensures end_user param is not sent to prometheus if it is not supported. + """ + # Extract dictionary from Pydantic object + enum_dict = enum_values.model_dump() + + # Filter supported labels + filtered_labels = { + label: value + for label, value in enum_dict.items() + if label in supported_enum_labels + } + + if UserAPIKeyLabelNames.END_USER.value in filtered_labels: + filtered_labels["end_user"] = get_end_user_id_for_cost_tracking( + litellm_params={"user_api_key_end_user_id": enum_values.end_user}, + service_type="prometheus", + ) + + if enum_values.custom_metadata_labels is not None: + for key, value in enum_values.custom_metadata_labels.items(): + if key in supported_enum_labels: + filtered_labels[key] = value + + for label in supported_enum_labels: + if label not in filtered_labels: + filtered_labels[label] = None + + return filtered_labels + + +def get_custom_labels_from_metadata(metadata: dict) -> Dict[str, str]: + """ + Get custom labels from metadata + """ + keys = litellm.custom_prometheus_metadata_labels + if keys is None or len(keys) == 0: + return {} + + result: Dict[str, str] = {} + + for key in keys: + # Split the dot notation key into parts + original_key = key + key = key.replace("metadata.", "", 1) if key.startswith("metadata.") else key + + keys_parts = key.split(".") + # Traverse through the dictionary using the parts + value = metadata + for part in keys_parts: + value = value.get(part, None) # Get the value, return None if not found + if value is None: + break + + if value is not None and isinstance(value, str): + result[original_key.replace(".", "_")] = value + + return result diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/prometheus_helpers/prometheus_api.py b/.venv/lib/python3.12/site-packages/litellm/integrations/prometheus_helpers/prometheus_api.py new file mode 100644 index 00000000..b25da577 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/integrations/prometheus_helpers/prometheus_api.py @@ -0,0 +1,137 @@ +""" +Helper functions to query prometheus API +""" + +import time +from datetime import datetime, timedelta +from typing import Optional + +from litellm import get_secret +from litellm._logging import verbose_logger +from litellm.llms.custom_httpx.http_handler import ( + get_async_httpx_client, + httpxSpecialProvider, +) + +PROMETHEUS_URL: Optional[str] = get_secret("PROMETHEUS_URL") # type: ignore +PROMETHEUS_SELECTED_INSTANCE: Optional[str] = get_secret("PROMETHEUS_SELECTED_INSTANCE") # type: ignore +async_http_handler = get_async_httpx_client( + llm_provider=httpxSpecialProvider.LoggingCallback +) + + +async def get_metric_from_prometheus( + metric_name: str, +): + # Get the start of the current day in Unix timestamp + if PROMETHEUS_URL is None: + raise ValueError( + "PROMETHEUS_URL not set please set 'PROMETHEUS_URL=<>' in .env" + ) + + query = f"{metric_name}[24h]" + now = int(time.time()) + response = await async_http_handler.get( + f"{PROMETHEUS_URL}/api/v1/query", params={"query": query, "time": now} + ) # End of the day + _json_response = response.json() + verbose_logger.debug("json response from prometheus /query api %s", _json_response) + results = response.json()["data"]["result"] + return results + + +async def get_fallback_metric_from_prometheus(): + """ + Gets fallback metrics from prometheus for the last 24 hours + """ + response_message = "" + relevant_metrics = [ + "litellm_deployment_successful_fallbacks_total", + "litellm_deployment_failed_fallbacks_total", + ] + for metric in relevant_metrics: + response_json = await get_metric_from_prometheus( + metric_name=metric, + ) + + if response_json: + verbose_logger.debug("response json %s", response_json) + for result in response_json: + verbose_logger.debug("result= %s", result) + metric = result["metric"] + metric_values = result["values"] + most_recent_value = metric_values[0] + + if PROMETHEUS_SELECTED_INSTANCE is not None: + if metric.get("instance") != PROMETHEUS_SELECTED_INSTANCE: + continue + + value = int(float(most_recent_value[1])) # Convert value to integer + primary_model = metric.get("primary_model", "Unknown") + fallback_model = metric.get("fallback_model", "Unknown") + response_message += f"`{value} successful fallback requests` with primary model=`{primary_model}` -> fallback model=`{fallback_model}`" + response_message += "\n" + verbose_logger.debug("response message %s", response_message) + return response_message + + +def is_prometheus_connected() -> bool: + if PROMETHEUS_URL is not None: + return True + return False + + +async def get_daily_spend_from_prometheus(api_key: Optional[str]): + """ + Expected Response Format: + [ + { + "date": "2024-08-18T00:00:00+00:00", + "spend": 1.001818099998933 + }, + ...] + """ + if PROMETHEUS_URL is None: + raise ValueError( + "PROMETHEUS_URL not set please set 'PROMETHEUS_URL=<>' in .env" + ) + + # Calculate the start and end dates for the last 30 days + end_date = datetime.utcnow() + start_date = end_date - timedelta(days=30) + + # Format dates as ISO 8601 strings with UTC offset + start_str = start_date.isoformat() + "+00:00" + end_str = end_date.isoformat() + "+00:00" + + url = f"{PROMETHEUS_URL}/api/v1/query_range" + + if api_key is None: + query = "sum(delta(litellm_spend_metric_total[1d]))" + else: + query = ( + f'sum(delta(litellm_spend_metric_total{{hashed_api_key="{api_key}"}}[1d]))' + ) + + params = { + "query": query, + "start": start_str, + "end": end_str, + "step": "86400", # Step size of 1 day in seconds + } + + response = await async_http_handler.get(url, params=params) + _json_response = response.json() + verbose_logger.debug("json response from prometheus /query api %s", _json_response) + results = response.json()["data"]["result"] + formatted_results = [] + + for result in results: + metric_data = result["values"] + for timestamp, value in metric_data: + # Convert timestamp to ISO 8601 string with UTC offset + date = datetime.fromtimestamp(float(timestamp)).isoformat() + "+00:00" + spend = float(value) + formatted_results.append({"date": date, "spend": spend}) + + return formatted_results diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/prometheus_services.py b/.venv/lib/python3.12/site-packages/litellm/integrations/prometheus_services.py new file mode 100644 index 00000000..4bf293fb --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/integrations/prometheus_services.py @@ -0,0 +1,222 @@ +# used for monitoring litellm services health on `/metrics` endpoint on LiteLLM Proxy +#### What this does #### +# On success + failure, log events to Prometheus for litellm / adjacent services (litellm, redis, postgres, llm api providers) + + +from typing import List, Optional, Union + +from litellm._logging import print_verbose, verbose_logger +from litellm.types.integrations.prometheus import LATENCY_BUCKETS +from litellm.types.services import ServiceLoggerPayload, ServiceTypes + +FAILED_REQUESTS_LABELS = ["error_class", "function_name"] + + +class PrometheusServicesLogger: + # Class variables or attributes + litellm_service_latency = None # Class-level attribute to store the Histogram + + def __init__( + self, + mock_testing: bool = False, + **kwargs, + ): + try: + try: + from prometheus_client import REGISTRY, Counter, Histogram + except ImportError: + raise Exception( + "Missing prometheus_client. Run `pip install prometheus-client`" + ) + + self.Histogram = Histogram + self.Counter = Counter + self.REGISTRY = REGISTRY + + verbose_logger.debug("in init prometheus services metrics") + + self.services = [item.value for item in ServiceTypes] + + self.payload_to_prometheus_map = ( + {} + ) # store the prometheus histogram/counter we need to call for each field in payload + + for service in self.services: + histogram = self.create_histogram(service, type_of_request="latency") + counter_failed_request = self.create_counter( + service, + type_of_request="failed_requests", + additional_labels=FAILED_REQUESTS_LABELS, + ) + counter_total_requests = self.create_counter( + service, type_of_request="total_requests" + ) + self.payload_to_prometheus_map[service] = [ + histogram, + counter_failed_request, + counter_total_requests, + ] + + self.prometheus_to_amount_map: dict = ( + {} + ) # the field / value in ServiceLoggerPayload the object needs to be incremented by + + ### MOCK TESTING ### + self.mock_testing = mock_testing + self.mock_testing_success_calls = 0 + self.mock_testing_failure_calls = 0 + + except Exception as e: + print_verbose(f"Got exception on init prometheus client {str(e)}") + raise e + + def is_metric_registered(self, metric_name) -> bool: + for metric in self.REGISTRY.collect(): + if metric_name == metric.name: + return True + return False + + def _get_metric(self, metric_name): + """ + Helper function to get a metric from the registry by name. + """ + return self.REGISTRY._names_to_collectors.get(metric_name) + + def create_histogram(self, service: str, type_of_request: str): + metric_name = "litellm_{}_{}".format(service, type_of_request) + is_registered = self.is_metric_registered(metric_name) + if is_registered: + return self._get_metric(metric_name) + return self.Histogram( + metric_name, + "Latency for {} service".format(service), + labelnames=[service], + buckets=LATENCY_BUCKETS, + ) + + def create_counter( + self, + service: str, + type_of_request: str, + additional_labels: Optional[List[str]] = None, + ): + metric_name = "litellm_{}_{}".format(service, type_of_request) + is_registered = self.is_metric_registered(metric_name) + if is_registered: + return self._get_metric(metric_name) + return self.Counter( + metric_name, + "Total {} for {} service".format(type_of_request, service), + labelnames=[service] + (additional_labels or []), + ) + + def observe_histogram( + self, + histogram, + labels: str, + amount: float, + ): + assert isinstance(histogram, self.Histogram) + + histogram.labels(labels).observe(amount) + + def increment_counter( + self, + counter, + labels: str, + amount: float, + additional_labels: Optional[List[str]] = [], + ): + assert isinstance(counter, self.Counter) + + if additional_labels: + counter.labels(labels, *additional_labels).inc(amount) + else: + counter.labels(labels).inc(amount) + + def service_success_hook(self, payload: ServiceLoggerPayload): + if self.mock_testing: + self.mock_testing_success_calls += 1 + + if payload.service.value in self.payload_to_prometheus_map: + prom_objects = self.payload_to_prometheus_map[payload.service.value] + for obj in prom_objects: + if isinstance(obj, self.Histogram): + self.observe_histogram( + histogram=obj, + labels=payload.service.value, + amount=payload.duration, + ) + elif isinstance(obj, self.Counter) and "total_requests" in obj._name: + self.increment_counter( + counter=obj, + labels=payload.service.value, + amount=1, # LOG TOTAL REQUESTS TO PROMETHEUS + ) + + def service_failure_hook(self, payload: ServiceLoggerPayload): + if self.mock_testing: + self.mock_testing_failure_calls += 1 + + if payload.service.value in self.payload_to_prometheus_map: + prom_objects = self.payload_to_prometheus_map[payload.service.value] + for obj in prom_objects: + if isinstance(obj, self.Counter): + self.increment_counter( + counter=obj, + labels=payload.service.value, + amount=1, # LOG ERROR COUNT / TOTAL REQUESTS TO PROMETHEUS + ) + + async def async_service_success_hook(self, payload: ServiceLoggerPayload): + """ + Log successful call to prometheus + """ + if self.mock_testing: + self.mock_testing_success_calls += 1 + + if payload.service.value in self.payload_to_prometheus_map: + prom_objects = self.payload_to_prometheus_map[payload.service.value] + for obj in prom_objects: + if isinstance(obj, self.Histogram): + self.observe_histogram( + histogram=obj, + labels=payload.service.value, + amount=payload.duration, + ) + elif isinstance(obj, self.Counter) and "total_requests" in obj._name: + self.increment_counter( + counter=obj, + labels=payload.service.value, + amount=1, # LOG TOTAL REQUESTS TO PROMETHEUS + ) + + async def async_service_failure_hook( + self, + payload: ServiceLoggerPayload, + error: Union[str, Exception], + ): + if self.mock_testing: + self.mock_testing_failure_calls += 1 + error_class = error.__class__.__name__ + function_name = payload.call_type + + if payload.service.value in self.payload_to_prometheus_map: + prom_objects = self.payload_to_prometheus_map[payload.service.value] + for obj in prom_objects: + # increment both failed and total requests + if isinstance(obj, self.Counter): + if "failed_requests" in obj._name: + self.increment_counter( + counter=obj, + labels=payload.service.value, + # log additional_labels=["error_class", "function_name"], used for debugging what's going wrong with the DB + additional_labels=[error_class, function_name], + amount=1, # LOG ERROR COUNT TO PROMETHEUS + ) + else: + self.increment_counter( + counter=obj, + labels=payload.service.value, + amount=1, # LOG TOTAL REQUESTS TO PROMETHEUS + ) diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/prompt_layer.py b/.venv/lib/python3.12/site-packages/litellm/integrations/prompt_layer.py new file mode 100644 index 00000000..190b995f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/integrations/prompt_layer.py @@ -0,0 +1,91 @@ +#### What this does #### +# On success, logs events to Promptlayer +import os +import traceback + +from pydantic import BaseModel + +import litellm + + +class PromptLayerLogger: + # Class variables or attributes + def __init__(self): + # Instance variables + self.key = os.getenv("PROMPTLAYER_API_KEY") + + def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose): + # Method definition + try: + new_kwargs = {} + new_kwargs["model"] = kwargs["model"] + new_kwargs["messages"] = kwargs["messages"] + + # add kwargs["optional_params"] to new_kwargs + for optional_param in kwargs["optional_params"]: + new_kwargs[optional_param] = kwargs["optional_params"][optional_param] + + # Extract PromptLayer tags from metadata, if such exists + tags = [] + metadata = {} + if "metadata" in kwargs["litellm_params"]: + if "pl_tags" in kwargs["litellm_params"]["metadata"]: + tags = kwargs["litellm_params"]["metadata"]["pl_tags"] + + # Remove "pl_tags" from metadata + metadata = { + k: v + for k, v in kwargs["litellm_params"]["metadata"].items() + if k != "pl_tags" + } + + print_verbose( + f"Prompt Layer Logging - Enters logging function for model kwargs: {new_kwargs}\n, response: {response_obj}" + ) + + # python-openai >= 1.0.0 returns Pydantic objects instead of jsons + if isinstance(response_obj, BaseModel): + response_obj = response_obj.model_dump() + + request_response = litellm.module_level_client.post( + "https://api.promptlayer.com/rest/track-request", + json={ + "function_name": "openai.ChatCompletion.create", + "kwargs": new_kwargs, + "tags": tags, + "request_response": dict(response_obj), + "request_start_time": int(start_time.timestamp()), + "request_end_time": int(end_time.timestamp()), + "api_key": self.key, + # Optional params for PromptLayer + # "prompt_id": "<PROMPT ID>", + # "prompt_input_variables": "<Dictionary of variables for prompt>", + # "prompt_version":1, + }, + ) + + response_json = request_response.json() + if not request_response.json().get("success", False): + raise Exception("Promptlayer did not successfully log the response!") + + print_verbose( + f"Prompt Layer Logging: success - final response object: {request_response.text}" + ) + + if "request_id" in response_json: + if metadata: + response = litellm.module_level_client.post( + "https://api.promptlayer.com/rest/track-metadata", + json={ + "request_id": response_json["request_id"], + "api_key": self.key, + "metadata": metadata, + }, + ) + print_verbose( + f"Prompt Layer Logging: success - metadata post response object: {response.text}" + ) + + except Exception: + print_verbose(f"error: Prompt Layer Error - {traceback.format_exc()}") + pass diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/prompt_management_base.py b/.venv/lib/python3.12/site-packages/litellm/integrations/prompt_management_base.py new file mode 100644 index 00000000..3fe3b31e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/integrations/prompt_management_base.py @@ -0,0 +1,118 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional, Tuple, TypedDict + +from litellm.types.llms.openai import AllMessageValues +from litellm.types.utils import StandardCallbackDynamicParams + + +class PromptManagementClient(TypedDict): + prompt_id: str + prompt_template: List[AllMessageValues] + prompt_template_model: Optional[str] + prompt_template_optional_params: Optional[Dict[str, Any]] + completed_messages: Optional[List[AllMessageValues]] + + +class PromptManagementBase(ABC): + + @property + @abstractmethod + def integration_name(self) -> str: + pass + + @abstractmethod + def should_run_prompt_management( + self, + prompt_id: str, + dynamic_callback_params: StandardCallbackDynamicParams, + ) -> bool: + pass + + @abstractmethod + def _compile_prompt_helper( + self, + prompt_id: str, + prompt_variables: Optional[dict], + dynamic_callback_params: StandardCallbackDynamicParams, + ) -> PromptManagementClient: + pass + + def merge_messages( + self, + prompt_template: List[AllMessageValues], + client_messages: List[AllMessageValues], + ) -> List[AllMessageValues]: + return prompt_template + client_messages + + def compile_prompt( + self, + prompt_id: str, + prompt_variables: Optional[dict], + client_messages: List[AllMessageValues], + dynamic_callback_params: StandardCallbackDynamicParams, + ) -> PromptManagementClient: + compiled_prompt_client = self._compile_prompt_helper( + prompt_id=prompt_id, + prompt_variables=prompt_variables, + dynamic_callback_params=dynamic_callback_params, + ) + + try: + messages = compiled_prompt_client["prompt_template"] + client_messages + except Exception as e: + raise ValueError( + f"Error compiling prompt: {e}. Prompt id={prompt_id}, prompt_variables={prompt_variables}, client_messages={client_messages}, dynamic_callback_params={dynamic_callback_params}" + ) + + compiled_prompt_client["completed_messages"] = messages + return compiled_prompt_client + + def _get_model_from_prompt( + self, prompt_management_client: PromptManagementClient, model: str + ) -> str: + if prompt_management_client["prompt_template_model"] is not None: + return prompt_management_client["prompt_template_model"] + else: + return model.replace("{}/".format(self.integration_name), "") + + def get_chat_completion_prompt( + self, + model: str, + messages: List[AllMessageValues], + non_default_params: dict, + prompt_id: str, + prompt_variables: Optional[dict], + dynamic_callback_params: StandardCallbackDynamicParams, + ) -> Tuple[ + str, + List[AllMessageValues], + dict, + ]: + if not self.should_run_prompt_management( + prompt_id=prompt_id, dynamic_callback_params=dynamic_callback_params + ): + return model, messages, non_default_params + + prompt_template = self.compile_prompt( + prompt_id=prompt_id, + prompt_variables=prompt_variables, + client_messages=messages, + dynamic_callback_params=dynamic_callback_params, + ) + + completed_messages = prompt_template["completed_messages"] or messages + + prompt_template_optional_params = ( + prompt_template["prompt_template_optional_params"] or {} + ) + + updated_non_default_params = { + **non_default_params, + **prompt_template_optional_params, + } + + model = self._get_model_from_prompt( + prompt_management_client=prompt_template, model=model + ) + + return model, completed_messages, updated_non_default_params diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/s3.py b/.venv/lib/python3.12/site-packages/litellm/integrations/s3.py new file mode 100644 index 00000000..4a0c2735 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/integrations/s3.py @@ -0,0 +1,196 @@ +#### What this does #### +# On success + failure, log events to Supabase + +from datetime import datetime +from typing import Optional, cast + +import litellm +from litellm._logging import print_verbose, verbose_logger +from litellm.types.utils import StandardLoggingPayload + + +class S3Logger: + # Class variables or attributes + def __init__( + self, + s3_bucket_name=None, + s3_path=None, + s3_region_name=None, + s3_api_version=None, + s3_use_ssl=True, + s3_verify=None, + s3_endpoint_url=None, + s3_aws_access_key_id=None, + s3_aws_secret_access_key=None, + s3_aws_session_token=None, + s3_config=None, + **kwargs, + ): + import boto3 + + try: + verbose_logger.debug( + f"in init s3 logger - s3_callback_params {litellm.s3_callback_params}" + ) + + s3_use_team_prefix = False + + if litellm.s3_callback_params is not None: + # read in .env variables - example os.environ/AWS_BUCKET_NAME + for key, value in litellm.s3_callback_params.items(): + if type(value) is str and value.startswith("os.environ/"): + litellm.s3_callback_params[key] = litellm.get_secret(value) + # now set s3 params from litellm.s3_logger_params + s3_bucket_name = litellm.s3_callback_params.get("s3_bucket_name") + s3_region_name = litellm.s3_callback_params.get("s3_region_name") + s3_api_version = litellm.s3_callback_params.get("s3_api_version") + s3_use_ssl = litellm.s3_callback_params.get("s3_use_ssl", True) + s3_verify = litellm.s3_callback_params.get("s3_verify") + s3_endpoint_url = litellm.s3_callback_params.get("s3_endpoint_url") + s3_aws_access_key_id = litellm.s3_callback_params.get( + "s3_aws_access_key_id" + ) + s3_aws_secret_access_key = litellm.s3_callback_params.get( + "s3_aws_secret_access_key" + ) + s3_aws_session_token = litellm.s3_callback_params.get( + "s3_aws_session_token" + ) + s3_config = litellm.s3_callback_params.get("s3_config") + s3_path = litellm.s3_callback_params.get("s3_path") + # done reading litellm.s3_callback_params + s3_use_team_prefix = bool( + litellm.s3_callback_params.get("s3_use_team_prefix", False) + ) + self.s3_use_team_prefix = s3_use_team_prefix + self.bucket_name = s3_bucket_name + self.s3_path = s3_path + verbose_logger.debug(f"s3 logger using endpoint url {s3_endpoint_url}") + # Create an S3 client with custom endpoint URL + self.s3_client = boto3.client( + "s3", + region_name=s3_region_name, + endpoint_url=s3_endpoint_url, + api_version=s3_api_version, + use_ssl=s3_use_ssl, + verify=s3_verify, + aws_access_key_id=s3_aws_access_key_id, + aws_secret_access_key=s3_aws_secret_access_key, + aws_session_token=s3_aws_session_token, + config=s3_config, + **kwargs, + ) + except Exception as e: + print_verbose(f"Got exception on init s3 client {str(e)}") + raise e + + async def _async_log_event( + self, kwargs, response_obj, start_time, end_time, print_verbose + ): + self.log_event(kwargs, response_obj, start_time, end_time, print_verbose) + + def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose): + try: + verbose_logger.debug( + f"s3 Logging - Enters logging function for model {kwargs}" + ) + + # construct payload to send to s3 + # follows the same params as langfuse.py + litellm_params = kwargs.get("litellm_params", {}) + metadata = ( + litellm_params.get("metadata", {}) or {} + ) # if litellm_params['metadata'] == None + + # Clean Metadata before logging - never log raw metadata + # the raw metadata can contain circular references which leads to infinite recursion + # we clean out all extra litellm metadata params before logging + clean_metadata = {} + if isinstance(metadata, dict): + for key, value in metadata.items(): + # clean litellm metadata before logging + if key in [ + "headers", + "endpoint", + "caching_groups", + "previous_models", + ]: + continue + else: + clean_metadata[key] = value + + # Ensure everything in the payload is converted to str + payload: Optional[StandardLoggingPayload] = cast( + Optional[StandardLoggingPayload], + kwargs.get("standard_logging_object", None), + ) + + if payload is None: + return + + team_alias = payload["metadata"].get("user_api_key_team_alias") + + team_alias_prefix = "" + if ( + litellm.enable_preview_features + and self.s3_use_team_prefix + and team_alias is not None + ): + team_alias_prefix = f"{team_alias}/" + + s3_file_name = litellm.utils.get_logging_id(start_time, payload) or "" + s3_object_key = get_s3_object_key( + cast(Optional[str], self.s3_path) or "", + team_alias_prefix, + start_time, + s3_file_name, + ) + + s3_object_download_filename = ( + "time-" + + start_time.strftime("%Y-%m-%dT%H-%M-%S-%f") + + "_" + + payload["id"] + + ".json" + ) + + import json + + payload_str = json.dumps(payload) + + print_verbose(f"\ns3 Logger - Logging payload = {payload_str}") + + response = self.s3_client.put_object( + Bucket=self.bucket_name, + Key=s3_object_key, + Body=payload_str, + ContentType="application/json", + ContentLanguage="en", + ContentDisposition=f'inline; filename="{s3_object_download_filename}"', + CacheControl="private, immutable, max-age=31536000, s-maxage=0", + ) + + print_verbose(f"Response from s3:{str(response)}") + + print_verbose(f"s3 Layer Logging - final response object: {response_obj}") + return response + except Exception as e: + verbose_logger.exception(f"s3 Layer Error - {str(e)}") + pass + + +def get_s3_object_key( + s3_path: str, + team_alias_prefix: str, + start_time: datetime, + s3_file_name: str, +) -> str: + s3_object_key = ( + (s3_path.rstrip("/") + "/" if s3_path else "") + + team_alias_prefix + + start_time.strftime("%Y-%m-%d") + + "/" + + s3_file_name + ) # we need the s3 key to include the time, so we log cache hits too + s3_object_key += ".json" + return s3_object_key diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/supabase.py b/.venv/lib/python3.12/site-packages/litellm/integrations/supabase.py new file mode 100644 index 00000000..7eb007f8 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/integrations/supabase.py @@ -0,0 +1,120 @@ +#### What this does #### +# On success + failure, log events to Supabase + +import os +import subprocess +import sys +import traceback + +import litellm + + +class Supabase: + # Class variables or attributes + supabase_table_name = "request_logs" + + def __init__(self): + # Instance variables + self.supabase_url = os.getenv("SUPABASE_URL") + self.supabase_key = os.getenv("SUPABASE_KEY") + try: + import supabase + except ImportError: + subprocess.check_call([sys.executable, "-m", "pip", "install", "supabase"]) + import supabase + + if self.supabase_url is None or self.supabase_key is None: + raise ValueError( + "LiteLLM Error, trying to use Supabase but url or key not passed. Create a table and set `litellm.supabase_url=<your-url>` and `litellm.supabase_key=<your-key>`" + ) + self.supabase_client = supabase.create_client( # type: ignore + self.supabase_url, self.supabase_key + ) + + def input_log_event( + self, model, messages, end_user, litellm_call_id, print_verbose + ): + try: + print_verbose( + f"Supabase Logging - Enters input logging function for model {model}" + ) + supabase_data_obj = { + "model": model, + "messages": messages, + "end_user": end_user, + "status": "initiated", + "litellm_call_id": litellm_call_id, + } + data, count = ( + self.supabase_client.table(self.supabase_table_name) + .insert(supabase_data_obj) + .execute() + ) + print_verbose(f"data: {data}") + except Exception: + print_verbose(f"Supabase Logging Error - {traceback.format_exc()}") + pass + + def log_event( + self, + model, + messages, + end_user, + response_obj, + start_time, + end_time, + litellm_call_id, + print_verbose, + ): + try: + print_verbose( + f"Supabase Logging - Enters logging function for model {model}, response_obj: {response_obj}" + ) + + total_cost = litellm.completion_cost(completion_response=response_obj) + + response_time = (end_time - start_time).total_seconds() + if "choices" in response_obj: + supabase_data_obj = { + "response_time": response_time, + "model": response_obj["model"], + "total_cost": total_cost, + "messages": messages, + "response": response_obj["choices"][0]["message"]["content"], + "end_user": end_user, + "litellm_call_id": litellm_call_id, + "status": "success", + } + print_verbose( + f"Supabase Logging - final data object: {supabase_data_obj}" + ) + data, count = ( + self.supabase_client.table(self.supabase_table_name) + .upsert(supabase_data_obj, on_conflict="litellm_call_id") + .execute() + ) + elif "error" in response_obj: + if "Unable to map your input to a model." in response_obj["error"]: + total_cost = 0 + supabase_data_obj = { + "response_time": response_time, + "model": response_obj["model"], + "total_cost": total_cost, + "messages": messages, + "error": response_obj["error"], + "end_user": end_user, + "litellm_call_id": litellm_call_id, + "status": "failure", + } + print_verbose( + f"Supabase Logging - final data object: {supabase_data_obj}" + ) + data, count = ( + self.supabase_client.table(self.supabase_table_name) + .upsert(supabase_data_obj, on_conflict="litellm_call_id") + .execute() + ) + + except Exception: + print_verbose(f"Supabase Logging Error - {traceback.format_exc()}") + pass diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/test_httpx.py b/.venv/lib/python3.12/site-packages/litellm/integrations/test_httpx.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/integrations/test_httpx.py diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/traceloop.py b/.venv/lib/python3.12/site-packages/litellm/integrations/traceloop.py new file mode 100644 index 00000000..b4f3905c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/integrations/traceloop.py @@ -0,0 +1,152 @@ +import traceback + +from litellm._logging import verbose_logger + + +class TraceloopLogger: + """ + WARNING: DEPRECATED + Use the OpenTelemetry standard integration instead + """ + + def __init__(self): + try: + from traceloop.sdk import Traceloop + from traceloop.sdk.tracing.tracing import TracerWrapper + except ModuleNotFoundError as e: + verbose_logger.error( + f"Traceloop not installed, try running 'pip install traceloop-sdk' to fix this error: {e}\n{traceback.format_exc()}" + ) + raise e + + Traceloop.init( + app_name="Litellm-Server", + disable_batch=True, + ) + self.tracer_wrapper = TracerWrapper() + + def log_event( + self, + kwargs, + response_obj, + start_time, + end_time, + user_id, + print_verbose, + level="DEFAULT", + status_message=None, + ): + from opentelemetry.semconv.ai import SpanAttributes + from opentelemetry.trace import SpanKind, Status, StatusCode + + try: + print_verbose( + f"Traceloop Logging - Enters logging function for model {kwargs}" + ) + + tracer = self.tracer_wrapper.get_tracer() + + optional_params = kwargs.get("optional_params", {}) + start_time = int(start_time.timestamp()) + end_time = int(end_time.timestamp()) + span = tracer.start_span( + "litellm.completion", kind=SpanKind.CLIENT, start_time=start_time + ) + + if span.is_recording(): + span.set_attribute( + SpanAttributes.LLM_REQUEST_MODEL, kwargs.get("model") + ) + if "stop" in optional_params: + span.set_attribute( + SpanAttributes.LLM_CHAT_STOP_SEQUENCES, + optional_params.get("stop"), + ) + if "frequency_penalty" in optional_params: + span.set_attribute( + SpanAttributes.LLM_FREQUENCY_PENALTY, + optional_params.get("frequency_penalty"), + ) + if "presence_penalty" in optional_params: + span.set_attribute( + SpanAttributes.LLM_PRESENCE_PENALTY, + optional_params.get("presence_penalty"), + ) + if "top_p" in optional_params: + span.set_attribute( + SpanAttributes.LLM_REQUEST_TOP_P, optional_params.get("top_p") + ) + if "tools" in optional_params or "functions" in optional_params: + span.set_attribute( + SpanAttributes.LLM_REQUEST_FUNCTIONS, + optional_params.get("tools", optional_params.get("functions")), + ) + if "user" in optional_params: + span.set_attribute( + SpanAttributes.LLM_USER, optional_params.get("user") + ) + if "max_tokens" in optional_params: + span.set_attribute( + SpanAttributes.LLM_REQUEST_MAX_TOKENS, + kwargs.get("max_tokens"), + ) + if "temperature" in optional_params: + span.set_attribute( + SpanAttributes.LLM_REQUEST_TEMPERATURE, # type: ignore + kwargs.get("temperature"), + ) + + for idx, prompt in enumerate(kwargs.get("messages")): + span.set_attribute( + f"{SpanAttributes.LLM_PROMPTS}.{idx}.role", + prompt.get("role"), + ) + span.set_attribute( + f"{SpanAttributes.LLM_PROMPTS}.{idx}.content", + prompt.get("content"), + ) + + span.set_attribute( + SpanAttributes.LLM_RESPONSE_MODEL, response_obj.get("model") + ) + usage = response_obj.get("usage") + if usage: + span.set_attribute( + SpanAttributes.LLM_USAGE_TOTAL_TOKENS, + usage.get("total_tokens"), + ) + span.set_attribute( + SpanAttributes.LLM_USAGE_COMPLETION_TOKENS, + usage.get("completion_tokens"), + ) + span.set_attribute( + SpanAttributes.LLM_USAGE_PROMPT_TOKENS, + usage.get("prompt_tokens"), + ) + + for idx, choice in enumerate(response_obj.get("choices")): + span.set_attribute( + f"{SpanAttributes.LLM_COMPLETIONS}.{idx}.finish_reason", + choice.get("finish_reason"), + ) + span.set_attribute( + f"{SpanAttributes.LLM_COMPLETIONS}.{idx}.role", + choice.get("message").get("role"), + ) + span.set_attribute( + f"{SpanAttributes.LLM_COMPLETIONS}.{idx}.content", + choice.get("message").get("content"), + ) + + if ( + level == "ERROR" + and status_message is not None + and isinstance(status_message, str) + ): + span.record_exception(Exception(status_message)) + span.set_status(Status(StatusCode.ERROR, status_message)) + + span.end(end_time) + + except Exception as e: + print_verbose(f"Traceloop Layer Error - {e}") diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/weights_biases.py b/.venv/lib/python3.12/site-packages/litellm/integrations/weights_biases.py new file mode 100644 index 00000000..5fcbab04 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/integrations/weights_biases.py @@ -0,0 +1,217 @@ +imported_openAIResponse = True +try: + import io + import logging + import sys + from typing import Any, Dict, List, Optional, TypeVar + + from wandb.sdk.data_types import trace_tree + + if sys.version_info >= (3, 8): + from typing import Literal, Protocol + else: + from typing_extensions import Literal, Protocol + + logger = logging.getLogger(__name__) + + K = TypeVar("K", bound=str) + V = TypeVar("V") + + class OpenAIResponse(Protocol[K, V]): # type: ignore + # contains a (known) object attribute + object: Literal["chat.completion", "edit", "text_completion"] + + def __getitem__(self, key: K) -> V: ... # noqa + + def get( # noqa + self, key: K, default: Optional[V] = None + ) -> Optional[V]: ... # pragma: no cover + + class OpenAIRequestResponseResolver: + def __call__( + self, + request: Dict[str, Any], + response: OpenAIResponse, + time_elapsed: float, + ) -> Optional[trace_tree.WBTraceTree]: + try: + if response["object"] == "edit": + return self._resolve_edit(request, response, time_elapsed) + elif response["object"] == "text_completion": + return self._resolve_completion(request, response, time_elapsed) + elif response["object"] == "chat.completion": + return self._resolve_chat_completion( + request, response, time_elapsed + ) + else: + logger.info(f"Unknown OpenAI response object: {response['object']}") + except Exception as e: + logger.warning(f"Failed to resolve request/response: {e}") + return None + + @staticmethod + def results_to_trace_tree( + request: Dict[str, Any], + response: OpenAIResponse, + results: List[trace_tree.Result], + time_elapsed: float, + ) -> trace_tree.WBTraceTree: + """Converts the request, response, and results into a trace tree. + + params: + request: The request dictionary + response: The response object + results: A list of results object + time_elapsed: The time elapsed in seconds + returns: + A wandb trace tree object. + """ + start_time_ms = int(round(response["created"] * 1000)) + end_time_ms = start_time_ms + int(round(time_elapsed * 1000)) + span = trace_tree.Span( + name=f"{response.get('model', 'openai')}_{response['object']}_{response.get('created')}", + attributes=dict(response), # type: ignore + start_time_ms=start_time_ms, + end_time_ms=end_time_ms, + span_kind=trace_tree.SpanKind.LLM, + results=results, + ) + model_obj = {"request": request, "response": response, "_kind": "openai"} + return trace_tree.WBTraceTree(root_span=span, model_dict=model_obj) + + def _resolve_edit( + self, + request: Dict[str, Any], + response: OpenAIResponse, + time_elapsed: float, + ) -> trace_tree.WBTraceTree: + """Resolves the request and response objects for `openai.Edit`.""" + request_str = ( + f"\n\n**Instruction**: {request['instruction']}\n\n" + f"**Input**: {request['input']}\n" + ) + choices = [ + f"\n\n**Edited**: {choice['text']}\n" for choice in response["choices"] + ] + + return self._request_response_result_to_trace( + request=request, + response=response, + request_str=request_str, + choices=choices, + time_elapsed=time_elapsed, + ) + + def _resolve_completion( + self, + request: Dict[str, Any], + response: OpenAIResponse, + time_elapsed: float, + ) -> trace_tree.WBTraceTree: + """Resolves the request and response objects for `openai.Completion`.""" + request_str = f"\n\n**Prompt**: {request['prompt']}\n" + choices = [ + f"\n\n**Completion**: {choice['text']}\n" + for choice in response["choices"] + ] + + return self._request_response_result_to_trace( + request=request, + response=response, + request_str=request_str, + choices=choices, + time_elapsed=time_elapsed, + ) + + def _resolve_chat_completion( + self, + request: Dict[str, Any], + response: OpenAIResponse, + time_elapsed: float, + ) -> trace_tree.WBTraceTree: + """Resolves the request and response objects for `openai.Completion`.""" + prompt = io.StringIO() + for message in request["messages"]: + prompt.write(f"\n\n**{message['role']}**: {message['content']}\n") + request_str = prompt.getvalue() + + choices = [ + f"\n\n**{choice['message']['role']}**: {choice['message']['content']}\n" + for choice in response["choices"] + ] + + return self._request_response_result_to_trace( + request=request, + response=response, + request_str=request_str, + choices=choices, + time_elapsed=time_elapsed, + ) + + def _request_response_result_to_trace( + self, + request: Dict[str, Any], + response: OpenAIResponse, + request_str: str, + choices: List[str], + time_elapsed: float, + ) -> trace_tree.WBTraceTree: + """Resolves the request and response objects for `openai.Completion`.""" + results = [ + trace_tree.Result( + inputs={"request": request_str}, + outputs={"response": choice}, + ) + for choice in choices + ] + trace = self.results_to_trace_tree(request, response, results, time_elapsed) + return trace + +except Exception: + imported_openAIResponse = False + + +#### What this does #### +# On success, logs events to Langfuse +import traceback + + +class WeightsBiasesLogger: + # Class variables or attributes + def __init__(self): + try: + pass + except Exception: + raise Exception( + "\033[91m wandb not installed, try running 'pip install wandb' to fix this error\033[0m" + ) + if imported_openAIResponse is False: + raise Exception( + "\033[91m wandb not installed, try running 'pip install wandb' to fix this error\033[0m" + ) + self.resolver = OpenAIRequestResponseResolver() + + def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose): + # Method definition + import wandb + + try: + print_verbose(f"W&B Logging - Enters logging function for model {kwargs}") + run = wandb.init() + print_verbose(response_obj) + + trace = self.resolver( + kwargs, response_obj, (end_time - start_time).total_seconds() + ) + + if trace is not None and run is not None: + run.log({"trace": trace}) + + if run is not None: + run.finish() + print_verbose( + f"W&B Logging Logging - final response object: {response_obj}" + ) + except Exception: + print_verbose(f"W&B Logging Layer Error - {traceback.format_exc()}") + pass |