about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/sentry_sdk/integrations/langchain.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/sentry_sdk/integrations/langchain.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/sentry_sdk/integrations/langchain.py')
-rw-r--r--.venv/lib/python3.12/site-packages/sentry_sdk/integrations/langchain.py465
1 files changed, 465 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/sentry_sdk/integrations/langchain.py b/.venv/lib/python3.12/site-packages/sentry_sdk/integrations/langchain.py
new file mode 100644
index 00000000..431fc46b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/sentry_sdk/integrations/langchain.py
@@ -0,0 +1,465 @@
+from collections import OrderedDict
+from functools import wraps
+
+import sentry_sdk
+from sentry_sdk.ai.monitoring import set_ai_pipeline_name, record_token_usage
+from sentry_sdk.consts import OP, SPANDATA
+from sentry_sdk.ai.utils import set_data_normalized
+from sentry_sdk.scope import should_send_default_pii
+from sentry_sdk.tracing import Span
+from sentry_sdk.integrations import DidNotEnable, Integration
+from sentry_sdk.utils import logger, capture_internal_exceptions
+
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+    from typing import Any, List, Callable, Dict, Union, Optional
+    from uuid import UUID
+
+try:
+    from langchain_core.messages import BaseMessage
+    from langchain_core.outputs import LLMResult
+    from langchain_core.callbacks import (
+        manager,
+        BaseCallbackHandler,
+    )
+    from langchain_core.agents import AgentAction, AgentFinish
+except ImportError:
+    raise DidNotEnable("langchain not installed")
+
+
+DATA_FIELDS = {
+    "temperature": SPANDATA.AI_TEMPERATURE,
+    "top_p": SPANDATA.AI_TOP_P,
+    "top_k": SPANDATA.AI_TOP_K,
+    "function_call": SPANDATA.AI_FUNCTION_CALL,
+    "tool_calls": SPANDATA.AI_TOOL_CALLS,
+    "tools": SPANDATA.AI_TOOLS,
+    "response_format": SPANDATA.AI_RESPONSE_FORMAT,
+    "logit_bias": SPANDATA.AI_LOGIT_BIAS,
+    "tags": SPANDATA.AI_TAGS,
+}
+
+# To avoid double collecting tokens, we do *not* measure
+# token counts for models for which we have an explicit integration
+NO_COLLECT_TOKEN_MODELS = [
+    "openai-chat",
+    "anthropic-chat",
+    "cohere-chat",
+    "huggingface_endpoint",
+]
+
+
+class LangchainIntegration(Integration):
+    identifier = "langchain"
+    origin = f"auto.ai.{identifier}"
+
+    # The most number of spans (e.g., LLM calls) that can be processed at the same time.
+    max_spans = 1024
+
+    def __init__(
+        self, include_prompts=True, max_spans=1024, tiktoken_encoding_name=None
+    ):
+        # type: (LangchainIntegration, bool, int, Optional[str]) -> None
+        self.include_prompts = include_prompts
+        self.max_spans = max_spans
+        self.tiktoken_encoding_name = tiktoken_encoding_name
+
+    @staticmethod
+    def setup_once():
+        # type: () -> None
+        manager._configure = _wrap_configure(manager._configure)
+
+
+class WatchedSpan:
+    span = None  # type: Span
+    num_completion_tokens = 0  # type: int
+    num_prompt_tokens = 0  # type: int
+    no_collect_tokens = False  # type: bool
+    children = []  # type: List[WatchedSpan]
+    is_pipeline = False  # type: bool
+
+    def __init__(self, span):
+        # type: (Span) -> None
+        self.span = span
+
+
+class SentryLangchainCallback(BaseCallbackHandler):  # type: ignore[misc]
+    """Base callback handler that can be used to handle callbacks from langchain."""
+
+    span_map = OrderedDict()  # type: OrderedDict[UUID, WatchedSpan]
+
+    max_span_map_size = 0
+
+    def __init__(self, max_span_map_size, include_prompts, tiktoken_encoding_name=None):
+        # type: (int, bool, Optional[str]) -> None
+        self.max_span_map_size = max_span_map_size
+        self.include_prompts = include_prompts
+
+        self.tiktoken_encoding = None
+        if tiktoken_encoding_name is not None:
+            import tiktoken  # type: ignore
+
+            self.tiktoken_encoding = tiktoken.get_encoding(tiktoken_encoding_name)
+
+    def count_tokens(self, s):
+        # type: (str) -> int
+        if self.tiktoken_encoding is not None:
+            return len(self.tiktoken_encoding.encode_ordinary(s))
+        return 0
+
+    def gc_span_map(self):
+        # type: () -> None
+
+        while len(self.span_map) > self.max_span_map_size:
+            run_id, watched_span = self.span_map.popitem(last=False)
+            self._exit_span(watched_span, run_id)
+
+    def _handle_error(self, run_id, error):
+        # type: (UUID, Any) -> None
+        if not run_id or run_id not in self.span_map:
+            return
+
+        span_data = self.span_map[run_id]
+        if not span_data:
+            return
+        sentry_sdk.capture_exception(error, span_data.span.scope)
+        span_data.span.__exit__(None, None, None)
+        del self.span_map[run_id]
+
+    def _normalize_langchain_message(self, message):
+        # type: (BaseMessage) -> Any
+        parsed = {"content": message.content, "role": message.type}
+        parsed.update(message.additional_kwargs)
+        return parsed
+
+    def _create_span(self, run_id, parent_id, **kwargs):
+        # type: (SentryLangchainCallback, UUID, Optional[Any], Any) -> WatchedSpan
+
+        watched_span = None  # type: Optional[WatchedSpan]
+        if parent_id:
+            parent_span = self.span_map.get(parent_id)  # type: Optional[WatchedSpan]
+            if parent_span:
+                watched_span = WatchedSpan(parent_span.span.start_child(**kwargs))
+                parent_span.children.append(watched_span)
+        if watched_span is None:
+            watched_span = WatchedSpan(sentry_sdk.start_span(**kwargs))
+
+        if kwargs.get("op", "").startswith("ai.pipeline."):
+            if kwargs.get("name"):
+                set_ai_pipeline_name(kwargs.get("name"))
+            watched_span.is_pipeline = True
+
+        watched_span.span.__enter__()
+        self.span_map[run_id] = watched_span
+        self.gc_span_map()
+        return watched_span
+
+    def _exit_span(self, span_data, run_id):
+        # type: (SentryLangchainCallback, WatchedSpan, UUID) -> None
+
+        if span_data.is_pipeline:
+            set_ai_pipeline_name(None)
+
+        span_data.span.__exit__(None, None, None)
+        del self.span_map[run_id]
+
+    def on_llm_start(
+        self,
+        serialized,
+        prompts,
+        *,
+        run_id,
+        tags=None,
+        parent_run_id=None,
+        metadata=None,
+        **kwargs,
+    ):
+        # type: (SentryLangchainCallback, Dict[str, Any], List[str], UUID, Optional[List[str]], Optional[UUID], Optional[Dict[str, Any]], Any) -> Any
+        """Run when LLM starts running."""
+        with capture_internal_exceptions():
+            if not run_id:
+                return
+            all_params = kwargs.get("invocation_params", {})
+            all_params.update(serialized.get("kwargs", {}))
+            watched_span = self._create_span(
+                run_id,
+                kwargs.get("parent_run_id"),
+                op=OP.LANGCHAIN_RUN,
+                name=kwargs.get("name") or "Langchain LLM call",
+                origin=LangchainIntegration.origin,
+            )
+            span = watched_span.span
+            if should_send_default_pii() and self.include_prompts:
+                set_data_normalized(span, SPANDATA.AI_INPUT_MESSAGES, prompts)
+            for k, v in DATA_FIELDS.items():
+                if k in all_params:
+                    set_data_normalized(span, v, all_params[k])
+
+    def on_chat_model_start(self, serialized, messages, *, run_id, **kwargs):
+        # type: (SentryLangchainCallback, Dict[str, Any], List[List[BaseMessage]], UUID, Any) -> Any
+        """Run when Chat Model starts running."""
+        with capture_internal_exceptions():
+            if not run_id:
+                return
+            all_params = kwargs.get("invocation_params", {})
+            all_params.update(serialized.get("kwargs", {}))
+            watched_span = self._create_span(
+                run_id,
+                kwargs.get("parent_run_id"),
+                op=OP.LANGCHAIN_CHAT_COMPLETIONS_CREATE,
+                name=kwargs.get("name") or "Langchain Chat Model",
+                origin=LangchainIntegration.origin,
+            )
+            span = watched_span.span
+            model = all_params.get(
+                "model", all_params.get("model_name", all_params.get("model_id"))
+            )
+            watched_span.no_collect_tokens = any(
+                x in all_params.get("_type", "") for x in NO_COLLECT_TOKEN_MODELS
+            )
+
+            if not model and "anthropic" in all_params.get("_type"):
+                model = "claude-2"
+            if model:
+                span.set_data(SPANDATA.AI_MODEL_ID, model)
+            if should_send_default_pii() and self.include_prompts:
+                set_data_normalized(
+                    span,
+                    SPANDATA.AI_INPUT_MESSAGES,
+                    [
+                        [self._normalize_langchain_message(x) for x in list_]
+                        for list_ in messages
+                    ],
+                )
+            for k, v in DATA_FIELDS.items():
+                if k in all_params:
+                    set_data_normalized(span, v, all_params[k])
+            if not watched_span.no_collect_tokens:
+                for list_ in messages:
+                    for message in list_:
+                        self.span_map[run_id].num_prompt_tokens += self.count_tokens(
+                            message.content
+                        ) + self.count_tokens(message.type)
+
+    def on_llm_new_token(self, token, *, run_id, **kwargs):
+        # type: (SentryLangchainCallback, str, UUID, Any) -> Any
+        """Run on new LLM token. Only available when streaming is enabled."""
+        with capture_internal_exceptions():
+            if not run_id or run_id not in self.span_map:
+                return
+            span_data = self.span_map[run_id]
+            if not span_data or span_data.no_collect_tokens:
+                return
+            span_data.num_completion_tokens += self.count_tokens(token)
+
+    def on_llm_end(self, response, *, run_id, **kwargs):
+        # type: (SentryLangchainCallback, LLMResult, UUID, Any) -> Any
+        """Run when LLM ends running."""
+        with capture_internal_exceptions():
+            if not run_id:
+                return
+
+            token_usage = (
+                response.llm_output.get("token_usage") if response.llm_output else None
+            )
+
+            span_data = self.span_map[run_id]
+            if not span_data:
+                return
+
+            if should_send_default_pii() and self.include_prompts:
+                set_data_normalized(
+                    span_data.span,
+                    SPANDATA.AI_RESPONSES,
+                    [[x.text for x in list_] for list_ in response.generations],
+                )
+
+            if not span_data.no_collect_tokens:
+                if token_usage:
+                    record_token_usage(
+                        span_data.span,
+                        token_usage.get("prompt_tokens"),
+                        token_usage.get("completion_tokens"),
+                        token_usage.get("total_tokens"),
+                    )
+                else:
+                    record_token_usage(
+                        span_data.span,
+                        span_data.num_prompt_tokens,
+                        span_data.num_completion_tokens,
+                    )
+
+            self._exit_span(span_data, run_id)
+
+    def on_llm_error(self, error, *, run_id, **kwargs):
+        # type: (SentryLangchainCallback, Union[Exception, KeyboardInterrupt], UUID, Any) -> Any
+        """Run when LLM errors."""
+        with capture_internal_exceptions():
+            self._handle_error(run_id, error)
+
+    def on_chain_start(self, serialized, inputs, *, run_id, **kwargs):
+        # type: (SentryLangchainCallback, Dict[str, Any], Dict[str, Any], UUID, Any) -> Any
+        """Run when chain starts running."""
+        with capture_internal_exceptions():
+            if not run_id:
+                return
+            watched_span = self._create_span(
+                run_id,
+                kwargs.get("parent_run_id"),
+                op=(
+                    OP.LANGCHAIN_RUN
+                    if kwargs.get("parent_run_id") is not None
+                    else OP.LANGCHAIN_PIPELINE
+                ),
+                name=kwargs.get("name") or "Chain execution",
+                origin=LangchainIntegration.origin,
+            )
+            metadata = kwargs.get("metadata")
+            if metadata:
+                set_data_normalized(watched_span.span, SPANDATA.AI_METADATA, metadata)
+
+    def on_chain_end(self, outputs, *, run_id, **kwargs):
+        # type: (SentryLangchainCallback, Dict[str, Any], UUID, Any) -> Any
+        """Run when chain ends running."""
+        with capture_internal_exceptions():
+            if not run_id or run_id not in self.span_map:
+                return
+
+            span_data = self.span_map[run_id]
+            if not span_data:
+                return
+            self._exit_span(span_data, run_id)
+
+    def on_chain_error(self, error, *, run_id, **kwargs):
+        # type: (SentryLangchainCallback, Union[Exception, KeyboardInterrupt], UUID, Any) -> Any
+        """Run when chain errors."""
+        self._handle_error(run_id, error)
+
+    def on_agent_action(self, action, *, run_id, **kwargs):
+        # type: (SentryLangchainCallback, AgentAction, UUID, Any) -> Any
+        with capture_internal_exceptions():
+            if not run_id:
+                return
+            watched_span = self._create_span(
+                run_id,
+                kwargs.get("parent_run_id"),
+                op=OP.LANGCHAIN_AGENT,
+                name=action.tool or "AI tool usage",
+                origin=LangchainIntegration.origin,
+            )
+            if action.tool_input and should_send_default_pii() and self.include_prompts:
+                set_data_normalized(
+                    watched_span.span, SPANDATA.AI_INPUT_MESSAGES, action.tool_input
+                )
+
+    def on_agent_finish(self, finish, *, run_id, **kwargs):
+        # type: (SentryLangchainCallback, AgentFinish, UUID, Any) -> Any
+        with capture_internal_exceptions():
+            if not run_id:
+                return
+
+            span_data = self.span_map[run_id]
+            if not span_data:
+                return
+            if should_send_default_pii() and self.include_prompts:
+                set_data_normalized(
+                    span_data.span, SPANDATA.AI_RESPONSES, finish.return_values.items()
+                )
+            self._exit_span(span_data, run_id)
+
+    def on_tool_start(self, serialized, input_str, *, run_id, **kwargs):
+        # type: (SentryLangchainCallback, Dict[str, Any], str, UUID, Any) -> Any
+        """Run when tool starts running."""
+        with capture_internal_exceptions():
+            if not run_id:
+                return
+            watched_span = self._create_span(
+                run_id,
+                kwargs.get("parent_run_id"),
+                op=OP.LANGCHAIN_TOOL,
+                name=serialized.get("name") or kwargs.get("name") or "AI tool usage",
+                origin=LangchainIntegration.origin,
+            )
+            if should_send_default_pii() and self.include_prompts:
+                set_data_normalized(
+                    watched_span.span,
+                    SPANDATA.AI_INPUT_MESSAGES,
+                    kwargs.get("inputs", [input_str]),
+                )
+                if kwargs.get("metadata"):
+                    set_data_normalized(
+                        watched_span.span, SPANDATA.AI_METADATA, kwargs.get("metadata")
+                    )
+
+    def on_tool_end(self, output, *, run_id, **kwargs):
+        # type: (SentryLangchainCallback, str, UUID, Any) -> Any
+        """Run when tool ends running."""
+        with capture_internal_exceptions():
+            if not run_id or run_id not in self.span_map:
+                return
+
+            span_data = self.span_map[run_id]
+            if not span_data:
+                return
+            if should_send_default_pii() and self.include_prompts:
+                set_data_normalized(span_data.span, SPANDATA.AI_RESPONSES, output)
+            self._exit_span(span_data, run_id)
+
+    def on_tool_error(self, error, *args, run_id, **kwargs):
+        # type: (SentryLangchainCallback, Union[Exception, KeyboardInterrupt], UUID, Any) -> Any
+        """Run when tool errors."""
+        self._handle_error(run_id, error)
+
+
+def _wrap_configure(f):
+    # type: (Callable[..., Any]) -> Callable[..., Any]
+
+    @wraps(f)
+    def new_configure(*args, **kwargs):
+        # type: (Any, Any) -> Any
+
+        integration = sentry_sdk.get_client().get_integration(LangchainIntegration)
+        if integration is None:
+            return f(*args, **kwargs)
+
+        with capture_internal_exceptions():
+            new_callbacks = []  # type: List[BaseCallbackHandler]
+            if "local_callbacks" in kwargs:
+                existing_callbacks = kwargs["local_callbacks"]
+                kwargs["local_callbacks"] = new_callbacks
+            elif len(args) > 2:
+                existing_callbacks = args[2]
+                args = (
+                    args[0],
+                    args[1],
+                    new_callbacks,
+                ) + args[3:]
+            else:
+                existing_callbacks = []
+
+            if existing_callbacks:
+                if isinstance(existing_callbacks, list):
+                    for cb in existing_callbacks:
+                        new_callbacks.append(cb)
+                elif isinstance(existing_callbacks, BaseCallbackHandler):
+                    new_callbacks.append(existing_callbacks)
+                else:
+                    logger.debug("Unknown callback type: %s", existing_callbacks)
+
+            already_added = False
+            for callback in new_callbacks:
+                if isinstance(callback, SentryLangchainCallback):
+                    already_added = True
+
+            if not already_added:
+                new_callbacks.append(
+                    SentryLangchainCallback(
+                        integration.max_spans,
+                        integration.include_prompts,
+                        integration.tiktoken_encoding_name,
+                    )
+                )
+        return f(*args, **kwargs)
+
+    return new_configure