diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/integrations/pagerduty/pagerduty.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/integrations/pagerduty/pagerduty.py | 305 |
1 files changed, 305 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/pagerduty/pagerduty.py b/.venv/lib/python3.12/site-packages/litellm/integrations/pagerduty/pagerduty.py new file mode 100644 index 00000000..6085bc23 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/integrations/pagerduty/pagerduty.py @@ -0,0 +1,305 @@ +""" +PagerDuty Alerting Integration + +Handles two types of alerts: +- High LLM API Failure Rate. Configure X fails in Y seconds to trigger an alert. +- High Number of Hanging LLM Requests. Configure X hangs in Y seconds to trigger an alert. +""" + +import asyncio +import os +from datetime import datetime, timedelta, timezone +from typing import List, Literal, Optional, Union + +from litellm._logging import verbose_logger +from litellm.caching import DualCache +from litellm.integrations.SlackAlerting.slack_alerting import SlackAlerting +from litellm.llms.custom_httpx.http_handler import ( + AsyncHTTPHandler, + get_async_httpx_client, + httpxSpecialProvider, +) +from litellm.proxy._types import UserAPIKeyAuth +from litellm.types.integrations.pagerduty import ( + AlertingConfig, + PagerDutyInternalEvent, + PagerDutyPayload, + PagerDutyRequestBody, +) +from litellm.types.utils import ( + StandardLoggingPayload, + StandardLoggingPayloadErrorInformation, +) + +PAGERDUTY_DEFAULT_FAILURE_THRESHOLD = 60 +PAGERDUTY_DEFAULT_FAILURE_THRESHOLD_WINDOW_SECONDS = 60 +PAGERDUTY_DEFAULT_HANGING_THRESHOLD_SECONDS = 60 +PAGERDUTY_DEFAULT_HANGING_THRESHOLD_WINDOW_SECONDS = 600 + + +class PagerDutyAlerting(SlackAlerting): + """ + Tracks failed requests and hanging requests separately. + If threshold is crossed for either type, triggers a PagerDuty alert. + """ + + def __init__( + self, alerting_args: Optional[Union[AlertingConfig, dict]] = None, **kwargs + ): + from litellm.proxy.proxy_server import CommonProxyErrors, premium_user + + super().__init__() + _api_key = os.getenv("PAGERDUTY_API_KEY") + if not _api_key: + raise ValueError("PAGERDUTY_API_KEY is not set") + + self.api_key: str = _api_key + alerting_args = alerting_args or {} + self.alerting_args: AlertingConfig = AlertingConfig( + failure_threshold=alerting_args.get( + "failure_threshold", PAGERDUTY_DEFAULT_FAILURE_THRESHOLD + ), + failure_threshold_window_seconds=alerting_args.get( + "failure_threshold_window_seconds", + PAGERDUTY_DEFAULT_FAILURE_THRESHOLD_WINDOW_SECONDS, + ), + hanging_threshold_seconds=alerting_args.get( + "hanging_threshold_seconds", PAGERDUTY_DEFAULT_HANGING_THRESHOLD_SECONDS + ), + hanging_threshold_window_seconds=alerting_args.get( + "hanging_threshold_window_seconds", + PAGERDUTY_DEFAULT_HANGING_THRESHOLD_WINDOW_SECONDS, + ), + ) + + # Separate storage for failures vs. hangs + self._failure_events: List[PagerDutyInternalEvent] = [] + self._hanging_events: List[PagerDutyInternalEvent] = [] + + # premium user check + if premium_user is not True: + raise ValueError( + f"PagerDutyAlerting is only available for LiteLLM Enterprise users. {CommonProxyErrors.not_premium_user.value}" + ) + + # ------------------ MAIN LOGIC ------------------ # + + async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): + """ + Record a failure event. Only send an alert to PagerDuty if the + configured *failure* threshold is exceeded in the specified window. + """ + now = datetime.now(timezone.utc) + standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get( + "standard_logging_object" + ) + if not standard_logging_payload: + raise ValueError( + "standard_logging_object is required for PagerDutyAlerting" + ) + + # Extract error details + error_info: Optional[StandardLoggingPayloadErrorInformation] = ( + standard_logging_payload.get("error_information") or {} + ) + _meta = standard_logging_payload.get("metadata") or {} + + self._failure_events.append( + PagerDutyInternalEvent( + failure_event_type="failed_response", + timestamp=now, + error_class=error_info.get("error_class"), + error_code=error_info.get("error_code"), + error_llm_provider=error_info.get("llm_provider"), + user_api_key_hash=_meta.get("user_api_key_hash"), + user_api_key_alias=_meta.get("user_api_key_alias"), + user_api_key_org_id=_meta.get("user_api_key_org_id"), + user_api_key_team_id=_meta.get("user_api_key_team_id"), + user_api_key_user_id=_meta.get("user_api_key_user_id"), + user_api_key_team_alias=_meta.get("user_api_key_team_alias"), + user_api_key_end_user_id=_meta.get("user_api_key_end_user_id"), + user_api_key_user_email=_meta.get("user_api_key_user_email"), + ) + ) + + # Prune + Possibly alert + window_seconds = self.alerting_args.get("failure_threshold_window_seconds", 60) + threshold = self.alerting_args.get("failure_threshold", 1) + + # If threshold is crossed, send PD alert for failures + await self._send_alert_if_thresholds_crossed( + events=self._failure_events, + window_seconds=window_seconds, + threshold=threshold, + alert_prefix="High LLM API Failure Rate", + ) + + async def async_pre_call_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + cache: DualCache, + data: dict, + call_type: Literal[ + "completion", + "text_completion", + "embeddings", + "image_generation", + "moderation", + "audio_transcription", + "pass_through_endpoint", + "rerank", + ], + ) -> Optional[Union[Exception, str, dict]]: + """ + Example of detecting hanging requests by waiting a given threshold. + If the request didn't finish by then, we treat it as 'hanging'. + """ + verbose_logger.info("Inside Proxy Logging Pre-call hook!") + asyncio.create_task( + self.hanging_response_handler( + request_data=data, user_api_key_dict=user_api_key_dict + ) + ) + return None + + async def hanging_response_handler( + self, request_data: Optional[dict], user_api_key_dict: UserAPIKeyAuth + ): + """ + Checks if request completed by the time 'hanging_threshold_seconds' elapses. + If not, we classify it as a hanging request. + """ + verbose_logger.debug( + f"Inside Hanging Response Handler!..sleeping for {self.alerting_args.get('hanging_threshold_seconds', PAGERDUTY_DEFAULT_HANGING_THRESHOLD_SECONDS)} seconds" + ) + await asyncio.sleep( + self.alerting_args.get( + "hanging_threshold_seconds", PAGERDUTY_DEFAULT_HANGING_THRESHOLD_SECONDS + ) + ) + + if await self._request_is_completed(request_data=request_data): + return # It's not hanging if completed + + # Otherwise, record it as hanging + self._hanging_events.append( + PagerDutyInternalEvent( + failure_event_type="hanging_response", + timestamp=datetime.now(timezone.utc), + error_class="HangingRequest", + error_code="HangingRequest", + error_llm_provider="HangingRequest", + user_api_key_hash=user_api_key_dict.api_key, + user_api_key_alias=user_api_key_dict.key_alias, + user_api_key_org_id=user_api_key_dict.org_id, + user_api_key_team_id=user_api_key_dict.team_id, + user_api_key_user_id=user_api_key_dict.user_id, + user_api_key_team_alias=user_api_key_dict.team_alias, + user_api_key_end_user_id=user_api_key_dict.end_user_id, + user_api_key_user_email=user_api_key_dict.user_email, + ) + ) + + # Prune + Possibly alert + window_seconds = self.alerting_args.get( + "hanging_threshold_window_seconds", + PAGERDUTY_DEFAULT_HANGING_THRESHOLD_WINDOW_SECONDS, + ) + threshold: int = self.alerting_args.get( + "hanging_threshold_fails", PAGERDUTY_DEFAULT_HANGING_THRESHOLD_SECONDS + ) + + # If threshold is crossed, send PD alert for hangs + await self._send_alert_if_thresholds_crossed( + events=self._hanging_events, + window_seconds=window_seconds, + threshold=threshold, + alert_prefix="High Number of Hanging LLM Requests", + ) + + # ------------------ HELPERS ------------------ # + + async def _send_alert_if_thresholds_crossed( + self, + events: List[PagerDutyInternalEvent], + window_seconds: int, + threshold: int, + alert_prefix: str, + ): + """ + 1. Prune old events + 2. If threshold is reached, build alert, send to PagerDuty + 3. Clear those events + """ + cutoff = datetime.now(timezone.utc) - timedelta(seconds=window_seconds) + pruned = [e for e in events if e.get("timestamp", datetime.min) > cutoff] + + # Update the reference list + events.clear() + events.extend(pruned) + + # Check threshold + verbose_logger.debug( + f"Have {len(events)} events in the last {window_seconds} seconds. Threshold is {threshold}" + ) + if len(events) >= threshold: + # Build short summary of last N events + error_summaries = self._build_error_summaries(events, max_errors=5) + alert_message = ( + f"{alert_prefix}: {len(events)} in the last {window_seconds} seconds." + ) + custom_details = {"recent_errors": error_summaries} + + await self.send_alert_to_pagerduty( + alert_message=alert_message, + custom_details=custom_details, + ) + + # Clear them after sending an alert, so we don't spam + events.clear() + + def _build_error_summaries( + self, events: List[PagerDutyInternalEvent], max_errors: int = 5 + ) -> List[PagerDutyInternalEvent]: + """ + Build short text summaries for the last `max_errors`. + Example: "ValueError (code: 500, provider: openai)" + """ + recent = events[-max_errors:] + summaries = [] + for fe in recent: + # If any of these is None, show "N/A" to avoid messing up the summary string + fe.pop("timestamp") + summaries.append(fe) + return summaries + + async def send_alert_to_pagerduty(self, alert_message: str, custom_details: dict): + """ + Send [critical] Alert to PagerDuty + + https://developer.pagerduty.com/api-reference/YXBpOjI3NDgyNjU-pager-duty-v2-events-api + """ + try: + verbose_logger.debug(f"Sending alert to PagerDuty: {alert_message}") + async_client: AsyncHTTPHandler = get_async_httpx_client( + llm_provider=httpxSpecialProvider.LoggingCallback + ) + payload: PagerDutyRequestBody = PagerDutyRequestBody( + payload=PagerDutyPayload( + summary=alert_message, + severity="critical", + source="LiteLLM Alert", + component="LiteLLM", + custom_details=custom_details, + ), + routing_key=self.api_key, + event_action="trigger", + ) + + return await async_client.post( + url="https://events.pagerduty.com/v2/enqueue", + json=dict(payload), + headers={"Content-Type": "application/json"}, + ) + except Exception as e: + verbose_logger.exception(f"Error sending alert to PagerDuty: {e}") |