aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/integrations/weights_biases.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/integrations/weights_biases.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/weights_biases.py217
1 files changed, 217 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/weights_biases.py b/.venv/lib/python3.12/site-packages/litellm/integrations/weights_biases.py
new file mode 100644
index 00000000..5fcbab04
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/weights_biases.py
@@ -0,0 +1,217 @@
+imported_openAIResponse = True
+try:
+ import io
+ import logging
+ import sys
+ from typing import Any, Dict, List, Optional, TypeVar
+
+ from wandb.sdk.data_types import trace_tree
+
+ if sys.version_info >= (3, 8):
+ from typing import Literal, Protocol
+ else:
+ from typing_extensions import Literal, Protocol
+
+ logger = logging.getLogger(__name__)
+
+ K = TypeVar("K", bound=str)
+ V = TypeVar("V")
+
+ class OpenAIResponse(Protocol[K, V]): # type: ignore
+ # contains a (known) object attribute
+ object: Literal["chat.completion", "edit", "text_completion"]
+
+ def __getitem__(self, key: K) -> V: ... # noqa
+
+ def get( # noqa
+ self, key: K, default: Optional[V] = None
+ ) -> Optional[V]: ... # pragma: no cover
+
+ class OpenAIRequestResponseResolver:
+ def __call__(
+ self,
+ request: Dict[str, Any],
+ response: OpenAIResponse,
+ time_elapsed: float,
+ ) -> Optional[trace_tree.WBTraceTree]:
+ try:
+ if response["object"] == "edit":
+ return self._resolve_edit(request, response, time_elapsed)
+ elif response["object"] == "text_completion":
+ return self._resolve_completion(request, response, time_elapsed)
+ elif response["object"] == "chat.completion":
+ return self._resolve_chat_completion(
+ request, response, time_elapsed
+ )
+ else:
+ logger.info(f"Unknown OpenAI response object: {response['object']}")
+ except Exception as e:
+ logger.warning(f"Failed to resolve request/response: {e}")
+ return None
+
+ @staticmethod
+ def results_to_trace_tree(
+ request: Dict[str, Any],
+ response: OpenAIResponse,
+ results: List[trace_tree.Result],
+ time_elapsed: float,
+ ) -> trace_tree.WBTraceTree:
+ """Converts the request, response, and results into a trace tree.
+
+ params:
+ request: The request dictionary
+ response: The response object
+ results: A list of results object
+ time_elapsed: The time elapsed in seconds
+ returns:
+ A wandb trace tree object.
+ """
+ start_time_ms = int(round(response["created"] * 1000))
+ end_time_ms = start_time_ms + int(round(time_elapsed * 1000))
+ span = trace_tree.Span(
+ name=f"{response.get('model', 'openai')}_{response['object']}_{response.get('created')}",
+ attributes=dict(response), # type: ignore
+ start_time_ms=start_time_ms,
+ end_time_ms=end_time_ms,
+ span_kind=trace_tree.SpanKind.LLM,
+ results=results,
+ )
+ model_obj = {"request": request, "response": response, "_kind": "openai"}
+ return trace_tree.WBTraceTree(root_span=span, model_dict=model_obj)
+
+ def _resolve_edit(
+ self,
+ request: Dict[str, Any],
+ response: OpenAIResponse,
+ time_elapsed: float,
+ ) -> trace_tree.WBTraceTree:
+ """Resolves the request and response objects for `openai.Edit`."""
+ request_str = (
+ f"\n\n**Instruction**: {request['instruction']}\n\n"
+ f"**Input**: {request['input']}\n"
+ )
+ choices = [
+ f"\n\n**Edited**: {choice['text']}\n" for choice in response["choices"]
+ ]
+
+ return self._request_response_result_to_trace(
+ request=request,
+ response=response,
+ request_str=request_str,
+ choices=choices,
+ time_elapsed=time_elapsed,
+ )
+
+ def _resolve_completion(
+ self,
+ request: Dict[str, Any],
+ response: OpenAIResponse,
+ time_elapsed: float,
+ ) -> trace_tree.WBTraceTree:
+ """Resolves the request and response objects for `openai.Completion`."""
+ request_str = f"\n\n**Prompt**: {request['prompt']}\n"
+ choices = [
+ f"\n\n**Completion**: {choice['text']}\n"
+ for choice in response["choices"]
+ ]
+
+ return self._request_response_result_to_trace(
+ request=request,
+ response=response,
+ request_str=request_str,
+ choices=choices,
+ time_elapsed=time_elapsed,
+ )
+
+ def _resolve_chat_completion(
+ self,
+ request: Dict[str, Any],
+ response: OpenAIResponse,
+ time_elapsed: float,
+ ) -> trace_tree.WBTraceTree:
+ """Resolves the request and response objects for `openai.Completion`."""
+ prompt = io.StringIO()
+ for message in request["messages"]:
+ prompt.write(f"\n\n**{message['role']}**: {message['content']}\n")
+ request_str = prompt.getvalue()
+
+ choices = [
+ f"\n\n**{choice['message']['role']}**: {choice['message']['content']}\n"
+ for choice in response["choices"]
+ ]
+
+ return self._request_response_result_to_trace(
+ request=request,
+ response=response,
+ request_str=request_str,
+ choices=choices,
+ time_elapsed=time_elapsed,
+ )
+
+ def _request_response_result_to_trace(
+ self,
+ request: Dict[str, Any],
+ response: OpenAIResponse,
+ request_str: str,
+ choices: List[str],
+ time_elapsed: float,
+ ) -> trace_tree.WBTraceTree:
+ """Resolves the request and response objects for `openai.Completion`."""
+ results = [
+ trace_tree.Result(
+ inputs={"request": request_str},
+ outputs={"response": choice},
+ )
+ for choice in choices
+ ]
+ trace = self.results_to_trace_tree(request, response, results, time_elapsed)
+ return trace
+
+except Exception:
+ imported_openAIResponse = False
+
+
+#### What this does ####
+# On success, logs events to Langfuse
+import traceback
+
+
+class WeightsBiasesLogger:
+ # Class variables or attributes
+ def __init__(self):
+ try:
+ pass
+ except Exception:
+ raise Exception(
+ "\033[91m wandb not installed, try running 'pip install wandb' to fix this error\033[0m"
+ )
+ if imported_openAIResponse is False:
+ raise Exception(
+ "\033[91m wandb not installed, try running 'pip install wandb' to fix this error\033[0m"
+ )
+ self.resolver = OpenAIRequestResponseResolver()
+
+ def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
+ # Method definition
+ import wandb
+
+ try:
+ print_verbose(f"W&B Logging - Enters logging function for model {kwargs}")
+ run = wandb.init()
+ print_verbose(response_obj)
+
+ trace = self.resolver(
+ kwargs, response_obj, (end_time - start_time).total_seconds()
+ )
+
+ if trace is not None and run is not None:
+ run.log({"trace": trace})
+
+ if run is not None:
+ run.finish()
+ print_verbose(
+ f"W&B Logging Logging - final response object: {response_obj}"
+ )
+ except Exception:
+ print_verbose(f"W&B Logging Layer Error - {traceback.format_exc()}")
+ pass