aboutsummaryrefslogtreecommitdiff
imported_openAIResponse = True
try:
    import io
    import logging
    import sys
    from typing import Any, Dict, List, Optional, TypeVar

    from wandb.sdk.data_types import trace_tree

    if sys.version_info >= (3, 8):
        from typing import Literal, Protocol
    else:
        from typing_extensions import Literal, Protocol

    logger = logging.getLogger(__name__)

    K = TypeVar("K", bound=str)
    V = TypeVar("V")

    class OpenAIResponse(Protocol[K, V]):  # type: ignore
        # contains a (known) object attribute
        object: Literal["chat.completion", "edit", "text_completion"]

        def __getitem__(self, key: K) -> V: ...  # noqa

        def get(  # noqa
            self, key: K, default: Optional[V] = None
        ) -> Optional[V]: ...  # pragma: no cover

    class OpenAIRequestResponseResolver:
        def __call__(
            self,
            request: Dict[str, Any],
            response: OpenAIResponse,
            time_elapsed: float,
        ) -> Optional[trace_tree.WBTraceTree]:
            try:
                if response["object"] == "edit":
                    return self._resolve_edit(request, response, time_elapsed)
                elif response["object"] == "text_completion":
                    return self._resolve_completion(request, response, time_elapsed)
                elif response["object"] == "chat.completion":
                    return self._resolve_chat_completion(
                        request, response, time_elapsed
                    )
                else:
                    logger.info(f"Unknown OpenAI response object: {response['object']}")
            except Exception as e:
                logger.warning(f"Failed to resolve request/response: {e}")
            return None

        @staticmethod
        def results_to_trace_tree(
            request: Dict[str, Any],
            response: OpenAIResponse,
            results: List[trace_tree.Result],
            time_elapsed: float,
        ) -> trace_tree.WBTraceTree:
            """Converts the request, response, and results into a trace tree.

            params:
                request: The request dictionary
                response: The response object
                results: A list of results object
                time_elapsed: The time elapsed in seconds
            returns:
                A wandb trace tree object.
            """
            start_time_ms = int(round(response["created"] * 1000))
            end_time_ms = start_time_ms + int(round(time_elapsed * 1000))
            span = trace_tree.Span(
                name=f"{response.get('model', 'openai')}_{response['object']}_{response.get('created')}",
                attributes=dict(response),  # type: ignore
                start_time_ms=start_time_ms,
                end_time_ms=end_time_ms,
                span_kind=trace_tree.SpanKind.LLM,
                results=results,
            )
            model_obj = {"request": request, "response": response, "_kind": "openai"}
            return trace_tree.WBTraceTree(root_span=span, model_dict=model_obj)

        def _resolve_edit(
            self,
            request: Dict[str, Any],
            response: OpenAIResponse,
            time_elapsed: float,
        ) -> trace_tree.WBTraceTree:
            """Resolves the request and response objects for `openai.Edit`."""
            request_str = (
                f"\n\n**Instruction**: {request['instruction']}\n\n"
                f"**Input**: {request['input']}\n"
            )
            choices = [
                f"\n\n**Edited**: {choice['text']}\n" for choice in response["choices"]
            ]

            return self._request_response_result_to_trace(
                request=request,
                response=response,
                request_str=request_str,
                choices=choices,
                time_elapsed=time_elapsed,
            )

        def _resolve_completion(
            self,
            request: Dict[str, Any],
            response: OpenAIResponse,
            time_elapsed: float,
        ) -> trace_tree.WBTraceTree:
            """Resolves the request and response objects for `openai.Completion`."""
            request_str = f"\n\n**Prompt**: {request['prompt']}\n"
            choices = [
                f"\n\n**Completion**: {choice['text']}\n"
                for choice in response["choices"]
            ]

            return self._request_response_result_to_trace(
                request=request,
                response=response,
                request_str=request_str,
                choices=choices,
                time_elapsed=time_elapsed,
            )

        def _resolve_chat_completion(
            self,
            request: Dict[str, Any],
            response: OpenAIResponse,
            time_elapsed: float,
        ) -> trace_tree.WBTraceTree:
            """Resolves the request and response objects for `openai.Completion`."""
            prompt = io.StringIO()
            for message in request["messages"]:
                prompt.write(f"\n\n**{message['role']}**: {message['content']}\n")
            request_str = prompt.getvalue()

            choices = [
                f"\n\n**{choice['message']['role']}**: {choice['message']['content']}\n"
                for choice in response["choices"]
            ]

            return self._request_response_result_to_trace(
                request=request,
                response=response,
                request_str=request_str,
                choices=choices,
                time_elapsed=time_elapsed,
            )

        def _request_response_result_to_trace(
            self,
            request: Dict[str, Any],
            response: OpenAIResponse,
            request_str: str,
            choices: List[str],
            time_elapsed: float,
        ) -> trace_tree.WBTraceTree:
            """Resolves the request and response objects for `openai.Completion`."""
            results = [
                trace_tree.Result(
                    inputs={"request": request_str},
                    outputs={"response": choice},
                )
                for choice in choices
            ]
            trace = self.results_to_trace_tree(request, response, results, time_elapsed)
            return trace

except Exception:
    imported_openAIResponse = False


#### What this does ####
#    On success, logs events to Langfuse
import traceback


class WeightsBiasesLogger:
    # Class variables or attributes
    def __init__(self):
        try:
            pass
        except Exception:
            raise Exception(
                "\033[91m wandb not installed, try running 'pip install wandb' to fix this error\033[0m"
            )
        if imported_openAIResponse is False:
            raise Exception(
                "\033[91m wandb not installed, try running 'pip install wandb' to fix this error\033[0m"
            )
        self.resolver = OpenAIRequestResponseResolver()

    def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
        # Method definition
        import wandb

        try:
            print_verbose(f"W&B Logging - Enters logging function for model {kwargs}")
            run = wandb.init()
            print_verbose(response_obj)

            trace = self.resolver(
                kwargs, response_obj, (end_time - start_time).total_seconds()
            )

            if trace is not None and run is not None:
                run.log({"trace": trace})

            if run is not None:
                run.finish()
                print_verbose(
                    f"W&B Logging Logging - final response object: {response_obj}"
                )
        except Exception:
            print_verbose(f"W&B Logging Layer Error - {traceback.format_exc()}")
            pass