diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/integrations/literal_ai.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/integrations/literal_ai.py | 317 |
1 files changed, 317 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/literal_ai.py b/.venv/lib/python3.12/site-packages/litellm/integrations/literal_ai.py new file mode 100644 index 00000000..5bf9afd7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/integrations/literal_ai.py @@ -0,0 +1,317 @@ +#### What this does #### +# This file contains the LiteralAILogger class which is used to log steps to the LiteralAI observability platform. +import asyncio +import os +import uuid +from typing import List, Optional + +import httpx + +from litellm._logging import verbose_logger +from litellm.integrations.custom_batch_logger import CustomBatchLogger +from litellm.llms.custom_httpx.http_handler import ( + HTTPHandler, + get_async_httpx_client, + httpxSpecialProvider, +) +from litellm.types.utils import StandardLoggingPayload + + +class LiteralAILogger(CustomBatchLogger): + def __init__( + self, + literalai_api_key=None, + literalai_api_url="https://cloud.getliteral.ai", + env=None, + **kwargs, + ): + self.literalai_api_url = os.getenv("LITERAL_API_URL") or literalai_api_url + self.headers = { + "Content-Type": "application/json", + "x-api-key": literalai_api_key or os.getenv("LITERAL_API_KEY"), + "x-client-name": "litellm", + } + if env: + self.headers["x-env"] = env + self.async_httpx_client = get_async_httpx_client( + llm_provider=httpxSpecialProvider.LoggingCallback + ) + self.sync_http_handler = HTTPHandler() + batch_size = os.getenv("LITERAL_BATCH_SIZE", None) + self.flush_lock = asyncio.Lock() + super().__init__( + **kwargs, + flush_lock=self.flush_lock, + batch_size=int(batch_size) if batch_size else None, + ) + + def log_success_event(self, kwargs, response_obj, start_time, end_time): + try: + verbose_logger.debug( + "Literal AI Layer Logging - kwargs: %s, response_obj: %s", + kwargs, + response_obj, + ) + data = self._prepare_log_data(kwargs, response_obj, start_time, end_time) + self.log_queue.append(data) + verbose_logger.debug( + "Literal AI logging: queue length %s, batch size %s", + len(self.log_queue), + self.batch_size, + ) + if len(self.log_queue) >= self.batch_size: + self._send_batch() + except Exception: + verbose_logger.exception( + "Literal AI Layer Error - error logging success event." + ) + + def log_failure_event(self, kwargs, response_obj, start_time, end_time): + verbose_logger.info("Literal AI Failure Event Logging!") + try: + data = self._prepare_log_data(kwargs, response_obj, start_time, end_time) + self.log_queue.append(data) + verbose_logger.debug( + "Literal AI logging: queue length %s, batch size %s", + len(self.log_queue), + self.batch_size, + ) + if len(self.log_queue) >= self.batch_size: + self._send_batch() + except Exception: + verbose_logger.exception( + "Literal AI Layer Error - error logging failure event." + ) + + def _send_batch(self): + if not self.log_queue: + return + + url = f"{self.literalai_api_url}/api/graphql" + query = self._steps_query_builder(self.log_queue) + variables = self._steps_variables_builder(self.log_queue) + try: + response = self.sync_http_handler.post( + url=url, + json={ + "query": query, + "variables": variables, + }, + headers=self.headers, + ) + + if response.status_code >= 300: + verbose_logger.error( + f"Literal AI Error: {response.status_code} - {response.text}" + ) + else: + verbose_logger.debug( + f"Batch of {len(self.log_queue)} runs successfully created" + ) + except Exception: + verbose_logger.exception("Literal AI Layer Error") + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + try: + verbose_logger.debug( + "Literal AI Async Layer Logging - kwargs: %s, response_obj: %s", + kwargs, + response_obj, + ) + data = self._prepare_log_data(kwargs, response_obj, start_time, end_time) + self.log_queue.append(data) + verbose_logger.debug( + "Literal AI logging: queue length %s, batch size %s", + len(self.log_queue), + self.batch_size, + ) + if len(self.log_queue) >= self.batch_size: + await self.flush_queue() + except Exception: + verbose_logger.exception( + "Literal AI Layer Error - error logging async success event." + ) + + async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): + verbose_logger.info("Literal AI Failure Event Logging!") + try: + data = self._prepare_log_data(kwargs, response_obj, start_time, end_time) + self.log_queue.append(data) + verbose_logger.debug( + "Literal AI logging: queue length %s, batch size %s", + len(self.log_queue), + self.batch_size, + ) + if len(self.log_queue) >= self.batch_size: + await self.flush_queue() + except Exception: + verbose_logger.exception( + "Literal AI Layer Error - error logging async failure event." + ) + + async def async_send_batch(self): + if not self.log_queue: + return + + url = f"{self.literalai_api_url}/api/graphql" + query = self._steps_query_builder(self.log_queue) + variables = self._steps_variables_builder(self.log_queue) + + try: + response = await self.async_httpx_client.post( + url=url, + json={ + "query": query, + "variables": variables, + }, + headers=self.headers, + ) + if response.status_code >= 300: + verbose_logger.error( + f"Literal AI Error: {response.status_code} - {response.text}" + ) + else: + verbose_logger.debug( + f"Batch of {len(self.log_queue)} runs successfully created" + ) + except httpx.HTTPStatusError as e: + verbose_logger.exception( + f"Literal AI HTTP Error: {e.response.status_code} - {e.response.text}" + ) + except Exception: + verbose_logger.exception("Literal AI Layer Error") + + def _prepare_log_data(self, kwargs, response_obj, start_time, end_time) -> dict: + logging_payload: Optional[StandardLoggingPayload] = kwargs.get( + "standard_logging_object", None + ) + + if logging_payload is None: + raise ValueError("standard_logging_object not found in kwargs") + clean_metadata = logging_payload["metadata"] + metadata = kwargs.get("litellm_params", {}).get("metadata", {}) + + settings = logging_payload["model_parameters"] + messages = logging_payload["messages"] + response = logging_payload["response"] + choices: List = [] + if isinstance(response, dict) and "choices" in response: + choices = response["choices"] + message_completion = choices[0]["message"] if choices else None + prompt_id = None + variables = None + + if messages and isinstance(messages, list) and isinstance(messages[0], dict): + for message in messages: + if literal_prompt := getattr(message, "__literal_prompt__", None): + prompt_id = literal_prompt.get("prompt_id") + variables = literal_prompt.get("variables") + message["uuid"] = literal_prompt.get("uuid") + message["templated"] = True + + tools = settings.pop("tools", None) + + step = { + "id": metadata.get("step_id", str(uuid.uuid4())), + "error": logging_payload["error_str"], + "name": kwargs.get("model", ""), + "threadId": metadata.get("literalai_thread_id", None), + "parentId": metadata.get("literalai_parent_id", None), + "rootRunId": metadata.get("literalai_root_run_id", None), + "input": None, + "output": None, + "type": "llm", + "tags": metadata.get("tags", metadata.get("literalai_tags", None)), + "startTime": str(start_time), + "endTime": str(end_time), + "metadata": clean_metadata, + "generation": { + "inputTokenCount": logging_payload["prompt_tokens"], + "outputTokenCount": logging_payload["completion_tokens"], + "tokenCount": logging_payload["total_tokens"], + "promptId": prompt_id, + "variables": variables, + "provider": kwargs.get("custom_llm_provider", "litellm"), + "model": kwargs.get("model", ""), + "duration": (end_time - start_time).total_seconds(), + "settings": settings, + "messages": messages, + "messageCompletion": message_completion, + "tools": tools, + }, + } + return step + + def _steps_query_variables_builder(self, steps): + generated = "" + for id in range(len(steps)): + generated += f"""$id_{id}: String! + $threadId_{id}: String + $rootRunId_{id}: String + $type_{id}: StepType + $startTime_{id}: DateTime + $endTime_{id}: DateTime + $error_{id}: String + $input_{id}: Json + $output_{id}: Json + $metadata_{id}: Json + $parentId_{id}: String + $name_{id}: String + $tags_{id}: [String!] + $generation_{id}: GenerationPayloadInput + $scores_{id}: [ScorePayloadInput!] + $attachments_{id}: [AttachmentPayloadInput!] + """ + return generated + + def _steps_ingest_steps_builder(self, steps): + generated = "" + for id in range(len(steps)): + generated += f""" + step{id}: ingestStep( + id: $id_{id} + threadId: $threadId_{id} + rootRunId: $rootRunId_{id} + startTime: $startTime_{id} + endTime: $endTime_{id} + type: $type_{id} + error: $error_{id} + input: $input_{id} + output: $output_{id} + metadata: $metadata_{id} + parentId: $parentId_{id} + name: $name_{id} + tags: $tags_{id} + generation: $generation_{id} + scores: $scores_{id} + attachments: $attachments_{id} + ) {{ + ok + message + }} + """ + return generated + + def _steps_query_builder(self, steps): + return f""" + mutation AddStep({self._steps_query_variables_builder(steps)}) {{ + {self._steps_ingest_steps_builder(steps)} + }} + """ + + def _steps_variables_builder(self, steps): + def serialize_step(event, id): + result = {} + + for key, value in event.items(): + # Only keep the keys that are not None to avoid overriding existing values + if value is not None: + result[f"{key}_{id}"] = value + + return result + + variables = {} + for i in range(len(steps)): + step = steps[i] + variables.update(serialize_step(step, i)) + return variables |