aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/integrations
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/integrations')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/Readme.md5
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/SlackAlerting/Readme.md13
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/SlackAlerting/batching_handler.py82
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/SlackAlerting/slack_alerting.py1822
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/SlackAlerting/utils.py92
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/__init__.py1
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/_types/open_inference.py286
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/additional_logging_utils.py36
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/argilla.py392
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/arize/_utils.py126
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/arize/arize.py105
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/arize/arize_phoenix.py73
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/athina.py102
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/azure_storage/azure_storage.py381
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/braintrust_logging.py399
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/custom_batch_logger.py59
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/custom_guardrail.py274
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/custom_logger.py388
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/custom_prompt_management.py49
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/datadog/datadog.py580
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/datadog/datadog_llm_obs.py203
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/dynamodb.py89
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/email_alerting.py136
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/email_templates/templates.py62
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/galileo.py157
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/gcs_bucket/Readme.md12
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/gcs_bucket/gcs_bucket.py237
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/gcs_bucket/gcs_bucket_base.py326
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/gcs_pubsub/pub_sub.py203
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/greenscale.py72
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/helicone.py188
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/humanloop.py197
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/lago.py202
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/langfuse/langfuse.py955
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/langfuse/langfuse_handler.py169
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/langfuse/langfuse_prompt_management.py287
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/langsmith.py500
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/langtrace.py106
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/literal_ai.py317
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/logfire_logger.py179
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/lunary.py181
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/mlflow.py269
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/openmeter.py132
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/opentelemetry.py1023
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/opik/opik.py326
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/opik/utils.py110
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/pagerduty/pagerduty.py305
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/prometheus.py1789
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/prometheus_helpers/prometheus_api.py137
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/prometheus_services.py222
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/prompt_layer.py91
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/prompt_management_base.py118
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/s3.py196
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/supabase.py120
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/test_httpx.py0
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/traceloop.py152
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/weights_biases.py217
57 files changed, 15250 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/Readme.md b/.venv/lib/python3.12/site-packages/litellm/integrations/Readme.md
new file mode 100644
index 00000000..2b0b530a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/Readme.md
@@ -0,0 +1,5 @@
+# Integrations
+
+This folder contains logging integrations for litellm
+
+eg. logging to Datadog, Langfuse, Prometheus, s3, GCS Bucket, etc. \ No newline at end of file
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/SlackAlerting/Readme.md b/.venv/lib/python3.12/site-packages/litellm/integrations/SlackAlerting/Readme.md
new file mode 100644
index 00000000..f28f7150
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/SlackAlerting/Readme.md
@@ -0,0 +1,13 @@
+# Slack Alerting on LiteLLM Gateway
+
+This folder contains the Slack Alerting integration for LiteLLM Gateway.
+
+## Folder Structure
+
+- `slack_alerting.py`: This is the main file that handles sending different types of alerts
+- `batching_handler.py`: Handles Batching + sending Httpx Post requests to slack. Slack alerts are sent every 10s or when events are greater than X events. Done to ensure litellm has good performance under high traffic
+- `types.py`: This file contains the AlertType enum which is used to define the different types of alerts that can be sent to Slack.
+- `utils.py`: This file contains common utils used specifically for slack alerting
+
+## Further Reading
+- [Doc setting up Alerting on LiteLLM Proxy (Gateway)](https://docs.litellm.ai/docs/proxy/alerting) \ No newline at end of file
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/SlackAlerting/batching_handler.py b/.venv/lib/python3.12/site-packages/litellm/integrations/SlackAlerting/batching_handler.py
new file mode 100644
index 00000000..e35cf61d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/SlackAlerting/batching_handler.py
@@ -0,0 +1,82 @@
+"""
+Handles Batching + sending Httpx Post requests to slack
+
+Slack alerts are sent every 10s or when events are greater than X events
+
+see custom_batch_logger.py for more details / defaults
+"""
+
+from typing import TYPE_CHECKING, Any
+
+from litellm._logging import verbose_proxy_logger
+
+if TYPE_CHECKING:
+ from .slack_alerting import SlackAlerting as _SlackAlerting
+
+ SlackAlertingType = _SlackAlerting
+else:
+ SlackAlertingType = Any
+
+
+def squash_payloads(queue):
+
+ squashed = {}
+ if len(queue) == 0:
+ return squashed
+ if len(queue) == 1:
+ return {"key": {"item": queue[0], "count": 1}}
+
+ for item in queue:
+ url = item["url"]
+ alert_type = item["alert_type"]
+ _key = (url, alert_type)
+
+ if _key in squashed:
+ squashed[_key]["count"] += 1
+ # Merge the payloads
+
+ else:
+ squashed[_key] = {"item": item, "count": 1}
+
+ return squashed
+
+
+def _print_alerting_payload_warning(
+ payload: dict, slackAlertingInstance: SlackAlertingType
+):
+ """
+ Print the payload to the console when
+ slackAlertingInstance.alerting_args.log_to_console is True
+
+ Relevant issue: https://github.com/BerriAI/litellm/issues/7372
+ """
+ if slackAlertingInstance.alerting_args.log_to_console is True:
+ verbose_proxy_logger.warning(payload)
+
+
+async def send_to_webhook(slackAlertingInstance: SlackAlertingType, item, count):
+ """
+ Send a single slack alert to the webhook
+ """
+ import json
+
+ payload = item.get("payload", {})
+ try:
+ if count > 1:
+ payload["text"] = f"[Num Alerts: {count}]\n\n{payload['text']}"
+
+ response = await slackAlertingInstance.async_http_handler.post(
+ url=item["url"],
+ headers=item["headers"],
+ data=json.dumps(payload),
+ )
+ if response.status_code != 200:
+ verbose_proxy_logger.debug(
+ f"Error sending slack alert to url={item['url']}. Error={response.text}"
+ )
+ except Exception as e:
+ verbose_proxy_logger.debug(f"Error sending slack alert: {str(e)}")
+ finally:
+ _print_alerting_payload_warning(
+ payload, slackAlertingInstance=slackAlertingInstance
+ )
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/SlackAlerting/slack_alerting.py b/.venv/lib/python3.12/site-packages/litellm/integrations/SlackAlerting/slack_alerting.py
new file mode 100644
index 00000000..a2e62647
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/SlackAlerting/slack_alerting.py
@@ -0,0 +1,1822 @@
+#### What this does ####
+# Class for sending Slack Alerts #
+import asyncio
+import datetime
+import os
+import random
+import time
+from datetime import timedelta
+from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union
+
+from openai import APIError
+
+import litellm
+import litellm.litellm_core_utils
+import litellm.litellm_core_utils.litellm_logging
+import litellm.types
+from litellm._logging import verbose_logger, verbose_proxy_logger
+from litellm.caching.caching import DualCache
+from litellm.integrations.custom_batch_logger import CustomBatchLogger
+from litellm.litellm_core_utils.duration_parser import duration_in_seconds
+from litellm.litellm_core_utils.exception_mapping_utils import (
+ _add_key_name_and_team_to_alert,
+)
+from litellm.llms.custom_httpx.http_handler import (
+ get_async_httpx_client,
+ httpxSpecialProvider,
+)
+from litellm.proxy._types import AlertType, CallInfo, VirtualKeyEvent, WebhookEvent
+from litellm.types.integrations.slack_alerting import *
+
+from ..email_templates.templates import *
+from .batching_handler import send_to_webhook, squash_payloads
+from .utils import _add_langfuse_trace_id_to_alert, process_slack_alerting_variables
+
+if TYPE_CHECKING:
+ from litellm.router import Router as _Router
+
+ Router = _Router
+else:
+ Router = Any
+
+
+class SlackAlerting(CustomBatchLogger):
+ """
+ Class for sending Slack Alerts
+ """
+
+ # Class variables or attributes
+ def __init__(
+ self,
+ internal_usage_cache: Optional[DualCache] = None,
+ alerting_threshold: Optional[
+ float
+ ] = None, # threshold for slow / hanging llm responses (in seconds)
+ alerting: Optional[List] = [],
+ alert_types: List[AlertType] = DEFAULT_ALERT_TYPES,
+ alert_to_webhook_url: Optional[
+ Dict[AlertType, Union[List[str], str]]
+ ] = None, # if user wants to separate alerts to diff channels
+ alerting_args={},
+ default_webhook_url: Optional[str] = None,
+ **kwargs,
+ ):
+ if alerting_threshold is None:
+ alerting_threshold = 300
+ self.alerting_threshold = alerting_threshold
+ self.alerting = alerting
+ self.alert_types = alert_types
+ self.internal_usage_cache = internal_usage_cache or DualCache()
+ self.async_http_handler = get_async_httpx_client(
+ llm_provider=httpxSpecialProvider.LoggingCallback
+ )
+ self.alert_to_webhook_url = process_slack_alerting_variables(
+ alert_to_webhook_url=alert_to_webhook_url
+ )
+ self.is_running = False
+ self.alerting_args = SlackAlertingArgs(**alerting_args)
+ self.default_webhook_url = default_webhook_url
+ self.flush_lock = asyncio.Lock()
+ super().__init__(**kwargs, flush_lock=self.flush_lock)
+
+ def update_values(
+ self,
+ alerting: Optional[List] = None,
+ alerting_threshold: Optional[float] = None,
+ alert_types: Optional[List[AlertType]] = None,
+ alert_to_webhook_url: Optional[Dict[AlertType, Union[List[str], str]]] = None,
+ alerting_args: Optional[Dict] = None,
+ llm_router: Optional[Router] = None,
+ ):
+ if alerting is not None:
+ self.alerting = alerting
+ asyncio.create_task(self.periodic_flush())
+ if alerting_threshold is not None:
+ self.alerting_threshold = alerting_threshold
+ if alert_types is not None:
+ self.alert_types = alert_types
+ if alerting_args is not None:
+ self.alerting_args = SlackAlertingArgs(**alerting_args)
+ if alert_to_webhook_url is not None:
+ # update the dict
+ if self.alert_to_webhook_url is None:
+ self.alert_to_webhook_url = process_slack_alerting_variables(
+ alert_to_webhook_url=alert_to_webhook_url
+ )
+ else:
+ _new_values = (
+ process_slack_alerting_variables(
+ alert_to_webhook_url=alert_to_webhook_url
+ )
+ or {}
+ )
+ self.alert_to_webhook_url.update(_new_values)
+ if llm_router is not None:
+ self.llm_router = llm_router
+
+ async def deployment_in_cooldown(self):
+ pass
+
+ async def deployment_removed_from_cooldown(self):
+ pass
+
+ def _all_possible_alert_types(self):
+ # used by the UI to show all supported alert types
+ # Note: This is not the alerts the user has configured, instead it's all possible alert types a user can select
+ # return list of all values AlertType enum
+ return list(AlertType)
+
+ def _response_taking_too_long_callback_helper(
+ self,
+ kwargs, # kwargs to completion
+ start_time,
+ end_time, # start/end time
+ ):
+ try:
+ time_difference = end_time - start_time
+ # Convert the timedelta to float (in seconds)
+ time_difference_float = time_difference.total_seconds()
+ litellm_params = kwargs.get("litellm_params", {})
+ model = kwargs.get("model", "")
+ api_base = litellm.get_api_base(model=model, optional_params=litellm_params)
+ messages = kwargs.get("messages", None)
+ # if messages does not exist fallback to "input"
+ if messages is None:
+ messages = kwargs.get("input", None)
+
+ # only use first 100 chars for alerting
+ _messages = str(messages)[:100]
+
+ return time_difference_float, model, api_base, _messages
+ except Exception as e:
+ raise e
+
+ def _get_deployment_latencies_to_alert(self, metadata=None):
+ if metadata is None:
+ return None
+
+ if "_latency_per_deployment" in metadata:
+ # Translate model_id to -> api_base
+ # _latency_per_deployment is a dictionary that looks like this:
+ """
+ _latency_per_deployment: {
+ api_base: 0.01336697916666667
+ }
+ """
+ _message_to_send = ""
+ _deployment_latencies = metadata["_latency_per_deployment"]
+ if len(_deployment_latencies) == 0:
+ return None
+ _deployment_latency_map: Optional[dict] = None
+ try:
+ # try sorting deployments by latency
+ _deployment_latencies = sorted(
+ _deployment_latencies.items(), key=lambda x: x[1]
+ )
+ _deployment_latency_map = dict(_deployment_latencies)
+ except Exception:
+ pass
+
+ if _deployment_latency_map is None:
+ return
+
+ for api_base, latency in _deployment_latency_map.items():
+ _message_to_send += f"\n{api_base}: {round(latency,2)}s"
+ _message_to_send = "```" + _message_to_send + "```"
+ return _message_to_send
+
+ async def response_taking_too_long_callback(
+ self,
+ kwargs, # kwargs to completion
+ completion_response, # response from completion
+ start_time,
+ end_time, # start/end time
+ ):
+ if self.alerting is None or self.alert_types is None:
+ return
+
+ time_difference_float, model, api_base, messages = (
+ self._response_taking_too_long_callback_helper(
+ kwargs=kwargs,
+ start_time=start_time,
+ end_time=end_time,
+ )
+ )
+ if litellm.turn_off_message_logging or litellm.redact_messages_in_exceptions:
+ messages = "Message not logged. litellm.redact_messages_in_exceptions=True"
+ request_info = f"\nRequest Model: `{model}`\nAPI Base: `{api_base}`\nMessages: `{messages}`"
+ slow_message = f"`Responses are slow - {round(time_difference_float,2)}s response time > Alerting threshold: {self.alerting_threshold}s`"
+ alerting_metadata: dict = {}
+ if time_difference_float > self.alerting_threshold:
+ # add deployment latencies to alert
+ if (
+ kwargs is not None
+ and "litellm_params" in kwargs
+ and "metadata" in kwargs["litellm_params"]
+ ):
+ _metadata: dict = kwargs["litellm_params"]["metadata"]
+ request_info = _add_key_name_and_team_to_alert(
+ request_info=request_info, metadata=_metadata
+ )
+
+ _deployment_latency_map = self._get_deployment_latencies_to_alert(
+ metadata=_metadata
+ )
+ if _deployment_latency_map is not None:
+ request_info += (
+ f"\nAvailable Deployment Latencies\n{_deployment_latency_map}"
+ )
+
+ if "alerting_metadata" in _metadata:
+ alerting_metadata = _metadata["alerting_metadata"]
+ await self.send_alert(
+ message=slow_message + request_info,
+ level="Low",
+ alert_type=AlertType.llm_too_slow,
+ alerting_metadata=alerting_metadata,
+ )
+
+ async def async_update_daily_reports(
+ self, deployment_metrics: DeploymentMetrics
+ ) -> int:
+ """
+ Store the perf by deployment in cache
+ - Number of failed requests per deployment
+ - Latency / output tokens per deployment
+
+ 'deployment_id:daily_metrics:failed_requests'
+ 'deployment_id:daily_metrics:latency_per_output_token'
+
+ Returns
+ int - count of metrics set (1 - if just latency, 2 - if failed + latency)
+ """
+
+ return_val = 0
+ try:
+ ## FAILED REQUESTS ##
+ if deployment_metrics.failed_request:
+ await self.internal_usage_cache.async_increment_cache(
+ key="{}:{}".format(
+ deployment_metrics.id,
+ SlackAlertingCacheKeys.failed_requests_key.value,
+ ),
+ value=1,
+ parent_otel_span=None, # no attached request, this is a background operation
+ )
+
+ return_val += 1
+
+ ## LATENCY ##
+ if deployment_metrics.latency_per_output_token is not None:
+ await self.internal_usage_cache.async_increment_cache(
+ key="{}:{}".format(
+ deployment_metrics.id, SlackAlertingCacheKeys.latency_key.value
+ ),
+ value=deployment_metrics.latency_per_output_token,
+ parent_otel_span=None, # no attached request, this is a background operation
+ )
+
+ return_val += 1
+
+ return return_val
+ except Exception:
+ return 0
+
+ async def send_daily_reports(self, router) -> bool: # noqa: PLR0915
+ """
+ Send a daily report on:
+ - Top 5 deployments with most failed requests
+ - Top 5 slowest deployments (normalized by latency/output tokens)
+
+ Get the value from redis cache (if available) or in-memory and send it
+
+ Cleanup:
+ - reset values in cache -> prevent memory leak
+
+ Returns:
+ True -> if successfuly sent
+ False -> if not sent
+ """
+
+ ids = router.get_model_ids()
+
+ # get keys
+ failed_request_keys = [
+ "{}:{}".format(id, SlackAlertingCacheKeys.failed_requests_key.value)
+ for id in ids
+ ]
+ latency_keys = [
+ "{}:{}".format(id, SlackAlertingCacheKeys.latency_key.value) for id in ids
+ ]
+
+ combined_metrics_keys = failed_request_keys + latency_keys # reduce cache calls
+
+ combined_metrics_values = await self.internal_usage_cache.async_batch_get_cache(
+ keys=combined_metrics_keys
+ ) # [1, 2, None, ..]
+
+ if combined_metrics_values is None:
+ return False
+
+ all_none = True
+ for val in combined_metrics_values:
+ if val is not None and val > 0:
+ all_none = False
+ break
+
+ if all_none:
+ return False
+
+ failed_request_values = combined_metrics_values[
+ : len(failed_request_keys)
+ ] # # [1, 2, None, ..]
+ latency_values = combined_metrics_values[len(failed_request_keys) :]
+
+ # find top 5 failed
+ ## Replace None values with a placeholder value (-1 in this case)
+ placeholder_value = 0
+ replaced_failed_values = [
+ value if value is not None else placeholder_value
+ for value in failed_request_values
+ ]
+
+ ## Get the indices of top 5 keys with the highest numerical values (ignoring None and 0 values)
+ top_5_failed = sorted(
+ range(len(replaced_failed_values)),
+ key=lambda i: replaced_failed_values[i],
+ reverse=True,
+ )[:5]
+ top_5_failed = [
+ index for index in top_5_failed if replaced_failed_values[index] > 0
+ ]
+
+ # find top 5 slowest
+ # Replace None values with a placeholder value (-1 in this case)
+ placeholder_value = 0
+ replaced_slowest_values = [
+ value if value is not None else placeholder_value
+ for value in latency_values
+ ]
+
+ # Get the indices of top 5 values with the highest numerical values (ignoring None and 0 values)
+ top_5_slowest = sorted(
+ range(len(replaced_slowest_values)),
+ key=lambda i: replaced_slowest_values[i],
+ reverse=True,
+ )[:5]
+ top_5_slowest = [
+ index for index in top_5_slowest if replaced_slowest_values[index] > 0
+ ]
+
+ # format alert -> return the litellm model name + api base
+ message = f"\n\nTime: `{time.time()}`s\nHere are today's key metrics 📈: \n\n"
+
+ message += "\n\n*❗️ Top Deployments with Most Failed Requests:*\n\n"
+ if not top_5_failed:
+ message += "\tNone\n"
+ for i in range(len(top_5_failed)):
+ key = failed_request_keys[top_5_failed[i]].split(":")[0]
+ _deployment = router.get_model_info(key)
+ if isinstance(_deployment, dict):
+ deployment_name = _deployment["litellm_params"].get("model", "")
+ else:
+ return False
+
+ api_base = litellm.get_api_base(
+ model=deployment_name,
+ optional_params=(
+ _deployment["litellm_params"] if _deployment is not None else {}
+ ),
+ )
+ if api_base is None:
+ api_base = ""
+ value = replaced_failed_values[top_5_failed[i]]
+ message += f"\t{i+1}. Deployment: `{deployment_name}`, Failed Requests: `{value}`, API Base: `{api_base}`\n"
+
+ message += "\n\n*😅 Top Slowest Deployments:*\n\n"
+ if not top_5_slowest:
+ message += "\tNone\n"
+ for i in range(len(top_5_slowest)):
+ key = latency_keys[top_5_slowest[i]].split(":")[0]
+ _deployment = router.get_model_info(key)
+ if _deployment is not None:
+ deployment_name = _deployment["litellm_params"].get("model", "")
+ else:
+ deployment_name = ""
+ api_base = litellm.get_api_base(
+ model=deployment_name,
+ optional_params=(
+ _deployment["litellm_params"] if _deployment is not None else {}
+ ),
+ )
+ value = round(replaced_slowest_values[top_5_slowest[i]], 3)
+ message += f"\t{i+1}. Deployment: `{deployment_name}`, Latency per output token: `{value}s/token`, API Base: `{api_base}`\n\n"
+
+ # cache cleanup -> reset values to 0
+ latency_cache_keys = [(key, 0) for key in latency_keys]
+ failed_request_cache_keys = [(key, 0) for key in failed_request_keys]
+ combined_metrics_cache_keys = latency_cache_keys + failed_request_cache_keys
+ await self.internal_usage_cache.async_set_cache_pipeline(
+ cache_list=combined_metrics_cache_keys
+ )
+
+ message += f"\n\nNext Run is at: `{time.time() + self.alerting_args.daily_report_frequency}`s"
+
+ # send alert
+ await self.send_alert(
+ message=message,
+ level="Low",
+ alert_type=AlertType.daily_reports,
+ alerting_metadata={},
+ )
+
+ return True
+
+ async def response_taking_too_long(
+ self,
+ start_time: Optional[datetime.datetime] = None,
+ end_time: Optional[datetime.datetime] = None,
+ type: Literal["hanging_request", "slow_response"] = "hanging_request",
+ request_data: Optional[dict] = None,
+ ):
+ if self.alerting is None or self.alert_types is None:
+ return
+ model: str = ""
+ if request_data is not None:
+ model = request_data.get("model", "")
+ messages = request_data.get("messages", None)
+ if messages is None:
+ # if messages does not exist fallback to "input"
+ messages = request_data.get("input", None)
+
+ # try casting messages to str and get the first 100 characters, else mark as None
+ try:
+ messages = str(messages)
+ messages = messages[:100]
+ except Exception:
+ messages = ""
+
+ if (
+ litellm.turn_off_message_logging
+ or litellm.redact_messages_in_exceptions
+ ):
+ messages = (
+ "Message not logged. litellm.redact_messages_in_exceptions=True"
+ )
+ request_info = f"\nRequest Model: `{model}`\nMessages: `{messages}`"
+ else:
+ request_info = ""
+
+ if type == "hanging_request":
+ await asyncio.sleep(
+ self.alerting_threshold
+ ) # Set it to 5 minutes - i'd imagine this might be different for streaming, non-streaming, non-completion (embedding + img) requests
+ alerting_metadata: dict = {}
+ if await self._request_is_completed(request_data=request_data) is True:
+ return
+
+ if request_data is not None:
+ if request_data.get("deployment", None) is not None and isinstance(
+ request_data["deployment"], dict
+ ):
+ _api_base = litellm.get_api_base(
+ model=model,
+ optional_params=request_data["deployment"].get(
+ "litellm_params", {}
+ ),
+ )
+
+ if _api_base is None:
+ _api_base = ""
+
+ request_info += f"\nAPI Base: {_api_base}"
+ elif request_data.get("metadata", None) is not None and isinstance(
+ request_data["metadata"], dict
+ ):
+ # In hanging requests sometime it has not made it to the point where the deployment is passed to the `request_data``
+ # in that case we fallback to the api base set in the request metadata
+ _metadata: dict = request_data["metadata"]
+ _api_base = _metadata.get("api_base", "")
+
+ request_info = _add_key_name_and_team_to_alert(
+ request_info=request_info, metadata=_metadata
+ )
+
+ if _api_base is None:
+ _api_base = ""
+
+ if "alerting_metadata" in _metadata:
+ alerting_metadata = _metadata["alerting_metadata"]
+ request_info += f"\nAPI Base: `{_api_base}`"
+ # only alert hanging responses if they have not been marked as success
+ alerting_message = (
+ f"`Requests are hanging - {self.alerting_threshold}s+ request time`"
+ )
+
+ if "langfuse" in litellm.success_callback:
+ langfuse_url = await _add_langfuse_trace_id_to_alert(
+ request_data=request_data,
+ )
+
+ if langfuse_url is not None:
+ request_info += "\n🪢 Langfuse Trace: {}".format(langfuse_url)
+
+ # add deployment latencies to alert
+ _deployment_latency_map = self._get_deployment_latencies_to_alert(
+ metadata=request_data.get("metadata", {})
+ )
+ if _deployment_latency_map is not None:
+ request_info += f"\nDeployment Latencies\n{_deployment_latency_map}"
+
+ await self.send_alert(
+ message=alerting_message + request_info,
+ level="Medium",
+ alert_type=AlertType.llm_requests_hanging,
+ alerting_metadata=alerting_metadata,
+ )
+
+ async def failed_tracking_alert(self, error_message: str, failing_model: str):
+ """
+ Raise alert when tracking failed for specific model
+
+ Args:
+ error_message (str): Error message
+ failing_model (str): Model that failed tracking
+ """
+ if self.alerting is None or self.alert_types is None:
+ # do nothing if alerting is not switched on
+ return
+ if "failed_tracking_spend" not in self.alert_types:
+ return
+
+ _cache: DualCache = self.internal_usage_cache
+ message = "Failed Tracking Cost for " + error_message
+ _cache_key = "budget_alerts:failed_tracking:{}".format(failing_model)
+ result = await _cache.async_get_cache(key=_cache_key)
+ if result is None:
+ await self.send_alert(
+ message=message,
+ level="High",
+ alert_type=AlertType.failed_tracking_spend,
+ alerting_metadata={},
+ )
+ await _cache.async_set_cache(
+ key=_cache_key,
+ value="SENT",
+ ttl=self.alerting_args.budget_alert_ttl,
+ )
+
+ async def budget_alerts( # noqa: PLR0915
+ self,
+ type: Literal[
+ "token_budget",
+ "soft_budget",
+ "user_budget",
+ "team_budget",
+ "proxy_budget",
+ "projected_limit_exceeded",
+ ],
+ user_info: CallInfo,
+ ):
+ ## PREVENTITIVE ALERTING ## - https://github.com/BerriAI/litellm/issues/2727
+ # - Alert once within 24hr period
+ # - Cache this information
+ # - Don't re-alert, if alert already sent
+ _cache: DualCache = self.internal_usage_cache
+
+ if self.alerting is None or self.alert_types is None:
+ # do nothing if alerting is not switched on
+ return
+ if "budget_alerts" not in self.alert_types:
+ return
+ _id: Optional[str] = "default_id" # used for caching
+ user_info_json = user_info.model_dump(exclude_none=True)
+ user_info_str = self._get_user_info_str(user_info)
+ event: Optional[
+ Literal[
+ "budget_crossed",
+ "threshold_crossed",
+ "projected_limit_exceeded",
+ "soft_budget_crossed",
+ ]
+ ] = None
+ event_group: Optional[
+ Literal["internal_user", "team", "key", "proxy", "customer"]
+ ] = None
+ event_message: str = ""
+ webhook_event: Optional[WebhookEvent] = None
+ if type == "proxy_budget":
+ event_group = "proxy"
+ event_message += "Proxy Budget: "
+ elif type == "soft_budget":
+ event_group = "proxy"
+ event_message += "Soft Budget Crossed: "
+ elif type == "user_budget":
+ event_group = "internal_user"
+ event_message += "User Budget: "
+ _id = user_info.user_id or _id
+ elif type == "team_budget":
+ event_group = "team"
+ event_message += "Team Budget: "
+ _id = user_info.team_id or _id
+ elif type == "token_budget":
+ event_group = "key"
+ event_message += "Key Budget: "
+ _id = user_info.token
+ elif type == "projected_limit_exceeded":
+ event_group = "key"
+ event_message += "Key Budget: Projected Limit Exceeded"
+ event = "projected_limit_exceeded"
+ _id = user_info.token
+
+ # percent of max_budget left to spend
+ if user_info.max_budget is None and user_info.soft_budget is None:
+ return
+ percent_left: float = 0
+ if user_info.max_budget is not None:
+ if user_info.max_budget > 0:
+ percent_left = (
+ user_info.max_budget - user_info.spend
+ ) / user_info.max_budget
+
+ # check if crossed budget
+ if user_info.max_budget is not None:
+ if user_info.spend >= user_info.max_budget:
+ event = "budget_crossed"
+ event_message += (
+ f"Budget Crossed\n Total Budget:`{user_info.max_budget}`"
+ )
+ elif percent_left <= 0.05:
+ event = "threshold_crossed"
+ event_message += "5% Threshold Crossed "
+ elif percent_left <= 0.15:
+ event = "threshold_crossed"
+ event_message += "15% Threshold Crossed"
+ elif user_info.soft_budget is not None:
+ if user_info.spend >= user_info.soft_budget:
+ event = "soft_budget_crossed"
+ if event is not None and event_group is not None:
+ _cache_key = "budget_alerts:{}:{}".format(event, _id)
+ result = await _cache.async_get_cache(key=_cache_key)
+ if result is None:
+ webhook_event = WebhookEvent(
+ event=event,
+ event_group=event_group,
+ event_message=event_message,
+ **user_info_json,
+ )
+ await self.send_alert(
+ message=event_message + "\n\n" + user_info_str,
+ level="High",
+ alert_type=AlertType.budget_alerts,
+ user_info=webhook_event,
+ alerting_metadata={},
+ )
+ await _cache.async_set_cache(
+ key=_cache_key,
+ value="SENT",
+ ttl=self.alerting_args.budget_alert_ttl,
+ )
+
+ return
+ return
+
+ def _get_user_info_str(self, user_info: CallInfo) -> str:
+ """
+ Create a standard message for a budget alert
+ """
+ _all_fields_as_dict = user_info.model_dump(exclude_none=True)
+ _all_fields_as_dict.pop("token")
+ msg = ""
+ for k, v in _all_fields_as_dict.items():
+ msg += f"*{k}:* `{v}`\n"
+
+ return msg
+
+ async def customer_spend_alert(
+ self,
+ token: Optional[str],
+ key_alias: Optional[str],
+ end_user_id: Optional[str],
+ response_cost: Optional[float],
+ max_budget: Optional[float],
+ ):
+ if (
+ self.alerting is not None
+ and "webhook" in self.alerting
+ and end_user_id is not None
+ and token is not None
+ and response_cost is not None
+ ):
+ # log customer spend
+ event = WebhookEvent(
+ spend=response_cost,
+ max_budget=max_budget,
+ token=token,
+ customer_id=end_user_id,
+ user_id=None,
+ team_id=None,
+ user_email=None,
+ key_alias=key_alias,
+ projected_exceeded_date=None,
+ projected_spend=None,
+ event="spend_tracked",
+ event_group="customer",
+ event_message="Customer spend tracked. Customer={}, spend={}".format(
+ end_user_id, response_cost
+ ),
+ )
+
+ await self.send_webhook_alert(webhook_event=event)
+
+ def _count_outage_alerts(self, alerts: List[int]) -> str:
+ """
+ Parameters:
+ - alerts: List[int] -> list of error codes (either 408 or 500+)
+
+ Returns:
+ - str -> formatted string. This is an alert message, giving a human-friendly description of the errors.
+ """
+ error_breakdown = {"Timeout Errors": 0, "API Errors": 0, "Unknown Errors": 0}
+ for alert in alerts:
+ if alert == 408:
+ error_breakdown["Timeout Errors"] += 1
+ elif alert >= 500:
+ error_breakdown["API Errors"] += 1
+ else:
+ error_breakdown["Unknown Errors"] += 1
+
+ error_msg = ""
+ for key, value in error_breakdown.items():
+ if value > 0:
+ error_msg += "\n{}: {}\n".format(key, value)
+
+ return error_msg
+
+ def _outage_alert_msg_factory(
+ self,
+ alert_type: Literal["Major", "Minor"],
+ key: Literal["Model", "Region"],
+ key_val: str,
+ provider: str,
+ api_base: Optional[str],
+ outage_value: BaseOutageModel,
+ ) -> str:
+ """Format an alert message for slack"""
+ headers = {f"{key} Name": key_val, "Provider": provider}
+ if api_base is not None:
+ headers["API Base"] = api_base # type: ignore
+
+ headers_str = "\n"
+ for k, v in headers.items():
+ headers_str += f"*{k}:* `{v}`\n"
+ return f"""\n\n
+*⚠️ {alert_type} Service Outage*
+
+{headers_str}
+
+*Errors:*
+{self._count_outage_alerts(alerts=outage_value["alerts"])}
+
+*Last Check:* `{round(time.time() - outage_value["last_updated_at"], 4)}s ago`\n\n
+"""
+
+ async def region_outage_alerts(
+ self,
+ exception: APIError,
+ deployment_id: str,
+ ) -> None:
+ """
+ Send slack alert if specific provider region is having an outage.
+
+ Track for 408 (Timeout) and >=500 Error codes
+ """
+ ## CREATE (PROVIDER+REGION) ID ##
+ if self.llm_router is None:
+ return
+
+ deployment = self.llm_router.get_deployment(model_id=deployment_id)
+
+ if deployment is None:
+ return
+
+ model = deployment.litellm_params.model
+ ### GET PROVIDER ###
+ provider = deployment.litellm_params.custom_llm_provider
+ if provider is None:
+ model, provider, _, _ = litellm.get_llm_provider(model=model)
+
+ ### GET REGION ###
+ region_name = deployment.litellm_params.region_name
+ if region_name is None:
+ region_name = litellm.utils._get_model_region(
+ custom_llm_provider=provider, litellm_params=deployment.litellm_params
+ )
+
+ if region_name is None:
+ return
+
+ ### UNIQUE CACHE KEY ###
+ cache_key = provider + region_name
+
+ outage_value: Optional[ProviderRegionOutageModel] = (
+ await self.internal_usage_cache.async_get_cache(key=cache_key)
+ )
+
+ if (
+ getattr(exception, "status_code", None) is None
+ or (
+ exception.status_code != 408 # type: ignore
+ and exception.status_code < 500 # type: ignore
+ )
+ or self.llm_router is None
+ ):
+ return
+
+ if outage_value is None:
+ _deployment_set = set()
+ _deployment_set.add(deployment_id)
+ outage_value = ProviderRegionOutageModel(
+ provider_region_id=cache_key,
+ alerts=[exception.status_code], # type: ignore
+ minor_alert_sent=False,
+ major_alert_sent=False,
+ last_updated_at=time.time(),
+ deployment_ids=_deployment_set,
+ )
+
+ ## add to cache ##
+ await self.internal_usage_cache.async_set_cache(
+ key=cache_key,
+ value=outage_value,
+ ttl=self.alerting_args.region_outage_alert_ttl,
+ )
+ return
+
+ if len(outage_value["alerts"]) < self.alerting_args.max_outage_alert_list_size:
+ outage_value["alerts"].append(exception.status_code) # type: ignore
+ else: # prevent memory leaks
+ pass
+ _deployment_set = outage_value["deployment_ids"]
+ _deployment_set.add(deployment_id)
+ outage_value["deployment_ids"] = _deployment_set
+ outage_value["last_updated_at"] = time.time()
+
+ ## MINOR OUTAGE ALERT SENT ##
+ if (
+ outage_value["minor_alert_sent"] is False
+ and len(outage_value["alerts"])
+ >= self.alerting_args.minor_outage_alert_threshold
+ and len(_deployment_set) > 1 # make sure it's not just 1 bad deployment
+ ):
+ msg = self._outage_alert_msg_factory(
+ alert_type="Minor",
+ key="Region",
+ key_val=region_name,
+ api_base=None,
+ outage_value=outage_value,
+ provider=provider,
+ )
+ # send minor alert
+ await self.send_alert(
+ message=msg,
+ level="Medium",
+ alert_type=AlertType.outage_alerts,
+ alerting_metadata={},
+ )
+ # set to true
+ outage_value["minor_alert_sent"] = True
+
+ ## MAJOR OUTAGE ALERT SENT ##
+ elif (
+ outage_value["major_alert_sent"] is False
+ and len(outage_value["alerts"])
+ >= self.alerting_args.major_outage_alert_threshold
+ and len(_deployment_set) > 1 # make sure it's not just 1 bad deployment
+ ):
+ msg = self._outage_alert_msg_factory(
+ alert_type="Major",
+ key="Region",
+ key_val=region_name,
+ api_base=None,
+ outage_value=outage_value,
+ provider=provider,
+ )
+
+ # send minor alert
+ await self.send_alert(
+ message=msg,
+ level="High",
+ alert_type=AlertType.outage_alerts,
+ alerting_metadata={},
+ )
+ # set to true
+ outage_value["major_alert_sent"] = True
+
+ ## update cache ##
+ await self.internal_usage_cache.async_set_cache(
+ key=cache_key, value=outage_value
+ )
+
+ async def outage_alerts(
+ self,
+ exception: APIError,
+ deployment_id: str,
+ ) -> None:
+ """
+ Send slack alert if model is badly configured / having an outage (408, 401, 429, >=500).
+
+ key = model_id
+
+ value = {
+ - model_id
+ - threshold
+ - alerts []
+ }
+
+ ttl = 1hr
+ max_alerts_size = 10
+ """
+ try:
+ outage_value: Optional[OutageModel] = await self.internal_usage_cache.async_get_cache(key=deployment_id) # type: ignore
+ if (
+ getattr(exception, "status_code", None) is None
+ or (
+ exception.status_code != 408 # type: ignore
+ and exception.status_code < 500 # type: ignore
+ )
+ or self.llm_router is None
+ ):
+ return
+
+ ### EXTRACT MODEL DETAILS ###
+ deployment = self.llm_router.get_deployment(model_id=deployment_id)
+ if deployment is None:
+ return
+
+ model = deployment.litellm_params.model
+ provider = deployment.litellm_params.custom_llm_provider
+ if provider is None:
+ try:
+ model, provider, _, _ = litellm.get_llm_provider(model=model)
+ except Exception:
+ provider = ""
+ api_base = litellm.get_api_base(
+ model=model, optional_params=deployment.litellm_params
+ )
+
+ if outage_value is None:
+ outage_value = OutageModel(
+ model_id=deployment_id,
+ alerts=[exception.status_code], # type: ignore
+ minor_alert_sent=False,
+ major_alert_sent=False,
+ last_updated_at=time.time(),
+ )
+
+ ## add to cache ##
+ await self.internal_usage_cache.async_set_cache(
+ key=deployment_id,
+ value=outage_value,
+ ttl=self.alerting_args.outage_alert_ttl,
+ )
+ return
+
+ if (
+ len(outage_value["alerts"])
+ < self.alerting_args.max_outage_alert_list_size
+ ):
+ outage_value["alerts"].append(exception.status_code) # type: ignore
+ else: # prevent memory leaks
+ pass
+
+ outage_value["last_updated_at"] = time.time()
+
+ ## MINOR OUTAGE ALERT SENT ##
+ if (
+ outage_value["minor_alert_sent"] is False
+ and len(outage_value["alerts"])
+ >= self.alerting_args.minor_outage_alert_threshold
+ ):
+ msg = self._outage_alert_msg_factory(
+ alert_type="Minor",
+ key="Model",
+ key_val=model,
+ api_base=api_base,
+ outage_value=outage_value,
+ provider=provider,
+ )
+ # send minor alert
+ await self.send_alert(
+ message=msg,
+ level="Medium",
+ alert_type=AlertType.outage_alerts,
+ alerting_metadata={},
+ )
+ # set to true
+ outage_value["minor_alert_sent"] = True
+ elif (
+ outage_value["major_alert_sent"] is False
+ and len(outage_value["alerts"])
+ >= self.alerting_args.major_outage_alert_threshold
+ ):
+ msg = self._outage_alert_msg_factory(
+ alert_type="Major",
+ key="Model",
+ key_val=model,
+ api_base=api_base,
+ outage_value=outage_value,
+ provider=provider,
+ )
+ # send minor alert
+ await self.send_alert(
+ message=msg,
+ level="High",
+ alert_type=AlertType.outage_alerts,
+ alerting_metadata={},
+ )
+ # set to true
+ outage_value["major_alert_sent"] = True
+
+ ## update cache ##
+ await self.internal_usage_cache.async_set_cache(
+ key=deployment_id, value=outage_value
+ )
+ except Exception:
+ pass
+
+ async def model_added_alert(
+ self, model_name: str, litellm_model_name: str, passed_model_info: Any
+ ):
+ base_model_from_user = getattr(passed_model_info, "base_model", None)
+ model_info = {}
+ base_model = ""
+ if base_model_from_user is not None:
+ model_info = litellm.model_cost.get(base_model_from_user, {})
+ base_model = f"Base Model: `{base_model_from_user}`\n"
+ else:
+ model_info = litellm.model_cost.get(litellm_model_name, {})
+ model_info_str = ""
+ for k, v in model_info.items():
+ if k == "input_cost_per_token" or k == "output_cost_per_token":
+ # when converting to string it should not be 1.63e-06
+ v = "{:.8f}".format(v)
+
+ model_info_str += f"{k}: {v}\n"
+
+ message = f"""
+*🚅 New Model Added*
+Model Name: `{model_name}`
+{base_model}
+
+Usage OpenAI Python SDK:
+```
+import openai
+client = openai.OpenAI(
+ api_key="your_api_key",
+ base_url={os.getenv("PROXY_BASE_URL", "http://0.0.0.0:4000")}
+)
+
+response = client.chat.completions.create(
+ model="{model_name}", # model to send to the proxy
+ messages = [
+ {{
+ "role": "user",
+ "content": "this is a test request, write a short poem"
+ }}
+ ]
+)
+```
+
+Model Info:
+```
+{model_info_str}
+```
+"""
+
+ alert_val = self.send_alert(
+ message=message,
+ level="Low",
+ alert_type=AlertType.new_model_added,
+ alerting_metadata={},
+ )
+
+ if alert_val is not None and asyncio.iscoroutine(alert_val):
+ await alert_val
+
+ async def model_removed_alert(self, model_name: str):
+ pass
+
+ async def send_webhook_alert(self, webhook_event: WebhookEvent) -> bool:
+ """
+ Sends structured alert to webhook, if set.
+
+ Currently only implemented for budget alerts
+
+ Returns -> True if sent, False if not.
+
+ Raises Exception
+ - if WEBHOOK_URL is not set
+ """
+
+ webhook_url = os.getenv("WEBHOOK_URL", None)
+ if webhook_url is None:
+ raise Exception("Missing webhook_url from environment")
+
+ payload = webhook_event.model_dump_json()
+ headers = {"Content-type": "application/json"}
+
+ response = await self.async_http_handler.post(
+ url=webhook_url,
+ headers=headers,
+ data=payload,
+ )
+ if response.status_code == 200:
+ return True
+ else:
+ print("Error sending webhook alert. Error=", response.text) # noqa
+
+ return False
+
+ async def _check_if_using_premium_email_feature(
+ self,
+ premium_user: bool,
+ email_logo_url: Optional[str] = None,
+ email_support_contact: Optional[str] = None,
+ ):
+ from litellm.proxy.proxy_server import CommonProxyErrors, premium_user
+
+ if premium_user is not True:
+ if email_logo_url is not None or email_support_contact is not None:
+ raise ValueError(
+ f"Trying to Customize Email Alerting\n {CommonProxyErrors.not_premium_user.value}"
+ )
+ return
+
+ async def send_key_created_or_user_invited_email(
+ self, webhook_event: WebhookEvent
+ ) -> bool:
+ try:
+ from litellm.proxy.utils import send_email
+
+ if self.alerting is None or "email" not in self.alerting:
+ # do nothing if user does not want email alerts
+ verbose_proxy_logger.error(
+ "Error sending email alert - 'email' not in self.alerting %s",
+ self.alerting,
+ )
+ return False
+ from litellm.proxy.proxy_server import premium_user, prisma_client
+
+ email_logo_url = os.getenv(
+ "SMTP_SENDER_LOGO", os.getenv("EMAIL_LOGO_URL", None)
+ )
+ email_support_contact = os.getenv("EMAIL_SUPPORT_CONTACT", None)
+ await self._check_if_using_premium_email_feature(
+ premium_user, email_logo_url, email_support_contact
+ )
+ if email_logo_url is None:
+ email_logo_url = LITELLM_LOGO_URL
+ if email_support_contact is None:
+ email_support_contact = LITELLM_SUPPORT_CONTACT
+
+ event_name = webhook_event.event_message
+ recipient_email = webhook_event.user_email
+ recipient_user_id = webhook_event.user_id
+ if (
+ recipient_email is None
+ and recipient_user_id is not None
+ and prisma_client is not None
+ ):
+ user_row = await prisma_client.db.litellm_usertable.find_unique(
+ where={"user_id": recipient_user_id}
+ )
+
+ if user_row is not None:
+ recipient_email = user_row.user_email
+
+ key_token = webhook_event.token
+ key_budget = webhook_event.max_budget
+ base_url = os.getenv("PROXY_BASE_URL", "http://0.0.0.0:4000")
+
+ email_html_content = "Alert from LiteLLM Server"
+ if recipient_email is None:
+ verbose_proxy_logger.error(
+ "Trying to send email alert to no recipient",
+ extra=webhook_event.dict(),
+ )
+
+ if webhook_event.event == "key_created":
+ email_html_content = KEY_CREATED_EMAIL_TEMPLATE.format(
+ email_logo_url=email_logo_url,
+ recipient_email=recipient_email,
+ key_budget=key_budget,
+ key_token=key_token,
+ base_url=base_url,
+ email_support_contact=email_support_contact,
+ )
+ elif webhook_event.event == "internal_user_created":
+ # GET TEAM NAME
+ team_id = webhook_event.team_id
+ team_name = "Default Team"
+ if team_id is not None and prisma_client is not None:
+ team_row = await prisma_client.db.litellm_teamtable.find_unique(
+ where={"team_id": team_id}
+ )
+ if team_row is not None:
+ team_name = team_row.team_alias or "-"
+ email_html_content = USER_INVITED_EMAIL_TEMPLATE.format(
+ email_logo_url=email_logo_url,
+ recipient_email=recipient_email,
+ team_name=team_name,
+ base_url=base_url,
+ email_support_contact=email_support_contact,
+ )
+ else:
+ verbose_proxy_logger.error(
+ "Trying to send email alert on unknown webhook event",
+ extra=webhook_event.model_dump(),
+ )
+
+ webhook_event.model_dump_json()
+ email_event = {
+ "to": recipient_email,
+ "subject": f"LiteLLM: {event_name}",
+ "html": email_html_content,
+ }
+
+ await send_email(
+ receiver_email=email_event["to"],
+ subject=email_event["subject"],
+ html=email_event["html"],
+ )
+
+ return True
+
+ except Exception as e:
+ verbose_proxy_logger.error("Error sending email alert %s", str(e))
+ return False
+
+ async def send_email_alert_using_smtp(
+ self, webhook_event: WebhookEvent, alert_type: str
+ ) -> bool:
+ """
+ Sends structured Email alert to an SMTP server
+
+ Currently only implemented for budget alerts
+
+ Returns -> True if sent, False if not.
+ """
+ from litellm.proxy.proxy_server import premium_user
+ from litellm.proxy.utils import send_email
+
+ email_logo_url = os.getenv(
+ "SMTP_SENDER_LOGO", os.getenv("EMAIL_LOGO_URL", None)
+ )
+ email_support_contact = os.getenv("EMAIL_SUPPORT_CONTACT", None)
+ await self._check_if_using_premium_email_feature(
+ premium_user, email_logo_url, email_support_contact
+ )
+
+ if email_logo_url is None:
+ email_logo_url = LITELLM_LOGO_URL
+ if email_support_contact is None:
+ email_support_contact = LITELLM_SUPPORT_CONTACT
+
+ event_name = webhook_event.event_message
+ recipient_email = webhook_event.user_email
+ user_name = webhook_event.user_id
+ max_budget = webhook_event.max_budget
+ email_html_content = "Alert from LiteLLM Server"
+ if recipient_email is None:
+ verbose_proxy_logger.error(
+ "Trying to send email alert to no recipient", extra=webhook_event.dict()
+ )
+
+ if webhook_event.event == "budget_crossed":
+ email_html_content = f"""
+ <img src="{email_logo_url}" alt="LiteLLM Logo" width="150" height="50" />
+
+ <p> Hi {user_name}, <br/>
+
+ Your LLM API usage this month has reached your account's <b> monthly budget of ${max_budget} </b> <br /> <br />
+
+ API requests will be rejected until either (a) you increase your monthly budget or (b) your monthly usage resets at the beginning of the next calendar month. <br /> <br />
+
+ If you have any questions, please send an email to {email_support_contact} <br /> <br />
+
+ Best, <br />
+ The LiteLLM team <br />
+ """
+
+ webhook_event.model_dump_json()
+ email_event = {
+ "to": recipient_email,
+ "subject": f"LiteLLM: {event_name}",
+ "html": email_html_content,
+ }
+
+ await send_email(
+ receiver_email=email_event["to"],
+ subject=email_event["subject"],
+ html=email_event["html"],
+ )
+ if webhook_event.event_group == "team":
+ from litellm.integrations.email_alerting import send_team_budget_alert
+
+ await send_team_budget_alert(webhook_event=webhook_event)
+
+ return False
+
+ async def send_alert(
+ self,
+ message: str,
+ level: Literal["Low", "Medium", "High"],
+ alert_type: AlertType,
+ alerting_metadata: dict,
+ user_info: Optional[WebhookEvent] = None,
+ **kwargs,
+ ):
+ """
+ Alerting based on thresholds: - https://github.com/BerriAI/litellm/issues/1298
+
+ - Responses taking too long
+ - Requests are hanging
+ - Calls are failing
+ - DB Read/Writes are failing
+ - Proxy Close to max budget
+ - Key Close to max budget
+
+ Parameters:
+ level: str - Low|Medium|High - if calls might fail (Medium) or are failing (High); Currently, no alerts would be 'Low'.
+ message: str - what is the alert about
+ """
+ if self.alerting is None:
+ return
+
+ if (
+ "webhook" in self.alerting
+ and alert_type == "budget_alerts"
+ and user_info is not None
+ ):
+ await self.send_webhook_alert(webhook_event=user_info)
+
+ if (
+ "email" in self.alerting
+ and alert_type == "budget_alerts"
+ and user_info is not None
+ ):
+ # only send budget alerts over Email
+ await self.send_email_alert_using_smtp(
+ webhook_event=user_info, alert_type=alert_type
+ )
+
+ if "slack" not in self.alerting:
+ return
+ if alert_type not in self.alert_types:
+ return
+
+ from datetime import datetime
+
+ # Get the current timestamp
+ current_time = datetime.now().strftime("%H:%M:%S")
+ _proxy_base_url = os.getenv("PROXY_BASE_URL", None)
+ if alert_type == "daily_reports" or alert_type == "new_model_added":
+ formatted_message = message
+ else:
+ formatted_message = (
+ f"Level: `{level}`\nTimestamp: `{current_time}`\n\nMessage: {message}"
+ )
+
+ if kwargs:
+ for key, value in kwargs.items():
+ formatted_message += f"\n\n{key}: `{value}`\n\n"
+ if alerting_metadata:
+ for key, value in alerting_metadata.items():
+ formatted_message += f"\n\n*Alerting Metadata*: \n{key}: `{value}`\n\n"
+ if _proxy_base_url is not None:
+ formatted_message += f"\n\nProxy URL: `{_proxy_base_url}`"
+
+ # check if we find the slack webhook url in self.alert_to_webhook_url
+ if (
+ self.alert_to_webhook_url is not None
+ and alert_type in self.alert_to_webhook_url
+ ):
+ slack_webhook_url: Optional[Union[str, List[str]]] = (
+ self.alert_to_webhook_url[alert_type]
+ )
+ elif self.default_webhook_url is not None:
+ slack_webhook_url = self.default_webhook_url
+ else:
+ slack_webhook_url = os.getenv("SLACK_WEBHOOK_URL", None)
+
+ if slack_webhook_url is None:
+ raise ValueError("Missing SLACK_WEBHOOK_URL from environment")
+ payload = {"text": formatted_message}
+ headers = {"Content-type": "application/json"}
+
+ if isinstance(slack_webhook_url, list):
+ for url in slack_webhook_url:
+ self.log_queue.append(
+ {
+ "url": url,
+ "headers": headers,
+ "payload": payload,
+ "alert_type": alert_type,
+ }
+ )
+ else:
+ self.log_queue.append(
+ {
+ "url": slack_webhook_url,
+ "headers": headers,
+ "payload": payload,
+ "alert_type": alert_type,
+ }
+ )
+
+ if len(self.log_queue) >= self.batch_size:
+ await self.flush_queue()
+
+ async def async_send_batch(self):
+ if not self.log_queue:
+ return
+
+ squashed_queue = squash_payloads(self.log_queue)
+ tasks = [
+ send_to_webhook(
+ slackAlertingInstance=self, item=item["item"], count=item["count"]
+ )
+ for item in squashed_queue.values()
+ ]
+ await asyncio.gather(*tasks)
+ self.log_queue.clear()
+
+ async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
+ """Log deployment latency"""
+ try:
+ if "daily_reports" in self.alert_types:
+ litellm_params = kwargs.get("litellm_params", {}) or {}
+ model_info = litellm_params.get("model_info", {}) or {}
+ model_id = model_info.get("id", "") or ""
+ response_s: timedelta = end_time - start_time
+
+ final_value = response_s
+
+ if isinstance(response_obj, litellm.ModelResponse) and (
+ hasattr(response_obj, "usage")
+ and response_obj.usage is not None # type: ignore
+ and hasattr(response_obj.usage, "completion_tokens") # type: ignore
+ ):
+ completion_tokens = response_obj.usage.completion_tokens # type: ignore
+ if completion_tokens is not None and completion_tokens > 0:
+ final_value = float(
+ response_s.total_seconds() / completion_tokens
+ )
+ if isinstance(final_value, timedelta):
+ final_value = final_value.total_seconds()
+
+ await self.async_update_daily_reports(
+ DeploymentMetrics(
+ id=model_id,
+ failed_request=False,
+ latency_per_output_token=final_value,
+ updated_at=litellm.utils.get_utc_datetime(),
+ )
+ )
+ except Exception as e:
+ verbose_proxy_logger.error(
+ f"[Non-Blocking Error] Slack Alerting: Got error in logging LLM deployment latency: {str(e)}"
+ )
+ pass
+
+ async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
+ """Log failure + deployment latency"""
+ _litellm_params = kwargs.get("litellm_params", {})
+ _model_info = _litellm_params.get("model_info", {}) or {}
+ model_id = _model_info.get("id", "")
+ try:
+ if "daily_reports" in self.alert_types:
+ try:
+ await self.async_update_daily_reports(
+ DeploymentMetrics(
+ id=model_id,
+ failed_request=True,
+ latency_per_output_token=None,
+ updated_at=litellm.utils.get_utc_datetime(),
+ )
+ )
+ except Exception as e:
+ verbose_logger.debug(f"Exception raises -{str(e)}")
+
+ if isinstance(kwargs.get("exception", ""), APIError):
+ if "outage_alerts" in self.alert_types:
+ await self.outage_alerts(
+ exception=kwargs["exception"],
+ deployment_id=model_id,
+ )
+
+ if "region_outage_alerts" in self.alert_types:
+ await self.region_outage_alerts(
+ exception=kwargs["exception"], deployment_id=model_id
+ )
+ except Exception:
+ pass
+
+ async def _run_scheduler_helper(self, llm_router) -> bool:
+ """
+ Returns:
+ - True -> report sent
+ - False -> report not sent
+ """
+ report_sent_bool = False
+
+ report_sent = await self.internal_usage_cache.async_get_cache(
+ key=SlackAlertingCacheKeys.report_sent_key.value,
+ parent_otel_span=None,
+ ) # None | float
+
+ current_time = time.time()
+
+ if report_sent is None:
+ await self.internal_usage_cache.async_set_cache(
+ key=SlackAlertingCacheKeys.report_sent_key.value,
+ value=current_time,
+ )
+ elif isinstance(report_sent, float):
+ # Check if current time - interval >= time last sent
+ interval_seconds = self.alerting_args.daily_report_frequency
+
+ if current_time - report_sent >= interval_seconds:
+ # Sneak in the reporting logic here
+ await self.send_daily_reports(router=llm_router)
+ # Also, don't forget to update the report_sent time after sending the report!
+ await self.internal_usage_cache.async_set_cache(
+ key=SlackAlertingCacheKeys.report_sent_key.value,
+ value=current_time,
+ )
+ report_sent_bool = True
+
+ return report_sent_bool
+
+ async def _run_scheduled_daily_report(self, llm_router: Optional[Any] = None):
+ """
+ If 'daily_reports' enabled
+
+ Ping redis cache every 5 minutes to check if we should send the report
+
+ If yes -> call send_daily_report()
+ """
+ if llm_router is None or self.alert_types is None:
+ return
+
+ if "daily_reports" in self.alert_types:
+ while True:
+ await self._run_scheduler_helper(llm_router=llm_router)
+ interval = random.randint(
+ self.alerting_args.report_check_interval - 3,
+ self.alerting_args.report_check_interval + 3,
+ ) # shuffle to prevent collisions
+ await asyncio.sleep(interval)
+ return
+
+ async def send_weekly_spend_report(
+ self,
+ time_range: str = "7d",
+ ):
+ """
+ Send a spend report for a configurable time range.
+
+ Args:
+ time_range: A string specifying the time range for the report, e.g., "1d", "7d", "30d"
+ """
+ if self.alerting is None or "spend_reports" not in self.alert_types:
+ return
+
+ try:
+ from litellm.proxy.spend_tracking.spend_management_endpoints import (
+ _get_spend_report_for_time_range,
+ )
+
+ # Parse the time range
+ days = int(time_range[:-1])
+ if time_range[-1].lower() != "d":
+ raise ValueError("Time range must be specified in days, e.g., '7d'")
+
+ todays_date = datetime.datetime.now().date()
+ start_date = todays_date - datetime.timedelta(days=days)
+
+ _event_cache_key = f"weekly_spend_report_sent_{start_date.strftime('%Y-%m-%d')}_{todays_date.strftime('%Y-%m-%d')}"
+ if await self.internal_usage_cache.async_get_cache(key=_event_cache_key):
+ return
+
+ _resp = await _get_spend_report_for_time_range(
+ start_date=start_date.strftime("%Y-%m-%d"),
+ end_date=todays_date.strftime("%Y-%m-%d"),
+ )
+ if _resp is None or _resp == ([], []):
+ return
+
+ spend_per_team, spend_per_tag = _resp
+
+ _spend_message = f"*💸 Spend Report for `{start_date.strftime('%m-%d-%Y')} - {todays_date.strftime('%m-%d-%Y')}` ({days} days)*\n"
+
+ if spend_per_team is not None:
+ _spend_message += "\n*Team Spend Report:*\n"
+ for spend in spend_per_team:
+ _team_spend = round(float(spend["total_spend"]), 4)
+ _spend_message += (
+ f"Team: `{spend['team_alias']}` | Spend: `${_team_spend}`\n"
+ )
+
+ if spend_per_tag is not None:
+ _spend_message += "\n*Tag Spend Report:*\n"
+ for spend in spend_per_tag:
+ _tag_spend = round(float(spend["total_spend"]), 4)
+ _spend_message += f"Tag: `{spend['individual_request_tag']}` | Spend: `${_tag_spend}`\n"
+
+ await self.send_alert(
+ message=_spend_message,
+ level="Low",
+ alert_type=AlertType.spend_reports,
+ alerting_metadata={},
+ )
+
+ await self.internal_usage_cache.async_set_cache(
+ key=_event_cache_key,
+ value="SENT",
+ ttl=duration_in_seconds(time_range),
+ )
+
+ except ValueError as ve:
+ verbose_proxy_logger.error(f"Invalid time range format: {ve}")
+ except Exception as e:
+ verbose_proxy_logger.error(f"Error sending spend report: {e}")
+
+ async def send_monthly_spend_report(self):
+ """ """
+ try:
+ from calendar import monthrange
+
+ from litellm.proxy.spend_tracking.spend_management_endpoints import (
+ _get_spend_report_for_time_range,
+ )
+
+ todays_date = datetime.datetime.now().date()
+ first_day_of_month = todays_date.replace(day=1)
+ _, last_day_of_month = monthrange(todays_date.year, todays_date.month)
+ last_day_of_month = first_day_of_month + datetime.timedelta(
+ days=last_day_of_month - 1
+ )
+
+ _event_cache_key = f"monthly_spend_report_sent_{first_day_of_month.strftime('%Y-%m-%d')}_{last_day_of_month.strftime('%Y-%m-%d')}"
+ if await self.internal_usage_cache.async_get_cache(key=_event_cache_key):
+ return
+
+ _resp = await _get_spend_report_for_time_range(
+ start_date=first_day_of_month.strftime("%Y-%m-%d"),
+ end_date=last_day_of_month.strftime("%Y-%m-%d"),
+ )
+
+ if _resp is None or _resp == ([], []):
+ return
+
+ monthly_spend_per_team, monthly_spend_per_tag = _resp
+
+ _spend_message = f"*💸 Monthly Spend Report for `{first_day_of_month.strftime('%m-%d-%Y')} - {last_day_of_month.strftime('%m-%d-%Y')}` *\n"
+
+ if monthly_spend_per_team is not None:
+ _spend_message += "\n*Team Spend Report:*\n"
+ for spend in monthly_spend_per_team:
+ _team_spend = spend["total_spend"]
+ _team_spend = float(_team_spend)
+ # round to 4 decimal places
+ _team_spend = round(_team_spend, 4)
+ _spend_message += (
+ f"Team: `{spend['team_alias']}` | Spend: `${_team_spend}`\n"
+ )
+
+ if monthly_spend_per_tag is not None:
+ _spend_message += "\n*Tag Spend Report:*\n"
+ for spend in monthly_spend_per_tag:
+ _tag_spend = spend["total_spend"]
+ _tag_spend = float(_tag_spend)
+ # round to 4 decimal places
+ _tag_spend = round(_tag_spend, 4)
+ _spend_message += f"Tag: `{spend['individual_request_tag']}` | Spend: `${_tag_spend}`\n"
+
+ await self.send_alert(
+ message=_spend_message,
+ level="Low",
+ alert_type=AlertType.spend_reports,
+ alerting_metadata={},
+ )
+
+ await self.internal_usage_cache.async_set_cache(
+ key=_event_cache_key,
+ value="SENT",
+ ttl=(30 * 24 * 60 * 60), # 1 month
+ )
+
+ except Exception as e:
+ verbose_proxy_logger.exception("Error sending weekly spend report %s", e)
+
+ async def send_fallback_stats_from_prometheus(self):
+ """
+ Helper to send fallback statistics from prometheus server -> to slack
+
+ This runs once per day and sends an overview of all the fallback statistics
+ """
+ try:
+ from litellm.integrations.prometheus_helpers.prometheus_api import (
+ get_fallback_metric_from_prometheus,
+ )
+
+ # call prometheuslogger.
+ falllback_success_info_prometheus = (
+ await get_fallback_metric_from_prometheus()
+ )
+
+ fallback_message = (
+ f"*Fallback Statistics:*\n{falllback_success_info_prometheus}"
+ )
+
+ await self.send_alert(
+ message=fallback_message,
+ level="Low",
+ alert_type=AlertType.fallback_reports,
+ alerting_metadata={},
+ )
+
+ except Exception as e:
+ verbose_proxy_logger.error("Error sending weekly spend report %s", e)
+
+ pass
+
+ async def send_virtual_key_event_slack(
+ self,
+ key_event: VirtualKeyEvent,
+ alert_type: AlertType,
+ event_name: str,
+ ):
+ """
+ Handles sending Virtual Key related alerts
+
+ Example:
+ - New Virtual Key Created
+ - Internal User Updated
+ - Team Created, Updated, Deleted
+ """
+ try:
+
+ message = f"`{event_name}`\n"
+
+ key_event_dict = key_event.model_dump()
+
+ # Add Created by information first
+ message += "*Action Done by:*\n"
+ for key, value in key_event_dict.items():
+ if "created_by" in key:
+ message += f"{key}: `{value}`\n"
+
+ # Add args sent to function in the alert
+ message += "\n*Arguments passed:*\n"
+ request_kwargs = key_event.request_kwargs
+ for key, value in request_kwargs.items():
+ if key == "user_api_key_dict":
+ continue
+ message += f"{key}: `{value}`\n"
+
+ await self.send_alert(
+ message=message,
+ level="High",
+ alert_type=alert_type,
+ alerting_metadata={},
+ )
+
+ except Exception as e:
+ verbose_proxy_logger.error(
+ "Error sending send_virtual_key_event_slack %s", e
+ )
+
+ return
+
+ async def _request_is_completed(self, request_data: Optional[dict]) -> bool:
+ """
+ Returns True if the request is completed - either as a success or failure
+ """
+ if request_data is None:
+ return False
+
+ if (
+ request_data.get("litellm_status", "") != "success"
+ and request_data.get("litellm_status", "") != "fail"
+ ):
+ ## CHECK IF CACHE IS UPDATED
+ litellm_call_id = request_data.get("litellm_call_id", "")
+ status: Optional[str] = await self.internal_usage_cache.async_get_cache(
+ key="request_status:{}".format(litellm_call_id), local_only=True
+ )
+ if status is not None and (status == "success" or status == "fail"):
+ return True
+ return False
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/SlackAlerting/utils.py b/.venv/lib/python3.12/site-packages/litellm/integrations/SlackAlerting/utils.py
new file mode 100644
index 00000000..0dc8bae5
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/SlackAlerting/utils.py
@@ -0,0 +1,92 @@
+"""
+Utils used for slack alerting
+"""
+
+import asyncio
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
+
+from litellm.proxy._types import AlertType
+from litellm.secret_managers.main import get_secret
+
+if TYPE_CHECKING:
+ from litellm.litellm_core_utils.litellm_logging import Logging as _Logging
+
+ Logging = _Logging
+else:
+ Logging = Any
+
+
+def process_slack_alerting_variables(
+ alert_to_webhook_url: Optional[Dict[AlertType, Union[List[str], str]]]
+) -> Optional[Dict[AlertType, Union[List[str], str]]]:
+ """
+ process alert_to_webhook_url
+ - check if any urls are set as os.environ/SLACK_WEBHOOK_URL_1 read env var and set the correct value
+ """
+ if alert_to_webhook_url is None:
+ return None
+
+ for alert_type, webhook_urls in alert_to_webhook_url.items():
+ if isinstance(webhook_urls, list):
+ _webhook_values: List[str] = []
+ for webhook_url in webhook_urls:
+ if "os.environ/" in webhook_url:
+ _env_value = get_secret(secret_name=webhook_url)
+ if not isinstance(_env_value, str):
+ raise ValueError(
+ f"Invalid webhook url value for: {webhook_url}. Got type={type(_env_value)}"
+ )
+ _webhook_values.append(_env_value)
+ else:
+ _webhook_values.append(webhook_url)
+
+ alert_to_webhook_url[alert_type] = _webhook_values
+ else:
+ _webhook_value_str: str = webhook_urls
+ if "os.environ/" in webhook_urls:
+ _env_value = get_secret(secret_name=webhook_urls)
+ if not isinstance(_env_value, str):
+ raise ValueError(
+ f"Invalid webhook url value for: {webhook_urls}. Got type={type(_env_value)}"
+ )
+ _webhook_value_str = _env_value
+ else:
+ _webhook_value_str = webhook_urls
+
+ alert_to_webhook_url[alert_type] = _webhook_value_str
+
+ return alert_to_webhook_url
+
+
+async def _add_langfuse_trace_id_to_alert(
+ request_data: Optional[dict] = None,
+) -> Optional[str]:
+ """
+ Returns langfuse trace url
+
+ - check:
+ -> existing_trace_id
+ -> trace_id
+ -> litellm_call_id
+ """
+ # do nothing for now
+ if (
+ request_data is not None
+ and request_data.get("litellm_logging_obj", None) is not None
+ ):
+ trace_id: Optional[str] = None
+ litellm_logging_obj: Logging = request_data["litellm_logging_obj"]
+
+ for _ in range(3):
+ trace_id = litellm_logging_obj._get_trace_id(service_name="langfuse")
+ if trace_id is not None:
+ break
+ await asyncio.sleep(3) # wait 3s before retrying for trace id
+
+ _langfuse_object = litellm_logging_obj._get_callback_object(
+ service_name="langfuse"
+ )
+ if _langfuse_object is not None:
+ base_url = _langfuse_object.Langfuse.base_url
+ return f"{base_url}/trace/{trace_id}"
+ return None
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/__init__.py b/.venv/lib/python3.12/site-packages/litellm/integrations/__init__.py
new file mode 100644
index 00000000..b6e690fd
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/__init__.py
@@ -0,0 +1 @@
+from . import *
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/_types/open_inference.py b/.venv/lib/python3.12/site-packages/litellm/integrations/_types/open_inference.py
new file mode 100644
index 00000000..b5076c0e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/_types/open_inference.py
@@ -0,0 +1,286 @@
+from enum import Enum
+
+
+class SpanAttributes:
+ OUTPUT_VALUE = "output.value"
+ OUTPUT_MIME_TYPE = "output.mime_type"
+ """
+ The type of output.value. If unspecified, the type is plain text by default.
+ If type is JSON, the value is a string representing a JSON object.
+ """
+ INPUT_VALUE = "input.value"
+ INPUT_MIME_TYPE = "input.mime_type"
+ """
+ The type of input.value. If unspecified, the type is plain text by default.
+ If type is JSON, the value is a string representing a JSON object.
+ """
+
+ EMBEDDING_EMBEDDINGS = "embedding.embeddings"
+ """
+ A list of objects containing embedding data, including the vector and represented piece of text.
+ """
+ EMBEDDING_MODEL_NAME = "embedding.model_name"
+ """
+ The name of the embedding model.
+ """
+
+ LLM_FUNCTION_CALL = "llm.function_call"
+ """
+ For models and APIs that support function calling. Records attributes such as the function
+ name and arguments to the called function.
+ """
+ LLM_INVOCATION_PARAMETERS = "llm.invocation_parameters"
+ """
+ Invocation parameters passed to the LLM or API, such as the model name, temperature, etc.
+ """
+ LLM_INPUT_MESSAGES = "llm.input_messages"
+ """
+ Messages provided to a chat API.
+ """
+ LLM_OUTPUT_MESSAGES = "llm.output_messages"
+ """
+ Messages received from a chat API.
+ """
+ LLM_MODEL_NAME = "llm.model_name"
+ """
+ The name of the model being used.
+ """
+ LLM_PROMPTS = "llm.prompts"
+ """
+ Prompts provided to a completions API.
+ """
+ LLM_PROMPT_TEMPLATE = "llm.prompt_template.template"
+ """
+ The prompt template as a Python f-string.
+ """
+ LLM_PROMPT_TEMPLATE_VARIABLES = "llm.prompt_template.variables"
+ """
+ A list of input variables to the prompt template.
+ """
+ LLM_PROMPT_TEMPLATE_VERSION = "llm.prompt_template.version"
+ """
+ The version of the prompt template being used.
+ """
+ LLM_TOKEN_COUNT_PROMPT = "llm.token_count.prompt"
+ """
+ Number of tokens in the prompt.
+ """
+ LLM_TOKEN_COUNT_COMPLETION = "llm.token_count.completion"
+ """
+ Number of tokens in the completion.
+ """
+ LLM_TOKEN_COUNT_TOTAL = "llm.token_count.total"
+ """
+ Total number of tokens, including both prompt and completion.
+ """
+
+ TOOL_NAME = "tool.name"
+ """
+ Name of the tool being used.
+ """
+ TOOL_DESCRIPTION = "tool.description"
+ """
+ Description of the tool's purpose, typically used to select the tool.
+ """
+ TOOL_PARAMETERS = "tool.parameters"
+ """
+ Parameters of the tool represented a dictionary JSON string, e.g.
+ see https://platform.openai.com/docs/guides/gpt/function-calling
+ """
+
+ RETRIEVAL_DOCUMENTS = "retrieval.documents"
+
+ METADATA = "metadata"
+ """
+ Metadata attributes are used to store user-defined key-value pairs.
+ For example, LangChain uses metadata to store user-defined attributes for a chain.
+ """
+
+ TAG_TAGS = "tag.tags"
+ """
+ Custom categorical tags for the span.
+ """
+
+ OPENINFERENCE_SPAN_KIND = "openinference.span.kind"
+
+ SESSION_ID = "session.id"
+ """
+ The id of the session
+ """
+ USER_ID = "user.id"
+ """
+ The id of the user
+ """
+
+
+class MessageAttributes:
+ """
+ Attributes for a message sent to or from an LLM
+ """
+
+ MESSAGE_ROLE = "message.role"
+ """
+ The role of the message, such as "user", "agent", "function".
+ """
+ MESSAGE_CONTENT = "message.content"
+ """
+ The content of the message to or from the llm, must be a string.
+ """
+ MESSAGE_CONTENTS = "message.contents"
+ """
+ The message contents to the llm, it is an array of
+ `message_content` prefixed attributes.
+ """
+ MESSAGE_NAME = "message.name"
+ """
+ The name of the message, often used to identify the function
+ that was used to generate the message.
+ """
+ MESSAGE_TOOL_CALLS = "message.tool_calls"
+ """
+ The tool calls generated by the model, such as function calls.
+ """
+ MESSAGE_FUNCTION_CALL_NAME = "message.function_call_name"
+ """
+ The function name that is a part of the message list.
+ This is populated for role 'function' or 'agent' as a mechanism to identify
+ the function that was called during the execution of a tool.
+ """
+ MESSAGE_FUNCTION_CALL_ARGUMENTS_JSON = "message.function_call_arguments_json"
+ """
+ The JSON string representing the arguments passed to the function
+ during a function call.
+ """
+
+
+class MessageContentAttributes:
+ """
+ Attributes for the contents of user messages sent to an LLM.
+ """
+
+ MESSAGE_CONTENT_TYPE = "message_content.type"
+ """
+ The type of the content, such as "text" or "image".
+ """
+ MESSAGE_CONTENT_TEXT = "message_content.text"
+ """
+ The text content of the message, if the type is "text".
+ """
+ MESSAGE_CONTENT_IMAGE = "message_content.image"
+ """
+ The image content of the message, if the type is "image".
+ An image can be made available to the model by passing a link to
+ the image or by passing the base64 encoded image directly in the
+ request.
+ """
+
+
+class ImageAttributes:
+ """
+ Attributes for images
+ """
+
+ IMAGE_URL = "image.url"
+ """
+ An http or base64 image url
+ """
+
+
+class DocumentAttributes:
+ """
+ Attributes for a document.
+ """
+
+ DOCUMENT_ID = "document.id"
+ """
+ The id of the document.
+ """
+ DOCUMENT_SCORE = "document.score"
+ """
+ The score of the document
+ """
+ DOCUMENT_CONTENT = "document.content"
+ """
+ The content of the document.
+ """
+ DOCUMENT_METADATA = "document.metadata"
+ """
+ The metadata of the document represented as a dictionary
+ JSON string, e.g. `"{ 'title': 'foo' }"`
+ """
+
+
+class RerankerAttributes:
+ """
+ Attributes for a reranker
+ """
+
+ RERANKER_INPUT_DOCUMENTS = "reranker.input_documents"
+ """
+ List of documents as input to the reranker
+ """
+ RERANKER_OUTPUT_DOCUMENTS = "reranker.output_documents"
+ """
+ List of documents as output from the reranker
+ """
+ RERANKER_QUERY = "reranker.query"
+ """
+ Query string for the reranker
+ """
+ RERANKER_MODEL_NAME = "reranker.model_name"
+ """
+ Model name of the reranker
+ """
+ RERANKER_TOP_K = "reranker.top_k"
+ """
+ Top K parameter of the reranker
+ """
+
+
+class EmbeddingAttributes:
+ """
+ Attributes for an embedding
+ """
+
+ EMBEDDING_TEXT = "embedding.text"
+ """
+ The text represented by the embedding.
+ """
+ EMBEDDING_VECTOR = "embedding.vector"
+ """
+ The embedding vector.
+ """
+
+
+class ToolCallAttributes:
+ """
+ Attributes for a tool call
+ """
+
+ TOOL_CALL_FUNCTION_NAME = "tool_call.function.name"
+ """
+ The name of function that is being called during a tool call.
+ """
+ TOOL_CALL_FUNCTION_ARGUMENTS_JSON = "tool_call.function.arguments"
+ """
+ The JSON string representing the arguments passed to the function
+ during a tool call.
+ """
+
+
+class OpenInferenceSpanKindValues(Enum):
+ TOOL = "TOOL"
+ CHAIN = "CHAIN"
+ LLM = "LLM"
+ RETRIEVER = "RETRIEVER"
+ EMBEDDING = "EMBEDDING"
+ AGENT = "AGENT"
+ RERANKER = "RERANKER"
+ UNKNOWN = "UNKNOWN"
+ GUARDRAIL = "GUARDRAIL"
+ EVALUATOR = "EVALUATOR"
+
+
+class OpenInferenceMimeTypeValues(Enum):
+ TEXT = "text/plain"
+ JSON = "application/json" \ No newline at end of file
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/additional_logging_utils.py b/.venv/lib/python3.12/site-packages/litellm/integrations/additional_logging_utils.py
new file mode 100644
index 00000000..795afd81
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/additional_logging_utils.py
@@ -0,0 +1,36 @@
+"""
+Base class for Additional Logging Utils for CustomLoggers
+
+- Health Check for the logging util
+- Get Request / Response Payload for the logging util
+"""
+
+from abc import ABC, abstractmethod
+from datetime import datetime
+from typing import Optional
+
+from litellm.types.integrations.base_health_check import IntegrationHealthCheckStatus
+
+
+class AdditionalLoggingUtils(ABC):
+ def __init__(self):
+ super().__init__()
+
+ @abstractmethod
+ async def async_health_check(self) -> IntegrationHealthCheckStatus:
+ """
+ Check if the service is healthy
+ """
+ pass
+
+ @abstractmethod
+ async def get_request_response_payload(
+ self,
+ request_id: str,
+ start_time_utc: Optional[datetime],
+ end_time_utc: Optional[datetime],
+ ) -> Optional[dict]:
+ """
+ Get the request and response payload for a given `request_id`
+ """
+ return None
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/argilla.py b/.venv/lib/python3.12/site-packages/litellm/integrations/argilla.py
new file mode 100644
index 00000000..055ad902
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/argilla.py
@@ -0,0 +1,392 @@
+"""
+Send logs to Argilla for annotation
+"""
+
+import asyncio
+import json
+import os
+import random
+import types
+from typing import Any, Dict, List, Optional
+
+import httpx
+from pydantic import BaseModel # type: ignore
+
+import litellm
+from litellm._logging import verbose_logger
+from litellm.integrations.custom_batch_logger import CustomBatchLogger
+from litellm.integrations.custom_logger import CustomLogger
+from litellm.llms.custom_httpx.http_handler import (
+ get_async_httpx_client,
+ httpxSpecialProvider,
+)
+from litellm.types.integrations.argilla import (
+ SUPPORTED_PAYLOAD_FIELDS,
+ ArgillaCredentialsObject,
+ ArgillaItem,
+)
+from litellm.types.utils import StandardLoggingPayload
+
+
+def is_serializable(value):
+ non_serializable_types = (
+ types.CoroutineType,
+ types.FunctionType,
+ types.GeneratorType,
+ BaseModel,
+ )
+ return not isinstance(value, non_serializable_types)
+
+
+class ArgillaLogger(CustomBatchLogger):
+ def __init__(
+ self,
+ argilla_api_key: Optional[str] = None,
+ argilla_dataset_name: Optional[str] = None,
+ argilla_base_url: Optional[str] = None,
+ **kwargs,
+ ):
+ if litellm.argilla_transformation_object is None:
+ raise Exception(
+ "'litellm.argilla_transformation_object' is required, to log your payload to Argilla."
+ )
+ self.validate_argilla_transformation_object(
+ litellm.argilla_transformation_object
+ )
+ self.argilla_transformation_object = litellm.argilla_transformation_object
+ self.default_credentials = self.get_credentials_from_env(
+ argilla_api_key=argilla_api_key,
+ argilla_dataset_name=argilla_dataset_name,
+ argilla_base_url=argilla_base_url,
+ )
+ self.sampling_rate: float = (
+ float(os.getenv("ARGILLA_SAMPLING_RATE")) # type: ignore
+ if os.getenv("ARGILLA_SAMPLING_RATE") is not None
+ and os.getenv("ARGILLA_SAMPLING_RATE").strip().isdigit() # type: ignore
+ else 1.0
+ )
+
+ self.async_httpx_client = get_async_httpx_client(
+ llm_provider=httpxSpecialProvider.LoggingCallback
+ )
+ _batch_size = (
+ os.getenv("ARGILLA_BATCH_SIZE", None) or litellm.argilla_batch_size
+ )
+ if _batch_size:
+ self.batch_size = int(_batch_size)
+ asyncio.create_task(self.periodic_flush())
+ self.flush_lock = asyncio.Lock()
+ super().__init__(**kwargs, flush_lock=self.flush_lock)
+
+ def validate_argilla_transformation_object(
+ self, argilla_transformation_object: Dict[str, Any]
+ ):
+ if not isinstance(argilla_transformation_object, dict):
+ raise Exception(
+ "'argilla_transformation_object' must be a dictionary, to log your payload to Argilla."
+ )
+
+ for v in argilla_transformation_object.values():
+ if v not in SUPPORTED_PAYLOAD_FIELDS:
+ raise Exception(
+ f"All values in argilla_transformation_object must be a key in SUPPORTED_PAYLOAD_FIELDS, {v} is not a valid key."
+ )
+
+ def get_credentials_from_env(
+ self,
+ argilla_api_key: Optional[str],
+ argilla_dataset_name: Optional[str],
+ argilla_base_url: Optional[str],
+ ) -> ArgillaCredentialsObject:
+
+ _credentials_api_key = argilla_api_key or os.getenv("ARGILLA_API_KEY")
+ if _credentials_api_key is None:
+ raise Exception("Invalid Argilla API Key given. _credentials_api_key=None.")
+
+ _credentials_base_url = (
+ argilla_base_url
+ or os.getenv("ARGILLA_BASE_URL")
+ or "http://localhost:6900/"
+ )
+ if _credentials_base_url is None:
+ raise Exception(
+ "Invalid Argilla Base URL given. _credentials_base_url=None."
+ )
+
+ _credentials_dataset_name = (
+ argilla_dataset_name
+ or os.getenv("ARGILLA_DATASET_NAME")
+ or "litellm-completion"
+ )
+ if _credentials_dataset_name is None:
+ raise Exception("Invalid Argilla Dataset give. Value=None.")
+ else:
+ dataset_response = litellm.module_level_client.get(
+ url=f"{_credentials_base_url}/api/v1/me/datasets?name={_credentials_dataset_name}",
+ headers={"X-Argilla-Api-Key": _credentials_api_key},
+ )
+ json_response = dataset_response.json()
+ if (
+ "items" in json_response
+ and isinstance(json_response["items"], list)
+ and len(json_response["items"]) > 0
+ ):
+ _credentials_dataset_name = json_response["items"][0]["id"]
+
+ return ArgillaCredentialsObject(
+ ARGILLA_API_KEY=_credentials_api_key,
+ ARGILLA_BASE_URL=_credentials_base_url,
+ ARGILLA_DATASET_NAME=_credentials_dataset_name,
+ )
+
+ def get_chat_messages(
+ self, payload: StandardLoggingPayload
+ ) -> List[Dict[str, Any]]:
+ payload_messages = payload.get("messages", None)
+
+ if payload_messages is None:
+ raise Exception("No chat messages found in payload.")
+
+ if (
+ isinstance(payload_messages, list)
+ and len(payload_messages) > 0
+ and isinstance(payload_messages[0], dict)
+ ):
+ return payload_messages
+ elif isinstance(payload_messages, dict):
+ return [payload_messages]
+ else:
+ raise Exception(f"Invalid chat messages format: {payload_messages}")
+
+ def get_str_response(self, payload: StandardLoggingPayload) -> str:
+ response = payload["response"]
+
+ if response is None:
+ raise Exception("No response found in payload.")
+
+ if isinstance(response, str):
+ return response
+ elif isinstance(response, dict):
+ return (
+ response.get("choices", [{}])[0].get("message", {}).get("content", "")
+ )
+ else:
+ raise Exception(f"Invalid response format: {response}")
+
+ def _prepare_log_data(
+ self, kwargs, response_obj, start_time, end_time
+ ) -> Optional[ArgillaItem]:
+ try:
+ # Ensure everything in the payload is converted to str
+ payload: Optional[StandardLoggingPayload] = kwargs.get(
+ "standard_logging_object", None
+ )
+
+ if payload is None:
+ raise Exception("Error logging request payload. Payload=none.")
+
+ argilla_message = self.get_chat_messages(payload)
+ argilla_response = self.get_str_response(payload)
+ argilla_item: ArgillaItem = {"fields": {}}
+ for k, v in self.argilla_transformation_object.items():
+ if v == "messages":
+ argilla_item["fields"][k] = argilla_message
+ elif v == "response":
+ argilla_item["fields"][k] = argilla_response
+ else:
+ argilla_item["fields"][k] = payload.get(v, None)
+
+ return argilla_item
+ except Exception:
+ raise
+
+ def _send_batch(self):
+ if not self.log_queue:
+ return
+
+ argilla_api_base = self.default_credentials["ARGILLA_BASE_URL"]
+ argilla_dataset_name = self.default_credentials["ARGILLA_DATASET_NAME"]
+
+ url = f"{argilla_api_base}/api/v1/datasets/{argilla_dataset_name}/records/bulk"
+
+ argilla_api_key = self.default_credentials["ARGILLA_API_KEY"]
+
+ headers = {"X-Argilla-Api-Key": argilla_api_key}
+
+ try:
+ response = litellm.module_level_client.post(
+ url=url,
+ json=self.log_queue,
+ headers=headers,
+ )
+
+ if response.status_code >= 300:
+ verbose_logger.error(
+ f"Argilla Error: {response.status_code} - {response.text}"
+ )
+ else:
+ verbose_logger.debug(
+ f"Batch of {len(self.log_queue)} runs successfully created"
+ )
+
+ self.log_queue.clear()
+ except Exception:
+ verbose_logger.exception("Argilla Layer Error - Error sending batch.")
+
+ def log_success_event(self, kwargs, response_obj, start_time, end_time):
+ try:
+ sampling_rate = (
+ float(os.getenv("LANGSMITH_SAMPLING_RATE")) # type: ignore
+ if os.getenv("LANGSMITH_SAMPLING_RATE") is not None
+ and os.getenv("LANGSMITH_SAMPLING_RATE").strip().isdigit() # type: ignore
+ else 1.0
+ )
+ random_sample = random.random()
+ if random_sample > sampling_rate:
+ verbose_logger.info(
+ "Skipping Langsmith logging. Sampling rate={}, random_sample={}".format(
+ sampling_rate, random_sample
+ )
+ )
+ return # Skip logging
+ verbose_logger.debug(
+ "Langsmith Sync Layer Logging - kwargs: %s, response_obj: %s",
+ kwargs,
+ response_obj,
+ )
+ data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
+ if data is None:
+ return
+
+ self.log_queue.append(data)
+ verbose_logger.debug(
+ f"Langsmith, event added to queue. Will flush in {self.flush_interval} seconds..."
+ )
+
+ if len(self.log_queue) >= self.batch_size:
+ self._send_batch()
+
+ except Exception:
+ verbose_logger.exception("Langsmith Layer Error - log_success_event error")
+
+ async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
+ try:
+ sampling_rate = self.sampling_rate
+ random_sample = random.random()
+ if random_sample > sampling_rate:
+ verbose_logger.info(
+ "Skipping Langsmith logging. Sampling rate={}, random_sample={}".format(
+ sampling_rate, random_sample
+ )
+ )
+ return # Skip logging
+ verbose_logger.debug(
+ "Langsmith Async Layer Logging - kwargs: %s, response_obj: %s",
+ kwargs,
+ response_obj,
+ )
+ payload: Optional[StandardLoggingPayload] = kwargs.get(
+ "standard_logging_object", None
+ )
+
+ data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
+
+ ## ALLOW CUSTOM LOGGERS TO MODIFY / FILTER DATA BEFORE LOGGING
+ for callback in litellm.callbacks:
+ if isinstance(callback, CustomLogger):
+ try:
+ if data is None:
+ break
+ data = await callback.async_dataset_hook(data, payload)
+ except NotImplementedError:
+ pass
+
+ if data is None:
+ return
+
+ self.log_queue.append(data)
+ verbose_logger.debug(
+ "Langsmith logging: queue length %s, batch size %s",
+ len(self.log_queue),
+ self.batch_size,
+ )
+ if len(self.log_queue) >= self.batch_size:
+ await self.flush_queue()
+ except Exception:
+ verbose_logger.exception(
+ "Argilla Layer Error - error logging async success event."
+ )
+
+ async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
+ sampling_rate = self.sampling_rate
+ random_sample = random.random()
+ if random_sample > sampling_rate:
+ verbose_logger.info(
+ "Skipping Langsmith logging. Sampling rate={}, random_sample={}".format(
+ sampling_rate, random_sample
+ )
+ )
+ return # Skip logging
+ verbose_logger.info("Langsmith Failure Event Logging!")
+ try:
+ data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
+ self.log_queue.append(data)
+ verbose_logger.debug(
+ "Langsmith logging: queue length %s, batch size %s",
+ len(self.log_queue),
+ self.batch_size,
+ )
+ if len(self.log_queue) >= self.batch_size:
+ await self.flush_queue()
+ except Exception:
+ verbose_logger.exception(
+ "Langsmith Layer Error - error logging async failure event."
+ )
+
+ async def async_send_batch(self):
+ """
+ sends runs to /batch endpoint
+
+ Sends runs from self.log_queue
+
+ Returns: None
+
+ Raises: Does not raise an exception, will only verbose_logger.exception()
+ """
+ if not self.log_queue:
+ return
+
+ argilla_api_base = self.default_credentials["ARGILLA_BASE_URL"]
+ argilla_dataset_name = self.default_credentials["ARGILLA_DATASET_NAME"]
+
+ url = f"{argilla_api_base}/api/v1/datasets/{argilla_dataset_name}/records/bulk"
+
+ argilla_api_key = self.default_credentials["ARGILLA_API_KEY"]
+
+ headers = {"X-Argilla-Api-Key": argilla_api_key}
+
+ try:
+ response = await self.async_httpx_client.put(
+ url=url,
+ data=json.dumps(
+ {
+ "items": self.log_queue,
+ }
+ ),
+ headers=headers,
+ timeout=60000,
+ )
+ response.raise_for_status()
+
+ if response.status_code >= 300:
+ verbose_logger.error(
+ f"Argilla Error: {response.status_code} - {response.text}"
+ )
+ else:
+ verbose_logger.debug(
+ "Batch of %s runs successfully created", len(self.log_queue)
+ )
+ except httpx.HTTPStatusError:
+ verbose_logger.exception("Argilla HTTP Error")
+ except Exception:
+ verbose_logger.exception("Argilla Layer Error")
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/arize/_utils.py b/.venv/lib/python3.12/site-packages/litellm/integrations/arize/_utils.py
new file mode 100644
index 00000000..487304cc
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/arize/_utils.py
@@ -0,0 +1,126 @@
+from typing import TYPE_CHECKING, Any, Optional
+
+from litellm._logging import verbose_logger
+from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
+from litellm.types.utils import StandardLoggingPayload
+
+if TYPE_CHECKING:
+ from opentelemetry.trace import Span as _Span
+
+ Span = _Span
+else:
+ Span = Any
+
+
+def set_attributes(span: Span, kwargs, response_obj):
+ from litellm.integrations._types.open_inference import (
+ MessageAttributes,
+ OpenInferenceSpanKindValues,
+ SpanAttributes,
+ )
+
+ try:
+ standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
+ "standard_logging_object"
+ )
+
+ #############################################
+ ############ LLM CALL METADATA ##############
+ #############################################
+
+ if standard_logging_payload and (
+ metadata := standard_logging_payload["metadata"]
+ ):
+ span.set_attribute(SpanAttributes.METADATA, safe_dumps(metadata))
+
+ #############################################
+ ########## LLM Request Attributes ###########
+ #############################################
+
+ # The name of the LLM a request is being made to
+ if kwargs.get("model"):
+ span.set_attribute(SpanAttributes.LLM_MODEL_NAME, kwargs.get("model"))
+
+ span.set_attribute(
+ SpanAttributes.OPENINFERENCE_SPAN_KIND,
+ OpenInferenceSpanKindValues.LLM.value,
+ )
+ messages = kwargs.get("messages")
+
+ # for /chat/completions
+ # https://docs.arize.com/arize/large-language-models/tracing/semantic-conventions
+ if messages:
+ span.set_attribute(
+ SpanAttributes.INPUT_VALUE,
+ messages[-1].get("content", ""), # get the last message for input
+ )
+
+ # LLM_INPUT_MESSAGES shows up under `input_messages` tab on the span page
+ for idx, msg in enumerate(messages):
+ # Set the role per message
+ span.set_attribute(
+ f"{SpanAttributes.LLM_INPUT_MESSAGES}.{idx}.{MessageAttributes.MESSAGE_ROLE}",
+ msg["role"],
+ )
+ # Set the content per message
+ span.set_attribute(
+ f"{SpanAttributes.LLM_INPUT_MESSAGES}.{idx}.{MessageAttributes.MESSAGE_CONTENT}",
+ msg.get("content", ""),
+ )
+
+ if standard_logging_payload and (
+ model_params := standard_logging_payload["model_parameters"]
+ ):
+ # The Generative AI Provider: Azure, OpenAI, etc.
+ span.set_attribute(
+ SpanAttributes.LLM_INVOCATION_PARAMETERS, safe_dumps(model_params)
+ )
+
+ if model_params.get("user"):
+ user_id = model_params.get("user")
+ if user_id is not None:
+ span.set_attribute(SpanAttributes.USER_ID, user_id)
+
+ #############################################
+ ########## LLM Response Attributes ##########
+ # https://docs.arize.com/arize/large-language-models/tracing/semantic-conventions
+ #############################################
+ if hasattr(response_obj, "get"):
+ for choice in response_obj.get("choices", []):
+ response_message = choice.get("message", {})
+ span.set_attribute(
+ SpanAttributes.OUTPUT_VALUE, response_message.get("content", "")
+ )
+
+ # This shows up under `output_messages` tab on the span page
+ # This code assumes a single response
+ span.set_attribute(
+ f"{SpanAttributes.LLM_OUTPUT_MESSAGES}.0.{MessageAttributes.MESSAGE_ROLE}",
+ response_message.get("role"),
+ )
+ span.set_attribute(
+ f"{SpanAttributes.LLM_OUTPUT_MESSAGES}.0.{MessageAttributes.MESSAGE_CONTENT}",
+ response_message.get("content", ""),
+ )
+
+ usage = response_obj.get("usage")
+ if usage:
+ span.set_attribute(
+ SpanAttributes.LLM_TOKEN_COUNT_TOTAL,
+ usage.get("total_tokens"),
+ )
+
+ # The number of tokens used in the LLM response (completion).
+ span.set_attribute(
+ SpanAttributes.LLM_TOKEN_COUNT_COMPLETION,
+ usage.get("completion_tokens"),
+ )
+
+ # The number of tokens used in the LLM prompt.
+ span.set_attribute(
+ SpanAttributes.LLM_TOKEN_COUNT_PROMPT,
+ usage.get("prompt_tokens"),
+ )
+ pass
+ except Exception as e:
+ verbose_logger.error(f"Error setting arize attributes: {e}")
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/arize/arize.py b/.venv/lib/python3.12/site-packages/litellm/integrations/arize/arize.py
new file mode 100644
index 00000000..7a0fb785
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/arize/arize.py
@@ -0,0 +1,105 @@
+"""
+arize AI is OTEL compatible
+
+this file has Arize ai specific helper functions
+"""
+
+import os
+from datetime import datetime
+from typing import TYPE_CHECKING, Any, Optional, Union
+
+from litellm.integrations.arize import _utils
+from litellm.integrations.opentelemetry import OpenTelemetry
+from litellm.types.integrations.arize import ArizeConfig
+from litellm.types.services import ServiceLoggerPayload
+
+if TYPE_CHECKING:
+ from opentelemetry.trace import Span as _Span
+
+ from litellm.types.integrations.arize import Protocol as _Protocol
+
+ Protocol = _Protocol
+ Span = _Span
+else:
+ Protocol = Any
+ Span = Any
+
+
+class ArizeLogger(OpenTelemetry):
+
+ def set_attributes(self, span: Span, kwargs, response_obj: Optional[Any]):
+ ArizeLogger.set_arize_attributes(span, kwargs, response_obj)
+ return
+
+ @staticmethod
+ def set_arize_attributes(span: Span, kwargs, response_obj):
+ _utils.set_attributes(span, kwargs, response_obj)
+ return
+
+ @staticmethod
+ def get_arize_config() -> ArizeConfig:
+ """
+ Helper function to get Arize configuration.
+
+ Returns:
+ ArizeConfig: A Pydantic model containing Arize configuration.
+
+ Raises:
+ ValueError: If required environment variables are not set.
+ """
+ space_key = os.environ.get("ARIZE_SPACE_KEY")
+ api_key = os.environ.get("ARIZE_API_KEY")
+
+ grpc_endpoint = os.environ.get("ARIZE_ENDPOINT")
+ http_endpoint = os.environ.get("ARIZE_HTTP_ENDPOINT")
+
+ endpoint = None
+ protocol: Protocol = "otlp_grpc"
+
+ if grpc_endpoint:
+ protocol = "otlp_grpc"
+ endpoint = grpc_endpoint
+ elif http_endpoint:
+ protocol = "otlp_http"
+ endpoint = http_endpoint
+ else:
+ protocol = "otlp_grpc"
+ endpoint = "https://otlp.arize.com/v1"
+
+ return ArizeConfig(
+ space_key=space_key,
+ api_key=api_key,
+ protocol=protocol,
+ endpoint=endpoint,
+ )
+
+ async def async_service_success_hook(
+ self,
+ payload: ServiceLoggerPayload,
+ parent_otel_span: Optional[Span] = None,
+ start_time: Optional[Union[datetime, float]] = None,
+ end_time: Optional[Union[datetime, float]] = None,
+ event_metadata: Optional[dict] = None,
+ ):
+ """Arize is used mainly for LLM I/O tracing, sending router+caching metrics adds bloat to arize logs"""
+ pass
+
+ async def async_service_failure_hook(
+ self,
+ payload: ServiceLoggerPayload,
+ error: Optional[str] = "",
+ parent_otel_span: Optional[Span] = None,
+ start_time: Optional[Union[datetime, float]] = None,
+ end_time: Optional[Union[float, datetime]] = None,
+ event_metadata: Optional[dict] = None,
+ ):
+ """Arize is used mainly for LLM I/O tracing, sending router+caching metrics adds bloat to arize logs"""
+ pass
+
+ def create_litellm_proxy_request_started_span(
+ self,
+ start_time: datetime,
+ headers: dict,
+ ):
+ """Arize is used mainly for LLM I/O tracing, sending Proxy Server Request adds bloat to arize logs"""
+ pass
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/arize/arize_phoenix.py b/.venv/lib/python3.12/site-packages/litellm/integrations/arize/arize_phoenix.py
new file mode 100644
index 00000000..d7b7d581
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/arize/arize_phoenix.py
@@ -0,0 +1,73 @@
+import os
+from typing import TYPE_CHECKING, Any
+from litellm.integrations.arize import _utils
+from litellm._logging import verbose_logger
+from litellm.types.integrations.arize_phoenix import ArizePhoenixConfig
+
+if TYPE_CHECKING:
+ from .opentelemetry import OpenTelemetryConfig as _OpenTelemetryConfig
+ from litellm.types.integrations.arize import Protocol as _Protocol
+ from opentelemetry.trace import Span as _Span
+
+ Protocol = _Protocol
+ OpenTelemetryConfig = _OpenTelemetryConfig
+ Span = _Span
+else:
+ Protocol = Any
+ OpenTelemetryConfig = Any
+ Span = Any
+
+
+ARIZE_HOSTED_PHOENIX_ENDPOINT = "https://app.phoenix.arize.com/v1/traces"
+
+class ArizePhoenixLogger:
+ @staticmethod
+ def set_arize_phoenix_attributes(span: Span, kwargs, response_obj):
+ _utils.set_attributes(span, kwargs, response_obj)
+ return
+
+ @staticmethod
+ def get_arize_phoenix_config() -> ArizePhoenixConfig:
+ """
+ Retrieves the Arize Phoenix configuration based on environment variables.
+
+ Returns:
+ ArizePhoenixConfig: A Pydantic model containing Arize Phoenix configuration.
+ """
+ api_key = os.environ.get("PHOENIX_API_KEY", None)
+ grpc_endpoint = os.environ.get("PHOENIX_COLLECTOR_ENDPOINT", None)
+ http_endpoint = os.environ.get("PHOENIX_COLLECTOR_HTTP_ENDPOINT", None)
+
+ endpoint = None
+ protocol: Protocol = "otlp_http"
+
+ if http_endpoint:
+ endpoint = http_endpoint
+ protocol = "otlp_http"
+ elif grpc_endpoint:
+ endpoint = grpc_endpoint
+ protocol = "otlp_grpc"
+ else:
+ endpoint = ARIZE_HOSTED_PHOENIX_ENDPOINT
+ protocol = "otlp_http"
+ verbose_logger.debug(
+ f"No PHOENIX_COLLECTOR_ENDPOINT or PHOENIX_COLLECTOR_HTTP_ENDPOINT found, using default endpoint with http: {ARIZE_HOSTED_PHOENIX_ENDPOINT}"
+ )
+
+ otlp_auth_headers = None
+ # If the endpoint is the Arize hosted Phoenix endpoint, use the api_key as the auth header as currently it is uses
+ # a slightly different auth header format than self hosted phoenix
+ if endpoint == ARIZE_HOSTED_PHOENIX_ENDPOINT:
+ if api_key is None:
+ raise ValueError("PHOENIX_API_KEY must be set when the Arize hosted Phoenix endpoint is used.")
+ otlp_auth_headers = f"api_key={api_key}"
+ elif api_key is not None:
+ # api_key/auth is optional for self hosted phoenix
+ otlp_auth_headers = f"Authorization=Bearer {api_key}"
+
+ return ArizePhoenixConfig(
+ otlp_auth_headers=otlp_auth_headers,
+ protocol=protocol,
+ endpoint=endpoint
+ )
+
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/athina.py b/.venv/lib/python3.12/site-packages/litellm/integrations/athina.py
new file mode 100644
index 00000000..705dc11f
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/athina.py
@@ -0,0 +1,102 @@
+import datetime
+
+import litellm
+
+
+class AthinaLogger:
+ def __init__(self):
+ import os
+
+ self.athina_api_key = os.getenv("ATHINA_API_KEY")
+ self.headers = {
+ "athina-api-key": self.athina_api_key,
+ "Content-Type": "application/json",
+ }
+ self.athina_logging_url = os.getenv("ATHINA_BASE_URL", "https://log.athina.ai") + "/api/v1/log/inference"
+ self.additional_keys = [
+ "environment",
+ "prompt_slug",
+ "customer_id",
+ "customer_user_id",
+ "session_id",
+ "external_reference_id",
+ "context",
+ "expected_response",
+ "user_query",
+ "tags",
+ "user_feedback",
+ "model_options",
+ "custom_attributes",
+ ]
+
+ def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
+ import json
+ import traceback
+
+ try:
+ is_stream = kwargs.get("stream", False)
+ if is_stream:
+ if "complete_streaming_response" in kwargs:
+ # Log the completion response in streaming mode
+ completion_response = kwargs["complete_streaming_response"]
+ response_json = (
+ completion_response.model_dump() if completion_response else {}
+ )
+ else:
+ # Skip logging if the completion response is not available
+ return
+ else:
+ # Log the completion response in non streaming mode
+ response_json = response_obj.model_dump() if response_obj else {}
+ data = {
+ "language_model_id": kwargs.get("model"),
+ "request": kwargs,
+ "response": response_json,
+ "prompt_tokens": response_json.get("usage", {}).get("prompt_tokens"),
+ "completion_tokens": response_json.get("usage", {}).get(
+ "completion_tokens"
+ ),
+ "total_tokens": response_json.get("usage", {}).get("total_tokens"),
+ }
+
+ if (
+ type(end_time) is datetime.datetime
+ and type(start_time) is datetime.datetime
+ ):
+ data["response_time"] = int(
+ (end_time - start_time).total_seconds() * 1000
+ )
+
+ if "messages" in kwargs:
+ data["prompt"] = kwargs.get("messages", None)
+
+ # Directly add tools or functions if present
+ optional_params = kwargs.get("optional_params", {})
+ data.update(
+ (k, v)
+ for k, v in optional_params.items()
+ if k in ["tools", "functions"]
+ )
+
+ # Add additional metadata keys
+ metadata = kwargs.get("litellm_params", {}).get("metadata", {})
+ if metadata:
+ for key in self.additional_keys:
+ if key in metadata:
+ data[key] = metadata[key]
+ response = litellm.module_level_client.post(
+ self.athina_logging_url,
+ headers=self.headers,
+ data=json.dumps(data, default=str),
+ )
+ if response.status_code != 200:
+ print_verbose(
+ f"Athina Logger Error - {response.text}, {response.status_code}"
+ )
+ else:
+ print_verbose(f"Athina Logger Succeeded - {response.text}")
+ except Exception as e:
+ print_verbose(
+ f"Athina Logger Error - {e}, Stack trace: {traceback.format_exc()}"
+ )
+ pass
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/azure_storage/azure_storage.py b/.venv/lib/python3.12/site-packages/litellm/integrations/azure_storage/azure_storage.py
new file mode 100644
index 00000000..ddc46b11
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/azure_storage/azure_storage.py
@@ -0,0 +1,381 @@
+import asyncio
+import json
+import os
+import uuid
+from datetime import datetime, timedelta
+from typing import List, Optional
+
+from litellm._logging import verbose_logger
+from litellm.constants import AZURE_STORAGE_MSFT_VERSION
+from litellm.integrations.custom_batch_logger import CustomBatchLogger
+from litellm.llms.azure.common_utils import get_azure_ad_token_from_entrata_id
+from litellm.llms.custom_httpx.http_handler import (
+ AsyncHTTPHandler,
+ get_async_httpx_client,
+ httpxSpecialProvider,
+)
+from litellm.types.utils import StandardLoggingPayload
+
+
+class AzureBlobStorageLogger(CustomBatchLogger):
+ def __init__(
+ self,
+ **kwargs,
+ ):
+ try:
+ verbose_logger.debug(
+ "AzureBlobStorageLogger: in init azure blob storage logger"
+ )
+
+ # Env Variables used for Azure Storage Authentication
+ self.tenant_id = os.getenv("AZURE_STORAGE_TENANT_ID")
+ self.client_id = os.getenv("AZURE_STORAGE_CLIENT_ID")
+ self.client_secret = os.getenv("AZURE_STORAGE_CLIENT_SECRET")
+ self.azure_storage_account_key: Optional[str] = os.getenv(
+ "AZURE_STORAGE_ACCOUNT_KEY"
+ )
+
+ # Required Env Variables for Azure Storage
+ _azure_storage_account_name = os.getenv("AZURE_STORAGE_ACCOUNT_NAME")
+ if not _azure_storage_account_name:
+ raise ValueError(
+ "Missing required environment variable: AZURE_STORAGE_ACCOUNT_NAME"
+ )
+ self.azure_storage_account_name: str = _azure_storage_account_name
+ _azure_storage_file_system = os.getenv("AZURE_STORAGE_FILE_SYSTEM")
+ if not _azure_storage_file_system:
+ raise ValueError(
+ "Missing required environment variable: AZURE_STORAGE_FILE_SYSTEM"
+ )
+ self.azure_storage_file_system: str = _azure_storage_file_system
+
+ # Internal variables used for Token based authentication
+ self.azure_auth_token: Optional[str] = (
+ None # the Azure AD token to use for Azure Storage API requests
+ )
+ self.token_expiry: Optional[datetime] = (
+ None # the expiry time of the currentAzure AD token
+ )
+
+ asyncio.create_task(self.periodic_flush())
+ self.flush_lock = asyncio.Lock()
+ self.log_queue: List[StandardLoggingPayload] = []
+ super().__init__(**kwargs, flush_lock=self.flush_lock)
+ except Exception as e:
+ verbose_logger.exception(
+ f"AzureBlobStorageLogger: Got exception on init AzureBlobStorageLogger client {str(e)}"
+ )
+ raise e
+
+ async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
+ """
+ Async Log success events to Azure Blob Storage
+
+ Raises:
+ Raises a NON Blocking verbose_logger.exception if an error occurs
+ """
+ try:
+ self._premium_user_check()
+ verbose_logger.debug(
+ "AzureBlobStorageLogger: Logging - Enters logging function for model %s",
+ kwargs,
+ )
+ standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
+ "standard_logging_object"
+ )
+
+ if standard_logging_payload is None:
+ raise ValueError("standard_logging_payload is not set")
+
+ self.log_queue.append(standard_logging_payload)
+
+ except Exception as e:
+ verbose_logger.exception(f"AzureBlobStorageLogger Layer Error - {str(e)}")
+ pass
+
+ async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
+ """
+ Async Log failure events to Azure Blob Storage
+
+ Raises:
+ Raises a NON Blocking verbose_logger.exception if an error occurs
+ """
+ try:
+ self._premium_user_check()
+ verbose_logger.debug(
+ "AzureBlobStorageLogger: Logging - Enters logging function for model %s",
+ kwargs,
+ )
+ standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
+ "standard_logging_object"
+ )
+
+ if standard_logging_payload is None:
+ raise ValueError("standard_logging_payload is not set")
+
+ self.log_queue.append(standard_logging_payload)
+ except Exception as e:
+ verbose_logger.exception(f"AzureBlobStorageLogger Layer Error - {str(e)}")
+ pass
+
+ async def async_send_batch(self):
+ """
+ Sends the in memory logs queue to Azure Blob Storage
+
+ Raises:
+ Raises a NON Blocking verbose_logger.exception if an error occurs
+ """
+ try:
+ if not self.log_queue:
+ verbose_logger.exception("Datadog: log_queue does not exist")
+ return
+
+ verbose_logger.debug(
+ "AzureBlobStorageLogger - about to flush %s events",
+ len(self.log_queue),
+ )
+
+ for payload in self.log_queue:
+ await self.async_upload_payload_to_azure_blob_storage(payload=payload)
+
+ except Exception as e:
+ verbose_logger.exception(
+ f"AzureBlobStorageLogger Error sending batch API - {str(e)}"
+ )
+
+ async def async_upload_payload_to_azure_blob_storage(
+ self, payload: StandardLoggingPayload
+ ):
+ """
+ Uploads the payload to Azure Blob Storage using a 3-step process:
+ 1. Create file resource
+ 2. Append data
+ 3. Flush the data
+ """
+ try:
+
+ if self.azure_storage_account_key:
+ await self.upload_to_azure_data_lake_with_azure_account_key(
+ payload=payload
+ )
+ else:
+ # Get a valid token instead of always requesting a new one
+ await self.set_valid_azure_ad_token()
+ async_client = get_async_httpx_client(
+ llm_provider=httpxSpecialProvider.LoggingCallback
+ )
+ json_payload = (
+ json.dumps(payload) + "\n"
+ ) # Add newline for each log entry
+ payload_bytes = json_payload.encode("utf-8")
+ filename = f"{payload.get('id') or str(uuid.uuid4())}.json"
+ base_url = f"https://{self.azure_storage_account_name}.dfs.core.windows.net/{self.azure_storage_file_system}/{filename}"
+
+ # Execute the 3-step upload process
+ await self._create_file(async_client, base_url)
+ await self._append_data(async_client, base_url, json_payload)
+ await self._flush_data(async_client, base_url, len(payload_bytes))
+
+ verbose_logger.debug(
+ f"Successfully uploaded log to Azure Blob Storage: {filename}"
+ )
+
+ except Exception as e:
+ verbose_logger.exception(f"Error uploading to Azure Blob Storage: {str(e)}")
+ raise e
+
+ async def _create_file(self, client: AsyncHTTPHandler, base_url: str):
+ """Helper method to create the file resource"""
+ try:
+ verbose_logger.debug(f"Creating file resource at: {base_url}")
+ headers = {
+ "x-ms-version": AZURE_STORAGE_MSFT_VERSION,
+ "Content-Length": "0",
+ "Authorization": f"Bearer {self.azure_auth_token}",
+ }
+ response = await client.put(f"{base_url}?resource=file", headers=headers)
+ response.raise_for_status()
+ verbose_logger.debug("Successfully created file resource")
+ except Exception as e:
+ verbose_logger.exception(f"Error creating file resource: {str(e)}")
+ raise
+
+ async def _append_data(
+ self, client: AsyncHTTPHandler, base_url: str, json_payload: str
+ ):
+ """Helper method to append data to the file"""
+ try:
+ verbose_logger.debug(f"Appending data to file: {base_url}")
+ headers = {
+ "x-ms-version": AZURE_STORAGE_MSFT_VERSION,
+ "Content-Type": "application/json",
+ "Authorization": f"Bearer {self.azure_auth_token}",
+ }
+ response = await client.patch(
+ f"{base_url}?action=append&position=0",
+ headers=headers,
+ data=json_payload,
+ )
+ response.raise_for_status()
+ verbose_logger.debug("Successfully appended data")
+ except Exception as e:
+ verbose_logger.exception(f"Error appending data: {str(e)}")
+ raise
+
+ async def _flush_data(self, client: AsyncHTTPHandler, base_url: str, position: int):
+ """Helper method to flush the data"""
+ try:
+ verbose_logger.debug(f"Flushing data at position {position}")
+ headers = {
+ "x-ms-version": AZURE_STORAGE_MSFT_VERSION,
+ "Content-Length": "0",
+ "Authorization": f"Bearer {self.azure_auth_token}",
+ }
+ response = await client.patch(
+ f"{base_url}?action=flush&position={position}", headers=headers
+ )
+ response.raise_for_status()
+ verbose_logger.debug("Successfully flushed data")
+ except Exception as e:
+ verbose_logger.exception(f"Error flushing data: {str(e)}")
+ raise
+
+ ####### Helper methods to managing Authentication to Azure Storage #######
+ ##########################################################################
+
+ async def set_valid_azure_ad_token(self):
+ """
+ Wrapper to set self.azure_auth_token to a valid Azure AD token, refreshing if necessary
+
+ Refreshes the token when:
+ - Token is expired
+ - Token is not set
+ """
+ # Check if token needs refresh
+ if self._azure_ad_token_is_expired() or self.azure_auth_token is None:
+ verbose_logger.debug("Azure AD token needs refresh")
+ self.azure_auth_token = self.get_azure_ad_token_from_azure_storage(
+ tenant_id=self.tenant_id,
+ client_id=self.client_id,
+ client_secret=self.client_secret,
+ )
+ # Token typically expires in 1 hour
+ self.token_expiry = datetime.now() + timedelta(hours=1)
+ verbose_logger.debug(f"New token will expire at {self.token_expiry}")
+
+ def get_azure_ad_token_from_azure_storage(
+ self,
+ tenant_id: Optional[str],
+ client_id: Optional[str],
+ client_secret: Optional[str],
+ ) -> str:
+ """
+ Gets Azure AD token to use for Azure Storage API requests
+ """
+ verbose_logger.debug("Getting Azure AD Token from Azure Storage")
+ verbose_logger.debug(
+ "tenant_id %s, client_id %s, client_secret %s",
+ tenant_id,
+ client_id,
+ client_secret,
+ )
+ if tenant_id is None:
+ raise ValueError(
+ "Missing required environment variable: AZURE_STORAGE_TENANT_ID"
+ )
+ if client_id is None:
+ raise ValueError(
+ "Missing required environment variable: AZURE_STORAGE_CLIENT_ID"
+ )
+ if client_secret is None:
+ raise ValueError(
+ "Missing required environment variable: AZURE_STORAGE_CLIENT_SECRET"
+ )
+
+ token_provider = get_azure_ad_token_from_entrata_id(
+ tenant_id=tenant_id,
+ client_id=client_id,
+ client_secret=client_secret,
+ scope="https://storage.azure.com/.default",
+ )
+ token = token_provider()
+
+ verbose_logger.debug("azure auth token %s", token)
+
+ return token
+
+ def _azure_ad_token_is_expired(self):
+ """
+ Returns True if Azure AD token is expired, False otherwise
+ """
+ if self.azure_auth_token and self.token_expiry:
+ if datetime.now() + timedelta(minutes=5) >= self.token_expiry:
+ verbose_logger.debug("Azure AD token is expired. Requesting new token")
+ return True
+ return False
+
+ def _premium_user_check(self):
+ """
+ Checks if the user is a premium user, raises an error if not
+ """
+ from litellm.proxy.proxy_server import CommonProxyErrors, premium_user
+
+ if premium_user is not True:
+ raise ValueError(
+ f"AzureBlobStorageLogger is only available for premium users. {CommonProxyErrors.not_premium_user}"
+ )
+
+ async def upload_to_azure_data_lake_with_azure_account_key(
+ self, payload: StandardLoggingPayload
+ ):
+ """
+ Uploads the payload to Azure Data Lake using the Azure SDK
+
+ This is used when Azure Storage Account Key is set - Azure Storage Account Key does not work directly with Azure Rest API
+ """
+ from azure.storage.filedatalake.aio import DataLakeServiceClient
+
+ # Create an async service client
+ service_client = DataLakeServiceClient(
+ account_url=f"https://{self.azure_storage_account_name}.dfs.core.windows.net",
+ credential=self.azure_storage_account_key,
+ )
+ # Get file system client
+ file_system_client = service_client.get_file_system_client(
+ file_system=self.azure_storage_file_system
+ )
+
+ try:
+ # Create directory with today's date
+ from datetime import datetime
+
+ today = datetime.now().strftime("%Y-%m-%d")
+ directory_client = file_system_client.get_directory_client(today)
+
+ # check if the directory exists
+ if not await directory_client.exists():
+ await directory_client.create_directory()
+ verbose_logger.debug(f"Created directory: {today}")
+
+ # Create a file client
+ file_name = f"{payload.get('id') or str(uuid.uuid4())}.json"
+ file_client = directory_client.get_file_client(file_name)
+
+ # Create the file
+ await file_client.create_file()
+
+ # Content to append
+ content = json.dumps(payload).encode("utf-8")
+
+ # Append content to the file
+ await file_client.append_data(data=content, offset=0, length=len(content))
+
+ # Flush the content to finalize the file
+ await file_client.flush_data(position=len(content), offset=0)
+
+ verbose_logger.debug(
+ f"Successfully uploaded and wrote to {today}/{file_name}"
+ )
+
+ except Exception as e:
+ verbose_logger.exception(f"Error occurred: {str(e)}")
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/braintrust_logging.py b/.venv/lib/python3.12/site-packages/litellm/integrations/braintrust_logging.py
new file mode 100644
index 00000000..281fbda0
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/braintrust_logging.py
@@ -0,0 +1,399 @@
+# What is this?
+## Log success + failure events to Braintrust
+
+import copy
+import os
+from datetime import datetime
+from typing import Optional, Dict
+
+import httpx
+from pydantic import BaseModel
+
+import litellm
+from litellm import verbose_logger
+from litellm.integrations.custom_logger import CustomLogger
+from litellm.llms.custom_httpx.http_handler import (
+ HTTPHandler,
+ get_async_httpx_client,
+ httpxSpecialProvider,
+)
+from litellm.utils import print_verbose
+
+global_braintrust_http_handler = get_async_httpx_client(llm_provider=httpxSpecialProvider.LoggingCallback)
+global_braintrust_sync_http_handler = HTTPHandler()
+API_BASE = "https://api.braintrustdata.com/v1"
+
+
+def get_utc_datetime():
+ import datetime as dt
+ from datetime import datetime
+
+ if hasattr(dt, "UTC"):
+ return datetime.now(dt.UTC) # type: ignore
+ else:
+ return datetime.utcnow() # type: ignore
+
+
+class BraintrustLogger(CustomLogger):
+ def __init__(self, api_key: Optional[str] = None, api_base: Optional[str] = None) -> None:
+ super().__init__()
+ self.validate_environment(api_key=api_key)
+ self.api_base = api_base or API_BASE
+ self.default_project_id = None
+ self.api_key: str = api_key or os.getenv("BRAINTRUST_API_KEY") # type: ignore
+ self.headers = {
+ "Authorization": "Bearer " + self.api_key,
+ "Content-Type": "application/json",
+ }
+ self._project_id_cache: Dict[str, str] = {} # Cache mapping project names to IDs
+
+ def validate_environment(self, api_key: Optional[str]):
+ """
+ Expects
+ BRAINTRUST_API_KEY
+
+ in the environment
+ """
+ missing_keys = []
+ if api_key is None and os.getenv("BRAINTRUST_API_KEY", None) is None:
+ missing_keys.append("BRAINTRUST_API_KEY")
+
+ if len(missing_keys) > 0:
+ raise Exception("Missing keys={} in environment.".format(missing_keys))
+
+ def get_project_id_sync(self, project_name: str) -> str:
+ """
+ Get project ID from name, using cache if available.
+ If project doesn't exist, creates it.
+ """
+ if project_name in self._project_id_cache:
+ return self._project_id_cache[project_name]
+
+ try:
+ response = global_braintrust_sync_http_handler.post(
+ f"{self.api_base}/project", headers=self.headers, json={"name": project_name}
+ )
+ project_dict = response.json()
+ project_id = project_dict["id"]
+ self._project_id_cache[project_name] = project_id
+ return project_id
+ except httpx.HTTPStatusError as e:
+ raise Exception(f"Failed to register project: {e.response.text}")
+
+ async def get_project_id_async(self, project_name: str) -> str:
+ """
+ Async version of get_project_id_sync
+ """
+ if project_name in self._project_id_cache:
+ return self._project_id_cache[project_name]
+
+ try:
+ response = await global_braintrust_http_handler.post(
+ f"{self.api_base}/project/register", headers=self.headers, json={"name": project_name}
+ )
+ project_dict = response.json()
+ project_id = project_dict["id"]
+ self._project_id_cache[project_name] = project_id
+ return project_id
+ except httpx.HTTPStatusError as e:
+ raise Exception(f"Failed to register project: {e.response.text}")
+
+ @staticmethod
+ def add_metadata_from_header(litellm_params: dict, metadata: dict) -> dict:
+ """
+ Adds metadata from proxy request headers to Langfuse logging if keys start with "langfuse_"
+ and overwrites litellm_params.metadata if already included.
+
+ For example if you want to append your trace to an existing `trace_id` via header, send
+ `headers: { ..., langfuse_existing_trace_id: your-existing-trace-id }` via proxy request.
+ """
+ if litellm_params is None:
+ return metadata
+
+ if litellm_params.get("proxy_server_request") is None:
+ return metadata
+
+ if metadata is None:
+ metadata = {}
+
+ proxy_headers = litellm_params.get("proxy_server_request", {}).get("headers", {}) or {}
+
+ for metadata_param_key in proxy_headers:
+ if metadata_param_key.startswith("braintrust"):
+ trace_param_key = metadata_param_key.replace("braintrust", "", 1)
+ if trace_param_key in metadata:
+ verbose_logger.warning(f"Overwriting Braintrust `{trace_param_key}` from request header")
+ else:
+ verbose_logger.debug(f"Found Braintrust `{trace_param_key}` in request header")
+ metadata[trace_param_key] = proxy_headers.get(metadata_param_key)
+
+ return metadata
+
+ async def create_default_project_and_experiment(self):
+ project = await global_braintrust_http_handler.post(
+ f"{self.api_base}/project", headers=self.headers, json={"name": "litellm"}
+ )
+
+ project_dict = project.json()
+
+ self.default_project_id = project_dict["id"]
+
+ def create_sync_default_project_and_experiment(self):
+ project = global_braintrust_sync_http_handler.post(
+ f"{self.api_base}/project", headers=self.headers, json={"name": "litellm"}
+ )
+
+ project_dict = project.json()
+
+ self.default_project_id = project_dict["id"]
+
+ def log_success_event( # noqa: PLR0915
+ self, kwargs, response_obj, start_time, end_time
+ ):
+ verbose_logger.debug("REACHES BRAINTRUST SUCCESS")
+ try:
+ litellm_call_id = kwargs.get("litellm_call_id")
+ prompt = {"messages": kwargs.get("messages")}
+ output = None
+ choices = []
+ if response_obj is not None and (
+ kwargs.get("call_type", None) == "embedding" or isinstance(response_obj, litellm.EmbeddingResponse)
+ ):
+ output = None
+ elif response_obj is not None and isinstance(response_obj, litellm.ModelResponse):
+ output = response_obj["choices"][0]["message"].json()
+ choices = response_obj["choices"]
+ elif response_obj is not None and isinstance(response_obj, litellm.TextCompletionResponse):
+ output = response_obj.choices[0].text
+ choices = response_obj.choices
+ elif response_obj is not None and isinstance(response_obj, litellm.ImageResponse):
+ output = response_obj["data"]
+
+ litellm_params = kwargs.get("litellm_params", {})
+ metadata = litellm_params.get("metadata", {}) or {} # if litellm_params['metadata'] == None
+ metadata = self.add_metadata_from_header(litellm_params, metadata)
+ clean_metadata = {}
+ try:
+ metadata = copy.deepcopy(metadata) # Avoid modifying the original metadata
+ except Exception:
+ new_metadata = {}
+ for key, value in metadata.items():
+ if (
+ isinstance(value, list)
+ or isinstance(value, dict)
+ or isinstance(value, str)
+ or isinstance(value, int)
+ or isinstance(value, float)
+ ):
+ new_metadata[key] = copy.deepcopy(value)
+ metadata = new_metadata
+
+ # Get project_id from metadata or create default if needed
+ project_id = metadata.get("project_id")
+ if project_id is None:
+ project_name = metadata.get("project_name")
+ project_id = self.get_project_id_sync(project_name) if project_name else None
+
+ if project_id is None:
+ if self.default_project_id is None:
+ self.create_sync_default_project_and_experiment()
+ project_id = self.default_project_id
+
+ tags = []
+ if isinstance(metadata, dict):
+ for key, value in metadata.items():
+ # generate langfuse tags - Default Tags sent to Langfuse from LiteLLM Proxy
+ if (
+ litellm.langfuse_default_tags is not None
+ and isinstance(litellm.langfuse_default_tags, list)
+ and key in litellm.langfuse_default_tags
+ ):
+ tags.append(f"{key}:{value}")
+
+ # clean litellm metadata before logging
+ if key in [
+ "headers",
+ "endpoint",
+ "caching_groups",
+ "previous_models",
+ ]:
+ continue
+ else:
+ clean_metadata[key] = value
+
+ cost = kwargs.get("response_cost", None)
+ if cost is not None:
+ clean_metadata["litellm_response_cost"] = cost
+
+ metrics: Optional[dict] = None
+ usage_obj = getattr(response_obj, "usage", None)
+ if usage_obj and isinstance(usage_obj, litellm.Usage):
+ litellm.utils.get_logging_id(start_time, response_obj)
+ metrics = {
+ "prompt_tokens": usage_obj.prompt_tokens,
+ "completion_tokens": usage_obj.completion_tokens,
+ "total_tokens": usage_obj.total_tokens,
+ "total_cost": cost,
+ "time_to_first_token": end_time.timestamp() - start_time.timestamp(),
+ "start": start_time.timestamp(),
+ "end": end_time.timestamp(),
+ }
+
+ request_data = {
+ "id": litellm_call_id,
+ "input": prompt["messages"],
+ "metadata": clean_metadata,
+ "tags": tags,
+ "span_attributes": {"name": "Chat Completion", "type": "llm"},
+ }
+ if choices is not None:
+ request_data["output"] = [choice.dict() for choice in choices]
+ else:
+ request_data["output"] = output
+
+ if metrics is not None:
+ request_data["metrics"] = metrics
+
+ try:
+ print_verbose(f"global_braintrust_sync_http_handler.post: {global_braintrust_sync_http_handler.post}")
+ global_braintrust_sync_http_handler.post(
+ url=f"{self.api_base}/project_logs/{project_id}/insert",
+ json={"events": [request_data]},
+ headers=self.headers,
+ )
+ except httpx.HTTPStatusError as e:
+ raise Exception(e.response.text)
+ except Exception as e:
+ raise e # don't use verbose_logger.exception, if exception is raised
+
+ async def async_log_success_event( # noqa: PLR0915
+ self, kwargs, response_obj, start_time, end_time
+ ):
+ verbose_logger.debug("REACHES BRAINTRUST SUCCESS")
+ try:
+ litellm_call_id = kwargs.get("litellm_call_id")
+ prompt = {"messages": kwargs.get("messages")}
+ output = None
+ choices = []
+ if response_obj is not None and (
+ kwargs.get("call_type", None) == "embedding" or isinstance(response_obj, litellm.EmbeddingResponse)
+ ):
+ output = None
+ elif response_obj is not None and isinstance(response_obj, litellm.ModelResponse):
+ output = response_obj["choices"][0]["message"].json()
+ choices = response_obj["choices"]
+ elif response_obj is not None and isinstance(response_obj, litellm.TextCompletionResponse):
+ output = response_obj.choices[0].text
+ choices = response_obj.choices
+ elif response_obj is not None and isinstance(response_obj, litellm.ImageResponse):
+ output = response_obj["data"]
+
+ litellm_params = kwargs.get("litellm_params", {})
+ metadata = litellm_params.get("metadata", {}) or {} # if litellm_params['metadata'] == None
+ metadata = self.add_metadata_from_header(litellm_params, metadata)
+ clean_metadata = {}
+ new_metadata = {}
+ for key, value in metadata.items():
+ if (
+ isinstance(value, list)
+ or isinstance(value, str)
+ or isinstance(value, int)
+ or isinstance(value, float)
+ ):
+ new_metadata[key] = value
+ elif isinstance(value, BaseModel):
+ new_metadata[key] = value.model_dump_json()
+ elif isinstance(value, dict):
+ for k, v in value.items():
+ if isinstance(v, datetime):
+ value[k] = v.isoformat()
+ new_metadata[key] = value
+
+ # Get project_id from metadata or create default if needed
+ project_id = metadata.get("project_id")
+ if project_id is None:
+ project_name = metadata.get("project_name")
+ project_id = await self.get_project_id_async(project_name) if project_name else None
+
+ if project_id is None:
+ if self.default_project_id is None:
+ await self.create_default_project_and_experiment()
+ project_id = self.default_project_id
+
+ tags = []
+ if isinstance(metadata, dict):
+ for key, value in metadata.items():
+ # generate langfuse tags - Default Tags sent to Langfuse from LiteLLM Proxy
+ if (
+ litellm.langfuse_default_tags is not None
+ and isinstance(litellm.langfuse_default_tags, list)
+ and key in litellm.langfuse_default_tags
+ ):
+ tags.append(f"{key}:{value}")
+
+ # clean litellm metadata before logging
+ if key in [
+ "headers",
+ "endpoint",
+ "caching_groups",
+ "previous_models",
+ ]:
+ continue
+ else:
+ clean_metadata[key] = value
+
+ cost = kwargs.get("response_cost", None)
+ if cost is not None:
+ clean_metadata["litellm_response_cost"] = cost
+
+ metrics: Optional[dict] = None
+ usage_obj = getattr(response_obj, "usage", None)
+ if usage_obj and isinstance(usage_obj, litellm.Usage):
+ litellm.utils.get_logging_id(start_time, response_obj)
+ metrics = {
+ "prompt_tokens": usage_obj.prompt_tokens,
+ "completion_tokens": usage_obj.completion_tokens,
+ "total_tokens": usage_obj.total_tokens,
+ "total_cost": cost,
+ "start": start_time.timestamp(),
+ "end": end_time.timestamp(),
+ }
+
+ api_call_start_time = kwargs.get("api_call_start_time")
+ completion_start_time = kwargs.get("completion_start_time")
+
+ if api_call_start_time is not None and completion_start_time is not None:
+ metrics["time_to_first_token"] = completion_start_time.timestamp() - api_call_start_time.timestamp()
+
+ request_data = {
+ "id": litellm_call_id,
+ "input": prompt["messages"],
+ "output": output,
+ "metadata": clean_metadata,
+ "tags": tags,
+ "span_attributes": {"name": "Chat Completion", "type": "llm"},
+ }
+ if choices is not None:
+ request_data["output"] = [choice.dict() for choice in choices]
+ else:
+ request_data["output"] = output
+
+ if metrics is not None:
+ request_data["metrics"] = metrics
+
+ if metrics is not None:
+ request_data["metrics"] = metrics
+
+ try:
+ await global_braintrust_http_handler.post(
+ url=f"{self.api_base}/project_logs/{project_id}/insert",
+ json={"events": [request_data]},
+ headers=self.headers,
+ )
+ except httpx.HTTPStatusError as e:
+ raise Exception(e.response.text)
+ except Exception as e:
+ raise e # don't use verbose_logger.exception, if exception is raised
+
+ def log_failure_event(self, kwargs, response_obj, start_time, end_time):
+ return super().log_failure_event(kwargs, response_obj, start_time, end_time)
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/custom_batch_logger.py b/.venv/lib/python3.12/site-packages/litellm/integrations/custom_batch_logger.py
new file mode 100644
index 00000000..3cfdf82c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/custom_batch_logger.py
@@ -0,0 +1,59 @@
+"""
+Custom Logger that handles batching logic
+
+Use this if you want your logs to be stored in memory and flushed periodically.
+"""
+
+import asyncio
+import time
+from typing import List, Optional
+
+import litellm
+from litellm._logging import verbose_logger
+from litellm.integrations.custom_logger import CustomLogger
+
+
+class CustomBatchLogger(CustomLogger):
+
+ def __init__(
+ self,
+ flush_lock: Optional[asyncio.Lock] = None,
+ batch_size: Optional[int] = None,
+ flush_interval: Optional[int] = None,
+ **kwargs,
+ ) -> None:
+ """
+ Args:
+ flush_lock (Optional[asyncio.Lock], optional): Lock to use when flushing the queue. Defaults to None. Only used for custom loggers that do batching
+ """
+ self.log_queue: List = []
+ self.flush_interval = flush_interval or litellm.DEFAULT_FLUSH_INTERVAL_SECONDS
+ self.batch_size: int = batch_size or litellm.DEFAULT_BATCH_SIZE
+ self.last_flush_time = time.time()
+ self.flush_lock = flush_lock
+
+ super().__init__(**kwargs)
+
+ async def periodic_flush(self):
+ while True:
+ await asyncio.sleep(self.flush_interval)
+ verbose_logger.debug(
+ f"CustomLogger periodic flush after {self.flush_interval} seconds"
+ )
+ await self.flush_queue()
+
+ async def flush_queue(self):
+ if self.flush_lock is None:
+ return
+
+ async with self.flush_lock:
+ if self.log_queue:
+ verbose_logger.debug(
+ "CustomLogger: Flushing batch of %s events", len(self.log_queue)
+ )
+ await self.async_send_batch()
+ self.log_queue.clear()
+ self.last_flush_time = time.time()
+
+ async def async_send_batch(self, *args, **kwargs):
+ pass
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/custom_guardrail.py b/.venv/lib/python3.12/site-packages/litellm/integrations/custom_guardrail.py
new file mode 100644
index 00000000..4421664b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/custom_guardrail.py
@@ -0,0 +1,274 @@
+from typing import Dict, List, Literal, Optional, Union
+
+from litellm._logging import verbose_logger
+from litellm.integrations.custom_logger import CustomLogger
+from litellm.types.guardrails import DynamicGuardrailParams, GuardrailEventHooks
+from litellm.types.utils import StandardLoggingGuardrailInformation
+
+
+class CustomGuardrail(CustomLogger):
+
+ def __init__(
+ self,
+ guardrail_name: Optional[str] = None,
+ supported_event_hooks: Optional[List[GuardrailEventHooks]] = None,
+ event_hook: Optional[
+ Union[GuardrailEventHooks, List[GuardrailEventHooks]]
+ ] = None,
+ default_on: bool = False,
+ **kwargs,
+ ):
+ """
+ Initialize the CustomGuardrail class
+
+ Args:
+ guardrail_name: The name of the guardrail. This is the name used in your requests.
+ supported_event_hooks: The event hooks that the guardrail supports
+ event_hook: The event hook to run the guardrail on
+ default_on: If True, the guardrail will be run by default on all requests
+ """
+ self.guardrail_name = guardrail_name
+ self.supported_event_hooks = supported_event_hooks
+ self.event_hook: Optional[
+ Union[GuardrailEventHooks, List[GuardrailEventHooks]]
+ ] = event_hook
+ self.default_on: bool = default_on
+
+ if supported_event_hooks:
+ ## validate event_hook is in supported_event_hooks
+ self._validate_event_hook(event_hook, supported_event_hooks)
+ super().__init__(**kwargs)
+
+ def _validate_event_hook(
+ self,
+ event_hook: Optional[Union[GuardrailEventHooks, List[GuardrailEventHooks]]],
+ supported_event_hooks: List[GuardrailEventHooks],
+ ) -> None:
+ if event_hook is None:
+ return
+ if isinstance(event_hook, list):
+ for hook in event_hook:
+ if hook not in supported_event_hooks:
+ raise ValueError(
+ f"Event hook {hook} is not in the supported event hooks {supported_event_hooks}"
+ )
+ elif isinstance(event_hook, GuardrailEventHooks):
+ if event_hook not in supported_event_hooks:
+ raise ValueError(
+ f"Event hook {event_hook} is not in the supported event hooks {supported_event_hooks}"
+ )
+
+ def get_guardrail_from_metadata(
+ self, data: dict
+ ) -> Union[List[str], List[Dict[str, DynamicGuardrailParams]]]:
+ """
+ Returns the guardrail(s) to be run from the metadata
+ """
+ metadata = data.get("metadata") or {}
+ requested_guardrails = metadata.get("guardrails") or []
+ return requested_guardrails
+
+ def _guardrail_is_in_requested_guardrails(
+ self,
+ requested_guardrails: Union[List[str], List[Dict[str, DynamicGuardrailParams]]],
+ ) -> bool:
+ for _guardrail in requested_guardrails:
+ if isinstance(_guardrail, dict):
+ if self.guardrail_name in _guardrail:
+ return True
+ elif isinstance(_guardrail, str):
+ if self.guardrail_name == _guardrail:
+ return True
+ return False
+
+ def should_run_guardrail(self, data, event_type: GuardrailEventHooks) -> bool:
+ """
+ Returns True if the guardrail should be run on the event_type
+ """
+ requested_guardrails = self.get_guardrail_from_metadata(data)
+
+ verbose_logger.debug(
+ "inside should_run_guardrail for guardrail=%s event_type= %s guardrail_supported_event_hooks= %s requested_guardrails= %s self.default_on= %s",
+ self.guardrail_name,
+ event_type,
+ self.event_hook,
+ requested_guardrails,
+ self.default_on,
+ )
+
+ if self.default_on is True:
+ if self._event_hook_is_event_type(event_type):
+ return True
+ return False
+
+ if (
+ self.event_hook
+ and not self._guardrail_is_in_requested_guardrails(requested_guardrails)
+ and event_type.value != "logging_only"
+ ):
+ return False
+
+ if not self._event_hook_is_event_type(event_type):
+ return False
+
+ return True
+
+ def _event_hook_is_event_type(self, event_type: GuardrailEventHooks) -> bool:
+ """
+ Returns True if the event_hook is the same as the event_type
+
+ eg. if `self.event_hook == "pre_call" and event_type == "pre_call"` -> then True
+ eg. if `self.event_hook == "pre_call" and event_type == "post_call"` -> then False
+ """
+
+ if self.event_hook is None:
+ return True
+ if isinstance(self.event_hook, list):
+ return event_type.value in self.event_hook
+ return self.event_hook == event_type.value
+
+ def get_guardrail_dynamic_request_body_params(self, request_data: dict) -> dict:
+ """
+ Returns `extra_body` to be added to the request body for the Guardrail API call
+
+ Use this to pass dynamic params to the guardrail API call - eg. success_threshold, failure_threshold, etc.
+
+ ```
+ [{"lakera_guard": {"extra_body": {"foo": "bar"}}}]
+ ```
+
+ Will return: for guardrail=`lakera-guard`:
+ {
+ "foo": "bar"
+ }
+
+ Args:
+ request_data: The original `request_data` passed to LiteLLM Proxy
+ """
+ requested_guardrails = self.get_guardrail_from_metadata(request_data)
+
+ # Look for the guardrail configuration matching self.guardrail_name
+ for guardrail in requested_guardrails:
+ if isinstance(guardrail, dict) and self.guardrail_name in guardrail:
+ # Get the configuration for this guardrail
+ guardrail_config: DynamicGuardrailParams = DynamicGuardrailParams(
+ **guardrail[self.guardrail_name]
+ )
+ if self._validate_premium_user() is not True:
+ return {}
+
+ # Return the extra_body if it exists, otherwise empty dict
+ return guardrail_config.get("extra_body", {})
+
+ return {}
+
+ def _validate_premium_user(self) -> bool:
+ """
+ Returns True if the user is a premium user
+ """
+ from litellm.proxy.proxy_server import CommonProxyErrors, premium_user
+
+ if premium_user is not True:
+ verbose_logger.warning(
+ f"Trying to use premium guardrail without premium user {CommonProxyErrors.not_premium_user.value}"
+ )
+ return False
+ return True
+
+ def add_standard_logging_guardrail_information_to_request_data(
+ self,
+ guardrail_json_response: Union[Exception, str, dict],
+ request_data: dict,
+ guardrail_status: Literal["success", "failure"],
+ ) -> None:
+ """
+ Builds `StandardLoggingGuardrailInformation` and adds it to the request metadata so it can be used for logging to DataDog, Langfuse, etc.
+ """
+ from litellm.proxy.proxy_server import premium_user
+
+ if premium_user is not True:
+ verbose_logger.warning(
+ f"Guardrail Tracing is only available for premium users. Skipping guardrail logging for guardrail={self.guardrail_name} event_hook={self.event_hook}"
+ )
+ return
+ if isinstance(guardrail_json_response, Exception):
+ guardrail_json_response = str(guardrail_json_response)
+ slg = StandardLoggingGuardrailInformation(
+ guardrail_name=self.guardrail_name,
+ guardrail_mode=self.event_hook,
+ guardrail_response=guardrail_json_response,
+ guardrail_status=guardrail_status,
+ )
+ if "metadata" in request_data:
+ request_data["metadata"]["standard_logging_guardrail_information"] = slg
+ elif "litellm_metadata" in request_data:
+ request_data["litellm_metadata"][
+ "standard_logging_guardrail_information"
+ ] = slg
+ else:
+ verbose_logger.warning(
+ "unable to log guardrail information. No metadata found in request_data"
+ )
+
+
+def log_guardrail_information(func):
+ """
+ Decorator to add standard logging guardrail information to any function
+
+ Add this decorator to ensure your guardrail response is logged to DataDog, OTEL, s3, GCS etc.
+
+ Logs for:
+ - pre_call
+ - during_call
+ - TODO: log post_call. This is more involved since the logs are sent to DD, s3 before the guardrail is even run
+ """
+ import asyncio
+ import functools
+
+ def process_response(self, response, request_data):
+ self.add_standard_logging_guardrail_information_to_request_data(
+ guardrail_json_response=response,
+ request_data=request_data,
+ guardrail_status="success",
+ )
+ return response
+
+ def process_error(self, e, request_data):
+ self.add_standard_logging_guardrail_information_to_request_data(
+ guardrail_json_response=e,
+ request_data=request_data,
+ guardrail_status="failure",
+ )
+ raise e
+
+ @functools.wraps(func)
+ async def async_wrapper(*args, **kwargs):
+ self: CustomGuardrail = args[0]
+ request_data: Optional[dict] = (
+ kwargs.get("data") or kwargs.get("request_data") or {}
+ )
+ try:
+ response = await func(*args, **kwargs)
+ return process_response(self, response, request_data)
+ except Exception as e:
+ return process_error(self, e, request_data)
+
+ @functools.wraps(func)
+ def sync_wrapper(*args, **kwargs):
+ self: CustomGuardrail = args[0]
+ request_data: Optional[dict] = (
+ kwargs.get("data") or kwargs.get("request_data") or {}
+ )
+ try:
+ response = func(*args, **kwargs)
+ return process_response(self, response, request_data)
+ except Exception as e:
+ return process_error(self, e, request_data)
+
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ if asyncio.iscoroutinefunction(func):
+ return async_wrapper(*args, **kwargs)
+ return sync_wrapper(*args, **kwargs)
+
+ return wrapper
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/custom_logger.py b/.venv/lib/python3.12/site-packages/litellm/integrations/custom_logger.py
new file mode 100644
index 00000000..6f1ec88d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/custom_logger.py
@@ -0,0 +1,388 @@
+#### What this does ####
+# On success, logs events to Promptlayer
+import traceback
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ AsyncGenerator,
+ List,
+ Literal,
+ Optional,
+ Tuple,
+ Union,
+)
+
+from pydantic import BaseModel
+
+from litellm.caching.caching import DualCache
+from litellm.proxy._types import UserAPIKeyAuth
+from litellm.types.integrations.argilla import ArgillaItem
+from litellm.types.llms.openai import AllMessageValues, ChatCompletionRequest
+from litellm.types.utils import (
+ AdapterCompletionStreamWrapper,
+ EmbeddingResponse,
+ ImageResponse,
+ ModelResponse,
+ ModelResponseStream,
+ StandardCallbackDynamicParams,
+ StandardLoggingPayload,
+)
+
+if TYPE_CHECKING:
+ from opentelemetry.trace import Span as _Span
+
+ Span = _Span
+else:
+ Span = Any
+
+
+class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
+ # Class variables or attributes
+ def __init__(self, message_logging: bool = True) -> None:
+ self.message_logging = message_logging
+ pass
+
+ def log_pre_api_call(self, model, messages, kwargs):
+ pass
+
+ def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
+ pass
+
+ def log_stream_event(self, kwargs, response_obj, start_time, end_time):
+ pass
+
+ def log_success_event(self, kwargs, response_obj, start_time, end_time):
+ pass
+
+ def log_failure_event(self, kwargs, response_obj, start_time, end_time):
+ pass
+
+ #### ASYNC ####
+
+ async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
+ pass
+
+ async def async_log_pre_api_call(self, model, messages, kwargs):
+ pass
+
+ async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
+ pass
+
+ async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
+ pass
+
+ #### PROMPT MANAGEMENT HOOKS ####
+
+ async def async_get_chat_completion_prompt(
+ self,
+ model: str,
+ messages: List[AllMessageValues],
+ non_default_params: dict,
+ prompt_id: str,
+ prompt_variables: Optional[dict],
+ dynamic_callback_params: StandardCallbackDynamicParams,
+ ) -> Tuple[str, List[AllMessageValues], dict]:
+ """
+ Returns:
+ - model: str - the model to use (can be pulled from prompt management tool)
+ - messages: List[AllMessageValues] - the messages to use (can be pulled from prompt management tool)
+ - non_default_params: dict - update with any optional params (e.g. temperature, max_tokens, etc.) to use (can be pulled from prompt management tool)
+ """
+ return model, messages, non_default_params
+
+ def get_chat_completion_prompt(
+ self,
+ model: str,
+ messages: List[AllMessageValues],
+ non_default_params: dict,
+ prompt_id: str,
+ prompt_variables: Optional[dict],
+ dynamic_callback_params: StandardCallbackDynamicParams,
+ ) -> Tuple[str, List[AllMessageValues], dict]:
+ """
+ Returns:
+ - model: str - the model to use (can be pulled from prompt management tool)
+ - messages: List[AllMessageValues] - the messages to use (can be pulled from prompt management tool)
+ - non_default_params: dict - update with any optional params (e.g. temperature, max_tokens, etc.) to use (can be pulled from prompt management tool)
+ """
+ return model, messages, non_default_params
+
+ #### PRE-CALL CHECKS - router/proxy only ####
+ """
+ Allows usage-based-routing-v2 to run pre-call rpm checks within the picked deployment's semaphore (concurrency-safe tpm/rpm checks).
+ """
+
+ async def async_filter_deployments(
+ self,
+ model: str,
+ healthy_deployments: List,
+ messages: Optional[List[AllMessageValues]],
+ request_kwargs: Optional[dict] = None,
+ parent_otel_span: Optional[Span] = None,
+ ) -> List[dict]:
+ return healthy_deployments
+
+ async def async_pre_call_check(
+ self, deployment: dict, parent_otel_span: Optional[Span]
+ ) -> Optional[dict]:
+ pass
+
+ def pre_call_check(self, deployment: dict) -> Optional[dict]:
+ pass
+
+ #### Fallback Events - router/proxy only ####
+ async def log_model_group_rate_limit_error(
+ self, exception: Exception, original_model_group: Optional[str], kwargs: dict
+ ):
+ pass
+
+ async def log_success_fallback_event(
+ self, original_model_group: str, kwargs: dict, original_exception: Exception
+ ):
+ pass
+
+ async def log_failure_fallback_event(
+ self, original_model_group: str, kwargs: dict, original_exception: Exception
+ ):
+ pass
+
+ #### ADAPTERS #### Allow calling 100+ LLMs in custom format - https://github.com/BerriAI/litellm/pulls
+
+ def translate_completion_input_params(
+ self, kwargs
+ ) -> Optional[ChatCompletionRequest]:
+ """
+ Translates the input params, from the provider's native format to the litellm.completion() format.
+ """
+ pass
+
+ def translate_completion_output_params(
+ self, response: ModelResponse
+ ) -> Optional[BaseModel]:
+ """
+ Translates the output params, from the OpenAI format to the custom format.
+ """
+ pass
+
+ def translate_completion_output_params_streaming(
+ self, completion_stream: Any
+ ) -> Optional[AdapterCompletionStreamWrapper]:
+ """
+ Translates the streaming chunk, from the OpenAI format to the custom format.
+ """
+ pass
+
+ ### DATASET HOOKS #### - currently only used for Argilla
+
+ async def async_dataset_hook(
+ self,
+ logged_item: ArgillaItem,
+ standard_logging_payload: Optional[StandardLoggingPayload],
+ ) -> Optional[ArgillaItem]:
+ """
+ - Decide if the result should be logged to Argilla.
+ - Modify the result before logging to Argilla.
+ - Return None if the result should not be logged to Argilla.
+ """
+ raise NotImplementedError("async_dataset_hook not implemented")
+
+ #### CALL HOOKS - proxy only ####
+ """
+ Control the modify incoming / outgoung data before calling the model
+ """
+
+ async def async_pre_call_hook(
+ self,
+ user_api_key_dict: UserAPIKeyAuth,
+ cache: DualCache,
+ data: dict,
+ call_type: Literal[
+ "completion",
+ "text_completion",
+ "embeddings",
+ "image_generation",
+ "moderation",
+ "audio_transcription",
+ "pass_through_endpoint",
+ "rerank",
+ ],
+ ) -> Optional[
+ Union[Exception, str, dict]
+ ]: # raise exception if invalid, return a str for the user to receive - if rejected, or return a modified dictionary for passing into litellm
+ pass
+
+ async def async_post_call_failure_hook(
+ self,
+ request_data: dict,
+ original_exception: Exception,
+ user_api_key_dict: UserAPIKeyAuth,
+ ):
+ pass
+
+ async def async_post_call_success_hook(
+ self,
+ data: dict,
+ user_api_key_dict: UserAPIKeyAuth,
+ response: Union[Any, ModelResponse, EmbeddingResponse, ImageResponse],
+ ) -> Any:
+ pass
+
+ async def async_logging_hook(
+ self, kwargs: dict, result: Any, call_type: str
+ ) -> Tuple[dict, Any]:
+ """For masking logged request/response. Return a modified version of the request/result."""
+ return kwargs, result
+
+ def logging_hook(
+ self, kwargs: dict, result: Any, call_type: str
+ ) -> Tuple[dict, Any]:
+ """For masking logged request/response. Return a modified version of the request/result."""
+ return kwargs, result
+
+ async def async_moderation_hook(
+ self,
+ data: dict,
+ user_api_key_dict: UserAPIKeyAuth,
+ call_type: Literal[
+ "completion",
+ "embeddings",
+ "image_generation",
+ "moderation",
+ "audio_transcription",
+ "responses",
+ ],
+ ) -> Any:
+ pass
+
+ async def async_post_call_streaming_hook(
+ self,
+ user_api_key_dict: UserAPIKeyAuth,
+ response: str,
+ ) -> Any:
+ pass
+
+ async def async_post_call_streaming_iterator_hook(
+ self,
+ user_api_key_dict: UserAPIKeyAuth,
+ response: Any,
+ request_data: dict,
+ ) -> AsyncGenerator[ModelResponseStream, None]:
+ async for item in response:
+ yield item
+
+ #### SINGLE-USE #### - https://docs.litellm.ai/docs/observability/custom_callback#using-your-custom-callback-function
+
+ def log_input_event(self, model, messages, kwargs, print_verbose, callback_func):
+ try:
+ kwargs["model"] = model
+ kwargs["messages"] = messages
+ kwargs["log_event_type"] = "pre_api_call"
+ callback_func(
+ kwargs,
+ )
+ print_verbose(f"Custom Logger - model call details: {kwargs}")
+ except Exception:
+ print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
+
+ async def async_log_input_event(
+ self, model, messages, kwargs, print_verbose, callback_func
+ ):
+ try:
+ kwargs["model"] = model
+ kwargs["messages"] = messages
+ kwargs["log_event_type"] = "pre_api_call"
+ await callback_func(
+ kwargs,
+ )
+ print_verbose(f"Custom Logger - model call details: {kwargs}")
+ except Exception:
+ print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
+
+ def log_event(
+ self, kwargs, response_obj, start_time, end_time, print_verbose, callback_func
+ ):
+ # Method definition
+ try:
+ kwargs["log_event_type"] = "post_api_call"
+ callback_func(
+ kwargs, # kwargs to func
+ response_obj,
+ start_time,
+ end_time,
+ )
+ except Exception:
+ print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
+ pass
+
+ async def async_log_event(
+ self, kwargs, response_obj, start_time, end_time, print_verbose, callback_func
+ ):
+ # Method definition
+ try:
+ kwargs["log_event_type"] = "post_api_call"
+ await callback_func(
+ kwargs, # kwargs to func
+ response_obj,
+ start_time,
+ end_time,
+ )
+ except Exception:
+ print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
+ pass
+
+ # Useful helpers for custom logger classes
+
+ def truncate_standard_logging_payload_content(
+ self,
+ standard_logging_object: StandardLoggingPayload,
+ ):
+ """
+ Truncate error strings and message content in logging payload
+
+ Some loggers like DataDog/ GCS Bucket have a limit on the size of the payload. (1MB)
+
+ This function truncates the error string and the message content if they exceed a certain length.
+ """
+ MAX_STR_LENGTH = 10_000
+
+ # Truncate fields that might exceed max length
+ fields_to_truncate = ["error_str", "messages", "response"]
+ for field in fields_to_truncate:
+ self._truncate_field(
+ standard_logging_object=standard_logging_object,
+ field_name=field,
+ max_length=MAX_STR_LENGTH,
+ )
+
+ def _truncate_field(
+ self,
+ standard_logging_object: StandardLoggingPayload,
+ field_name: str,
+ max_length: int,
+ ) -> None:
+ """
+ Helper function to truncate a field in the logging payload
+
+ This converts the field to a string and then truncates it if it exceeds the max length.
+
+ Why convert to string ?
+ 1. User was sending a poorly formatted list for `messages` field, we could not predict where they would send content
+ - Converting to string and then truncating the logged content catches this
+ 2. We want to avoid modifying the original `messages`, `response`, and `error_str` in the logging payload since these are in kwargs and could be returned to the user
+ """
+ field_value = standard_logging_object.get(field_name) # type: ignore
+ if field_value:
+ str_value = str(field_value)
+ if len(str_value) > max_length:
+ standard_logging_object[field_name] = self._truncate_text( # type: ignore
+ text=str_value, max_length=max_length
+ )
+
+ def _truncate_text(self, text: str, max_length: int) -> str:
+ """Truncate text if it exceeds max_length"""
+ return (
+ text[:max_length]
+ + "...truncated by litellm, this logger does not support large content"
+ if len(text) > max_length
+ else text
+ )
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/custom_prompt_management.py b/.venv/lib/python3.12/site-packages/litellm/integrations/custom_prompt_management.py
new file mode 100644
index 00000000..5b34ef0c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/custom_prompt_management.py
@@ -0,0 +1,49 @@
+from typing import List, Optional, Tuple
+
+from litellm.integrations.custom_logger import CustomLogger
+from litellm.integrations.prompt_management_base import (
+ PromptManagementBase,
+ PromptManagementClient,
+)
+from litellm.types.llms.openai import AllMessageValues
+from litellm.types.utils import StandardCallbackDynamicParams
+
+
+class CustomPromptManagement(CustomLogger, PromptManagementBase):
+ def get_chat_completion_prompt(
+ self,
+ model: str,
+ messages: List[AllMessageValues],
+ non_default_params: dict,
+ prompt_id: str,
+ prompt_variables: Optional[dict],
+ dynamic_callback_params: StandardCallbackDynamicParams,
+ ) -> Tuple[str, List[AllMessageValues], dict]:
+ """
+ Returns:
+ - model: str - the model to use (can be pulled from prompt management tool)
+ - messages: List[AllMessageValues] - the messages to use (can be pulled from prompt management tool)
+ - non_default_params: dict - update with any optional params (e.g. temperature, max_tokens, etc.) to use (can be pulled from prompt management tool)
+ """
+ return model, messages, non_default_params
+
+ @property
+ def integration_name(self) -> str:
+ return "custom-prompt-management"
+
+ def should_run_prompt_management(
+ self,
+ prompt_id: str,
+ dynamic_callback_params: StandardCallbackDynamicParams,
+ ) -> bool:
+ return True
+
+ def _compile_prompt_helper(
+ self,
+ prompt_id: str,
+ prompt_variables: Optional[dict],
+ dynamic_callback_params: StandardCallbackDynamicParams,
+ ) -> PromptManagementClient:
+ raise NotImplementedError(
+ "Custom prompt management does not support compile prompt helper"
+ )
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/datadog/datadog.py b/.venv/lib/python3.12/site-packages/litellm/integrations/datadog/datadog.py
new file mode 100644
index 00000000..4f4b05c8
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/datadog/datadog.py
@@ -0,0 +1,580 @@
+"""
+DataDog Integration - sends logs to /api/v2/log
+
+DD Reference API: https://docs.datadoghq.com/api/latest/logs
+
+`async_log_success_event` - used by litellm proxy to send logs to datadog
+`log_success_event` - sync version of logging to DataDog, only used on litellm Python SDK, if user opts in to using sync functions
+
+async_log_success_event: will store batch of DD_MAX_BATCH_SIZE in memory and flush to Datadog once it reaches DD_MAX_BATCH_SIZE or every 5 seconds
+
+async_service_failure_hook: Logs failures from Redis, Postgres (Adjacent systems), as 'WARNING' on DataDog
+
+For batching specific details see CustomBatchLogger class
+"""
+
+import asyncio
+import datetime
+import json
+import os
+import traceback
+import uuid
+from datetime import datetime as datetimeObj
+from typing import Any, List, Optional, Union
+
+import httpx
+from httpx import Response
+
+import litellm
+from litellm._logging import verbose_logger
+from litellm.integrations.custom_batch_logger import CustomBatchLogger
+from litellm.llms.custom_httpx.http_handler import (
+ _get_httpx_client,
+ get_async_httpx_client,
+ httpxSpecialProvider,
+)
+from litellm.types.integrations.base_health_check import IntegrationHealthCheckStatus
+from litellm.types.integrations.datadog import *
+from litellm.types.services import ServiceLoggerPayload, ServiceTypes
+from litellm.types.utils import StandardLoggingPayload
+
+from ..additional_logging_utils import AdditionalLoggingUtils
+
+# max number of logs DD API can accept
+DD_MAX_BATCH_SIZE = 1000
+
+# specify what ServiceTypes are logged as success events to DD. (We don't want to spam DD traces with large number of service types)
+DD_LOGGED_SUCCESS_SERVICE_TYPES = [
+ ServiceTypes.RESET_BUDGET_JOB,
+]
+
+
+class DataDogLogger(
+ CustomBatchLogger,
+ AdditionalLoggingUtils,
+):
+ # Class variables or attributes
+ def __init__(
+ self,
+ **kwargs,
+ ):
+ """
+ Initializes the datadog logger, checks if the correct env variables are set
+
+ Required environment variables:
+ `DD_API_KEY` - your datadog api key
+ `DD_SITE` - your datadog site, example = `"us5.datadoghq.com"`
+ """
+ try:
+ verbose_logger.debug("Datadog: in init datadog logger")
+ # check if the correct env variables are set
+ if os.getenv("DD_API_KEY", None) is None:
+ raise Exception("DD_API_KEY is not set, set 'DD_API_KEY=<>")
+ if os.getenv("DD_SITE", None) is None:
+ raise Exception("DD_SITE is not set in .env, set 'DD_SITE=<>")
+ self.async_client = get_async_httpx_client(
+ llm_provider=httpxSpecialProvider.LoggingCallback
+ )
+ self.DD_API_KEY = os.getenv("DD_API_KEY")
+ self.intake_url = (
+ f"https://http-intake.logs.{os.getenv('DD_SITE')}/api/v2/logs"
+ )
+
+ ###################################
+ # OPTIONAL -only used for testing
+ dd_base_url: Optional[str] = (
+ os.getenv("_DATADOG_BASE_URL")
+ or os.getenv("DATADOG_BASE_URL")
+ or os.getenv("DD_BASE_URL")
+ )
+ if dd_base_url is not None:
+ self.intake_url = f"{dd_base_url}/api/v2/logs"
+ ###################################
+ self.sync_client = _get_httpx_client()
+ asyncio.create_task(self.periodic_flush())
+ self.flush_lock = asyncio.Lock()
+ super().__init__(
+ **kwargs, flush_lock=self.flush_lock, batch_size=DD_MAX_BATCH_SIZE
+ )
+ except Exception as e:
+ verbose_logger.exception(
+ f"Datadog: Got exception on init Datadog client {str(e)}"
+ )
+ raise e
+
+ async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
+ """
+ Async Log success events to Datadog
+
+ - Creates a Datadog payload
+ - Adds the Payload to the in memory logs queue
+ - Payload is flushed every 10 seconds or when batch size is greater than 100
+
+
+ Raises:
+ Raises a NON Blocking verbose_logger.exception if an error occurs
+ """
+ try:
+ verbose_logger.debug(
+ "Datadog: Logging - Enters logging function for model %s", kwargs
+ )
+ await self._log_async_event(kwargs, response_obj, start_time, end_time)
+
+ except Exception as e:
+ verbose_logger.exception(
+ f"Datadog Layer Error - {str(e)}\n{traceback.format_exc()}"
+ )
+ pass
+
+ async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
+ try:
+ verbose_logger.debug(
+ "Datadog: Logging - Enters logging function for model %s", kwargs
+ )
+ await self._log_async_event(kwargs, response_obj, start_time, end_time)
+
+ except Exception as e:
+ verbose_logger.exception(
+ f"Datadog Layer Error - {str(e)}\n{traceback.format_exc()}"
+ )
+ pass
+
+ async def async_send_batch(self):
+ """
+ Sends the in memory logs queue to datadog api
+
+ Logs sent to /api/v2/logs
+
+ DD Ref: https://docs.datadoghq.com/api/latest/logs/
+
+ Raises:
+ Raises a NON Blocking verbose_logger.exception if an error occurs
+ """
+ try:
+ if not self.log_queue:
+ verbose_logger.exception("Datadog: log_queue does not exist")
+ return
+
+ verbose_logger.debug(
+ "Datadog - about to flush %s events on %s",
+ len(self.log_queue),
+ self.intake_url,
+ )
+
+ response = await self.async_send_compressed_data(self.log_queue)
+ if response.status_code == 413:
+ verbose_logger.exception(DD_ERRORS.DATADOG_413_ERROR.value)
+ return
+
+ response.raise_for_status()
+ if response.status_code != 202:
+ raise Exception(
+ f"Response from datadog API status_code: {response.status_code}, text: {response.text}"
+ )
+
+ verbose_logger.debug(
+ "Datadog: Response from datadog API status_code: %s, text: %s",
+ response.status_code,
+ response.text,
+ )
+ except Exception as e:
+ verbose_logger.exception(
+ f"Datadog Error sending batch API - {str(e)}\n{traceback.format_exc()}"
+ )
+
+ def log_success_event(self, kwargs, response_obj, start_time, end_time):
+ """
+ Sync Log success events to Datadog
+
+ - Creates a Datadog payload
+ - instantly logs it on DD API
+ """
+ try:
+ if litellm.datadog_use_v1 is True:
+ dd_payload = self._create_v0_logging_payload(
+ kwargs=kwargs,
+ response_obj=response_obj,
+ start_time=start_time,
+ end_time=end_time,
+ )
+ else:
+ dd_payload = self.create_datadog_logging_payload(
+ kwargs=kwargs,
+ response_obj=response_obj,
+ start_time=start_time,
+ end_time=end_time,
+ )
+
+ response = self.sync_client.post(
+ url=self.intake_url,
+ json=dd_payload, # type: ignore
+ headers={
+ "DD-API-KEY": self.DD_API_KEY,
+ },
+ )
+
+ response.raise_for_status()
+ if response.status_code != 202:
+ raise Exception(
+ f"Response from datadog API status_code: {response.status_code}, text: {response.text}"
+ )
+
+ verbose_logger.debug(
+ "Datadog: Response from datadog API status_code: %s, text: %s",
+ response.status_code,
+ response.text,
+ )
+
+ except Exception as e:
+ verbose_logger.exception(
+ f"Datadog Layer Error - {str(e)}\n{traceback.format_exc()}"
+ )
+ pass
+ pass
+
+ async def _log_async_event(self, kwargs, response_obj, start_time, end_time):
+
+ dd_payload = self.create_datadog_logging_payload(
+ kwargs=kwargs,
+ response_obj=response_obj,
+ start_time=start_time,
+ end_time=end_time,
+ )
+
+ self.log_queue.append(dd_payload)
+ verbose_logger.debug(
+ f"Datadog, event added to queue. Will flush in {self.flush_interval} seconds..."
+ )
+
+ if len(self.log_queue) >= self.batch_size:
+ await self.async_send_batch()
+
+ def _create_datadog_logging_payload_helper(
+ self,
+ standard_logging_object: StandardLoggingPayload,
+ status: DataDogStatus,
+ ) -> DatadogPayload:
+ json_payload = json.dumps(standard_logging_object, default=str)
+ verbose_logger.debug("Datadog: Logger - Logging payload = %s", json_payload)
+ dd_payload = DatadogPayload(
+ ddsource=self._get_datadog_source(),
+ ddtags=self._get_datadog_tags(
+ standard_logging_object=standard_logging_object
+ ),
+ hostname=self._get_datadog_hostname(),
+ message=json_payload,
+ service=self._get_datadog_service(),
+ status=status,
+ )
+ return dd_payload
+
+ def create_datadog_logging_payload(
+ self,
+ kwargs: Union[dict, Any],
+ response_obj: Any,
+ start_time: datetime.datetime,
+ end_time: datetime.datetime,
+ ) -> DatadogPayload:
+ """
+ Helper function to create a datadog payload for logging
+
+ Args:
+ kwargs (Union[dict, Any]): request kwargs
+ response_obj (Any): llm api response
+ start_time (datetime.datetime): start time of request
+ end_time (datetime.datetime): end time of request
+
+ Returns:
+ DatadogPayload: defined in types.py
+ """
+
+ standard_logging_object: Optional[StandardLoggingPayload] = kwargs.get(
+ "standard_logging_object", None
+ )
+ if standard_logging_object is None:
+ raise ValueError("standard_logging_object not found in kwargs")
+
+ status = DataDogStatus.INFO
+ if standard_logging_object.get("status") == "failure":
+ status = DataDogStatus.ERROR
+
+ # Build the initial payload
+ self.truncate_standard_logging_payload_content(standard_logging_object)
+
+ dd_payload = self._create_datadog_logging_payload_helper(
+ standard_logging_object=standard_logging_object,
+ status=status,
+ )
+ return dd_payload
+
+ async def async_send_compressed_data(self, data: List) -> Response:
+ """
+ Async helper to send compressed data to datadog self.intake_url
+
+ Datadog recommends using gzip to compress data
+ https://docs.datadoghq.com/api/latest/logs/
+
+ "Datadog recommends sending your logs compressed. Add the Content-Encoding: gzip header to the request when sending"
+ """
+
+ import gzip
+ import json
+
+ compressed_data = gzip.compress(json.dumps(data, default=str).encode("utf-8"))
+ response = await self.async_client.post(
+ url=self.intake_url,
+ data=compressed_data, # type: ignore
+ headers={
+ "DD-API-KEY": self.DD_API_KEY,
+ "Content-Encoding": "gzip",
+ "Content-Type": "application/json",
+ },
+ )
+ return response
+
+ async def async_service_failure_hook(
+ self,
+ payload: ServiceLoggerPayload,
+ error: Optional[str] = "",
+ parent_otel_span: Optional[Any] = None,
+ start_time: Optional[Union[datetimeObj, float]] = None,
+ end_time: Optional[Union[float, datetimeObj]] = None,
+ event_metadata: Optional[dict] = None,
+ ):
+ """
+ Logs failures from Redis, Postgres (Adjacent systems), as 'WARNING' on DataDog
+
+ - example - Redis is failing / erroring, will be logged on DataDog
+ """
+ try:
+ _payload_dict = payload.model_dump()
+ _payload_dict.update(event_metadata or {})
+ _dd_message_str = json.dumps(_payload_dict, default=str)
+ _dd_payload = DatadogPayload(
+ ddsource=self._get_datadog_source(),
+ ddtags=self._get_datadog_tags(),
+ hostname=self._get_datadog_hostname(),
+ message=_dd_message_str,
+ service=self._get_datadog_service(),
+ status=DataDogStatus.WARN,
+ )
+
+ self.log_queue.append(_dd_payload)
+
+ except Exception as e:
+ verbose_logger.exception(
+ f"Datadog: Logger - Exception in async_service_failure_hook: {e}"
+ )
+ pass
+
+ async def async_service_success_hook(
+ self,
+ payload: ServiceLoggerPayload,
+ error: Optional[str] = "",
+ parent_otel_span: Optional[Any] = None,
+ start_time: Optional[Union[datetimeObj, float]] = None,
+ end_time: Optional[Union[float, datetimeObj]] = None,
+ event_metadata: Optional[dict] = None,
+ ):
+ """
+ Logs success from Redis, Postgres (Adjacent systems), as 'INFO' on DataDog
+
+ No user has asked for this so far, this might be spammy on datatdog. If need arises we can implement this
+ """
+ try:
+ # intentionally done. Don't want to log all service types to DD
+ if payload.service not in DD_LOGGED_SUCCESS_SERVICE_TYPES:
+ return
+
+ _payload_dict = payload.model_dump()
+ _payload_dict.update(event_metadata or {})
+
+ _dd_message_str = json.dumps(_payload_dict, default=str)
+ _dd_payload = DatadogPayload(
+ ddsource=self._get_datadog_source(),
+ ddtags=self._get_datadog_tags(),
+ hostname=self._get_datadog_hostname(),
+ message=_dd_message_str,
+ service=self._get_datadog_service(),
+ status=DataDogStatus.INFO,
+ )
+
+ self.log_queue.append(_dd_payload)
+
+ except Exception as e:
+ verbose_logger.exception(
+ f"Datadog: Logger - Exception in async_service_failure_hook: {e}"
+ )
+
+ def _create_v0_logging_payload(
+ self,
+ kwargs: Union[dict, Any],
+ response_obj: Any,
+ start_time: datetime.datetime,
+ end_time: datetime.datetime,
+ ) -> DatadogPayload:
+ """
+ Note: This is our V1 Version of DataDog Logging Payload
+
+
+ (Not Recommended) If you want this to get logged set `litellm.datadog_use_v1 = True`
+ """
+ import json
+
+ litellm_params = kwargs.get("litellm_params", {})
+ metadata = (
+ litellm_params.get("metadata", {}) or {}
+ ) # if litellm_params['metadata'] == None
+ messages = kwargs.get("messages")
+ optional_params = kwargs.get("optional_params", {})
+ call_type = kwargs.get("call_type", "litellm.completion")
+ cache_hit = kwargs.get("cache_hit", False)
+ usage = response_obj["usage"]
+ id = response_obj.get("id", str(uuid.uuid4()))
+ usage = dict(usage)
+ try:
+ response_time = (end_time - start_time).total_seconds() * 1000
+ except Exception:
+ response_time = None
+
+ try:
+ response_obj = dict(response_obj)
+ except Exception:
+ response_obj = response_obj
+
+ # Clean Metadata before logging - never log raw metadata
+ # the raw metadata can contain circular references which leads to infinite recursion
+ # we clean out all extra litellm metadata params before logging
+ clean_metadata = {}
+ if isinstance(metadata, dict):
+ for key, value in metadata.items():
+ # clean litellm metadata before logging
+ if key in [
+ "endpoint",
+ "caching_groups",
+ "previous_models",
+ ]:
+ continue
+ else:
+ clean_metadata[key] = value
+
+ # Build the initial payload
+ payload = {
+ "id": id,
+ "call_type": call_type,
+ "cache_hit": cache_hit,
+ "start_time": start_time,
+ "end_time": end_time,
+ "response_time": response_time,
+ "model": kwargs.get("model", ""),
+ "user": kwargs.get("user", ""),
+ "model_parameters": optional_params,
+ "spend": kwargs.get("response_cost", 0),
+ "messages": messages,
+ "response": response_obj,
+ "usage": usage,
+ "metadata": clean_metadata,
+ }
+
+ json_payload = json.dumps(payload, default=str)
+
+ verbose_logger.debug("Datadog: Logger - Logging payload = %s", json_payload)
+
+ dd_payload = DatadogPayload(
+ ddsource=self._get_datadog_source(),
+ ddtags=self._get_datadog_tags(),
+ hostname=self._get_datadog_hostname(),
+ message=json_payload,
+ service=self._get_datadog_service(),
+ status=DataDogStatus.INFO,
+ )
+ return dd_payload
+
+ @staticmethod
+ def _get_datadog_tags(
+ standard_logging_object: Optional[StandardLoggingPayload] = None,
+ ) -> str:
+ """
+ Get the datadog tags for the request
+
+ DD tags need to be as follows:
+ - tags: ["user_handle:dog@gmail.com", "app_version:1.0.0"]
+ """
+ base_tags = {
+ "env": os.getenv("DD_ENV", "unknown"),
+ "service": os.getenv("DD_SERVICE", "litellm"),
+ "version": os.getenv("DD_VERSION", "unknown"),
+ "HOSTNAME": DataDogLogger._get_datadog_hostname(),
+ "POD_NAME": os.getenv("POD_NAME", "unknown"),
+ }
+
+ tags = [f"{k}:{v}" for k, v in base_tags.items()]
+
+ if standard_logging_object:
+ _request_tags: List[str] = (
+ standard_logging_object.get("request_tags", []) or []
+ )
+ request_tags = [f"request_tag:{tag}" for tag in _request_tags]
+ tags.extend(request_tags)
+
+ return ",".join(tags)
+
+ @staticmethod
+ def _get_datadog_source():
+ return os.getenv("DD_SOURCE", "litellm")
+
+ @staticmethod
+ def _get_datadog_service():
+ return os.getenv("DD_SERVICE", "litellm-server")
+
+ @staticmethod
+ def _get_datadog_hostname():
+ return os.getenv("HOSTNAME", "")
+
+ @staticmethod
+ def _get_datadog_env():
+ return os.getenv("DD_ENV", "unknown")
+
+ @staticmethod
+ def _get_datadog_pod_name():
+ return os.getenv("POD_NAME", "unknown")
+
+ async def async_health_check(self) -> IntegrationHealthCheckStatus:
+ """
+ Check if the service is healthy
+ """
+ from litellm.litellm_core_utils.litellm_logging import (
+ create_dummy_standard_logging_payload,
+ )
+
+ standard_logging_object = create_dummy_standard_logging_payload()
+ dd_payload = self._create_datadog_logging_payload_helper(
+ standard_logging_object=standard_logging_object,
+ status=DataDogStatus.INFO,
+ )
+ log_queue = [dd_payload]
+ response = await self.async_send_compressed_data(log_queue)
+ try:
+ response.raise_for_status()
+ return IntegrationHealthCheckStatus(
+ status="healthy",
+ error_message=None,
+ )
+ except httpx.HTTPStatusError as e:
+ return IntegrationHealthCheckStatus(
+ status="unhealthy",
+ error_message=e.response.text,
+ )
+ except Exception as e:
+ return IntegrationHealthCheckStatus(
+ status="unhealthy",
+ error_message=str(e),
+ )
+
+ async def get_request_response_payload(
+ self,
+ request_id: str,
+ start_time_utc: Optional[datetimeObj],
+ end_time_utc: Optional[datetimeObj],
+ ) -> Optional[dict]:
+ pass
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/datadog/datadog_llm_obs.py b/.venv/lib/python3.12/site-packages/litellm/integrations/datadog/datadog_llm_obs.py
new file mode 100644
index 00000000..e4e074ba
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/datadog/datadog_llm_obs.py
@@ -0,0 +1,203 @@
+"""
+Implements logging integration with Datadog's LLM Observability Service
+
+
+API Reference: https://docs.datadoghq.com/llm_observability/setup/api/?tab=example#api-standards
+
+"""
+
+import asyncio
+import json
+import os
+import uuid
+from datetime import datetime
+from typing import Any, Dict, List, Optional, Union
+
+import litellm
+from litellm._logging import verbose_logger
+from litellm.integrations.custom_batch_logger import CustomBatchLogger
+from litellm.integrations.datadog.datadog import DataDogLogger
+from litellm.llms.custom_httpx.http_handler import (
+ get_async_httpx_client,
+ httpxSpecialProvider,
+)
+from litellm.types.integrations.datadog_llm_obs import *
+from litellm.types.utils import StandardLoggingPayload
+
+
+class DataDogLLMObsLogger(DataDogLogger, CustomBatchLogger):
+ def __init__(self, **kwargs):
+ try:
+ verbose_logger.debug("DataDogLLMObs: Initializing logger")
+ if os.getenv("DD_API_KEY", None) is None:
+ raise Exception("DD_API_KEY is not set, set 'DD_API_KEY=<>'")
+ if os.getenv("DD_SITE", None) is None:
+ raise Exception(
+ "DD_SITE is not set, set 'DD_SITE=<>', example sit = `us5.datadoghq.com`"
+ )
+
+ self.async_client = get_async_httpx_client(
+ llm_provider=httpxSpecialProvider.LoggingCallback
+ )
+ self.DD_API_KEY = os.getenv("DD_API_KEY")
+ self.DD_SITE = os.getenv("DD_SITE")
+ self.intake_url = (
+ f"https://api.{self.DD_SITE}/api/intake/llm-obs/v1/trace/spans"
+ )
+
+ # testing base url
+ dd_base_url = os.getenv("DD_BASE_URL")
+ if dd_base_url:
+ self.intake_url = f"{dd_base_url}/api/intake/llm-obs/v1/trace/spans"
+
+ asyncio.create_task(self.periodic_flush())
+ self.flush_lock = asyncio.Lock()
+ self.log_queue: List[LLMObsPayload] = []
+ CustomBatchLogger.__init__(self, **kwargs, flush_lock=self.flush_lock)
+ except Exception as e:
+ verbose_logger.exception(f"DataDogLLMObs: Error initializing - {str(e)}")
+ raise e
+
+ async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
+ try:
+ verbose_logger.debug(
+ f"DataDogLLMObs: Logging success event for model {kwargs.get('model', 'unknown')}"
+ )
+ payload = self.create_llm_obs_payload(
+ kwargs, response_obj, start_time, end_time
+ )
+ verbose_logger.debug(f"DataDogLLMObs: Payload: {payload}")
+ self.log_queue.append(payload)
+
+ if len(self.log_queue) >= self.batch_size:
+ await self.async_send_batch()
+ except Exception as e:
+ verbose_logger.exception(
+ f"DataDogLLMObs: Error logging success event - {str(e)}"
+ )
+
+ async def async_send_batch(self):
+ try:
+ if not self.log_queue:
+ return
+
+ verbose_logger.debug(
+ f"DataDogLLMObs: Flushing {len(self.log_queue)} events"
+ )
+
+ # Prepare the payload
+ payload = {
+ "data": DDIntakePayload(
+ type="span",
+ attributes=DDSpanAttributes(
+ ml_app=self._get_datadog_service(),
+ tags=[self._get_datadog_tags()],
+ spans=self.log_queue,
+ ),
+ ),
+ }
+ verbose_logger.debug("payload %s", json.dumps(payload, indent=4))
+ response = await self.async_client.post(
+ url=self.intake_url,
+ json=payload,
+ headers={
+ "DD-API-KEY": self.DD_API_KEY,
+ "Content-Type": "application/json",
+ },
+ )
+
+ response.raise_for_status()
+ if response.status_code != 202:
+ raise Exception(
+ f"DataDogLLMObs: Unexpected response - status_code: {response.status_code}, text: {response.text}"
+ )
+
+ verbose_logger.debug(
+ f"DataDogLLMObs: Successfully sent batch - status_code: {response.status_code}"
+ )
+ self.log_queue.clear()
+ except Exception as e:
+ verbose_logger.exception(f"DataDogLLMObs: Error sending batch - {str(e)}")
+
+ def create_llm_obs_payload(
+ self, kwargs: Dict, response_obj: Any, start_time: datetime, end_time: datetime
+ ) -> LLMObsPayload:
+ standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
+ "standard_logging_object"
+ )
+ if standard_logging_payload is None:
+ raise Exception("DataDogLLMObs: standard_logging_object is not set")
+
+ messages = standard_logging_payload["messages"]
+ messages = self._ensure_string_content(messages=messages)
+
+ metadata = kwargs.get("litellm_params", {}).get("metadata", {})
+
+ input_meta = InputMeta(messages=messages) # type: ignore
+ output_meta = OutputMeta(messages=self._get_response_messages(response_obj))
+
+ meta = Meta(
+ kind="llm",
+ input=input_meta,
+ output=output_meta,
+ metadata=self._get_dd_llm_obs_payload_metadata(standard_logging_payload),
+ )
+
+ # Calculate metrics (you may need to adjust these based on available data)
+ metrics = LLMMetrics(
+ input_tokens=float(standard_logging_payload.get("prompt_tokens", 0)),
+ output_tokens=float(standard_logging_payload.get("completion_tokens", 0)),
+ total_tokens=float(standard_logging_payload.get("total_tokens", 0)),
+ )
+
+ return LLMObsPayload(
+ parent_id=metadata.get("parent_id", "undefined"),
+ trace_id=metadata.get("trace_id", str(uuid.uuid4())),
+ span_id=metadata.get("span_id", str(uuid.uuid4())),
+ name=metadata.get("name", "litellm_llm_call"),
+ meta=meta,
+ start_ns=int(start_time.timestamp() * 1e9),
+ duration=int((end_time - start_time).total_seconds() * 1e9),
+ metrics=metrics,
+ tags=[
+ self._get_datadog_tags(standard_logging_object=standard_logging_payload)
+ ],
+ )
+
+ def _get_response_messages(self, response_obj: Any) -> List[Any]:
+ """
+ Get the messages from the response object
+
+ for now this handles logging /chat/completions responses
+ """
+ if isinstance(response_obj, litellm.ModelResponse):
+ return [response_obj["choices"][0]["message"].json()]
+ return []
+
+ def _ensure_string_content(
+ self, messages: Optional[Union[str, List[Any], Dict[Any, Any]]]
+ ) -> List[Any]:
+ if messages is None:
+ return []
+ if isinstance(messages, str):
+ return [messages]
+ elif isinstance(messages, list):
+ return [message for message in messages]
+ elif isinstance(messages, dict):
+ return [str(messages.get("content", ""))]
+ return []
+
+ def _get_dd_llm_obs_payload_metadata(
+ self, standard_logging_payload: StandardLoggingPayload
+ ) -> Dict:
+ _metadata = {
+ "model_name": standard_logging_payload.get("model", "unknown"),
+ "model_provider": standard_logging_payload.get(
+ "custom_llm_provider", "unknown"
+ ),
+ }
+ _standard_logging_metadata: dict = (
+ dict(standard_logging_payload.get("metadata", {})) or {}
+ )
+ _metadata.update(_standard_logging_metadata)
+ return _metadata
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/dynamodb.py b/.venv/lib/python3.12/site-packages/litellm/integrations/dynamodb.py
new file mode 100644
index 00000000..2c527ea8
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/dynamodb.py
@@ -0,0 +1,89 @@
+#### What this does ####
+# On success + failure, log events to Supabase
+
+import os
+import traceback
+import uuid
+from typing import Any
+
+import litellm
+
+
+class DyanmoDBLogger:
+ # Class variables or attributes
+
+ def __init__(self):
+ # Instance variables
+ import boto3
+
+ self.dynamodb: Any = boto3.resource(
+ "dynamodb", region_name=os.environ["AWS_REGION_NAME"]
+ )
+ if litellm.dynamodb_table_name is None:
+ raise ValueError(
+ "LiteLLM Error, trying to use DynamoDB but not table name passed. Create a table and set `litellm.dynamodb_table_name=<your-table>`"
+ )
+ self.table_name = litellm.dynamodb_table_name
+
+ async def _async_log_event(
+ self, kwargs, response_obj, start_time, end_time, print_verbose
+ ):
+ self.log_event(kwargs, response_obj, start_time, end_time, print_verbose)
+
+ def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
+ try:
+ print_verbose(
+ f"DynamoDB Logging - Enters logging function for model {kwargs}"
+ )
+
+ # construct payload to send to DynamoDB
+ # follows the same params as langfuse.py
+ litellm_params = kwargs.get("litellm_params", {})
+ metadata = (
+ litellm_params.get("metadata", {}) or {}
+ ) # if litellm_params['metadata'] == None
+ messages = kwargs.get("messages")
+ optional_params = kwargs.get("optional_params", {})
+ call_type = kwargs.get("call_type", "litellm.completion")
+ usage = response_obj["usage"]
+ id = response_obj.get("id", str(uuid.uuid4()))
+
+ # Build the initial payload
+ payload = {
+ "id": id,
+ "call_type": call_type,
+ "startTime": start_time,
+ "endTime": end_time,
+ "model": kwargs.get("model", ""),
+ "user": kwargs.get("user", ""),
+ "modelParameters": optional_params,
+ "messages": messages,
+ "response": response_obj,
+ "usage": usage,
+ "metadata": metadata,
+ }
+
+ # Ensure everything in the payload is converted to str
+ for key, value in payload.items():
+ try:
+ payload[key] = str(value)
+ except Exception:
+ # non blocking if it can't cast to a str
+ pass
+
+ print_verbose(f"\nDynamoDB Logger - Logging payload = {payload}")
+
+ # put data in dyanmo DB
+ table = self.dynamodb.Table(self.table_name)
+ # Assuming log_data is a dictionary with log information
+ response = table.put_item(Item=payload)
+
+ print_verbose(f"Response from DynamoDB:{str(response)}")
+
+ print_verbose(
+ f"DynamoDB Layer Logging - final response object: {response_obj}"
+ )
+ return response
+ except Exception:
+ print_verbose(f"DynamoDB Layer Error - {traceback.format_exc()}")
+ pass
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/email_alerting.py b/.venv/lib/python3.12/site-packages/litellm/integrations/email_alerting.py
new file mode 100644
index 00000000..b45b9aa7
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/email_alerting.py
@@ -0,0 +1,136 @@
+"""
+Functions for sending Email Alerts
+"""
+
+import os
+from typing import List, Optional
+
+from litellm._logging import verbose_logger, verbose_proxy_logger
+from litellm.proxy._types import WebhookEvent
+
+# we use this for the email header, please send a test email if you change this. verify it looks good on email
+LITELLM_LOGO_URL = "https://litellm-listing.s3.amazonaws.com/litellm_logo.png"
+LITELLM_SUPPORT_CONTACT = "support@berri.ai"
+
+
+async def get_all_team_member_emails(team_id: Optional[str] = None) -> list:
+ verbose_logger.debug(
+ "Email Alerting: Getting all team members for team_id=%s", team_id
+ )
+ if team_id is None:
+ return []
+ from litellm.proxy.proxy_server import prisma_client
+
+ if prisma_client is None:
+ raise Exception("Not connected to DB!")
+
+ team_row = await prisma_client.db.litellm_teamtable.find_unique(
+ where={
+ "team_id": team_id,
+ }
+ )
+
+ if team_row is None:
+ return []
+
+ _team_members = team_row.members_with_roles
+ verbose_logger.debug(
+ "Email Alerting: Got team members for team_id=%s Team Members: %s",
+ team_id,
+ _team_members,
+ )
+ _team_member_user_ids: List[str] = []
+ for member in _team_members:
+ if member and isinstance(member, dict):
+ _user_id = member.get("user_id")
+ if _user_id and isinstance(_user_id, str):
+ _team_member_user_ids.append(_user_id)
+
+ sql_query = """
+ SELECT user_email
+ FROM "LiteLLM_UserTable"
+ WHERE user_id = ANY($1::TEXT[]);
+ """
+
+ _result = await prisma_client.db.query_raw(sql_query, _team_member_user_ids)
+
+ verbose_logger.debug("Email Alerting: Got all Emails for team, emails=%s", _result)
+
+ if _result is None:
+ return []
+
+ emails = []
+ for user in _result:
+ if user and isinstance(user, dict) and user.get("user_email", None) is not None:
+ emails.append(user.get("user_email"))
+ return emails
+
+
+async def send_team_budget_alert(webhook_event: WebhookEvent) -> bool:
+ """
+ Send an Email Alert to All Team Members when the Team Budget is crossed
+ Returns -> True if sent, False if not.
+ """
+ from litellm.proxy.utils import send_email
+
+ _team_id = webhook_event.team_id
+ team_alias = webhook_event.team_alias
+ verbose_logger.debug(
+ "Email Alerting: Sending Team Budget Alert for team=%s", team_alias
+ )
+
+ email_logo_url = os.getenv("SMTP_SENDER_LOGO", os.getenv("EMAIL_LOGO_URL", None))
+ email_support_contact = os.getenv("EMAIL_SUPPORT_CONTACT", None)
+
+ # await self._check_if_using_premium_email_feature(
+ # premium_user, email_logo_url, email_support_contact
+ # )
+
+ if email_logo_url is None:
+ email_logo_url = LITELLM_LOGO_URL
+ if email_support_contact is None:
+ email_support_contact = LITELLM_SUPPORT_CONTACT
+ recipient_emails = await get_all_team_member_emails(_team_id)
+ recipient_emails_str: str = ",".join(recipient_emails)
+ verbose_logger.debug(
+ "Email Alerting: Sending team budget alert to %s", recipient_emails_str
+ )
+
+ event_name = webhook_event.event_message
+ max_budget = webhook_event.max_budget
+ email_html_content = "Alert from LiteLLM Server"
+
+ if recipient_emails_str is None:
+ verbose_proxy_logger.warning(
+ "Email Alerting: Trying to send email alert to no recipient, got recipient_emails=%s",
+ recipient_emails_str,
+ )
+
+ email_html_content = f"""
+ <img src="{email_logo_url}" alt="LiteLLM Logo" width="150" height="50" /> <br/><br/><br/>
+
+ Budget Crossed for Team <b> {team_alias} </b> <br/> <br/>
+
+ Your Teams LLM API usage has crossed it's <b> budget of ${max_budget} </b>, current spend is <b>${webhook_event.spend}</b><br /> <br />
+
+ API requests will be rejected until either (a) you increase your budget or (b) your budget gets reset <br /> <br />
+
+ If you have any questions, please send an email to {email_support_contact} <br /> <br />
+
+ Best, <br />
+ The LiteLLM team <br />
+ """
+
+ email_event = {
+ "to": recipient_emails_str,
+ "subject": f"LiteLLM {event_name} for Team {team_alias}",
+ "html": email_html_content,
+ }
+
+ await send_email(
+ receiver_email=email_event["to"],
+ subject=email_event["subject"],
+ html=email_event["html"],
+ )
+
+ return False
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/email_templates/templates.py b/.venv/lib/python3.12/site-packages/litellm/integrations/email_templates/templates.py
new file mode 100644
index 00000000..7029e8ce
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/email_templates/templates.py
@@ -0,0 +1,62 @@
+"""
+Email Templates used by the LiteLLM Email Service in slack_alerting.py
+"""
+
+KEY_CREATED_EMAIL_TEMPLATE = """
+ <img src="{email_logo_url}" alt="LiteLLM Logo" width="150" height="50" />
+
+ <p> Hi {recipient_email}, <br/>
+
+ I'm happy to provide you with an OpenAI Proxy API Key, loaded with ${key_budget} per month. <br /> <br />
+
+ <b>
+ Key: <pre>{key_token}</pre> <br>
+ </b>
+
+ <h2>Usage Example</h2>
+
+ Detailed Documentation on <a href="https://docs.litellm.ai/docs/proxy/user_keys">Usage with OpenAI Python SDK, Langchain, LlamaIndex, Curl</a>
+
+ <pre>
+
+ import openai
+ client = openai.OpenAI(
+ api_key="{key_token}",
+ base_url={{base_url}}
+ )
+
+ response = client.chat.completions.create(
+ model="gpt-3.5-turbo", # model to send to the proxy
+ messages = [
+ {{
+ "role": "user",
+ "content": "this is a test request, write a short poem"
+ }}
+ ]
+ )
+
+ </pre>
+
+
+ If you have any questions, please send an email to {email_support_contact} <br /> <br />
+
+ Best, <br />
+ The LiteLLM team <br />
+"""
+
+
+USER_INVITED_EMAIL_TEMPLATE = """
+ <img src="{email_logo_url}" alt="LiteLLM Logo" width="150" height="50" />
+
+ <p> Hi {recipient_email}, <br/>
+
+ You were invited to use OpenAI Proxy API for team {team_name} <br /> <br />
+
+ <a href="{base_url}" style="display: inline-block; padding: 10px 20px; background-color: #87ceeb; color: #fff; text-decoration: none; border-radius: 20px;">Get Started here</a> <br /> <br />
+
+
+ If you have any questions, please send an email to {email_support_contact} <br /> <br />
+
+ Best, <br />
+ The LiteLLM team <br />
+"""
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/galileo.py b/.venv/lib/python3.12/site-packages/litellm/integrations/galileo.py
new file mode 100644
index 00000000..e99d5f23
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/galileo.py
@@ -0,0 +1,157 @@
+import os
+from typing import Any, Dict, List, Optional
+
+from pydantic import BaseModel, Field
+
+import litellm
+from litellm._logging import verbose_logger
+from litellm.integrations.custom_logger import CustomLogger
+from litellm.llms.custom_httpx.http_handler import (
+ get_async_httpx_client,
+ httpxSpecialProvider,
+)
+
+
+# from here: https://docs.rungalileo.io/galileo/gen-ai-studio-products/galileo-observe/how-to/logging-data-via-restful-apis#structuring-your-records
+class LLMResponse(BaseModel):
+ latency_ms: int
+ status_code: int
+ input_text: str
+ output_text: str
+ node_type: str
+ model: str
+ num_input_tokens: int
+ num_output_tokens: int
+ output_logprobs: Optional[Dict[str, Any]] = Field(
+ default=None,
+ description="Optional. When available, logprobs are used to compute Uncertainty.",
+ )
+ created_at: str = Field(
+ ..., description='timestamp constructed in "%Y-%m-%dT%H:%M:%S" format'
+ )
+ tags: Optional[List[str]] = None
+ user_metadata: Optional[Dict[str, Any]] = None
+
+
+class GalileoObserve(CustomLogger):
+ def __init__(self) -> None:
+ self.in_memory_records: List[dict] = []
+ self.batch_size = 1
+ self.base_url = os.getenv("GALILEO_BASE_URL", None)
+ self.project_id = os.getenv("GALILEO_PROJECT_ID", None)
+ self.headers: Optional[Dict[str, str]] = None
+ self.async_httpx_handler = get_async_httpx_client(
+ llm_provider=httpxSpecialProvider.LoggingCallback
+ )
+ pass
+
+ def set_galileo_headers(self):
+ # following https://docs.rungalileo.io/galileo/gen-ai-studio-products/galileo-observe/how-to/logging-data-via-restful-apis#logging-your-records
+
+ headers = {
+ "accept": "application/json",
+ "Content-Type": "application/x-www-form-urlencoded",
+ }
+ galileo_login_response = litellm.module_level_client.post(
+ url=f"{self.base_url}/login",
+ headers=headers,
+ data={
+ "username": os.getenv("GALILEO_USERNAME"),
+ "password": os.getenv("GALILEO_PASSWORD"),
+ },
+ )
+
+ access_token = galileo_login_response.json()["access_token"]
+
+ self.headers = {
+ "accept": "application/json",
+ "Content-Type": "application/json",
+ "Authorization": f"Bearer {access_token}",
+ }
+
+ def get_output_str_from_response(self, response_obj, kwargs):
+ output = None
+ if response_obj is not None and (
+ kwargs.get("call_type", None) == "embedding"
+ or isinstance(response_obj, litellm.EmbeddingResponse)
+ ):
+ output = None
+ elif response_obj is not None and isinstance(
+ response_obj, litellm.ModelResponse
+ ):
+ output = response_obj["choices"][0]["message"].json()
+ elif response_obj is not None and isinstance(
+ response_obj, litellm.TextCompletionResponse
+ ):
+ output = response_obj.choices[0].text
+ elif response_obj is not None and isinstance(
+ response_obj, litellm.ImageResponse
+ ):
+ output = response_obj["data"]
+
+ return output
+
+ async def async_log_success_event(
+ self, kwargs: Any, response_obj: Any, start_time: Any, end_time: Any
+ ):
+ verbose_logger.debug("On Async Success")
+
+ _latency_ms = int((end_time - start_time).total_seconds() * 1000)
+ _call_type = kwargs.get("call_type", "litellm")
+ input_text = litellm.utils.get_formatted_prompt(
+ data=kwargs, call_type=_call_type
+ )
+
+ _usage = response_obj.get("usage", {}) or {}
+ num_input_tokens = _usage.get("prompt_tokens", 0)
+ num_output_tokens = _usage.get("completion_tokens", 0)
+
+ output_text = self.get_output_str_from_response(
+ response_obj=response_obj, kwargs=kwargs
+ )
+
+ if output_text is not None:
+ request_record = LLMResponse(
+ latency_ms=_latency_ms,
+ status_code=200,
+ input_text=input_text,
+ output_text=output_text,
+ node_type=_call_type,
+ model=kwargs.get("model", "-"),
+ num_input_tokens=num_input_tokens,
+ num_output_tokens=num_output_tokens,
+ created_at=start_time.strftime(
+ "%Y-%m-%dT%H:%M:%S"
+ ), # timestamp str constructed in "%Y-%m-%dT%H:%M:%S" format
+ )
+
+ # dump to dict
+ request_dict = request_record.model_dump()
+ self.in_memory_records.append(request_dict)
+
+ if len(self.in_memory_records) >= self.batch_size:
+ await self.flush_in_memory_records()
+
+ async def flush_in_memory_records(self):
+ verbose_logger.debug("flushing in memory records")
+ response = await self.async_httpx_handler.post(
+ url=f"{self.base_url}/projects/{self.project_id}/observe/ingest",
+ headers=self.headers,
+ json={"records": self.in_memory_records},
+ )
+
+ if response.status_code == 200:
+ verbose_logger.debug(
+ "Galileo Logger:successfully flushed in memory records"
+ )
+ self.in_memory_records = []
+ else:
+ verbose_logger.debug("Galileo Logger: failed to flush in memory records")
+ verbose_logger.debug(
+ "Galileo Logger error=%s, status code=%s",
+ response.text,
+ response.status_code,
+ )
+
+ async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
+ verbose_logger.debug("On Async Failure")
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/gcs_bucket/Readme.md b/.venv/lib/python3.12/site-packages/litellm/integrations/gcs_bucket/Readme.md
new file mode 100644
index 00000000..2ab0b233
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/gcs_bucket/Readme.md
@@ -0,0 +1,12 @@
+# GCS (Google Cloud Storage) Bucket Logging on LiteLLM Gateway
+
+This folder contains the GCS Bucket Logging integration for LiteLLM Gateway.
+
+## Folder Structure
+
+- `gcs_bucket.py`: This is the main file that handles failure/success logging to GCS Bucket
+- `gcs_bucket_base.py`: This file contains the GCSBucketBase class which handles Authentication for GCS Buckets
+
+## Further Reading
+- [Doc setting up GCS Bucket Logging on LiteLLM Proxy (Gateway)](https://docs.litellm.ai/docs/proxy/bucket)
+- [Doc on Key / Team Based logging with GCS](https://docs.litellm.ai/docs/proxy/team_logging) \ No newline at end of file
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/gcs_bucket/gcs_bucket.py b/.venv/lib/python3.12/site-packages/litellm/integrations/gcs_bucket/gcs_bucket.py
new file mode 100644
index 00000000..187ab779
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/gcs_bucket/gcs_bucket.py
@@ -0,0 +1,237 @@
+import asyncio
+import json
+import os
+import uuid
+from datetime import datetime, timedelta, timezone
+from typing import TYPE_CHECKING, Any, Dict, List, Optional
+from urllib.parse import quote
+
+from litellm._logging import verbose_logger
+from litellm.integrations.additional_logging_utils import AdditionalLoggingUtils
+from litellm.integrations.gcs_bucket.gcs_bucket_base import GCSBucketBase
+from litellm.proxy._types import CommonProxyErrors
+from litellm.types.integrations.base_health_check import IntegrationHealthCheckStatus
+from litellm.types.integrations.gcs_bucket import *
+from litellm.types.utils import StandardLoggingPayload
+
+if TYPE_CHECKING:
+ from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
+else:
+ VertexBase = Any
+
+
+GCS_DEFAULT_BATCH_SIZE = 2048
+GCS_DEFAULT_FLUSH_INTERVAL_SECONDS = 20
+
+
+class GCSBucketLogger(GCSBucketBase, AdditionalLoggingUtils):
+ def __init__(self, bucket_name: Optional[str] = None) -> None:
+ from litellm.proxy.proxy_server import premium_user
+
+ super().__init__(bucket_name=bucket_name)
+
+ # Init Batch logging settings
+ self.log_queue: List[GCSLogQueueItem] = []
+ self.batch_size = int(os.getenv("GCS_BATCH_SIZE", GCS_DEFAULT_BATCH_SIZE))
+ self.flush_interval = int(
+ os.getenv("GCS_FLUSH_INTERVAL", GCS_DEFAULT_FLUSH_INTERVAL_SECONDS)
+ )
+ asyncio.create_task(self.periodic_flush())
+ self.flush_lock = asyncio.Lock()
+ super().__init__(
+ flush_lock=self.flush_lock,
+ batch_size=self.batch_size,
+ flush_interval=self.flush_interval,
+ )
+ AdditionalLoggingUtils.__init__(self)
+
+ if premium_user is not True:
+ raise ValueError(
+ f"GCS Bucket logging is a premium feature. Please upgrade to use it. {CommonProxyErrors.not_premium_user.value}"
+ )
+
+ #### ASYNC ####
+ async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
+ from litellm.proxy.proxy_server import premium_user
+
+ if premium_user is not True:
+ raise ValueError(
+ f"GCS Bucket logging is a premium feature. Please upgrade to use it. {CommonProxyErrors.not_premium_user.value}"
+ )
+ try:
+ verbose_logger.debug(
+ "GCS Logger: async_log_success_event logging kwargs: %s, response_obj: %s",
+ kwargs,
+ response_obj,
+ )
+ logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
+ "standard_logging_object", None
+ )
+ if logging_payload is None:
+ raise ValueError("standard_logging_object not found in kwargs")
+ # Add to logging queue - this will be flushed periodically
+ self.log_queue.append(
+ GCSLogQueueItem(
+ payload=logging_payload, kwargs=kwargs, response_obj=response_obj
+ )
+ )
+
+ except Exception as e:
+ verbose_logger.exception(f"GCS Bucket logging error: {str(e)}")
+
+ async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
+ try:
+ verbose_logger.debug(
+ "GCS Logger: async_log_failure_event logging kwargs: %s, response_obj: %s",
+ kwargs,
+ response_obj,
+ )
+
+ logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
+ "standard_logging_object", None
+ )
+ if logging_payload is None:
+ raise ValueError("standard_logging_object not found in kwargs")
+ # Add to logging queue - this will be flushed periodically
+ self.log_queue.append(
+ GCSLogQueueItem(
+ payload=logging_payload, kwargs=kwargs, response_obj=response_obj
+ )
+ )
+
+ except Exception as e:
+ verbose_logger.exception(f"GCS Bucket logging error: {str(e)}")
+
+ async def async_send_batch(self):
+ """
+ Process queued logs in batch - sends logs to GCS Bucket
+
+
+ GCS Bucket does not have a Batch endpoint to batch upload logs
+
+ Instead, we
+ - collect the logs to flush every `GCS_FLUSH_INTERVAL` seconds
+ - during async_send_batch, we make 1 POST request per log to GCS Bucket
+
+ """
+ if not self.log_queue:
+ return
+
+ for log_item in self.log_queue:
+ logging_payload = log_item["payload"]
+ kwargs = log_item["kwargs"]
+ response_obj = log_item.get("response_obj", None) or {}
+
+ gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
+ kwargs
+ )
+ headers = await self.construct_request_headers(
+ vertex_instance=gcs_logging_config["vertex_instance"],
+ service_account_json=gcs_logging_config["path_service_account"],
+ )
+ bucket_name = gcs_logging_config["bucket_name"]
+ object_name = self._get_object_name(kwargs, logging_payload, response_obj)
+
+ try:
+ await self._log_json_data_on_gcs(
+ headers=headers,
+ bucket_name=bucket_name,
+ object_name=object_name,
+ logging_payload=logging_payload,
+ )
+ except Exception as e:
+ # don't let one log item fail the entire batch
+ verbose_logger.exception(
+ f"GCS Bucket error logging payload to GCS bucket: {str(e)}"
+ )
+ pass
+
+ # Clear the queue after processing
+ self.log_queue.clear()
+
+ def _get_object_name(
+ self, kwargs: Dict, logging_payload: StandardLoggingPayload, response_obj: Any
+ ) -> str:
+ """
+ Get the object name to use for the current payload
+ """
+ current_date = self._get_object_date_from_datetime(datetime.now(timezone.utc))
+ if logging_payload.get("error_str", None) is not None:
+ object_name = self._generate_failure_object_name(
+ request_date_str=current_date,
+ )
+ else:
+ object_name = self._generate_success_object_name(
+ request_date_str=current_date,
+ response_id=response_obj.get("id", ""),
+ )
+
+ # used for testing
+ _litellm_params = kwargs.get("litellm_params", None) or {}
+ _metadata = _litellm_params.get("metadata", None) or {}
+ if "gcs_log_id" in _metadata:
+ object_name = _metadata["gcs_log_id"]
+
+ return object_name
+
+ async def get_request_response_payload(
+ self,
+ request_id: str,
+ start_time_utc: Optional[datetime],
+ end_time_utc: Optional[datetime],
+ ) -> Optional[dict]:
+ """
+ Get the request and response payload for a given `request_id`
+ Tries current day, next day, and previous day until it finds the payload
+ """
+ if start_time_utc is None:
+ raise ValueError(
+ "start_time_utc is required for getting a payload from GCS Bucket"
+ )
+
+ # Try current day, next day, and previous day
+ dates_to_try = [
+ start_time_utc,
+ start_time_utc + timedelta(days=1),
+ start_time_utc - timedelta(days=1),
+ ]
+ date_str = None
+ for date in dates_to_try:
+ try:
+ date_str = self._get_object_date_from_datetime(datetime_obj=date)
+ object_name = self._generate_success_object_name(
+ request_date_str=date_str,
+ response_id=request_id,
+ )
+ encoded_object_name = quote(object_name, safe="")
+ response = await self.download_gcs_object(encoded_object_name)
+
+ if response is not None:
+ loaded_response = json.loads(response)
+ return loaded_response
+ except Exception as e:
+ verbose_logger.debug(
+ f"Failed to fetch payload for date {date_str}: {str(e)}"
+ )
+ continue
+
+ return None
+
+ def _generate_success_object_name(
+ self,
+ request_date_str: str,
+ response_id: str,
+ ) -> str:
+ return f"{request_date_str}/{response_id}"
+
+ def _generate_failure_object_name(
+ self,
+ request_date_str: str,
+ ) -> str:
+ return f"{request_date_str}/failure-{uuid.uuid4().hex}"
+
+ def _get_object_date_from_datetime(self, datetime_obj: datetime) -> str:
+ return datetime_obj.strftime("%Y-%m-%d")
+
+ async def async_health_check(self) -> IntegrationHealthCheckStatus:
+ raise NotImplementedError("GCS Bucket does not support health check")
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/gcs_bucket/gcs_bucket_base.py b/.venv/lib/python3.12/site-packages/litellm/integrations/gcs_bucket/gcs_bucket_base.py
new file mode 100644
index 00000000..66995d84
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/gcs_bucket/gcs_bucket_base.py
@@ -0,0 +1,326 @@
+import json
+import os
+from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
+
+from litellm._logging import verbose_logger
+from litellm.integrations.custom_batch_logger import CustomBatchLogger
+from litellm.llms.custom_httpx.http_handler import (
+ get_async_httpx_client,
+ httpxSpecialProvider,
+)
+from litellm.types.integrations.gcs_bucket import *
+from litellm.types.utils import StandardCallbackDynamicParams, StandardLoggingPayload
+
+if TYPE_CHECKING:
+ from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
+else:
+ VertexBase = Any
+IAM_AUTH_KEY = "IAM_AUTH"
+
+
+class GCSBucketBase(CustomBatchLogger):
+ def __init__(self, bucket_name: Optional[str] = None, **kwargs) -> None:
+ self.async_httpx_client = get_async_httpx_client(
+ llm_provider=httpxSpecialProvider.LoggingCallback
+ )
+ _path_service_account = os.getenv("GCS_PATH_SERVICE_ACCOUNT")
+ _bucket_name = bucket_name or os.getenv("GCS_BUCKET_NAME")
+ self.path_service_account_json: Optional[str] = _path_service_account
+ self.BUCKET_NAME: Optional[str] = _bucket_name
+ self.vertex_instances: Dict[str, VertexBase] = {}
+ super().__init__(**kwargs)
+
+ async def construct_request_headers(
+ self,
+ service_account_json: Optional[str],
+ vertex_instance: Optional[VertexBase] = None,
+ ) -> Dict[str, str]:
+ from litellm import vertex_chat_completion
+
+ if vertex_instance is None:
+ vertex_instance = vertex_chat_completion
+
+ _auth_header, vertex_project = await vertex_instance._ensure_access_token_async(
+ credentials=service_account_json,
+ project_id=None,
+ custom_llm_provider="vertex_ai",
+ )
+
+ auth_header, _ = vertex_instance._get_token_and_url(
+ model="gcs-bucket",
+ auth_header=_auth_header,
+ vertex_credentials=service_account_json,
+ vertex_project=vertex_project,
+ vertex_location=None,
+ gemini_api_key=None,
+ stream=None,
+ custom_llm_provider="vertex_ai",
+ api_base=None,
+ )
+ verbose_logger.debug("constructed auth_header %s", auth_header)
+ headers = {
+ "Authorization": f"Bearer {auth_header}", # auth_header
+ "Content-Type": "application/json",
+ }
+
+ return headers
+
+ def sync_construct_request_headers(self) -> Dict[str, str]:
+ from litellm import vertex_chat_completion
+
+ _auth_header, vertex_project = vertex_chat_completion._ensure_access_token(
+ credentials=self.path_service_account_json,
+ project_id=None,
+ custom_llm_provider="vertex_ai",
+ )
+
+ auth_header, _ = vertex_chat_completion._get_token_and_url(
+ model="gcs-bucket",
+ auth_header=_auth_header,
+ vertex_credentials=self.path_service_account_json,
+ vertex_project=vertex_project,
+ vertex_location=None,
+ gemini_api_key=None,
+ stream=None,
+ custom_llm_provider="vertex_ai",
+ api_base=None,
+ )
+ verbose_logger.debug("constructed auth_header %s", auth_header)
+ headers = {
+ "Authorization": f"Bearer {auth_header}", # auth_header
+ "Content-Type": "application/json",
+ }
+
+ return headers
+
+ def _handle_folders_in_bucket_name(
+ self,
+ bucket_name: str,
+ object_name: str,
+ ) -> Tuple[str, str]:
+ """
+ Handles when the user passes a bucket name with a folder postfix
+
+
+ Example:
+ - Bucket name: "my-bucket/my-folder/dev"
+ - Object name: "my-object"
+ - Returns: bucket_name="my-bucket", object_name="my-folder/dev/my-object"
+
+ """
+ if "/" in bucket_name:
+ bucket_name, prefix = bucket_name.split("/", 1)
+ object_name = f"{prefix}/{object_name}"
+ return bucket_name, object_name
+ return bucket_name, object_name
+
+ async def get_gcs_logging_config(
+ self, kwargs: Optional[Dict[str, Any]] = {}
+ ) -> GCSLoggingConfig:
+ """
+ This function is used to get the GCS logging config for the GCS Bucket Logger.
+ It checks if the dynamic parameters are provided in the kwargs and uses them to get the GCS logging config.
+ If no dynamic parameters are provided, it uses the default values.
+ """
+ if kwargs is None:
+ kwargs = {}
+
+ standard_callback_dynamic_params: Optional[StandardCallbackDynamicParams] = (
+ kwargs.get("standard_callback_dynamic_params", None)
+ )
+
+ bucket_name: str
+ path_service_account: Optional[str]
+ if standard_callback_dynamic_params is not None:
+ verbose_logger.debug("Using dynamic GCS logging")
+ verbose_logger.debug(
+ "standard_callback_dynamic_params: %s", standard_callback_dynamic_params
+ )
+
+ _bucket_name: Optional[str] = (
+ standard_callback_dynamic_params.get("gcs_bucket_name", None)
+ or self.BUCKET_NAME
+ )
+ _path_service_account: Optional[str] = (
+ standard_callback_dynamic_params.get("gcs_path_service_account", None)
+ or self.path_service_account_json
+ )
+
+ if _bucket_name is None:
+ raise ValueError(
+ "GCS_BUCKET_NAME is not set in the environment, but GCS Bucket is being used as a logging callback. Please set 'GCS_BUCKET_NAME' in the environment."
+ )
+ bucket_name = _bucket_name
+ path_service_account = _path_service_account
+ vertex_instance = await self.get_or_create_vertex_instance(
+ credentials=path_service_account
+ )
+ else:
+ # If no dynamic parameters, use the default instance
+ if self.BUCKET_NAME is None:
+ raise ValueError(
+ "GCS_BUCKET_NAME is not set in the environment, but GCS Bucket is being used as a logging callback. Please set 'GCS_BUCKET_NAME' in the environment."
+ )
+ bucket_name = self.BUCKET_NAME
+ path_service_account = self.path_service_account_json
+ vertex_instance = await self.get_or_create_vertex_instance(
+ credentials=path_service_account
+ )
+
+ return GCSLoggingConfig(
+ bucket_name=bucket_name,
+ vertex_instance=vertex_instance,
+ path_service_account=path_service_account,
+ )
+
+ async def get_or_create_vertex_instance(
+ self, credentials: Optional[str]
+ ) -> VertexBase:
+ """
+ This function is used to get the Vertex instance for the GCS Bucket Logger.
+ It checks if the Vertex instance is already created and cached, if not it creates a new instance and caches it.
+ """
+ from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
+
+ _in_memory_key = self._get_in_memory_key_for_vertex_instance(credentials)
+ if _in_memory_key not in self.vertex_instances:
+ vertex_instance = VertexBase()
+ await vertex_instance._ensure_access_token_async(
+ credentials=credentials,
+ project_id=None,
+ custom_llm_provider="vertex_ai",
+ )
+ self.vertex_instances[_in_memory_key] = vertex_instance
+ return self.vertex_instances[_in_memory_key]
+
+ def _get_in_memory_key_for_vertex_instance(self, credentials: Optional[str]) -> str:
+ """
+ Returns key to use for caching the Vertex instance in-memory.
+
+ When using Vertex with Key based logging, we need to cache the Vertex instance in-memory.
+
+ - If a credentials string is provided, it is used as the key.
+ - If no credentials string is provided, "IAM_AUTH" is used as the key.
+ """
+ return credentials or IAM_AUTH_KEY
+
+ async def download_gcs_object(self, object_name: str, **kwargs):
+ """
+ Download an object from GCS.
+
+ https://cloud.google.com/storage/docs/downloading-objects#download-object-json
+ """
+ try:
+ gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
+ kwargs=kwargs
+ )
+ headers = await self.construct_request_headers(
+ vertex_instance=gcs_logging_config["vertex_instance"],
+ service_account_json=gcs_logging_config["path_service_account"],
+ )
+ bucket_name = gcs_logging_config["bucket_name"]
+ bucket_name, object_name = self._handle_folders_in_bucket_name(
+ bucket_name=bucket_name,
+ object_name=object_name,
+ )
+
+ url = f"https://storage.googleapis.com/storage/v1/b/{bucket_name}/o/{object_name}?alt=media"
+
+ # Send the GET request to download the object
+ response = await self.async_httpx_client.get(url=url, headers=headers)
+
+ if response.status_code != 200:
+ verbose_logger.error(
+ "GCS object download error: %s", str(response.text)
+ )
+ return None
+
+ verbose_logger.debug(
+ "GCS object download response status code: %s", response.status_code
+ )
+
+ # Return the content of the downloaded object
+ return response.content
+
+ except Exception as e:
+ verbose_logger.error("GCS object download error: %s", str(e))
+ return None
+
+ async def delete_gcs_object(self, object_name: str, **kwargs):
+ """
+ Delete an object from GCS.
+ """
+ try:
+ gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
+ kwargs=kwargs
+ )
+ headers = await self.construct_request_headers(
+ vertex_instance=gcs_logging_config["vertex_instance"],
+ service_account_json=gcs_logging_config["path_service_account"],
+ )
+ bucket_name = gcs_logging_config["bucket_name"]
+ bucket_name, object_name = self._handle_folders_in_bucket_name(
+ bucket_name=bucket_name,
+ object_name=object_name,
+ )
+
+ url = f"https://storage.googleapis.com/storage/v1/b/{bucket_name}/o/{object_name}"
+
+ # Send the DELETE request to delete the object
+ response = await self.async_httpx_client.delete(url=url, headers=headers)
+
+ if (response.status_code != 200) or (response.status_code != 204):
+ verbose_logger.error(
+ "GCS object delete error: %s, status code: %s",
+ str(response.text),
+ response.status_code,
+ )
+ return None
+
+ verbose_logger.debug(
+ "GCS object delete response status code: %s, response: %s",
+ response.status_code,
+ response.text,
+ )
+
+ # Return the content of the downloaded object
+ return response.text
+
+ except Exception as e:
+ verbose_logger.error("GCS object download error: %s", str(e))
+ return None
+
+ async def _log_json_data_on_gcs(
+ self,
+ headers: Dict[str, str],
+ bucket_name: str,
+ object_name: str,
+ logging_payload: Union[StandardLoggingPayload, str],
+ ):
+ """
+ Helper function to make POST request to GCS Bucket in the specified bucket.
+ """
+ if isinstance(logging_payload, str):
+ json_logged_payload = logging_payload
+ else:
+ json_logged_payload = json.dumps(logging_payload, default=str)
+
+ bucket_name, object_name = self._handle_folders_in_bucket_name(
+ bucket_name=bucket_name,
+ object_name=object_name,
+ )
+
+ response = await self.async_httpx_client.post(
+ headers=headers,
+ url=f"https://storage.googleapis.com/upload/storage/v1/b/{bucket_name}/o?uploadType=media&name={object_name}",
+ data=json_logged_payload,
+ )
+
+ if response.status_code != 200:
+ verbose_logger.error("GCS Bucket logging error: %s", str(response.text))
+
+ verbose_logger.debug("GCS Bucket response %s", response)
+ verbose_logger.debug("GCS Bucket status code %s", response.status_code)
+ verbose_logger.debug("GCS Bucket response.text %s", response.text)
+
+ return response.json()
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/gcs_pubsub/pub_sub.py b/.venv/lib/python3.12/site-packages/litellm/integrations/gcs_pubsub/pub_sub.py
new file mode 100644
index 00000000..e94c853f
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/gcs_pubsub/pub_sub.py
@@ -0,0 +1,203 @@
+"""
+BETA
+
+This is the PubSub logger for GCS PubSub, this sends LiteLLM SpendLogs Payloads to GCS PubSub.
+
+Users can use this instead of sending their SpendLogs to their Postgres database.
+"""
+
+import asyncio
+import json
+import os
+import traceback
+from typing import TYPE_CHECKING, Any, Dict, List, Optional
+
+if TYPE_CHECKING:
+ from litellm.proxy._types import SpendLogsPayload
+else:
+ SpendLogsPayload = Any
+
+from litellm._logging import verbose_logger
+from litellm.integrations.custom_batch_logger import CustomBatchLogger
+from litellm.llms.custom_httpx.http_handler import (
+ get_async_httpx_client,
+ httpxSpecialProvider,
+)
+
+
+class GcsPubSubLogger(CustomBatchLogger):
+ def __init__(
+ self,
+ project_id: Optional[str] = None,
+ topic_id: Optional[str] = None,
+ credentials_path: Optional[str] = None,
+ **kwargs,
+ ):
+ """
+ Initialize Google Cloud Pub/Sub publisher
+
+ Args:
+ project_id (str): Google Cloud project ID
+ topic_id (str): Pub/Sub topic ID
+ credentials_path (str, optional): Path to Google Cloud credentials JSON file
+ """
+ from litellm.proxy.utils import _premium_user_check
+
+ _premium_user_check()
+
+ self.async_httpx_client = get_async_httpx_client(
+ llm_provider=httpxSpecialProvider.LoggingCallback
+ )
+
+ self.project_id = project_id or os.getenv("GCS_PUBSUB_PROJECT_ID")
+ self.topic_id = topic_id or os.getenv("GCS_PUBSUB_TOPIC_ID")
+ self.path_service_account_json = credentials_path or os.getenv(
+ "GCS_PATH_SERVICE_ACCOUNT"
+ )
+
+ if not self.project_id or not self.topic_id:
+ raise ValueError("Both project_id and topic_id must be provided")
+
+ self.flush_lock = asyncio.Lock()
+ super().__init__(**kwargs, flush_lock=self.flush_lock)
+ asyncio.create_task(self.periodic_flush())
+ self.log_queue: List[SpendLogsPayload] = []
+
+ async def construct_request_headers(self) -> Dict[str, str]:
+ """Construct authorization headers using Vertex AI auth"""
+ from litellm import vertex_chat_completion
+
+ _auth_header, vertex_project = (
+ await vertex_chat_completion._ensure_access_token_async(
+ credentials=self.path_service_account_json,
+ project_id=None,
+ custom_llm_provider="vertex_ai",
+ )
+ )
+
+ auth_header, _ = vertex_chat_completion._get_token_and_url(
+ model="pub-sub",
+ auth_header=_auth_header,
+ vertex_credentials=self.path_service_account_json,
+ vertex_project=vertex_project,
+ vertex_location=None,
+ gemini_api_key=None,
+ stream=None,
+ custom_llm_provider="vertex_ai",
+ api_base=None,
+ )
+
+ headers = {
+ "Authorization": f"Bearer {auth_header}",
+ "Content-Type": "application/json",
+ }
+ return headers
+
+ async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
+ """
+ Async Log success events to GCS PubSub Topic
+
+ - Creates a SpendLogsPayload
+ - Adds to batch queue
+ - Flushes based on CustomBatchLogger settings
+
+ Raises:
+ Raises a NON Blocking verbose_logger.exception if an error occurs
+ """
+ from litellm.proxy.spend_tracking.spend_tracking_utils import (
+ get_logging_payload,
+ )
+ from litellm.proxy.utils import _premium_user_check
+
+ _premium_user_check()
+
+ try:
+ verbose_logger.debug(
+ "PubSub: Logging - Enters logging function for model %s", kwargs
+ )
+ spend_logs_payload = get_logging_payload(
+ kwargs=kwargs,
+ response_obj=response_obj,
+ start_time=start_time,
+ end_time=end_time,
+ )
+ self.log_queue.append(spend_logs_payload)
+
+ if len(self.log_queue) >= self.batch_size:
+ await self.async_send_batch()
+
+ except Exception as e:
+ verbose_logger.exception(
+ f"PubSub Layer Error - {str(e)}\n{traceback.format_exc()}"
+ )
+ pass
+
+ async def async_send_batch(self):
+ """
+ Sends the batch of messages to Pub/Sub
+ """
+ try:
+ if not self.log_queue:
+ return
+
+ verbose_logger.debug(
+ f"PubSub - about to flush {len(self.log_queue)} events"
+ )
+
+ for message in self.log_queue:
+ await self.publish_message(message)
+
+ except Exception as e:
+ verbose_logger.exception(
+ f"PubSub Error sending batch - {str(e)}\n{traceback.format_exc()}"
+ )
+ finally:
+ self.log_queue.clear()
+
+ async def publish_message(
+ self, message: SpendLogsPayload
+ ) -> Optional[Dict[str, Any]]:
+ """
+ Publish message to Google Cloud Pub/Sub using REST API
+
+ Args:
+ message: Message to publish (dict or string)
+
+ Returns:
+ dict: Published message response
+ """
+ try:
+ headers = await self.construct_request_headers()
+
+ # Prepare message data
+ if isinstance(message, str):
+ message_data = message
+ else:
+ message_data = json.dumps(message, default=str)
+
+ # Base64 encode the message
+ import base64
+
+ encoded_message = base64.b64encode(message_data.encode("utf-8")).decode(
+ "utf-8"
+ )
+
+ # Construct request body
+ request_body = {"messages": [{"data": encoded_message}]}
+
+ url = f"https://pubsub.googleapis.com/v1/projects/{self.project_id}/topics/{self.topic_id}:publish"
+
+ response = await self.async_httpx_client.post(
+ url=url, headers=headers, json=request_body
+ )
+
+ if response.status_code not in [200, 202]:
+ verbose_logger.error("Pub/Sub publish error: %s", str(response.text))
+ raise Exception(f"Failed to publish message: {response.text}")
+
+ verbose_logger.debug("Pub/Sub response: %s", response.text)
+ return response.json()
+
+ except Exception as e:
+ verbose_logger.error("Pub/Sub publish error: %s", str(e))
+ return None
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/greenscale.py b/.venv/lib/python3.12/site-packages/litellm/integrations/greenscale.py
new file mode 100644
index 00000000..430c3d0a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/greenscale.py
@@ -0,0 +1,72 @@
+import json
+import traceback
+from datetime import datetime, timezone
+
+import litellm
+
+
+class GreenscaleLogger:
+ def __init__(self):
+ import os
+
+ self.greenscale_api_key = os.getenv("GREENSCALE_API_KEY")
+ self.headers = {
+ "api-key": self.greenscale_api_key,
+ "Content-Type": "application/json",
+ }
+ self.greenscale_logging_url = os.getenv("GREENSCALE_ENDPOINT")
+
+ def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
+ try:
+ response_json = response_obj.model_dump() if response_obj else {}
+ data = {
+ "modelId": kwargs.get("model"),
+ "inputTokenCount": response_json.get("usage", {}).get("prompt_tokens"),
+ "outputTokenCount": response_json.get("usage", {}).get(
+ "completion_tokens"
+ ),
+ }
+ data["timestamp"] = datetime.now(timezone.utc).strftime(
+ "%Y-%m-%dT%H:%M:%SZ"
+ )
+
+ if type(end_time) is datetime and type(start_time) is datetime:
+ data["invocationLatency"] = int(
+ (end_time - start_time).total_seconds() * 1000
+ )
+
+ # Add additional metadata keys to tags
+ tags = []
+ metadata = kwargs.get("litellm_params", {}).get("metadata", {})
+ for key, value in metadata.items():
+ if key.startswith("greenscale"):
+ if key == "greenscale_project":
+ data["project"] = value
+ elif key == "greenscale_application":
+ data["application"] = value
+ else:
+ tags.append(
+ {"key": key.replace("greenscale_", ""), "value": str(value)}
+ )
+
+ data["tags"] = tags
+
+ if self.greenscale_logging_url is None:
+ raise Exception("Greenscale Logger Error - No logging URL found")
+
+ response = litellm.module_level_client.post(
+ self.greenscale_logging_url,
+ headers=self.headers,
+ data=json.dumps(data, default=str),
+ )
+ if response.status_code != 200:
+ print_verbose(
+ f"Greenscale Logger Error - {response.text}, {response.status_code}"
+ )
+ else:
+ print_verbose(f"Greenscale Logger Succeeded - {response.text}")
+ except Exception as e:
+ print_verbose(
+ f"Greenscale Logger Error - {e}, Stack trace: {traceback.format_exc()}"
+ )
+ pass
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/helicone.py b/.venv/lib/python3.12/site-packages/litellm/integrations/helicone.py
new file mode 100644
index 00000000..a526a74f
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/helicone.py
@@ -0,0 +1,188 @@
+#### What this does ####
+# On success, logs events to Helicone
+import os
+import traceback
+
+import litellm
+
+
+class HeliconeLogger:
+ # Class variables or attributes
+ helicone_model_list = [
+ "gpt",
+ "claude",
+ "command-r",
+ "command-r-plus",
+ "command-light",
+ "command-medium",
+ "command-medium-beta",
+ "command-xlarge-nightly",
+ "command-nightly",
+ ]
+
+ def __init__(self):
+ # Instance variables
+ self.provider_url = "https://api.openai.com/v1"
+ self.key = os.getenv("HELICONE_API_KEY")
+
+ def claude_mapping(self, model, messages, response_obj):
+ from anthropic import AI_PROMPT, HUMAN_PROMPT
+
+ prompt = f"{HUMAN_PROMPT}"
+ for message in messages:
+ if "role" in message:
+ if message["role"] == "user":
+ prompt += f"{HUMAN_PROMPT}{message['content']}"
+ else:
+ prompt += f"{AI_PROMPT}{message['content']}"
+ else:
+ prompt += f"{HUMAN_PROMPT}{message['content']}"
+ prompt += f"{AI_PROMPT}"
+
+ choice = response_obj["choices"][0]
+ message = choice["message"]
+
+ content = []
+ if "tool_calls" in message and message["tool_calls"]:
+ for tool_call in message["tool_calls"]:
+ content.append(
+ {
+ "type": "tool_use",
+ "id": tool_call["id"],
+ "name": tool_call["function"]["name"],
+ "input": tool_call["function"]["arguments"],
+ }
+ )
+ elif "content" in message and message["content"]:
+ content = [{"type": "text", "text": message["content"]}]
+
+ claude_response_obj = {
+ "id": response_obj["id"],
+ "type": "message",
+ "role": "assistant",
+ "model": model,
+ "content": content,
+ "stop_reason": choice["finish_reason"],
+ "stop_sequence": None,
+ "usage": {
+ "input_tokens": response_obj["usage"]["prompt_tokens"],
+ "output_tokens": response_obj["usage"]["completion_tokens"],
+ },
+ }
+
+ return claude_response_obj
+
+ @staticmethod
+ def add_metadata_from_header(litellm_params: dict, metadata: dict) -> dict:
+ """
+ Adds metadata from proxy request headers to Helicone logging if keys start with "helicone_"
+ and overwrites litellm_params.metadata if already included.
+
+ For example if you want to add custom property to your request, send
+ `headers: { ..., helicone-property-something: 1234 }` via proxy request.
+ """
+ if litellm_params is None:
+ return metadata
+
+ if litellm_params.get("proxy_server_request") is None:
+ return metadata
+
+ if metadata is None:
+ metadata = {}
+
+ proxy_headers = (
+ litellm_params.get("proxy_server_request", {}).get("headers", {}) or {}
+ )
+
+ for header_key in proxy_headers:
+ if header_key.startswith("helicone_"):
+ metadata[header_key] = proxy_headers.get(header_key)
+
+ return metadata
+
+ def log_success(
+ self, model, messages, response_obj, start_time, end_time, print_verbose, kwargs
+ ):
+ # Method definition
+ try:
+ print_verbose(
+ f"Helicone Logging - Enters logging function for model {model}"
+ )
+ litellm_params = kwargs.get("litellm_params", {})
+ kwargs.get("litellm_call_id", None)
+ metadata = litellm_params.get("metadata", {}) or {}
+ metadata = self.add_metadata_from_header(litellm_params, metadata)
+ model = (
+ model
+ if any(
+ accepted_model in model
+ for accepted_model in self.helicone_model_list
+ )
+ else "gpt-3.5-turbo"
+ )
+ provider_request = {"model": model, "messages": messages}
+ if isinstance(response_obj, litellm.EmbeddingResponse) or isinstance(
+ response_obj, litellm.ModelResponse
+ ):
+ response_obj = response_obj.json()
+
+ if "claude" in model:
+ response_obj = self.claude_mapping(
+ model=model, messages=messages, response_obj=response_obj
+ )
+
+ providerResponse = {
+ "json": response_obj,
+ "headers": {"openai-version": "2020-10-01"},
+ "status": 200,
+ }
+
+ # Code to be executed
+ provider_url = self.provider_url
+ url = "https://api.hconeai.com/oai/v1/log"
+ if "claude" in model:
+ url = "https://api.hconeai.com/anthropic/v1/log"
+ provider_url = "https://api.anthropic.com/v1/messages"
+ headers = {
+ "Authorization": f"Bearer {self.key}",
+ "Content-Type": "application/json",
+ }
+ start_time_seconds = int(start_time.timestamp())
+ start_time_milliseconds = int(
+ (start_time.timestamp() - start_time_seconds) * 1000
+ )
+ end_time_seconds = int(end_time.timestamp())
+ end_time_milliseconds = int(
+ (end_time.timestamp() - end_time_seconds) * 1000
+ )
+ meta = {"Helicone-Auth": f"Bearer {self.key}"}
+ meta.update(metadata)
+ data = {
+ "providerRequest": {
+ "url": provider_url,
+ "json": provider_request,
+ "meta": meta,
+ },
+ "providerResponse": providerResponse,
+ "timing": {
+ "startTime": {
+ "seconds": start_time_seconds,
+ "milliseconds": start_time_milliseconds,
+ },
+ "endTime": {
+ "seconds": end_time_seconds,
+ "milliseconds": end_time_milliseconds,
+ },
+ }, # {"seconds": .., "milliseconds": ..}
+ }
+ response = litellm.module_level_client.post(url, headers=headers, json=data)
+ if response.status_code == 200:
+ print_verbose("Helicone Logging - Success!")
+ else:
+ print_verbose(
+ f"Helicone Logging - Error Request was not successful. Status Code: {response.status_code}"
+ )
+ print_verbose(f"Helicone Logging - Error {response.text}")
+ except Exception:
+ print_verbose(f"Helicone Logging Error - {traceback.format_exc()}")
+ pass
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/humanloop.py b/.venv/lib/python3.12/site-packages/litellm/integrations/humanloop.py
new file mode 100644
index 00000000..fd3463f9
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/humanloop.py
@@ -0,0 +1,197 @@
+"""
+Humanloop integration
+
+https://humanloop.com/
+"""
+
+from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union, cast
+
+import httpx
+
+import litellm
+from litellm.caching import DualCache
+from litellm.llms.custom_httpx.http_handler import _get_httpx_client
+from litellm.secret_managers.main import get_secret_str
+from litellm.types.llms.openai import AllMessageValues
+from litellm.types.utils import StandardCallbackDynamicParams
+
+from .custom_logger import CustomLogger
+
+
+class PromptManagementClient(TypedDict):
+ prompt_id: str
+ prompt_template: List[AllMessageValues]
+ model: Optional[str]
+ optional_params: Optional[Dict[str, Any]]
+
+
+class HumanLoopPromptManager(DualCache):
+ @property
+ def integration_name(self):
+ return "humanloop"
+
+ def _get_prompt_from_id_cache(
+ self, humanloop_prompt_id: str
+ ) -> Optional[PromptManagementClient]:
+ return cast(
+ Optional[PromptManagementClient], self.get_cache(key=humanloop_prompt_id)
+ )
+
+ def _compile_prompt_helper(
+ self, prompt_template: List[AllMessageValues], prompt_variables: Dict[str, Any]
+ ) -> List[AllMessageValues]:
+ """
+ Helper function to compile the prompt by substituting variables in the template.
+
+ Args:
+ prompt_template: List[AllMessageValues]
+ prompt_variables (dict): A dictionary of variables to substitute into the prompt template.
+
+ Returns:
+ list: A list of dictionaries with variables substituted.
+ """
+ compiled_prompts: List[AllMessageValues] = []
+
+ for template in prompt_template:
+ tc = template.get("content")
+ if tc and isinstance(tc, str):
+ formatted_template = tc.replace("{{", "{").replace("}}", "}")
+ compiled_content = formatted_template.format(**prompt_variables)
+ template["content"] = compiled_content
+ compiled_prompts.append(template)
+
+ return compiled_prompts
+
+ def _get_prompt_from_id_api(
+ self, humanloop_prompt_id: str, humanloop_api_key: str
+ ) -> PromptManagementClient:
+ client = _get_httpx_client()
+
+ base_url = "https://api.humanloop.com/v5/prompts/{}".format(humanloop_prompt_id)
+
+ response = client.get(
+ url=base_url,
+ headers={
+ "X-Api-Key": humanloop_api_key,
+ "Content-Type": "application/json",
+ },
+ )
+
+ try:
+ response.raise_for_status()
+ except httpx.HTTPStatusError as e:
+ raise Exception(f"Error getting prompt from Humanloop: {e.response.text}")
+
+ json_response = response.json()
+ template_message = json_response["template"]
+ if isinstance(template_message, dict):
+ template_messages = [template_message]
+ elif isinstance(template_message, list):
+ template_messages = template_message
+ else:
+ raise ValueError(f"Invalid template message type: {type(template_message)}")
+ template_model = json_response["model"]
+ optional_params = {}
+ for k, v in json_response.items():
+ if k in litellm.OPENAI_CHAT_COMPLETION_PARAMS:
+ optional_params[k] = v
+ return PromptManagementClient(
+ prompt_id=humanloop_prompt_id,
+ prompt_template=cast(List[AllMessageValues], template_messages),
+ model=template_model,
+ optional_params=optional_params,
+ )
+
+ def _get_prompt_from_id(
+ self, humanloop_prompt_id: str, humanloop_api_key: str
+ ) -> PromptManagementClient:
+ prompt = self._get_prompt_from_id_cache(humanloop_prompt_id)
+ if prompt is None:
+ prompt = self._get_prompt_from_id_api(
+ humanloop_prompt_id, humanloop_api_key
+ )
+ self.set_cache(
+ key=humanloop_prompt_id,
+ value=prompt,
+ ttl=litellm.HUMANLOOP_PROMPT_CACHE_TTL_SECONDS,
+ )
+ return prompt
+
+ def compile_prompt(
+ self,
+ prompt_template: List[AllMessageValues],
+ prompt_variables: Optional[dict],
+ ) -> List[AllMessageValues]:
+ compiled_prompt: Optional[Union[str, list]] = None
+
+ if prompt_variables is None:
+ prompt_variables = {}
+
+ compiled_prompt = self._compile_prompt_helper(
+ prompt_template=prompt_template,
+ prompt_variables=prompt_variables,
+ )
+
+ return compiled_prompt
+
+ def _get_model_from_prompt(
+ self, prompt_management_client: PromptManagementClient, model: str
+ ) -> str:
+ if prompt_management_client["model"] is not None:
+ return prompt_management_client["model"]
+ else:
+ return model.replace("{}/".format(self.integration_name), "")
+
+
+prompt_manager = HumanLoopPromptManager()
+
+
+class HumanloopLogger(CustomLogger):
+ def get_chat_completion_prompt(
+ self,
+ model: str,
+ messages: List[AllMessageValues],
+ non_default_params: dict,
+ prompt_id: str,
+ prompt_variables: Optional[dict],
+ dynamic_callback_params: StandardCallbackDynamicParams,
+ ) -> Tuple[
+ str,
+ List[AllMessageValues],
+ dict,
+ ]:
+ humanloop_api_key = dynamic_callback_params.get(
+ "humanloop_api_key"
+ ) or get_secret_str("HUMANLOOP_API_KEY")
+
+ if humanloop_api_key is None:
+ return super().get_chat_completion_prompt(
+ model=model,
+ messages=messages,
+ non_default_params=non_default_params,
+ prompt_id=prompt_id,
+ prompt_variables=prompt_variables,
+ dynamic_callback_params=dynamic_callback_params,
+ )
+
+ prompt_template = prompt_manager._get_prompt_from_id(
+ humanloop_prompt_id=prompt_id, humanloop_api_key=humanloop_api_key
+ )
+
+ updated_messages = prompt_manager.compile_prompt(
+ prompt_template=prompt_template["prompt_template"],
+ prompt_variables=prompt_variables,
+ )
+
+ prompt_template_optional_params = prompt_template["optional_params"] or {}
+
+ updated_non_default_params = {
+ **non_default_params,
+ **prompt_template_optional_params,
+ }
+
+ model = prompt_manager._get_model_from_prompt(
+ prompt_management_client=prompt_template, model=model
+ )
+
+ return model, updated_messages, updated_non_default_params
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/lago.py b/.venv/lib/python3.12/site-packages/litellm/integrations/lago.py
new file mode 100644
index 00000000..5dfb1ce0
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/lago.py
@@ -0,0 +1,202 @@
+# What is this?
+## On Success events log cost to Lago - https://github.com/BerriAI/litellm/issues/3639
+
+import json
+import os
+import uuid
+from typing import Literal, Optional
+
+import httpx
+
+import litellm
+from litellm._logging import verbose_logger
+from litellm.integrations.custom_logger import CustomLogger
+from litellm.llms.custom_httpx.http_handler import (
+ HTTPHandler,
+ get_async_httpx_client,
+ httpxSpecialProvider,
+)
+
+
+def get_utc_datetime():
+ import datetime as dt
+ from datetime import datetime
+
+ if hasattr(dt, "UTC"):
+ return datetime.now(dt.UTC) # type: ignore
+ else:
+ return datetime.utcnow() # type: ignore
+
+
+class LagoLogger(CustomLogger):
+ def __init__(self) -> None:
+ super().__init__()
+ self.validate_environment()
+ self.async_http_handler = get_async_httpx_client(
+ llm_provider=httpxSpecialProvider.LoggingCallback
+ )
+ self.sync_http_handler = HTTPHandler()
+
+ def validate_environment(self):
+ """
+ Expects
+ LAGO_API_BASE,
+ LAGO_API_KEY,
+ LAGO_API_EVENT_CODE,
+
+ Optional:
+ LAGO_API_CHARGE_BY
+
+ in the environment
+ """
+ missing_keys = []
+ if os.getenv("LAGO_API_KEY", None) is None:
+ missing_keys.append("LAGO_API_KEY")
+
+ if os.getenv("LAGO_API_BASE", None) is None:
+ missing_keys.append("LAGO_API_BASE")
+
+ if os.getenv("LAGO_API_EVENT_CODE", None) is None:
+ missing_keys.append("LAGO_API_EVENT_CODE")
+
+ if len(missing_keys) > 0:
+ raise Exception("Missing keys={} in environment.".format(missing_keys))
+
+ def _common_logic(self, kwargs: dict, response_obj) -> dict:
+ response_obj.get("id", kwargs.get("litellm_call_id"))
+ get_utc_datetime().isoformat()
+ cost = kwargs.get("response_cost", None)
+ model = kwargs.get("model")
+ usage = {}
+
+ if (
+ isinstance(response_obj, litellm.ModelResponse)
+ or isinstance(response_obj, litellm.EmbeddingResponse)
+ ) and hasattr(response_obj, "usage"):
+ usage = {
+ "prompt_tokens": response_obj["usage"].get("prompt_tokens", 0),
+ "completion_tokens": response_obj["usage"].get("completion_tokens", 0),
+ "total_tokens": response_obj["usage"].get("total_tokens"),
+ }
+
+ litellm_params = kwargs.get("litellm_params", {}) or {}
+ proxy_server_request = litellm_params.get("proxy_server_request") or {}
+ end_user_id = proxy_server_request.get("body", {}).get("user", None)
+ user_id = litellm_params["metadata"].get("user_api_key_user_id", None)
+ team_id = litellm_params["metadata"].get("user_api_key_team_id", None)
+ litellm_params["metadata"].get("user_api_key_org_id", None)
+
+ charge_by: Literal["end_user_id", "team_id", "user_id"] = "end_user_id"
+ external_customer_id: Optional[str] = None
+
+ if os.getenv("LAGO_API_CHARGE_BY", None) is not None and isinstance(
+ os.environ["LAGO_API_CHARGE_BY"], str
+ ):
+ if os.environ["LAGO_API_CHARGE_BY"] in [
+ "end_user_id",
+ "user_id",
+ "team_id",
+ ]:
+ charge_by = os.environ["LAGO_API_CHARGE_BY"] # type: ignore
+ else:
+ raise Exception("invalid LAGO_API_CHARGE_BY set")
+
+ if charge_by == "end_user_id":
+ external_customer_id = end_user_id
+ elif charge_by == "team_id":
+ external_customer_id = team_id
+ elif charge_by == "user_id":
+ external_customer_id = user_id
+
+ if external_customer_id is None:
+ raise Exception(
+ "External Customer ID is not set. Charge_by={}. User_id={}. End_user_id={}. Team_id={}".format(
+ charge_by, user_id, end_user_id, team_id
+ )
+ )
+
+ returned_val = {
+ "event": {
+ "transaction_id": str(uuid.uuid4()),
+ "external_subscription_id": external_customer_id,
+ "code": os.getenv("LAGO_API_EVENT_CODE"),
+ "properties": {"model": model, "response_cost": cost, **usage},
+ }
+ }
+
+ verbose_logger.debug(
+ "\033[91mLogged Lago Object:\n{}\033[0m\n".format(returned_val)
+ )
+ return returned_val
+
+ def log_success_event(self, kwargs, response_obj, start_time, end_time):
+ _url = os.getenv("LAGO_API_BASE")
+ assert _url is not None and isinstance(
+ _url, str
+ ), "LAGO_API_BASE missing or not set correctly. LAGO_API_BASE={}".format(_url)
+ if _url.endswith("/"):
+ _url += "api/v1/events"
+ else:
+ _url += "/api/v1/events"
+
+ api_key = os.getenv("LAGO_API_KEY")
+
+ _data = self._common_logic(kwargs=kwargs, response_obj=response_obj)
+ _headers = {
+ "Content-Type": "application/json",
+ "Authorization": "Bearer {}".format(api_key),
+ }
+
+ try:
+ response = self.sync_http_handler.post(
+ url=_url,
+ data=json.dumps(_data),
+ headers=_headers,
+ )
+
+ response.raise_for_status()
+ except Exception as e:
+ error_response = getattr(e, "response", None)
+ if error_response is not None and hasattr(error_response, "text"):
+ verbose_logger.debug(f"\nError Message: {error_response.text}")
+ raise e
+
+ async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
+ try:
+ verbose_logger.debug("ENTERS LAGO CALLBACK")
+ _url = os.getenv("LAGO_API_BASE")
+ assert _url is not None and isinstance(
+ _url, str
+ ), "LAGO_API_BASE missing or not set correctly. LAGO_API_BASE={}".format(
+ _url
+ )
+ if _url.endswith("/"):
+ _url += "api/v1/events"
+ else:
+ _url += "/api/v1/events"
+
+ api_key = os.getenv("LAGO_API_KEY")
+
+ _data = self._common_logic(kwargs=kwargs, response_obj=response_obj)
+ _headers = {
+ "Content-Type": "application/json",
+ "Authorization": "Bearer {}".format(api_key),
+ }
+ except Exception as e:
+ raise e
+
+ response: Optional[httpx.Response] = None
+ try:
+ response = await self.async_http_handler.post(
+ url=_url,
+ data=json.dumps(_data),
+ headers=_headers,
+ )
+
+ response.raise_for_status()
+
+ verbose_logger.debug(f"Logged Lago Object: {response.text}")
+ except Exception as e:
+ if response is not None and hasattr(response, "text"):
+ verbose_logger.debug(f"\nError Message: {response.text}")
+ raise e
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/langfuse/langfuse.py b/.venv/lib/python3.12/site-packages/litellm/integrations/langfuse/langfuse.py
new file mode 100644
index 00000000..f990a316
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/langfuse/langfuse.py
@@ -0,0 +1,955 @@
+#### What this does ####
+# On success, logs events to Langfuse
+import copy
+import os
+import traceback
+from datetime import datetime
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
+
+from packaging.version import Version
+
+import litellm
+from litellm._logging import verbose_logger
+from litellm.litellm_core_utils.redact_messages import redact_user_api_key_info
+from litellm.llms.custom_httpx.http_handler import _get_httpx_client
+from litellm.secret_managers.main import str_to_bool
+from litellm.types.integrations.langfuse import *
+from litellm.types.llms.openai import HttpxBinaryResponseContent
+from litellm.types.utils import (
+ EmbeddingResponse,
+ ImageResponse,
+ ModelResponse,
+ RerankResponse,
+ StandardLoggingPayload,
+ StandardLoggingPromptManagementMetadata,
+ TextCompletionResponse,
+ TranscriptionResponse,
+)
+
+if TYPE_CHECKING:
+ from litellm.litellm_core_utils.litellm_logging import DynamicLoggingCache
+else:
+ DynamicLoggingCache = Any
+
+
+class LangFuseLogger:
+ # Class variables or attributes
+ def __init__(
+ self,
+ langfuse_public_key=None,
+ langfuse_secret=None,
+ langfuse_host=None,
+ flush_interval=1,
+ ):
+ try:
+ import langfuse
+ from langfuse import Langfuse
+ except Exception as e:
+ raise Exception(
+ f"\033[91mLangfuse not installed, try running 'pip install langfuse' to fix this error: {e}\n{traceback.format_exc()}\033[0m"
+ )
+ # Instance variables
+ self.secret_key = langfuse_secret or os.getenv("LANGFUSE_SECRET_KEY")
+ self.public_key = langfuse_public_key or os.getenv("LANGFUSE_PUBLIC_KEY")
+ self.langfuse_host = langfuse_host or os.getenv(
+ "LANGFUSE_HOST", "https://cloud.langfuse.com"
+ )
+ if not (
+ self.langfuse_host.startswith("http://")
+ or self.langfuse_host.startswith("https://")
+ ):
+ # add http:// if unset, assume communicating over private network - e.g. render
+ self.langfuse_host = "http://" + self.langfuse_host
+ self.langfuse_release = os.getenv("LANGFUSE_RELEASE")
+ self.langfuse_debug = os.getenv("LANGFUSE_DEBUG")
+ self.langfuse_flush_interval = LangFuseLogger._get_langfuse_flush_interval(
+ flush_interval
+ )
+ http_client = _get_httpx_client()
+ self.langfuse_client = http_client.client
+
+ parameters = {
+ "public_key": self.public_key,
+ "secret_key": self.secret_key,
+ "host": self.langfuse_host,
+ "release": self.langfuse_release,
+ "debug": self.langfuse_debug,
+ "flush_interval": self.langfuse_flush_interval, # flush interval in seconds
+ "httpx_client": self.langfuse_client,
+ }
+ self.langfuse_sdk_version: str = langfuse.version.__version__
+
+ if Version(self.langfuse_sdk_version) >= Version("2.6.0"):
+ parameters["sdk_integration"] = "litellm"
+
+ self.Langfuse = Langfuse(**parameters)
+
+ # set the current langfuse project id in the environ
+ # this is used by Alerting to link to the correct project
+ try:
+ project_id = self.Langfuse.client.projects.get().data[0].id
+ os.environ["LANGFUSE_PROJECT_ID"] = project_id
+ except Exception:
+ project_id = None
+
+ if os.getenv("UPSTREAM_LANGFUSE_SECRET_KEY") is not None:
+ upstream_langfuse_debug = (
+ str_to_bool(self.upstream_langfuse_debug)
+ if self.upstream_langfuse_debug is not None
+ else None
+ )
+ self.upstream_langfuse_secret_key = os.getenv(
+ "UPSTREAM_LANGFUSE_SECRET_KEY"
+ )
+ self.upstream_langfuse_public_key = os.getenv(
+ "UPSTREAM_LANGFUSE_PUBLIC_KEY"
+ )
+ self.upstream_langfuse_host = os.getenv("UPSTREAM_LANGFUSE_HOST")
+ self.upstream_langfuse_release = os.getenv("UPSTREAM_LANGFUSE_RELEASE")
+ self.upstream_langfuse_debug = os.getenv("UPSTREAM_LANGFUSE_DEBUG")
+ self.upstream_langfuse = Langfuse(
+ public_key=self.upstream_langfuse_public_key,
+ secret_key=self.upstream_langfuse_secret_key,
+ host=self.upstream_langfuse_host,
+ release=self.upstream_langfuse_release,
+ debug=(
+ upstream_langfuse_debug
+ if upstream_langfuse_debug is not None
+ else False
+ ),
+ )
+ else:
+ self.upstream_langfuse = None
+
+ @staticmethod
+ def add_metadata_from_header(litellm_params: dict, metadata: dict) -> dict:
+ """
+ Adds metadata from proxy request headers to Langfuse logging if keys start with "langfuse_"
+ and overwrites litellm_params.metadata if already included.
+
+ For example if you want to append your trace to an existing `trace_id` via header, send
+ `headers: { ..., langfuse_existing_trace_id: your-existing-trace-id }` via proxy request.
+ """
+ if litellm_params is None:
+ return metadata
+
+ if litellm_params.get("proxy_server_request") is None:
+ return metadata
+
+ if metadata is None:
+ metadata = {}
+
+ proxy_headers = (
+ litellm_params.get("proxy_server_request", {}).get("headers", {}) or {}
+ )
+
+ for metadata_param_key in proxy_headers:
+ if metadata_param_key.startswith("langfuse_"):
+ trace_param_key = metadata_param_key.replace("langfuse_", "", 1)
+ if trace_param_key in metadata:
+ verbose_logger.warning(
+ f"Overwriting Langfuse `{trace_param_key}` from request header"
+ )
+ else:
+ verbose_logger.debug(
+ f"Found Langfuse `{trace_param_key}` in request header"
+ )
+ metadata[trace_param_key] = proxy_headers.get(metadata_param_key)
+
+ return metadata
+
+ def log_event_on_langfuse(
+ self,
+ kwargs: dict,
+ response_obj: Union[
+ None,
+ dict,
+ EmbeddingResponse,
+ ModelResponse,
+ TextCompletionResponse,
+ ImageResponse,
+ TranscriptionResponse,
+ RerankResponse,
+ HttpxBinaryResponseContent,
+ ],
+ start_time: Optional[datetime] = None,
+ end_time: Optional[datetime] = None,
+ user_id: Optional[str] = None,
+ level: str = "DEFAULT",
+ status_message: Optional[str] = None,
+ ) -> dict:
+ """
+ Logs a success or error event on Langfuse
+ """
+ try:
+ verbose_logger.debug(
+ f"Langfuse Logging - Enters logging function for model {kwargs}"
+ )
+
+ # set default values for input/output for langfuse logging
+ input = None
+ output = None
+
+ litellm_params = kwargs.get("litellm_params", {})
+ litellm_call_id = kwargs.get("litellm_call_id", None)
+ metadata = (
+ litellm_params.get("metadata", {}) or {}
+ ) # if litellm_params['metadata'] == None
+ metadata = self.add_metadata_from_header(litellm_params, metadata)
+ optional_params = copy.deepcopy(kwargs.get("optional_params", {}))
+
+ prompt = {"messages": kwargs.get("messages")}
+
+ functions = optional_params.pop("functions", None)
+ tools = optional_params.pop("tools", None)
+ if functions is not None:
+ prompt["functions"] = functions
+ if tools is not None:
+ prompt["tools"] = tools
+
+ # langfuse only accepts str, int, bool, float for logging
+ for param, value in optional_params.items():
+ if not isinstance(value, (str, int, bool, float)):
+ try:
+ optional_params[param] = str(value)
+ except Exception:
+ # if casting value to str fails don't block logging
+ pass
+
+ input, output = self._get_langfuse_input_output_content(
+ kwargs=kwargs,
+ response_obj=response_obj,
+ prompt=prompt,
+ level=level,
+ status_message=status_message,
+ )
+ verbose_logger.debug(
+ f"OUTPUT IN LANGFUSE: {output}; original: {response_obj}"
+ )
+ trace_id = None
+ generation_id = None
+ if self._is_langfuse_v2():
+ trace_id, generation_id = self._log_langfuse_v2(
+ user_id=user_id,
+ metadata=metadata,
+ litellm_params=litellm_params,
+ output=output,
+ start_time=start_time,
+ end_time=end_time,
+ kwargs=kwargs,
+ optional_params=optional_params,
+ input=input,
+ response_obj=response_obj,
+ level=level,
+ litellm_call_id=litellm_call_id,
+ )
+ elif response_obj is not None:
+ self._log_langfuse_v1(
+ user_id=user_id,
+ metadata=metadata,
+ output=output,
+ start_time=start_time,
+ end_time=end_time,
+ kwargs=kwargs,
+ optional_params=optional_params,
+ input=input,
+ response_obj=response_obj,
+ )
+ verbose_logger.debug(
+ f"Langfuse Layer Logging - final response object: {response_obj}"
+ )
+ verbose_logger.info("Langfuse Layer Logging - logging success")
+
+ return {"trace_id": trace_id, "generation_id": generation_id}
+ except Exception as e:
+ verbose_logger.exception(
+ "Langfuse Layer Error(): Exception occured - {}".format(str(e))
+ )
+ return {"trace_id": None, "generation_id": None}
+
+ def _get_langfuse_input_output_content(
+ self,
+ kwargs: dict,
+ response_obj: Union[
+ None,
+ dict,
+ EmbeddingResponse,
+ ModelResponse,
+ TextCompletionResponse,
+ ImageResponse,
+ TranscriptionResponse,
+ RerankResponse,
+ HttpxBinaryResponseContent,
+ ],
+ prompt: dict,
+ level: str,
+ status_message: Optional[str],
+ ) -> Tuple[Optional[dict], Optional[Union[str, dict, list]]]:
+ """
+ Get the input and output content for Langfuse logging
+
+ Args:
+ kwargs: The keyword arguments passed to the function
+ response_obj: The response object returned by the function
+ prompt: The prompt used to generate the response
+ level: The level of the log message
+ status_message: The status message of the log message
+
+ Returns:
+ input: The input content for Langfuse logging
+ output: The output content for Langfuse logging
+ """
+ input = None
+ output: Optional[Union[str, dict, List[Any]]] = None
+ if (
+ level == "ERROR"
+ and status_message is not None
+ and isinstance(status_message, str)
+ ):
+ input = prompt
+ output = status_message
+ elif response_obj is not None and (
+ kwargs.get("call_type", None) == "embedding"
+ or isinstance(response_obj, litellm.EmbeddingResponse)
+ ):
+ input = prompt
+ output = None
+ elif response_obj is not None and isinstance(
+ response_obj, litellm.ModelResponse
+ ):
+ input = prompt
+ output = self._get_chat_content_for_langfuse(response_obj)
+ elif response_obj is not None and isinstance(
+ response_obj, litellm.HttpxBinaryResponseContent
+ ):
+ input = prompt
+ output = "speech-output"
+ elif response_obj is not None and isinstance(
+ response_obj, litellm.TextCompletionResponse
+ ):
+ input = prompt
+ output = self._get_text_completion_content_for_langfuse(response_obj)
+ elif response_obj is not None and isinstance(
+ response_obj, litellm.ImageResponse
+ ):
+ input = prompt
+ output = response_obj.get("data", None)
+ elif response_obj is not None and isinstance(
+ response_obj, litellm.TranscriptionResponse
+ ):
+ input = prompt
+ output = response_obj.get("text", None)
+ elif response_obj is not None and isinstance(
+ response_obj, litellm.RerankResponse
+ ):
+ input = prompt
+ output = response_obj.results
+ elif (
+ kwargs.get("call_type") is not None
+ and kwargs.get("call_type") == "_arealtime"
+ and response_obj is not None
+ and isinstance(response_obj, list)
+ ):
+ input = kwargs.get("input")
+ output = response_obj
+ elif (
+ kwargs.get("call_type") is not None
+ and kwargs.get("call_type") == "pass_through_endpoint"
+ and response_obj is not None
+ and isinstance(response_obj, dict)
+ ):
+ input = prompt
+ output = response_obj.get("response", "")
+ return input, output
+
+ async def _async_log_event(
+ self, kwargs, response_obj, start_time, end_time, user_id
+ ):
+ """
+ Langfuse SDK uses a background thread to log events
+
+ This approach does not impact latency and runs in the background
+ """
+
+ def _is_langfuse_v2(self):
+ import langfuse
+
+ return Version(langfuse.version.__version__) >= Version("2.0.0")
+
+ def _log_langfuse_v1(
+ self,
+ user_id,
+ metadata,
+ output,
+ start_time,
+ end_time,
+ kwargs,
+ optional_params,
+ input,
+ response_obj,
+ ):
+ from langfuse.model import CreateGeneration, CreateTrace # type: ignore
+
+ verbose_logger.warning(
+ "Please upgrade langfuse to v2.0.0 or higher: https://github.com/langfuse/langfuse-python/releases/tag/v2.0.1"
+ )
+
+ trace = self.Langfuse.trace( # type: ignore
+ CreateTrace( # type: ignore
+ name=metadata.get("generation_name", "litellm-completion"),
+ input=input,
+ output=output,
+ userId=user_id,
+ )
+ )
+
+ trace.generation(
+ CreateGeneration(
+ name=metadata.get("generation_name", "litellm-completion"),
+ startTime=start_time,
+ endTime=end_time,
+ model=kwargs["model"],
+ modelParameters=optional_params,
+ prompt=input,
+ completion=output,
+ usage={
+ "prompt_tokens": response_obj.usage.prompt_tokens,
+ "completion_tokens": response_obj.usage.completion_tokens,
+ },
+ metadata=metadata,
+ )
+ )
+
+ def _log_langfuse_v2( # noqa: PLR0915
+ self,
+ user_id: Optional[str],
+ metadata: dict,
+ litellm_params: dict,
+ output: Optional[Union[str, dict, list]],
+ start_time: Optional[datetime],
+ end_time: Optional[datetime],
+ kwargs: dict,
+ optional_params: dict,
+ input: Optional[dict],
+ response_obj,
+ level: str,
+ litellm_call_id: Optional[str],
+ ) -> tuple:
+ verbose_logger.debug("Langfuse Layer Logging - logging to langfuse v2")
+
+ try:
+ metadata = metadata or {}
+ standard_logging_object: Optional[StandardLoggingPayload] = cast(
+ Optional[StandardLoggingPayload],
+ kwargs.get("standard_logging_object", None),
+ )
+ tags = (
+ self._get_langfuse_tags(standard_logging_object=standard_logging_object)
+ if self._supports_tags()
+ else []
+ )
+
+ if standard_logging_object is None:
+ end_user_id = None
+ prompt_management_metadata: Optional[
+ StandardLoggingPromptManagementMetadata
+ ] = None
+ else:
+ end_user_id = standard_logging_object["metadata"].get(
+ "user_api_key_end_user_id", None
+ )
+
+ prompt_management_metadata = cast(
+ Optional[StandardLoggingPromptManagementMetadata],
+ standard_logging_object["metadata"].get(
+ "prompt_management_metadata", None
+ ),
+ )
+
+ # Clean Metadata before logging - never log raw metadata
+ # the raw metadata can contain circular references which leads to infinite recursion
+ # we clean out all extra litellm metadata params before logging
+ clean_metadata: Dict[str, Any] = {}
+ if prompt_management_metadata is not None:
+ clean_metadata["prompt_management_metadata"] = (
+ prompt_management_metadata
+ )
+ if isinstance(metadata, dict):
+ for key, value in metadata.items():
+ # generate langfuse tags - Default Tags sent to Langfuse from LiteLLM Proxy
+ if (
+ litellm.langfuse_default_tags is not None
+ and isinstance(litellm.langfuse_default_tags, list)
+ and key in litellm.langfuse_default_tags
+ ):
+ tags.append(f"{key}:{value}")
+
+ # clean litellm metadata before logging
+ if key in [
+ "headers",
+ "endpoint",
+ "caching_groups",
+ "previous_models",
+ ]:
+ continue
+ else:
+ clean_metadata[key] = value
+
+ # Add default langfuse tags
+ tags = self.add_default_langfuse_tags(
+ tags=tags, kwargs=kwargs, metadata=metadata
+ )
+
+ session_id = clean_metadata.pop("session_id", None)
+ trace_name = cast(Optional[str], clean_metadata.pop("trace_name", None))
+ trace_id = clean_metadata.pop("trace_id", litellm_call_id)
+ existing_trace_id = clean_metadata.pop("existing_trace_id", None)
+ update_trace_keys = cast(list, clean_metadata.pop("update_trace_keys", []))
+ debug = clean_metadata.pop("debug_langfuse", None)
+ mask_input = clean_metadata.pop("mask_input", False)
+ mask_output = clean_metadata.pop("mask_output", False)
+
+ clean_metadata = redact_user_api_key_info(metadata=clean_metadata)
+
+ if trace_name is None and existing_trace_id is None:
+ # just log `litellm-{call_type}` as the trace name
+ ## DO NOT SET TRACE_NAME if trace-id set. this can lead to overwriting of past traces.
+ trace_name = f"litellm-{kwargs.get('call_type', 'completion')}"
+
+ if existing_trace_id is not None:
+ trace_params: Dict[str, Any] = {"id": existing_trace_id}
+
+ # Update the following keys for this trace
+ for metadata_param_key in update_trace_keys:
+ trace_param_key = metadata_param_key.replace("trace_", "")
+ if trace_param_key not in trace_params:
+ updated_trace_value = clean_metadata.pop(
+ metadata_param_key, None
+ )
+ if updated_trace_value is not None:
+ trace_params[trace_param_key] = updated_trace_value
+
+ # Pop the trace specific keys that would have been popped if there were a new trace
+ for key in list(
+ filter(lambda key: key.startswith("trace_"), clean_metadata.keys())
+ ):
+ clean_metadata.pop(key, None)
+
+ # Special keys that are found in the function arguments and not the metadata
+ if "input" in update_trace_keys:
+ trace_params["input"] = (
+ input if not mask_input else "redacted-by-litellm"
+ )
+ if "output" in update_trace_keys:
+ trace_params["output"] = (
+ output if not mask_output else "redacted-by-litellm"
+ )
+ else: # don't overwrite an existing trace
+ trace_params = {
+ "id": trace_id,
+ "name": trace_name,
+ "session_id": session_id,
+ "input": input if not mask_input else "redacted-by-litellm",
+ "version": clean_metadata.pop(
+ "trace_version", clean_metadata.get("version", None)
+ ), # If provided just version, it will applied to the trace as well, if applied a trace version it will take precedence
+ "user_id": end_user_id,
+ }
+ for key in list(
+ filter(lambda key: key.startswith("trace_"), clean_metadata.keys())
+ ):
+ trace_params[key.replace("trace_", "")] = clean_metadata.pop(
+ key, None
+ )
+
+ if level == "ERROR":
+ trace_params["status_message"] = output
+ else:
+ trace_params["output"] = (
+ output if not mask_output else "redacted-by-litellm"
+ )
+
+ if debug is True or (isinstance(debug, str) and debug.lower() == "true"):
+ if "metadata" in trace_params:
+ # log the raw_metadata in the trace
+ trace_params["metadata"]["metadata_passed_to_litellm"] = metadata
+ else:
+ trace_params["metadata"] = {"metadata_passed_to_litellm": metadata}
+
+ cost = kwargs.get("response_cost", None)
+ verbose_logger.debug(f"trace: {cost}")
+
+ clean_metadata["litellm_response_cost"] = cost
+ if standard_logging_object is not None:
+ clean_metadata["hidden_params"] = standard_logging_object[
+ "hidden_params"
+ ]
+
+ if (
+ litellm.langfuse_default_tags is not None
+ and isinstance(litellm.langfuse_default_tags, list)
+ and "proxy_base_url" in litellm.langfuse_default_tags
+ ):
+ proxy_base_url = os.environ.get("PROXY_BASE_URL", None)
+ if proxy_base_url is not None:
+ tags.append(f"proxy_base_url:{proxy_base_url}")
+
+ api_base = litellm_params.get("api_base", None)
+ if api_base:
+ clean_metadata["api_base"] = api_base
+
+ vertex_location = kwargs.get("vertex_location", None)
+ if vertex_location:
+ clean_metadata["vertex_location"] = vertex_location
+
+ aws_region_name = kwargs.get("aws_region_name", None)
+ if aws_region_name:
+ clean_metadata["aws_region_name"] = aws_region_name
+
+ if self._supports_tags():
+ if "cache_hit" in kwargs:
+ if kwargs["cache_hit"] is None:
+ kwargs["cache_hit"] = False
+ clean_metadata["cache_hit"] = kwargs["cache_hit"]
+ if existing_trace_id is None:
+ trace_params.update({"tags": tags})
+
+ proxy_server_request = litellm_params.get("proxy_server_request", None)
+ if proxy_server_request:
+ proxy_server_request.get("method", None)
+ proxy_server_request.get("url", None)
+ headers = proxy_server_request.get("headers", None)
+ clean_headers = {}
+ if headers:
+ for key, value in headers.items():
+ # these headers can leak our API keys and/or JWT tokens
+ if key.lower() not in ["authorization", "cookie", "referer"]:
+ clean_headers[key] = value
+
+ # clean_metadata["request"] = {
+ # "method": method,
+ # "url": url,
+ # "headers": clean_headers,
+ # }
+ trace = self.Langfuse.trace(**trace_params)
+
+ # Log provider specific information as a span
+ log_provider_specific_information_as_span(trace, clean_metadata)
+
+ generation_id = None
+ usage = None
+ if response_obj is not None:
+ if (
+ hasattr(response_obj, "id")
+ and response_obj.get("id", None) is not None
+ ):
+ generation_id = litellm.utils.get_logging_id(
+ start_time, response_obj
+ )
+ _usage_obj = getattr(response_obj, "usage", None)
+
+ if _usage_obj:
+ usage = {
+ "prompt_tokens": _usage_obj.prompt_tokens,
+ "completion_tokens": _usage_obj.completion_tokens,
+ "total_cost": cost if self._supports_costs() else None,
+ }
+ generation_name = clean_metadata.pop("generation_name", None)
+ if generation_name is None:
+ # if `generation_name` is None, use sensible default values
+ # If using litellm proxy user `key_alias` if not None
+ # If `key_alias` is None, just log `litellm-{call_type}` as the generation name
+ _user_api_key_alias = cast(
+ Optional[str], clean_metadata.get("user_api_key_alias", None)
+ )
+ generation_name = (
+ f"litellm-{cast(str, kwargs.get('call_type', 'completion'))}"
+ )
+ if _user_api_key_alias is not None:
+ generation_name = f"litellm:{_user_api_key_alias}"
+
+ if response_obj is not None:
+ system_fingerprint = getattr(response_obj, "system_fingerprint", None)
+ else:
+ system_fingerprint = None
+
+ if system_fingerprint is not None:
+ optional_params["system_fingerprint"] = system_fingerprint
+
+ generation_params = {
+ "name": generation_name,
+ "id": clean_metadata.pop("generation_id", generation_id),
+ "start_time": start_time,
+ "end_time": end_time,
+ "model": kwargs["model"],
+ "model_parameters": optional_params,
+ "input": input if not mask_input else "redacted-by-litellm",
+ "output": output if not mask_output else "redacted-by-litellm",
+ "usage": usage,
+ "metadata": log_requester_metadata(clean_metadata),
+ "level": level,
+ "version": clean_metadata.pop("version", None),
+ }
+
+ parent_observation_id = metadata.get("parent_observation_id", None)
+ if parent_observation_id is not None:
+ generation_params["parent_observation_id"] = parent_observation_id
+
+ if self._supports_prompt():
+ generation_params = _add_prompt_to_generation_params(
+ generation_params=generation_params,
+ clean_metadata=clean_metadata,
+ prompt_management_metadata=prompt_management_metadata,
+ langfuse_client=self.Langfuse,
+ )
+ if output is not None and isinstance(output, str) and level == "ERROR":
+ generation_params["status_message"] = output
+
+ if self._supports_completion_start_time():
+ generation_params["completion_start_time"] = kwargs.get(
+ "completion_start_time", None
+ )
+
+ generation_client = trace.generation(**generation_params)
+
+ return generation_client.trace_id, generation_id
+ except Exception:
+ verbose_logger.error(f"Langfuse Layer Error - {traceback.format_exc()}")
+ return None, None
+
+ @staticmethod
+ def _get_chat_content_for_langfuse(
+ response_obj: ModelResponse,
+ ):
+ """
+ Get the chat content for Langfuse logging
+ """
+ if response_obj.choices and len(response_obj.choices) > 0:
+ output = response_obj["choices"][0]["message"].json()
+ return output
+ else:
+ return None
+
+ @staticmethod
+ def _get_text_completion_content_for_langfuse(
+ response_obj: TextCompletionResponse,
+ ):
+ """
+ Get the text completion content for Langfuse logging
+ """
+ if response_obj.choices and len(response_obj.choices) > 0:
+ return response_obj.choices[0].text
+ else:
+ return None
+
+ @staticmethod
+ def _get_langfuse_tags(
+ standard_logging_object: Optional[StandardLoggingPayload],
+ ) -> List[str]:
+ if standard_logging_object is None:
+ return []
+ return standard_logging_object.get("request_tags", []) or []
+
+ def add_default_langfuse_tags(self, tags, kwargs, metadata):
+ """
+ Helper function to add litellm default langfuse tags
+
+ - Special LiteLLM tags:
+ - cache_hit
+ - cache_key
+
+ """
+ if litellm.langfuse_default_tags is not None and isinstance(
+ litellm.langfuse_default_tags, list
+ ):
+ if "cache_hit" in litellm.langfuse_default_tags:
+ _cache_hit_value = kwargs.get("cache_hit", False)
+ tags.append(f"cache_hit:{_cache_hit_value}")
+ if "cache_key" in litellm.langfuse_default_tags:
+ _hidden_params = metadata.get("hidden_params", {}) or {}
+ _cache_key = _hidden_params.get("cache_key", None)
+ if _cache_key is None and litellm.cache is not None:
+ # fallback to using "preset_cache_key"
+ _preset_cache_key = litellm.cache._get_preset_cache_key_from_kwargs(
+ **kwargs
+ )
+ _cache_key = _preset_cache_key
+ tags.append(f"cache_key:{_cache_key}")
+ return tags
+
+ def _supports_tags(self):
+ """Check if current langfuse version supports tags"""
+ return Version(self.langfuse_sdk_version) >= Version("2.6.3")
+
+ def _supports_prompt(self):
+ """Check if current langfuse version supports prompt"""
+ return Version(self.langfuse_sdk_version) >= Version("2.7.3")
+
+ def _supports_costs(self):
+ """Check if current langfuse version supports costs"""
+ return Version(self.langfuse_sdk_version) >= Version("2.7.3")
+
+ def _supports_completion_start_time(self):
+ """Check if current langfuse version supports completion start time"""
+ return Version(self.langfuse_sdk_version) >= Version("2.7.3")
+
+ @staticmethod
+ def _get_langfuse_flush_interval(flush_interval: int) -> int:
+ """
+ Get the langfuse flush interval to initialize the Langfuse client
+
+ Reads `LANGFUSE_FLUSH_INTERVAL` from the environment variable.
+ If not set, uses the flush interval passed in as an argument.
+
+ Args:
+ flush_interval: The flush interval to use if LANGFUSE_FLUSH_INTERVAL is not set
+
+ Returns:
+ [int] The flush interval to use to initialize the Langfuse client
+ """
+ return int(os.getenv("LANGFUSE_FLUSH_INTERVAL") or flush_interval)
+
+
+def _add_prompt_to_generation_params(
+ generation_params: dict,
+ clean_metadata: dict,
+ prompt_management_metadata: Optional[StandardLoggingPromptManagementMetadata],
+ langfuse_client: Any,
+) -> dict:
+ from langfuse import Langfuse
+ from langfuse.model import (
+ ChatPromptClient,
+ Prompt_Chat,
+ Prompt_Text,
+ TextPromptClient,
+ )
+
+ langfuse_client = cast(Langfuse, langfuse_client)
+
+ user_prompt = clean_metadata.pop("prompt", None)
+ if user_prompt is None and prompt_management_metadata is None:
+ pass
+ elif isinstance(user_prompt, dict):
+ if user_prompt.get("type", "") == "chat":
+ _prompt_chat = Prompt_Chat(**user_prompt)
+ generation_params["prompt"] = ChatPromptClient(prompt=_prompt_chat)
+ elif user_prompt.get("type", "") == "text":
+ _prompt_text = Prompt_Text(**user_prompt)
+ generation_params["prompt"] = TextPromptClient(prompt=_prompt_text)
+ elif "version" in user_prompt and "prompt" in user_prompt:
+ # prompts
+ if isinstance(user_prompt["prompt"], str):
+ prompt_text_params = getattr(
+ Prompt_Text, "model_fields", Prompt_Text.__fields__
+ )
+ _data = {
+ "name": user_prompt["name"],
+ "prompt": user_prompt["prompt"],
+ "version": user_prompt["version"],
+ "config": user_prompt.get("config", None),
+ }
+ if "labels" in prompt_text_params and "tags" in prompt_text_params:
+ _data["labels"] = user_prompt.get("labels", []) or []
+ _data["tags"] = user_prompt.get("tags", []) or []
+ _prompt_obj = Prompt_Text(**_data) # type: ignore
+ generation_params["prompt"] = TextPromptClient(prompt=_prompt_obj)
+
+ elif isinstance(user_prompt["prompt"], list):
+ prompt_chat_params = getattr(
+ Prompt_Chat, "model_fields", Prompt_Chat.__fields__
+ )
+ _data = {
+ "name": user_prompt["name"],
+ "prompt": user_prompt["prompt"],
+ "version": user_prompt["version"],
+ "config": user_prompt.get("config", None),
+ }
+ if "labels" in prompt_chat_params and "tags" in prompt_chat_params:
+ _data["labels"] = user_prompt.get("labels", []) or []
+ _data["tags"] = user_prompt.get("tags", []) or []
+
+ _prompt_obj = Prompt_Chat(**_data) # type: ignore
+
+ generation_params["prompt"] = ChatPromptClient(prompt=_prompt_obj)
+ else:
+ verbose_logger.error(
+ "[Non-blocking] Langfuse Logger: Invalid prompt format"
+ )
+ else:
+ verbose_logger.error(
+ "[Non-blocking] Langfuse Logger: Invalid prompt format. No prompt logged to Langfuse"
+ )
+ elif (
+ prompt_management_metadata is not None
+ and prompt_management_metadata["prompt_integration"] == "langfuse"
+ ):
+ try:
+ generation_params["prompt"] = langfuse_client.get_prompt(
+ prompt_management_metadata["prompt_id"]
+ )
+ except Exception as e:
+ verbose_logger.debug(
+ f"[Non-blocking] Langfuse Logger: Error getting prompt client for logging: {e}"
+ )
+ pass
+
+ else:
+ generation_params["prompt"] = user_prompt
+
+ return generation_params
+
+
+def log_provider_specific_information_as_span(
+ trace,
+ clean_metadata,
+):
+ """
+ Logs provider-specific information as spans.
+
+ Parameters:
+ trace: The tracing object used to log spans.
+ clean_metadata: A dictionary containing metadata to be logged.
+
+ Returns:
+ None
+ """
+
+ _hidden_params = clean_metadata.get("hidden_params", None)
+ if _hidden_params is None:
+ return
+
+ vertex_ai_grounding_metadata = _hidden_params.get(
+ "vertex_ai_grounding_metadata", None
+ )
+
+ if vertex_ai_grounding_metadata is not None:
+ if isinstance(vertex_ai_grounding_metadata, list):
+ for elem in vertex_ai_grounding_metadata:
+ if isinstance(elem, dict):
+ for key, value in elem.items():
+ trace.span(
+ name=key,
+ input=value,
+ )
+ else:
+ trace.span(
+ name="vertex_ai_grounding_metadata",
+ input=elem,
+ )
+ else:
+ trace.span(
+ name="vertex_ai_grounding_metadata",
+ input=vertex_ai_grounding_metadata,
+ )
+
+
+def log_requester_metadata(clean_metadata: dict):
+ returned_metadata = {}
+ requester_metadata = clean_metadata.get("requester_metadata") or {}
+ for k, v in clean_metadata.items():
+ if k not in requester_metadata:
+ returned_metadata[k] = v
+
+ returned_metadata.update({"requester_metadata": requester_metadata})
+
+ return returned_metadata
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/langfuse/langfuse_handler.py b/.venv/lib/python3.12/site-packages/litellm/integrations/langfuse/langfuse_handler.py
new file mode 100644
index 00000000..aebe1461
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/langfuse/langfuse_handler.py
@@ -0,0 +1,169 @@
+"""
+This file contains the LangFuseHandler class
+
+Used to get the LangFuseLogger for a given request
+
+Handles Key/Team Based Langfuse Logging
+"""
+
+from typing import TYPE_CHECKING, Any, Dict, Optional
+
+from litellm.litellm_core_utils.litellm_logging import StandardCallbackDynamicParams
+
+from .langfuse import LangFuseLogger, LangfuseLoggingConfig
+
+if TYPE_CHECKING:
+ from litellm.litellm_core_utils.litellm_logging import DynamicLoggingCache
+else:
+ DynamicLoggingCache = Any
+
+
+class LangFuseHandler:
+
+ @staticmethod
+ def get_langfuse_logger_for_request(
+ standard_callback_dynamic_params: StandardCallbackDynamicParams,
+ in_memory_dynamic_logger_cache: DynamicLoggingCache,
+ globalLangfuseLogger: Optional[LangFuseLogger] = None,
+ ) -> LangFuseLogger:
+ """
+ This function is used to get the LangFuseLogger for a given request
+
+ 1. If dynamic credentials are passed
+ - check if a LangFuseLogger is cached for the dynamic credentials
+ - if cached LangFuseLogger is not found, create a new LangFuseLogger and cache it
+
+ 2. If dynamic credentials are not passed return the globalLangfuseLogger
+
+ """
+ temp_langfuse_logger: Optional[LangFuseLogger] = globalLangfuseLogger
+ if (
+ LangFuseHandler._dynamic_langfuse_credentials_are_passed(
+ standard_callback_dynamic_params
+ )
+ is False
+ ):
+ return LangFuseHandler._return_global_langfuse_logger(
+ globalLangfuseLogger=globalLangfuseLogger,
+ in_memory_dynamic_logger_cache=in_memory_dynamic_logger_cache,
+ )
+
+ # get langfuse logging config to use for this request, based on standard_callback_dynamic_params
+ _credentials = LangFuseHandler.get_dynamic_langfuse_logging_config(
+ globalLangfuseLogger=globalLangfuseLogger,
+ standard_callback_dynamic_params=standard_callback_dynamic_params,
+ )
+ credentials_dict = dict(_credentials)
+
+ # check if langfuse logger is already cached
+ temp_langfuse_logger = in_memory_dynamic_logger_cache.get_cache(
+ credentials=credentials_dict, service_name="langfuse"
+ )
+
+ # if not cached, create a new langfuse logger and cache it
+ if temp_langfuse_logger is None:
+ temp_langfuse_logger = (
+ LangFuseHandler._create_langfuse_logger_from_credentials(
+ credentials=credentials_dict,
+ in_memory_dynamic_logger_cache=in_memory_dynamic_logger_cache,
+ )
+ )
+
+ return temp_langfuse_logger
+
+ @staticmethod
+ def _return_global_langfuse_logger(
+ globalLangfuseLogger: Optional[LangFuseLogger],
+ in_memory_dynamic_logger_cache: DynamicLoggingCache,
+ ) -> LangFuseLogger:
+ """
+ Returns the Global LangfuseLogger set on litellm
+
+ (this is the default langfuse logger - used when no dynamic credentials are passed)
+
+ If no Global LangfuseLogger is set, it will check in_memory_dynamic_logger_cache for a cached LangFuseLogger
+ This function is used to return the globalLangfuseLogger if it exists, otherwise it will check in_memory_dynamic_logger_cache for a cached LangFuseLogger
+ """
+ if globalLangfuseLogger is not None:
+ return globalLangfuseLogger
+
+ credentials_dict: Dict[str, Any] = (
+ {}
+ ) # the global langfuse logger uses Environment Variables, there are no dynamic credentials
+ globalLangfuseLogger = in_memory_dynamic_logger_cache.get_cache(
+ credentials=credentials_dict,
+ service_name="langfuse",
+ )
+ if globalLangfuseLogger is None:
+ globalLangfuseLogger = (
+ LangFuseHandler._create_langfuse_logger_from_credentials(
+ credentials=credentials_dict,
+ in_memory_dynamic_logger_cache=in_memory_dynamic_logger_cache,
+ )
+ )
+ return globalLangfuseLogger
+
+ @staticmethod
+ def _create_langfuse_logger_from_credentials(
+ credentials: Dict,
+ in_memory_dynamic_logger_cache: DynamicLoggingCache,
+ ) -> LangFuseLogger:
+ """
+ This function is used to
+ 1. create a LangFuseLogger from the credentials
+ 2. cache the LangFuseLogger to prevent re-creating it for the same credentials
+ """
+
+ langfuse_logger = LangFuseLogger(
+ langfuse_public_key=credentials.get("langfuse_public_key"),
+ langfuse_secret=credentials.get("langfuse_secret"),
+ langfuse_host=credentials.get("langfuse_host"),
+ )
+ in_memory_dynamic_logger_cache.set_cache(
+ credentials=credentials,
+ service_name="langfuse",
+ logging_obj=langfuse_logger,
+ )
+ return langfuse_logger
+
+ @staticmethod
+ def get_dynamic_langfuse_logging_config(
+ standard_callback_dynamic_params: StandardCallbackDynamicParams,
+ globalLangfuseLogger: Optional[LangFuseLogger] = None,
+ ) -> LangfuseLoggingConfig:
+ """
+ This function is used to get the Langfuse logging config to use for a given request.
+
+ It checks if the dynamic parameters are provided in the standard_callback_dynamic_params and uses them to get the Langfuse logging config.
+
+ If no dynamic parameters are provided, it uses the `globalLangfuseLogger` values
+ """
+ # only use dynamic params if langfuse credentials are passed dynamically
+ return LangfuseLoggingConfig(
+ langfuse_secret=standard_callback_dynamic_params.get("langfuse_secret")
+ or standard_callback_dynamic_params.get("langfuse_secret_key"),
+ langfuse_public_key=standard_callback_dynamic_params.get(
+ "langfuse_public_key"
+ ),
+ langfuse_host=standard_callback_dynamic_params.get("langfuse_host"),
+ )
+
+ @staticmethod
+ def _dynamic_langfuse_credentials_are_passed(
+ standard_callback_dynamic_params: StandardCallbackDynamicParams,
+ ) -> bool:
+ """
+ This function is used to check if the dynamic langfuse credentials are passed in standard_callback_dynamic_params
+
+ Returns:
+ bool: True if the dynamic langfuse credentials are passed, False otherwise
+ """
+
+ if (
+ standard_callback_dynamic_params.get("langfuse_host") is not None
+ or standard_callback_dynamic_params.get("langfuse_public_key") is not None
+ or standard_callback_dynamic_params.get("langfuse_secret") is not None
+ or standard_callback_dynamic_params.get("langfuse_secret_key") is not None
+ ):
+ return True
+ return False
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/langfuse/langfuse_prompt_management.py b/.venv/lib/python3.12/site-packages/litellm/integrations/langfuse/langfuse_prompt_management.py
new file mode 100644
index 00000000..1f4ca84d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/langfuse/langfuse_prompt_management.py
@@ -0,0 +1,287 @@
+"""
+Call Hook for LiteLLM Proxy which allows Langfuse prompt management.
+"""
+
+import os
+from functools import lru_cache
+from typing import TYPE_CHECKING, Any, List, Literal, Optional, Tuple, Union, cast
+
+from packaging.version import Version
+from typing_extensions import TypeAlias
+
+from litellm.integrations.custom_logger import CustomLogger
+from litellm.integrations.prompt_management_base import PromptManagementClient
+from litellm.litellm_core_utils.asyncify import run_async_function
+from litellm.types.llms.openai import AllMessageValues, ChatCompletionSystemMessage
+from litellm.types.utils import StandardCallbackDynamicParams, StandardLoggingPayload
+
+from ...litellm_core_utils.specialty_caches.dynamic_logging_cache import (
+ DynamicLoggingCache,
+)
+from ..prompt_management_base import PromptManagementBase
+from .langfuse import LangFuseLogger
+from .langfuse_handler import LangFuseHandler
+
+if TYPE_CHECKING:
+ from langfuse import Langfuse
+ from langfuse.client import ChatPromptClient, TextPromptClient
+
+ LangfuseClass: TypeAlias = Langfuse
+
+ PROMPT_CLIENT = Union[TextPromptClient, ChatPromptClient]
+else:
+ PROMPT_CLIENT = Any
+ LangfuseClass = Any
+
+in_memory_dynamic_logger_cache = DynamicLoggingCache()
+
+
+@lru_cache(maxsize=10)
+def langfuse_client_init(
+ langfuse_public_key=None,
+ langfuse_secret=None,
+ langfuse_secret_key=None,
+ langfuse_host=None,
+ flush_interval=1,
+) -> LangfuseClass:
+ """
+ Initialize Langfuse client with caching to prevent multiple initializations.
+
+ Args:
+ langfuse_public_key (str, optional): Public key for Langfuse. Defaults to None.
+ langfuse_secret (str, optional): Secret key for Langfuse. Defaults to None.
+ langfuse_host (str, optional): Host URL for Langfuse. Defaults to None.
+ flush_interval (int, optional): Flush interval in seconds. Defaults to 1.
+
+ Returns:
+ Langfuse: Initialized Langfuse client instance
+
+ Raises:
+ Exception: If langfuse package is not installed
+ """
+ try:
+ import langfuse
+ from langfuse import Langfuse
+ except Exception as e:
+ raise Exception(
+ f"\033[91mLangfuse not installed, try running 'pip install langfuse' to fix this error: {e}\n\033[0m"
+ )
+
+ # Instance variables
+
+ secret_key = (
+ langfuse_secret or langfuse_secret_key or os.getenv("LANGFUSE_SECRET_KEY")
+ )
+ public_key = langfuse_public_key or os.getenv("LANGFUSE_PUBLIC_KEY")
+ langfuse_host = langfuse_host or os.getenv(
+ "LANGFUSE_HOST", "https://cloud.langfuse.com"
+ )
+
+ if not (
+ langfuse_host.startswith("http://") or langfuse_host.startswith("https://")
+ ):
+ # add http:// if unset, assume communicating over private network - e.g. render
+ langfuse_host = "http://" + langfuse_host
+
+ langfuse_release = os.getenv("LANGFUSE_RELEASE")
+ langfuse_debug = os.getenv("LANGFUSE_DEBUG")
+
+ parameters = {
+ "public_key": public_key,
+ "secret_key": secret_key,
+ "host": langfuse_host,
+ "release": langfuse_release,
+ "debug": langfuse_debug,
+ "flush_interval": LangFuseLogger._get_langfuse_flush_interval(
+ flush_interval
+ ), # flush interval in seconds
+ }
+
+ if Version(langfuse.version.__version__) >= Version("2.6.0"):
+ parameters["sdk_integration"] = "litellm"
+
+ client = Langfuse(**parameters)
+
+ return client
+
+
+class LangfusePromptManagement(LangFuseLogger, PromptManagementBase, CustomLogger):
+ def __init__(
+ self,
+ langfuse_public_key=None,
+ langfuse_secret=None,
+ langfuse_host=None,
+ flush_interval=1,
+ ):
+ import langfuse
+
+ self.langfuse_sdk_version = langfuse.version.__version__
+ self.Langfuse = langfuse_client_init(
+ langfuse_public_key=langfuse_public_key,
+ langfuse_secret=langfuse_secret,
+ langfuse_host=langfuse_host,
+ flush_interval=flush_interval,
+ )
+
+ @property
+ def integration_name(self):
+ return "langfuse"
+
+ def _get_prompt_from_id(
+ self, langfuse_prompt_id: str, langfuse_client: LangfuseClass
+ ) -> PROMPT_CLIENT:
+ return langfuse_client.get_prompt(langfuse_prompt_id)
+
+ def _compile_prompt(
+ self,
+ langfuse_prompt_client: PROMPT_CLIENT,
+ langfuse_prompt_variables: Optional[dict],
+ call_type: Union[Literal["completion"], Literal["text_completion"]],
+ ) -> List[AllMessageValues]:
+ compiled_prompt: Optional[Union[str, list]] = None
+
+ if langfuse_prompt_variables is None:
+ langfuse_prompt_variables = {}
+
+ compiled_prompt = langfuse_prompt_client.compile(**langfuse_prompt_variables)
+
+ if isinstance(compiled_prompt, str):
+ compiled_prompt = [
+ ChatCompletionSystemMessage(role="system", content=compiled_prompt)
+ ]
+ else:
+ compiled_prompt = cast(List[AllMessageValues], compiled_prompt)
+
+ return compiled_prompt
+
+ def _get_optional_params_from_langfuse(
+ self, langfuse_prompt_client: PROMPT_CLIENT
+ ) -> dict:
+ config = langfuse_prompt_client.config
+ optional_params = {}
+ for k, v in config.items():
+ if k != "model":
+ optional_params[k] = v
+ return optional_params
+
+ async def async_get_chat_completion_prompt(
+ self,
+ model: str,
+ messages: List[AllMessageValues],
+ non_default_params: dict,
+ prompt_id: str,
+ prompt_variables: Optional[dict],
+ dynamic_callback_params: StandardCallbackDynamicParams,
+ ) -> Tuple[
+ str,
+ List[AllMessageValues],
+ dict,
+ ]:
+ return self.get_chat_completion_prompt(
+ model,
+ messages,
+ non_default_params,
+ prompt_id,
+ prompt_variables,
+ dynamic_callback_params,
+ )
+
+ def should_run_prompt_management(
+ self,
+ prompt_id: str,
+ dynamic_callback_params: StandardCallbackDynamicParams,
+ ) -> bool:
+ langfuse_client = langfuse_client_init(
+ langfuse_public_key=dynamic_callback_params.get("langfuse_public_key"),
+ langfuse_secret=dynamic_callback_params.get("langfuse_secret"),
+ langfuse_secret_key=dynamic_callback_params.get("langfuse_secret_key"),
+ langfuse_host=dynamic_callback_params.get("langfuse_host"),
+ )
+ langfuse_prompt_client = self._get_prompt_from_id(
+ langfuse_prompt_id=prompt_id, langfuse_client=langfuse_client
+ )
+ return langfuse_prompt_client is not None
+
+ def _compile_prompt_helper(
+ self,
+ prompt_id: str,
+ prompt_variables: Optional[dict],
+ dynamic_callback_params: StandardCallbackDynamicParams,
+ ) -> PromptManagementClient:
+ langfuse_client = langfuse_client_init(
+ langfuse_public_key=dynamic_callback_params.get("langfuse_public_key"),
+ langfuse_secret=dynamic_callback_params.get("langfuse_secret"),
+ langfuse_secret_key=dynamic_callback_params.get("langfuse_secret_key"),
+ langfuse_host=dynamic_callback_params.get("langfuse_host"),
+ )
+ langfuse_prompt_client = self._get_prompt_from_id(
+ langfuse_prompt_id=prompt_id, langfuse_client=langfuse_client
+ )
+
+ ## SET PROMPT
+ compiled_prompt = self._compile_prompt(
+ langfuse_prompt_client=langfuse_prompt_client,
+ langfuse_prompt_variables=prompt_variables,
+ call_type="completion",
+ )
+
+ template_model = langfuse_prompt_client.config.get("model")
+
+ template_optional_params = self._get_optional_params_from_langfuse(
+ langfuse_prompt_client
+ )
+
+ return PromptManagementClient(
+ prompt_id=prompt_id,
+ prompt_template=compiled_prompt,
+ prompt_template_model=template_model,
+ prompt_template_optional_params=template_optional_params,
+ completed_messages=None,
+ )
+
+ def log_success_event(self, kwargs, response_obj, start_time, end_time):
+ return run_async_function(
+ self.async_log_success_event, kwargs, response_obj, start_time, end_time
+ )
+
+ async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
+ standard_callback_dynamic_params = kwargs.get(
+ "standard_callback_dynamic_params"
+ )
+ langfuse_logger_to_use = LangFuseHandler.get_langfuse_logger_for_request(
+ globalLangfuseLogger=self,
+ standard_callback_dynamic_params=standard_callback_dynamic_params,
+ in_memory_dynamic_logger_cache=in_memory_dynamic_logger_cache,
+ )
+ langfuse_logger_to_use.log_event_on_langfuse(
+ kwargs=kwargs,
+ response_obj=response_obj,
+ start_time=start_time,
+ end_time=end_time,
+ user_id=kwargs.get("user", None),
+ )
+
+ async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
+ standard_callback_dynamic_params = kwargs.get(
+ "standard_callback_dynamic_params"
+ )
+ langfuse_logger_to_use = LangFuseHandler.get_langfuse_logger_for_request(
+ globalLangfuseLogger=self,
+ standard_callback_dynamic_params=standard_callback_dynamic_params,
+ in_memory_dynamic_logger_cache=in_memory_dynamic_logger_cache,
+ )
+ standard_logging_object = cast(
+ Optional[StandardLoggingPayload],
+ kwargs.get("standard_logging_object", None),
+ )
+ if standard_logging_object is None:
+ return
+ langfuse_logger_to_use.log_event_on_langfuse(
+ start_time=start_time,
+ end_time=end_time,
+ response_obj=None,
+ user_id=kwargs.get("user", None),
+ status_message=standard_logging_object["error_str"],
+ level="ERROR",
+ kwargs=kwargs,
+ )
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/langsmith.py b/.venv/lib/python3.12/site-packages/litellm/integrations/langsmith.py
new file mode 100644
index 00000000..1ef90c18
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/langsmith.py
@@ -0,0 +1,500 @@
+#### What this does ####
+# On success, logs events to Langsmith
+import asyncio
+import os
+import random
+import traceback
+import types
+import uuid
+from datetime import datetime, timezone
+from typing import Any, Dict, List, Optional
+
+import httpx
+from pydantic import BaseModel # type: ignore
+
+import litellm
+from litellm._logging import verbose_logger
+from litellm.integrations.custom_batch_logger import CustomBatchLogger
+from litellm.llms.custom_httpx.http_handler import (
+ get_async_httpx_client,
+ httpxSpecialProvider,
+)
+from litellm.types.integrations.langsmith import *
+from litellm.types.utils import StandardCallbackDynamicParams, StandardLoggingPayload
+
+
+def is_serializable(value):
+ non_serializable_types = (
+ types.CoroutineType,
+ types.FunctionType,
+ types.GeneratorType,
+ BaseModel,
+ )
+ return not isinstance(value, non_serializable_types)
+
+
+class LangsmithLogger(CustomBatchLogger):
+ def __init__(
+ self,
+ langsmith_api_key: Optional[str] = None,
+ langsmith_project: Optional[str] = None,
+ langsmith_base_url: Optional[str] = None,
+ **kwargs,
+ ):
+ self.default_credentials = self.get_credentials_from_env(
+ langsmith_api_key=langsmith_api_key,
+ langsmith_project=langsmith_project,
+ langsmith_base_url=langsmith_base_url,
+ )
+ self.sampling_rate: float = (
+ float(os.getenv("LANGSMITH_SAMPLING_RATE")) # type: ignore
+ if os.getenv("LANGSMITH_SAMPLING_RATE") is not None
+ and os.getenv("LANGSMITH_SAMPLING_RATE").strip().isdigit() # type: ignore
+ else 1.0
+ )
+ self.langsmith_default_run_name = os.getenv(
+ "LANGSMITH_DEFAULT_RUN_NAME", "LLMRun"
+ )
+ self.async_httpx_client = get_async_httpx_client(
+ llm_provider=httpxSpecialProvider.LoggingCallback
+ )
+ _batch_size = (
+ os.getenv("LANGSMITH_BATCH_SIZE", None) or litellm.langsmith_batch_size
+ )
+ if _batch_size:
+ self.batch_size = int(_batch_size)
+ self.log_queue: List[LangsmithQueueObject] = []
+ asyncio.create_task(self.periodic_flush())
+ self.flush_lock = asyncio.Lock()
+
+ super().__init__(**kwargs, flush_lock=self.flush_lock)
+
+ def get_credentials_from_env(
+ self,
+ langsmith_api_key: Optional[str] = None,
+ langsmith_project: Optional[str] = None,
+ langsmith_base_url: Optional[str] = None,
+ ) -> LangsmithCredentialsObject:
+
+ _credentials_api_key = langsmith_api_key or os.getenv("LANGSMITH_API_KEY")
+ if _credentials_api_key is None:
+ raise Exception(
+ "Invalid Langsmith API Key given. _credentials_api_key=None."
+ )
+ _credentials_project = (
+ langsmith_project or os.getenv("LANGSMITH_PROJECT") or "litellm-completion"
+ )
+ if _credentials_project is None:
+ raise Exception(
+ "Invalid Langsmith API Key given. _credentials_project=None."
+ )
+ _credentials_base_url = (
+ langsmith_base_url
+ or os.getenv("LANGSMITH_BASE_URL")
+ or "https://api.smith.langchain.com"
+ )
+ if _credentials_base_url is None:
+ raise Exception(
+ "Invalid Langsmith API Key given. _credentials_base_url=None."
+ )
+
+ return LangsmithCredentialsObject(
+ LANGSMITH_API_KEY=_credentials_api_key,
+ LANGSMITH_BASE_URL=_credentials_base_url,
+ LANGSMITH_PROJECT=_credentials_project,
+ )
+
+ def _prepare_log_data(
+ self,
+ kwargs,
+ response_obj,
+ start_time,
+ end_time,
+ credentials: LangsmithCredentialsObject,
+ ):
+ try:
+ _litellm_params = kwargs.get("litellm_params", {}) or {}
+ metadata = _litellm_params.get("metadata", {}) or {}
+ project_name = metadata.get(
+ "project_name", credentials["LANGSMITH_PROJECT"]
+ )
+ run_name = metadata.get("run_name", self.langsmith_default_run_name)
+ run_id = metadata.get("id", metadata.get("run_id", None))
+ parent_run_id = metadata.get("parent_run_id", None)
+ trace_id = metadata.get("trace_id", None)
+ session_id = metadata.get("session_id", None)
+ dotted_order = metadata.get("dotted_order", None)
+ verbose_logger.debug(
+ f"Langsmith Logging - project_name: {project_name}, run_name {run_name}"
+ )
+
+ # Ensure everything in the payload is converted to str
+ payload: Optional[StandardLoggingPayload] = kwargs.get(
+ "standard_logging_object", None
+ )
+
+ if payload is None:
+ raise Exception("Error logging request payload. Payload=none.")
+
+ metadata = payload[
+ "metadata"
+ ] # ensure logged metadata is json serializable
+
+ data = {
+ "name": run_name,
+ "run_type": "llm", # this should always be llm, since litellm always logs llm calls. Langsmith allow us to log "chain"
+ "inputs": payload,
+ "outputs": payload["response"],
+ "session_name": project_name,
+ "start_time": payload["startTime"],
+ "end_time": payload["endTime"],
+ "tags": payload["request_tags"],
+ "extra": metadata,
+ }
+
+ if payload["error_str"] is not None and payload["status"] == "failure":
+ data["error"] = payload["error_str"]
+
+ if run_id:
+ data["id"] = run_id
+
+ if parent_run_id:
+ data["parent_run_id"] = parent_run_id
+
+ if trace_id:
+ data["trace_id"] = trace_id
+
+ if session_id:
+ data["session_id"] = session_id
+
+ if dotted_order:
+ data["dotted_order"] = dotted_order
+
+ run_id: Optional[str] = data.get("id") # type: ignore
+ if "id" not in data or data["id"] is None:
+ """
+ for /batch langsmith requires id, trace_id and dotted_order passed as params
+ """
+ run_id = str(uuid.uuid4())
+
+ data["id"] = run_id
+
+ if (
+ "trace_id" not in data
+ or data["trace_id"] is None
+ and (run_id is not None and isinstance(run_id, str))
+ ):
+ data["trace_id"] = run_id
+
+ if (
+ "dotted_order" not in data
+ or data["dotted_order"] is None
+ and (run_id is not None and isinstance(run_id, str))
+ ):
+ data["dotted_order"] = self.make_dot_order(run_id=run_id) # type: ignore
+
+ verbose_logger.debug("Langsmith Logging data on langsmith: %s", data)
+
+ return data
+ except Exception:
+ raise
+
+ def log_success_event(self, kwargs, response_obj, start_time, end_time):
+ try:
+ sampling_rate = (
+ float(os.getenv("LANGSMITH_SAMPLING_RATE")) # type: ignore
+ if os.getenv("LANGSMITH_SAMPLING_RATE") is not None
+ and os.getenv("LANGSMITH_SAMPLING_RATE").strip().isdigit() # type: ignore
+ else 1.0
+ )
+ random_sample = random.random()
+ if random_sample > sampling_rate:
+ verbose_logger.info(
+ "Skipping Langsmith logging. Sampling rate={}, random_sample={}".format(
+ sampling_rate, random_sample
+ )
+ )
+ return # Skip logging
+ verbose_logger.debug(
+ "Langsmith Sync Layer Logging - kwargs: %s, response_obj: %s",
+ kwargs,
+ response_obj,
+ )
+ credentials = self._get_credentials_to_use_for_request(kwargs=kwargs)
+ data = self._prepare_log_data(
+ kwargs=kwargs,
+ response_obj=response_obj,
+ start_time=start_time,
+ end_time=end_time,
+ credentials=credentials,
+ )
+ self.log_queue.append(
+ LangsmithQueueObject(
+ data=data,
+ credentials=credentials,
+ )
+ )
+ verbose_logger.debug(
+ f"Langsmith, event added to queue. Will flush in {self.flush_interval} seconds..."
+ )
+
+ if len(self.log_queue) >= self.batch_size:
+ self._send_batch()
+
+ except Exception:
+ verbose_logger.exception("Langsmith Layer Error - log_success_event error")
+
+ async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
+ try:
+ sampling_rate = self.sampling_rate
+ random_sample = random.random()
+ if random_sample > sampling_rate:
+ verbose_logger.info(
+ "Skipping Langsmith logging. Sampling rate={}, random_sample={}".format(
+ sampling_rate, random_sample
+ )
+ )
+ return # Skip logging
+ verbose_logger.debug(
+ "Langsmith Async Layer Logging - kwargs: %s, response_obj: %s",
+ kwargs,
+ response_obj,
+ )
+ credentials = self._get_credentials_to_use_for_request(kwargs=kwargs)
+ data = self._prepare_log_data(
+ kwargs=kwargs,
+ response_obj=response_obj,
+ start_time=start_time,
+ end_time=end_time,
+ credentials=credentials,
+ )
+ self.log_queue.append(
+ LangsmithQueueObject(
+ data=data,
+ credentials=credentials,
+ )
+ )
+ verbose_logger.debug(
+ "Langsmith logging: queue length %s, batch size %s",
+ len(self.log_queue),
+ self.batch_size,
+ )
+ if len(self.log_queue) >= self.batch_size:
+ await self.flush_queue()
+ except Exception:
+ verbose_logger.exception(
+ "Langsmith Layer Error - error logging async success event."
+ )
+
+ async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
+ sampling_rate = self.sampling_rate
+ random_sample = random.random()
+ if random_sample > sampling_rate:
+ verbose_logger.info(
+ "Skipping Langsmith logging. Sampling rate={}, random_sample={}".format(
+ sampling_rate, random_sample
+ )
+ )
+ return # Skip logging
+ verbose_logger.info("Langsmith Failure Event Logging!")
+ try:
+ credentials = self._get_credentials_to_use_for_request(kwargs=kwargs)
+ data = self._prepare_log_data(
+ kwargs=kwargs,
+ response_obj=response_obj,
+ start_time=start_time,
+ end_time=end_time,
+ credentials=credentials,
+ )
+ self.log_queue.append(
+ LangsmithQueueObject(
+ data=data,
+ credentials=credentials,
+ )
+ )
+ verbose_logger.debug(
+ "Langsmith logging: queue length %s, batch size %s",
+ len(self.log_queue),
+ self.batch_size,
+ )
+ if len(self.log_queue) >= self.batch_size:
+ await self.flush_queue()
+ except Exception:
+ verbose_logger.exception(
+ "Langsmith Layer Error - error logging async failure event."
+ )
+
+ async def async_send_batch(self):
+ """
+ Handles sending batches of runs to Langsmith
+
+ self.log_queue contains LangsmithQueueObjects
+ Each LangsmithQueueObject has the following:
+ - "credentials" - credentials to use for the request (langsmith_api_key, langsmith_project, langsmith_base_url)
+ - "data" - data to log on to langsmith for the request
+
+
+ This function
+ - groups the queue objects by credentials
+ - loops through each unique credentials and sends batches to Langsmith
+
+
+ This was added to support key/team based logging on langsmith
+ """
+ if not self.log_queue:
+ return
+
+ batch_groups = self._group_batches_by_credentials()
+ for batch_group in batch_groups.values():
+ await self._log_batch_on_langsmith(
+ credentials=batch_group.credentials,
+ queue_objects=batch_group.queue_objects,
+ )
+
+ def _add_endpoint_to_url(
+ self, url: str, endpoint: str, api_version: str = "/api/v1"
+ ) -> str:
+ if api_version not in url:
+ url = f"{url.rstrip('/')}{api_version}"
+
+ if url.endswith("/"):
+ return f"{url}{endpoint}"
+ return f"{url}/{endpoint}"
+
+ async def _log_batch_on_langsmith(
+ self,
+ credentials: LangsmithCredentialsObject,
+ queue_objects: List[LangsmithQueueObject],
+ ):
+ """
+ Logs a batch of runs to Langsmith
+ sends runs to /batch endpoint for the given credentials
+
+ Args:
+ credentials: LangsmithCredentialsObject
+ queue_objects: List[LangsmithQueueObject]
+
+ Returns: None
+
+ Raises: Does not raise an exception, will only verbose_logger.exception()
+ """
+ langsmith_api_base = credentials["LANGSMITH_BASE_URL"]
+ langsmith_api_key = credentials["LANGSMITH_API_KEY"]
+ url = self._add_endpoint_to_url(langsmith_api_base, "runs/batch")
+ headers = {"x-api-key": langsmith_api_key}
+ elements_to_log = [queue_object["data"] for queue_object in queue_objects]
+
+ try:
+ verbose_logger.debug(
+ "Sending batch of %s runs to Langsmith", len(elements_to_log)
+ )
+ response = await self.async_httpx_client.post(
+ url=url,
+ json={"post": elements_to_log},
+ headers=headers,
+ )
+ response.raise_for_status()
+
+ if response.status_code >= 300:
+ verbose_logger.error(
+ f"Langsmith Error: {response.status_code} - {response.text}"
+ )
+ else:
+ verbose_logger.debug(
+ f"Batch of {len(self.log_queue)} runs successfully created"
+ )
+ except httpx.HTTPStatusError as e:
+ verbose_logger.exception(
+ f"Langsmith HTTP Error: {e.response.status_code} - {e.response.text}"
+ )
+ except Exception:
+ verbose_logger.exception(
+ f"Langsmith Layer Error - {traceback.format_exc()}"
+ )
+
+ def _group_batches_by_credentials(self) -> Dict[CredentialsKey, BatchGroup]:
+ """Groups queue objects by credentials using a proper key structure"""
+ log_queue_by_credentials: Dict[CredentialsKey, BatchGroup] = {}
+
+ for queue_object in self.log_queue:
+ credentials = queue_object["credentials"]
+ key = CredentialsKey(
+ api_key=credentials["LANGSMITH_API_KEY"],
+ project=credentials["LANGSMITH_PROJECT"],
+ base_url=credentials["LANGSMITH_BASE_URL"],
+ )
+
+ if key not in log_queue_by_credentials:
+ log_queue_by_credentials[key] = BatchGroup(
+ credentials=credentials, queue_objects=[]
+ )
+
+ log_queue_by_credentials[key].queue_objects.append(queue_object)
+
+ return log_queue_by_credentials
+
+ def _get_credentials_to_use_for_request(
+ self, kwargs: Dict[str, Any]
+ ) -> LangsmithCredentialsObject:
+ """
+ Handles key/team based logging
+
+ If standard_callback_dynamic_params are provided, use those credentials.
+
+ Otherwise, use the default credentials.
+ """
+ standard_callback_dynamic_params: Optional[StandardCallbackDynamicParams] = (
+ kwargs.get("standard_callback_dynamic_params", None)
+ )
+ if standard_callback_dynamic_params is not None:
+ credentials = self.get_credentials_from_env(
+ langsmith_api_key=standard_callback_dynamic_params.get(
+ "langsmith_api_key", None
+ ),
+ langsmith_project=standard_callback_dynamic_params.get(
+ "langsmith_project", None
+ ),
+ langsmith_base_url=standard_callback_dynamic_params.get(
+ "langsmith_base_url", None
+ ),
+ )
+ else:
+ credentials = self.default_credentials
+ return credentials
+
+ def _send_batch(self):
+ """Calls async_send_batch in an event loop"""
+ if not self.log_queue:
+ return
+
+ try:
+ # Try to get the existing event loop
+ loop = asyncio.get_event_loop()
+ if loop.is_running():
+ # If we're already in an event loop, create a task
+ asyncio.create_task(self.async_send_batch())
+ else:
+ # If no event loop is running, run the coroutine directly
+ loop.run_until_complete(self.async_send_batch())
+ except RuntimeError:
+ # If we can't get an event loop, create a new one
+ asyncio.run(self.async_send_batch())
+
+ def get_run_by_id(self, run_id):
+
+ langsmith_api_key = self.default_credentials["LANGSMITH_API_KEY"]
+
+ langsmith_api_base = self.default_credentials["LANGSMITH_BASE_URL"]
+
+ url = f"{langsmith_api_base}/runs/{run_id}"
+ response = litellm.module_level_client.get(
+ url=url,
+ headers={"x-api-key": langsmith_api_key},
+ )
+
+ return response.json()
+
+ def make_dot_order(self, run_id: str):
+ st = datetime.now(timezone.utc)
+ id_ = run_id
+ return st.strftime("%Y%m%dT%H%M%S%fZ") + str(id_)
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/langtrace.py b/.venv/lib/python3.12/site-packages/litellm/integrations/langtrace.py
new file mode 100644
index 00000000..51cd272f
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/langtrace.py
@@ -0,0 +1,106 @@
+import json
+from typing import TYPE_CHECKING, Any
+
+from litellm.proxy._types import SpanAttributes
+
+if TYPE_CHECKING:
+ from opentelemetry.trace import Span as _Span
+
+ Span = _Span
+else:
+ Span = Any
+
+
+class LangtraceAttributes:
+ """
+ This class is used to save trace attributes to Langtrace's spans
+ """
+
+ def set_langtrace_attributes(self, span: Span, kwargs, response_obj):
+ """
+ This function is used to log the event to Langtrace
+ """
+
+ vendor = kwargs.get("litellm_params").get("custom_llm_provider")
+ optional_params = kwargs.get("optional_params", {})
+ options = {**kwargs, **optional_params}
+ self.set_request_attributes(span, options, vendor)
+ self.set_response_attributes(span, response_obj)
+ self.set_usage_attributes(span, response_obj)
+
+ def set_request_attributes(self, span: Span, kwargs, vendor):
+ """
+ This function is used to get span attributes for the LLM request
+ """
+ span_attributes = {
+ "gen_ai.operation.name": "chat",
+ "langtrace.service.name": vendor,
+ SpanAttributes.LLM_REQUEST_MODEL.value: kwargs.get("model"),
+ SpanAttributes.LLM_IS_STREAMING.value: kwargs.get("stream"),
+ SpanAttributes.LLM_REQUEST_TEMPERATURE.value: kwargs.get("temperature"),
+ SpanAttributes.LLM_TOP_K.value: kwargs.get("top_k"),
+ SpanAttributes.LLM_REQUEST_TOP_P.value: kwargs.get("top_p"),
+ SpanAttributes.LLM_USER.value: kwargs.get("user"),
+ SpanAttributes.LLM_REQUEST_MAX_TOKENS.value: kwargs.get("max_tokens"),
+ SpanAttributes.LLM_RESPONSE_STOP_REASON.value: kwargs.get("stop"),
+ SpanAttributes.LLM_FREQUENCY_PENALTY.value: kwargs.get("frequency_penalty"),
+ SpanAttributes.LLM_PRESENCE_PENALTY.value: kwargs.get("presence_penalty"),
+ }
+
+ prompts = kwargs.get("messages")
+
+ if prompts:
+ span.add_event(
+ name="gen_ai.content.prompt",
+ attributes={SpanAttributes.LLM_PROMPTS.value: json.dumps(prompts)},
+ )
+
+ self.set_span_attributes(span, span_attributes)
+
+ def set_response_attributes(self, span: Span, response_obj):
+ """
+ This function is used to get span attributes for the LLM response
+ """
+ response_attributes = {
+ "gen_ai.response_id": response_obj.get("id"),
+ "gen_ai.system_fingerprint": response_obj.get("system_fingerprint"),
+ SpanAttributes.LLM_RESPONSE_MODEL.value: response_obj.get("model"),
+ }
+ completions = []
+ for choice in response_obj.get("choices", []):
+ role = choice.get("message").get("role")
+ content = choice.get("message").get("content")
+ completions.append({"role": role, "content": content})
+
+ span.add_event(
+ name="gen_ai.content.completion",
+ attributes={SpanAttributes.LLM_COMPLETIONS: json.dumps(completions)},
+ )
+
+ self.set_span_attributes(span, response_attributes)
+
+ def set_usage_attributes(self, span: Span, response_obj):
+ """
+ This function is used to get span attributes for the LLM usage
+ """
+ usage = response_obj.get("usage")
+ if usage:
+ usage_attributes = {
+ SpanAttributes.LLM_USAGE_PROMPT_TOKENS.value: usage.get(
+ "prompt_tokens"
+ ),
+ SpanAttributes.LLM_USAGE_COMPLETION_TOKENS.value: usage.get(
+ "completion_tokens"
+ ),
+ SpanAttributes.LLM_USAGE_TOTAL_TOKENS.value: usage.get("total_tokens"),
+ }
+ self.set_span_attributes(span, usage_attributes)
+
+ def set_span_attributes(self, span: Span, attributes):
+ """
+ This function is used to set span attributes
+ """
+ for key, value in attributes.items():
+ if not value:
+ continue
+ span.set_attribute(key, value)
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/literal_ai.py b/.venv/lib/python3.12/site-packages/litellm/integrations/literal_ai.py
new file mode 100644
index 00000000..5bf9afd7
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/literal_ai.py
@@ -0,0 +1,317 @@
+#### What this does ####
+# This file contains the LiteralAILogger class which is used to log steps to the LiteralAI observability platform.
+import asyncio
+import os
+import uuid
+from typing import List, Optional
+
+import httpx
+
+from litellm._logging import verbose_logger
+from litellm.integrations.custom_batch_logger import CustomBatchLogger
+from litellm.llms.custom_httpx.http_handler import (
+ HTTPHandler,
+ get_async_httpx_client,
+ httpxSpecialProvider,
+)
+from litellm.types.utils import StandardLoggingPayload
+
+
+class LiteralAILogger(CustomBatchLogger):
+ def __init__(
+ self,
+ literalai_api_key=None,
+ literalai_api_url="https://cloud.getliteral.ai",
+ env=None,
+ **kwargs,
+ ):
+ self.literalai_api_url = os.getenv("LITERAL_API_URL") or literalai_api_url
+ self.headers = {
+ "Content-Type": "application/json",
+ "x-api-key": literalai_api_key or os.getenv("LITERAL_API_KEY"),
+ "x-client-name": "litellm",
+ }
+ if env:
+ self.headers["x-env"] = env
+ self.async_httpx_client = get_async_httpx_client(
+ llm_provider=httpxSpecialProvider.LoggingCallback
+ )
+ self.sync_http_handler = HTTPHandler()
+ batch_size = os.getenv("LITERAL_BATCH_SIZE", None)
+ self.flush_lock = asyncio.Lock()
+ super().__init__(
+ **kwargs,
+ flush_lock=self.flush_lock,
+ batch_size=int(batch_size) if batch_size else None,
+ )
+
+ def log_success_event(self, kwargs, response_obj, start_time, end_time):
+ try:
+ verbose_logger.debug(
+ "Literal AI Layer Logging - kwargs: %s, response_obj: %s",
+ kwargs,
+ response_obj,
+ )
+ data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
+ self.log_queue.append(data)
+ verbose_logger.debug(
+ "Literal AI logging: queue length %s, batch size %s",
+ len(self.log_queue),
+ self.batch_size,
+ )
+ if len(self.log_queue) >= self.batch_size:
+ self._send_batch()
+ except Exception:
+ verbose_logger.exception(
+ "Literal AI Layer Error - error logging success event."
+ )
+
+ def log_failure_event(self, kwargs, response_obj, start_time, end_time):
+ verbose_logger.info("Literal AI Failure Event Logging!")
+ try:
+ data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
+ self.log_queue.append(data)
+ verbose_logger.debug(
+ "Literal AI logging: queue length %s, batch size %s",
+ len(self.log_queue),
+ self.batch_size,
+ )
+ if len(self.log_queue) >= self.batch_size:
+ self._send_batch()
+ except Exception:
+ verbose_logger.exception(
+ "Literal AI Layer Error - error logging failure event."
+ )
+
+ def _send_batch(self):
+ if not self.log_queue:
+ return
+
+ url = f"{self.literalai_api_url}/api/graphql"
+ query = self._steps_query_builder(self.log_queue)
+ variables = self._steps_variables_builder(self.log_queue)
+ try:
+ response = self.sync_http_handler.post(
+ url=url,
+ json={
+ "query": query,
+ "variables": variables,
+ },
+ headers=self.headers,
+ )
+
+ if response.status_code >= 300:
+ verbose_logger.error(
+ f"Literal AI Error: {response.status_code} - {response.text}"
+ )
+ else:
+ verbose_logger.debug(
+ f"Batch of {len(self.log_queue)} runs successfully created"
+ )
+ except Exception:
+ verbose_logger.exception("Literal AI Layer Error")
+
+ async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
+ try:
+ verbose_logger.debug(
+ "Literal AI Async Layer Logging - kwargs: %s, response_obj: %s",
+ kwargs,
+ response_obj,
+ )
+ data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
+ self.log_queue.append(data)
+ verbose_logger.debug(
+ "Literal AI logging: queue length %s, batch size %s",
+ len(self.log_queue),
+ self.batch_size,
+ )
+ if len(self.log_queue) >= self.batch_size:
+ await self.flush_queue()
+ except Exception:
+ verbose_logger.exception(
+ "Literal AI Layer Error - error logging async success event."
+ )
+
+ async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
+ verbose_logger.info("Literal AI Failure Event Logging!")
+ try:
+ data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
+ self.log_queue.append(data)
+ verbose_logger.debug(
+ "Literal AI logging: queue length %s, batch size %s",
+ len(self.log_queue),
+ self.batch_size,
+ )
+ if len(self.log_queue) >= self.batch_size:
+ await self.flush_queue()
+ except Exception:
+ verbose_logger.exception(
+ "Literal AI Layer Error - error logging async failure event."
+ )
+
+ async def async_send_batch(self):
+ if not self.log_queue:
+ return
+
+ url = f"{self.literalai_api_url}/api/graphql"
+ query = self._steps_query_builder(self.log_queue)
+ variables = self._steps_variables_builder(self.log_queue)
+
+ try:
+ response = await self.async_httpx_client.post(
+ url=url,
+ json={
+ "query": query,
+ "variables": variables,
+ },
+ headers=self.headers,
+ )
+ if response.status_code >= 300:
+ verbose_logger.error(
+ f"Literal AI Error: {response.status_code} - {response.text}"
+ )
+ else:
+ verbose_logger.debug(
+ f"Batch of {len(self.log_queue)} runs successfully created"
+ )
+ except httpx.HTTPStatusError as e:
+ verbose_logger.exception(
+ f"Literal AI HTTP Error: {e.response.status_code} - {e.response.text}"
+ )
+ except Exception:
+ verbose_logger.exception("Literal AI Layer Error")
+
+ def _prepare_log_data(self, kwargs, response_obj, start_time, end_time) -> dict:
+ logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
+ "standard_logging_object", None
+ )
+
+ if logging_payload is None:
+ raise ValueError("standard_logging_object not found in kwargs")
+ clean_metadata = logging_payload["metadata"]
+ metadata = kwargs.get("litellm_params", {}).get("metadata", {})
+
+ settings = logging_payload["model_parameters"]
+ messages = logging_payload["messages"]
+ response = logging_payload["response"]
+ choices: List = []
+ if isinstance(response, dict) and "choices" in response:
+ choices = response["choices"]
+ message_completion = choices[0]["message"] if choices else None
+ prompt_id = None
+ variables = None
+
+ if messages and isinstance(messages, list) and isinstance(messages[0], dict):
+ for message in messages:
+ if literal_prompt := getattr(message, "__literal_prompt__", None):
+ prompt_id = literal_prompt.get("prompt_id")
+ variables = literal_prompt.get("variables")
+ message["uuid"] = literal_prompt.get("uuid")
+ message["templated"] = True
+
+ tools = settings.pop("tools", None)
+
+ step = {
+ "id": metadata.get("step_id", str(uuid.uuid4())),
+ "error": logging_payload["error_str"],
+ "name": kwargs.get("model", ""),
+ "threadId": metadata.get("literalai_thread_id", None),
+ "parentId": metadata.get("literalai_parent_id", None),
+ "rootRunId": metadata.get("literalai_root_run_id", None),
+ "input": None,
+ "output": None,
+ "type": "llm",
+ "tags": metadata.get("tags", metadata.get("literalai_tags", None)),
+ "startTime": str(start_time),
+ "endTime": str(end_time),
+ "metadata": clean_metadata,
+ "generation": {
+ "inputTokenCount": logging_payload["prompt_tokens"],
+ "outputTokenCount": logging_payload["completion_tokens"],
+ "tokenCount": logging_payload["total_tokens"],
+ "promptId": prompt_id,
+ "variables": variables,
+ "provider": kwargs.get("custom_llm_provider", "litellm"),
+ "model": kwargs.get("model", ""),
+ "duration": (end_time - start_time).total_seconds(),
+ "settings": settings,
+ "messages": messages,
+ "messageCompletion": message_completion,
+ "tools": tools,
+ },
+ }
+ return step
+
+ def _steps_query_variables_builder(self, steps):
+ generated = ""
+ for id in range(len(steps)):
+ generated += f"""$id_{id}: String!
+ $threadId_{id}: String
+ $rootRunId_{id}: String
+ $type_{id}: StepType
+ $startTime_{id}: DateTime
+ $endTime_{id}: DateTime
+ $error_{id}: String
+ $input_{id}: Json
+ $output_{id}: Json
+ $metadata_{id}: Json
+ $parentId_{id}: String
+ $name_{id}: String
+ $tags_{id}: [String!]
+ $generation_{id}: GenerationPayloadInput
+ $scores_{id}: [ScorePayloadInput!]
+ $attachments_{id}: [AttachmentPayloadInput!]
+ """
+ return generated
+
+ def _steps_ingest_steps_builder(self, steps):
+ generated = ""
+ for id in range(len(steps)):
+ generated += f"""
+ step{id}: ingestStep(
+ id: $id_{id}
+ threadId: $threadId_{id}
+ rootRunId: $rootRunId_{id}
+ startTime: $startTime_{id}
+ endTime: $endTime_{id}
+ type: $type_{id}
+ error: $error_{id}
+ input: $input_{id}
+ output: $output_{id}
+ metadata: $metadata_{id}
+ parentId: $parentId_{id}
+ name: $name_{id}
+ tags: $tags_{id}
+ generation: $generation_{id}
+ scores: $scores_{id}
+ attachments: $attachments_{id}
+ ) {{
+ ok
+ message
+ }}
+ """
+ return generated
+
+ def _steps_query_builder(self, steps):
+ return f"""
+ mutation AddStep({self._steps_query_variables_builder(steps)}) {{
+ {self._steps_ingest_steps_builder(steps)}
+ }}
+ """
+
+ def _steps_variables_builder(self, steps):
+ def serialize_step(event, id):
+ result = {}
+
+ for key, value in event.items():
+ # Only keep the keys that are not None to avoid overriding existing values
+ if value is not None:
+ result[f"{key}_{id}"] = value
+
+ return result
+
+ variables = {}
+ for i in range(len(steps)):
+ step = steps[i]
+ variables.update(serialize_step(step, i))
+ return variables
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/logfire_logger.py b/.venv/lib/python3.12/site-packages/litellm/integrations/logfire_logger.py
new file mode 100644
index 00000000..516bd4a8
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/logfire_logger.py
@@ -0,0 +1,179 @@
+#### What this does ####
+# On success + failure, log events to Logfire
+
+import os
+import traceback
+import uuid
+from enum import Enum
+from typing import Any, Dict, NamedTuple
+
+from typing_extensions import LiteralString
+
+from litellm._logging import print_verbose, verbose_logger
+from litellm.litellm_core_utils.redact_messages import redact_user_api_key_info
+
+
+class SpanConfig(NamedTuple):
+ message_template: LiteralString
+ span_data: Dict[str, Any]
+
+
+class LogfireLevel(str, Enum):
+ INFO = "info"
+ ERROR = "error"
+
+
+class LogfireLogger:
+ # Class variables or attributes
+ def __init__(self):
+ try:
+ verbose_logger.debug("in init logfire logger")
+ import logfire
+
+ # only setting up logfire if we are sending to logfire
+ # in testing, we don't want to send to logfire
+ if logfire.DEFAULT_LOGFIRE_INSTANCE.config.send_to_logfire:
+ logfire.configure(token=os.getenv("LOGFIRE_TOKEN"))
+ except Exception as e:
+ print_verbose(f"Got exception on init logfire client {str(e)}")
+ raise e
+
+ def _get_span_config(self, payload) -> SpanConfig:
+ if (
+ payload["call_type"] == "completion"
+ or payload["call_type"] == "acompletion"
+ ):
+ return SpanConfig(
+ message_template="Chat Completion with {request_data[model]!r}",
+ span_data={"request_data": payload},
+ )
+ elif (
+ payload["call_type"] == "embedding" or payload["call_type"] == "aembedding"
+ ):
+ return SpanConfig(
+ message_template="Embedding Creation with {request_data[model]!r}",
+ span_data={"request_data": payload},
+ )
+ elif (
+ payload["call_type"] == "image_generation"
+ or payload["call_type"] == "aimage_generation"
+ ):
+ return SpanConfig(
+ message_template="Image Generation with {request_data[model]!r}",
+ span_data={"request_data": payload},
+ )
+ else:
+ return SpanConfig(
+ message_template="Litellm Call with {request_data[model]!r}",
+ span_data={"request_data": payload},
+ )
+
+ async def _async_log_event(
+ self,
+ kwargs,
+ response_obj,
+ start_time,
+ end_time,
+ print_verbose,
+ level: LogfireLevel,
+ ):
+ self.log_event(
+ kwargs=kwargs,
+ response_obj=response_obj,
+ start_time=start_time,
+ end_time=end_time,
+ print_verbose=print_verbose,
+ level=level,
+ )
+
+ def log_event(
+ self,
+ kwargs,
+ start_time,
+ end_time,
+ print_verbose,
+ level: LogfireLevel,
+ response_obj,
+ ):
+ try:
+ import logfire
+
+ verbose_logger.debug(
+ f"logfire Logging - Enters logging function for model {kwargs}"
+ )
+
+ if not response_obj:
+ response_obj = {}
+ litellm_params = kwargs.get("litellm_params", {})
+ metadata = (
+ litellm_params.get("metadata", {}) or {}
+ ) # if litellm_params['metadata'] == None
+ messages = kwargs.get("messages")
+ optional_params = kwargs.get("optional_params", {})
+ call_type = kwargs.get("call_type", "completion")
+ cache_hit = kwargs.get("cache_hit", False)
+ usage = response_obj.get("usage", {})
+ id = response_obj.get("id", str(uuid.uuid4()))
+ try:
+ response_time = (end_time - start_time).total_seconds()
+ except Exception:
+ response_time = None
+
+ # Clean Metadata before logging - never log raw metadata
+ # the raw metadata can contain circular references which leads to infinite recursion
+ # we clean out all extra litellm metadata params before logging
+ clean_metadata = {}
+ if isinstance(metadata, dict):
+ for key, value in metadata.items():
+ # clean litellm metadata before logging
+ if key in [
+ "endpoint",
+ "caching_groups",
+ "previous_models",
+ ]:
+ continue
+ else:
+ clean_metadata[key] = value
+
+ clean_metadata = redact_user_api_key_info(metadata=clean_metadata)
+
+ # Build the initial payload
+ payload = {
+ "id": id,
+ "call_type": call_type,
+ "cache_hit": cache_hit,
+ "startTime": start_time,
+ "endTime": end_time,
+ "responseTime (seconds)": response_time,
+ "model": kwargs.get("model", ""),
+ "user": kwargs.get("user", ""),
+ "modelParameters": optional_params,
+ "spend": kwargs.get("response_cost", 0),
+ "messages": messages,
+ "response": response_obj,
+ "usage": usage,
+ "metadata": clean_metadata,
+ }
+ logfire_openai = logfire.with_settings(custom_scope_suffix="openai")
+ message_template, span_data = self._get_span_config(payload)
+ if level == LogfireLevel.INFO:
+ logfire_openai.info(
+ message_template,
+ **span_data,
+ )
+ elif level == LogfireLevel.ERROR:
+ logfire_openai.error(
+ message_template,
+ **span_data,
+ _exc_info=True,
+ )
+ print_verbose(f"\ndd Logger - Logging payload = {payload}")
+
+ print_verbose(
+ f"Logfire Layer Logging - final response object: {response_obj}"
+ )
+ except Exception as e:
+ verbose_logger.debug(
+ f"Logfire Layer Error - {str(e)}\n{traceback.format_exc()}"
+ )
+ pass
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/lunary.py b/.venv/lib/python3.12/site-packages/litellm/integrations/lunary.py
new file mode 100644
index 00000000..fcd781e4
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/lunary.py
@@ -0,0 +1,181 @@
+#### What this does ####
+# On success + failure, log events to lunary.ai
+import importlib
+import traceback
+from datetime import datetime, timezone
+
+import packaging
+
+
+# convert to {completion: xx, tokens: xx}
+def parse_usage(usage):
+ return {
+ "completion": usage["completion_tokens"] if "completion_tokens" in usage else 0,
+ "prompt": usage["prompt_tokens"] if "prompt_tokens" in usage else 0,
+ }
+
+
+def parse_tool_calls(tool_calls):
+ if tool_calls is None:
+ return None
+
+ def clean_tool_call(tool_call):
+
+ serialized = {
+ "type": tool_call.type,
+ "id": tool_call.id,
+ "function": {
+ "name": tool_call.function.name,
+ "arguments": tool_call.function.arguments,
+ },
+ }
+
+ return serialized
+
+ return [clean_tool_call(tool_call) for tool_call in tool_calls]
+
+
+def parse_messages(input):
+
+ if input is None:
+ return None
+
+ def clean_message(message):
+ # if is string, return as is
+ if isinstance(message, str):
+ return message
+
+ if "message" in message:
+ return clean_message(message["message"])
+
+ serialized = {
+ "role": message.get("role"),
+ "content": message.get("content"),
+ }
+
+ # Only add tool_calls and function_call to res if they are set
+ if message.get("tool_calls"):
+ serialized["tool_calls"] = parse_tool_calls(message.get("tool_calls"))
+
+ return serialized
+
+ if isinstance(input, list):
+ if len(input) == 1:
+ return clean_message(input[0])
+ else:
+ return [clean_message(msg) for msg in input]
+ else:
+ return clean_message(input)
+
+
+class LunaryLogger:
+ # Class variables or attributes
+ def __init__(self):
+ try:
+ import lunary
+
+ version = importlib.metadata.version("lunary") # type: ignore
+ # if version < 0.1.43 then raise ImportError
+ if packaging.version.Version(version) < packaging.version.Version("0.1.43"): # type: ignore
+ print( # noqa
+ "Lunary version outdated. Required: >= 0.1.43. Upgrade via 'pip install lunary --upgrade'"
+ )
+ raise ImportError
+
+ self.lunary_client = lunary
+ except ImportError:
+ print( # noqa
+ "Lunary not installed. Please install it using 'pip install lunary'"
+ ) # noqa
+ raise ImportError
+
+ def log_event(
+ self,
+ kwargs,
+ type,
+ event,
+ run_id,
+ model,
+ print_verbose,
+ extra={},
+ input=None,
+ user_id=None,
+ response_obj=None,
+ start_time=datetime.now(timezone.utc),
+ end_time=datetime.now(timezone.utc),
+ error=None,
+ ):
+ try:
+ print_verbose(f"Lunary Logging - Logging request for model {model}")
+
+ template_id = None
+ litellm_params = kwargs.get("litellm_params", {})
+ optional_params = kwargs.get("optional_params", {})
+ metadata = litellm_params.get("metadata", {}) or {}
+
+ if optional_params:
+ extra = {**extra, **optional_params}
+
+ tags = metadata.get("tags", None)
+
+ if extra:
+ extra.pop("extra_body", None)
+ extra.pop("user", None)
+ template_id = extra.pop("extra_headers", {}).get("Template-Id", None)
+
+ # keep only serializable types
+ for param, value in extra.items():
+ if not isinstance(value, (str, int, bool, float)) and param != "tools":
+ try:
+ extra[param] = str(value)
+ except Exception:
+ pass
+
+ if response_obj:
+ usage = (
+ parse_usage(response_obj["usage"])
+ if "usage" in response_obj
+ else None
+ )
+
+ output = response_obj["choices"] if "choices" in response_obj else None
+
+ else:
+ usage = None
+ output = None
+
+ if error:
+ error_obj = {"stack": error}
+ else:
+ error_obj = None
+
+ self.lunary_client.track_event( # type: ignore
+ type,
+ "start",
+ run_id,
+ parent_run_id=metadata.get("parent_run_id", None),
+ user_id=user_id,
+ name=model,
+ input=parse_messages(input),
+ timestamp=start_time.astimezone(timezone.utc).isoformat(),
+ template_id=template_id,
+ metadata=metadata,
+ runtime="litellm",
+ tags=tags,
+ params=extra,
+ )
+
+ self.lunary_client.track_event( # type: ignore
+ type,
+ event,
+ run_id,
+ timestamp=end_time.astimezone(timezone.utc).isoformat(),
+ runtime="litellm",
+ error=error_obj,
+ output=parse_messages(output),
+ token_usage=usage,
+ )
+
+ except Exception:
+ print_verbose(f"Lunary Logging Error - {traceback.format_exc()}")
+ pass
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/mlflow.py b/.venv/lib/python3.12/site-packages/litellm/integrations/mlflow.py
new file mode 100644
index 00000000..193d1c4e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/mlflow.py
@@ -0,0 +1,269 @@
+import json
+import threading
+from typing import Optional
+
+from litellm._logging import verbose_logger
+from litellm.integrations.custom_logger import CustomLogger
+
+
+class MlflowLogger(CustomLogger):
+ def __init__(self):
+ from mlflow.tracking import MlflowClient
+
+ self._client = MlflowClient()
+
+ self._stream_id_to_span = {}
+ self._lock = threading.Lock() # lock for _stream_id_to_span
+
+ def log_success_event(self, kwargs, response_obj, start_time, end_time):
+ self._handle_success(kwargs, response_obj, start_time, end_time)
+
+ async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
+ self._handle_success(kwargs, response_obj, start_time, end_time)
+
+ def _handle_success(self, kwargs, response_obj, start_time, end_time):
+ """
+ Log the success event as an MLflow span.
+ Note that this method is called asynchronously in the background thread.
+ """
+ from mlflow.entities import SpanStatusCode
+
+ try:
+ verbose_logger.debug("MLflow logging start for success event")
+
+ if kwargs.get("stream"):
+ self._handle_stream_event(kwargs, response_obj, start_time, end_time)
+ else:
+ span = self._start_span_or_trace(kwargs, start_time)
+ end_time_ns = int(end_time.timestamp() * 1e9)
+ self._extract_and_set_chat_attributes(span, kwargs, response_obj)
+ self._end_span_or_trace(
+ span=span,
+ outputs=response_obj,
+ status=SpanStatusCode.OK,
+ end_time_ns=end_time_ns,
+ )
+ except Exception:
+ verbose_logger.debug("MLflow Logging Error", stack_info=True)
+
+ def _extract_and_set_chat_attributes(self, span, kwargs, response_obj):
+ try:
+ from mlflow.tracing.utils import set_span_chat_messages, set_span_chat_tools
+ except ImportError:
+ return
+
+ inputs = self._construct_input(kwargs)
+ input_messages = inputs.get("messages", [])
+ output_messages = [c.message.model_dump(exclude_none=True)
+ for c in getattr(response_obj, "choices", [])]
+ if messages := [*input_messages, *output_messages]:
+ set_span_chat_messages(span, messages)
+ if tools := inputs.get("tools"):
+ set_span_chat_tools(span, tools)
+
+ def log_failure_event(self, kwargs, response_obj, start_time, end_time):
+ self._handle_failure(kwargs, response_obj, start_time, end_time)
+
+ async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
+ self._handle_failure(kwargs, response_obj, start_time, end_time)
+
+ def _handle_failure(self, kwargs, response_obj, start_time, end_time):
+ """
+ Log the failure event as an MLflow span.
+ Note that this method is called *synchronously* unlike the success handler.
+ """
+ from mlflow.entities import SpanEvent, SpanStatusCode
+
+ try:
+ span = self._start_span_or_trace(kwargs, start_time)
+
+ end_time_ns = int(end_time.timestamp() * 1e9)
+
+ # Record exception info as event
+ if exception := kwargs.get("exception"):
+ span.add_event(SpanEvent.from_exception(exception)) # type: ignore
+
+ self._extract_and_set_chat_attributes(span, kwargs, response_obj)
+ self._end_span_or_trace(
+ span=span,
+ outputs=response_obj,
+ status=SpanStatusCode.ERROR,
+ end_time_ns=end_time_ns,
+ )
+
+ except Exception as e:
+ verbose_logger.debug(f"MLflow Logging Error - {e}", stack_info=True)
+
+ def _handle_stream_event(self, kwargs, response_obj, start_time, end_time):
+ """
+ Handle the success event for a streaming response. For streaming calls,
+ log_success_event handle is triggered for every chunk of the stream.
+ We create a single span for the entire stream request as follows:
+
+ 1. For the first chunk, start a new span and store it in the map.
+ 2. For subsequent chunks, add the chunk as an event to the span.
+ 3. For the final chunk, end the span and remove the span from the map.
+ """
+ from mlflow.entities import SpanStatusCode
+
+ litellm_call_id = kwargs.get("litellm_call_id")
+
+ if litellm_call_id not in self._stream_id_to_span:
+ with self._lock:
+ # Check again after acquiring lock
+ if litellm_call_id not in self._stream_id_to_span:
+ # Start a new span for the first chunk of the stream
+ span = self._start_span_or_trace(kwargs, start_time)
+ self._stream_id_to_span[litellm_call_id] = span
+
+ # Add chunk as event to the span
+ span = self._stream_id_to_span[litellm_call_id]
+ self._add_chunk_events(span, response_obj)
+
+ # If this is the final chunk, end the span. The final chunk
+ # has complete_streaming_response that gathers the full response.
+ if final_response := kwargs.get("complete_streaming_response"):
+ end_time_ns = int(end_time.timestamp() * 1e9)
+
+ self._extract_and_set_chat_attributes(span, kwargs, final_response)
+ self._end_span_or_trace(
+ span=span,
+ outputs=final_response,
+ status=SpanStatusCode.OK,
+ end_time_ns=end_time_ns,
+ )
+
+ # Remove the stream_id from the map
+ with self._lock:
+ self._stream_id_to_span.pop(litellm_call_id)
+
+ def _add_chunk_events(self, span, response_obj):
+ from mlflow.entities import SpanEvent
+
+ try:
+ for choice in response_obj.choices:
+ span.add_event(
+ SpanEvent(
+ name="streaming_chunk",
+ attributes={"delta": json.dumps(choice.delta.model_dump())},
+ )
+ )
+ except Exception:
+ verbose_logger.debug("Error adding chunk events to span", stack_info=True)
+
+ def _construct_input(self, kwargs):
+ """Construct span inputs with optional parameters"""
+ inputs = {"messages": kwargs.get("messages")}
+ if tools := kwargs.get("tools"):
+ inputs["tools"] = tools
+
+ for key in ["functions", "tools", "stream", "tool_choice", "user"]:
+ if value := kwargs.get("optional_params", {}).pop(key, None):
+ inputs[key] = value
+ return inputs
+
+ def _extract_attributes(self, kwargs):
+ """
+ Extract span attributes from kwargs.
+
+ With the latest version of litellm, the standard_logging_object contains
+ canonical information for logging. If it is not present, we extract
+ subset of attributes from other kwargs.
+ """
+ attributes = {
+ "litellm_call_id": kwargs.get("litellm_call_id"),
+ "call_type": kwargs.get("call_type"),
+ "model": kwargs.get("model"),
+ }
+ standard_obj = kwargs.get("standard_logging_object")
+ if standard_obj:
+ attributes.update(
+ {
+ "api_base": standard_obj.get("api_base"),
+ "cache_hit": standard_obj.get("cache_hit"),
+ "usage": {
+ "completion_tokens": standard_obj.get("completion_tokens"),
+ "prompt_tokens": standard_obj.get("prompt_tokens"),
+ "total_tokens": standard_obj.get("total_tokens"),
+ },
+ "raw_llm_response": standard_obj.get("response"),
+ "response_cost": standard_obj.get("response_cost"),
+ "saved_cache_cost": standard_obj.get("saved_cache_cost"),
+ }
+ )
+ else:
+ litellm_params = kwargs.get("litellm_params", {})
+ attributes.update(
+ {
+ "model": kwargs.get("model"),
+ "cache_hit": kwargs.get("cache_hit"),
+ "custom_llm_provider": kwargs.get("custom_llm_provider"),
+ "api_base": litellm_params.get("api_base"),
+ "response_cost": kwargs.get("response_cost"),
+ }
+ )
+ return attributes
+
+ def _get_span_type(self, call_type: Optional[str]) -> str:
+ from mlflow.entities import SpanType
+
+ if call_type in ["completion", "acompletion"]:
+ return SpanType.LLM
+ elif call_type == "embeddings":
+ return SpanType.EMBEDDING
+ else:
+ return SpanType.LLM
+
+ def _start_span_or_trace(self, kwargs, start_time):
+ """
+ Start an MLflow span or a trace.
+
+ If there is an active span, we start a new span as a child of
+ that span. Otherwise, we start a new trace.
+ """
+ import mlflow
+
+ call_type = kwargs.get("call_type", "completion")
+ span_name = f"litellm-{call_type}"
+ span_type = self._get_span_type(call_type)
+ start_time_ns = int(start_time.timestamp() * 1e9)
+
+ inputs = self._construct_input(kwargs)
+ attributes = self._extract_attributes(kwargs)
+
+ if active_span := mlflow.get_current_active_span(): # type: ignore
+ return self._client.start_span(
+ name=span_name,
+ request_id=active_span.request_id,
+ parent_id=active_span.span_id,
+ span_type=span_type,
+ inputs=inputs,
+ attributes=attributes,
+ start_time_ns=start_time_ns,
+ )
+ else:
+ return self._client.start_trace(
+ name=span_name,
+ span_type=span_type,
+ inputs=inputs,
+ attributes=attributes,
+ start_time_ns=start_time_ns,
+ )
+
+ def _end_span_or_trace(self, span, outputs, end_time_ns, status):
+ """End an MLflow span or a trace."""
+ if span.parent_id is None:
+ self._client.end_trace(
+ request_id=span.request_id,
+ outputs=outputs,
+ status=status,
+ end_time_ns=end_time_ns,
+ )
+ else:
+ self._client.end_span(
+ request_id=span.request_id,
+ span_id=span.span_id,
+ outputs=outputs,
+ status=status,
+ end_time_ns=end_time_ns,
+ )
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/openmeter.py b/.venv/lib/python3.12/site-packages/litellm/integrations/openmeter.py
new file mode 100644
index 00000000..ebfed532
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/openmeter.py
@@ -0,0 +1,132 @@
+# What is this?
+## On Success events log cost to OpenMeter - https://github.com/BerriAI/litellm/issues/1268
+
+import json
+import os
+
+import httpx
+
+import litellm
+from litellm.integrations.custom_logger import CustomLogger
+from litellm.llms.custom_httpx.http_handler import (
+ HTTPHandler,
+ get_async_httpx_client,
+ httpxSpecialProvider,
+)
+
+
+def get_utc_datetime():
+ import datetime as dt
+ from datetime import datetime
+
+ if hasattr(dt, "UTC"):
+ return datetime.now(dt.UTC) # type: ignore
+ else:
+ return datetime.utcnow() # type: ignore
+
+
+class OpenMeterLogger(CustomLogger):
+ def __init__(self) -> None:
+ super().__init__()
+ self.validate_environment()
+ self.async_http_handler = get_async_httpx_client(
+ llm_provider=httpxSpecialProvider.LoggingCallback
+ )
+ self.sync_http_handler = HTTPHandler()
+
+ def validate_environment(self):
+ """
+ Expects
+ OPENMETER_API_ENDPOINT,
+ OPENMETER_API_KEY,
+
+ in the environment
+ """
+ missing_keys = []
+ if os.getenv("OPENMETER_API_KEY", None) is None:
+ missing_keys.append("OPENMETER_API_KEY")
+
+ if len(missing_keys) > 0:
+ raise Exception("Missing keys={} in environment.".format(missing_keys))
+
+ def _common_logic(self, kwargs: dict, response_obj):
+ call_id = response_obj.get("id", kwargs.get("litellm_call_id"))
+ dt = get_utc_datetime().isoformat()
+ cost = kwargs.get("response_cost", None)
+ model = kwargs.get("model")
+ usage = {}
+ if (
+ isinstance(response_obj, litellm.ModelResponse)
+ or isinstance(response_obj, litellm.EmbeddingResponse)
+ ) and hasattr(response_obj, "usage"):
+ usage = {
+ "prompt_tokens": response_obj["usage"].get("prompt_tokens", 0),
+ "completion_tokens": response_obj["usage"].get("completion_tokens", 0),
+ "total_tokens": response_obj["usage"].get("total_tokens"),
+ }
+
+ subject = (kwargs.get("user", None),) # end-user passed in via 'user' param
+ if not subject:
+ raise Exception("OpenMeter: user is required")
+
+ return {
+ "specversion": "1.0",
+ "type": os.getenv("OPENMETER_EVENT_TYPE", "litellm_tokens"),
+ "id": call_id,
+ "time": dt,
+ "subject": subject,
+ "source": "litellm-proxy",
+ "data": {"model": model, "cost": cost, **usage},
+ }
+
+ def log_success_event(self, kwargs, response_obj, start_time, end_time):
+ _url = os.getenv("OPENMETER_API_ENDPOINT", "https://openmeter.cloud")
+ if _url.endswith("/"):
+ _url += "api/v1/events"
+ else:
+ _url += "/api/v1/events"
+
+ api_key = os.getenv("OPENMETER_API_KEY")
+
+ _data = self._common_logic(kwargs=kwargs, response_obj=response_obj)
+ _headers = {
+ "Content-Type": "application/cloudevents+json",
+ "Authorization": "Bearer {}".format(api_key),
+ }
+
+ try:
+ self.sync_http_handler.post(
+ url=_url,
+ data=json.dumps(_data),
+ headers=_headers,
+ )
+ except httpx.HTTPStatusError as e:
+ raise Exception(f"OpenMeter logging error: {e.response.text}")
+ except Exception as e:
+ raise e
+
+ async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
+ _url = os.getenv("OPENMETER_API_ENDPOINT", "https://openmeter.cloud")
+ if _url.endswith("/"):
+ _url += "api/v1/events"
+ else:
+ _url += "/api/v1/events"
+
+ api_key = os.getenv("OPENMETER_API_KEY")
+
+ _data = self._common_logic(kwargs=kwargs, response_obj=response_obj)
+ _headers = {
+ "Content-Type": "application/cloudevents+json",
+ "Authorization": "Bearer {}".format(api_key),
+ }
+
+ try:
+ await self.async_http_handler.post(
+ url=_url,
+ data=json.dumps(_data),
+ headers=_headers,
+ )
+ except httpx.HTTPStatusError as e:
+ raise Exception(f"OpenMeter logging error: {e.response.text}")
+ except Exception as e:
+ raise e
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/opentelemetry.py b/.venv/lib/python3.12/site-packages/litellm/integrations/opentelemetry.py
new file mode 100644
index 00000000..1572eb81
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/opentelemetry.py
@@ -0,0 +1,1023 @@
+import os
+from dataclasses import dataclass
+from datetime import datetime
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
+
+import litellm
+from litellm._logging import verbose_logger
+from litellm.integrations.custom_logger import CustomLogger
+from litellm.types.services import ServiceLoggerPayload
+from litellm.types.utils import (
+ ChatCompletionMessageToolCall,
+ Function,
+ StandardCallbackDynamicParams,
+ StandardLoggingPayload,
+)
+
+if TYPE_CHECKING:
+ from opentelemetry.sdk.trace.export import SpanExporter as _SpanExporter
+ from opentelemetry.trace import Span as _Span
+
+ from litellm.proxy._types import (
+ ManagementEndpointLoggingPayload as _ManagementEndpointLoggingPayload,
+ )
+ from litellm.proxy.proxy_server import UserAPIKeyAuth as _UserAPIKeyAuth
+
+ Span = _Span
+ SpanExporter = _SpanExporter
+ UserAPIKeyAuth = _UserAPIKeyAuth
+ ManagementEndpointLoggingPayload = _ManagementEndpointLoggingPayload
+else:
+ Span = Any
+ SpanExporter = Any
+ UserAPIKeyAuth = Any
+ ManagementEndpointLoggingPayload = Any
+
+
+LITELLM_TRACER_NAME = os.getenv("OTEL_TRACER_NAME", "litellm")
+LITELLM_RESOURCE: Dict[Any, Any] = {
+ "service.name": os.getenv("OTEL_SERVICE_NAME", "litellm"),
+ "deployment.environment": os.getenv("OTEL_ENVIRONMENT_NAME", "production"),
+ "model_id": os.getenv("OTEL_SERVICE_NAME", "litellm"),
+}
+RAW_REQUEST_SPAN_NAME = "raw_gen_ai_request"
+LITELLM_REQUEST_SPAN_NAME = "litellm_request"
+
+
+@dataclass
+class OpenTelemetryConfig:
+
+ exporter: Union[str, SpanExporter] = "console"
+ endpoint: Optional[str] = None
+ headers: Optional[str] = None
+
+ @classmethod
+ def from_env(cls):
+ """
+ OTEL_HEADERS=x-honeycomb-team=B85YgLm9****
+ OTEL_EXPORTER="otlp_http"
+ OTEL_ENDPOINT="https://api.honeycomb.io/v1/traces"
+
+ OTEL_HEADERS gets sent as headers = {"x-honeycomb-team": "B85YgLm96******"}
+ """
+ from opentelemetry.sdk.trace.export.in_memory_span_exporter import (
+ InMemorySpanExporter,
+ )
+
+ if os.getenv("OTEL_EXPORTER") == "in_memory":
+ return cls(exporter=InMemorySpanExporter())
+ return cls(
+ exporter=os.getenv("OTEL_EXPORTER", "console"),
+ endpoint=os.getenv("OTEL_ENDPOINT"),
+ headers=os.getenv(
+ "OTEL_HEADERS"
+ ), # example: OTEL_HEADERS=x-honeycomb-team=B85YgLm96***"
+ )
+
+
+class OpenTelemetry(CustomLogger):
+ def __init__(
+ self,
+ config: Optional[OpenTelemetryConfig] = None,
+ callback_name: Optional[str] = None,
+ **kwargs,
+ ):
+ from opentelemetry import trace
+ from opentelemetry.sdk.resources import Resource
+ from opentelemetry.sdk.trace import TracerProvider
+ from opentelemetry.trace import SpanKind
+
+ if config is None:
+ config = OpenTelemetryConfig.from_env()
+
+ self.config = config
+ self.OTEL_EXPORTER = self.config.exporter
+ self.OTEL_ENDPOINT = self.config.endpoint
+ self.OTEL_HEADERS = self.config.headers
+ provider = TracerProvider(resource=Resource(attributes=LITELLM_RESOURCE))
+ provider.add_span_processor(self._get_span_processor())
+ self.callback_name = callback_name
+
+ trace.set_tracer_provider(provider)
+ self.tracer = trace.get_tracer(LITELLM_TRACER_NAME)
+
+ self.span_kind = SpanKind
+
+ _debug_otel = str(os.getenv("DEBUG_OTEL", "False")).lower()
+
+ if _debug_otel == "true":
+ # Set up logging
+ import logging
+
+ logging.basicConfig(level=logging.DEBUG)
+ logging.getLogger(__name__)
+
+ # Enable OpenTelemetry logging
+ otel_exporter_logger = logging.getLogger("opentelemetry.sdk.trace.export")
+ otel_exporter_logger.setLevel(logging.DEBUG)
+
+ # init CustomLogger params
+ super().__init__(**kwargs)
+ self._init_otel_logger_on_litellm_proxy()
+
+ def _init_otel_logger_on_litellm_proxy(self):
+ """
+ Initializes OpenTelemetry for litellm proxy server
+
+ - Adds Otel as a service callback
+ - Sets `proxy_server.open_telemetry_logger` to self
+ """
+ from litellm.proxy import proxy_server
+
+ # Add Otel as a service callback
+ if "otel" not in litellm.service_callback:
+ litellm.service_callback.append("otel")
+ setattr(proxy_server, "open_telemetry_logger", self)
+
+ def log_success_event(self, kwargs, response_obj, start_time, end_time):
+ self._handle_sucess(kwargs, response_obj, start_time, end_time)
+
+ def log_failure_event(self, kwargs, response_obj, start_time, end_time):
+ self._handle_failure(kwargs, response_obj, start_time, end_time)
+
+ async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
+ self._handle_sucess(kwargs, response_obj, start_time, end_time)
+
+ async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
+ self._handle_failure(kwargs, response_obj, start_time, end_time)
+
+ async def async_service_success_hook(
+ self,
+ payload: ServiceLoggerPayload,
+ parent_otel_span: Optional[Span] = None,
+ start_time: Optional[Union[datetime, float]] = None,
+ end_time: Optional[Union[datetime, float]] = None,
+ event_metadata: Optional[dict] = None,
+ ):
+
+ from opentelemetry import trace
+ from opentelemetry.trace import Status, StatusCode
+
+ _start_time_ns = 0
+ _end_time_ns = 0
+
+ if isinstance(start_time, float):
+ _start_time_ns = int(start_time * 1e9)
+ else:
+ _start_time_ns = self._to_ns(start_time)
+
+ if isinstance(end_time, float):
+ _end_time_ns = int(end_time * 1e9)
+ else:
+ _end_time_ns = self._to_ns(end_time)
+
+ if parent_otel_span is not None:
+ _span_name = payload.service
+ service_logging_span = self.tracer.start_span(
+ name=_span_name,
+ context=trace.set_span_in_context(parent_otel_span),
+ start_time=_start_time_ns,
+ )
+ self.safe_set_attribute(
+ span=service_logging_span,
+ key="call_type",
+ value=payload.call_type,
+ )
+ self.safe_set_attribute(
+ span=service_logging_span,
+ key="service",
+ value=payload.service.value,
+ )
+
+ if event_metadata:
+ for key, value in event_metadata.items():
+ if value is None:
+ value = "None"
+ if isinstance(value, dict):
+ try:
+ value = str(value)
+ except Exception:
+ value = "litellm logging error - could_not_json_serialize"
+ self.safe_set_attribute(
+ span=service_logging_span,
+ key=key,
+ value=value,
+ )
+ service_logging_span.set_status(Status(StatusCode.OK))
+ service_logging_span.end(end_time=_end_time_ns)
+
+ async def async_service_failure_hook(
+ self,
+ payload: ServiceLoggerPayload,
+ error: Optional[str] = "",
+ parent_otel_span: Optional[Span] = None,
+ start_time: Optional[Union[datetime, float]] = None,
+ end_time: Optional[Union[float, datetime]] = None,
+ event_metadata: Optional[dict] = None,
+ ):
+
+ from opentelemetry import trace
+ from opentelemetry.trace import Status, StatusCode
+
+ _start_time_ns = 0
+ _end_time_ns = 0
+
+ if isinstance(start_time, float):
+ _start_time_ns = int(int(start_time) * 1e9)
+ else:
+ _start_time_ns = self._to_ns(start_time)
+
+ if isinstance(end_time, float):
+ _end_time_ns = int(int(end_time) * 1e9)
+ else:
+ _end_time_ns = self._to_ns(end_time)
+
+ if parent_otel_span is not None:
+ _span_name = payload.service
+ service_logging_span = self.tracer.start_span(
+ name=_span_name,
+ context=trace.set_span_in_context(parent_otel_span),
+ start_time=_start_time_ns,
+ )
+ self.safe_set_attribute(
+ span=service_logging_span,
+ key="call_type",
+ value=payload.call_type,
+ )
+ self.safe_set_attribute(
+ span=service_logging_span,
+ key="service",
+ value=payload.service.value,
+ )
+ if error:
+ self.safe_set_attribute(
+ span=service_logging_span,
+ key="error",
+ value=error,
+ )
+ if event_metadata:
+ for key, value in event_metadata.items():
+ if isinstance(value, dict):
+ try:
+ value = str(value)
+ except Exception:
+ value = "litllm logging error - could_not_json_serialize"
+ self.safe_set_attribute(
+ span=service_logging_span,
+ key=key,
+ value=value,
+ )
+
+ service_logging_span.set_status(Status(StatusCode.ERROR))
+ service_logging_span.end(end_time=_end_time_ns)
+
+ async def async_post_call_failure_hook(
+ self,
+ request_data: dict,
+ original_exception: Exception,
+ user_api_key_dict: UserAPIKeyAuth,
+ ):
+ from opentelemetry import trace
+ from opentelemetry.trace import Status, StatusCode
+
+ parent_otel_span = user_api_key_dict.parent_otel_span
+ if parent_otel_span is not None:
+ parent_otel_span.set_status(Status(StatusCode.ERROR))
+ _span_name = "Failed Proxy Server Request"
+
+ # Exception Logging Child Span
+ exception_logging_span = self.tracer.start_span(
+ name=_span_name,
+ context=trace.set_span_in_context(parent_otel_span),
+ )
+ self.safe_set_attribute(
+ span=exception_logging_span,
+ key="exception",
+ value=str(original_exception),
+ )
+ exception_logging_span.set_status(Status(StatusCode.ERROR))
+ exception_logging_span.end(end_time=self._to_ns(datetime.now()))
+
+ # End Parent OTEL Sspan
+ parent_otel_span.end(end_time=self._to_ns(datetime.now()))
+
+ def _handle_sucess(self, kwargs, response_obj, start_time, end_time):
+ from opentelemetry import trace
+ from opentelemetry.trace import Status, StatusCode
+
+ verbose_logger.debug(
+ "OpenTelemetry Logger: Logging kwargs: %s, OTEL config settings=%s",
+ kwargs,
+ self.config,
+ )
+ _parent_context, parent_otel_span = self._get_span_context(kwargs)
+
+ self._add_dynamic_span_processor_if_needed(kwargs)
+
+ # Span 1: Requst sent to litellm SDK
+ span = self.tracer.start_span(
+ name=self._get_span_name(kwargs),
+ start_time=self._to_ns(start_time),
+ context=_parent_context,
+ )
+ span.set_status(Status(StatusCode.OK))
+ self.set_attributes(span, kwargs, response_obj)
+
+ if litellm.turn_off_message_logging is True:
+ pass
+ elif self.message_logging is not True:
+ pass
+ else:
+ # Span 2: Raw Request / Response to LLM
+ raw_request_span = self.tracer.start_span(
+ name=RAW_REQUEST_SPAN_NAME,
+ start_time=self._to_ns(start_time),
+ context=trace.set_span_in_context(span),
+ )
+
+ raw_request_span.set_status(Status(StatusCode.OK))
+ self.set_raw_request_attributes(raw_request_span, kwargs, response_obj)
+ raw_request_span.end(end_time=self._to_ns(end_time))
+
+ span.end(end_time=self._to_ns(end_time))
+
+ if parent_otel_span is not None:
+ parent_otel_span.end(end_time=self._to_ns(datetime.now()))
+
+ def _add_dynamic_span_processor_if_needed(self, kwargs):
+ """
+ Helper method to add a span processor with dynamic headers if needed.
+
+ This allows for per-request configuration of telemetry exporters by
+ extracting headers from standard_callback_dynamic_params.
+ """
+ from opentelemetry import trace
+
+ standard_callback_dynamic_params: Optional[StandardCallbackDynamicParams] = (
+ kwargs.get("standard_callback_dynamic_params")
+ )
+ if not standard_callback_dynamic_params:
+ return
+
+ # Extract headers from dynamic params
+ dynamic_headers = {}
+
+ # Handle Arize headers
+ if standard_callback_dynamic_params.get("arize_space_key"):
+ dynamic_headers["space_key"] = standard_callback_dynamic_params.get(
+ "arize_space_key"
+ )
+ if standard_callback_dynamic_params.get("arize_api_key"):
+ dynamic_headers["api_key"] = standard_callback_dynamic_params.get(
+ "arize_api_key"
+ )
+
+ # Only create a span processor if we have headers to use
+ if len(dynamic_headers) > 0:
+ from opentelemetry.sdk.trace import TracerProvider
+
+ provider = trace.get_tracer_provider()
+ if isinstance(provider, TracerProvider):
+ span_processor = self._get_span_processor(
+ dynamic_headers=dynamic_headers
+ )
+ provider.add_span_processor(span_processor)
+
+ def _handle_failure(self, kwargs, response_obj, start_time, end_time):
+ from opentelemetry.trace import Status, StatusCode
+
+ verbose_logger.debug(
+ "OpenTelemetry Logger: Failure HandlerLogging kwargs: %s, OTEL config settings=%s",
+ kwargs,
+ self.config,
+ )
+ _parent_context, parent_otel_span = self._get_span_context(kwargs)
+
+ # Span 1: Requst sent to litellm SDK
+ span = self.tracer.start_span(
+ name=self._get_span_name(kwargs),
+ start_time=self._to_ns(start_time),
+ context=_parent_context,
+ )
+ span.set_status(Status(StatusCode.ERROR))
+ self.set_attributes(span, kwargs, response_obj)
+ span.end(end_time=self._to_ns(end_time))
+
+ if parent_otel_span is not None:
+ parent_otel_span.end(end_time=self._to_ns(datetime.now()))
+
+ def set_tools_attributes(self, span: Span, tools):
+ import json
+
+ from litellm.proxy._types import SpanAttributes
+
+ if not tools:
+ return
+
+ try:
+ for i, tool in enumerate(tools):
+ function = tool.get("function")
+ if not function:
+ continue
+
+ prefix = f"{SpanAttributes.LLM_REQUEST_FUNCTIONS}.{i}"
+ self.safe_set_attribute(
+ span=span,
+ key=f"{prefix}.name",
+ value=function.get("name"),
+ )
+ self.safe_set_attribute(
+ span=span,
+ key=f"{prefix}.description",
+ value=function.get("description"),
+ )
+ self.safe_set_attribute(
+ span=span,
+ key=f"{prefix}.parameters",
+ value=json.dumps(function.get("parameters")),
+ )
+ except Exception as e:
+ verbose_logger.error(
+ "OpenTelemetry: Error setting tools attributes: %s", str(e)
+ )
+ pass
+
+ def cast_as_primitive_value_type(self, value) -> Union[str, bool, int, float]:
+ """
+ Casts the value to a primitive OTEL type if it is not already a primitive type.
+
+ OTEL supports - str, bool, int, float
+
+ If it's not a primitive type, then it's converted to a string
+ """
+ if value is None:
+ return ""
+ if isinstance(value, (str, bool, int, float)):
+ return value
+ try:
+ return str(value)
+ except Exception:
+ return ""
+
+ @staticmethod
+ def _tool_calls_kv_pair(
+ tool_calls: List[ChatCompletionMessageToolCall],
+ ) -> Dict[str, Any]:
+ from litellm.proxy._types import SpanAttributes
+
+ kv_pairs: Dict[str, Any] = {}
+ for idx, tool_call in enumerate(tool_calls):
+ _function = tool_call.get("function")
+ if not _function:
+ continue
+
+ keys = Function.__annotations__.keys()
+ for key in keys:
+ _value = _function.get(key)
+ if _value:
+ kv_pairs[
+ f"{SpanAttributes.LLM_COMPLETIONS}.{idx}.function_call.{key}"
+ ] = _value
+
+ return kv_pairs
+
+ def set_attributes( # noqa: PLR0915
+ self, span: Span, kwargs, response_obj: Optional[Any]
+ ):
+ try:
+ if self.callback_name == "arize_phoenix":
+ from litellm.integrations.arize.arize_phoenix import ArizePhoenixLogger
+
+ ArizePhoenixLogger.set_arize_phoenix_attributes(
+ span, kwargs, response_obj
+ )
+ return
+ elif self.callback_name == "langtrace":
+ from litellm.integrations.langtrace import LangtraceAttributes
+
+ LangtraceAttributes().set_langtrace_attributes(
+ span, kwargs, response_obj
+ )
+ return
+ from litellm.proxy._types import SpanAttributes
+
+ optional_params = kwargs.get("optional_params", {})
+ litellm_params = kwargs.get("litellm_params", {}) or {}
+ standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
+ "standard_logging_object"
+ )
+ if standard_logging_payload is None:
+ raise ValueError("standard_logging_object not found in kwargs")
+
+ # https://github.com/open-telemetry/semantic-conventions/blob/main/model/registry/gen-ai.yaml
+ # Following Conventions here: https://github.com/open-telemetry/semantic-conventions/blob/main/docs/gen-ai/llm-spans.md
+ #############################################
+ ############ LLM CALL METADATA ##############
+ #############################################
+ metadata = standard_logging_payload["metadata"]
+ for key, value in metadata.items():
+ self.safe_set_attribute(
+ span=span, key="metadata.{}".format(key), value=value
+ )
+
+ #############################################
+ ########## LLM Request Attributes ###########
+ #############################################
+
+ # The name of the LLM a request is being made to
+ if kwargs.get("model"):
+ self.safe_set_attribute(
+ span=span,
+ key=SpanAttributes.LLM_REQUEST_MODEL,
+ value=kwargs.get("model"),
+ )
+
+ # The LLM request type
+ self.safe_set_attribute(
+ span=span,
+ key=SpanAttributes.LLM_REQUEST_TYPE,
+ value=standard_logging_payload["call_type"],
+ )
+
+ # The Generative AI Provider: Azure, OpenAI, etc.
+ self.safe_set_attribute(
+ span=span,
+ key=SpanAttributes.LLM_SYSTEM,
+ value=litellm_params.get("custom_llm_provider", "Unknown"),
+ )
+
+ # The maximum number of tokens the LLM generates for a request.
+ if optional_params.get("max_tokens"):
+ self.safe_set_attribute(
+ span=span,
+ key=SpanAttributes.LLM_REQUEST_MAX_TOKENS,
+ value=optional_params.get("max_tokens"),
+ )
+
+ # The temperature setting for the LLM request.
+ if optional_params.get("temperature"):
+ self.safe_set_attribute(
+ span=span,
+ key=SpanAttributes.LLM_REQUEST_TEMPERATURE,
+ value=optional_params.get("temperature"),
+ )
+
+ # The top_p sampling setting for the LLM request.
+ if optional_params.get("top_p"):
+ self.safe_set_attribute(
+ span=span,
+ key=SpanAttributes.LLM_REQUEST_TOP_P,
+ value=optional_params.get("top_p"),
+ )
+
+ self.safe_set_attribute(
+ span=span,
+ key=SpanAttributes.LLM_IS_STREAMING,
+ value=str(optional_params.get("stream", False)),
+ )
+
+ if optional_params.get("user"):
+ self.safe_set_attribute(
+ span=span,
+ key=SpanAttributes.LLM_USER,
+ value=optional_params.get("user"),
+ )
+
+ # The unique identifier for the completion.
+ if response_obj and response_obj.get("id"):
+ self.safe_set_attribute(
+ span=span, key="gen_ai.response.id", value=response_obj.get("id")
+ )
+
+ # The model used to generate the response.
+ if response_obj and response_obj.get("model"):
+ self.safe_set_attribute(
+ span=span,
+ key=SpanAttributes.LLM_RESPONSE_MODEL,
+ value=response_obj.get("model"),
+ )
+
+ usage = response_obj and response_obj.get("usage")
+ if usage:
+ self.safe_set_attribute(
+ span=span,
+ key=SpanAttributes.LLM_USAGE_TOTAL_TOKENS,
+ value=usage.get("total_tokens"),
+ )
+
+ # The number of tokens used in the LLM response (completion).
+ self.safe_set_attribute(
+ span=span,
+ key=SpanAttributes.LLM_USAGE_COMPLETION_TOKENS,
+ value=usage.get("completion_tokens"),
+ )
+
+ # The number of tokens used in the LLM prompt.
+ self.safe_set_attribute(
+ span=span,
+ key=SpanAttributes.LLM_USAGE_PROMPT_TOKENS,
+ value=usage.get("prompt_tokens"),
+ )
+
+ ########################################################################
+ ########## LLM Request Medssages / tools / content Attributes ###########
+ #########################################################################
+
+ if litellm.turn_off_message_logging is True:
+ return
+ if self.message_logging is not True:
+ return
+
+ if optional_params.get("tools"):
+ tools = optional_params["tools"]
+ self.set_tools_attributes(span, tools)
+
+ if kwargs.get("messages"):
+ for idx, prompt in enumerate(kwargs.get("messages")):
+ if prompt.get("role"):
+ self.safe_set_attribute(
+ span=span,
+ key=f"{SpanAttributes.LLM_PROMPTS}.{idx}.role",
+ value=prompt.get("role"),
+ )
+
+ if prompt.get("content"):
+ if not isinstance(prompt.get("content"), str):
+ prompt["content"] = str(prompt.get("content"))
+ self.safe_set_attribute(
+ span=span,
+ key=f"{SpanAttributes.LLM_PROMPTS}.{idx}.content",
+ value=prompt.get("content"),
+ )
+ #############################################
+ ########## LLM Response Attributes ##########
+ #############################################
+ if response_obj is not None:
+ if response_obj.get("choices"):
+ for idx, choice in enumerate(response_obj.get("choices")):
+ if choice.get("finish_reason"):
+ self.safe_set_attribute(
+ span=span,
+ key=f"{SpanAttributes.LLM_COMPLETIONS}.{idx}.finish_reason",
+ value=choice.get("finish_reason"),
+ )
+ if choice.get("message"):
+ if choice.get("message").get("role"):
+ self.safe_set_attribute(
+ span=span,
+ key=f"{SpanAttributes.LLM_COMPLETIONS}.{idx}.role",
+ value=choice.get("message").get("role"),
+ )
+ if choice.get("message").get("content"):
+ if not isinstance(
+ choice.get("message").get("content"), str
+ ):
+ choice["message"]["content"] = str(
+ choice.get("message").get("content")
+ )
+ self.safe_set_attribute(
+ span=span,
+ key=f"{SpanAttributes.LLM_COMPLETIONS}.{idx}.content",
+ value=choice.get("message").get("content"),
+ )
+
+ message = choice.get("message")
+ tool_calls = message.get("tool_calls")
+ if tool_calls:
+ kv_pairs = OpenTelemetry._tool_calls_kv_pair(tool_calls) # type: ignore
+ for key, value in kv_pairs.items():
+ self.safe_set_attribute(
+ span=span,
+ key=key,
+ value=value,
+ )
+
+ except Exception as e:
+ verbose_logger.exception(
+ "OpenTelemetry logging error in set_attributes %s", str(e)
+ )
+
+ def _cast_as_primitive_value_type(self, value) -> Union[str, bool, int, float]:
+ """
+ Casts the value to a primitive OTEL type if it is not already a primitive type.
+
+ OTEL supports - str, bool, int, float
+
+ If it's not a primitive type, then it's converted to a string
+ """
+ if value is None:
+ return ""
+ if isinstance(value, (str, bool, int, float)):
+ return value
+ try:
+ return str(value)
+ except Exception:
+ return ""
+
+ def safe_set_attribute(self, span: Span, key: str, value: Any):
+ """
+ Safely sets an attribute on the span, ensuring the value is a primitive type.
+ """
+ primitive_value = self._cast_as_primitive_value_type(value)
+ span.set_attribute(key, primitive_value)
+
+ def set_raw_request_attributes(self, span: Span, kwargs, response_obj):
+
+ kwargs.get("optional_params", {})
+ litellm_params = kwargs.get("litellm_params", {}) or {}
+ custom_llm_provider = litellm_params.get("custom_llm_provider", "Unknown")
+
+ _raw_response = kwargs.get("original_response")
+ _additional_args = kwargs.get("additional_args", {}) or {}
+ complete_input_dict = _additional_args.get("complete_input_dict")
+ #############################################
+ ########## LLM Request Attributes ###########
+ #############################################
+
+ # OTEL Attributes for the RAW Request to https://docs.anthropic.com/en/api/messages
+ if complete_input_dict and isinstance(complete_input_dict, dict):
+ for param, val in complete_input_dict.items():
+ self.safe_set_attribute(
+ span=span, key=f"llm.{custom_llm_provider}.{param}", value=val
+ )
+
+ #############################################
+ ########## LLM Response Attributes ##########
+ #############################################
+ if _raw_response and isinstance(_raw_response, str):
+ # cast sr -> dict
+ import json
+
+ try:
+ _raw_response = json.loads(_raw_response)
+ for param, val in _raw_response.items():
+ self.safe_set_attribute(
+ span=span,
+ key=f"llm.{custom_llm_provider}.{param}",
+ value=val,
+ )
+ except json.JSONDecodeError:
+ verbose_logger.debug(
+ "litellm.integrations.opentelemetry.py::set_raw_request_attributes() - raw_response not json string - {}".format(
+ _raw_response
+ )
+ )
+
+ self.safe_set_attribute(
+ span=span,
+ key=f"llm.{custom_llm_provider}.stringified_raw_response",
+ value=_raw_response,
+ )
+
+ def _to_ns(self, dt):
+ return int(dt.timestamp() * 1e9)
+
+ def _get_span_name(self, kwargs):
+ return LITELLM_REQUEST_SPAN_NAME
+
+ def get_traceparent_from_header(self, headers):
+ if headers is None:
+ return None
+ _traceparent = headers.get("traceparent", None)
+ if _traceparent is None:
+ return None
+
+ from opentelemetry.trace.propagation.tracecontext import (
+ TraceContextTextMapPropagator,
+ )
+
+ propagator = TraceContextTextMapPropagator()
+ carrier = {"traceparent": _traceparent}
+ _parent_context = propagator.extract(carrier=carrier)
+
+ return _parent_context
+
+ def _get_span_context(self, kwargs):
+ from opentelemetry import trace
+ from opentelemetry.trace.propagation.tracecontext import (
+ TraceContextTextMapPropagator,
+ )
+
+ litellm_params = kwargs.get("litellm_params", {}) or {}
+ proxy_server_request = litellm_params.get("proxy_server_request", {}) or {}
+ headers = proxy_server_request.get("headers", {}) or {}
+ traceparent = headers.get("traceparent", None)
+ _metadata = litellm_params.get("metadata", {}) or {}
+ parent_otel_span = _metadata.get("litellm_parent_otel_span", None)
+
+ """
+ Two way to use parents in opentelemetry
+ - using the traceparent header
+ - using the parent_otel_span in the [metadata][parent_otel_span]
+ """
+ if parent_otel_span is not None:
+ return trace.set_span_in_context(parent_otel_span), parent_otel_span
+
+ if traceparent is None:
+ return None, None
+ else:
+ carrier = {"traceparent": traceparent}
+ return TraceContextTextMapPropagator().extract(carrier=carrier), None
+
+ def _get_span_processor(self, dynamic_headers: Optional[dict] = None):
+ from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import (
+ OTLPSpanExporter as OTLPSpanExporterGRPC,
+ )
+ from opentelemetry.exporter.otlp.proto.http.trace_exporter import (
+ OTLPSpanExporter as OTLPSpanExporterHTTP,
+ )
+ from opentelemetry.sdk.trace.export import (
+ BatchSpanProcessor,
+ ConsoleSpanExporter,
+ SimpleSpanProcessor,
+ SpanExporter,
+ )
+
+ verbose_logger.debug(
+ "OpenTelemetry Logger, initializing span processor \nself.OTEL_EXPORTER: %s\nself.OTEL_ENDPOINT: %s\nself.OTEL_HEADERS: %s",
+ self.OTEL_EXPORTER,
+ self.OTEL_ENDPOINT,
+ self.OTEL_HEADERS,
+ )
+ _split_otel_headers = OpenTelemetry._get_headers_dictionary(
+ headers=dynamic_headers or self.OTEL_HEADERS
+ )
+
+ if isinstance(self.OTEL_EXPORTER, SpanExporter):
+ verbose_logger.debug(
+ "OpenTelemetry: intiializing SpanExporter. Value of OTEL_EXPORTER: %s",
+ self.OTEL_EXPORTER,
+ )
+ return SimpleSpanProcessor(self.OTEL_EXPORTER)
+
+ if self.OTEL_EXPORTER == "console":
+ verbose_logger.debug(
+ "OpenTelemetry: intiializing console exporter. Value of OTEL_EXPORTER: %s",
+ self.OTEL_EXPORTER,
+ )
+ return BatchSpanProcessor(ConsoleSpanExporter())
+ elif self.OTEL_EXPORTER == "otlp_http":
+ verbose_logger.debug(
+ "OpenTelemetry: intiializing http exporter. Value of OTEL_EXPORTER: %s",
+ self.OTEL_EXPORTER,
+ )
+ return BatchSpanProcessor(
+ OTLPSpanExporterHTTP(
+ endpoint=self.OTEL_ENDPOINT, headers=_split_otel_headers
+ ),
+ )
+ elif self.OTEL_EXPORTER == "otlp_grpc":
+ verbose_logger.debug(
+ "OpenTelemetry: intiializing grpc exporter. Value of OTEL_EXPORTER: %s",
+ self.OTEL_EXPORTER,
+ )
+ return BatchSpanProcessor(
+ OTLPSpanExporterGRPC(
+ endpoint=self.OTEL_ENDPOINT, headers=_split_otel_headers
+ ),
+ )
+ else:
+ verbose_logger.debug(
+ "OpenTelemetry: intiializing console exporter. Value of OTEL_EXPORTER: %s",
+ self.OTEL_EXPORTER,
+ )
+ return BatchSpanProcessor(ConsoleSpanExporter())
+
+ @staticmethod
+ def _get_headers_dictionary(headers: Optional[Union[str, dict]]) -> Dict[str, str]:
+ """
+ Convert a string or dictionary of headers into a dictionary of headers.
+ """
+ _split_otel_headers: Dict[str, str] = {}
+ if headers:
+ if isinstance(headers, str):
+ # when passed HEADERS="x-honeycomb-team=B85YgLm96******"
+ # Split only on first '=' occurrence
+ parts = headers.split("=", 1)
+ if len(parts) == 2:
+ _split_otel_headers = {parts[0]: parts[1]}
+ else:
+ _split_otel_headers = {}
+ elif isinstance(headers, dict):
+ _split_otel_headers = headers
+ return _split_otel_headers
+
+ async def async_management_endpoint_success_hook(
+ self,
+ logging_payload: ManagementEndpointLoggingPayload,
+ parent_otel_span: Optional[Span] = None,
+ ):
+
+ from opentelemetry import trace
+ from opentelemetry.trace import Status, StatusCode
+
+ _start_time_ns = 0
+ _end_time_ns = 0
+
+ start_time = logging_payload.start_time
+ end_time = logging_payload.end_time
+
+ if isinstance(start_time, float):
+ _start_time_ns = int(start_time * 1e9)
+ else:
+ _start_time_ns = self._to_ns(start_time)
+
+ if isinstance(end_time, float):
+ _end_time_ns = int(end_time * 1e9)
+ else:
+ _end_time_ns = self._to_ns(end_time)
+
+ if parent_otel_span is not None:
+ _span_name = logging_payload.route
+ management_endpoint_span = self.tracer.start_span(
+ name=_span_name,
+ context=trace.set_span_in_context(parent_otel_span),
+ start_time=_start_time_ns,
+ )
+
+ _request_data = logging_payload.request_data
+ if _request_data is not None:
+ for key, value in _request_data.items():
+ self.safe_set_attribute(
+ span=management_endpoint_span,
+ key=f"request.{key}",
+ value=value,
+ )
+
+ _response = logging_payload.response
+ if _response is not None:
+ for key, value in _response.items():
+ self.safe_set_attribute(
+ span=management_endpoint_span,
+ key=f"response.{key}",
+ value=value,
+ )
+
+ management_endpoint_span.set_status(Status(StatusCode.OK))
+ management_endpoint_span.end(end_time=_end_time_ns)
+
+ async def async_management_endpoint_failure_hook(
+ self,
+ logging_payload: ManagementEndpointLoggingPayload,
+ parent_otel_span: Optional[Span] = None,
+ ):
+
+ from opentelemetry import trace
+ from opentelemetry.trace import Status, StatusCode
+
+ _start_time_ns = 0
+ _end_time_ns = 0
+
+ start_time = logging_payload.start_time
+ end_time = logging_payload.end_time
+
+ if isinstance(start_time, float):
+ _start_time_ns = int(int(start_time) * 1e9)
+ else:
+ _start_time_ns = self._to_ns(start_time)
+
+ if isinstance(end_time, float):
+ _end_time_ns = int(int(end_time) * 1e9)
+ else:
+ _end_time_ns = self._to_ns(end_time)
+
+ if parent_otel_span is not None:
+ _span_name = logging_payload.route
+ management_endpoint_span = self.tracer.start_span(
+ name=_span_name,
+ context=trace.set_span_in_context(parent_otel_span),
+ start_time=_start_time_ns,
+ )
+
+ _request_data = logging_payload.request_data
+ if _request_data is not None:
+ for key, value in _request_data.items():
+ self.safe_set_attribute(
+ span=management_endpoint_span,
+ key=f"request.{key}",
+ value=value,
+ )
+
+ _exception = logging_payload.exception
+ self.safe_set_attribute(
+ span=management_endpoint_span,
+ key="exception",
+ value=str(_exception),
+ )
+ management_endpoint_span.set_status(Status(StatusCode.ERROR))
+ management_endpoint_span.end(end_time=_end_time_ns)
+
+ def create_litellm_proxy_request_started_span(
+ self,
+ start_time: datetime,
+ headers: dict,
+ ) -> Optional[Span]:
+ """
+ Create a span for the received proxy server request.
+ """
+ return self.tracer.start_span(
+ name="Received Proxy Server Request",
+ start_time=self._to_ns(start_time),
+ context=self.get_traceparent_from_header(headers=headers),
+ kind=self.span_kind.SERVER,
+ )
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/opik/opik.py b/.venv/lib/python3.12/site-packages/litellm/integrations/opik/opik.py
new file mode 100644
index 00000000..1f7f18f3
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/opik/opik.py
@@ -0,0 +1,326 @@
+"""
+Opik Logger that logs LLM events to an Opik server
+"""
+
+import asyncio
+import json
+import traceback
+from typing import Dict, List
+
+from litellm._logging import verbose_logger
+from litellm.integrations.custom_batch_logger import CustomBatchLogger
+from litellm.llms.custom_httpx.http_handler import (
+ _get_httpx_client,
+ get_async_httpx_client,
+ httpxSpecialProvider,
+)
+
+from .utils import (
+ create_usage_object,
+ create_uuid7,
+ get_opik_config_variable,
+ get_traces_and_spans_from_payload,
+)
+
+
+class OpikLogger(CustomBatchLogger):
+ """
+ Opik Logger for logging events to an Opik Server
+ """
+
+ def __init__(self, **kwargs):
+ self.async_httpx_client = get_async_httpx_client(
+ llm_provider=httpxSpecialProvider.LoggingCallback
+ )
+ self.sync_httpx_client = _get_httpx_client()
+
+ self.opik_project_name = get_opik_config_variable(
+ "project_name",
+ user_value=kwargs.get("project_name", None),
+ default_value="Default Project",
+ )
+
+ opik_base_url = get_opik_config_variable(
+ "url_override",
+ user_value=kwargs.get("url", None),
+ default_value="https://www.comet.com/opik/api",
+ )
+ opik_api_key = get_opik_config_variable(
+ "api_key", user_value=kwargs.get("api_key", None), default_value=None
+ )
+ opik_workspace = get_opik_config_variable(
+ "workspace", user_value=kwargs.get("workspace", None), default_value=None
+ )
+
+ self.trace_url = f"{opik_base_url}/v1/private/traces/batch"
+ self.span_url = f"{opik_base_url}/v1/private/spans/batch"
+
+ self.headers = {}
+ if opik_workspace:
+ self.headers["Comet-Workspace"] = opik_workspace
+
+ if opik_api_key:
+ self.headers["authorization"] = opik_api_key
+
+ self.opik_workspace = opik_workspace
+ self.opik_api_key = opik_api_key
+ try:
+ asyncio.create_task(self.periodic_flush())
+ self.flush_lock = asyncio.Lock()
+ except Exception as e:
+ verbose_logger.exception(
+ f"OpikLogger - Asynchronous processing not initialized as we are not running in an async context {str(e)}"
+ )
+ self.flush_lock = None
+
+ super().__init__(**kwargs, flush_lock=self.flush_lock)
+
+ async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
+ try:
+ opik_payload = self._create_opik_payload(
+ kwargs=kwargs,
+ response_obj=response_obj,
+ start_time=start_time,
+ end_time=end_time,
+ )
+
+ self.log_queue.extend(opik_payload)
+ verbose_logger.debug(
+ f"OpikLogger added event to log_queue - Will flush in {self.flush_interval} seconds..."
+ )
+
+ if len(self.log_queue) >= self.batch_size:
+ verbose_logger.debug("OpikLogger - Flushing batch")
+ await self.flush_queue()
+ except Exception as e:
+ verbose_logger.exception(
+ f"OpikLogger failed to log success event - {str(e)}\n{traceback.format_exc()}"
+ )
+
+ def _sync_send(self, url: str, headers: Dict[str, str], batch: Dict):
+ try:
+ response = self.sync_httpx_client.post(
+ url=url, headers=headers, json=batch # type: ignore
+ )
+ response.raise_for_status()
+ if response.status_code != 204:
+ raise Exception(
+ f"Response from opik API status_code: {response.status_code}, text: {response.text}"
+ )
+ except Exception as e:
+ verbose_logger.exception(
+ f"OpikLogger failed to send batch - {str(e)}\n{traceback.format_exc()}"
+ )
+
+ def log_success_event(self, kwargs, response_obj, start_time, end_time):
+ try:
+ opik_payload = self._create_opik_payload(
+ kwargs=kwargs,
+ response_obj=response_obj,
+ start_time=start_time,
+ end_time=end_time,
+ )
+
+ traces, spans = get_traces_and_spans_from_payload(opik_payload)
+ if len(traces) > 0:
+ self._sync_send(
+ url=self.trace_url, headers=self.headers, batch={"traces": traces}
+ )
+ if len(spans) > 0:
+ self._sync_send(
+ url=self.span_url, headers=self.headers, batch={"spans": spans}
+ )
+ except Exception as e:
+ verbose_logger.exception(
+ f"OpikLogger failed to log success event - {str(e)}\n{traceback.format_exc()}"
+ )
+
+ async def _submit_batch(self, url: str, headers: Dict[str, str], batch: Dict):
+ try:
+ response = await self.async_httpx_client.post(
+ url=url, headers=headers, json=batch # type: ignore
+ )
+ response.raise_for_status()
+
+ if response.status_code >= 300:
+ verbose_logger.error(
+ f"OpikLogger - Error: {response.status_code} - {response.text}"
+ )
+ else:
+ verbose_logger.info(
+ f"OpikLogger - {len(self.log_queue)} Opik events submitted"
+ )
+ except Exception as e:
+ verbose_logger.exception(f"OpikLogger failed to send batch - {str(e)}")
+
+ def _create_opik_headers(self):
+ headers = {}
+ if self.opik_workspace:
+ headers["Comet-Workspace"] = self.opik_workspace
+
+ if self.opik_api_key:
+ headers["authorization"] = self.opik_api_key
+ return headers
+
+ async def async_send_batch(self):
+ verbose_logger.info("Calling async_send_batch")
+ if not self.log_queue:
+ return
+
+ # Split the log_queue into traces and spans
+ traces, spans = get_traces_and_spans_from_payload(self.log_queue)
+
+ # Send trace batch
+ if len(traces) > 0:
+ await self._submit_batch(
+ url=self.trace_url, headers=self.headers, batch={"traces": traces}
+ )
+ verbose_logger.info(f"Sent {len(traces)} traces")
+ if len(spans) > 0:
+ await self._submit_batch(
+ url=self.span_url, headers=self.headers, batch={"spans": spans}
+ )
+ verbose_logger.info(f"Sent {len(spans)} spans")
+
+ def _create_opik_payload( # noqa: PLR0915
+ self, kwargs, response_obj, start_time, end_time
+ ) -> List[Dict]:
+
+ # Get metadata
+ _litellm_params = kwargs.get("litellm_params", {}) or {}
+ litellm_params_metadata = _litellm_params.get("metadata", {}) or {}
+
+ # Extract opik metadata
+ litellm_opik_metadata = litellm_params_metadata.get("opik", {})
+ verbose_logger.debug(
+ f"litellm_opik_metadata - {json.dumps(litellm_opik_metadata, default=str)}"
+ )
+ project_name = litellm_opik_metadata.get("project_name", self.opik_project_name)
+
+ # Extract trace_id and parent_span_id
+ current_span_data = litellm_opik_metadata.get("current_span_data", None)
+ if isinstance(current_span_data, dict):
+ trace_id = current_span_data.get("trace_id", None)
+ parent_span_id = current_span_data.get("id", None)
+ elif current_span_data:
+ trace_id = current_span_data.trace_id
+ parent_span_id = current_span_data.id
+ else:
+ trace_id = None
+ parent_span_id = None
+ # Create Opik tags
+ opik_tags = litellm_opik_metadata.get("tags", [])
+ if kwargs.get("custom_llm_provider"):
+ opik_tags.append(kwargs["custom_llm_provider"])
+
+ # Use standard_logging_object to create metadata and input/output data
+ standard_logging_object = kwargs.get("standard_logging_object", None)
+ if standard_logging_object is None:
+ verbose_logger.debug(
+ "OpikLogger skipping event; no standard_logging_object found"
+ )
+ return []
+
+ # Create input and output data
+ input_data = standard_logging_object.get("messages", {})
+ output_data = standard_logging_object.get("response", {})
+
+ # Create usage object
+ usage = create_usage_object(response_obj["usage"])
+
+ # Define span and trace names
+ span_name = "%s_%s_%s" % (
+ response_obj.get("model", "unknown-model"),
+ response_obj.get("object", "unknown-object"),
+ response_obj.get("created", 0),
+ )
+ trace_name = response_obj.get("object", "unknown type")
+
+ # Create metadata object, we add the opik metadata first and then
+ # update it with the standard_logging_object metadata
+ metadata = litellm_opik_metadata
+ if "current_span_data" in metadata:
+ del metadata["current_span_data"]
+ metadata["created_from"] = "litellm"
+
+ metadata.update(standard_logging_object.get("metadata", {}))
+ if "call_type" in standard_logging_object:
+ metadata["type"] = standard_logging_object["call_type"]
+ if "status" in standard_logging_object:
+ metadata["status"] = standard_logging_object["status"]
+ if "response_cost" in kwargs:
+ metadata["cost"] = {
+ "total_tokens": kwargs["response_cost"],
+ "currency": "USD",
+ }
+ if "response_cost_failure_debug_info" in kwargs:
+ metadata["response_cost_failure_debug_info"] = kwargs[
+ "response_cost_failure_debug_info"
+ ]
+ if "model_map_information" in standard_logging_object:
+ metadata["model_map_information"] = standard_logging_object[
+ "model_map_information"
+ ]
+ if "model" in standard_logging_object:
+ metadata["model"] = standard_logging_object["model"]
+ if "model_id" in standard_logging_object:
+ metadata["model_id"] = standard_logging_object["model_id"]
+ if "model_group" in standard_logging_object:
+ metadata["model_group"] = standard_logging_object["model_group"]
+ if "api_base" in standard_logging_object:
+ metadata["api_base"] = standard_logging_object["api_base"]
+ if "cache_hit" in standard_logging_object:
+ metadata["cache_hit"] = standard_logging_object["cache_hit"]
+ if "saved_cache_cost" in standard_logging_object:
+ metadata["saved_cache_cost"] = standard_logging_object["saved_cache_cost"]
+ if "error_str" in standard_logging_object:
+ metadata["error_str"] = standard_logging_object["error_str"]
+ if "model_parameters" in standard_logging_object:
+ metadata["model_parameters"] = standard_logging_object["model_parameters"]
+ if "hidden_params" in standard_logging_object:
+ metadata["hidden_params"] = standard_logging_object["hidden_params"]
+
+ payload = []
+ if trace_id is None:
+ trace_id = create_uuid7()
+ verbose_logger.debug(
+ f"OpikLogger creating payload for trace with id {trace_id}"
+ )
+
+ payload.append(
+ {
+ "project_name": project_name,
+ "id": trace_id,
+ "name": trace_name,
+ "start_time": start_time.isoformat() + "Z",
+ "end_time": end_time.isoformat() + "Z",
+ "input": input_data,
+ "output": output_data,
+ "metadata": metadata,
+ "tags": opik_tags,
+ }
+ )
+
+ span_id = create_uuid7()
+ verbose_logger.debug(
+ f"OpikLogger creating payload for trace with id {trace_id} and span with id {span_id}"
+ )
+ payload.append(
+ {
+ "id": span_id,
+ "project_name": project_name,
+ "trace_id": trace_id,
+ "parent_span_id": parent_span_id,
+ "name": span_name,
+ "type": "llm",
+ "start_time": start_time.isoformat() + "Z",
+ "end_time": end_time.isoformat() + "Z",
+ "input": input_data,
+ "output": output_data,
+ "metadata": metadata,
+ "tags": opik_tags,
+ "usage": usage,
+ }
+ )
+ verbose_logger.debug(f"Payload: {payload}")
+ return payload
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/opik/utils.py b/.venv/lib/python3.12/site-packages/litellm/integrations/opik/utils.py
new file mode 100644
index 00000000..7b3b64dc
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/opik/utils.py
@@ -0,0 +1,110 @@
+import configparser
+import os
+import time
+from typing import Dict, Final, List, Optional
+
+CONFIG_FILE_PATH_DEFAULT: Final[str] = "~/.opik.config"
+
+
+def create_uuid7():
+ ns = time.time_ns()
+ last = [0, 0, 0, 0]
+
+ # Simple uuid7 implementation
+ sixteen_secs = 16_000_000_000
+ t1, rest1 = divmod(ns, sixteen_secs)
+ t2, rest2 = divmod(rest1 << 16, sixteen_secs)
+ t3, _ = divmod(rest2 << 12, sixteen_secs)
+ t3 |= 7 << 12 # Put uuid version in top 4 bits, which are 0 in t3
+
+ # The next two bytes are an int (t4) with two bits for
+ # the variant 2 and a 14 bit sequence counter which increments
+ # if the time is unchanged.
+ if t1 == last[0] and t2 == last[1] and t3 == last[2]:
+ # Stop the seq counter wrapping past 0x3FFF.
+ # This won't happen in practice, but if it does,
+ # uuids after the 16383rd with that same timestamp
+ # will not longer be correctly ordered but
+ # are still unique due to the 6 random bytes.
+ if last[3] < 0x3FFF:
+ last[3] += 1
+ else:
+ last[:] = (t1, t2, t3, 0)
+ t4 = (2 << 14) | last[3] # Put variant 0b10 in top two bits
+
+ # Six random bytes for the lower part of the uuid
+ rand = os.urandom(6)
+ return f"{t1:>08x}-{t2:>04x}-{t3:>04x}-{t4:>04x}-{rand.hex()}"
+
+
+def _read_opik_config_file() -> Dict[str, str]:
+ config_path = os.path.expanduser(CONFIG_FILE_PATH_DEFAULT)
+
+ config = configparser.ConfigParser()
+ config.read(config_path)
+
+ config_values = {
+ section: dict(config.items(section)) for section in config.sections()
+ }
+
+ if "opik" in config_values:
+ return config_values["opik"]
+
+ return {}
+
+
+def _get_env_variable(key: str) -> Optional[str]:
+ env_prefix = "opik_"
+ return os.getenv((env_prefix + key).upper(), None)
+
+
+def get_opik_config_variable(
+ key: str, user_value: Optional[str] = None, default_value: Optional[str] = None
+) -> Optional[str]:
+ """
+ Get the configuration value of a variable, order priority is:
+ 1. user provided value
+ 2. environment variable
+ 3. Opik configuration file
+ 4. default value
+ """
+ # Return user provided value if it is not None
+ if user_value is not None:
+ return user_value
+
+ # Return environment variable if it is not None
+ env_value = _get_env_variable(key)
+ if env_value is not None:
+ return env_value
+
+ # Return value from Opik configuration file if it is not None
+ config_values = _read_opik_config_file()
+
+ if key in config_values:
+ return config_values[key]
+
+ # Return default value if it is not None
+ return default_value
+
+
+def create_usage_object(usage):
+ usage_dict = {}
+
+ if usage.completion_tokens is not None:
+ usage_dict["completion_tokens"] = usage.completion_tokens
+ if usage.prompt_tokens is not None:
+ usage_dict["prompt_tokens"] = usage.prompt_tokens
+ if usage.total_tokens is not None:
+ usage_dict["total_tokens"] = usage.total_tokens
+ return usage_dict
+
+
+def _remove_nulls(x):
+ x_ = {k: v for k, v in x.items() if v is not None}
+ return x_
+
+
+def get_traces_and_spans_from_payload(payload: List):
+ traces = [_remove_nulls(x) for x in payload if "type" not in x]
+ spans = [_remove_nulls(x) for x in payload if "type" in x]
+ return traces, spans
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/pagerduty/pagerduty.py b/.venv/lib/python3.12/site-packages/litellm/integrations/pagerduty/pagerduty.py
new file mode 100644
index 00000000..6085bc23
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/pagerduty/pagerduty.py
@@ -0,0 +1,305 @@
+"""
+PagerDuty Alerting Integration
+
+Handles two types of alerts:
+- High LLM API Failure Rate. Configure X fails in Y seconds to trigger an alert.
+- High Number of Hanging LLM Requests. Configure X hangs in Y seconds to trigger an alert.
+"""
+
+import asyncio
+import os
+from datetime import datetime, timedelta, timezone
+from typing import List, Literal, Optional, Union
+
+from litellm._logging import verbose_logger
+from litellm.caching import DualCache
+from litellm.integrations.SlackAlerting.slack_alerting import SlackAlerting
+from litellm.llms.custom_httpx.http_handler import (
+ AsyncHTTPHandler,
+ get_async_httpx_client,
+ httpxSpecialProvider,
+)
+from litellm.proxy._types import UserAPIKeyAuth
+from litellm.types.integrations.pagerduty import (
+ AlertingConfig,
+ PagerDutyInternalEvent,
+ PagerDutyPayload,
+ PagerDutyRequestBody,
+)
+from litellm.types.utils import (
+ StandardLoggingPayload,
+ StandardLoggingPayloadErrorInformation,
+)
+
+PAGERDUTY_DEFAULT_FAILURE_THRESHOLD = 60
+PAGERDUTY_DEFAULT_FAILURE_THRESHOLD_WINDOW_SECONDS = 60
+PAGERDUTY_DEFAULT_HANGING_THRESHOLD_SECONDS = 60
+PAGERDUTY_DEFAULT_HANGING_THRESHOLD_WINDOW_SECONDS = 600
+
+
+class PagerDutyAlerting(SlackAlerting):
+ """
+ Tracks failed requests and hanging requests separately.
+ If threshold is crossed for either type, triggers a PagerDuty alert.
+ """
+
+ def __init__(
+ self, alerting_args: Optional[Union[AlertingConfig, dict]] = None, **kwargs
+ ):
+ from litellm.proxy.proxy_server import CommonProxyErrors, premium_user
+
+ super().__init__()
+ _api_key = os.getenv("PAGERDUTY_API_KEY")
+ if not _api_key:
+ raise ValueError("PAGERDUTY_API_KEY is not set")
+
+ self.api_key: str = _api_key
+ alerting_args = alerting_args or {}
+ self.alerting_args: AlertingConfig = AlertingConfig(
+ failure_threshold=alerting_args.get(
+ "failure_threshold", PAGERDUTY_DEFAULT_FAILURE_THRESHOLD
+ ),
+ failure_threshold_window_seconds=alerting_args.get(
+ "failure_threshold_window_seconds",
+ PAGERDUTY_DEFAULT_FAILURE_THRESHOLD_WINDOW_SECONDS,
+ ),
+ hanging_threshold_seconds=alerting_args.get(
+ "hanging_threshold_seconds", PAGERDUTY_DEFAULT_HANGING_THRESHOLD_SECONDS
+ ),
+ hanging_threshold_window_seconds=alerting_args.get(
+ "hanging_threshold_window_seconds",
+ PAGERDUTY_DEFAULT_HANGING_THRESHOLD_WINDOW_SECONDS,
+ ),
+ )
+
+ # Separate storage for failures vs. hangs
+ self._failure_events: List[PagerDutyInternalEvent] = []
+ self._hanging_events: List[PagerDutyInternalEvent] = []
+
+ # premium user check
+ if premium_user is not True:
+ raise ValueError(
+ f"PagerDutyAlerting is only available for LiteLLM Enterprise users. {CommonProxyErrors.not_premium_user.value}"
+ )
+
+ # ------------------ MAIN LOGIC ------------------ #
+
+ async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
+ """
+ Record a failure event. Only send an alert to PagerDuty if the
+ configured *failure* threshold is exceeded in the specified window.
+ """
+ now = datetime.now(timezone.utc)
+ standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
+ "standard_logging_object"
+ )
+ if not standard_logging_payload:
+ raise ValueError(
+ "standard_logging_object is required for PagerDutyAlerting"
+ )
+
+ # Extract error details
+ error_info: Optional[StandardLoggingPayloadErrorInformation] = (
+ standard_logging_payload.get("error_information") or {}
+ )
+ _meta = standard_logging_payload.get("metadata") or {}
+
+ self._failure_events.append(
+ PagerDutyInternalEvent(
+ failure_event_type="failed_response",
+ timestamp=now,
+ error_class=error_info.get("error_class"),
+ error_code=error_info.get("error_code"),
+ error_llm_provider=error_info.get("llm_provider"),
+ user_api_key_hash=_meta.get("user_api_key_hash"),
+ user_api_key_alias=_meta.get("user_api_key_alias"),
+ user_api_key_org_id=_meta.get("user_api_key_org_id"),
+ user_api_key_team_id=_meta.get("user_api_key_team_id"),
+ user_api_key_user_id=_meta.get("user_api_key_user_id"),
+ user_api_key_team_alias=_meta.get("user_api_key_team_alias"),
+ user_api_key_end_user_id=_meta.get("user_api_key_end_user_id"),
+ user_api_key_user_email=_meta.get("user_api_key_user_email"),
+ )
+ )
+
+ # Prune + Possibly alert
+ window_seconds = self.alerting_args.get("failure_threshold_window_seconds", 60)
+ threshold = self.alerting_args.get("failure_threshold", 1)
+
+ # If threshold is crossed, send PD alert for failures
+ await self._send_alert_if_thresholds_crossed(
+ events=self._failure_events,
+ window_seconds=window_seconds,
+ threshold=threshold,
+ alert_prefix="High LLM API Failure Rate",
+ )
+
+ async def async_pre_call_hook(
+ self,
+ user_api_key_dict: UserAPIKeyAuth,
+ cache: DualCache,
+ data: dict,
+ call_type: Literal[
+ "completion",
+ "text_completion",
+ "embeddings",
+ "image_generation",
+ "moderation",
+ "audio_transcription",
+ "pass_through_endpoint",
+ "rerank",
+ ],
+ ) -> Optional[Union[Exception, str, dict]]:
+ """
+ Example of detecting hanging requests by waiting a given threshold.
+ If the request didn't finish by then, we treat it as 'hanging'.
+ """
+ verbose_logger.info("Inside Proxy Logging Pre-call hook!")
+ asyncio.create_task(
+ self.hanging_response_handler(
+ request_data=data, user_api_key_dict=user_api_key_dict
+ )
+ )
+ return None
+
+ async def hanging_response_handler(
+ self, request_data: Optional[dict], user_api_key_dict: UserAPIKeyAuth
+ ):
+ """
+ Checks if request completed by the time 'hanging_threshold_seconds' elapses.
+ If not, we classify it as a hanging request.
+ """
+ verbose_logger.debug(
+ f"Inside Hanging Response Handler!..sleeping for {self.alerting_args.get('hanging_threshold_seconds', PAGERDUTY_DEFAULT_HANGING_THRESHOLD_SECONDS)} seconds"
+ )
+ await asyncio.sleep(
+ self.alerting_args.get(
+ "hanging_threshold_seconds", PAGERDUTY_DEFAULT_HANGING_THRESHOLD_SECONDS
+ )
+ )
+
+ if await self._request_is_completed(request_data=request_data):
+ return # It's not hanging if completed
+
+ # Otherwise, record it as hanging
+ self._hanging_events.append(
+ PagerDutyInternalEvent(
+ failure_event_type="hanging_response",
+ timestamp=datetime.now(timezone.utc),
+ error_class="HangingRequest",
+ error_code="HangingRequest",
+ error_llm_provider="HangingRequest",
+ user_api_key_hash=user_api_key_dict.api_key,
+ user_api_key_alias=user_api_key_dict.key_alias,
+ user_api_key_org_id=user_api_key_dict.org_id,
+ user_api_key_team_id=user_api_key_dict.team_id,
+ user_api_key_user_id=user_api_key_dict.user_id,
+ user_api_key_team_alias=user_api_key_dict.team_alias,
+ user_api_key_end_user_id=user_api_key_dict.end_user_id,
+ user_api_key_user_email=user_api_key_dict.user_email,
+ )
+ )
+
+ # Prune + Possibly alert
+ window_seconds = self.alerting_args.get(
+ "hanging_threshold_window_seconds",
+ PAGERDUTY_DEFAULT_HANGING_THRESHOLD_WINDOW_SECONDS,
+ )
+ threshold: int = self.alerting_args.get(
+ "hanging_threshold_fails", PAGERDUTY_DEFAULT_HANGING_THRESHOLD_SECONDS
+ )
+
+ # If threshold is crossed, send PD alert for hangs
+ await self._send_alert_if_thresholds_crossed(
+ events=self._hanging_events,
+ window_seconds=window_seconds,
+ threshold=threshold,
+ alert_prefix="High Number of Hanging LLM Requests",
+ )
+
+ # ------------------ HELPERS ------------------ #
+
+ async def _send_alert_if_thresholds_crossed(
+ self,
+ events: List[PagerDutyInternalEvent],
+ window_seconds: int,
+ threshold: int,
+ alert_prefix: str,
+ ):
+ """
+ 1. Prune old events
+ 2. If threshold is reached, build alert, send to PagerDuty
+ 3. Clear those events
+ """
+ cutoff = datetime.now(timezone.utc) - timedelta(seconds=window_seconds)
+ pruned = [e for e in events if e.get("timestamp", datetime.min) > cutoff]
+
+ # Update the reference list
+ events.clear()
+ events.extend(pruned)
+
+ # Check threshold
+ verbose_logger.debug(
+ f"Have {len(events)} events in the last {window_seconds} seconds. Threshold is {threshold}"
+ )
+ if len(events) >= threshold:
+ # Build short summary of last N events
+ error_summaries = self._build_error_summaries(events, max_errors=5)
+ alert_message = (
+ f"{alert_prefix}: {len(events)} in the last {window_seconds} seconds."
+ )
+ custom_details = {"recent_errors": error_summaries}
+
+ await self.send_alert_to_pagerduty(
+ alert_message=alert_message,
+ custom_details=custom_details,
+ )
+
+ # Clear them after sending an alert, so we don't spam
+ events.clear()
+
+ def _build_error_summaries(
+ self, events: List[PagerDutyInternalEvent], max_errors: int = 5
+ ) -> List[PagerDutyInternalEvent]:
+ """
+ Build short text summaries for the last `max_errors`.
+ Example: "ValueError (code: 500, provider: openai)"
+ """
+ recent = events[-max_errors:]
+ summaries = []
+ for fe in recent:
+ # If any of these is None, show "N/A" to avoid messing up the summary string
+ fe.pop("timestamp")
+ summaries.append(fe)
+ return summaries
+
+ async def send_alert_to_pagerduty(self, alert_message: str, custom_details: dict):
+ """
+ Send [critical] Alert to PagerDuty
+
+ https://developer.pagerduty.com/api-reference/YXBpOjI3NDgyNjU-pager-duty-v2-events-api
+ """
+ try:
+ verbose_logger.debug(f"Sending alert to PagerDuty: {alert_message}")
+ async_client: AsyncHTTPHandler = get_async_httpx_client(
+ llm_provider=httpxSpecialProvider.LoggingCallback
+ )
+ payload: PagerDutyRequestBody = PagerDutyRequestBody(
+ payload=PagerDutyPayload(
+ summary=alert_message,
+ severity="critical",
+ source="LiteLLM Alert",
+ component="LiteLLM",
+ custom_details=custom_details,
+ ),
+ routing_key=self.api_key,
+ event_action="trigger",
+ )
+
+ return await async_client.post(
+ url="https://events.pagerduty.com/v2/enqueue",
+ json=dict(payload),
+ headers={"Content-Type": "application/json"},
+ )
+ except Exception as e:
+ verbose_logger.exception(f"Error sending alert to PagerDuty: {e}")
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/prometheus.py b/.venv/lib/python3.12/site-packages/litellm/integrations/prometheus.py
new file mode 100644
index 00000000..d6e47b87
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/prometheus.py
@@ -0,0 +1,1789 @@
+# used for /metrics endpoint on LiteLLM Proxy
+#### What this does ####
+# On success, log events to Prometheus
+import asyncio
+import sys
+from datetime import datetime, timedelta
+from typing import Any, Awaitable, Callable, List, Literal, Optional, Tuple, cast
+
+import litellm
+from litellm._logging import print_verbose, verbose_logger
+from litellm.integrations.custom_logger import CustomLogger
+from litellm.proxy._types import LiteLLM_TeamTable, UserAPIKeyAuth
+from litellm.types.integrations.prometheus import *
+from litellm.types.utils import StandardLoggingPayload
+from litellm.utils import get_end_user_id_for_cost_tracking
+
+
+class PrometheusLogger(CustomLogger):
+ # Class variables or attributes
+ def __init__(
+ self,
+ **kwargs,
+ ):
+ try:
+ from prometheus_client import Counter, Gauge, Histogram
+
+ from litellm.proxy.proxy_server import CommonProxyErrors, premium_user
+
+ if premium_user is not True:
+ verbose_logger.warning(
+ f"🚨🚨🚨 Prometheus Metrics is on LiteLLM Enterprise\n🚨 {CommonProxyErrors.not_premium_user.value}"
+ )
+ self.litellm_not_a_premium_user_metric = Counter(
+ name="litellm_not_a_premium_user_metric",
+ documentation=f"🚨🚨🚨 Prometheus Metrics is on LiteLLM Enterprise. 🚨 {CommonProxyErrors.not_premium_user.value}",
+ )
+ return
+
+ self.litellm_proxy_failed_requests_metric = Counter(
+ name="litellm_proxy_failed_requests_metric",
+ documentation="Total number of failed responses from proxy - the client did not get a success response from litellm proxy",
+ labelnames=PrometheusMetricLabels.get_labels(
+ label_name="litellm_proxy_failed_requests_metric"
+ ),
+ )
+
+ self.litellm_proxy_total_requests_metric = Counter(
+ name="litellm_proxy_total_requests_metric",
+ documentation="Total number of requests made to the proxy server - track number of client side requests",
+ labelnames=PrometheusMetricLabels.get_labels(
+ label_name="litellm_proxy_total_requests_metric"
+ ),
+ )
+
+ # request latency metrics
+ self.litellm_request_total_latency_metric = Histogram(
+ "litellm_request_total_latency_metric",
+ "Total latency (seconds) for a request to LiteLLM",
+ labelnames=PrometheusMetricLabels.get_labels(
+ label_name="litellm_request_total_latency_metric"
+ ),
+ buckets=LATENCY_BUCKETS,
+ )
+
+ self.litellm_llm_api_latency_metric = Histogram(
+ "litellm_llm_api_latency_metric",
+ "Total latency (seconds) for a models LLM API call",
+ labelnames=PrometheusMetricLabels.get_labels(
+ label_name="litellm_llm_api_latency_metric"
+ ),
+ buckets=LATENCY_BUCKETS,
+ )
+
+ self.litellm_llm_api_time_to_first_token_metric = Histogram(
+ "litellm_llm_api_time_to_first_token_metric",
+ "Time to first token for a models LLM API call",
+ labelnames=[
+ "model",
+ "hashed_api_key",
+ "api_key_alias",
+ "team",
+ "team_alias",
+ ],
+ buckets=LATENCY_BUCKETS,
+ )
+
+ # Counter for spend
+ self.litellm_spend_metric = Counter(
+ "litellm_spend_metric",
+ "Total spend on LLM requests",
+ labelnames=[
+ "end_user",
+ "hashed_api_key",
+ "api_key_alias",
+ "model",
+ "team",
+ "team_alias",
+ "user",
+ ],
+ )
+
+ # Counter for total_output_tokens
+ self.litellm_tokens_metric = Counter(
+ "litellm_total_tokens",
+ "Total number of input + output tokens from LLM requests",
+ labelnames=[
+ "end_user",
+ "hashed_api_key",
+ "api_key_alias",
+ "model",
+ "team",
+ "team_alias",
+ "user",
+ ],
+ )
+
+ self.litellm_input_tokens_metric = Counter(
+ "litellm_input_tokens",
+ "Total number of input tokens from LLM requests",
+ labelnames=PrometheusMetricLabels.get_labels(
+ label_name="litellm_input_tokens_metric"
+ ),
+ )
+
+ self.litellm_output_tokens_metric = Counter(
+ "litellm_output_tokens",
+ "Total number of output tokens from LLM requests",
+ labelnames=PrometheusMetricLabels.get_labels(
+ label_name="litellm_output_tokens_metric"
+ ),
+ )
+
+ # Remaining Budget for Team
+ self.litellm_remaining_team_budget_metric = Gauge(
+ "litellm_remaining_team_budget_metric",
+ "Remaining budget for team",
+ labelnames=PrometheusMetricLabels.get_labels(
+ label_name="litellm_remaining_team_budget_metric"
+ ),
+ )
+
+ # Max Budget for Team
+ self.litellm_team_max_budget_metric = Gauge(
+ "litellm_team_max_budget_metric",
+ "Maximum budget set for team",
+ labelnames=PrometheusMetricLabels.get_labels(
+ label_name="litellm_team_max_budget_metric"
+ ),
+ )
+
+ # Team Budget Reset At
+ self.litellm_team_budget_remaining_hours_metric = Gauge(
+ "litellm_team_budget_remaining_hours_metric",
+ "Remaining days for team budget to be reset",
+ labelnames=PrometheusMetricLabels.get_labels(
+ label_name="litellm_team_budget_remaining_hours_metric"
+ ),
+ )
+
+ # Remaining Budget for API Key
+ self.litellm_remaining_api_key_budget_metric = Gauge(
+ "litellm_remaining_api_key_budget_metric",
+ "Remaining budget for api key",
+ labelnames=PrometheusMetricLabels.get_labels(
+ label_name="litellm_remaining_api_key_budget_metric"
+ ),
+ )
+
+ # Max Budget for API Key
+ self.litellm_api_key_max_budget_metric = Gauge(
+ "litellm_api_key_max_budget_metric",
+ "Maximum budget set for api key",
+ labelnames=PrometheusMetricLabels.get_labels(
+ label_name="litellm_api_key_max_budget_metric"
+ ),
+ )
+
+ self.litellm_api_key_budget_remaining_hours_metric = Gauge(
+ "litellm_api_key_budget_remaining_hours_metric",
+ "Remaining hours for api key budget to be reset",
+ labelnames=PrometheusMetricLabels.get_labels(
+ label_name="litellm_api_key_budget_remaining_hours_metric"
+ ),
+ )
+
+ ########################################
+ # LiteLLM Virtual API KEY metrics
+ ########################################
+ # Remaining MODEL RPM limit for API Key
+ self.litellm_remaining_api_key_requests_for_model = Gauge(
+ "litellm_remaining_api_key_requests_for_model",
+ "Remaining Requests API Key can make for model (model based rpm limit on key)",
+ labelnames=["hashed_api_key", "api_key_alias", "model"],
+ )
+
+ # Remaining MODEL TPM limit for API Key
+ self.litellm_remaining_api_key_tokens_for_model = Gauge(
+ "litellm_remaining_api_key_tokens_for_model",
+ "Remaining Tokens API Key can make for model (model based tpm limit on key)",
+ labelnames=["hashed_api_key", "api_key_alias", "model"],
+ )
+
+ ########################################
+ # LLM API Deployment Metrics / analytics
+ ########################################
+
+ # Remaining Rate Limit for model
+ self.litellm_remaining_requests_metric = Gauge(
+ "litellm_remaining_requests",
+ "LLM Deployment Analytics - remaining requests for model, returned from LLM API Provider",
+ labelnames=[
+ "model_group",
+ "api_provider",
+ "api_base",
+ "litellm_model_name",
+ "hashed_api_key",
+ "api_key_alias",
+ ],
+ )
+
+ self.litellm_remaining_tokens_metric = Gauge(
+ "litellm_remaining_tokens",
+ "remaining tokens for model, returned from LLM API Provider",
+ labelnames=[
+ "model_group",
+ "api_provider",
+ "api_base",
+ "litellm_model_name",
+ "hashed_api_key",
+ "api_key_alias",
+ ],
+ )
+
+ self.litellm_overhead_latency_metric = Histogram(
+ "litellm_overhead_latency_metric",
+ "Latency overhead (milliseconds) added by LiteLLM processing",
+ labelnames=[
+ "model_group",
+ "api_provider",
+ "api_base",
+ "litellm_model_name",
+ "hashed_api_key",
+ "api_key_alias",
+ ],
+ buckets=LATENCY_BUCKETS,
+ )
+ # llm api provider budget metrics
+ self.litellm_provider_remaining_budget_metric = Gauge(
+ "litellm_provider_remaining_budget_metric",
+ "Remaining budget for provider - used when you set provider budget limits",
+ labelnames=["api_provider"],
+ )
+
+ # Get all keys
+ _logged_llm_labels = [
+ UserAPIKeyLabelNames.v2_LITELLM_MODEL_NAME.value,
+ UserAPIKeyLabelNames.MODEL_ID.value,
+ UserAPIKeyLabelNames.API_BASE.value,
+ UserAPIKeyLabelNames.API_PROVIDER.value,
+ ]
+ team_and_key_labels = [
+ "hashed_api_key",
+ "api_key_alias",
+ "team",
+ "team_alias",
+ ]
+
+ # Metric for deployment state
+ self.litellm_deployment_state = Gauge(
+ "litellm_deployment_state",
+ "LLM Deployment Analytics - The state of the deployment: 0 = healthy, 1 = partial outage, 2 = complete outage",
+ labelnames=_logged_llm_labels,
+ )
+
+ self.litellm_deployment_cooled_down = Counter(
+ "litellm_deployment_cooled_down",
+ "LLM Deployment Analytics - Number of times a deployment has been cooled down by LiteLLM load balancing logic. exception_status is the status of the exception that caused the deployment to be cooled down",
+ labelnames=_logged_llm_labels + [EXCEPTION_STATUS],
+ )
+
+ self.litellm_deployment_success_responses = Counter(
+ name="litellm_deployment_success_responses",
+ documentation="LLM Deployment Analytics - Total number of successful LLM API calls via litellm",
+ labelnames=[REQUESTED_MODEL] + _logged_llm_labels + team_and_key_labels,
+ )
+ self.litellm_deployment_failure_responses = Counter(
+ name="litellm_deployment_failure_responses",
+ documentation="LLM Deployment Analytics - Total number of failed LLM API calls for a specific LLM deploymeny. exception_status is the status of the exception from the llm api",
+ labelnames=[REQUESTED_MODEL]
+ + _logged_llm_labels
+ + EXCEPTION_LABELS
+ + team_and_key_labels,
+ )
+ self.litellm_deployment_failure_by_tag_responses = Counter(
+ "litellm_deployment_failure_by_tag_responses",
+ "Total number of failed LLM API calls for a specific LLM deploymeny by custom metadata tags",
+ labelnames=[
+ UserAPIKeyLabelNames.REQUESTED_MODEL.value,
+ UserAPIKeyLabelNames.TAG.value,
+ ]
+ + _logged_llm_labels
+ + EXCEPTION_LABELS,
+ )
+ self.litellm_deployment_total_requests = Counter(
+ name="litellm_deployment_total_requests",
+ documentation="LLM Deployment Analytics - Total number of LLM API calls via litellm - success + failure",
+ labelnames=[REQUESTED_MODEL] + _logged_llm_labels + team_and_key_labels,
+ )
+
+ # Deployment Latency tracking
+ team_and_key_labels = [
+ "hashed_api_key",
+ "api_key_alias",
+ "team",
+ "team_alias",
+ ]
+ self.litellm_deployment_latency_per_output_token = Histogram(
+ name="litellm_deployment_latency_per_output_token",
+ documentation="LLM Deployment Analytics - Latency per output token",
+ labelnames=PrometheusMetricLabels.get_labels(
+ label_name="litellm_deployment_latency_per_output_token"
+ ),
+ )
+
+ self.litellm_deployment_successful_fallbacks = Counter(
+ "litellm_deployment_successful_fallbacks",
+ "LLM Deployment Analytics - Number of successful fallback requests from primary model -> fallback model",
+ PrometheusMetricLabels.get_labels(
+ "litellm_deployment_successful_fallbacks"
+ ),
+ )
+
+ self.litellm_deployment_failed_fallbacks = Counter(
+ "litellm_deployment_failed_fallbacks",
+ "LLM Deployment Analytics - Number of failed fallback requests from primary model -> fallback model",
+ PrometheusMetricLabels.get_labels(
+ "litellm_deployment_failed_fallbacks"
+ ),
+ )
+
+ self.litellm_llm_api_failed_requests_metric = Counter(
+ name="litellm_llm_api_failed_requests_metric",
+ documentation="deprecated - use litellm_proxy_failed_requests_metric",
+ labelnames=[
+ "end_user",
+ "hashed_api_key",
+ "api_key_alias",
+ "model",
+ "team",
+ "team_alias",
+ "user",
+ ],
+ )
+
+ self.litellm_requests_metric = Counter(
+ name="litellm_requests_metric",
+ documentation="deprecated - use litellm_proxy_total_requests_metric. Total number of LLM calls to litellm - track total per API Key, team, user",
+ labelnames=PrometheusMetricLabels.get_labels(
+ label_name="litellm_requests_metric"
+ ),
+ )
+ self._initialize_prometheus_startup_metrics()
+
+ except Exception as e:
+ print_verbose(f"Got exception on init prometheus client {str(e)}")
+ raise e
+
+ async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
+ # Define prometheus client
+ from litellm.types.utils import StandardLoggingPayload
+
+ verbose_logger.debug(
+ f"prometheus Logging - Enters success logging function for kwargs {kwargs}"
+ )
+
+ # unpack kwargs
+ standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
+ "standard_logging_object"
+ )
+
+ if standard_logging_payload is None or not isinstance(
+ standard_logging_payload, dict
+ ):
+ raise ValueError(
+ f"standard_logging_object is required, got={standard_logging_payload}"
+ )
+
+ model = kwargs.get("model", "")
+ litellm_params = kwargs.get("litellm_params", {}) or {}
+ _metadata = litellm_params.get("metadata", {})
+ end_user_id = get_end_user_id_for_cost_tracking(
+ litellm_params, service_type="prometheus"
+ )
+ user_id = standard_logging_payload["metadata"]["user_api_key_user_id"]
+ user_api_key = standard_logging_payload["metadata"]["user_api_key_hash"]
+ user_api_key_alias = standard_logging_payload["metadata"]["user_api_key_alias"]
+ user_api_team = standard_logging_payload["metadata"]["user_api_key_team_id"]
+ user_api_team_alias = standard_logging_payload["metadata"][
+ "user_api_key_team_alias"
+ ]
+ output_tokens = standard_logging_payload["completion_tokens"]
+ tokens_used = standard_logging_payload["total_tokens"]
+ response_cost = standard_logging_payload["response_cost"]
+ _requester_metadata = standard_logging_payload["metadata"].get(
+ "requester_metadata"
+ )
+ if standard_logging_payload is not None and isinstance(
+ standard_logging_payload, dict
+ ):
+ _tags = standard_logging_payload["request_tags"]
+ else:
+ _tags = []
+
+ print_verbose(
+ f"inside track_prometheus_metrics, model {model}, response_cost {response_cost}, tokens_used {tokens_used}, end_user_id {end_user_id}, user_api_key {user_api_key}"
+ )
+
+ enum_values = UserAPIKeyLabelValues(
+ end_user=end_user_id,
+ hashed_api_key=user_api_key,
+ api_key_alias=user_api_key_alias,
+ requested_model=standard_logging_payload["model_group"],
+ team=user_api_team,
+ team_alias=user_api_team_alias,
+ user=user_id,
+ user_email=standard_logging_payload["metadata"]["user_api_key_user_email"],
+ status_code="200",
+ model=model,
+ litellm_model_name=model,
+ tags=_tags,
+ model_id=standard_logging_payload["model_id"],
+ api_base=standard_logging_payload["api_base"],
+ api_provider=standard_logging_payload["custom_llm_provider"],
+ exception_status=None,
+ exception_class=None,
+ custom_metadata_labels=get_custom_labels_from_metadata(
+ metadata=standard_logging_payload["metadata"].get("requester_metadata")
+ or {}
+ ),
+ )
+
+ if (
+ user_api_key is not None
+ and isinstance(user_api_key, str)
+ and user_api_key.startswith("sk-")
+ ):
+ from litellm.proxy.utils import hash_token
+
+ user_api_key = hash_token(user_api_key)
+
+ # increment total LLM requests and spend metric
+ self._increment_top_level_request_and_spend_metrics(
+ end_user_id=end_user_id,
+ user_api_key=user_api_key,
+ user_api_key_alias=user_api_key_alias,
+ model=model,
+ user_api_team=user_api_team,
+ user_api_team_alias=user_api_team_alias,
+ user_id=user_id,
+ response_cost=response_cost,
+ enum_values=enum_values,
+ )
+
+ # input, output, total token metrics
+ self._increment_token_metrics(
+ # why type ignore below?
+ # 1. We just checked if isinstance(standard_logging_payload, dict). Pyright complains.
+ # 2. Pyright does not allow us to run isinstance(standard_logging_payload, StandardLoggingPayload) <- this would be ideal
+ standard_logging_payload=standard_logging_payload, # type: ignore
+ end_user_id=end_user_id,
+ user_api_key=user_api_key,
+ user_api_key_alias=user_api_key_alias,
+ model=model,
+ user_api_team=user_api_team,
+ user_api_team_alias=user_api_team_alias,
+ user_id=user_id,
+ enum_values=enum_values,
+ )
+
+ # remaining budget metrics
+ await self._increment_remaining_budget_metrics(
+ user_api_team=user_api_team,
+ user_api_team_alias=user_api_team_alias,
+ user_api_key=user_api_key,
+ user_api_key_alias=user_api_key_alias,
+ litellm_params=litellm_params,
+ response_cost=response_cost,
+ )
+
+ # set proxy virtual key rpm/tpm metrics
+ self._set_virtual_key_rate_limit_metrics(
+ user_api_key=user_api_key,
+ user_api_key_alias=user_api_key_alias,
+ kwargs=kwargs,
+ metadata=_metadata,
+ )
+
+ # set latency metrics
+ self._set_latency_metrics(
+ kwargs=kwargs,
+ model=model,
+ user_api_key=user_api_key,
+ user_api_key_alias=user_api_key_alias,
+ user_api_team=user_api_team,
+ user_api_team_alias=user_api_team_alias,
+ # why type ignore below?
+ # 1. We just checked if isinstance(standard_logging_payload, dict). Pyright complains.
+ # 2. Pyright does not allow us to run isinstance(standard_logging_payload, StandardLoggingPayload) <- this would be ideal
+ enum_values=enum_values,
+ )
+
+ # set x-ratelimit headers
+ self.set_llm_deployment_success_metrics(
+ kwargs, start_time, end_time, enum_values, output_tokens
+ )
+
+ if (
+ standard_logging_payload["stream"] is True
+ ): # log successful streaming requests from logging event hook.
+ _labels = prometheus_label_factory(
+ supported_enum_labels=PrometheusMetricLabels.get_labels(
+ label_name="litellm_proxy_total_requests_metric"
+ ),
+ enum_values=enum_values,
+ )
+ self.litellm_proxy_total_requests_metric.labels(**_labels).inc()
+
+ def _increment_token_metrics(
+ self,
+ standard_logging_payload: StandardLoggingPayload,
+ end_user_id: Optional[str],
+ user_api_key: Optional[str],
+ user_api_key_alias: Optional[str],
+ model: Optional[str],
+ user_api_team: Optional[str],
+ user_api_team_alias: Optional[str],
+ user_id: Optional[str],
+ enum_values: UserAPIKeyLabelValues,
+ ):
+ # token metrics
+ self.litellm_tokens_metric.labels(
+ end_user_id,
+ user_api_key,
+ user_api_key_alias,
+ model,
+ user_api_team,
+ user_api_team_alias,
+ user_id,
+ ).inc(standard_logging_payload["total_tokens"])
+
+ if standard_logging_payload is not None and isinstance(
+ standard_logging_payload, dict
+ ):
+ _tags = standard_logging_payload["request_tags"]
+
+ _labels = prometheus_label_factory(
+ supported_enum_labels=PrometheusMetricLabels.get_labels(
+ label_name="litellm_input_tokens_metric"
+ ),
+ enum_values=enum_values,
+ )
+ self.litellm_input_tokens_metric.labels(**_labels).inc(
+ standard_logging_payload["prompt_tokens"]
+ )
+
+ _labels = prometheus_label_factory(
+ supported_enum_labels=PrometheusMetricLabels.get_labels(
+ label_name="litellm_output_tokens_metric"
+ ),
+ enum_values=enum_values,
+ )
+
+ self.litellm_output_tokens_metric.labels(**_labels).inc(
+ standard_logging_payload["completion_tokens"]
+ )
+
+ async def _increment_remaining_budget_metrics(
+ self,
+ user_api_team: Optional[str],
+ user_api_team_alias: Optional[str],
+ user_api_key: Optional[str],
+ user_api_key_alias: Optional[str],
+ litellm_params: dict,
+ response_cost: float,
+ ):
+ _team_spend = litellm_params.get("metadata", {}).get(
+ "user_api_key_team_spend", None
+ )
+ _team_max_budget = litellm_params.get("metadata", {}).get(
+ "user_api_key_team_max_budget", None
+ )
+
+ _api_key_spend = litellm_params.get("metadata", {}).get(
+ "user_api_key_spend", None
+ )
+ _api_key_max_budget = litellm_params.get("metadata", {}).get(
+ "user_api_key_max_budget", None
+ )
+ await self._set_api_key_budget_metrics_after_api_request(
+ user_api_key=user_api_key,
+ user_api_key_alias=user_api_key_alias,
+ response_cost=response_cost,
+ key_max_budget=_api_key_max_budget,
+ key_spend=_api_key_spend,
+ )
+
+ await self._set_team_budget_metrics_after_api_request(
+ user_api_team=user_api_team,
+ user_api_team_alias=user_api_team_alias,
+ team_spend=_team_spend,
+ team_max_budget=_team_max_budget,
+ response_cost=response_cost,
+ )
+
+ def _increment_top_level_request_and_spend_metrics(
+ self,
+ end_user_id: Optional[str],
+ user_api_key: Optional[str],
+ user_api_key_alias: Optional[str],
+ model: Optional[str],
+ user_api_team: Optional[str],
+ user_api_team_alias: Optional[str],
+ user_id: Optional[str],
+ response_cost: float,
+ enum_values: UserAPIKeyLabelValues,
+ ):
+ _labels = prometheus_label_factory(
+ supported_enum_labels=PrometheusMetricLabels.get_labels(
+ label_name="litellm_requests_metric"
+ ),
+ enum_values=enum_values,
+ )
+ self.litellm_requests_metric.labels(**_labels).inc()
+
+ self.litellm_spend_metric.labels(
+ end_user_id,
+ user_api_key,
+ user_api_key_alias,
+ model,
+ user_api_team,
+ user_api_team_alias,
+ user_id,
+ ).inc(response_cost)
+
+ def _set_virtual_key_rate_limit_metrics(
+ self,
+ user_api_key: Optional[str],
+ user_api_key_alias: Optional[str],
+ kwargs: dict,
+ metadata: dict,
+ ):
+ from litellm.proxy.common_utils.callback_utils import (
+ get_model_group_from_litellm_kwargs,
+ )
+
+ # Set remaining rpm/tpm for API Key + model
+ # see parallel_request_limiter.py - variables are set there
+ model_group = get_model_group_from_litellm_kwargs(kwargs)
+ remaining_requests_variable_name = (
+ f"litellm-key-remaining-requests-{model_group}"
+ )
+ remaining_tokens_variable_name = f"litellm-key-remaining-tokens-{model_group}"
+
+ remaining_requests = (
+ metadata.get(remaining_requests_variable_name, sys.maxsize) or sys.maxsize
+ )
+ remaining_tokens = (
+ metadata.get(remaining_tokens_variable_name, sys.maxsize) or sys.maxsize
+ )
+
+ self.litellm_remaining_api_key_requests_for_model.labels(
+ user_api_key, user_api_key_alias, model_group
+ ).set(remaining_requests)
+
+ self.litellm_remaining_api_key_tokens_for_model.labels(
+ user_api_key, user_api_key_alias, model_group
+ ).set(remaining_tokens)
+
+ def _set_latency_metrics(
+ self,
+ kwargs: dict,
+ model: Optional[str],
+ user_api_key: Optional[str],
+ user_api_key_alias: Optional[str],
+ user_api_team: Optional[str],
+ user_api_team_alias: Optional[str],
+ enum_values: UserAPIKeyLabelValues,
+ ):
+ # latency metrics
+ end_time: datetime = kwargs.get("end_time") or datetime.now()
+ start_time: Optional[datetime] = kwargs.get("start_time")
+ api_call_start_time = kwargs.get("api_call_start_time", None)
+ completion_start_time = kwargs.get("completion_start_time", None)
+ time_to_first_token_seconds = self._safe_duration_seconds(
+ start_time=api_call_start_time,
+ end_time=completion_start_time,
+ )
+ if (
+ time_to_first_token_seconds is not None
+ and kwargs.get("stream", False) is True # only emit for streaming requests
+ ):
+ self.litellm_llm_api_time_to_first_token_metric.labels(
+ model,
+ user_api_key,
+ user_api_key_alias,
+ user_api_team,
+ user_api_team_alias,
+ ).observe(time_to_first_token_seconds)
+ else:
+ verbose_logger.debug(
+ "Time to first token metric not emitted, stream option in model_parameters is not True"
+ )
+
+ api_call_total_time_seconds = self._safe_duration_seconds(
+ start_time=api_call_start_time,
+ end_time=end_time,
+ )
+ if api_call_total_time_seconds is not None:
+ _labels = prometheus_label_factory(
+ supported_enum_labels=PrometheusMetricLabels.get_labels(
+ label_name="litellm_llm_api_latency_metric"
+ ),
+ enum_values=enum_values,
+ )
+ self.litellm_llm_api_latency_metric.labels(**_labels).observe(
+ api_call_total_time_seconds
+ )
+
+ # total request latency
+ total_time_seconds = self._safe_duration_seconds(
+ start_time=start_time,
+ end_time=end_time,
+ )
+ if total_time_seconds is not None:
+ _labels = prometheus_label_factory(
+ supported_enum_labels=PrometheusMetricLabels.get_labels(
+ label_name="litellm_request_total_latency_metric"
+ ),
+ enum_values=enum_values,
+ )
+ self.litellm_request_total_latency_metric.labels(**_labels).observe(
+ total_time_seconds
+ )
+
+ async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
+ from litellm.types.utils import StandardLoggingPayload
+
+ verbose_logger.debug(
+ f"prometheus Logging - Enters failure logging function for kwargs {kwargs}"
+ )
+
+ # unpack kwargs
+ model = kwargs.get("model", "")
+ standard_logging_payload: StandardLoggingPayload = kwargs.get(
+ "standard_logging_object", {}
+ )
+ litellm_params = kwargs.get("litellm_params", {}) or {}
+ end_user_id = get_end_user_id_for_cost_tracking(
+ litellm_params, service_type="prometheus"
+ )
+ user_id = standard_logging_payload["metadata"]["user_api_key_user_id"]
+ user_api_key = standard_logging_payload["metadata"]["user_api_key_hash"]
+ user_api_key_alias = standard_logging_payload["metadata"]["user_api_key_alias"]
+ user_api_team = standard_logging_payload["metadata"]["user_api_key_team_id"]
+ user_api_team_alias = standard_logging_payload["metadata"][
+ "user_api_key_team_alias"
+ ]
+ kwargs.get("exception", None)
+
+ try:
+ self.litellm_llm_api_failed_requests_metric.labels(
+ end_user_id,
+ user_api_key,
+ user_api_key_alias,
+ model,
+ user_api_team,
+ user_api_team_alias,
+ user_id,
+ ).inc()
+ self.set_llm_deployment_failure_metrics(kwargs)
+ except Exception as e:
+ verbose_logger.exception(
+ "prometheus Layer Error(): Exception occured - {}".format(str(e))
+ )
+ pass
+ pass
+
+ async def async_post_call_failure_hook(
+ self,
+ request_data: dict,
+ original_exception: Exception,
+ user_api_key_dict: UserAPIKeyAuth,
+ ):
+ """
+ Track client side failures
+
+ Proxy level tracking - failed client side requests
+
+ labelnames=[
+ "end_user",
+ "hashed_api_key",
+ "api_key_alias",
+ REQUESTED_MODEL,
+ "team",
+ "team_alias",
+ ] + EXCEPTION_LABELS,
+ """
+ try:
+ _tags = cast(List[str], request_data.get("tags") or [])
+ enum_values = UserAPIKeyLabelValues(
+ end_user=user_api_key_dict.end_user_id,
+ user=user_api_key_dict.user_id,
+ user_email=user_api_key_dict.user_email,
+ hashed_api_key=user_api_key_dict.api_key,
+ api_key_alias=user_api_key_dict.key_alias,
+ team=user_api_key_dict.team_id,
+ team_alias=user_api_key_dict.team_alias,
+ requested_model=request_data.get("model", ""),
+ status_code=str(getattr(original_exception, "status_code", None)),
+ exception_status=str(getattr(original_exception, "status_code", None)),
+ exception_class=str(original_exception.__class__.__name__),
+ tags=_tags,
+ )
+ _labels = prometheus_label_factory(
+ supported_enum_labels=PrometheusMetricLabels.get_labels(
+ label_name="litellm_proxy_failed_requests_metric"
+ ),
+ enum_values=enum_values,
+ )
+ self.litellm_proxy_failed_requests_metric.labels(**_labels).inc()
+
+ _labels = prometheus_label_factory(
+ supported_enum_labels=PrometheusMetricLabels.get_labels(
+ label_name="litellm_proxy_total_requests_metric"
+ ),
+ enum_values=enum_values,
+ )
+ self.litellm_proxy_total_requests_metric.labels(**_labels).inc()
+
+ except Exception as e:
+ verbose_logger.exception(
+ "prometheus Layer Error(): Exception occured - {}".format(str(e))
+ )
+ pass
+
+ async def async_post_call_success_hook(
+ self, data: dict, user_api_key_dict: UserAPIKeyAuth, response
+ ):
+ """
+ Proxy level tracking - triggered when the proxy responds with a success response to the client
+ """
+ try:
+ enum_values = UserAPIKeyLabelValues(
+ end_user=user_api_key_dict.end_user_id,
+ hashed_api_key=user_api_key_dict.api_key,
+ api_key_alias=user_api_key_dict.key_alias,
+ requested_model=data.get("model", ""),
+ team=user_api_key_dict.team_id,
+ team_alias=user_api_key_dict.team_alias,
+ user=user_api_key_dict.user_id,
+ user_email=user_api_key_dict.user_email,
+ status_code="200",
+ )
+ _labels = prometheus_label_factory(
+ supported_enum_labels=PrometheusMetricLabels.get_labels(
+ label_name="litellm_proxy_total_requests_metric"
+ ),
+ enum_values=enum_values,
+ )
+ self.litellm_proxy_total_requests_metric.labels(**_labels).inc()
+
+ except Exception as e:
+ verbose_logger.exception(
+ "prometheus Layer Error(): Exception occured - {}".format(str(e))
+ )
+ pass
+
+ def set_llm_deployment_failure_metrics(self, request_kwargs: dict):
+ """
+ Sets Failure metrics when an LLM API call fails
+
+ - mark the deployment as partial outage
+ - increment deployment failure responses metric
+ - increment deployment total requests metric
+
+ Args:
+ request_kwargs: dict
+
+ """
+ try:
+ verbose_logger.debug("setting remaining tokens requests metric")
+ standard_logging_payload: StandardLoggingPayload = request_kwargs.get(
+ "standard_logging_object", {}
+ )
+ _litellm_params = request_kwargs.get("litellm_params", {}) or {}
+ litellm_model_name = request_kwargs.get("model", None)
+ model_group = standard_logging_payload.get("model_group", None)
+ api_base = standard_logging_payload.get("api_base", None)
+ model_id = standard_logging_payload.get("model_id", None)
+ exception: Exception = request_kwargs.get("exception", None)
+
+ llm_provider = _litellm_params.get("custom_llm_provider", None)
+
+ """
+ log these labels
+ ["litellm_model_name", "model_id", "api_base", "api_provider"]
+ """
+ self.set_deployment_partial_outage(
+ litellm_model_name=litellm_model_name,
+ model_id=model_id,
+ api_base=api_base,
+ api_provider=llm_provider,
+ )
+ self.litellm_deployment_failure_responses.labels(
+ litellm_model_name=litellm_model_name,
+ model_id=model_id,
+ api_base=api_base,
+ api_provider=llm_provider,
+ exception_status=str(getattr(exception, "status_code", None)),
+ exception_class=exception.__class__.__name__,
+ requested_model=model_group,
+ hashed_api_key=standard_logging_payload["metadata"][
+ "user_api_key_hash"
+ ],
+ api_key_alias=standard_logging_payload["metadata"][
+ "user_api_key_alias"
+ ],
+ team=standard_logging_payload["metadata"]["user_api_key_team_id"],
+ team_alias=standard_logging_payload["metadata"][
+ "user_api_key_team_alias"
+ ],
+ ).inc()
+
+ # tag based tracking
+ if standard_logging_payload is not None and isinstance(
+ standard_logging_payload, dict
+ ):
+ _tags = standard_logging_payload["request_tags"]
+ for tag in _tags:
+ self.litellm_deployment_failure_by_tag_responses.labels(
+ **{
+ UserAPIKeyLabelNames.REQUESTED_MODEL.value: model_group,
+ UserAPIKeyLabelNames.TAG.value: tag,
+ UserAPIKeyLabelNames.v2_LITELLM_MODEL_NAME.value: litellm_model_name,
+ UserAPIKeyLabelNames.MODEL_ID.value: model_id,
+ UserAPIKeyLabelNames.API_BASE.value: api_base,
+ UserAPIKeyLabelNames.API_PROVIDER.value: llm_provider,
+ UserAPIKeyLabelNames.EXCEPTION_CLASS.value: exception.__class__.__name__,
+ UserAPIKeyLabelNames.EXCEPTION_STATUS.value: str(
+ getattr(exception, "status_code", None)
+ ),
+ }
+ ).inc()
+
+ self.litellm_deployment_total_requests.labels(
+ litellm_model_name=litellm_model_name,
+ model_id=model_id,
+ api_base=api_base,
+ api_provider=llm_provider,
+ requested_model=model_group,
+ hashed_api_key=standard_logging_payload["metadata"][
+ "user_api_key_hash"
+ ],
+ api_key_alias=standard_logging_payload["metadata"][
+ "user_api_key_alias"
+ ],
+ team=standard_logging_payload["metadata"]["user_api_key_team_id"],
+ team_alias=standard_logging_payload["metadata"][
+ "user_api_key_team_alias"
+ ],
+ ).inc()
+
+ pass
+ except Exception as e:
+ verbose_logger.debug(
+ "Prometheus Error: set_llm_deployment_failure_metrics. Exception occured - {}".format(
+ str(e)
+ )
+ )
+
+ def set_llm_deployment_success_metrics(
+ self,
+ request_kwargs: dict,
+ start_time,
+ end_time,
+ enum_values: UserAPIKeyLabelValues,
+ output_tokens: float = 1.0,
+ ):
+ try:
+ verbose_logger.debug("setting remaining tokens requests metric")
+ standard_logging_payload: Optional[StandardLoggingPayload] = (
+ request_kwargs.get("standard_logging_object")
+ )
+
+ if standard_logging_payload is None:
+ return
+
+ model_group = standard_logging_payload["model_group"]
+ api_base = standard_logging_payload["api_base"]
+ _response_headers = request_kwargs.get("response_headers")
+ _litellm_params = request_kwargs.get("litellm_params", {}) or {}
+ _metadata = _litellm_params.get("metadata", {})
+ litellm_model_name = request_kwargs.get("model", None)
+ llm_provider = _litellm_params.get("custom_llm_provider", None)
+ _model_info = _metadata.get("model_info") or {}
+ model_id = _model_info.get("id", None)
+
+ remaining_requests: Optional[int] = None
+ remaining_tokens: Optional[int] = None
+ if additional_headers := standard_logging_payload["hidden_params"][
+ "additional_headers"
+ ]:
+ # OpenAI / OpenAI Compatible headers
+ remaining_requests = additional_headers.get(
+ "x_ratelimit_remaining_requests", None
+ )
+ remaining_tokens = additional_headers.get(
+ "x_ratelimit_remaining_tokens", None
+ )
+
+ if litellm_overhead_time_ms := standard_logging_payload[
+ "hidden_params"
+ ].get("litellm_overhead_time_ms"):
+ self.litellm_overhead_latency_metric.labels(
+ model_group,
+ llm_provider,
+ api_base,
+ litellm_model_name,
+ standard_logging_payload["metadata"]["user_api_key_hash"],
+ standard_logging_payload["metadata"]["user_api_key_alias"],
+ ).observe(
+ litellm_overhead_time_ms / 1000
+ ) # set as seconds
+
+ if remaining_requests:
+ """
+ "model_group",
+ "api_provider",
+ "api_base",
+ "litellm_model_name"
+ """
+ self.litellm_remaining_requests_metric.labels(
+ model_group,
+ llm_provider,
+ api_base,
+ litellm_model_name,
+ standard_logging_payload["metadata"]["user_api_key_hash"],
+ standard_logging_payload["metadata"]["user_api_key_alias"],
+ ).set(remaining_requests)
+
+ if remaining_tokens:
+ self.litellm_remaining_tokens_metric.labels(
+ model_group,
+ llm_provider,
+ api_base,
+ litellm_model_name,
+ standard_logging_payload["metadata"]["user_api_key_hash"],
+ standard_logging_payload["metadata"]["user_api_key_alias"],
+ ).set(remaining_tokens)
+
+ """
+ log these labels
+ ["litellm_model_name", "requested_model", model_id", "api_base", "api_provider"]
+ """
+ self.set_deployment_healthy(
+ litellm_model_name=litellm_model_name,
+ model_id=model_id,
+ api_base=api_base,
+ api_provider=llm_provider,
+ )
+
+ self.litellm_deployment_success_responses.labels(
+ litellm_model_name=litellm_model_name,
+ model_id=model_id,
+ api_base=api_base,
+ api_provider=llm_provider,
+ requested_model=model_group,
+ hashed_api_key=standard_logging_payload["metadata"][
+ "user_api_key_hash"
+ ],
+ api_key_alias=standard_logging_payload["metadata"][
+ "user_api_key_alias"
+ ],
+ team=standard_logging_payload["metadata"]["user_api_key_team_id"],
+ team_alias=standard_logging_payload["metadata"][
+ "user_api_key_team_alias"
+ ],
+ ).inc()
+
+ self.litellm_deployment_total_requests.labels(
+ litellm_model_name=litellm_model_name,
+ model_id=model_id,
+ api_base=api_base,
+ api_provider=llm_provider,
+ requested_model=model_group,
+ hashed_api_key=standard_logging_payload["metadata"][
+ "user_api_key_hash"
+ ],
+ api_key_alias=standard_logging_payload["metadata"][
+ "user_api_key_alias"
+ ],
+ team=standard_logging_payload["metadata"]["user_api_key_team_id"],
+ team_alias=standard_logging_payload["metadata"][
+ "user_api_key_team_alias"
+ ],
+ ).inc()
+
+ # Track deployment Latency
+ response_ms: timedelta = end_time - start_time
+ time_to_first_token_response_time: Optional[timedelta] = None
+
+ if (
+ request_kwargs.get("stream", None) is not None
+ and request_kwargs["stream"] is True
+ ):
+ # only log ttft for streaming request
+ time_to_first_token_response_time = (
+ request_kwargs.get("completion_start_time", end_time) - start_time
+ )
+
+ # use the metric that is not None
+ # if streaming - use time_to_first_token_response
+ # if not streaming - use response_ms
+ _latency: timedelta = time_to_first_token_response_time or response_ms
+ _latency_seconds = _latency.total_seconds()
+
+ # latency per output token
+ latency_per_token = None
+ if output_tokens is not None and output_tokens > 0:
+ latency_per_token = _latency_seconds / output_tokens
+ _labels = prometheus_label_factory(
+ supported_enum_labels=PrometheusMetricLabels.get_labels(
+ label_name="litellm_deployment_latency_per_output_token"
+ ),
+ enum_values=enum_values,
+ )
+ self.litellm_deployment_latency_per_output_token.labels(
+ **_labels
+ ).observe(latency_per_token)
+
+ except Exception as e:
+ verbose_logger.error(
+ "Prometheus Error: set_llm_deployment_success_metrics. Exception occured - {}".format(
+ str(e)
+ )
+ )
+ return
+
+ async def log_success_fallback_event(
+ self, original_model_group: str, kwargs: dict, original_exception: Exception
+ ):
+ """
+
+ Logs a successful LLM fallback event on prometheus
+
+ """
+ from litellm.litellm_core_utils.litellm_logging import (
+ StandardLoggingMetadata,
+ StandardLoggingPayloadSetup,
+ )
+
+ verbose_logger.debug(
+ "Prometheus: log_success_fallback_event, original_model_group: %s, kwargs: %s",
+ original_model_group,
+ kwargs,
+ )
+ _metadata = kwargs.get("metadata", {})
+ standard_metadata: StandardLoggingMetadata = (
+ StandardLoggingPayloadSetup.get_standard_logging_metadata(
+ metadata=_metadata
+ )
+ )
+ _new_model = kwargs.get("model")
+ _tags = cast(List[str], kwargs.get("tags") or [])
+
+ enum_values = UserAPIKeyLabelValues(
+ requested_model=original_model_group,
+ fallback_model=_new_model,
+ hashed_api_key=standard_metadata["user_api_key_hash"],
+ api_key_alias=standard_metadata["user_api_key_alias"],
+ team=standard_metadata["user_api_key_team_id"],
+ team_alias=standard_metadata["user_api_key_team_alias"],
+ exception_status=str(getattr(original_exception, "status_code", None)),
+ exception_class=str(original_exception.__class__.__name__),
+ tags=_tags,
+ )
+ _labels = prometheus_label_factory(
+ supported_enum_labels=PrometheusMetricLabels.get_labels(
+ label_name="litellm_deployment_successful_fallbacks"
+ ),
+ enum_values=enum_values,
+ )
+ self.litellm_deployment_successful_fallbacks.labels(**_labels).inc()
+
+ async def log_failure_fallback_event(
+ self, original_model_group: str, kwargs: dict, original_exception: Exception
+ ):
+ """
+ Logs a failed LLM fallback event on prometheus
+ """
+ from litellm.litellm_core_utils.litellm_logging import (
+ StandardLoggingMetadata,
+ StandardLoggingPayloadSetup,
+ )
+
+ verbose_logger.debug(
+ "Prometheus: log_failure_fallback_event, original_model_group: %s, kwargs: %s",
+ original_model_group,
+ kwargs,
+ )
+ _new_model = kwargs.get("model")
+ _metadata = kwargs.get("metadata", {})
+ _tags = cast(List[str], kwargs.get("tags") or [])
+ standard_metadata: StandardLoggingMetadata = (
+ StandardLoggingPayloadSetup.get_standard_logging_metadata(
+ metadata=_metadata
+ )
+ )
+
+ enum_values = UserAPIKeyLabelValues(
+ requested_model=original_model_group,
+ fallback_model=_new_model,
+ hashed_api_key=standard_metadata["user_api_key_hash"],
+ api_key_alias=standard_metadata["user_api_key_alias"],
+ team=standard_metadata["user_api_key_team_id"],
+ team_alias=standard_metadata["user_api_key_team_alias"],
+ exception_status=str(getattr(original_exception, "status_code", None)),
+ exception_class=str(original_exception.__class__.__name__),
+ tags=_tags,
+ )
+
+ _labels = prometheus_label_factory(
+ supported_enum_labels=PrometheusMetricLabels.get_labels(
+ label_name="litellm_deployment_failed_fallbacks"
+ ),
+ enum_values=enum_values,
+ )
+ self.litellm_deployment_failed_fallbacks.labels(**_labels).inc()
+
+ def set_litellm_deployment_state(
+ self,
+ state: int,
+ litellm_model_name: str,
+ model_id: Optional[str],
+ api_base: Optional[str],
+ api_provider: str,
+ ):
+ self.litellm_deployment_state.labels(
+ litellm_model_name, model_id, api_base, api_provider
+ ).set(state)
+
+ def set_deployment_healthy(
+ self,
+ litellm_model_name: str,
+ model_id: str,
+ api_base: str,
+ api_provider: str,
+ ):
+ self.set_litellm_deployment_state(
+ 0, litellm_model_name, model_id, api_base, api_provider
+ )
+
+ def set_deployment_partial_outage(
+ self,
+ litellm_model_name: str,
+ model_id: Optional[str],
+ api_base: Optional[str],
+ api_provider: str,
+ ):
+ self.set_litellm_deployment_state(
+ 1, litellm_model_name, model_id, api_base, api_provider
+ )
+
+ def set_deployment_complete_outage(
+ self,
+ litellm_model_name: str,
+ model_id: Optional[str],
+ api_base: Optional[str],
+ api_provider: str,
+ ):
+ self.set_litellm_deployment_state(
+ 2, litellm_model_name, model_id, api_base, api_provider
+ )
+
+ def increment_deployment_cooled_down(
+ self,
+ litellm_model_name: str,
+ model_id: str,
+ api_base: str,
+ api_provider: str,
+ exception_status: str,
+ ):
+ """
+ increment metric when litellm.Router / load balancing logic places a deployment in cool down
+ """
+ self.litellm_deployment_cooled_down.labels(
+ litellm_model_name, model_id, api_base, api_provider, exception_status
+ ).inc()
+
+ def track_provider_remaining_budget(
+ self, provider: str, spend: float, budget_limit: float
+ ):
+ """
+ Track provider remaining budget in Prometheus
+ """
+ self.litellm_provider_remaining_budget_metric.labels(provider).set(
+ self._safe_get_remaining_budget(
+ max_budget=budget_limit,
+ spend=spend,
+ )
+ )
+
+ def _safe_get_remaining_budget(
+ self, max_budget: Optional[float], spend: Optional[float]
+ ) -> float:
+ if max_budget is None:
+ return float("inf")
+
+ if spend is None:
+ return max_budget
+
+ return max_budget - spend
+
+ def _initialize_prometheus_startup_metrics(self):
+ """
+ Initialize prometheus startup metrics
+
+ Helper to create tasks for initializing metrics that are required on startup - eg. remaining budget metrics
+ """
+ if litellm.prometheus_initialize_budget_metrics is not True:
+ verbose_logger.debug("Prometheus: skipping budget metrics initialization")
+ return
+
+ try:
+ if asyncio.get_running_loop():
+ asyncio.create_task(self._initialize_remaining_budget_metrics())
+ except RuntimeError as e: # no running event loop
+ verbose_logger.exception(
+ f"No running event loop - skipping budget metrics initialization: {str(e)}"
+ )
+
+ async def _initialize_budget_metrics(
+ self,
+ data_fetch_function: Callable[..., Awaitable[Tuple[List[Any], Optional[int]]]],
+ set_metrics_function: Callable[[List[Any]], Awaitable[None]],
+ data_type: Literal["teams", "keys"],
+ ):
+ """
+ Generic method to initialize budget metrics for teams or API keys.
+
+ Args:
+ data_fetch_function: Function to fetch data with pagination.
+ set_metrics_function: Function to set metrics for the fetched data.
+ data_type: String representing the type of data ("teams" or "keys") for logging purposes.
+ """
+ from litellm.proxy.proxy_server import prisma_client
+
+ if prisma_client is None:
+ return
+
+ try:
+ page = 1
+ page_size = 50
+ data, total_count = await data_fetch_function(
+ page_size=page_size, page=page
+ )
+
+ if total_count is None:
+ total_count = len(data)
+
+ # Calculate total pages needed
+ total_pages = (total_count + page_size - 1) // page_size
+
+ # Set metrics for first page of data
+ await set_metrics_function(data)
+
+ # Get and set metrics for remaining pages
+ for page in range(2, total_pages + 1):
+ data, _ = await data_fetch_function(page_size=page_size, page=page)
+ await set_metrics_function(data)
+
+ except Exception as e:
+ verbose_logger.exception(
+ f"Error initializing {data_type} budget metrics: {str(e)}"
+ )
+
+ async def _initialize_team_budget_metrics(self):
+ """
+ Initialize team budget metrics by reusing the generic pagination logic.
+ """
+ from litellm.proxy.management_endpoints.team_endpoints import (
+ get_paginated_teams,
+ )
+ from litellm.proxy.proxy_server import prisma_client
+
+ if prisma_client is None:
+ verbose_logger.debug(
+ "Prometheus: skipping team metrics initialization, DB not initialized"
+ )
+ return
+
+ async def fetch_teams(
+ page_size: int, page: int
+ ) -> Tuple[List[LiteLLM_TeamTable], Optional[int]]:
+ teams, total_count = await get_paginated_teams(
+ prisma_client=prisma_client, page_size=page_size, page=page
+ )
+ if total_count is None:
+ total_count = len(teams)
+ return teams, total_count
+
+ await self._initialize_budget_metrics(
+ data_fetch_function=fetch_teams,
+ set_metrics_function=self._set_team_list_budget_metrics,
+ data_type="teams",
+ )
+
+ async def _initialize_api_key_budget_metrics(self):
+ """
+ Initialize API key budget metrics by reusing the generic pagination logic.
+ """
+ from typing import Union
+
+ from litellm.constants import UI_SESSION_TOKEN_TEAM_ID
+ from litellm.proxy.management_endpoints.key_management_endpoints import (
+ _list_key_helper,
+ )
+ from litellm.proxy.proxy_server import prisma_client
+
+ if prisma_client is None:
+ verbose_logger.debug(
+ "Prometheus: skipping key metrics initialization, DB not initialized"
+ )
+ return
+
+ async def fetch_keys(
+ page_size: int, page: int
+ ) -> Tuple[List[Union[str, UserAPIKeyAuth]], Optional[int]]:
+ key_list_response = await _list_key_helper(
+ prisma_client=prisma_client,
+ page=page,
+ size=page_size,
+ user_id=None,
+ team_id=None,
+ key_alias=None,
+ exclude_team_id=UI_SESSION_TOKEN_TEAM_ID,
+ return_full_object=True,
+ organization_id=None,
+ )
+ keys = key_list_response.get("keys", [])
+ total_count = key_list_response.get("total_count")
+ if total_count is None:
+ total_count = len(keys)
+ return keys, total_count
+
+ await self._initialize_budget_metrics(
+ data_fetch_function=fetch_keys,
+ set_metrics_function=self._set_key_list_budget_metrics,
+ data_type="keys",
+ )
+
+ async def _initialize_remaining_budget_metrics(self):
+ """
+ Initialize remaining budget metrics for all teams to avoid metric discrepancies.
+
+ Runs when prometheus logger starts up.
+ """
+ await self._initialize_team_budget_metrics()
+ await self._initialize_api_key_budget_metrics()
+
+ async def _set_key_list_budget_metrics(
+ self, keys: List[Union[str, UserAPIKeyAuth]]
+ ):
+ """Helper function to set budget metrics for a list of keys"""
+ for key in keys:
+ if isinstance(key, UserAPIKeyAuth):
+ self._set_key_budget_metrics(key)
+
+ async def _set_team_list_budget_metrics(self, teams: List[LiteLLM_TeamTable]):
+ """Helper function to set budget metrics for a list of teams"""
+ for team in teams:
+ self._set_team_budget_metrics(team)
+
+ async def _set_team_budget_metrics_after_api_request(
+ self,
+ user_api_team: Optional[str],
+ user_api_team_alias: Optional[str],
+ team_spend: float,
+ team_max_budget: float,
+ response_cost: float,
+ ):
+ """
+ Set team budget metrics after an LLM API request
+
+ - Assemble a LiteLLM_TeamTable object
+ - looks up team info from db if not available in metadata
+ - Set team budget metrics
+ """
+ if user_api_team:
+ team_object = await self._assemble_team_object(
+ team_id=user_api_team,
+ team_alias=user_api_team_alias or "",
+ spend=team_spend,
+ max_budget=team_max_budget,
+ response_cost=response_cost,
+ )
+
+ self._set_team_budget_metrics(team_object)
+
+ async def _assemble_team_object(
+ self,
+ team_id: str,
+ team_alias: str,
+ spend: Optional[float],
+ max_budget: Optional[float],
+ response_cost: float,
+ ) -> LiteLLM_TeamTable:
+ """
+ Assemble a LiteLLM_TeamTable object
+
+ for fields not available in metadata, we fetch from db
+ Fields not available in metadata:
+ - `budget_reset_at`
+ """
+ from litellm.proxy.auth.auth_checks import get_team_object
+ from litellm.proxy.proxy_server import prisma_client, user_api_key_cache
+
+ _total_team_spend = (spend or 0) + response_cost
+ team_object = LiteLLM_TeamTable(
+ team_id=team_id,
+ team_alias=team_alias,
+ spend=_total_team_spend,
+ max_budget=max_budget,
+ )
+ try:
+ team_info = await get_team_object(
+ team_id=team_id,
+ prisma_client=prisma_client,
+ user_api_key_cache=user_api_key_cache,
+ )
+ except Exception as e:
+ verbose_logger.debug(
+ f"[Non-Blocking] Prometheus: Error getting team info: {str(e)}"
+ )
+ return team_object
+
+ if team_info:
+ team_object.budget_reset_at = team_info.budget_reset_at
+
+ return team_object
+
+ def _set_team_budget_metrics(
+ self,
+ team: LiteLLM_TeamTable,
+ ):
+ """
+ Set team budget metrics for a single team
+
+ - Remaining Budget
+ - Max Budget
+ - Budget Reset At
+ """
+ enum_values = UserAPIKeyLabelValues(
+ team=team.team_id,
+ team_alias=team.team_alias or "",
+ )
+
+ _labels = prometheus_label_factory(
+ supported_enum_labels=PrometheusMetricLabels.get_labels(
+ label_name="litellm_remaining_team_budget_metric"
+ ),
+ enum_values=enum_values,
+ )
+ self.litellm_remaining_team_budget_metric.labels(**_labels).set(
+ self._safe_get_remaining_budget(
+ max_budget=team.max_budget,
+ spend=team.spend,
+ )
+ )
+
+ if team.max_budget is not None:
+ _labels = prometheus_label_factory(
+ supported_enum_labels=PrometheusMetricLabels.get_labels(
+ label_name="litellm_team_max_budget_metric"
+ ),
+ enum_values=enum_values,
+ )
+ self.litellm_team_max_budget_metric.labels(**_labels).set(team.max_budget)
+
+ if team.budget_reset_at is not None:
+ _labels = prometheus_label_factory(
+ supported_enum_labels=PrometheusMetricLabels.get_labels(
+ label_name="litellm_team_budget_remaining_hours_metric"
+ ),
+ enum_values=enum_values,
+ )
+ self.litellm_team_budget_remaining_hours_metric.labels(**_labels).set(
+ self._get_remaining_hours_for_budget_reset(
+ budget_reset_at=team.budget_reset_at
+ )
+ )
+
+ def _set_key_budget_metrics(self, user_api_key_dict: UserAPIKeyAuth):
+ """
+ Set virtual key budget metrics
+
+ - Remaining Budget
+ - Max Budget
+ - Budget Reset At
+ """
+ enum_values = UserAPIKeyLabelValues(
+ hashed_api_key=user_api_key_dict.token,
+ api_key_alias=user_api_key_dict.key_alias or "",
+ )
+ _labels = prometheus_label_factory(
+ supported_enum_labels=PrometheusMetricLabels.get_labels(
+ label_name="litellm_remaining_api_key_budget_metric"
+ ),
+ enum_values=enum_values,
+ )
+ self.litellm_remaining_api_key_budget_metric.labels(**_labels).set(
+ self._safe_get_remaining_budget(
+ max_budget=user_api_key_dict.max_budget,
+ spend=user_api_key_dict.spend,
+ )
+ )
+
+ if user_api_key_dict.max_budget is not None:
+ _labels = prometheus_label_factory(
+ supported_enum_labels=PrometheusMetricLabels.get_labels(
+ label_name="litellm_api_key_max_budget_metric"
+ ),
+ enum_values=enum_values,
+ )
+ self.litellm_api_key_max_budget_metric.labels(**_labels).set(
+ user_api_key_dict.max_budget
+ )
+
+ if user_api_key_dict.budget_reset_at is not None:
+ self.litellm_api_key_budget_remaining_hours_metric.labels(**_labels).set(
+ self._get_remaining_hours_for_budget_reset(
+ budget_reset_at=user_api_key_dict.budget_reset_at
+ )
+ )
+
+ async def _set_api_key_budget_metrics_after_api_request(
+ self,
+ user_api_key: Optional[str],
+ user_api_key_alias: Optional[str],
+ response_cost: float,
+ key_max_budget: float,
+ key_spend: Optional[float],
+ ):
+ if user_api_key:
+ user_api_key_dict = await self._assemble_key_object(
+ user_api_key=user_api_key,
+ user_api_key_alias=user_api_key_alias or "",
+ key_max_budget=key_max_budget,
+ key_spend=key_spend,
+ response_cost=response_cost,
+ )
+ self._set_key_budget_metrics(user_api_key_dict)
+
+ async def _assemble_key_object(
+ self,
+ user_api_key: str,
+ user_api_key_alias: str,
+ key_max_budget: float,
+ key_spend: Optional[float],
+ response_cost: float,
+ ) -> UserAPIKeyAuth:
+ """
+ Assemble a UserAPIKeyAuth object
+ """
+ from litellm.proxy.auth.auth_checks import get_key_object
+ from litellm.proxy.proxy_server import prisma_client, user_api_key_cache
+
+ _total_key_spend = (key_spend or 0) + response_cost
+ user_api_key_dict = UserAPIKeyAuth(
+ token=user_api_key,
+ key_alias=user_api_key_alias,
+ max_budget=key_max_budget,
+ spend=_total_key_spend,
+ )
+ try:
+ if user_api_key_dict.token:
+ key_object = await get_key_object(
+ hashed_token=user_api_key_dict.token,
+ prisma_client=prisma_client,
+ user_api_key_cache=user_api_key_cache,
+ )
+ if key_object:
+ user_api_key_dict.budget_reset_at = key_object.budget_reset_at
+ except Exception as e:
+ verbose_logger.debug(
+ f"[Non-Blocking] Prometheus: Error getting key info: {str(e)}"
+ )
+
+ return user_api_key_dict
+
+ def _get_remaining_hours_for_budget_reset(self, budget_reset_at: datetime) -> float:
+ """
+ Get remaining hours for budget reset
+ """
+ return (
+ budget_reset_at - datetime.now(budget_reset_at.tzinfo)
+ ).total_seconds() / 3600
+
+ def _safe_duration_seconds(
+ self,
+ start_time: Any,
+ end_time: Any,
+ ) -> Optional[float]:
+ """
+ Compute the duration in seconds between two objects.
+
+ Returns the duration as a float if both start and end are instances of datetime,
+ otherwise returns None.
+ """
+ if isinstance(start_time, datetime) and isinstance(end_time, datetime):
+ return (end_time - start_time).total_seconds()
+ return None
+
+
+def prometheus_label_factory(
+ supported_enum_labels: List[str],
+ enum_values: UserAPIKeyLabelValues,
+ tag: Optional[str] = None,
+) -> dict:
+ """
+ Returns a dictionary of label + values for prometheus.
+
+ Ensures end_user param is not sent to prometheus if it is not supported.
+ """
+ # Extract dictionary from Pydantic object
+ enum_dict = enum_values.model_dump()
+
+ # Filter supported labels
+ filtered_labels = {
+ label: value
+ for label, value in enum_dict.items()
+ if label in supported_enum_labels
+ }
+
+ if UserAPIKeyLabelNames.END_USER.value in filtered_labels:
+ filtered_labels["end_user"] = get_end_user_id_for_cost_tracking(
+ litellm_params={"user_api_key_end_user_id": enum_values.end_user},
+ service_type="prometheus",
+ )
+
+ if enum_values.custom_metadata_labels is not None:
+ for key, value in enum_values.custom_metadata_labels.items():
+ if key in supported_enum_labels:
+ filtered_labels[key] = value
+
+ for label in supported_enum_labels:
+ if label not in filtered_labels:
+ filtered_labels[label] = None
+
+ return filtered_labels
+
+
+def get_custom_labels_from_metadata(metadata: dict) -> Dict[str, str]:
+ """
+ Get custom labels from metadata
+ """
+ keys = litellm.custom_prometheus_metadata_labels
+ if keys is None or len(keys) == 0:
+ return {}
+
+ result: Dict[str, str] = {}
+
+ for key in keys:
+ # Split the dot notation key into parts
+ original_key = key
+ key = key.replace("metadata.", "", 1) if key.startswith("metadata.") else key
+
+ keys_parts = key.split(".")
+ # Traverse through the dictionary using the parts
+ value = metadata
+ for part in keys_parts:
+ value = value.get(part, None) # Get the value, return None if not found
+ if value is None:
+ break
+
+ if value is not None and isinstance(value, str):
+ result[original_key.replace(".", "_")] = value
+
+ return result
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/prometheus_helpers/prometheus_api.py b/.venv/lib/python3.12/site-packages/litellm/integrations/prometheus_helpers/prometheus_api.py
new file mode 100644
index 00000000..b25da577
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/prometheus_helpers/prometheus_api.py
@@ -0,0 +1,137 @@
+"""
+Helper functions to query prometheus API
+"""
+
+import time
+from datetime import datetime, timedelta
+from typing import Optional
+
+from litellm import get_secret
+from litellm._logging import verbose_logger
+from litellm.llms.custom_httpx.http_handler import (
+ get_async_httpx_client,
+ httpxSpecialProvider,
+)
+
+PROMETHEUS_URL: Optional[str] = get_secret("PROMETHEUS_URL") # type: ignore
+PROMETHEUS_SELECTED_INSTANCE: Optional[str] = get_secret("PROMETHEUS_SELECTED_INSTANCE") # type: ignore
+async_http_handler = get_async_httpx_client(
+ llm_provider=httpxSpecialProvider.LoggingCallback
+)
+
+
+async def get_metric_from_prometheus(
+ metric_name: str,
+):
+ # Get the start of the current day in Unix timestamp
+ if PROMETHEUS_URL is None:
+ raise ValueError(
+ "PROMETHEUS_URL not set please set 'PROMETHEUS_URL=<>' in .env"
+ )
+
+ query = f"{metric_name}[24h]"
+ now = int(time.time())
+ response = await async_http_handler.get(
+ f"{PROMETHEUS_URL}/api/v1/query", params={"query": query, "time": now}
+ ) # End of the day
+ _json_response = response.json()
+ verbose_logger.debug("json response from prometheus /query api %s", _json_response)
+ results = response.json()["data"]["result"]
+ return results
+
+
+async def get_fallback_metric_from_prometheus():
+ """
+ Gets fallback metrics from prometheus for the last 24 hours
+ """
+ response_message = ""
+ relevant_metrics = [
+ "litellm_deployment_successful_fallbacks_total",
+ "litellm_deployment_failed_fallbacks_total",
+ ]
+ for metric in relevant_metrics:
+ response_json = await get_metric_from_prometheus(
+ metric_name=metric,
+ )
+
+ if response_json:
+ verbose_logger.debug("response json %s", response_json)
+ for result in response_json:
+ verbose_logger.debug("result= %s", result)
+ metric = result["metric"]
+ metric_values = result["values"]
+ most_recent_value = metric_values[0]
+
+ if PROMETHEUS_SELECTED_INSTANCE is not None:
+ if metric.get("instance") != PROMETHEUS_SELECTED_INSTANCE:
+ continue
+
+ value = int(float(most_recent_value[1])) # Convert value to integer
+ primary_model = metric.get("primary_model", "Unknown")
+ fallback_model = metric.get("fallback_model", "Unknown")
+ response_message += f"`{value} successful fallback requests` with primary model=`{primary_model}` -> fallback model=`{fallback_model}`"
+ response_message += "\n"
+ verbose_logger.debug("response message %s", response_message)
+ return response_message
+
+
+def is_prometheus_connected() -> bool:
+ if PROMETHEUS_URL is not None:
+ return True
+ return False
+
+
+async def get_daily_spend_from_prometheus(api_key: Optional[str]):
+ """
+ Expected Response Format:
+ [
+ {
+ "date": "2024-08-18T00:00:00+00:00",
+ "spend": 1.001818099998933
+ },
+ ...]
+ """
+ if PROMETHEUS_URL is None:
+ raise ValueError(
+ "PROMETHEUS_URL not set please set 'PROMETHEUS_URL=<>' in .env"
+ )
+
+ # Calculate the start and end dates for the last 30 days
+ end_date = datetime.utcnow()
+ start_date = end_date - timedelta(days=30)
+
+ # Format dates as ISO 8601 strings with UTC offset
+ start_str = start_date.isoformat() + "+00:00"
+ end_str = end_date.isoformat() + "+00:00"
+
+ url = f"{PROMETHEUS_URL}/api/v1/query_range"
+
+ if api_key is None:
+ query = "sum(delta(litellm_spend_metric_total[1d]))"
+ else:
+ query = (
+ f'sum(delta(litellm_spend_metric_total{{hashed_api_key="{api_key}"}}[1d]))'
+ )
+
+ params = {
+ "query": query,
+ "start": start_str,
+ "end": end_str,
+ "step": "86400", # Step size of 1 day in seconds
+ }
+
+ response = await async_http_handler.get(url, params=params)
+ _json_response = response.json()
+ verbose_logger.debug("json response from prometheus /query api %s", _json_response)
+ results = response.json()["data"]["result"]
+ formatted_results = []
+
+ for result in results:
+ metric_data = result["values"]
+ for timestamp, value in metric_data:
+ # Convert timestamp to ISO 8601 string with UTC offset
+ date = datetime.fromtimestamp(float(timestamp)).isoformat() + "+00:00"
+ spend = float(value)
+ formatted_results.append({"date": date, "spend": spend})
+
+ return formatted_results
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/prometheus_services.py b/.venv/lib/python3.12/site-packages/litellm/integrations/prometheus_services.py
new file mode 100644
index 00000000..4bf293fb
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/prometheus_services.py
@@ -0,0 +1,222 @@
+# used for monitoring litellm services health on `/metrics` endpoint on LiteLLM Proxy
+#### What this does ####
+# On success + failure, log events to Prometheus for litellm / adjacent services (litellm, redis, postgres, llm api providers)
+
+
+from typing import List, Optional, Union
+
+from litellm._logging import print_verbose, verbose_logger
+from litellm.types.integrations.prometheus import LATENCY_BUCKETS
+from litellm.types.services import ServiceLoggerPayload, ServiceTypes
+
+FAILED_REQUESTS_LABELS = ["error_class", "function_name"]
+
+
+class PrometheusServicesLogger:
+ # Class variables or attributes
+ litellm_service_latency = None # Class-level attribute to store the Histogram
+
+ def __init__(
+ self,
+ mock_testing: bool = False,
+ **kwargs,
+ ):
+ try:
+ try:
+ from prometheus_client import REGISTRY, Counter, Histogram
+ except ImportError:
+ raise Exception(
+ "Missing prometheus_client. Run `pip install prometheus-client`"
+ )
+
+ self.Histogram = Histogram
+ self.Counter = Counter
+ self.REGISTRY = REGISTRY
+
+ verbose_logger.debug("in init prometheus services metrics")
+
+ self.services = [item.value for item in ServiceTypes]
+
+ self.payload_to_prometheus_map = (
+ {}
+ ) # store the prometheus histogram/counter we need to call for each field in payload
+
+ for service in self.services:
+ histogram = self.create_histogram(service, type_of_request="latency")
+ counter_failed_request = self.create_counter(
+ service,
+ type_of_request="failed_requests",
+ additional_labels=FAILED_REQUESTS_LABELS,
+ )
+ counter_total_requests = self.create_counter(
+ service, type_of_request="total_requests"
+ )
+ self.payload_to_prometheus_map[service] = [
+ histogram,
+ counter_failed_request,
+ counter_total_requests,
+ ]
+
+ self.prometheus_to_amount_map: dict = (
+ {}
+ ) # the field / value in ServiceLoggerPayload the object needs to be incremented by
+
+ ### MOCK TESTING ###
+ self.mock_testing = mock_testing
+ self.mock_testing_success_calls = 0
+ self.mock_testing_failure_calls = 0
+
+ except Exception as e:
+ print_verbose(f"Got exception on init prometheus client {str(e)}")
+ raise e
+
+ def is_metric_registered(self, metric_name) -> bool:
+ for metric in self.REGISTRY.collect():
+ if metric_name == metric.name:
+ return True
+ return False
+
+ def _get_metric(self, metric_name):
+ """
+ Helper function to get a metric from the registry by name.
+ """
+ return self.REGISTRY._names_to_collectors.get(metric_name)
+
+ def create_histogram(self, service: str, type_of_request: str):
+ metric_name = "litellm_{}_{}".format(service, type_of_request)
+ is_registered = self.is_metric_registered(metric_name)
+ if is_registered:
+ return self._get_metric(metric_name)
+ return self.Histogram(
+ metric_name,
+ "Latency for {} service".format(service),
+ labelnames=[service],
+ buckets=LATENCY_BUCKETS,
+ )
+
+ def create_counter(
+ self,
+ service: str,
+ type_of_request: str,
+ additional_labels: Optional[List[str]] = None,
+ ):
+ metric_name = "litellm_{}_{}".format(service, type_of_request)
+ is_registered = self.is_metric_registered(metric_name)
+ if is_registered:
+ return self._get_metric(metric_name)
+ return self.Counter(
+ metric_name,
+ "Total {} for {} service".format(type_of_request, service),
+ labelnames=[service] + (additional_labels or []),
+ )
+
+ def observe_histogram(
+ self,
+ histogram,
+ labels: str,
+ amount: float,
+ ):
+ assert isinstance(histogram, self.Histogram)
+
+ histogram.labels(labels).observe(amount)
+
+ def increment_counter(
+ self,
+ counter,
+ labels: str,
+ amount: float,
+ additional_labels: Optional[List[str]] = [],
+ ):
+ assert isinstance(counter, self.Counter)
+
+ if additional_labels:
+ counter.labels(labels, *additional_labels).inc(amount)
+ else:
+ counter.labels(labels).inc(amount)
+
+ def service_success_hook(self, payload: ServiceLoggerPayload):
+ if self.mock_testing:
+ self.mock_testing_success_calls += 1
+
+ if payload.service.value in self.payload_to_prometheus_map:
+ prom_objects = self.payload_to_prometheus_map[payload.service.value]
+ for obj in prom_objects:
+ if isinstance(obj, self.Histogram):
+ self.observe_histogram(
+ histogram=obj,
+ labels=payload.service.value,
+ amount=payload.duration,
+ )
+ elif isinstance(obj, self.Counter) and "total_requests" in obj._name:
+ self.increment_counter(
+ counter=obj,
+ labels=payload.service.value,
+ amount=1, # LOG TOTAL REQUESTS TO PROMETHEUS
+ )
+
+ def service_failure_hook(self, payload: ServiceLoggerPayload):
+ if self.mock_testing:
+ self.mock_testing_failure_calls += 1
+
+ if payload.service.value in self.payload_to_prometheus_map:
+ prom_objects = self.payload_to_prometheus_map[payload.service.value]
+ for obj in prom_objects:
+ if isinstance(obj, self.Counter):
+ self.increment_counter(
+ counter=obj,
+ labels=payload.service.value,
+ amount=1, # LOG ERROR COUNT / TOTAL REQUESTS TO PROMETHEUS
+ )
+
+ async def async_service_success_hook(self, payload: ServiceLoggerPayload):
+ """
+ Log successful call to prometheus
+ """
+ if self.mock_testing:
+ self.mock_testing_success_calls += 1
+
+ if payload.service.value in self.payload_to_prometheus_map:
+ prom_objects = self.payload_to_prometheus_map[payload.service.value]
+ for obj in prom_objects:
+ if isinstance(obj, self.Histogram):
+ self.observe_histogram(
+ histogram=obj,
+ labels=payload.service.value,
+ amount=payload.duration,
+ )
+ elif isinstance(obj, self.Counter) and "total_requests" in obj._name:
+ self.increment_counter(
+ counter=obj,
+ labels=payload.service.value,
+ amount=1, # LOG TOTAL REQUESTS TO PROMETHEUS
+ )
+
+ async def async_service_failure_hook(
+ self,
+ payload: ServiceLoggerPayload,
+ error: Union[str, Exception],
+ ):
+ if self.mock_testing:
+ self.mock_testing_failure_calls += 1
+ error_class = error.__class__.__name__
+ function_name = payload.call_type
+
+ if payload.service.value in self.payload_to_prometheus_map:
+ prom_objects = self.payload_to_prometheus_map[payload.service.value]
+ for obj in prom_objects:
+ # increment both failed and total requests
+ if isinstance(obj, self.Counter):
+ if "failed_requests" in obj._name:
+ self.increment_counter(
+ counter=obj,
+ labels=payload.service.value,
+ # log additional_labels=["error_class", "function_name"], used for debugging what's going wrong with the DB
+ additional_labels=[error_class, function_name],
+ amount=1, # LOG ERROR COUNT TO PROMETHEUS
+ )
+ else:
+ self.increment_counter(
+ counter=obj,
+ labels=payload.service.value,
+ amount=1, # LOG TOTAL REQUESTS TO PROMETHEUS
+ )
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/prompt_layer.py b/.venv/lib/python3.12/site-packages/litellm/integrations/prompt_layer.py
new file mode 100644
index 00000000..190b995f
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/prompt_layer.py
@@ -0,0 +1,91 @@
+#### What this does ####
+# On success, logs events to Promptlayer
+import os
+import traceback
+
+from pydantic import BaseModel
+
+import litellm
+
+
+class PromptLayerLogger:
+ # Class variables or attributes
+ def __init__(self):
+ # Instance variables
+ self.key = os.getenv("PROMPTLAYER_API_KEY")
+
+ def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
+ # Method definition
+ try:
+ new_kwargs = {}
+ new_kwargs["model"] = kwargs["model"]
+ new_kwargs["messages"] = kwargs["messages"]
+
+ # add kwargs["optional_params"] to new_kwargs
+ for optional_param in kwargs["optional_params"]:
+ new_kwargs[optional_param] = kwargs["optional_params"][optional_param]
+
+ # Extract PromptLayer tags from metadata, if such exists
+ tags = []
+ metadata = {}
+ if "metadata" in kwargs["litellm_params"]:
+ if "pl_tags" in kwargs["litellm_params"]["metadata"]:
+ tags = kwargs["litellm_params"]["metadata"]["pl_tags"]
+
+ # Remove "pl_tags" from metadata
+ metadata = {
+ k: v
+ for k, v in kwargs["litellm_params"]["metadata"].items()
+ if k != "pl_tags"
+ }
+
+ print_verbose(
+ f"Prompt Layer Logging - Enters logging function for model kwargs: {new_kwargs}\n, response: {response_obj}"
+ )
+
+ # python-openai >= 1.0.0 returns Pydantic objects instead of jsons
+ if isinstance(response_obj, BaseModel):
+ response_obj = response_obj.model_dump()
+
+ request_response = litellm.module_level_client.post(
+ "https://api.promptlayer.com/rest/track-request",
+ json={
+ "function_name": "openai.ChatCompletion.create",
+ "kwargs": new_kwargs,
+ "tags": tags,
+ "request_response": dict(response_obj),
+ "request_start_time": int(start_time.timestamp()),
+ "request_end_time": int(end_time.timestamp()),
+ "api_key": self.key,
+ # Optional params for PromptLayer
+ # "prompt_id": "<PROMPT ID>",
+ # "prompt_input_variables": "<Dictionary of variables for prompt>",
+ # "prompt_version":1,
+ },
+ )
+
+ response_json = request_response.json()
+ if not request_response.json().get("success", False):
+ raise Exception("Promptlayer did not successfully log the response!")
+
+ print_verbose(
+ f"Prompt Layer Logging: success - final response object: {request_response.text}"
+ )
+
+ if "request_id" in response_json:
+ if metadata:
+ response = litellm.module_level_client.post(
+ "https://api.promptlayer.com/rest/track-metadata",
+ json={
+ "request_id": response_json["request_id"],
+ "api_key": self.key,
+ "metadata": metadata,
+ },
+ )
+ print_verbose(
+ f"Prompt Layer Logging: success - metadata post response object: {response.text}"
+ )
+
+ except Exception:
+ print_verbose(f"error: Prompt Layer Error - {traceback.format_exc()}")
+ pass
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/prompt_management_base.py b/.venv/lib/python3.12/site-packages/litellm/integrations/prompt_management_base.py
new file mode 100644
index 00000000..3fe3b31e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/prompt_management_base.py
@@ -0,0 +1,118 @@
+from abc import ABC, abstractmethod
+from typing import Any, Dict, List, Optional, Tuple, TypedDict
+
+from litellm.types.llms.openai import AllMessageValues
+from litellm.types.utils import StandardCallbackDynamicParams
+
+
+class PromptManagementClient(TypedDict):
+ prompt_id: str
+ prompt_template: List[AllMessageValues]
+ prompt_template_model: Optional[str]
+ prompt_template_optional_params: Optional[Dict[str, Any]]
+ completed_messages: Optional[List[AllMessageValues]]
+
+
+class PromptManagementBase(ABC):
+
+ @property
+ @abstractmethod
+ def integration_name(self) -> str:
+ pass
+
+ @abstractmethod
+ def should_run_prompt_management(
+ self,
+ prompt_id: str,
+ dynamic_callback_params: StandardCallbackDynamicParams,
+ ) -> bool:
+ pass
+
+ @abstractmethod
+ def _compile_prompt_helper(
+ self,
+ prompt_id: str,
+ prompt_variables: Optional[dict],
+ dynamic_callback_params: StandardCallbackDynamicParams,
+ ) -> PromptManagementClient:
+ pass
+
+ def merge_messages(
+ self,
+ prompt_template: List[AllMessageValues],
+ client_messages: List[AllMessageValues],
+ ) -> List[AllMessageValues]:
+ return prompt_template + client_messages
+
+ def compile_prompt(
+ self,
+ prompt_id: str,
+ prompt_variables: Optional[dict],
+ client_messages: List[AllMessageValues],
+ dynamic_callback_params: StandardCallbackDynamicParams,
+ ) -> PromptManagementClient:
+ compiled_prompt_client = self._compile_prompt_helper(
+ prompt_id=prompt_id,
+ prompt_variables=prompt_variables,
+ dynamic_callback_params=dynamic_callback_params,
+ )
+
+ try:
+ messages = compiled_prompt_client["prompt_template"] + client_messages
+ except Exception as e:
+ raise ValueError(
+ f"Error compiling prompt: {e}. Prompt id={prompt_id}, prompt_variables={prompt_variables}, client_messages={client_messages}, dynamic_callback_params={dynamic_callback_params}"
+ )
+
+ compiled_prompt_client["completed_messages"] = messages
+ return compiled_prompt_client
+
+ def _get_model_from_prompt(
+ self, prompt_management_client: PromptManagementClient, model: str
+ ) -> str:
+ if prompt_management_client["prompt_template_model"] is not None:
+ return prompt_management_client["prompt_template_model"]
+ else:
+ return model.replace("{}/".format(self.integration_name), "")
+
+ def get_chat_completion_prompt(
+ self,
+ model: str,
+ messages: List[AllMessageValues],
+ non_default_params: dict,
+ prompt_id: str,
+ prompt_variables: Optional[dict],
+ dynamic_callback_params: StandardCallbackDynamicParams,
+ ) -> Tuple[
+ str,
+ List[AllMessageValues],
+ dict,
+ ]:
+ if not self.should_run_prompt_management(
+ prompt_id=prompt_id, dynamic_callback_params=dynamic_callback_params
+ ):
+ return model, messages, non_default_params
+
+ prompt_template = self.compile_prompt(
+ prompt_id=prompt_id,
+ prompt_variables=prompt_variables,
+ client_messages=messages,
+ dynamic_callback_params=dynamic_callback_params,
+ )
+
+ completed_messages = prompt_template["completed_messages"] or messages
+
+ prompt_template_optional_params = (
+ prompt_template["prompt_template_optional_params"] or {}
+ )
+
+ updated_non_default_params = {
+ **non_default_params,
+ **prompt_template_optional_params,
+ }
+
+ model = self._get_model_from_prompt(
+ prompt_management_client=prompt_template, model=model
+ )
+
+ return model, completed_messages, updated_non_default_params
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/s3.py b/.venv/lib/python3.12/site-packages/litellm/integrations/s3.py
new file mode 100644
index 00000000..4a0c2735
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/s3.py
@@ -0,0 +1,196 @@
+#### What this does ####
+# On success + failure, log events to Supabase
+
+from datetime import datetime
+from typing import Optional, cast
+
+import litellm
+from litellm._logging import print_verbose, verbose_logger
+from litellm.types.utils import StandardLoggingPayload
+
+
+class S3Logger:
+ # Class variables or attributes
+ def __init__(
+ self,
+ s3_bucket_name=None,
+ s3_path=None,
+ s3_region_name=None,
+ s3_api_version=None,
+ s3_use_ssl=True,
+ s3_verify=None,
+ s3_endpoint_url=None,
+ s3_aws_access_key_id=None,
+ s3_aws_secret_access_key=None,
+ s3_aws_session_token=None,
+ s3_config=None,
+ **kwargs,
+ ):
+ import boto3
+
+ try:
+ verbose_logger.debug(
+ f"in init s3 logger - s3_callback_params {litellm.s3_callback_params}"
+ )
+
+ s3_use_team_prefix = False
+
+ if litellm.s3_callback_params is not None:
+ # read in .env variables - example os.environ/AWS_BUCKET_NAME
+ for key, value in litellm.s3_callback_params.items():
+ if type(value) is str and value.startswith("os.environ/"):
+ litellm.s3_callback_params[key] = litellm.get_secret(value)
+ # now set s3 params from litellm.s3_logger_params
+ s3_bucket_name = litellm.s3_callback_params.get("s3_bucket_name")
+ s3_region_name = litellm.s3_callback_params.get("s3_region_name")
+ s3_api_version = litellm.s3_callback_params.get("s3_api_version")
+ s3_use_ssl = litellm.s3_callback_params.get("s3_use_ssl", True)
+ s3_verify = litellm.s3_callback_params.get("s3_verify")
+ s3_endpoint_url = litellm.s3_callback_params.get("s3_endpoint_url")
+ s3_aws_access_key_id = litellm.s3_callback_params.get(
+ "s3_aws_access_key_id"
+ )
+ s3_aws_secret_access_key = litellm.s3_callback_params.get(
+ "s3_aws_secret_access_key"
+ )
+ s3_aws_session_token = litellm.s3_callback_params.get(
+ "s3_aws_session_token"
+ )
+ s3_config = litellm.s3_callback_params.get("s3_config")
+ s3_path = litellm.s3_callback_params.get("s3_path")
+ # done reading litellm.s3_callback_params
+ s3_use_team_prefix = bool(
+ litellm.s3_callback_params.get("s3_use_team_prefix", False)
+ )
+ self.s3_use_team_prefix = s3_use_team_prefix
+ self.bucket_name = s3_bucket_name
+ self.s3_path = s3_path
+ verbose_logger.debug(f"s3 logger using endpoint url {s3_endpoint_url}")
+ # Create an S3 client with custom endpoint URL
+ self.s3_client = boto3.client(
+ "s3",
+ region_name=s3_region_name,
+ endpoint_url=s3_endpoint_url,
+ api_version=s3_api_version,
+ use_ssl=s3_use_ssl,
+ verify=s3_verify,
+ aws_access_key_id=s3_aws_access_key_id,
+ aws_secret_access_key=s3_aws_secret_access_key,
+ aws_session_token=s3_aws_session_token,
+ config=s3_config,
+ **kwargs,
+ )
+ except Exception as e:
+ print_verbose(f"Got exception on init s3 client {str(e)}")
+ raise e
+
+ async def _async_log_event(
+ self, kwargs, response_obj, start_time, end_time, print_verbose
+ ):
+ self.log_event(kwargs, response_obj, start_time, end_time, print_verbose)
+
+ def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
+ try:
+ verbose_logger.debug(
+ f"s3 Logging - Enters logging function for model {kwargs}"
+ )
+
+ # construct payload to send to s3
+ # follows the same params as langfuse.py
+ litellm_params = kwargs.get("litellm_params", {})
+ metadata = (
+ litellm_params.get("metadata", {}) or {}
+ ) # if litellm_params['metadata'] == None
+
+ # Clean Metadata before logging - never log raw metadata
+ # the raw metadata can contain circular references which leads to infinite recursion
+ # we clean out all extra litellm metadata params before logging
+ clean_metadata = {}
+ if isinstance(metadata, dict):
+ for key, value in metadata.items():
+ # clean litellm metadata before logging
+ if key in [
+ "headers",
+ "endpoint",
+ "caching_groups",
+ "previous_models",
+ ]:
+ continue
+ else:
+ clean_metadata[key] = value
+
+ # Ensure everything in the payload is converted to str
+ payload: Optional[StandardLoggingPayload] = cast(
+ Optional[StandardLoggingPayload],
+ kwargs.get("standard_logging_object", None),
+ )
+
+ if payload is None:
+ return
+
+ team_alias = payload["metadata"].get("user_api_key_team_alias")
+
+ team_alias_prefix = ""
+ if (
+ litellm.enable_preview_features
+ and self.s3_use_team_prefix
+ and team_alias is not None
+ ):
+ team_alias_prefix = f"{team_alias}/"
+
+ s3_file_name = litellm.utils.get_logging_id(start_time, payload) or ""
+ s3_object_key = get_s3_object_key(
+ cast(Optional[str], self.s3_path) or "",
+ team_alias_prefix,
+ start_time,
+ s3_file_name,
+ )
+
+ s3_object_download_filename = (
+ "time-"
+ + start_time.strftime("%Y-%m-%dT%H-%M-%S-%f")
+ + "_"
+ + payload["id"]
+ + ".json"
+ )
+
+ import json
+
+ payload_str = json.dumps(payload)
+
+ print_verbose(f"\ns3 Logger - Logging payload = {payload_str}")
+
+ response = self.s3_client.put_object(
+ Bucket=self.bucket_name,
+ Key=s3_object_key,
+ Body=payload_str,
+ ContentType="application/json",
+ ContentLanguage="en",
+ ContentDisposition=f'inline; filename="{s3_object_download_filename}"',
+ CacheControl="private, immutable, max-age=31536000, s-maxage=0",
+ )
+
+ print_verbose(f"Response from s3:{str(response)}")
+
+ print_verbose(f"s3 Layer Logging - final response object: {response_obj}")
+ return response
+ except Exception as e:
+ verbose_logger.exception(f"s3 Layer Error - {str(e)}")
+ pass
+
+
+def get_s3_object_key(
+ s3_path: str,
+ team_alias_prefix: str,
+ start_time: datetime,
+ s3_file_name: str,
+) -> str:
+ s3_object_key = (
+ (s3_path.rstrip("/") + "/" if s3_path else "")
+ + team_alias_prefix
+ + start_time.strftime("%Y-%m-%d")
+ + "/"
+ + s3_file_name
+ ) # we need the s3 key to include the time, so we log cache hits too
+ s3_object_key += ".json"
+ return s3_object_key
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/supabase.py b/.venv/lib/python3.12/site-packages/litellm/integrations/supabase.py
new file mode 100644
index 00000000..7eb007f8
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/supabase.py
@@ -0,0 +1,120 @@
+#### What this does ####
+# On success + failure, log events to Supabase
+
+import os
+import subprocess
+import sys
+import traceback
+
+import litellm
+
+
+class Supabase:
+ # Class variables or attributes
+ supabase_table_name = "request_logs"
+
+ def __init__(self):
+ # Instance variables
+ self.supabase_url = os.getenv("SUPABASE_URL")
+ self.supabase_key = os.getenv("SUPABASE_KEY")
+ try:
+ import supabase
+ except ImportError:
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "supabase"])
+ import supabase
+
+ if self.supabase_url is None or self.supabase_key is None:
+ raise ValueError(
+ "LiteLLM Error, trying to use Supabase but url or key not passed. Create a table and set `litellm.supabase_url=<your-url>` and `litellm.supabase_key=<your-key>`"
+ )
+ self.supabase_client = supabase.create_client( # type: ignore
+ self.supabase_url, self.supabase_key
+ )
+
+ def input_log_event(
+ self, model, messages, end_user, litellm_call_id, print_verbose
+ ):
+ try:
+ print_verbose(
+ f"Supabase Logging - Enters input logging function for model {model}"
+ )
+ supabase_data_obj = {
+ "model": model,
+ "messages": messages,
+ "end_user": end_user,
+ "status": "initiated",
+ "litellm_call_id": litellm_call_id,
+ }
+ data, count = (
+ self.supabase_client.table(self.supabase_table_name)
+ .insert(supabase_data_obj)
+ .execute()
+ )
+ print_verbose(f"data: {data}")
+ except Exception:
+ print_verbose(f"Supabase Logging Error - {traceback.format_exc()}")
+ pass
+
+ def log_event(
+ self,
+ model,
+ messages,
+ end_user,
+ response_obj,
+ start_time,
+ end_time,
+ litellm_call_id,
+ print_verbose,
+ ):
+ try:
+ print_verbose(
+ f"Supabase Logging - Enters logging function for model {model}, response_obj: {response_obj}"
+ )
+
+ total_cost = litellm.completion_cost(completion_response=response_obj)
+
+ response_time = (end_time - start_time).total_seconds()
+ if "choices" in response_obj:
+ supabase_data_obj = {
+ "response_time": response_time,
+ "model": response_obj["model"],
+ "total_cost": total_cost,
+ "messages": messages,
+ "response": response_obj["choices"][0]["message"]["content"],
+ "end_user": end_user,
+ "litellm_call_id": litellm_call_id,
+ "status": "success",
+ }
+ print_verbose(
+ f"Supabase Logging - final data object: {supabase_data_obj}"
+ )
+ data, count = (
+ self.supabase_client.table(self.supabase_table_name)
+ .upsert(supabase_data_obj, on_conflict="litellm_call_id")
+ .execute()
+ )
+ elif "error" in response_obj:
+ if "Unable to map your input to a model." in response_obj["error"]:
+ total_cost = 0
+ supabase_data_obj = {
+ "response_time": response_time,
+ "model": response_obj["model"],
+ "total_cost": total_cost,
+ "messages": messages,
+ "error": response_obj["error"],
+ "end_user": end_user,
+ "litellm_call_id": litellm_call_id,
+ "status": "failure",
+ }
+ print_verbose(
+ f"Supabase Logging - final data object: {supabase_data_obj}"
+ )
+ data, count = (
+ self.supabase_client.table(self.supabase_table_name)
+ .upsert(supabase_data_obj, on_conflict="litellm_call_id")
+ .execute()
+ )
+
+ except Exception:
+ print_verbose(f"Supabase Logging Error - {traceback.format_exc()}")
+ pass
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/test_httpx.py b/.venv/lib/python3.12/site-packages/litellm/integrations/test_httpx.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/test_httpx.py
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/traceloop.py b/.venv/lib/python3.12/site-packages/litellm/integrations/traceloop.py
new file mode 100644
index 00000000..b4f3905c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/traceloop.py
@@ -0,0 +1,152 @@
+import traceback
+
+from litellm._logging import verbose_logger
+
+
+class TraceloopLogger:
+ """
+ WARNING: DEPRECATED
+ Use the OpenTelemetry standard integration instead
+ """
+
+ def __init__(self):
+ try:
+ from traceloop.sdk import Traceloop
+ from traceloop.sdk.tracing.tracing import TracerWrapper
+ except ModuleNotFoundError as e:
+ verbose_logger.error(
+ f"Traceloop not installed, try running 'pip install traceloop-sdk' to fix this error: {e}\n{traceback.format_exc()}"
+ )
+ raise e
+
+ Traceloop.init(
+ app_name="Litellm-Server",
+ disable_batch=True,
+ )
+ self.tracer_wrapper = TracerWrapper()
+
+ def log_event(
+ self,
+ kwargs,
+ response_obj,
+ start_time,
+ end_time,
+ user_id,
+ print_verbose,
+ level="DEFAULT",
+ status_message=None,
+ ):
+ from opentelemetry.semconv.ai import SpanAttributes
+ from opentelemetry.trace import SpanKind, Status, StatusCode
+
+ try:
+ print_verbose(
+ f"Traceloop Logging - Enters logging function for model {kwargs}"
+ )
+
+ tracer = self.tracer_wrapper.get_tracer()
+
+ optional_params = kwargs.get("optional_params", {})
+ start_time = int(start_time.timestamp())
+ end_time = int(end_time.timestamp())
+ span = tracer.start_span(
+ "litellm.completion", kind=SpanKind.CLIENT, start_time=start_time
+ )
+
+ if span.is_recording():
+ span.set_attribute(
+ SpanAttributes.LLM_REQUEST_MODEL, kwargs.get("model")
+ )
+ if "stop" in optional_params:
+ span.set_attribute(
+ SpanAttributes.LLM_CHAT_STOP_SEQUENCES,
+ optional_params.get("stop"),
+ )
+ if "frequency_penalty" in optional_params:
+ span.set_attribute(
+ SpanAttributes.LLM_FREQUENCY_PENALTY,
+ optional_params.get("frequency_penalty"),
+ )
+ if "presence_penalty" in optional_params:
+ span.set_attribute(
+ SpanAttributes.LLM_PRESENCE_PENALTY,
+ optional_params.get("presence_penalty"),
+ )
+ if "top_p" in optional_params:
+ span.set_attribute(
+ SpanAttributes.LLM_REQUEST_TOP_P, optional_params.get("top_p")
+ )
+ if "tools" in optional_params or "functions" in optional_params:
+ span.set_attribute(
+ SpanAttributes.LLM_REQUEST_FUNCTIONS,
+ optional_params.get("tools", optional_params.get("functions")),
+ )
+ if "user" in optional_params:
+ span.set_attribute(
+ SpanAttributes.LLM_USER, optional_params.get("user")
+ )
+ if "max_tokens" in optional_params:
+ span.set_attribute(
+ SpanAttributes.LLM_REQUEST_MAX_TOKENS,
+ kwargs.get("max_tokens"),
+ )
+ if "temperature" in optional_params:
+ span.set_attribute(
+ SpanAttributes.LLM_REQUEST_TEMPERATURE, # type: ignore
+ kwargs.get("temperature"),
+ )
+
+ for idx, prompt in enumerate(kwargs.get("messages")):
+ span.set_attribute(
+ f"{SpanAttributes.LLM_PROMPTS}.{idx}.role",
+ prompt.get("role"),
+ )
+ span.set_attribute(
+ f"{SpanAttributes.LLM_PROMPTS}.{idx}.content",
+ prompt.get("content"),
+ )
+
+ span.set_attribute(
+ SpanAttributes.LLM_RESPONSE_MODEL, response_obj.get("model")
+ )
+ usage = response_obj.get("usage")
+ if usage:
+ span.set_attribute(
+ SpanAttributes.LLM_USAGE_TOTAL_TOKENS,
+ usage.get("total_tokens"),
+ )
+ span.set_attribute(
+ SpanAttributes.LLM_USAGE_COMPLETION_TOKENS,
+ usage.get("completion_tokens"),
+ )
+ span.set_attribute(
+ SpanAttributes.LLM_USAGE_PROMPT_TOKENS,
+ usage.get("prompt_tokens"),
+ )
+
+ for idx, choice in enumerate(response_obj.get("choices")):
+ span.set_attribute(
+ f"{SpanAttributes.LLM_COMPLETIONS}.{idx}.finish_reason",
+ choice.get("finish_reason"),
+ )
+ span.set_attribute(
+ f"{SpanAttributes.LLM_COMPLETIONS}.{idx}.role",
+ choice.get("message").get("role"),
+ )
+ span.set_attribute(
+ f"{SpanAttributes.LLM_COMPLETIONS}.{idx}.content",
+ choice.get("message").get("content"),
+ )
+
+ if (
+ level == "ERROR"
+ and status_message is not None
+ and isinstance(status_message, str)
+ ):
+ span.record_exception(Exception(status_message))
+ span.set_status(Status(StatusCode.ERROR, status_message))
+
+ span.end(end_time)
+
+ except Exception as e:
+ print_verbose(f"Traceloop Layer Error - {e}")
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/weights_biases.py b/.venv/lib/python3.12/site-packages/litellm/integrations/weights_biases.py
new file mode 100644
index 00000000..5fcbab04
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/weights_biases.py
@@ -0,0 +1,217 @@
+imported_openAIResponse = True
+try:
+ import io
+ import logging
+ import sys
+ from typing import Any, Dict, List, Optional, TypeVar
+
+ from wandb.sdk.data_types import trace_tree
+
+ if sys.version_info >= (3, 8):
+ from typing import Literal, Protocol
+ else:
+ from typing_extensions import Literal, Protocol
+
+ logger = logging.getLogger(__name__)
+
+ K = TypeVar("K", bound=str)
+ V = TypeVar("V")
+
+ class OpenAIResponse(Protocol[K, V]): # type: ignore
+ # contains a (known) object attribute
+ object: Literal["chat.completion", "edit", "text_completion"]
+
+ def __getitem__(self, key: K) -> V: ... # noqa
+
+ def get( # noqa
+ self, key: K, default: Optional[V] = None
+ ) -> Optional[V]: ... # pragma: no cover
+
+ class OpenAIRequestResponseResolver:
+ def __call__(
+ self,
+ request: Dict[str, Any],
+ response: OpenAIResponse,
+ time_elapsed: float,
+ ) -> Optional[trace_tree.WBTraceTree]:
+ try:
+ if response["object"] == "edit":
+ return self._resolve_edit(request, response, time_elapsed)
+ elif response["object"] == "text_completion":
+ return self._resolve_completion(request, response, time_elapsed)
+ elif response["object"] == "chat.completion":
+ return self._resolve_chat_completion(
+ request, response, time_elapsed
+ )
+ else:
+ logger.info(f"Unknown OpenAI response object: {response['object']}")
+ except Exception as e:
+ logger.warning(f"Failed to resolve request/response: {e}")
+ return None
+
+ @staticmethod
+ def results_to_trace_tree(
+ request: Dict[str, Any],
+ response: OpenAIResponse,
+ results: List[trace_tree.Result],
+ time_elapsed: float,
+ ) -> trace_tree.WBTraceTree:
+ """Converts the request, response, and results into a trace tree.
+
+ params:
+ request: The request dictionary
+ response: The response object
+ results: A list of results object
+ time_elapsed: The time elapsed in seconds
+ returns:
+ A wandb trace tree object.
+ """
+ start_time_ms = int(round(response["created"] * 1000))
+ end_time_ms = start_time_ms + int(round(time_elapsed * 1000))
+ span = trace_tree.Span(
+ name=f"{response.get('model', 'openai')}_{response['object']}_{response.get('created')}",
+ attributes=dict(response), # type: ignore
+ start_time_ms=start_time_ms,
+ end_time_ms=end_time_ms,
+ span_kind=trace_tree.SpanKind.LLM,
+ results=results,
+ )
+ model_obj = {"request": request, "response": response, "_kind": "openai"}
+ return trace_tree.WBTraceTree(root_span=span, model_dict=model_obj)
+
+ def _resolve_edit(
+ self,
+ request: Dict[str, Any],
+ response: OpenAIResponse,
+ time_elapsed: float,
+ ) -> trace_tree.WBTraceTree:
+ """Resolves the request and response objects for `openai.Edit`."""
+ request_str = (
+ f"\n\n**Instruction**: {request['instruction']}\n\n"
+ f"**Input**: {request['input']}\n"
+ )
+ choices = [
+ f"\n\n**Edited**: {choice['text']}\n" for choice in response["choices"]
+ ]
+
+ return self._request_response_result_to_trace(
+ request=request,
+ response=response,
+ request_str=request_str,
+ choices=choices,
+ time_elapsed=time_elapsed,
+ )
+
+ def _resolve_completion(
+ self,
+ request: Dict[str, Any],
+ response: OpenAIResponse,
+ time_elapsed: float,
+ ) -> trace_tree.WBTraceTree:
+ """Resolves the request and response objects for `openai.Completion`."""
+ request_str = f"\n\n**Prompt**: {request['prompt']}\n"
+ choices = [
+ f"\n\n**Completion**: {choice['text']}\n"
+ for choice in response["choices"]
+ ]
+
+ return self._request_response_result_to_trace(
+ request=request,
+ response=response,
+ request_str=request_str,
+ choices=choices,
+ time_elapsed=time_elapsed,
+ )
+
+ def _resolve_chat_completion(
+ self,
+ request: Dict[str, Any],
+ response: OpenAIResponse,
+ time_elapsed: float,
+ ) -> trace_tree.WBTraceTree:
+ """Resolves the request and response objects for `openai.Completion`."""
+ prompt = io.StringIO()
+ for message in request["messages"]:
+ prompt.write(f"\n\n**{message['role']}**: {message['content']}\n")
+ request_str = prompt.getvalue()
+
+ choices = [
+ f"\n\n**{choice['message']['role']}**: {choice['message']['content']}\n"
+ for choice in response["choices"]
+ ]
+
+ return self._request_response_result_to_trace(
+ request=request,
+ response=response,
+ request_str=request_str,
+ choices=choices,
+ time_elapsed=time_elapsed,
+ )
+
+ def _request_response_result_to_trace(
+ self,
+ request: Dict[str, Any],
+ response: OpenAIResponse,
+ request_str: str,
+ choices: List[str],
+ time_elapsed: float,
+ ) -> trace_tree.WBTraceTree:
+ """Resolves the request and response objects for `openai.Completion`."""
+ results = [
+ trace_tree.Result(
+ inputs={"request": request_str},
+ outputs={"response": choice},
+ )
+ for choice in choices
+ ]
+ trace = self.results_to_trace_tree(request, response, results, time_elapsed)
+ return trace
+
+except Exception:
+ imported_openAIResponse = False
+
+
+#### What this does ####
+# On success, logs events to Langfuse
+import traceback
+
+
+class WeightsBiasesLogger:
+ # Class variables or attributes
+ def __init__(self):
+ try:
+ pass
+ except Exception:
+ raise Exception(
+ "\033[91m wandb not installed, try running 'pip install wandb' to fix this error\033[0m"
+ )
+ if imported_openAIResponse is False:
+ raise Exception(
+ "\033[91m wandb not installed, try running 'pip install wandb' to fix this error\033[0m"
+ )
+ self.resolver = OpenAIRequestResponseResolver()
+
+ def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
+ # Method definition
+ import wandb
+
+ try:
+ print_verbose(f"W&B Logging - Enters logging function for model {kwargs}")
+ run = wandb.init()
+ print_verbose(response_obj)
+
+ trace = self.resolver(
+ kwargs, response_obj, (end_time - start_time).total_seconds()
+ )
+
+ if trace is not None and run is not None:
+ run.log({"trace": trace})
+
+ if run is not None:
+ run.finish()
+ print_verbose(
+ f"W&B Logging Logging - final response object: {response_obj}"
+ )
+ except Exception:
+ print_verbose(f"W&B Logging Layer Error - {traceback.format_exc()}")
+ pass