aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/sentry_sdk/integrations/langchain.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/sentry_sdk/integrations/langchain.py')
-rw-r--r--.venv/lib/python3.12/site-packages/sentry_sdk/integrations/langchain.py465
1 files changed, 465 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/sentry_sdk/integrations/langchain.py b/.venv/lib/python3.12/site-packages/sentry_sdk/integrations/langchain.py
new file mode 100644
index 00000000..431fc46b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/sentry_sdk/integrations/langchain.py
@@ -0,0 +1,465 @@
+from collections import OrderedDict
+from functools import wraps
+
+import sentry_sdk
+from sentry_sdk.ai.monitoring import set_ai_pipeline_name, record_token_usage
+from sentry_sdk.consts import OP, SPANDATA
+from sentry_sdk.ai.utils import set_data_normalized
+from sentry_sdk.scope import should_send_default_pii
+from sentry_sdk.tracing import Span
+from sentry_sdk.integrations import DidNotEnable, Integration
+from sentry_sdk.utils import logger, capture_internal_exceptions
+
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+ from typing import Any, List, Callable, Dict, Union, Optional
+ from uuid import UUID
+
+try:
+ from langchain_core.messages import BaseMessage
+ from langchain_core.outputs import LLMResult
+ from langchain_core.callbacks import (
+ manager,
+ BaseCallbackHandler,
+ )
+ from langchain_core.agents import AgentAction, AgentFinish
+except ImportError:
+ raise DidNotEnable("langchain not installed")
+
+
+DATA_FIELDS = {
+ "temperature": SPANDATA.AI_TEMPERATURE,
+ "top_p": SPANDATA.AI_TOP_P,
+ "top_k": SPANDATA.AI_TOP_K,
+ "function_call": SPANDATA.AI_FUNCTION_CALL,
+ "tool_calls": SPANDATA.AI_TOOL_CALLS,
+ "tools": SPANDATA.AI_TOOLS,
+ "response_format": SPANDATA.AI_RESPONSE_FORMAT,
+ "logit_bias": SPANDATA.AI_LOGIT_BIAS,
+ "tags": SPANDATA.AI_TAGS,
+}
+
+# To avoid double collecting tokens, we do *not* measure
+# token counts for models for which we have an explicit integration
+NO_COLLECT_TOKEN_MODELS = [
+ "openai-chat",
+ "anthropic-chat",
+ "cohere-chat",
+ "huggingface_endpoint",
+]
+
+
+class LangchainIntegration(Integration):
+ identifier = "langchain"
+ origin = f"auto.ai.{identifier}"
+
+ # The most number of spans (e.g., LLM calls) that can be processed at the same time.
+ max_spans = 1024
+
+ def __init__(
+ self, include_prompts=True, max_spans=1024, tiktoken_encoding_name=None
+ ):
+ # type: (LangchainIntegration, bool, int, Optional[str]) -> None
+ self.include_prompts = include_prompts
+ self.max_spans = max_spans
+ self.tiktoken_encoding_name = tiktoken_encoding_name
+
+ @staticmethod
+ def setup_once():
+ # type: () -> None
+ manager._configure = _wrap_configure(manager._configure)
+
+
+class WatchedSpan:
+ span = None # type: Span
+ num_completion_tokens = 0 # type: int
+ num_prompt_tokens = 0 # type: int
+ no_collect_tokens = False # type: bool
+ children = [] # type: List[WatchedSpan]
+ is_pipeline = False # type: bool
+
+ def __init__(self, span):
+ # type: (Span) -> None
+ self.span = span
+
+
+class SentryLangchainCallback(BaseCallbackHandler): # type: ignore[misc]
+ """Base callback handler that can be used to handle callbacks from langchain."""
+
+ span_map = OrderedDict() # type: OrderedDict[UUID, WatchedSpan]
+
+ max_span_map_size = 0
+
+ def __init__(self, max_span_map_size, include_prompts, tiktoken_encoding_name=None):
+ # type: (int, bool, Optional[str]) -> None
+ self.max_span_map_size = max_span_map_size
+ self.include_prompts = include_prompts
+
+ self.tiktoken_encoding = None
+ if tiktoken_encoding_name is not None:
+ import tiktoken # type: ignore
+
+ self.tiktoken_encoding = tiktoken.get_encoding(tiktoken_encoding_name)
+
+ def count_tokens(self, s):
+ # type: (str) -> int
+ if self.tiktoken_encoding is not None:
+ return len(self.tiktoken_encoding.encode_ordinary(s))
+ return 0
+
+ def gc_span_map(self):
+ # type: () -> None
+
+ while len(self.span_map) > self.max_span_map_size:
+ run_id, watched_span = self.span_map.popitem(last=False)
+ self._exit_span(watched_span, run_id)
+
+ def _handle_error(self, run_id, error):
+ # type: (UUID, Any) -> None
+ if not run_id or run_id not in self.span_map:
+ return
+
+ span_data = self.span_map[run_id]
+ if not span_data:
+ return
+ sentry_sdk.capture_exception(error, span_data.span.scope)
+ span_data.span.__exit__(None, None, None)
+ del self.span_map[run_id]
+
+ def _normalize_langchain_message(self, message):
+ # type: (BaseMessage) -> Any
+ parsed = {"content": message.content, "role": message.type}
+ parsed.update(message.additional_kwargs)
+ return parsed
+
+ def _create_span(self, run_id, parent_id, **kwargs):
+ # type: (SentryLangchainCallback, UUID, Optional[Any], Any) -> WatchedSpan
+
+ watched_span = None # type: Optional[WatchedSpan]
+ if parent_id:
+ parent_span = self.span_map.get(parent_id) # type: Optional[WatchedSpan]
+ if parent_span:
+ watched_span = WatchedSpan(parent_span.span.start_child(**kwargs))
+ parent_span.children.append(watched_span)
+ if watched_span is None:
+ watched_span = WatchedSpan(sentry_sdk.start_span(**kwargs))
+
+ if kwargs.get("op", "").startswith("ai.pipeline."):
+ if kwargs.get("name"):
+ set_ai_pipeline_name(kwargs.get("name"))
+ watched_span.is_pipeline = True
+
+ watched_span.span.__enter__()
+ self.span_map[run_id] = watched_span
+ self.gc_span_map()
+ return watched_span
+
+ def _exit_span(self, span_data, run_id):
+ # type: (SentryLangchainCallback, WatchedSpan, UUID) -> None
+
+ if span_data.is_pipeline:
+ set_ai_pipeline_name(None)
+
+ span_data.span.__exit__(None, None, None)
+ del self.span_map[run_id]
+
+ def on_llm_start(
+ self,
+ serialized,
+ prompts,
+ *,
+ run_id,
+ tags=None,
+ parent_run_id=None,
+ metadata=None,
+ **kwargs,
+ ):
+ # type: (SentryLangchainCallback, Dict[str, Any], List[str], UUID, Optional[List[str]], Optional[UUID], Optional[Dict[str, Any]], Any) -> Any
+ """Run when LLM starts running."""
+ with capture_internal_exceptions():
+ if not run_id:
+ return
+ all_params = kwargs.get("invocation_params", {})
+ all_params.update(serialized.get("kwargs", {}))
+ watched_span = self._create_span(
+ run_id,
+ kwargs.get("parent_run_id"),
+ op=OP.LANGCHAIN_RUN,
+ name=kwargs.get("name") or "Langchain LLM call",
+ origin=LangchainIntegration.origin,
+ )
+ span = watched_span.span
+ if should_send_default_pii() and self.include_prompts:
+ set_data_normalized(span, SPANDATA.AI_INPUT_MESSAGES, prompts)
+ for k, v in DATA_FIELDS.items():
+ if k in all_params:
+ set_data_normalized(span, v, all_params[k])
+
+ def on_chat_model_start(self, serialized, messages, *, run_id, **kwargs):
+ # type: (SentryLangchainCallback, Dict[str, Any], List[List[BaseMessage]], UUID, Any) -> Any
+ """Run when Chat Model starts running."""
+ with capture_internal_exceptions():
+ if not run_id:
+ return
+ all_params = kwargs.get("invocation_params", {})
+ all_params.update(serialized.get("kwargs", {}))
+ watched_span = self._create_span(
+ run_id,
+ kwargs.get("parent_run_id"),
+ op=OP.LANGCHAIN_CHAT_COMPLETIONS_CREATE,
+ name=kwargs.get("name") or "Langchain Chat Model",
+ origin=LangchainIntegration.origin,
+ )
+ span = watched_span.span
+ model = all_params.get(
+ "model", all_params.get("model_name", all_params.get("model_id"))
+ )
+ watched_span.no_collect_tokens = any(
+ x in all_params.get("_type", "") for x in NO_COLLECT_TOKEN_MODELS
+ )
+
+ if not model and "anthropic" in all_params.get("_type"):
+ model = "claude-2"
+ if model:
+ span.set_data(SPANDATA.AI_MODEL_ID, model)
+ if should_send_default_pii() and self.include_prompts:
+ set_data_normalized(
+ span,
+ SPANDATA.AI_INPUT_MESSAGES,
+ [
+ [self._normalize_langchain_message(x) for x in list_]
+ for list_ in messages
+ ],
+ )
+ for k, v in DATA_FIELDS.items():
+ if k in all_params:
+ set_data_normalized(span, v, all_params[k])
+ if not watched_span.no_collect_tokens:
+ for list_ in messages:
+ for message in list_:
+ self.span_map[run_id].num_prompt_tokens += self.count_tokens(
+ message.content
+ ) + self.count_tokens(message.type)
+
+ def on_llm_new_token(self, token, *, run_id, **kwargs):
+ # type: (SentryLangchainCallback, str, UUID, Any) -> Any
+ """Run on new LLM token. Only available when streaming is enabled."""
+ with capture_internal_exceptions():
+ if not run_id or run_id not in self.span_map:
+ return
+ span_data = self.span_map[run_id]
+ if not span_data or span_data.no_collect_tokens:
+ return
+ span_data.num_completion_tokens += self.count_tokens(token)
+
+ def on_llm_end(self, response, *, run_id, **kwargs):
+ # type: (SentryLangchainCallback, LLMResult, UUID, Any) -> Any
+ """Run when LLM ends running."""
+ with capture_internal_exceptions():
+ if not run_id:
+ return
+
+ token_usage = (
+ response.llm_output.get("token_usage") if response.llm_output else None
+ )
+
+ span_data = self.span_map[run_id]
+ if not span_data:
+ return
+
+ if should_send_default_pii() and self.include_prompts:
+ set_data_normalized(
+ span_data.span,
+ SPANDATA.AI_RESPONSES,
+ [[x.text for x in list_] for list_ in response.generations],
+ )
+
+ if not span_data.no_collect_tokens:
+ if token_usage:
+ record_token_usage(
+ span_data.span,
+ token_usage.get("prompt_tokens"),
+ token_usage.get("completion_tokens"),
+ token_usage.get("total_tokens"),
+ )
+ else:
+ record_token_usage(
+ span_data.span,
+ span_data.num_prompt_tokens,
+ span_data.num_completion_tokens,
+ )
+
+ self._exit_span(span_data, run_id)
+
+ def on_llm_error(self, error, *, run_id, **kwargs):
+ # type: (SentryLangchainCallback, Union[Exception, KeyboardInterrupt], UUID, Any) -> Any
+ """Run when LLM errors."""
+ with capture_internal_exceptions():
+ self._handle_error(run_id, error)
+
+ def on_chain_start(self, serialized, inputs, *, run_id, **kwargs):
+ # type: (SentryLangchainCallback, Dict[str, Any], Dict[str, Any], UUID, Any) -> Any
+ """Run when chain starts running."""
+ with capture_internal_exceptions():
+ if not run_id:
+ return
+ watched_span = self._create_span(
+ run_id,
+ kwargs.get("parent_run_id"),
+ op=(
+ OP.LANGCHAIN_RUN
+ if kwargs.get("parent_run_id") is not None
+ else OP.LANGCHAIN_PIPELINE
+ ),
+ name=kwargs.get("name") or "Chain execution",
+ origin=LangchainIntegration.origin,
+ )
+ metadata = kwargs.get("metadata")
+ if metadata:
+ set_data_normalized(watched_span.span, SPANDATA.AI_METADATA, metadata)
+
+ def on_chain_end(self, outputs, *, run_id, **kwargs):
+ # type: (SentryLangchainCallback, Dict[str, Any], UUID, Any) -> Any
+ """Run when chain ends running."""
+ with capture_internal_exceptions():
+ if not run_id or run_id not in self.span_map:
+ return
+
+ span_data = self.span_map[run_id]
+ if not span_data:
+ return
+ self._exit_span(span_data, run_id)
+
+ def on_chain_error(self, error, *, run_id, **kwargs):
+ # type: (SentryLangchainCallback, Union[Exception, KeyboardInterrupt], UUID, Any) -> Any
+ """Run when chain errors."""
+ self._handle_error(run_id, error)
+
+ def on_agent_action(self, action, *, run_id, **kwargs):
+ # type: (SentryLangchainCallback, AgentAction, UUID, Any) -> Any
+ with capture_internal_exceptions():
+ if not run_id:
+ return
+ watched_span = self._create_span(
+ run_id,
+ kwargs.get("parent_run_id"),
+ op=OP.LANGCHAIN_AGENT,
+ name=action.tool or "AI tool usage",
+ origin=LangchainIntegration.origin,
+ )
+ if action.tool_input and should_send_default_pii() and self.include_prompts:
+ set_data_normalized(
+ watched_span.span, SPANDATA.AI_INPUT_MESSAGES, action.tool_input
+ )
+
+ def on_agent_finish(self, finish, *, run_id, **kwargs):
+ # type: (SentryLangchainCallback, AgentFinish, UUID, Any) -> Any
+ with capture_internal_exceptions():
+ if not run_id:
+ return
+
+ span_data = self.span_map[run_id]
+ if not span_data:
+ return
+ if should_send_default_pii() and self.include_prompts:
+ set_data_normalized(
+ span_data.span, SPANDATA.AI_RESPONSES, finish.return_values.items()
+ )
+ self._exit_span(span_data, run_id)
+
+ def on_tool_start(self, serialized, input_str, *, run_id, **kwargs):
+ # type: (SentryLangchainCallback, Dict[str, Any], str, UUID, Any) -> Any
+ """Run when tool starts running."""
+ with capture_internal_exceptions():
+ if not run_id:
+ return
+ watched_span = self._create_span(
+ run_id,
+ kwargs.get("parent_run_id"),
+ op=OP.LANGCHAIN_TOOL,
+ name=serialized.get("name") or kwargs.get("name") or "AI tool usage",
+ origin=LangchainIntegration.origin,
+ )
+ if should_send_default_pii() and self.include_prompts:
+ set_data_normalized(
+ watched_span.span,
+ SPANDATA.AI_INPUT_MESSAGES,
+ kwargs.get("inputs", [input_str]),
+ )
+ if kwargs.get("metadata"):
+ set_data_normalized(
+ watched_span.span, SPANDATA.AI_METADATA, kwargs.get("metadata")
+ )
+
+ def on_tool_end(self, output, *, run_id, **kwargs):
+ # type: (SentryLangchainCallback, str, UUID, Any) -> Any
+ """Run when tool ends running."""
+ with capture_internal_exceptions():
+ if not run_id or run_id not in self.span_map:
+ return
+
+ span_data = self.span_map[run_id]
+ if not span_data:
+ return
+ if should_send_default_pii() and self.include_prompts:
+ set_data_normalized(span_data.span, SPANDATA.AI_RESPONSES, output)
+ self._exit_span(span_data, run_id)
+
+ def on_tool_error(self, error, *args, run_id, **kwargs):
+ # type: (SentryLangchainCallback, Union[Exception, KeyboardInterrupt], UUID, Any) -> Any
+ """Run when tool errors."""
+ self._handle_error(run_id, error)
+
+
+def _wrap_configure(f):
+ # type: (Callable[..., Any]) -> Callable[..., Any]
+
+ @wraps(f)
+ def new_configure(*args, **kwargs):
+ # type: (Any, Any) -> Any
+
+ integration = sentry_sdk.get_client().get_integration(LangchainIntegration)
+ if integration is None:
+ return f(*args, **kwargs)
+
+ with capture_internal_exceptions():
+ new_callbacks = [] # type: List[BaseCallbackHandler]
+ if "local_callbacks" in kwargs:
+ existing_callbacks = kwargs["local_callbacks"]
+ kwargs["local_callbacks"] = new_callbacks
+ elif len(args) > 2:
+ existing_callbacks = args[2]
+ args = (
+ args[0],
+ args[1],
+ new_callbacks,
+ ) + args[3:]
+ else:
+ existing_callbacks = []
+
+ if existing_callbacks:
+ if isinstance(existing_callbacks, list):
+ for cb in existing_callbacks:
+ new_callbacks.append(cb)
+ elif isinstance(existing_callbacks, BaseCallbackHandler):
+ new_callbacks.append(existing_callbacks)
+ else:
+ logger.debug("Unknown callback type: %s", existing_callbacks)
+
+ already_added = False
+ for callback in new_callbacks:
+ if isinstance(callback, SentryLangchainCallback):
+ already_added = True
+
+ if not already_added:
+ new_callbacks.append(
+ SentryLangchainCallback(
+ integration.max_spans,
+ integration.include_prompts,
+ integration.tiktoken_encoding_name,
+ )
+ )
+ return f(*args, **kwargs)
+
+ return new_configure