aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/integrations/mlflow.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/integrations/mlflow.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/integrations/mlflow.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/mlflow.py269
1 files changed, 269 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/mlflow.py b/.venv/lib/python3.12/site-packages/litellm/integrations/mlflow.py
new file mode 100644
index 00000000..193d1c4e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/mlflow.py
@@ -0,0 +1,269 @@
+import json
+import threading
+from typing import Optional
+
+from litellm._logging import verbose_logger
+from litellm.integrations.custom_logger import CustomLogger
+
+
+class MlflowLogger(CustomLogger):
+ def __init__(self):
+ from mlflow.tracking import MlflowClient
+
+ self._client = MlflowClient()
+
+ self._stream_id_to_span = {}
+ self._lock = threading.Lock() # lock for _stream_id_to_span
+
+ def log_success_event(self, kwargs, response_obj, start_time, end_time):
+ self._handle_success(kwargs, response_obj, start_time, end_time)
+
+ async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
+ self._handle_success(kwargs, response_obj, start_time, end_time)
+
+ def _handle_success(self, kwargs, response_obj, start_time, end_time):
+ """
+ Log the success event as an MLflow span.
+ Note that this method is called asynchronously in the background thread.
+ """
+ from mlflow.entities import SpanStatusCode
+
+ try:
+ verbose_logger.debug("MLflow logging start for success event")
+
+ if kwargs.get("stream"):
+ self._handle_stream_event(kwargs, response_obj, start_time, end_time)
+ else:
+ span = self._start_span_or_trace(kwargs, start_time)
+ end_time_ns = int(end_time.timestamp() * 1e9)
+ self._extract_and_set_chat_attributes(span, kwargs, response_obj)
+ self._end_span_or_trace(
+ span=span,
+ outputs=response_obj,
+ status=SpanStatusCode.OK,
+ end_time_ns=end_time_ns,
+ )
+ except Exception:
+ verbose_logger.debug("MLflow Logging Error", stack_info=True)
+
+ def _extract_and_set_chat_attributes(self, span, kwargs, response_obj):
+ try:
+ from mlflow.tracing.utils import set_span_chat_messages, set_span_chat_tools
+ except ImportError:
+ return
+
+ inputs = self._construct_input(kwargs)
+ input_messages = inputs.get("messages", [])
+ output_messages = [c.message.model_dump(exclude_none=True)
+ for c in getattr(response_obj, "choices", [])]
+ if messages := [*input_messages, *output_messages]:
+ set_span_chat_messages(span, messages)
+ if tools := inputs.get("tools"):
+ set_span_chat_tools(span, tools)
+
+ def log_failure_event(self, kwargs, response_obj, start_time, end_time):
+ self._handle_failure(kwargs, response_obj, start_time, end_time)
+
+ async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
+ self._handle_failure(kwargs, response_obj, start_time, end_time)
+
+ def _handle_failure(self, kwargs, response_obj, start_time, end_time):
+ """
+ Log the failure event as an MLflow span.
+ Note that this method is called *synchronously* unlike the success handler.
+ """
+ from mlflow.entities import SpanEvent, SpanStatusCode
+
+ try:
+ span = self._start_span_or_trace(kwargs, start_time)
+
+ end_time_ns = int(end_time.timestamp() * 1e9)
+
+ # Record exception info as event
+ if exception := kwargs.get("exception"):
+ span.add_event(SpanEvent.from_exception(exception)) # type: ignore
+
+ self._extract_and_set_chat_attributes(span, kwargs, response_obj)
+ self._end_span_or_trace(
+ span=span,
+ outputs=response_obj,
+ status=SpanStatusCode.ERROR,
+ end_time_ns=end_time_ns,
+ )
+
+ except Exception as e:
+ verbose_logger.debug(f"MLflow Logging Error - {e}", stack_info=True)
+
+ def _handle_stream_event(self, kwargs, response_obj, start_time, end_time):
+ """
+ Handle the success event for a streaming response. For streaming calls,
+ log_success_event handle is triggered for every chunk of the stream.
+ We create a single span for the entire stream request as follows:
+
+ 1. For the first chunk, start a new span and store it in the map.
+ 2. For subsequent chunks, add the chunk as an event to the span.
+ 3. For the final chunk, end the span and remove the span from the map.
+ """
+ from mlflow.entities import SpanStatusCode
+
+ litellm_call_id = kwargs.get("litellm_call_id")
+
+ if litellm_call_id not in self._stream_id_to_span:
+ with self._lock:
+ # Check again after acquiring lock
+ if litellm_call_id not in self._stream_id_to_span:
+ # Start a new span for the first chunk of the stream
+ span = self._start_span_or_trace(kwargs, start_time)
+ self._stream_id_to_span[litellm_call_id] = span
+
+ # Add chunk as event to the span
+ span = self._stream_id_to_span[litellm_call_id]
+ self._add_chunk_events(span, response_obj)
+
+ # If this is the final chunk, end the span. The final chunk
+ # has complete_streaming_response that gathers the full response.
+ if final_response := kwargs.get("complete_streaming_response"):
+ end_time_ns = int(end_time.timestamp() * 1e9)
+
+ self._extract_and_set_chat_attributes(span, kwargs, final_response)
+ self._end_span_or_trace(
+ span=span,
+ outputs=final_response,
+ status=SpanStatusCode.OK,
+ end_time_ns=end_time_ns,
+ )
+
+ # Remove the stream_id from the map
+ with self._lock:
+ self._stream_id_to_span.pop(litellm_call_id)
+
+ def _add_chunk_events(self, span, response_obj):
+ from mlflow.entities import SpanEvent
+
+ try:
+ for choice in response_obj.choices:
+ span.add_event(
+ SpanEvent(
+ name="streaming_chunk",
+ attributes={"delta": json.dumps(choice.delta.model_dump())},
+ )
+ )
+ except Exception:
+ verbose_logger.debug("Error adding chunk events to span", stack_info=True)
+
+ def _construct_input(self, kwargs):
+ """Construct span inputs with optional parameters"""
+ inputs = {"messages": kwargs.get("messages")}
+ if tools := kwargs.get("tools"):
+ inputs["tools"] = tools
+
+ for key in ["functions", "tools", "stream", "tool_choice", "user"]:
+ if value := kwargs.get("optional_params", {}).pop(key, None):
+ inputs[key] = value
+ return inputs
+
+ def _extract_attributes(self, kwargs):
+ """
+ Extract span attributes from kwargs.
+
+ With the latest version of litellm, the standard_logging_object contains
+ canonical information for logging. If it is not present, we extract
+ subset of attributes from other kwargs.
+ """
+ attributes = {
+ "litellm_call_id": kwargs.get("litellm_call_id"),
+ "call_type": kwargs.get("call_type"),
+ "model": kwargs.get("model"),
+ }
+ standard_obj = kwargs.get("standard_logging_object")
+ if standard_obj:
+ attributes.update(
+ {
+ "api_base": standard_obj.get("api_base"),
+ "cache_hit": standard_obj.get("cache_hit"),
+ "usage": {
+ "completion_tokens": standard_obj.get("completion_tokens"),
+ "prompt_tokens": standard_obj.get("prompt_tokens"),
+ "total_tokens": standard_obj.get("total_tokens"),
+ },
+ "raw_llm_response": standard_obj.get("response"),
+ "response_cost": standard_obj.get("response_cost"),
+ "saved_cache_cost": standard_obj.get("saved_cache_cost"),
+ }
+ )
+ else:
+ litellm_params = kwargs.get("litellm_params", {})
+ attributes.update(
+ {
+ "model": kwargs.get("model"),
+ "cache_hit": kwargs.get("cache_hit"),
+ "custom_llm_provider": kwargs.get("custom_llm_provider"),
+ "api_base": litellm_params.get("api_base"),
+ "response_cost": kwargs.get("response_cost"),
+ }
+ )
+ return attributes
+
+ def _get_span_type(self, call_type: Optional[str]) -> str:
+ from mlflow.entities import SpanType
+
+ if call_type in ["completion", "acompletion"]:
+ return SpanType.LLM
+ elif call_type == "embeddings":
+ return SpanType.EMBEDDING
+ else:
+ return SpanType.LLM
+
+ def _start_span_or_trace(self, kwargs, start_time):
+ """
+ Start an MLflow span or a trace.
+
+ If there is an active span, we start a new span as a child of
+ that span. Otherwise, we start a new trace.
+ """
+ import mlflow
+
+ call_type = kwargs.get("call_type", "completion")
+ span_name = f"litellm-{call_type}"
+ span_type = self._get_span_type(call_type)
+ start_time_ns = int(start_time.timestamp() * 1e9)
+
+ inputs = self._construct_input(kwargs)
+ attributes = self._extract_attributes(kwargs)
+
+ if active_span := mlflow.get_current_active_span(): # type: ignore
+ return self._client.start_span(
+ name=span_name,
+ request_id=active_span.request_id,
+ parent_id=active_span.span_id,
+ span_type=span_type,
+ inputs=inputs,
+ attributes=attributes,
+ start_time_ns=start_time_ns,
+ )
+ else:
+ return self._client.start_trace(
+ name=span_name,
+ span_type=span_type,
+ inputs=inputs,
+ attributes=attributes,
+ start_time_ns=start_time_ns,
+ )
+
+ def _end_span_or_trace(self, span, outputs, end_time_ns, status):
+ """End an MLflow span or a trace."""
+ if span.parent_id is None:
+ self._client.end_trace(
+ request_id=span.request_id,
+ outputs=outputs,
+ status=status,
+ end_time_ns=end_time_ns,
+ )
+ else:
+ self._client.end_span(
+ request_id=span.request_id,
+ span_id=span.span_id,
+ outputs=outputs,
+ status=status,
+ end_time_ns=end_time_ns,
+ )