diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/google/genai/_transformers.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/google/genai/_transformers.py | 621 |
1 files changed, 621 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/google/genai/_transformers.py b/.venv/lib/python3.12/site-packages/google/genai/_transformers.py new file mode 100644 index 00000000..f1b392e9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/google/genai/_transformers.py @@ -0,0 +1,621 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Transformers for Google GenAI SDK.""" + +import base64 +from collections.abc import Iterable, Mapping +import inspect +import io +import re +import time +import typing +from typing import Any, GenericAlias, Optional, Union + +import PIL.Image +import PIL.PngImagePlugin +import pydantic + +from . import _api_client +from . import types + + +def _resource_name( + client: _api_client.ApiClient, + resource_name: str, + *, + collection_identifier: str, + collection_hierarchy_depth: int = 2, +): + # pylint: disable=line-too-long + """Prepends resource name with project, location, collection_identifier if needed. + + The collection_identifier will only be prepended if it's not present + and the prepending won't violate the collection hierarchy depth. + When the prepending condition doesn't meet, returns the input + resource_name. + + Args: + client: The API client. + resource_name: The user input resource name to be completed. + collection_identifier: The collection identifier to be prepended. See + collection identifiers in https://google.aip.dev/122. + collection_hierarchy_depth: The collection hierarchy depth. Only set this + field when the resource has nested collections. For example, + `users/vhugo1802/events/birthday-dinner-226`, the collection_identifier is + `users` and collection_hierarchy_depth is 4. See nested collections in + https://google.aip.dev/122. + + Example: + + resource_name = 'cachedContents/123' + client.vertexai = True + client.project = 'bar' + client.location = 'us-west1' + _resource_name(client, 'cachedContents/123', + collection_identifier='cachedContents') + returns: 'projects/bar/locations/us-west1/cachedContents/123' + + Example: + + resource_name = 'projects/foo/locations/us-central1/cachedContents/123' + # resource_name = 'locations/us-central1/cachedContents/123' + client.vertexai = True + client.project = 'bar' + client.location = 'us-west1' + _resource_name(client, resource_name, + collection_identifier='cachedContents') + returns: 'projects/foo/locations/us-central1/cachedContents/123' + + Example: + + resource_name = '123' + # resource_name = 'cachedContents/123' + client.vertexai = False + _resource_name(client, resource_name, + collection_identifier='cachedContents') + returns 'cachedContents/123' + + Example: + resource_name = 'some/wrong/cachedContents/resource/name/123' + resource_prefix = 'cachedContents' + client.vertexai = False + # client.vertexai = True + _resource_name(client, resource_name, + collection_identifier='cachedContents') + returns: 'some/wrong/cachedContents/resource/name/123' + + Returns: + The completed resource name. + """ + should_prepend_collection_identifier = ( + not resource_name.startswith(f'{collection_identifier}/') + # Check if prepending the collection identifier won't violate the + # collection hierarchy depth. + and f'{collection_identifier}/{resource_name}'.count('/') + 1 + == collection_hierarchy_depth + ) + if client.vertexai: + if resource_name.startswith('projects/'): + return resource_name + elif resource_name.startswith('locations/'): + return f'projects/{client.project}/{resource_name}' + elif resource_name.startswith(f'{collection_identifier}/'): + return f'projects/{client.project}/locations/{client.location}/{resource_name}' + elif should_prepend_collection_identifier: + return f'projects/{client.project}/locations/{client.location}/{collection_identifier}/{resource_name}' + else: + return resource_name + else: + if should_prepend_collection_identifier: + return f'{collection_identifier}/{resource_name}' + else: + return resource_name + + +def t_model(client: _api_client.ApiClient, model: str): + if not model: + raise ValueError('model is required.') + if client.vertexai: + if ( + model.startswith('projects/') + or model.startswith('models/') + or model.startswith('publishers/') + ): + return model + elif '/' in model: + publisher, model_id = model.split('/', 1) + return f'publishers/{publisher}/models/{model_id}' + else: + return f'publishers/google/models/{model}' + else: + if model.startswith('models/'): + return model + elif model.startswith('tunedModels/'): + return model + else: + return f'models/{model}' + + +def t_models_url(api_client: _api_client.ApiClient, base_models: bool) -> str: + if api_client.vertexai: + if base_models: + return 'publishers/google/models' + else: + return 'models' + else: + if base_models: + return 'models' + else: + return 'tunedModels' + + +def t_extract_models( + api_client: _api_client.ApiClient, response: dict +) -> list[types.Model]: + if not response: + return [] + elif response.get('models') is not None: + return response.get('models') + elif response.get('tunedModels') is not None: + return response.get('tunedModels') + elif response.get('publisherModels') is not None: + return response.get('publisherModels') + else: + raise ValueError('Cannot determine the models type.') + + +def t_caches_model(api_client: _api_client.ApiClient, model: str): + model = t_model(api_client, model) + if not model: + return None + if model.startswith('publishers/') and api_client.vertexai: + # vertex caches only support model name start with projects. + return ( + f'projects/{api_client.project}/locations/{api_client.location}/{model}' + ) + elif model.startswith('models/') and api_client.vertexai: + return f'projects/{api_client.project}/locations/{api_client.location}/publishers/google/{model}' + else: + return model + + +def pil_to_blob(img): + bytesio = io.BytesIO() + if isinstance(img, PIL.PngImagePlugin.PngImageFile) or img.mode == 'RGBA': + img.save(bytesio, format='PNG') + mime_type = 'image/png' + else: + img.save(bytesio, format='JPEG') + mime_type = 'image/jpeg' + bytesio.seek(0) + data = bytesio.read() + return types.Blob(mime_type=mime_type, data=data) + + +PartType = Union[types.Part, types.PartDict, str, PIL.Image.Image] + + +def t_part(client: _api_client.ApiClient, part: PartType) -> types.Part: + if not part: + raise ValueError('content part is required.') + if isinstance(part, str): + return types.Part(text=part) + if isinstance(part, PIL.Image.Image): + return types.Part(inline_data=pil_to_blob(part)) + if isinstance(part, types.File): + if not part.uri or not part.mime_type: + raise ValueError('file uri and mime_type are required.') + return types.Part.from_uri(part.uri, part.mime_type) + else: + return part + + +def t_parts( + client: _api_client.ApiClient, parts: Union[list, PartType] +) -> list[types.Part]: + if parts is None: + raise ValueError('content parts are required.') + if isinstance(parts, list): + return [t_part(client, part) for part in parts] + else: + return [t_part(client, parts)] + + +def t_image_predictions( + client: _api_client.ApiClient, + predictions: Optional[Iterable[Mapping[str, Any]]], +) -> list[types.GeneratedImage]: + if not predictions: + return None + images = [] + for prediction in predictions: + if prediction.get('image'): + images.append( + types.GeneratedImage( + image=types.Image( + gcs_uri=prediction['image']['gcsUri'], + image_bytes=prediction['image']['imageBytes'], + ) + ) + ) + return images + + +ContentType = Union[types.Content, types.ContentDict, PartType] + + +def t_content( + client: _api_client.ApiClient, + content: ContentType, +): + if not content: + raise ValueError('content is required.') + if isinstance(content, types.Content): + return content + if isinstance(content, dict): + return types.Content.model_validate(content) + return types.Content(role='user', parts=t_parts(client, content)) + + +def t_contents_for_embed( + client: _api_client.ApiClient, + contents: Union[list[types.Content], list[types.ContentDict], ContentType], +): + if client.vertexai and isinstance(contents, list): + # TODO: Assert that only text is supported. + return [t_content(client, content).parts[0].text for content in contents] + elif client.vertexai: + return [t_content(client, contents).parts[0].text] + elif isinstance(contents, list): + return [t_content(client, content) for content in contents] + else: + return [t_content(client, contents)] + + +def t_contents( + client: _api_client.ApiClient, + contents: Union[list[types.Content], list[types.ContentDict], ContentType], +): + if not contents: + raise ValueError('contents are required.') + if isinstance(contents, list): + return [t_content(client, content) for content in contents] + else: + return [t_content(client, contents)] + + +def process_schema( + data: dict[str, Any], client: Optional[_api_client.ApiClient] = None +): + if isinstance(data, dict): + # Iterate over a copy of keys to allow deletion + for key in list(data.keys()): + # Only delete 'title'for the Gemini API + if client and not client.vertexai and key == 'title': + del data[key] + else: + process_schema(data[key], client) + elif isinstance(data, list): + for item in data: + process_schema(item, client) + + return data + + +def _build_schema(fname: str, fields_dict: dict[str, Any]) -> dict[str, Any]: + parameters = pydantic.create_model(fname, **fields_dict).model_json_schema() + defs = parameters.pop('$defs', {}) + + for _, value in defs.items(): + unpack_defs(value, defs) + + unpack_defs(parameters, defs) + return parameters['properties']['dummy'] + + +def unpack_defs(schema: dict[str, Any], defs: dict[str, Any]): + """Unpacks the $defs values in the schema generated by pydantic so they can be understood by the API. + + Example of a schema before and after unpacking: + Before: + + `schema` + + {'properties': { + 'dummy': { + 'items': { + '$ref': '#/$defs/CountryInfo' + }, + 'title': 'Dummy', + 'type': 'array' + } + }, + 'required': ['dummy'], + 'title': 'dummy', + 'type': 'object'} + + `defs` + + {'CountryInfo': {'properties': {'continent': {'title': 'Continent', 'type': + 'string'}, 'gdp': {'title': 'Gdp', 'type': 'integer'}}, 'required': + ['continent', 'gdp'], 'title': 'CountryInfo', 'type': 'object'}} + + After: + + `schema` + {'properties': { + 'continent': {'title': 'Continent', 'type': 'string'}, + 'gdp': {'title': 'Gdp', 'type': 'integer'} + }, + 'required': ['continent', 'gdp'], + 'title': 'CountryInfo', + 'type': 'object' + } + """ + properties = schema.get('properties', None) + if properties is None: + return + + for name, value in properties.items(): + ref_key = value.get('$ref', None) + if ref_key is not None: + ref = defs[ref_key.split('defs/')[-1]] + unpack_defs(ref, defs) + properties[name] = ref + continue + + anyof = value.get('anyOf', None) + if anyof is not None: + for i, atype in enumerate(anyof): + ref_key = atype.get('$ref', None) + if ref_key is not None: + ref = defs[ref_key.split('defs/')[-1]] + unpack_defs(ref, defs) + anyof[i] = ref + continue + + items = value.get('items', None) + if items is not None: + ref_key = items.get('$ref', None) + if ref_key is not None: + ref = defs[ref_key.split('defs/')[-1]] + unpack_defs(ref, defs) + value['items'] = ref + continue + + +def t_schema( + client: _api_client.ApiClient, origin: Union[types.SchemaUnionDict, Any] +) -> Optional[types.Schema]: + if not origin: + return None + if isinstance(origin, dict): + return process_schema(origin, client) + if isinstance(origin, types.Schema): + if dict(origin) == dict(types.Schema()): + # response_schema value was coerced to an empty Schema instance because it did not adhere to the Schema field annotation + raise ValueError(f'Unsupported schema type.') + schema = process_schema(origin.model_dump(exclude_unset=True), client) + return types.Schema.model_validate(schema) + if isinstance(origin, GenericAlias): + if origin.__origin__ is list: + if isinstance(origin.__args__[0], typing.types.UnionType): + raise ValueError(f'Unsupported schema type: GenericAlias {origin}') + if issubclass(origin.__args__[0], pydantic.BaseModel): + # Handle cases where response schema is `list[pydantic.BaseModel]` + list_schema = _build_schema( + 'dummy', {'dummy': (origin, pydantic.Field())} + ) + list_schema = process_schema(list_schema, client) + return types.Schema.model_validate(list_schema) + raise ValueError(f'Unsupported schema type: GenericAlias {origin}') + if issubclass(origin, pydantic.BaseModel): + schema = process_schema(origin.model_json_schema(), client) + return types.Schema.model_validate(schema) + raise ValueError(f'Unsupported schema type: {origin}') + + +def t_speech_config( + _: _api_client.ApiClient, origin: Union[types.SpeechConfigUnionDict, Any] +) -> Optional[types.SpeechConfig]: + if not origin: + return None + if isinstance(origin, types.SpeechConfig): + return origin + if isinstance(origin, str): + return types.SpeechConfig( + voice_config=types.VoiceConfig( + prebuilt_voice_config=types.PrebuiltVoiceConfig(voice_name=origin) + ) + ) + if ( + isinstance(origin, dict) + and 'voice_config' in origin + and 'prebuilt_voice_config' in origin['voice_config'] + ): + return types.SpeechConfig( + voice_config=types.VoiceConfig( + prebuilt_voice_config=types.PrebuiltVoiceConfig( + voice_name=origin['voice_config']['prebuilt_voice_config'].get( + 'voice_name' + ) + ) + ) + ) + raise ValueError(f'Unsupported speechConfig type: {type(origin)}') + + +def t_tool(client: _api_client.ApiClient, origin) -> types.Tool: + if not origin: + return None + if inspect.isfunction(origin) or inspect.ismethod(origin): + return types.Tool( + function_declarations=[ + types.FunctionDeclaration.from_callable(client, origin) + ] + ) + else: + return origin + + +# Only support functions now. +def t_tools( + client: _api_client.ApiClient, origin: list[Any] +) -> list[types.Tool]: + if not origin: + return [] + function_tool = types.Tool(function_declarations=[]) + tools = [] + for tool in origin: + transformed_tool = t_tool(client, tool) + # All functions should be merged into one tool. + if transformed_tool.function_declarations: + function_tool.function_declarations += ( + transformed_tool.function_declarations + ) + else: + tools.append(transformed_tool) + if function_tool.function_declarations: + tools.append(function_tool) + return tools + + +def t_cached_content_name(client: _api_client.ApiClient, name: str): + return _resource_name(client, name, collection_identifier='cachedContents') + + +def t_batch_job_source(client: _api_client.ApiClient, src: str): + if src.startswith('gs://'): + return types.BatchJobSource( + format='jsonl', + gcs_uri=[src], + ) + elif src.startswith('bq://'): + return types.BatchJobSource( + format='bigquery', + bigquery_uri=src, + ) + else: + raise ValueError(f'Unsupported source: {src}') + + +def t_batch_job_destination(client: _api_client.ApiClient, dest: str): + if dest.startswith('gs://'): + return types.BatchJobDestination( + format='jsonl', + gcs_uri=dest, + ) + elif dest.startswith('bq://'): + return types.BatchJobDestination( + format='bigquery', + bigquery_uri=dest, + ) + else: + raise ValueError(f'Unsupported destination: {dest}') + + +def t_batch_job_name(client: _api_client.ApiClient, name: str): + if not client.vertexai: + return name + + pattern = r'^projects/[^/]+/locations/[^/]+/batchPredictionJobs/[^/]+$' + if re.match(pattern, name): + return name.split('/')[-1] + elif name.isdigit(): + return name + else: + raise ValueError(f'Invalid batch job name: {name}.') + + +LRO_POLLING_INITIAL_DELAY_SECONDS = 1.0 +LRO_POLLING_MAXIMUM_DELAY_SECONDS = 20.0 +LRO_POLLING_TIMEOUT_SECONDS = 900.0 +LRO_POLLING_MULTIPLIER = 1.5 + + +def t_resolve_operation(api_client: _api_client.ApiClient, struct: dict): + if (name := struct.get('name')) and '/operations/' in name: + operation: dict[str, Any] = struct + total_seconds = 0.0 + delay_seconds = LRO_POLLING_INITIAL_DELAY_SECONDS + while operation.get('done') != True: + if total_seconds > LRO_POLLING_TIMEOUT_SECONDS: + raise RuntimeError(f'Operation {name} timed out.\n{operation}') + # TODO(b/374433890): Replace with LRO module once it's available. + operation: dict[str, Any] = api_client.request( + http_method='GET', path=name, request_dict={} + ) + time.sleep(delay_seconds) + total_seconds += total_seconds + # Exponential backoff + delay_seconds = min( + delay_seconds * LRO_POLLING_MULTIPLIER, + LRO_POLLING_MAXIMUM_DELAY_SECONDS, + ) + if error := operation.get('error'): + raise RuntimeError( + f'Operation {name} failed with error: {error}.\n{operation}' + ) + return operation.get('response') + else: + return struct + + +def t_file_name( + api_client: _api_client.ApiClient, name: Union[str, types.File] +): + # Remove the files/ prefix since it's added to the url path. + if isinstance(name, types.File): + name = name.name + + if name is None: + raise ValueError('File name is required.') + + if name.startswith('https://'): + suffix = name.split('files/')[1] + match = re.match('[a-z0-9]+', suffix) + if match is None: + raise ValueError(f'Could not extract file name from URI: {name}') + name = match.group(0) + elif name.startswith('files/'): + name = name.split('files/')[1] + + return name + + +def t_tuning_job_status( + api_client: _api_client.ApiClient, status: str +) -> types.JobState: + if status == 'STATE_UNSPECIFIED': + return 'JOB_STATE_UNSPECIFIED' + elif status == 'CREATING': + return 'JOB_STATE_RUNNING' + elif status == 'ACTIVE': + return 'JOB_STATE_SUCCEEDED' + elif status == 'FAILED': + return 'JOB_STATE_FAILED' + else: + return status + + +# Some fields don't accept url safe base64 encoding. +# We shouldn't use this transformer if the backend adhere to Cloud Type +# format https://cloud.google.com/docs/discovery/type-format. +# TODO(b/389133914,b/390320301): Remove the hack after backend fix the issue. +def t_bytes(api_client: _api_client.ApiClient, data: bytes) -> str: + if not isinstance(data, bytes): + return data + return base64.b64encode(data).decode('ascii') |