diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/integrations/argilla.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/integrations/argilla.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/integrations/argilla.py | 392 |
1 files changed, 392 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/argilla.py b/.venv/lib/python3.12/site-packages/litellm/integrations/argilla.py new file mode 100644 index 00000000..055ad902 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/integrations/argilla.py @@ -0,0 +1,392 @@ +""" +Send logs to Argilla for annotation +""" + +import asyncio +import json +import os +import random +import types +from typing import Any, Dict, List, Optional + +import httpx +from pydantic import BaseModel # type: ignore + +import litellm +from litellm._logging import verbose_logger +from litellm.integrations.custom_batch_logger import CustomBatchLogger +from litellm.integrations.custom_logger import CustomLogger +from litellm.llms.custom_httpx.http_handler import ( + get_async_httpx_client, + httpxSpecialProvider, +) +from litellm.types.integrations.argilla import ( + SUPPORTED_PAYLOAD_FIELDS, + ArgillaCredentialsObject, + ArgillaItem, +) +from litellm.types.utils import StandardLoggingPayload + + +def is_serializable(value): + non_serializable_types = ( + types.CoroutineType, + types.FunctionType, + types.GeneratorType, + BaseModel, + ) + return not isinstance(value, non_serializable_types) + + +class ArgillaLogger(CustomBatchLogger): + def __init__( + self, + argilla_api_key: Optional[str] = None, + argilla_dataset_name: Optional[str] = None, + argilla_base_url: Optional[str] = None, + **kwargs, + ): + if litellm.argilla_transformation_object is None: + raise Exception( + "'litellm.argilla_transformation_object' is required, to log your payload to Argilla." + ) + self.validate_argilla_transformation_object( + litellm.argilla_transformation_object + ) + self.argilla_transformation_object = litellm.argilla_transformation_object + self.default_credentials = self.get_credentials_from_env( + argilla_api_key=argilla_api_key, + argilla_dataset_name=argilla_dataset_name, + argilla_base_url=argilla_base_url, + ) + self.sampling_rate: float = ( + float(os.getenv("ARGILLA_SAMPLING_RATE")) # type: ignore + if os.getenv("ARGILLA_SAMPLING_RATE") is not None + and os.getenv("ARGILLA_SAMPLING_RATE").strip().isdigit() # type: ignore + else 1.0 + ) + + self.async_httpx_client = get_async_httpx_client( + llm_provider=httpxSpecialProvider.LoggingCallback + ) + _batch_size = ( + os.getenv("ARGILLA_BATCH_SIZE", None) or litellm.argilla_batch_size + ) + if _batch_size: + self.batch_size = int(_batch_size) + asyncio.create_task(self.periodic_flush()) + self.flush_lock = asyncio.Lock() + super().__init__(**kwargs, flush_lock=self.flush_lock) + + def validate_argilla_transformation_object( + self, argilla_transformation_object: Dict[str, Any] + ): + if not isinstance(argilla_transformation_object, dict): + raise Exception( + "'argilla_transformation_object' must be a dictionary, to log your payload to Argilla." + ) + + for v in argilla_transformation_object.values(): + if v not in SUPPORTED_PAYLOAD_FIELDS: + raise Exception( + f"All values in argilla_transformation_object must be a key in SUPPORTED_PAYLOAD_FIELDS, {v} is not a valid key." + ) + + def get_credentials_from_env( + self, + argilla_api_key: Optional[str], + argilla_dataset_name: Optional[str], + argilla_base_url: Optional[str], + ) -> ArgillaCredentialsObject: + + _credentials_api_key = argilla_api_key or os.getenv("ARGILLA_API_KEY") + if _credentials_api_key is None: + raise Exception("Invalid Argilla API Key given. _credentials_api_key=None.") + + _credentials_base_url = ( + argilla_base_url + or os.getenv("ARGILLA_BASE_URL") + or "http://localhost:6900/" + ) + if _credentials_base_url is None: + raise Exception( + "Invalid Argilla Base URL given. _credentials_base_url=None." + ) + + _credentials_dataset_name = ( + argilla_dataset_name + or os.getenv("ARGILLA_DATASET_NAME") + or "litellm-completion" + ) + if _credentials_dataset_name is None: + raise Exception("Invalid Argilla Dataset give. Value=None.") + else: + dataset_response = litellm.module_level_client.get( + url=f"{_credentials_base_url}/api/v1/me/datasets?name={_credentials_dataset_name}", + headers={"X-Argilla-Api-Key": _credentials_api_key}, + ) + json_response = dataset_response.json() + if ( + "items" in json_response + and isinstance(json_response["items"], list) + and len(json_response["items"]) > 0 + ): + _credentials_dataset_name = json_response["items"][0]["id"] + + return ArgillaCredentialsObject( + ARGILLA_API_KEY=_credentials_api_key, + ARGILLA_BASE_URL=_credentials_base_url, + ARGILLA_DATASET_NAME=_credentials_dataset_name, + ) + + def get_chat_messages( + self, payload: StandardLoggingPayload + ) -> List[Dict[str, Any]]: + payload_messages = payload.get("messages", None) + + if payload_messages is None: + raise Exception("No chat messages found in payload.") + + if ( + isinstance(payload_messages, list) + and len(payload_messages) > 0 + and isinstance(payload_messages[0], dict) + ): + return payload_messages + elif isinstance(payload_messages, dict): + return [payload_messages] + else: + raise Exception(f"Invalid chat messages format: {payload_messages}") + + def get_str_response(self, payload: StandardLoggingPayload) -> str: + response = payload["response"] + + if response is None: + raise Exception("No response found in payload.") + + if isinstance(response, str): + return response + elif isinstance(response, dict): + return ( + response.get("choices", [{}])[0].get("message", {}).get("content", "") + ) + else: + raise Exception(f"Invalid response format: {response}") + + def _prepare_log_data( + self, kwargs, response_obj, start_time, end_time + ) -> Optional[ArgillaItem]: + try: + # Ensure everything in the payload is converted to str + payload: Optional[StandardLoggingPayload] = kwargs.get( + "standard_logging_object", None + ) + + if payload is None: + raise Exception("Error logging request payload. Payload=none.") + + argilla_message = self.get_chat_messages(payload) + argilla_response = self.get_str_response(payload) + argilla_item: ArgillaItem = {"fields": {}} + for k, v in self.argilla_transformation_object.items(): + if v == "messages": + argilla_item["fields"][k] = argilla_message + elif v == "response": + argilla_item["fields"][k] = argilla_response + else: + argilla_item["fields"][k] = payload.get(v, None) + + return argilla_item + except Exception: + raise + + def _send_batch(self): + if not self.log_queue: + return + + argilla_api_base = self.default_credentials["ARGILLA_BASE_URL"] + argilla_dataset_name = self.default_credentials["ARGILLA_DATASET_NAME"] + + url = f"{argilla_api_base}/api/v1/datasets/{argilla_dataset_name}/records/bulk" + + argilla_api_key = self.default_credentials["ARGILLA_API_KEY"] + + headers = {"X-Argilla-Api-Key": argilla_api_key} + + try: + response = litellm.module_level_client.post( + url=url, + json=self.log_queue, + headers=headers, + ) + + if response.status_code >= 300: + verbose_logger.error( + f"Argilla Error: {response.status_code} - {response.text}" + ) + else: + verbose_logger.debug( + f"Batch of {len(self.log_queue)} runs successfully created" + ) + + self.log_queue.clear() + except Exception: + verbose_logger.exception("Argilla Layer Error - Error sending batch.") + + def log_success_event(self, kwargs, response_obj, start_time, end_time): + try: + sampling_rate = ( + float(os.getenv("LANGSMITH_SAMPLING_RATE")) # type: ignore + if os.getenv("LANGSMITH_SAMPLING_RATE") is not None + and os.getenv("LANGSMITH_SAMPLING_RATE").strip().isdigit() # type: ignore + else 1.0 + ) + random_sample = random.random() + if random_sample > sampling_rate: + verbose_logger.info( + "Skipping Langsmith logging. Sampling rate={}, random_sample={}".format( + sampling_rate, random_sample + ) + ) + return # Skip logging + verbose_logger.debug( + "Langsmith Sync Layer Logging - kwargs: %s, response_obj: %s", + kwargs, + response_obj, + ) + data = self._prepare_log_data(kwargs, response_obj, start_time, end_time) + if data is None: + return + + self.log_queue.append(data) + verbose_logger.debug( + f"Langsmith, event added to queue. Will flush in {self.flush_interval} seconds..." + ) + + if len(self.log_queue) >= self.batch_size: + self._send_batch() + + except Exception: + verbose_logger.exception("Langsmith Layer Error - log_success_event error") + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + try: + sampling_rate = self.sampling_rate + random_sample = random.random() + if random_sample > sampling_rate: + verbose_logger.info( + "Skipping Langsmith logging. Sampling rate={}, random_sample={}".format( + sampling_rate, random_sample + ) + ) + return # Skip logging + verbose_logger.debug( + "Langsmith Async Layer Logging - kwargs: %s, response_obj: %s", + kwargs, + response_obj, + ) + payload: Optional[StandardLoggingPayload] = kwargs.get( + "standard_logging_object", None + ) + + data = self._prepare_log_data(kwargs, response_obj, start_time, end_time) + + ## ALLOW CUSTOM LOGGERS TO MODIFY / FILTER DATA BEFORE LOGGING + for callback in litellm.callbacks: + if isinstance(callback, CustomLogger): + try: + if data is None: + break + data = await callback.async_dataset_hook(data, payload) + except NotImplementedError: + pass + + if data is None: + return + + self.log_queue.append(data) + verbose_logger.debug( + "Langsmith logging: queue length %s, batch size %s", + len(self.log_queue), + self.batch_size, + ) + if len(self.log_queue) >= self.batch_size: + await self.flush_queue() + except Exception: + verbose_logger.exception( + "Argilla Layer Error - error logging async success event." + ) + + async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): + sampling_rate = self.sampling_rate + random_sample = random.random() + if random_sample > sampling_rate: + verbose_logger.info( + "Skipping Langsmith logging. Sampling rate={}, random_sample={}".format( + sampling_rate, random_sample + ) + ) + return # Skip logging + verbose_logger.info("Langsmith Failure Event Logging!") + try: + data = self._prepare_log_data(kwargs, response_obj, start_time, end_time) + self.log_queue.append(data) + verbose_logger.debug( + "Langsmith logging: queue length %s, batch size %s", + len(self.log_queue), + self.batch_size, + ) + if len(self.log_queue) >= self.batch_size: + await self.flush_queue() + except Exception: + verbose_logger.exception( + "Langsmith Layer Error - error logging async failure event." + ) + + async def async_send_batch(self): + """ + sends runs to /batch endpoint + + Sends runs from self.log_queue + + Returns: None + + Raises: Does not raise an exception, will only verbose_logger.exception() + """ + if not self.log_queue: + return + + argilla_api_base = self.default_credentials["ARGILLA_BASE_URL"] + argilla_dataset_name = self.default_credentials["ARGILLA_DATASET_NAME"] + + url = f"{argilla_api_base}/api/v1/datasets/{argilla_dataset_name}/records/bulk" + + argilla_api_key = self.default_credentials["ARGILLA_API_KEY"] + + headers = {"X-Argilla-Api-Key": argilla_api_key} + + try: + response = await self.async_httpx_client.put( + url=url, + data=json.dumps( + { + "items": self.log_queue, + } + ), + headers=headers, + timeout=60000, + ) + response.raise_for_status() + + if response.status_code >= 300: + verbose_logger.error( + f"Argilla Error: {response.status_code} - {response.text}" + ) + else: + verbose_logger.debug( + "Batch of %s runs successfully created", len(self.log_queue) + ) + except httpx.HTTPStatusError: + verbose_logger.exception("Argilla HTTP Error") + except Exception: + verbose_logger.exception("Argilla Layer Error") |