diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/batches | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/batches')
3 files changed, 417 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/batches/Readme.md b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/batches/Readme.md new file mode 100644 index 00000000..2aa7d7b0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/batches/Readme.md @@ -0,0 +1,6 @@ +# Vertex AI Batch Prediction Jobs + +Implementation to call VertexAI Batch endpoints in OpenAI Batch API spec + +Vertex Docs: https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/batch-prediction-gemini + diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/batches/handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/batches/handler.py new file mode 100644 index 00000000..b82268be --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/batches/handler.py @@ -0,0 +1,218 @@ +import json +from typing import Any, Coroutine, Dict, Optional, Union + +import httpx + +import litellm +from litellm.llms.custom_httpx.http_handler import ( + _get_httpx_client, + get_async_httpx_client, +) +from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM +from litellm.types.llms.openai import CreateBatchRequest +from litellm.types.llms.vertex_ai import ( + VERTEX_CREDENTIALS_TYPES, + VertexAIBatchPredictionJob, +) +from litellm.types.utils import LiteLLMBatch + +from .transformation import VertexAIBatchTransformation + + +class VertexAIBatchPrediction(VertexLLM): + def __init__(self, gcs_bucket_name: str, *args, **kwargs): + super().__init__(*args, **kwargs) + self.gcs_bucket_name = gcs_bucket_name + + def create_batch( + self, + _is_async: bool, + create_batch_data: CreateBatchRequest, + api_base: Optional[str], + vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES], + vertex_project: Optional[str], + vertex_location: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + ) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]: + + sync_handler = _get_httpx_client() + + access_token, project_id = self._ensure_access_token( + credentials=vertex_credentials, + project_id=vertex_project, + custom_llm_provider="vertex_ai", + ) + + default_api_base = self.create_vertex_url( + vertex_location=vertex_location or "us-central1", + vertex_project=vertex_project or project_id, + ) + + if len(default_api_base.split(":")) > 1: + endpoint = default_api_base.split(":")[-1] + else: + endpoint = "" + + _, api_base = self._check_custom_proxy( + api_base=api_base, + custom_llm_provider="vertex_ai", + gemini_api_key=None, + endpoint=endpoint, + stream=None, + auth_header=None, + url=default_api_base, + ) + + headers = { + "Content-Type": "application/json; charset=utf-8", + "Authorization": f"Bearer {access_token}", + } + + vertex_batch_request: VertexAIBatchPredictionJob = ( + VertexAIBatchTransformation.transform_openai_batch_request_to_vertex_ai_batch_request( + request=create_batch_data + ) + ) + + if _is_async is True: + return self._async_create_batch( + vertex_batch_request=vertex_batch_request, + api_base=api_base, + headers=headers, + ) + + response = sync_handler.post( + url=api_base, + headers=headers, + data=json.dumps(vertex_batch_request), + ) + + if response.status_code != 200: + raise Exception(f"Error: {response.status_code} {response.text}") + + _json_response = response.json() + vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_response_to_openai_batch_response( + response=_json_response + ) + return vertex_batch_response + + async def _async_create_batch( + self, + vertex_batch_request: VertexAIBatchPredictionJob, + api_base: str, + headers: Dict[str, str], + ) -> LiteLLMBatch: + client = get_async_httpx_client( + llm_provider=litellm.LlmProviders.VERTEX_AI, + ) + response = await client.post( + url=api_base, + headers=headers, + data=json.dumps(vertex_batch_request), + ) + if response.status_code != 200: + raise Exception(f"Error: {response.status_code} {response.text}") + + _json_response = response.json() + vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_response_to_openai_batch_response( + response=_json_response + ) + return vertex_batch_response + + def create_vertex_url( + self, + vertex_location: str, + vertex_project: str, + ) -> str: + """Return the base url for the vertex garden models""" + # POST https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/LOCATION/batchPredictionJobs + return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/batchPredictionJobs" + + def retrieve_batch( + self, + _is_async: bool, + batch_id: str, + api_base: Optional[str], + vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES], + vertex_project: Optional[str], + vertex_location: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + ) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]: + sync_handler = _get_httpx_client() + + access_token, project_id = self._ensure_access_token( + credentials=vertex_credentials, + project_id=vertex_project, + custom_llm_provider="vertex_ai", + ) + + default_api_base = self.create_vertex_url( + vertex_location=vertex_location or "us-central1", + vertex_project=vertex_project or project_id, + ) + + # Append batch_id to the URL + default_api_base = f"{default_api_base}/{batch_id}" + + if len(default_api_base.split(":")) > 1: + endpoint = default_api_base.split(":")[-1] + else: + endpoint = "" + + _, api_base = self._check_custom_proxy( + api_base=api_base, + custom_llm_provider="vertex_ai", + gemini_api_key=None, + endpoint=endpoint, + stream=None, + auth_header=None, + url=default_api_base, + ) + + headers = { + "Content-Type": "application/json; charset=utf-8", + "Authorization": f"Bearer {access_token}", + } + + if _is_async is True: + return self._async_retrieve_batch( + api_base=api_base, + headers=headers, + ) + + response = sync_handler.get( + url=api_base, + headers=headers, + ) + + if response.status_code != 200: + raise Exception(f"Error: {response.status_code} {response.text}") + + _json_response = response.json() + vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_response_to_openai_batch_response( + response=_json_response + ) + return vertex_batch_response + + async def _async_retrieve_batch( + self, + api_base: str, + headers: Dict[str, str], + ) -> LiteLLMBatch: + client = get_async_httpx_client( + llm_provider=litellm.LlmProviders.VERTEX_AI, + ) + response = await client.get( + url=api_base, + headers=headers, + ) + if response.status_code != 200: + raise Exception(f"Error: {response.status_code} {response.text}") + + _json_response = response.json() + vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_response_to_openai_batch_response( + response=_json_response + ) + return vertex_batch_response diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/batches/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/batches/transformation.py new file mode 100644 index 00000000..a97f312d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/batches/transformation.py @@ -0,0 +1,193 @@ +import uuid +from typing import Dict + +from litellm.llms.vertex_ai.common_utils import ( + _convert_vertex_datetime_to_openai_datetime, +) +from litellm.types.llms.openai import BatchJobStatus, CreateBatchRequest +from litellm.types.llms.vertex_ai import * +from litellm.types.utils import LiteLLMBatch + + +class VertexAIBatchTransformation: + """ + Transforms OpenAI Batch requests to Vertex AI Batch requests + + API Ref: https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/batch-prediction-gemini + """ + + @classmethod + def transform_openai_batch_request_to_vertex_ai_batch_request( + cls, + request: CreateBatchRequest, + ) -> VertexAIBatchPredictionJob: + """ + Transforms OpenAI Batch requests to Vertex AI Batch requests + """ + request_display_name = f"litellm-vertex-batch-{uuid.uuid4()}" + input_file_id = request.get("input_file_id") + if input_file_id is None: + raise ValueError("input_file_id is required, but not provided") + input_config: InputConfig = InputConfig( + gcsSource=GcsSource(uris=input_file_id), instancesFormat="jsonl" + ) + model: str = cls._get_model_from_gcs_file(input_file_id) + output_config: OutputConfig = OutputConfig( + predictionsFormat="jsonl", + gcsDestination=GcsDestination( + outputUriPrefix=cls._get_gcs_uri_prefix_from_file(input_file_id) + ), + ) + return VertexAIBatchPredictionJob( + inputConfig=input_config, + outputConfig=output_config, + model=model, + displayName=request_display_name, + ) + + @classmethod + def transform_vertex_ai_batch_response_to_openai_batch_response( + cls, response: VertexBatchPredictionResponse + ) -> LiteLLMBatch: + return LiteLLMBatch( + id=cls._get_batch_id_from_vertex_ai_batch_response(response), + completion_window="24hrs", + created_at=_convert_vertex_datetime_to_openai_datetime( + vertex_datetime=response.get("createTime", "") + ), + endpoint="", + input_file_id=cls._get_input_file_id_from_vertex_ai_batch_response( + response + ), + object="batch", + status=cls._get_batch_job_status_from_vertex_ai_batch_response(response), + error_file_id=None, # Vertex AI doesn't seem to have a direct equivalent + output_file_id=cls._get_output_file_id_from_vertex_ai_batch_response( + response + ), + ) + + @classmethod + def _get_batch_id_from_vertex_ai_batch_response( + cls, response: VertexBatchPredictionResponse + ) -> str: + """ + Gets the batch id from the Vertex AI Batch response safely + + vertex response: `projects/510528649030/locations/us-central1/batchPredictionJobs/3814889423749775360` + returns: `3814889423749775360` + """ + _name = response.get("name", "") + if not _name: + return "" + + # Split by '/' and get the last part if it exists + parts = _name.split("/") + return parts[-1] if parts else _name + + @classmethod + def _get_input_file_id_from_vertex_ai_batch_response( + cls, response: VertexBatchPredictionResponse + ) -> str: + """ + Gets the input file id from the Vertex AI Batch response + """ + input_file_id: str = "" + input_config = response.get("inputConfig") + if input_config is None: + return input_file_id + + gcs_source = input_config.get("gcsSource") + if gcs_source is None: + return input_file_id + + uris = gcs_source.get("uris", "") + if len(uris) == 0: + return input_file_id + + return uris[0] + + @classmethod + def _get_output_file_id_from_vertex_ai_batch_response( + cls, response: VertexBatchPredictionResponse + ) -> str: + """ + Gets the output file id from the Vertex AI Batch response + """ + output_file_id: str = "" + output_config = response.get("outputConfig") + if output_config is None: + return output_file_id + + gcs_destination = output_config.get("gcsDestination") + if gcs_destination is None: + return output_file_id + + output_uri_prefix = gcs_destination.get("outputUriPrefix", "") + return output_uri_prefix + + @classmethod + def _get_batch_job_status_from_vertex_ai_batch_response( + cls, response: VertexBatchPredictionResponse + ) -> BatchJobStatus: + """ + Gets the batch job status from the Vertex AI Batch response + + ref: https://cloud.google.com/vertex-ai/docs/reference/rest/v1/JobState + """ + state_mapping: Dict[str, BatchJobStatus] = { + "JOB_STATE_UNSPECIFIED": "failed", + "JOB_STATE_QUEUED": "validating", + "JOB_STATE_PENDING": "validating", + "JOB_STATE_RUNNING": "in_progress", + "JOB_STATE_SUCCEEDED": "completed", + "JOB_STATE_FAILED": "failed", + "JOB_STATE_CANCELLING": "cancelling", + "JOB_STATE_CANCELLED": "cancelled", + "JOB_STATE_PAUSED": "in_progress", + "JOB_STATE_EXPIRED": "expired", + "JOB_STATE_UPDATING": "in_progress", + "JOB_STATE_PARTIALLY_SUCCEEDED": "completed", + } + + vertex_state = response.get("state", "JOB_STATE_UNSPECIFIED") + return state_mapping[vertex_state] + + @classmethod + def _get_gcs_uri_prefix_from_file(cls, input_file_id: str) -> str: + """ + Gets the gcs uri prefix from the input file id + + Example: + input_file_id: "gs://litellm-testing-bucket/vtx_batch.jsonl" + returns: "gs://litellm-testing-bucket" + + input_file_id: "gs://litellm-testing-bucket/batches/vtx_batch.jsonl" + returns: "gs://litellm-testing-bucket/batches" + """ + # Split the path and remove the filename + path_parts = input_file_id.rsplit("/", 1) + return path_parts[0] + + @classmethod + def _get_model_from_gcs_file(cls, gcs_file_uri: str) -> str: + """ + Extracts the model from the gcs file uri + + When files are uploaded using LiteLLM (/v1/files), the model is stored in the gcs file uri + + Why? + - Because Vertex Requires the `model` param in create batch jobs request, but OpenAI does not require this + + + gcs_file_uri format: gs://litellm-testing-bucket/litellm-vertex-files/publishers/google/models/gemini-1.5-flash-001/e9412502-2c91-42a6-8e61-f5c294cc0fc8 + returns: "publishers/google/models/gemini-1.5-flash-001" + """ + from urllib.parse import unquote + + decoded_uri = unquote(gcs_file_uri) + + model_path = decoded_uri.split("publishers/")[1] + parts = model_path.split("/") + model = f"publishers/{'/'.join(parts[:3])}" + return model |