aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/batches/transformation.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/batches/transformation.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/batches/transformation.py193
1 files changed, 193 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/batches/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/batches/transformation.py
new file mode 100644
index 00000000..a97f312d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/batches/transformation.py
@@ -0,0 +1,193 @@
+import uuid
+from typing import Dict
+
+from litellm.llms.vertex_ai.common_utils import (
+ _convert_vertex_datetime_to_openai_datetime,
+)
+from litellm.types.llms.openai import BatchJobStatus, CreateBatchRequest
+from litellm.types.llms.vertex_ai import *
+from litellm.types.utils import LiteLLMBatch
+
+
+class VertexAIBatchTransformation:
+ """
+ Transforms OpenAI Batch requests to Vertex AI Batch requests
+
+ API Ref: https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/batch-prediction-gemini
+ """
+
+ @classmethod
+ def transform_openai_batch_request_to_vertex_ai_batch_request(
+ cls,
+ request: CreateBatchRequest,
+ ) -> VertexAIBatchPredictionJob:
+ """
+ Transforms OpenAI Batch requests to Vertex AI Batch requests
+ """
+ request_display_name = f"litellm-vertex-batch-{uuid.uuid4()}"
+ input_file_id = request.get("input_file_id")
+ if input_file_id is None:
+ raise ValueError("input_file_id is required, but not provided")
+ input_config: InputConfig = InputConfig(
+ gcsSource=GcsSource(uris=input_file_id), instancesFormat="jsonl"
+ )
+ model: str = cls._get_model_from_gcs_file(input_file_id)
+ output_config: OutputConfig = OutputConfig(
+ predictionsFormat="jsonl",
+ gcsDestination=GcsDestination(
+ outputUriPrefix=cls._get_gcs_uri_prefix_from_file(input_file_id)
+ ),
+ )
+ return VertexAIBatchPredictionJob(
+ inputConfig=input_config,
+ outputConfig=output_config,
+ model=model,
+ displayName=request_display_name,
+ )
+
+ @classmethod
+ def transform_vertex_ai_batch_response_to_openai_batch_response(
+ cls, response: VertexBatchPredictionResponse
+ ) -> LiteLLMBatch:
+ return LiteLLMBatch(
+ id=cls._get_batch_id_from_vertex_ai_batch_response(response),
+ completion_window="24hrs",
+ created_at=_convert_vertex_datetime_to_openai_datetime(
+ vertex_datetime=response.get("createTime", "")
+ ),
+ endpoint="",
+ input_file_id=cls._get_input_file_id_from_vertex_ai_batch_response(
+ response
+ ),
+ object="batch",
+ status=cls._get_batch_job_status_from_vertex_ai_batch_response(response),
+ error_file_id=None, # Vertex AI doesn't seem to have a direct equivalent
+ output_file_id=cls._get_output_file_id_from_vertex_ai_batch_response(
+ response
+ ),
+ )
+
+ @classmethod
+ def _get_batch_id_from_vertex_ai_batch_response(
+ cls, response: VertexBatchPredictionResponse
+ ) -> str:
+ """
+ Gets the batch id from the Vertex AI Batch response safely
+
+ vertex response: `projects/510528649030/locations/us-central1/batchPredictionJobs/3814889423749775360`
+ returns: `3814889423749775360`
+ """
+ _name = response.get("name", "")
+ if not _name:
+ return ""
+
+ # Split by '/' and get the last part if it exists
+ parts = _name.split("/")
+ return parts[-1] if parts else _name
+
+ @classmethod
+ def _get_input_file_id_from_vertex_ai_batch_response(
+ cls, response: VertexBatchPredictionResponse
+ ) -> str:
+ """
+ Gets the input file id from the Vertex AI Batch response
+ """
+ input_file_id: str = ""
+ input_config = response.get("inputConfig")
+ if input_config is None:
+ return input_file_id
+
+ gcs_source = input_config.get("gcsSource")
+ if gcs_source is None:
+ return input_file_id
+
+ uris = gcs_source.get("uris", "")
+ if len(uris) == 0:
+ return input_file_id
+
+ return uris[0]
+
+ @classmethod
+ def _get_output_file_id_from_vertex_ai_batch_response(
+ cls, response: VertexBatchPredictionResponse
+ ) -> str:
+ """
+ Gets the output file id from the Vertex AI Batch response
+ """
+ output_file_id: str = ""
+ output_config = response.get("outputConfig")
+ if output_config is None:
+ return output_file_id
+
+ gcs_destination = output_config.get("gcsDestination")
+ if gcs_destination is None:
+ return output_file_id
+
+ output_uri_prefix = gcs_destination.get("outputUriPrefix", "")
+ return output_uri_prefix
+
+ @classmethod
+ def _get_batch_job_status_from_vertex_ai_batch_response(
+ cls, response: VertexBatchPredictionResponse
+ ) -> BatchJobStatus:
+ """
+ Gets the batch job status from the Vertex AI Batch response
+
+ ref: https://cloud.google.com/vertex-ai/docs/reference/rest/v1/JobState
+ """
+ state_mapping: Dict[str, BatchJobStatus] = {
+ "JOB_STATE_UNSPECIFIED": "failed",
+ "JOB_STATE_QUEUED": "validating",
+ "JOB_STATE_PENDING": "validating",
+ "JOB_STATE_RUNNING": "in_progress",
+ "JOB_STATE_SUCCEEDED": "completed",
+ "JOB_STATE_FAILED": "failed",
+ "JOB_STATE_CANCELLING": "cancelling",
+ "JOB_STATE_CANCELLED": "cancelled",
+ "JOB_STATE_PAUSED": "in_progress",
+ "JOB_STATE_EXPIRED": "expired",
+ "JOB_STATE_UPDATING": "in_progress",
+ "JOB_STATE_PARTIALLY_SUCCEEDED": "completed",
+ }
+
+ vertex_state = response.get("state", "JOB_STATE_UNSPECIFIED")
+ return state_mapping[vertex_state]
+
+ @classmethod
+ def _get_gcs_uri_prefix_from_file(cls, input_file_id: str) -> str:
+ """
+ Gets the gcs uri prefix from the input file id
+
+ Example:
+ input_file_id: "gs://litellm-testing-bucket/vtx_batch.jsonl"
+ returns: "gs://litellm-testing-bucket"
+
+ input_file_id: "gs://litellm-testing-bucket/batches/vtx_batch.jsonl"
+ returns: "gs://litellm-testing-bucket/batches"
+ """
+ # Split the path and remove the filename
+ path_parts = input_file_id.rsplit("/", 1)
+ return path_parts[0]
+
+ @classmethod
+ def _get_model_from_gcs_file(cls, gcs_file_uri: str) -> str:
+ """
+ Extracts the model from the gcs file uri
+
+ When files are uploaded using LiteLLM (/v1/files), the model is stored in the gcs file uri
+
+ Why?
+ - Because Vertex Requires the `model` param in create batch jobs request, but OpenAI does not require this
+
+
+ gcs_file_uri format: gs://litellm-testing-bucket/litellm-vertex-files/publishers/google/models/gemini-1.5-flash-001/e9412502-2c91-42a6-8e61-f5c294cc0fc8
+ returns: "publishers/google/models/gemini-1.5-flash-001"
+ """
+ from urllib.parse import unquote
+
+ decoded_uri = unquote(gcs_file_uri)
+
+ model_path = decoded_uri.split("publishers/")[1]
+ parts = model_path.split("/")
+ model = f"publishers/{'/'.join(parts[:3])}"
+ return model