about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/litellm/integrations/argilla.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/integrations/argilla.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/integrations/argilla.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/argilla.py392
1 files changed, 392 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/argilla.py b/.venv/lib/python3.12/site-packages/litellm/integrations/argilla.py
new file mode 100644
index 00000000..055ad902
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/argilla.py
@@ -0,0 +1,392 @@
+"""
+Send logs to Argilla for annotation
+"""
+
+import asyncio
+import json
+import os
+import random
+import types
+from typing import Any, Dict, List, Optional
+
+import httpx
+from pydantic import BaseModel  # type: ignore
+
+import litellm
+from litellm._logging import verbose_logger
+from litellm.integrations.custom_batch_logger import CustomBatchLogger
+from litellm.integrations.custom_logger import CustomLogger
+from litellm.llms.custom_httpx.http_handler import (
+    get_async_httpx_client,
+    httpxSpecialProvider,
+)
+from litellm.types.integrations.argilla import (
+    SUPPORTED_PAYLOAD_FIELDS,
+    ArgillaCredentialsObject,
+    ArgillaItem,
+)
+from litellm.types.utils import StandardLoggingPayload
+
+
+def is_serializable(value):
+    non_serializable_types = (
+        types.CoroutineType,
+        types.FunctionType,
+        types.GeneratorType,
+        BaseModel,
+    )
+    return not isinstance(value, non_serializable_types)
+
+
+class ArgillaLogger(CustomBatchLogger):
+    def __init__(
+        self,
+        argilla_api_key: Optional[str] = None,
+        argilla_dataset_name: Optional[str] = None,
+        argilla_base_url: Optional[str] = None,
+        **kwargs,
+    ):
+        if litellm.argilla_transformation_object is None:
+            raise Exception(
+                "'litellm.argilla_transformation_object' is required, to log your payload to Argilla."
+            )
+        self.validate_argilla_transformation_object(
+            litellm.argilla_transformation_object
+        )
+        self.argilla_transformation_object = litellm.argilla_transformation_object
+        self.default_credentials = self.get_credentials_from_env(
+            argilla_api_key=argilla_api_key,
+            argilla_dataset_name=argilla_dataset_name,
+            argilla_base_url=argilla_base_url,
+        )
+        self.sampling_rate: float = (
+            float(os.getenv("ARGILLA_SAMPLING_RATE"))  # type: ignore
+            if os.getenv("ARGILLA_SAMPLING_RATE") is not None
+            and os.getenv("ARGILLA_SAMPLING_RATE").strip().isdigit()  # type: ignore
+            else 1.0
+        )
+
+        self.async_httpx_client = get_async_httpx_client(
+            llm_provider=httpxSpecialProvider.LoggingCallback
+        )
+        _batch_size = (
+            os.getenv("ARGILLA_BATCH_SIZE", None) or litellm.argilla_batch_size
+        )
+        if _batch_size:
+            self.batch_size = int(_batch_size)
+        asyncio.create_task(self.periodic_flush())
+        self.flush_lock = asyncio.Lock()
+        super().__init__(**kwargs, flush_lock=self.flush_lock)
+
+    def validate_argilla_transformation_object(
+        self, argilla_transformation_object: Dict[str, Any]
+    ):
+        if not isinstance(argilla_transformation_object, dict):
+            raise Exception(
+                "'argilla_transformation_object' must be a dictionary, to log your payload to Argilla."
+            )
+
+        for v in argilla_transformation_object.values():
+            if v not in SUPPORTED_PAYLOAD_FIELDS:
+                raise Exception(
+                    f"All values in argilla_transformation_object must be a key in SUPPORTED_PAYLOAD_FIELDS, {v} is not a valid key."
+                )
+
+    def get_credentials_from_env(
+        self,
+        argilla_api_key: Optional[str],
+        argilla_dataset_name: Optional[str],
+        argilla_base_url: Optional[str],
+    ) -> ArgillaCredentialsObject:
+
+        _credentials_api_key = argilla_api_key or os.getenv("ARGILLA_API_KEY")
+        if _credentials_api_key is None:
+            raise Exception("Invalid Argilla API Key given. _credentials_api_key=None.")
+
+        _credentials_base_url = (
+            argilla_base_url
+            or os.getenv("ARGILLA_BASE_URL")
+            or "http://localhost:6900/"
+        )
+        if _credentials_base_url is None:
+            raise Exception(
+                "Invalid Argilla Base URL given. _credentials_base_url=None."
+            )
+
+        _credentials_dataset_name = (
+            argilla_dataset_name
+            or os.getenv("ARGILLA_DATASET_NAME")
+            or "litellm-completion"
+        )
+        if _credentials_dataset_name is None:
+            raise Exception("Invalid Argilla Dataset give. Value=None.")
+        else:
+            dataset_response = litellm.module_level_client.get(
+                url=f"{_credentials_base_url}/api/v1/me/datasets?name={_credentials_dataset_name}",
+                headers={"X-Argilla-Api-Key": _credentials_api_key},
+            )
+            json_response = dataset_response.json()
+            if (
+                "items" in json_response
+                and isinstance(json_response["items"], list)
+                and len(json_response["items"]) > 0
+            ):
+                _credentials_dataset_name = json_response["items"][0]["id"]
+
+        return ArgillaCredentialsObject(
+            ARGILLA_API_KEY=_credentials_api_key,
+            ARGILLA_BASE_URL=_credentials_base_url,
+            ARGILLA_DATASET_NAME=_credentials_dataset_name,
+        )
+
+    def get_chat_messages(
+        self, payload: StandardLoggingPayload
+    ) -> List[Dict[str, Any]]:
+        payload_messages = payload.get("messages", None)
+
+        if payload_messages is None:
+            raise Exception("No chat messages found in payload.")
+
+        if (
+            isinstance(payload_messages, list)
+            and len(payload_messages) > 0
+            and isinstance(payload_messages[0], dict)
+        ):
+            return payload_messages
+        elif isinstance(payload_messages, dict):
+            return [payload_messages]
+        else:
+            raise Exception(f"Invalid chat messages format: {payload_messages}")
+
+    def get_str_response(self, payload: StandardLoggingPayload) -> str:
+        response = payload["response"]
+
+        if response is None:
+            raise Exception("No response found in payload.")
+
+        if isinstance(response, str):
+            return response
+        elif isinstance(response, dict):
+            return (
+                response.get("choices", [{}])[0].get("message", {}).get("content", "")
+            )
+        else:
+            raise Exception(f"Invalid response format: {response}")
+
+    def _prepare_log_data(
+        self, kwargs, response_obj, start_time, end_time
+    ) -> Optional[ArgillaItem]:
+        try:
+            # Ensure everything in the payload is converted to str
+            payload: Optional[StandardLoggingPayload] = kwargs.get(
+                "standard_logging_object", None
+            )
+
+            if payload is None:
+                raise Exception("Error logging request payload. Payload=none.")
+
+            argilla_message = self.get_chat_messages(payload)
+            argilla_response = self.get_str_response(payload)
+            argilla_item: ArgillaItem = {"fields": {}}
+            for k, v in self.argilla_transformation_object.items():
+                if v == "messages":
+                    argilla_item["fields"][k] = argilla_message
+                elif v == "response":
+                    argilla_item["fields"][k] = argilla_response
+                else:
+                    argilla_item["fields"][k] = payload.get(v, None)
+
+            return argilla_item
+        except Exception:
+            raise
+
+    def _send_batch(self):
+        if not self.log_queue:
+            return
+
+        argilla_api_base = self.default_credentials["ARGILLA_BASE_URL"]
+        argilla_dataset_name = self.default_credentials["ARGILLA_DATASET_NAME"]
+
+        url = f"{argilla_api_base}/api/v1/datasets/{argilla_dataset_name}/records/bulk"
+
+        argilla_api_key = self.default_credentials["ARGILLA_API_KEY"]
+
+        headers = {"X-Argilla-Api-Key": argilla_api_key}
+
+        try:
+            response = litellm.module_level_client.post(
+                url=url,
+                json=self.log_queue,
+                headers=headers,
+            )
+
+            if response.status_code >= 300:
+                verbose_logger.error(
+                    f"Argilla Error: {response.status_code} - {response.text}"
+                )
+            else:
+                verbose_logger.debug(
+                    f"Batch of {len(self.log_queue)} runs successfully created"
+                )
+
+            self.log_queue.clear()
+        except Exception:
+            verbose_logger.exception("Argilla Layer Error - Error sending batch.")
+
+    def log_success_event(self, kwargs, response_obj, start_time, end_time):
+        try:
+            sampling_rate = (
+                float(os.getenv("LANGSMITH_SAMPLING_RATE"))  # type: ignore
+                if os.getenv("LANGSMITH_SAMPLING_RATE") is not None
+                and os.getenv("LANGSMITH_SAMPLING_RATE").strip().isdigit()  # type: ignore
+                else 1.0
+            )
+            random_sample = random.random()
+            if random_sample > sampling_rate:
+                verbose_logger.info(
+                    "Skipping Langsmith logging. Sampling rate={}, random_sample={}".format(
+                        sampling_rate, random_sample
+                    )
+                )
+                return  # Skip logging
+            verbose_logger.debug(
+                "Langsmith Sync Layer Logging - kwargs: %s, response_obj: %s",
+                kwargs,
+                response_obj,
+            )
+            data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
+            if data is None:
+                return
+
+            self.log_queue.append(data)
+            verbose_logger.debug(
+                f"Langsmith, event added to queue. Will flush in {self.flush_interval} seconds..."
+            )
+
+            if len(self.log_queue) >= self.batch_size:
+                self._send_batch()
+
+        except Exception:
+            verbose_logger.exception("Langsmith Layer Error - log_success_event error")
+
+    async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
+        try:
+            sampling_rate = self.sampling_rate
+            random_sample = random.random()
+            if random_sample > sampling_rate:
+                verbose_logger.info(
+                    "Skipping Langsmith logging. Sampling rate={}, random_sample={}".format(
+                        sampling_rate, random_sample
+                    )
+                )
+                return  # Skip logging
+            verbose_logger.debug(
+                "Langsmith Async Layer Logging - kwargs: %s, response_obj: %s",
+                kwargs,
+                response_obj,
+            )
+            payload: Optional[StandardLoggingPayload] = kwargs.get(
+                "standard_logging_object", None
+            )
+
+            data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
+
+            ## ALLOW CUSTOM LOGGERS TO MODIFY / FILTER DATA BEFORE LOGGING
+            for callback in litellm.callbacks:
+                if isinstance(callback, CustomLogger):
+                    try:
+                        if data is None:
+                            break
+                        data = await callback.async_dataset_hook(data, payload)
+                    except NotImplementedError:
+                        pass
+
+            if data is None:
+                return
+
+            self.log_queue.append(data)
+            verbose_logger.debug(
+                "Langsmith logging: queue length %s, batch size %s",
+                len(self.log_queue),
+                self.batch_size,
+            )
+            if len(self.log_queue) >= self.batch_size:
+                await self.flush_queue()
+        except Exception:
+            verbose_logger.exception(
+                "Argilla Layer Error - error logging async success event."
+            )
+
+    async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
+        sampling_rate = self.sampling_rate
+        random_sample = random.random()
+        if random_sample > sampling_rate:
+            verbose_logger.info(
+                "Skipping Langsmith logging. Sampling rate={}, random_sample={}".format(
+                    sampling_rate, random_sample
+                )
+            )
+            return  # Skip logging
+        verbose_logger.info("Langsmith Failure Event Logging!")
+        try:
+            data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
+            self.log_queue.append(data)
+            verbose_logger.debug(
+                "Langsmith logging: queue length %s, batch size %s",
+                len(self.log_queue),
+                self.batch_size,
+            )
+            if len(self.log_queue) >= self.batch_size:
+                await self.flush_queue()
+        except Exception:
+            verbose_logger.exception(
+                "Langsmith Layer Error - error logging async failure event."
+            )
+
+    async def async_send_batch(self):
+        """
+        sends runs to /batch endpoint
+
+        Sends runs from self.log_queue
+
+        Returns: None
+
+        Raises: Does not raise an exception, will only verbose_logger.exception()
+        """
+        if not self.log_queue:
+            return
+
+        argilla_api_base = self.default_credentials["ARGILLA_BASE_URL"]
+        argilla_dataset_name = self.default_credentials["ARGILLA_DATASET_NAME"]
+
+        url = f"{argilla_api_base}/api/v1/datasets/{argilla_dataset_name}/records/bulk"
+
+        argilla_api_key = self.default_credentials["ARGILLA_API_KEY"]
+
+        headers = {"X-Argilla-Api-Key": argilla_api_key}
+
+        try:
+            response = await self.async_httpx_client.put(
+                url=url,
+                data=json.dumps(
+                    {
+                        "items": self.log_queue,
+                    }
+                ),
+                headers=headers,
+                timeout=60000,
+            )
+            response.raise_for_status()
+
+            if response.status_code >= 300:
+                verbose_logger.error(
+                    f"Argilla Error: {response.status_code} - {response.text}"
+                )
+            else:
+                verbose_logger.debug(
+                    "Batch of %s runs successfully created", len(self.log_queue)
+                )
+        except httpx.HTTPStatusError:
+            verbose_logger.exception("Argilla HTTP Error")
+        except Exception:
+            verbose_logger.exception("Argilla Layer Error")