aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/integrations/literal_ai.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/integrations/literal_ai.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/literal_ai.py317
1 files changed, 317 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/literal_ai.py b/.venv/lib/python3.12/site-packages/litellm/integrations/literal_ai.py
new file mode 100644
index 00000000..5bf9afd7
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/literal_ai.py
@@ -0,0 +1,317 @@
+#### What this does ####
+# This file contains the LiteralAILogger class which is used to log steps to the LiteralAI observability platform.
+import asyncio
+import os
+import uuid
+from typing import List, Optional
+
+import httpx
+
+from litellm._logging import verbose_logger
+from litellm.integrations.custom_batch_logger import CustomBatchLogger
+from litellm.llms.custom_httpx.http_handler import (
+ HTTPHandler,
+ get_async_httpx_client,
+ httpxSpecialProvider,
+)
+from litellm.types.utils import StandardLoggingPayload
+
+
+class LiteralAILogger(CustomBatchLogger):
+ def __init__(
+ self,
+ literalai_api_key=None,
+ literalai_api_url="https://cloud.getliteral.ai",
+ env=None,
+ **kwargs,
+ ):
+ self.literalai_api_url = os.getenv("LITERAL_API_URL") or literalai_api_url
+ self.headers = {
+ "Content-Type": "application/json",
+ "x-api-key": literalai_api_key or os.getenv("LITERAL_API_KEY"),
+ "x-client-name": "litellm",
+ }
+ if env:
+ self.headers["x-env"] = env
+ self.async_httpx_client = get_async_httpx_client(
+ llm_provider=httpxSpecialProvider.LoggingCallback
+ )
+ self.sync_http_handler = HTTPHandler()
+ batch_size = os.getenv("LITERAL_BATCH_SIZE", None)
+ self.flush_lock = asyncio.Lock()
+ super().__init__(
+ **kwargs,
+ flush_lock=self.flush_lock,
+ batch_size=int(batch_size) if batch_size else None,
+ )
+
+ def log_success_event(self, kwargs, response_obj, start_time, end_time):
+ try:
+ verbose_logger.debug(
+ "Literal AI Layer Logging - kwargs: %s, response_obj: %s",
+ kwargs,
+ response_obj,
+ )
+ data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
+ self.log_queue.append(data)
+ verbose_logger.debug(
+ "Literal AI logging: queue length %s, batch size %s",
+ len(self.log_queue),
+ self.batch_size,
+ )
+ if len(self.log_queue) >= self.batch_size:
+ self._send_batch()
+ except Exception:
+ verbose_logger.exception(
+ "Literal AI Layer Error - error logging success event."
+ )
+
+ def log_failure_event(self, kwargs, response_obj, start_time, end_time):
+ verbose_logger.info("Literal AI Failure Event Logging!")
+ try:
+ data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
+ self.log_queue.append(data)
+ verbose_logger.debug(
+ "Literal AI logging: queue length %s, batch size %s",
+ len(self.log_queue),
+ self.batch_size,
+ )
+ if len(self.log_queue) >= self.batch_size:
+ self._send_batch()
+ except Exception:
+ verbose_logger.exception(
+ "Literal AI Layer Error - error logging failure event."
+ )
+
+ def _send_batch(self):
+ if not self.log_queue:
+ return
+
+ url = f"{self.literalai_api_url}/api/graphql"
+ query = self._steps_query_builder(self.log_queue)
+ variables = self._steps_variables_builder(self.log_queue)
+ try:
+ response = self.sync_http_handler.post(
+ url=url,
+ json={
+ "query": query,
+ "variables": variables,
+ },
+ headers=self.headers,
+ )
+
+ if response.status_code >= 300:
+ verbose_logger.error(
+ f"Literal AI Error: {response.status_code} - {response.text}"
+ )
+ else:
+ verbose_logger.debug(
+ f"Batch of {len(self.log_queue)} runs successfully created"
+ )
+ except Exception:
+ verbose_logger.exception("Literal AI Layer Error")
+
+ async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
+ try:
+ verbose_logger.debug(
+ "Literal AI Async Layer Logging - kwargs: %s, response_obj: %s",
+ kwargs,
+ response_obj,
+ )
+ data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
+ self.log_queue.append(data)
+ verbose_logger.debug(
+ "Literal AI logging: queue length %s, batch size %s",
+ len(self.log_queue),
+ self.batch_size,
+ )
+ if len(self.log_queue) >= self.batch_size:
+ await self.flush_queue()
+ except Exception:
+ verbose_logger.exception(
+ "Literal AI Layer Error - error logging async success event."
+ )
+
+ async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
+ verbose_logger.info("Literal AI Failure Event Logging!")
+ try:
+ data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
+ self.log_queue.append(data)
+ verbose_logger.debug(
+ "Literal AI logging: queue length %s, batch size %s",
+ len(self.log_queue),
+ self.batch_size,
+ )
+ if len(self.log_queue) >= self.batch_size:
+ await self.flush_queue()
+ except Exception:
+ verbose_logger.exception(
+ "Literal AI Layer Error - error logging async failure event."
+ )
+
+ async def async_send_batch(self):
+ if not self.log_queue:
+ return
+
+ url = f"{self.literalai_api_url}/api/graphql"
+ query = self._steps_query_builder(self.log_queue)
+ variables = self._steps_variables_builder(self.log_queue)
+
+ try:
+ response = await self.async_httpx_client.post(
+ url=url,
+ json={
+ "query": query,
+ "variables": variables,
+ },
+ headers=self.headers,
+ )
+ if response.status_code >= 300:
+ verbose_logger.error(
+ f"Literal AI Error: {response.status_code} - {response.text}"
+ )
+ else:
+ verbose_logger.debug(
+ f"Batch of {len(self.log_queue)} runs successfully created"
+ )
+ except httpx.HTTPStatusError as e:
+ verbose_logger.exception(
+ f"Literal AI HTTP Error: {e.response.status_code} - {e.response.text}"
+ )
+ except Exception:
+ verbose_logger.exception("Literal AI Layer Error")
+
+ def _prepare_log_data(self, kwargs, response_obj, start_time, end_time) -> dict:
+ logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
+ "standard_logging_object", None
+ )
+
+ if logging_payload is None:
+ raise ValueError("standard_logging_object not found in kwargs")
+ clean_metadata = logging_payload["metadata"]
+ metadata = kwargs.get("litellm_params", {}).get("metadata", {})
+
+ settings = logging_payload["model_parameters"]
+ messages = logging_payload["messages"]
+ response = logging_payload["response"]
+ choices: List = []
+ if isinstance(response, dict) and "choices" in response:
+ choices = response["choices"]
+ message_completion = choices[0]["message"] if choices else None
+ prompt_id = None
+ variables = None
+
+ if messages and isinstance(messages, list) and isinstance(messages[0], dict):
+ for message in messages:
+ if literal_prompt := getattr(message, "__literal_prompt__", None):
+ prompt_id = literal_prompt.get("prompt_id")
+ variables = literal_prompt.get("variables")
+ message["uuid"] = literal_prompt.get("uuid")
+ message["templated"] = True
+
+ tools = settings.pop("tools", None)
+
+ step = {
+ "id": metadata.get("step_id", str(uuid.uuid4())),
+ "error": logging_payload["error_str"],
+ "name": kwargs.get("model", ""),
+ "threadId": metadata.get("literalai_thread_id", None),
+ "parentId": metadata.get("literalai_parent_id", None),
+ "rootRunId": metadata.get("literalai_root_run_id", None),
+ "input": None,
+ "output": None,
+ "type": "llm",
+ "tags": metadata.get("tags", metadata.get("literalai_tags", None)),
+ "startTime": str(start_time),
+ "endTime": str(end_time),
+ "metadata": clean_metadata,
+ "generation": {
+ "inputTokenCount": logging_payload["prompt_tokens"],
+ "outputTokenCount": logging_payload["completion_tokens"],
+ "tokenCount": logging_payload["total_tokens"],
+ "promptId": prompt_id,
+ "variables": variables,
+ "provider": kwargs.get("custom_llm_provider", "litellm"),
+ "model": kwargs.get("model", ""),
+ "duration": (end_time - start_time).total_seconds(),
+ "settings": settings,
+ "messages": messages,
+ "messageCompletion": message_completion,
+ "tools": tools,
+ },
+ }
+ return step
+
+ def _steps_query_variables_builder(self, steps):
+ generated = ""
+ for id in range(len(steps)):
+ generated += f"""$id_{id}: String!
+ $threadId_{id}: String
+ $rootRunId_{id}: String
+ $type_{id}: StepType
+ $startTime_{id}: DateTime
+ $endTime_{id}: DateTime
+ $error_{id}: String
+ $input_{id}: Json
+ $output_{id}: Json
+ $metadata_{id}: Json
+ $parentId_{id}: String
+ $name_{id}: String
+ $tags_{id}: [String!]
+ $generation_{id}: GenerationPayloadInput
+ $scores_{id}: [ScorePayloadInput!]
+ $attachments_{id}: [AttachmentPayloadInput!]
+ """
+ return generated
+
+ def _steps_ingest_steps_builder(self, steps):
+ generated = ""
+ for id in range(len(steps)):
+ generated += f"""
+ step{id}: ingestStep(
+ id: $id_{id}
+ threadId: $threadId_{id}
+ rootRunId: $rootRunId_{id}
+ startTime: $startTime_{id}
+ endTime: $endTime_{id}
+ type: $type_{id}
+ error: $error_{id}
+ input: $input_{id}
+ output: $output_{id}
+ metadata: $metadata_{id}
+ parentId: $parentId_{id}
+ name: $name_{id}
+ tags: $tags_{id}
+ generation: $generation_{id}
+ scores: $scores_{id}
+ attachments: $attachments_{id}
+ ) {{
+ ok
+ message
+ }}
+ """
+ return generated
+
+ def _steps_query_builder(self, steps):
+ return f"""
+ mutation AddStep({self._steps_query_variables_builder(steps)}) {{
+ {self._steps_ingest_steps_builder(steps)}
+ }}
+ """
+
+ def _steps_variables_builder(self, steps):
+ def serialize_step(event, id):
+ result = {}
+
+ for key, value in event.items():
+ # Only keep the keys that are not None to avoid overriding existing values
+ if value is not None:
+ result[f"{key}_{id}"] = value
+
+ return result
+
+ variables = {}
+ for i in range(len(steps)):
+ step = steps[i]
+ variables.update(serialize_step(step, i))
+ return variables