about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/sentry_sdk/integrations/huggingface_hub.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/sentry_sdk/integrations/huggingface_hub.py')
-rw-r--r--.venv/lib/python3.12/site-packages/sentry_sdk/integrations/huggingface_hub.py175
1 files changed, 175 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/sentry_sdk/integrations/huggingface_hub.py b/.venv/lib/python3.12/site-packages/sentry_sdk/integrations/huggingface_hub.py
new file mode 100644
index 00000000..d09f6e21
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/sentry_sdk/integrations/huggingface_hub.py
@@ -0,0 +1,175 @@
+from functools import wraps
+
+from sentry_sdk import consts
+from sentry_sdk.ai.monitoring import record_token_usage
+from sentry_sdk.ai.utils import set_data_normalized
+from sentry_sdk.consts import SPANDATA
+
+from typing import Any, Iterable, Callable
+
+import sentry_sdk
+from sentry_sdk.scope import should_send_default_pii
+from sentry_sdk.integrations import DidNotEnable, Integration
+from sentry_sdk.utils import (
+    capture_internal_exceptions,
+    event_from_exception,
+)
+
+try:
+    import huggingface_hub.inference._client
+
+    from huggingface_hub import ChatCompletionStreamOutput, TextGenerationOutput
+except ImportError:
+    raise DidNotEnable("Huggingface not installed")
+
+
+class HuggingfaceHubIntegration(Integration):
+    identifier = "huggingface_hub"
+    origin = f"auto.ai.{identifier}"
+
+    def __init__(self, include_prompts=True):
+        # type: (HuggingfaceHubIntegration, bool) -> None
+        self.include_prompts = include_prompts
+
+    @staticmethod
+    def setup_once():
+        # type: () -> None
+        huggingface_hub.inference._client.InferenceClient.text_generation = (
+            _wrap_text_generation(
+                huggingface_hub.inference._client.InferenceClient.text_generation
+            )
+        )
+
+
+def _capture_exception(exc):
+    # type: (Any) -> None
+    event, hint = event_from_exception(
+        exc,
+        client_options=sentry_sdk.get_client().options,
+        mechanism={"type": "huggingface_hub", "handled": False},
+    )
+    sentry_sdk.capture_event(event, hint=hint)
+
+
+def _wrap_text_generation(f):
+    # type: (Callable[..., Any]) -> Callable[..., Any]
+    @wraps(f)
+    def new_text_generation(*args, **kwargs):
+        # type: (*Any, **Any) -> Any
+        integration = sentry_sdk.get_client().get_integration(HuggingfaceHubIntegration)
+        if integration is None:
+            return f(*args, **kwargs)
+
+        if "prompt" in kwargs:
+            prompt = kwargs["prompt"]
+        elif len(args) >= 2:
+            kwargs["prompt"] = args[1]
+            prompt = kwargs["prompt"]
+            args = (args[0],) + args[2:]
+        else:
+            # invalid call, let it return error
+            return f(*args, **kwargs)
+
+        model = kwargs.get("model")
+        streaming = kwargs.get("stream")
+
+        span = sentry_sdk.start_span(
+            op=consts.OP.HUGGINGFACE_HUB_CHAT_COMPLETIONS_CREATE,
+            name="Text Generation",
+            origin=HuggingfaceHubIntegration.origin,
+        )
+        span.__enter__()
+        try:
+            res = f(*args, **kwargs)
+        except Exception as e:
+            _capture_exception(e)
+            span.__exit__(None, None, None)
+            raise e from None
+
+        with capture_internal_exceptions():
+            if should_send_default_pii() and integration.include_prompts:
+                set_data_normalized(span, SPANDATA.AI_INPUT_MESSAGES, prompt)
+
+            set_data_normalized(span, SPANDATA.AI_MODEL_ID, model)
+            set_data_normalized(span, SPANDATA.AI_STREAMING, streaming)
+
+            if isinstance(res, str):
+                if should_send_default_pii() and integration.include_prompts:
+                    set_data_normalized(
+                        span,
+                        "ai.responses",
+                        [res],
+                    )
+                span.__exit__(None, None, None)
+                return res
+
+            if isinstance(res, TextGenerationOutput):
+                if should_send_default_pii() and integration.include_prompts:
+                    set_data_normalized(
+                        span,
+                        "ai.responses",
+                        [res.generated_text],
+                    )
+                if res.details is not None and res.details.generated_tokens > 0:
+                    record_token_usage(span, total_tokens=res.details.generated_tokens)
+                span.__exit__(None, None, None)
+                return res
+
+            if not isinstance(res, Iterable):
+                # we only know how to deal with strings and iterables, ignore
+                set_data_normalized(span, "unknown_response", True)
+                span.__exit__(None, None, None)
+                return res
+
+            if kwargs.get("details", False):
+                # res is Iterable[TextGenerationStreamOutput]
+                def new_details_iterator():
+                    # type: () -> Iterable[ChatCompletionStreamOutput]
+                    with capture_internal_exceptions():
+                        tokens_used = 0
+                        data_buf: list[str] = []
+                        for x in res:
+                            if hasattr(x, "token") and hasattr(x.token, "text"):
+                                data_buf.append(x.token.text)
+                            if hasattr(x, "details") and hasattr(
+                                x.details, "generated_tokens"
+                            ):
+                                tokens_used = x.details.generated_tokens
+                            yield x
+                        if (
+                            len(data_buf) > 0
+                            and should_send_default_pii()
+                            and integration.include_prompts
+                        ):
+                            set_data_normalized(
+                                span, SPANDATA.AI_RESPONSES, "".join(data_buf)
+                            )
+                        if tokens_used > 0:
+                            record_token_usage(span, total_tokens=tokens_used)
+                    span.__exit__(None, None, None)
+
+                return new_details_iterator()
+            else:
+                # res is Iterable[str]
+
+                def new_iterator():
+                    # type: () -> Iterable[str]
+                    data_buf: list[str] = []
+                    with capture_internal_exceptions():
+                        for s in res:
+                            if isinstance(s, str):
+                                data_buf.append(s)
+                            yield s
+                        if (
+                            len(data_buf) > 0
+                            and should_send_default_pii()
+                            and integration.include_prompts
+                        ):
+                            set_data_normalized(
+                                span, SPANDATA.AI_RESPONSES, "".join(data_buf)
+                            )
+                        span.__exit__(None, None, None)
+
+                return new_iterator()
+
+    return new_text_generation