about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/litellm/integrations/mlflow.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/integrations/mlflow.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/mlflow.py269
1 files changed, 269 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/mlflow.py b/.venv/lib/python3.12/site-packages/litellm/integrations/mlflow.py
new file mode 100644
index 00000000..193d1c4e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/mlflow.py
@@ -0,0 +1,269 @@
+import json
+import threading
+from typing import Optional
+
+from litellm._logging import verbose_logger
+from litellm.integrations.custom_logger import CustomLogger
+
+
+class MlflowLogger(CustomLogger):
+    def __init__(self):
+        from mlflow.tracking import MlflowClient
+
+        self._client = MlflowClient()
+
+        self._stream_id_to_span = {}
+        self._lock = threading.Lock()  # lock for _stream_id_to_span
+
+    def log_success_event(self, kwargs, response_obj, start_time, end_time):
+        self._handle_success(kwargs, response_obj, start_time, end_time)
+
+    async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
+        self._handle_success(kwargs, response_obj, start_time, end_time)
+
+    def _handle_success(self, kwargs, response_obj, start_time, end_time):
+        """
+        Log the success event as an MLflow span.
+        Note that this method is called asynchronously in the background thread.
+        """
+        from mlflow.entities import SpanStatusCode
+
+        try:
+            verbose_logger.debug("MLflow logging start for success event")
+
+            if kwargs.get("stream"):
+                self._handle_stream_event(kwargs, response_obj, start_time, end_time)
+            else:
+                span = self._start_span_or_trace(kwargs, start_time)
+                end_time_ns = int(end_time.timestamp() * 1e9)
+                self._extract_and_set_chat_attributes(span, kwargs, response_obj)
+                self._end_span_or_trace(
+                    span=span,
+                    outputs=response_obj,
+                    status=SpanStatusCode.OK,
+                    end_time_ns=end_time_ns,
+                )
+        except Exception:
+            verbose_logger.debug("MLflow Logging Error", stack_info=True)
+
+    def _extract_and_set_chat_attributes(self, span, kwargs, response_obj):
+        try:
+            from mlflow.tracing.utils import set_span_chat_messages, set_span_chat_tools
+        except ImportError:
+            return
+
+        inputs = self._construct_input(kwargs)
+        input_messages = inputs.get("messages", [])
+        output_messages = [c.message.model_dump(exclude_none=True)
+                           for c in getattr(response_obj, "choices", [])]
+        if messages := [*input_messages, *output_messages]:
+            set_span_chat_messages(span, messages)
+        if tools := inputs.get("tools"):
+            set_span_chat_tools(span, tools)
+
+    def log_failure_event(self, kwargs, response_obj, start_time, end_time):
+        self._handle_failure(kwargs, response_obj, start_time, end_time)
+
+    async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
+        self._handle_failure(kwargs, response_obj, start_time, end_time)
+
+    def _handle_failure(self, kwargs, response_obj, start_time, end_time):
+        """
+        Log the failure event as an MLflow span.
+        Note that this method is called *synchronously* unlike the success handler.
+        """
+        from mlflow.entities import SpanEvent, SpanStatusCode
+
+        try:
+            span = self._start_span_or_trace(kwargs, start_time)
+
+            end_time_ns = int(end_time.timestamp() * 1e9)
+
+            # Record exception info as event
+            if exception := kwargs.get("exception"):
+                span.add_event(SpanEvent.from_exception(exception))  # type: ignore
+
+            self._extract_and_set_chat_attributes(span, kwargs, response_obj)
+            self._end_span_or_trace(
+                span=span,
+                outputs=response_obj,
+                status=SpanStatusCode.ERROR,
+                end_time_ns=end_time_ns,
+            )
+
+        except Exception as e:
+            verbose_logger.debug(f"MLflow Logging Error - {e}", stack_info=True)
+
+    def _handle_stream_event(self, kwargs, response_obj, start_time, end_time):
+        """
+        Handle the success event for a streaming response. For streaming calls,
+        log_success_event handle is triggered for every chunk of the stream.
+        We create a single span for the entire stream request as follows:
+
+        1. For the first chunk, start a new span and store it in the map.
+        2. For subsequent chunks, add the chunk as an event to the span.
+        3. For the final chunk, end the span and remove the span from the map.
+        """
+        from mlflow.entities import SpanStatusCode
+
+        litellm_call_id = kwargs.get("litellm_call_id")
+
+        if litellm_call_id not in self._stream_id_to_span:
+            with self._lock:
+                # Check again after acquiring lock
+                if litellm_call_id not in self._stream_id_to_span:
+                    # Start a new span for the first chunk of the stream
+                    span = self._start_span_or_trace(kwargs, start_time)
+                    self._stream_id_to_span[litellm_call_id] = span
+
+        # Add chunk as event to the span
+        span = self._stream_id_to_span[litellm_call_id]
+        self._add_chunk_events(span, response_obj)
+
+        # If this is the final chunk, end the span. The final chunk
+        # has complete_streaming_response that gathers the full response.
+        if final_response := kwargs.get("complete_streaming_response"):
+            end_time_ns = int(end_time.timestamp() * 1e9)
+
+            self._extract_and_set_chat_attributes(span, kwargs, final_response)
+            self._end_span_or_trace(
+                span=span,
+                outputs=final_response,
+                status=SpanStatusCode.OK,
+                end_time_ns=end_time_ns,
+            )
+
+            # Remove the stream_id from the map
+            with self._lock:
+                self._stream_id_to_span.pop(litellm_call_id)
+
+    def _add_chunk_events(self, span, response_obj):
+        from mlflow.entities import SpanEvent
+
+        try:
+            for choice in response_obj.choices:
+                span.add_event(
+                    SpanEvent(
+                        name="streaming_chunk",
+                        attributes={"delta": json.dumps(choice.delta.model_dump())},
+                    )
+                )
+        except Exception:
+            verbose_logger.debug("Error adding chunk events to span", stack_info=True)
+
+    def _construct_input(self, kwargs):
+        """Construct span inputs with optional parameters"""
+        inputs = {"messages": kwargs.get("messages")}
+        if tools := kwargs.get("tools"):
+            inputs["tools"] = tools
+
+        for key in ["functions", "tools", "stream", "tool_choice", "user"]:
+            if value := kwargs.get("optional_params", {}).pop(key, None):
+                inputs[key] = value
+        return inputs
+
+    def _extract_attributes(self, kwargs):
+        """
+        Extract span attributes from kwargs.
+
+        With the latest version of litellm, the standard_logging_object contains
+        canonical information for logging. If it is not present, we extract
+        subset of attributes from other kwargs.
+        """
+        attributes = {
+            "litellm_call_id": kwargs.get("litellm_call_id"),
+            "call_type": kwargs.get("call_type"),
+            "model": kwargs.get("model"),
+        }
+        standard_obj = kwargs.get("standard_logging_object")
+        if standard_obj:
+            attributes.update(
+                {
+                    "api_base": standard_obj.get("api_base"),
+                    "cache_hit": standard_obj.get("cache_hit"),
+                    "usage": {
+                        "completion_tokens": standard_obj.get("completion_tokens"),
+                        "prompt_tokens": standard_obj.get("prompt_tokens"),
+                        "total_tokens": standard_obj.get("total_tokens"),
+                    },
+                    "raw_llm_response": standard_obj.get("response"),
+                    "response_cost": standard_obj.get("response_cost"),
+                    "saved_cache_cost": standard_obj.get("saved_cache_cost"),
+                }
+            )
+        else:
+            litellm_params = kwargs.get("litellm_params", {})
+            attributes.update(
+                {
+                    "model": kwargs.get("model"),
+                    "cache_hit": kwargs.get("cache_hit"),
+                    "custom_llm_provider": kwargs.get("custom_llm_provider"),
+                    "api_base": litellm_params.get("api_base"),
+                    "response_cost": kwargs.get("response_cost"),
+                }
+            )
+        return attributes
+
+    def _get_span_type(self, call_type: Optional[str]) -> str:
+        from mlflow.entities import SpanType
+
+        if call_type in ["completion", "acompletion"]:
+            return SpanType.LLM
+        elif call_type == "embeddings":
+            return SpanType.EMBEDDING
+        else:
+            return SpanType.LLM
+
+    def _start_span_or_trace(self, kwargs, start_time):
+        """
+        Start an MLflow span or a trace.
+
+        If there is an active span, we start a new span as a child of
+        that span. Otherwise, we start a new trace.
+        """
+        import mlflow
+
+        call_type = kwargs.get("call_type", "completion")
+        span_name = f"litellm-{call_type}"
+        span_type = self._get_span_type(call_type)
+        start_time_ns = int(start_time.timestamp() * 1e9)
+
+        inputs = self._construct_input(kwargs)
+        attributes = self._extract_attributes(kwargs)
+
+        if active_span := mlflow.get_current_active_span():  # type: ignore
+            return self._client.start_span(
+                name=span_name,
+                request_id=active_span.request_id,
+                parent_id=active_span.span_id,
+                span_type=span_type,
+                inputs=inputs,
+                attributes=attributes,
+                start_time_ns=start_time_ns,
+            )
+        else:
+            return self._client.start_trace(
+                name=span_name,
+                span_type=span_type,
+                inputs=inputs,
+                attributes=attributes,
+                start_time_ns=start_time_ns,
+            )
+
+    def _end_span_or_trace(self, span, outputs, end_time_ns, status):
+        """End an MLflow span or a trace."""
+        if span.parent_id is None:
+            self._client.end_trace(
+                request_id=span.request_id,
+                outputs=outputs,
+                status=status,
+                end_time_ns=end_time_ns,
+            )
+        else:
+            self._client.end_span(
+                request_id=span.request_id,
+                span_id=span.span_id,
+                outputs=outputs,
+                status=status,
+                end_time_ns=end_time_ns,
+            )