about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/sentry_sdk/integrations/openai.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/sentry_sdk/integrations/openai.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/sentry_sdk/integrations/openai.py')
-rw-r--r--.venv/lib/python3.12/site-packages/sentry_sdk/integrations/openai.py429
1 files changed, 429 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/sentry_sdk/integrations/openai.py b/.venv/lib/python3.12/site-packages/sentry_sdk/integrations/openai.py
new file mode 100644
index 00000000..61d335b1
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/sentry_sdk/integrations/openai.py
@@ -0,0 +1,429 @@
+from functools import wraps
+
+import sentry_sdk
+from sentry_sdk import consts
+from sentry_sdk.ai.monitoring import record_token_usage
+from sentry_sdk.ai.utils import set_data_normalized
+from sentry_sdk.consts import SPANDATA
+from sentry_sdk.integrations import DidNotEnable, Integration
+from sentry_sdk.scope import should_send_default_pii
+from sentry_sdk.utils import (
+    capture_internal_exceptions,
+    event_from_exception,
+)
+
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+    from typing import Any, Iterable, List, Optional, Callable, AsyncIterator, Iterator
+    from sentry_sdk.tracing import Span
+
+try:
+    from openai.resources.chat.completions import Completions, AsyncCompletions
+    from openai.resources import Embeddings, AsyncEmbeddings
+
+    if TYPE_CHECKING:
+        from openai.types.chat import ChatCompletionMessageParam, ChatCompletionChunk
+except ImportError:
+    raise DidNotEnable("OpenAI not installed")
+
+
+class OpenAIIntegration(Integration):
+    identifier = "openai"
+    origin = f"auto.ai.{identifier}"
+
+    def __init__(self, include_prompts=True, tiktoken_encoding_name=None):
+        # type: (OpenAIIntegration, bool, Optional[str]) -> None
+        self.include_prompts = include_prompts
+
+        self.tiktoken_encoding = None
+        if tiktoken_encoding_name is not None:
+            import tiktoken  # type: ignore
+
+            self.tiktoken_encoding = tiktoken.get_encoding(tiktoken_encoding_name)
+
+    @staticmethod
+    def setup_once():
+        # type: () -> None
+        Completions.create = _wrap_chat_completion_create(Completions.create)
+        Embeddings.create = _wrap_embeddings_create(Embeddings.create)
+
+        AsyncCompletions.create = _wrap_async_chat_completion_create(
+            AsyncCompletions.create
+        )
+        AsyncEmbeddings.create = _wrap_async_embeddings_create(AsyncEmbeddings.create)
+
+    def count_tokens(self, s):
+        # type: (OpenAIIntegration, str) -> int
+        if self.tiktoken_encoding is not None:
+            return len(self.tiktoken_encoding.encode_ordinary(s))
+        return 0
+
+
+def _capture_exception(exc):
+    # type: (Any) -> None
+    event, hint = event_from_exception(
+        exc,
+        client_options=sentry_sdk.get_client().options,
+        mechanism={"type": "openai", "handled": False},
+    )
+    sentry_sdk.capture_event(event, hint=hint)
+
+
+def _calculate_chat_completion_usage(
+    messages, response, span, streaming_message_responses, count_tokens
+):
+    # type: (Iterable[ChatCompletionMessageParam], Any, Span, Optional[List[str]], Callable[..., Any]) -> None
+    completion_tokens = 0  # type: Optional[int]
+    prompt_tokens = 0  # type: Optional[int]
+    total_tokens = 0  # type: Optional[int]
+    if hasattr(response, "usage"):
+        if hasattr(response.usage, "completion_tokens") and isinstance(
+            response.usage.completion_tokens, int
+        ):
+            completion_tokens = response.usage.completion_tokens
+        if hasattr(response.usage, "prompt_tokens") and isinstance(
+            response.usage.prompt_tokens, int
+        ):
+            prompt_tokens = response.usage.prompt_tokens
+        if hasattr(response.usage, "total_tokens") and isinstance(
+            response.usage.total_tokens, int
+        ):
+            total_tokens = response.usage.total_tokens
+
+    if prompt_tokens == 0:
+        for message in messages:
+            if "content" in message:
+                prompt_tokens += count_tokens(message["content"])
+
+    if completion_tokens == 0:
+        if streaming_message_responses is not None:
+            for message in streaming_message_responses:
+                completion_tokens += count_tokens(message)
+        elif hasattr(response, "choices"):
+            for choice in response.choices:
+                if hasattr(choice, "message"):
+                    completion_tokens += count_tokens(choice.message)
+
+    if prompt_tokens == 0:
+        prompt_tokens = None
+    if completion_tokens == 0:
+        completion_tokens = None
+    if total_tokens == 0:
+        total_tokens = None
+    record_token_usage(span, prompt_tokens, completion_tokens, total_tokens)
+
+
+def _new_chat_completion_common(f, *args, **kwargs):
+    # type: (Any, *Any, **Any) -> Any
+    integration = sentry_sdk.get_client().get_integration(OpenAIIntegration)
+    if integration is None:
+        return f(*args, **kwargs)
+
+    if "messages" not in kwargs:
+        # invalid call (in all versions of openai), let it return error
+        return f(*args, **kwargs)
+
+    try:
+        iter(kwargs["messages"])
+    except TypeError:
+        # invalid call (in all versions), messages must be iterable
+        return f(*args, **kwargs)
+
+    kwargs["messages"] = list(kwargs["messages"])
+    messages = kwargs["messages"]
+    model = kwargs.get("model")
+    streaming = kwargs.get("stream")
+
+    span = sentry_sdk.start_span(
+        op=consts.OP.OPENAI_CHAT_COMPLETIONS_CREATE,
+        name="Chat Completion",
+        origin=OpenAIIntegration.origin,
+    )
+    span.__enter__()
+
+    res = yield f, args, kwargs
+
+    with capture_internal_exceptions():
+        if should_send_default_pii() and integration.include_prompts:
+            set_data_normalized(span, SPANDATA.AI_INPUT_MESSAGES, messages)
+
+        set_data_normalized(span, SPANDATA.AI_MODEL_ID, model)
+        set_data_normalized(span, SPANDATA.AI_STREAMING, streaming)
+
+        if hasattr(res, "choices"):
+            if should_send_default_pii() and integration.include_prompts:
+                set_data_normalized(
+                    span,
+                    "ai.responses",
+                    list(map(lambda x: x.message, res.choices)),
+                )
+            _calculate_chat_completion_usage(
+                messages, res, span, None, integration.count_tokens
+            )
+            span.__exit__(None, None, None)
+        elif hasattr(res, "_iterator"):
+            data_buf: list[list[str]] = []  # one for each choice
+
+            old_iterator = res._iterator
+
+            def new_iterator():
+                # type: () -> Iterator[ChatCompletionChunk]
+                with capture_internal_exceptions():
+                    for x in old_iterator:
+                        if hasattr(x, "choices"):
+                            choice_index = 0
+                            for choice in x.choices:
+                                if hasattr(choice, "delta") and hasattr(
+                                    choice.delta, "content"
+                                ):
+                                    content = choice.delta.content
+                                    if len(data_buf) <= choice_index:
+                                        data_buf.append([])
+                                    data_buf[choice_index].append(content or "")
+                                choice_index += 1
+                        yield x
+                    if len(data_buf) > 0:
+                        all_responses = list(
+                            map(lambda chunk: "".join(chunk), data_buf)
+                        )
+                        if should_send_default_pii() and integration.include_prompts:
+                            set_data_normalized(
+                                span, SPANDATA.AI_RESPONSES, all_responses
+                            )
+                        _calculate_chat_completion_usage(
+                            messages,
+                            res,
+                            span,
+                            all_responses,
+                            integration.count_tokens,
+                        )
+                span.__exit__(None, None, None)
+
+            async def new_iterator_async():
+                # type: () -> AsyncIterator[ChatCompletionChunk]
+                with capture_internal_exceptions():
+                    async for x in old_iterator:
+                        if hasattr(x, "choices"):
+                            choice_index = 0
+                            for choice in x.choices:
+                                if hasattr(choice, "delta") and hasattr(
+                                    choice.delta, "content"
+                                ):
+                                    content = choice.delta.content
+                                    if len(data_buf) <= choice_index:
+                                        data_buf.append([])
+                                    data_buf[choice_index].append(content or "")
+                                choice_index += 1
+                        yield x
+                    if len(data_buf) > 0:
+                        all_responses = list(
+                            map(lambda chunk: "".join(chunk), data_buf)
+                        )
+                        if should_send_default_pii() and integration.include_prompts:
+                            set_data_normalized(
+                                span, SPANDATA.AI_RESPONSES, all_responses
+                            )
+                        _calculate_chat_completion_usage(
+                            messages,
+                            res,
+                            span,
+                            all_responses,
+                            integration.count_tokens,
+                        )
+                span.__exit__(None, None, None)
+
+            if str(type(res._iterator)) == "<class 'async_generator'>":
+                res._iterator = new_iterator_async()
+            else:
+                res._iterator = new_iterator()
+
+        else:
+            set_data_normalized(span, "unknown_response", True)
+            span.__exit__(None, None, None)
+    return res
+
+
+def _wrap_chat_completion_create(f):
+    # type: (Callable[..., Any]) -> Callable[..., Any]
+    def _execute_sync(f, *args, **kwargs):
+        # type: (Any, *Any, **Any) -> Any
+        gen = _new_chat_completion_common(f, *args, **kwargs)
+
+        try:
+            f, args, kwargs = next(gen)
+        except StopIteration as e:
+            return e.value
+
+        try:
+            try:
+                result = f(*args, **kwargs)
+            except Exception as e:
+                _capture_exception(e)
+                raise e from None
+
+            return gen.send(result)
+        except StopIteration as e:
+            return e.value
+
+    @wraps(f)
+    def _sentry_patched_create_sync(*args, **kwargs):
+        # type: (*Any, **Any) -> Any
+        integration = sentry_sdk.get_client().get_integration(OpenAIIntegration)
+        if integration is None or "messages" not in kwargs:
+            # no "messages" means invalid call (in all versions of openai), let it return error
+            return f(*args, **kwargs)
+
+        return _execute_sync(f, *args, **kwargs)
+
+    return _sentry_patched_create_sync
+
+
+def _wrap_async_chat_completion_create(f):
+    # type: (Callable[..., Any]) -> Callable[..., Any]
+    async def _execute_async(f, *args, **kwargs):
+        # type: (Any, *Any, **Any) -> Any
+        gen = _new_chat_completion_common(f, *args, **kwargs)
+
+        try:
+            f, args, kwargs = next(gen)
+        except StopIteration as e:
+            return await e.value
+
+        try:
+            try:
+                result = await f(*args, **kwargs)
+            except Exception as e:
+                _capture_exception(e)
+                raise e from None
+
+            return gen.send(result)
+        except StopIteration as e:
+            return e.value
+
+    @wraps(f)
+    async def _sentry_patched_create_async(*args, **kwargs):
+        # type: (*Any, **Any) -> Any
+        integration = sentry_sdk.get_client().get_integration(OpenAIIntegration)
+        if integration is None or "messages" not in kwargs:
+            # no "messages" means invalid call (in all versions of openai), let it return error
+            return await f(*args, **kwargs)
+
+        return await _execute_async(f, *args, **kwargs)
+
+    return _sentry_patched_create_async
+
+
+def _new_embeddings_create_common(f, *args, **kwargs):
+    # type: (Any, *Any, **Any) -> Any
+    integration = sentry_sdk.get_client().get_integration(OpenAIIntegration)
+    if integration is None:
+        return f(*args, **kwargs)
+
+    with sentry_sdk.start_span(
+        op=consts.OP.OPENAI_EMBEDDINGS_CREATE,
+        description="OpenAI Embedding Creation",
+        origin=OpenAIIntegration.origin,
+    ) as span:
+        if "input" in kwargs and (
+            should_send_default_pii() and integration.include_prompts
+        ):
+            if isinstance(kwargs["input"], str):
+                set_data_normalized(span, "ai.input_messages", [kwargs["input"]])
+            elif (
+                isinstance(kwargs["input"], list)
+                and len(kwargs["input"]) > 0
+                and isinstance(kwargs["input"][0], str)
+            ):
+                set_data_normalized(span, "ai.input_messages", kwargs["input"])
+        if "model" in kwargs:
+            set_data_normalized(span, "ai.model_id", kwargs["model"])
+
+        response = yield f, args, kwargs
+
+        prompt_tokens = 0
+        total_tokens = 0
+        if hasattr(response, "usage"):
+            if hasattr(response.usage, "prompt_tokens") and isinstance(
+                response.usage.prompt_tokens, int
+            ):
+                prompt_tokens = response.usage.prompt_tokens
+            if hasattr(response.usage, "total_tokens") and isinstance(
+                response.usage.total_tokens, int
+            ):
+                total_tokens = response.usage.total_tokens
+
+        if prompt_tokens == 0:
+            prompt_tokens = integration.count_tokens(kwargs["input"] or "")
+
+        record_token_usage(span, prompt_tokens, None, total_tokens or prompt_tokens)
+
+        return response
+
+
+def _wrap_embeddings_create(f):
+    # type: (Any) -> Any
+    def _execute_sync(f, *args, **kwargs):
+        # type: (Any, *Any, **Any) -> Any
+        gen = _new_embeddings_create_common(f, *args, **kwargs)
+
+        try:
+            f, args, kwargs = next(gen)
+        except StopIteration as e:
+            return e.value
+
+        try:
+            try:
+                result = f(*args, **kwargs)
+            except Exception as e:
+                _capture_exception(e)
+                raise e from None
+
+            return gen.send(result)
+        except StopIteration as e:
+            return e.value
+
+    @wraps(f)
+    def _sentry_patched_create_sync(*args, **kwargs):
+        # type: (*Any, **Any) -> Any
+        integration = sentry_sdk.get_client().get_integration(OpenAIIntegration)
+        if integration is None:
+            return f(*args, **kwargs)
+
+        return _execute_sync(f, *args, **kwargs)
+
+    return _sentry_patched_create_sync
+
+
+def _wrap_async_embeddings_create(f):
+    # type: (Any) -> Any
+    async def _execute_async(f, *args, **kwargs):
+        # type: (Any, *Any, **Any) -> Any
+        gen = _new_embeddings_create_common(f, *args, **kwargs)
+
+        try:
+            f, args, kwargs = next(gen)
+        except StopIteration as e:
+            return await e.value
+
+        try:
+            try:
+                result = await f(*args, **kwargs)
+            except Exception as e:
+                _capture_exception(e)
+                raise e from None
+
+            return gen.send(result)
+        except StopIteration as e:
+            return e.value
+
+    @wraps(f)
+    async def _sentry_patched_create_async(*args, **kwargs):
+        # type: (*Any, **Any) -> Any
+        integration = sentry_sdk.get_client().get_integration(OpenAIIntegration)
+        if integration is None:
+            return await f(*args, **kwargs)
+
+        return await _execute_async(f, *args, **kwargs)
+
+    return _sentry_patched_create_async