about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/batches
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/batches
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/batches')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/batches/Readme.md6
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/batches/handler.py218
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/batches/transformation.py193
3 files changed, 417 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/batches/Readme.md b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/batches/Readme.md
new file mode 100644
index 00000000..2aa7d7b0
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/batches/Readme.md
@@ -0,0 +1,6 @@
+# Vertex AI Batch Prediction Jobs
+
+Implementation to call VertexAI Batch endpoints in OpenAI Batch API spec
+
+Vertex Docs: https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/batch-prediction-gemini
+
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/batches/handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/batches/handler.py
new file mode 100644
index 00000000..b82268be
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/batches/handler.py
@@ -0,0 +1,218 @@
+import json
+from typing import Any, Coroutine, Dict, Optional, Union
+
+import httpx
+
+import litellm
+from litellm.llms.custom_httpx.http_handler import (
+    _get_httpx_client,
+    get_async_httpx_client,
+)
+from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
+from litellm.types.llms.openai import CreateBatchRequest
+from litellm.types.llms.vertex_ai import (
+    VERTEX_CREDENTIALS_TYPES,
+    VertexAIBatchPredictionJob,
+)
+from litellm.types.utils import LiteLLMBatch
+
+from .transformation import VertexAIBatchTransformation
+
+
+class VertexAIBatchPrediction(VertexLLM):
+    def __init__(self, gcs_bucket_name: str, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.gcs_bucket_name = gcs_bucket_name
+
+    def create_batch(
+        self,
+        _is_async: bool,
+        create_batch_data: CreateBatchRequest,
+        api_base: Optional[str],
+        vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
+        vertex_project: Optional[str],
+        vertex_location: Optional[str],
+        timeout: Union[float, httpx.Timeout],
+        max_retries: Optional[int],
+    ) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]:
+
+        sync_handler = _get_httpx_client()
+
+        access_token, project_id = self._ensure_access_token(
+            credentials=vertex_credentials,
+            project_id=vertex_project,
+            custom_llm_provider="vertex_ai",
+        )
+
+        default_api_base = self.create_vertex_url(
+            vertex_location=vertex_location or "us-central1",
+            vertex_project=vertex_project or project_id,
+        )
+
+        if len(default_api_base.split(":")) > 1:
+            endpoint = default_api_base.split(":")[-1]
+        else:
+            endpoint = ""
+
+        _, api_base = self._check_custom_proxy(
+            api_base=api_base,
+            custom_llm_provider="vertex_ai",
+            gemini_api_key=None,
+            endpoint=endpoint,
+            stream=None,
+            auth_header=None,
+            url=default_api_base,
+        )
+
+        headers = {
+            "Content-Type": "application/json; charset=utf-8",
+            "Authorization": f"Bearer {access_token}",
+        }
+
+        vertex_batch_request: VertexAIBatchPredictionJob = (
+            VertexAIBatchTransformation.transform_openai_batch_request_to_vertex_ai_batch_request(
+                request=create_batch_data
+            )
+        )
+
+        if _is_async is True:
+            return self._async_create_batch(
+                vertex_batch_request=vertex_batch_request,
+                api_base=api_base,
+                headers=headers,
+            )
+
+        response = sync_handler.post(
+            url=api_base,
+            headers=headers,
+            data=json.dumps(vertex_batch_request),
+        )
+
+        if response.status_code != 200:
+            raise Exception(f"Error: {response.status_code} {response.text}")
+
+        _json_response = response.json()
+        vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_response_to_openai_batch_response(
+            response=_json_response
+        )
+        return vertex_batch_response
+
+    async def _async_create_batch(
+        self,
+        vertex_batch_request: VertexAIBatchPredictionJob,
+        api_base: str,
+        headers: Dict[str, str],
+    ) -> LiteLLMBatch:
+        client = get_async_httpx_client(
+            llm_provider=litellm.LlmProviders.VERTEX_AI,
+        )
+        response = await client.post(
+            url=api_base,
+            headers=headers,
+            data=json.dumps(vertex_batch_request),
+        )
+        if response.status_code != 200:
+            raise Exception(f"Error: {response.status_code} {response.text}")
+
+        _json_response = response.json()
+        vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_response_to_openai_batch_response(
+            response=_json_response
+        )
+        return vertex_batch_response
+
+    def create_vertex_url(
+        self,
+        vertex_location: str,
+        vertex_project: str,
+    ) -> str:
+        """Return the base url for the vertex garden models"""
+        #  POST https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/LOCATION/batchPredictionJobs
+        return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/batchPredictionJobs"
+
+    def retrieve_batch(
+        self,
+        _is_async: bool,
+        batch_id: str,
+        api_base: Optional[str],
+        vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
+        vertex_project: Optional[str],
+        vertex_location: Optional[str],
+        timeout: Union[float, httpx.Timeout],
+        max_retries: Optional[int],
+    ) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]:
+        sync_handler = _get_httpx_client()
+
+        access_token, project_id = self._ensure_access_token(
+            credentials=vertex_credentials,
+            project_id=vertex_project,
+            custom_llm_provider="vertex_ai",
+        )
+
+        default_api_base = self.create_vertex_url(
+            vertex_location=vertex_location or "us-central1",
+            vertex_project=vertex_project or project_id,
+        )
+
+        # Append batch_id to the URL
+        default_api_base = f"{default_api_base}/{batch_id}"
+
+        if len(default_api_base.split(":")) > 1:
+            endpoint = default_api_base.split(":")[-1]
+        else:
+            endpoint = ""
+
+        _, api_base = self._check_custom_proxy(
+            api_base=api_base,
+            custom_llm_provider="vertex_ai",
+            gemini_api_key=None,
+            endpoint=endpoint,
+            stream=None,
+            auth_header=None,
+            url=default_api_base,
+        )
+
+        headers = {
+            "Content-Type": "application/json; charset=utf-8",
+            "Authorization": f"Bearer {access_token}",
+        }
+
+        if _is_async is True:
+            return self._async_retrieve_batch(
+                api_base=api_base,
+                headers=headers,
+            )
+
+        response = sync_handler.get(
+            url=api_base,
+            headers=headers,
+        )
+
+        if response.status_code != 200:
+            raise Exception(f"Error: {response.status_code} {response.text}")
+
+        _json_response = response.json()
+        vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_response_to_openai_batch_response(
+            response=_json_response
+        )
+        return vertex_batch_response
+
+    async def _async_retrieve_batch(
+        self,
+        api_base: str,
+        headers: Dict[str, str],
+    ) -> LiteLLMBatch:
+        client = get_async_httpx_client(
+            llm_provider=litellm.LlmProviders.VERTEX_AI,
+        )
+        response = await client.get(
+            url=api_base,
+            headers=headers,
+        )
+        if response.status_code != 200:
+            raise Exception(f"Error: {response.status_code} {response.text}")
+
+        _json_response = response.json()
+        vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_response_to_openai_batch_response(
+            response=_json_response
+        )
+        return vertex_batch_response
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/batches/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/batches/transformation.py
new file mode 100644
index 00000000..a97f312d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/batches/transformation.py
@@ -0,0 +1,193 @@
+import uuid
+from typing import Dict
+
+from litellm.llms.vertex_ai.common_utils import (
+    _convert_vertex_datetime_to_openai_datetime,
+)
+from litellm.types.llms.openai import BatchJobStatus, CreateBatchRequest
+from litellm.types.llms.vertex_ai import *
+from litellm.types.utils import LiteLLMBatch
+
+
+class VertexAIBatchTransformation:
+    """
+    Transforms OpenAI Batch requests to Vertex AI Batch requests
+
+    API Ref: https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/batch-prediction-gemini
+    """
+
+    @classmethod
+    def transform_openai_batch_request_to_vertex_ai_batch_request(
+        cls,
+        request: CreateBatchRequest,
+    ) -> VertexAIBatchPredictionJob:
+        """
+        Transforms OpenAI Batch requests to Vertex AI Batch requests
+        """
+        request_display_name = f"litellm-vertex-batch-{uuid.uuid4()}"
+        input_file_id = request.get("input_file_id")
+        if input_file_id is None:
+            raise ValueError("input_file_id is required, but not provided")
+        input_config: InputConfig = InputConfig(
+            gcsSource=GcsSource(uris=input_file_id), instancesFormat="jsonl"
+        )
+        model: str = cls._get_model_from_gcs_file(input_file_id)
+        output_config: OutputConfig = OutputConfig(
+            predictionsFormat="jsonl",
+            gcsDestination=GcsDestination(
+                outputUriPrefix=cls._get_gcs_uri_prefix_from_file(input_file_id)
+            ),
+        )
+        return VertexAIBatchPredictionJob(
+            inputConfig=input_config,
+            outputConfig=output_config,
+            model=model,
+            displayName=request_display_name,
+        )
+
+    @classmethod
+    def transform_vertex_ai_batch_response_to_openai_batch_response(
+        cls, response: VertexBatchPredictionResponse
+    ) -> LiteLLMBatch:
+        return LiteLLMBatch(
+            id=cls._get_batch_id_from_vertex_ai_batch_response(response),
+            completion_window="24hrs",
+            created_at=_convert_vertex_datetime_to_openai_datetime(
+                vertex_datetime=response.get("createTime", "")
+            ),
+            endpoint="",
+            input_file_id=cls._get_input_file_id_from_vertex_ai_batch_response(
+                response
+            ),
+            object="batch",
+            status=cls._get_batch_job_status_from_vertex_ai_batch_response(response),
+            error_file_id=None,  # Vertex AI doesn't seem to have a direct equivalent
+            output_file_id=cls._get_output_file_id_from_vertex_ai_batch_response(
+                response
+            ),
+        )
+
+    @classmethod
+    def _get_batch_id_from_vertex_ai_batch_response(
+        cls, response: VertexBatchPredictionResponse
+    ) -> str:
+        """
+        Gets the batch id from the Vertex AI Batch response safely
+
+        vertex response: `projects/510528649030/locations/us-central1/batchPredictionJobs/3814889423749775360`
+        returns: `3814889423749775360`
+        """
+        _name = response.get("name", "")
+        if not _name:
+            return ""
+
+        # Split by '/' and get the last part if it exists
+        parts = _name.split("/")
+        return parts[-1] if parts else _name
+
+    @classmethod
+    def _get_input_file_id_from_vertex_ai_batch_response(
+        cls, response: VertexBatchPredictionResponse
+    ) -> str:
+        """
+        Gets the input file id from the Vertex AI Batch response
+        """
+        input_file_id: str = ""
+        input_config = response.get("inputConfig")
+        if input_config is None:
+            return input_file_id
+
+        gcs_source = input_config.get("gcsSource")
+        if gcs_source is None:
+            return input_file_id
+
+        uris = gcs_source.get("uris", "")
+        if len(uris) == 0:
+            return input_file_id
+
+        return uris[0]
+
+    @classmethod
+    def _get_output_file_id_from_vertex_ai_batch_response(
+        cls, response: VertexBatchPredictionResponse
+    ) -> str:
+        """
+        Gets the output file id from the Vertex AI Batch response
+        """
+        output_file_id: str = ""
+        output_config = response.get("outputConfig")
+        if output_config is None:
+            return output_file_id
+
+        gcs_destination = output_config.get("gcsDestination")
+        if gcs_destination is None:
+            return output_file_id
+
+        output_uri_prefix = gcs_destination.get("outputUriPrefix", "")
+        return output_uri_prefix
+
+    @classmethod
+    def _get_batch_job_status_from_vertex_ai_batch_response(
+        cls, response: VertexBatchPredictionResponse
+    ) -> BatchJobStatus:
+        """
+        Gets the batch job status from the Vertex AI Batch response
+
+        ref: https://cloud.google.com/vertex-ai/docs/reference/rest/v1/JobState
+        """
+        state_mapping: Dict[str, BatchJobStatus] = {
+            "JOB_STATE_UNSPECIFIED": "failed",
+            "JOB_STATE_QUEUED": "validating",
+            "JOB_STATE_PENDING": "validating",
+            "JOB_STATE_RUNNING": "in_progress",
+            "JOB_STATE_SUCCEEDED": "completed",
+            "JOB_STATE_FAILED": "failed",
+            "JOB_STATE_CANCELLING": "cancelling",
+            "JOB_STATE_CANCELLED": "cancelled",
+            "JOB_STATE_PAUSED": "in_progress",
+            "JOB_STATE_EXPIRED": "expired",
+            "JOB_STATE_UPDATING": "in_progress",
+            "JOB_STATE_PARTIALLY_SUCCEEDED": "completed",
+        }
+
+        vertex_state = response.get("state", "JOB_STATE_UNSPECIFIED")
+        return state_mapping[vertex_state]
+
+    @classmethod
+    def _get_gcs_uri_prefix_from_file(cls, input_file_id: str) -> str:
+        """
+        Gets the gcs uri prefix from the input file id
+
+        Example:
+        input_file_id: "gs://litellm-testing-bucket/vtx_batch.jsonl"
+        returns: "gs://litellm-testing-bucket"
+
+        input_file_id: "gs://litellm-testing-bucket/batches/vtx_batch.jsonl"
+        returns: "gs://litellm-testing-bucket/batches"
+        """
+        # Split the path and remove the filename
+        path_parts = input_file_id.rsplit("/", 1)
+        return path_parts[0]
+
+    @classmethod
+    def _get_model_from_gcs_file(cls, gcs_file_uri: str) -> str:
+        """
+        Extracts the model from the gcs file uri
+
+        When files are uploaded using LiteLLM (/v1/files), the model is stored in the gcs file uri
+
+        Why?
+        - Because Vertex Requires the `model` param in create batch jobs request, but OpenAI does not require this
+
+
+        gcs_file_uri format: gs://litellm-testing-bucket/litellm-vertex-files/publishers/google/models/gemini-1.5-flash-001/e9412502-2c91-42a6-8e61-f5c294cc0fc8
+        returns: "publishers/google/models/gemini-1.5-flash-001"
+        """
+        from urllib.parse import unquote
+
+        decoded_uri = unquote(gcs_file_uri)
+
+        model_path = decoded_uri.split("publishers/")[1]
+        parts = model_path.split("/")
+        model = f"publishers/{'/'.join(parts[:3])}"
+        return model