diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/google/genai/_common.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/google/genai/_common.py | 272 |
1 files changed, 272 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/google/genai/_common.py b/.venv/lib/python3.12/site-packages/google/genai/_common.py new file mode 100644 index 00000000..3211fd40 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/google/genai/_common.py @@ -0,0 +1,272 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Common utilities for the SDK.""" + +import base64 +import datetime +import enum +import typing +from typing import Union +import uuid + +import pydantic +from pydantic import alias_generators + +from . import _api_client + + +def set_value_by_path(data, keys, value): + """Examples: + + set_value_by_path({}, ['a', 'b'], v) + -> {'a': {'b': v}} + set_value_by_path({}, ['a', 'b[]', c], [v1, v2]) + -> {'a': {'b': [{'c': v1}, {'c': v2}]}} + set_value_by_path({'a': {'b': [{'c': v1}, {'c': v2}]}}, ['a', 'b[]', 'd'], v3) + -> {'a': {'b': [{'c': v1, 'd': v3}, {'c': v2, 'd': v3}]}} + """ + if value is None: + return + for i, key in enumerate(keys[:-1]): + if key.endswith('[]'): + key_name = key[:-2] + if key_name not in data: + if isinstance(value, list): + data[key_name] = [{} for _ in range(len(value))] + else: + raise ValueError( + f'value {value} must be a list given an array path {key}' + ) + if isinstance(value, list): + for j, d in enumerate(data[key_name]): + set_value_by_path(d, keys[i + 1 :], value[j]) + else: + for d in data[key_name]: + set_value_by_path(d, keys[i + 1 :], value) + return + + data = data.setdefault(key, {}) + + existing_data = data.get(keys[-1]) + # If there is an existing value, merge, not overwrite. + if existing_data is not None: + # Don't overwrite existing non-empty value with new empty value. + # This is triggered when handling tuning datasets. + if not value: + pass + # Don't fail when overwriting value with same value + elif value == existing_data: + pass + # Instead of overwriting dictionary with another dictionary, merge them. + # This is important for handling training and validation datasets in tuning. + elif isinstance(existing_data, dict) and isinstance(value, dict): + # Merging dictionaries. Consider deep merging in the future. + existing_data.update(value) + else: + raise ValueError( + f'Cannot set value for an existing key. Key: {keys[-1]};' + f' Existing value: {existing_data}; New value: {value}.' + ) + else: + data[keys[-1]] = value + + +def get_value_by_path(data: object, keys: list[str]): + """Examples: + + get_value_by_path({'a': {'b': v}}, ['a', 'b']) + -> v + get_value_by_path({'a': {'b': [{'c': v1}, {'c': v2}]}}, ['a', 'b[]', 'c']) + -> [v1, v2] + """ + if keys == ['_self']: + return data + for i, key in enumerate(keys): + if not data: + return None + if key.endswith('[]'): + key_name = key[:-2] + if key_name in data: + return [get_value_by_path(d, keys[i + 1 :]) for d in data[key_name]] + else: + return None + else: + if key in data: + data = data[key] + elif isinstance(data, BaseModel) and hasattr(data, key): + data = getattr(data, key) + else: + return None + return data + + +class BaseModule: + + def __init__(self, api_client_: _api_client.ApiClient): + self._api_client = api_client_ + + +def convert_to_dict(obj: dict[str, object]) -> dict[str, object]: + """Recursively converts a given object to a dictionary. + + If the object is a Pydantic model, it uses the model's `model_dump()` method. + + Args: + obj: The object to convert. + + Returns: + A dictionary representation of the object. + """ + if isinstance(obj, pydantic.BaseModel): + return obj.model_dump(exclude_none=True) + elif isinstance(obj, dict): + return {key: convert_to_dict(value) for key, value in obj.items()} + elif isinstance(obj, list): + return [convert_to_dict(item) for item in obj] + else: + return obj + + +def _remove_extra_fields( + model: pydantic.BaseModel, response: dict[str, object] +) -> None: + """Removes extra fields from the response that are not in the model. + + Mutates the response in place. + """ + + key_values = list(response.items()) + + for key, value in key_values: + # Need to convert to snake case to match model fields names + # ex: UsageMetadata + alias_map = { + field_info.alias: key for key, field_info in model.model_fields.items() + } + + if key not in model.model_fields and key not in alias_map: + response.pop(key) + continue + + key = alias_map.get(key, key) + + annotation = model.model_fields[key].annotation + + # Get the BaseModel if Optional + if typing.get_origin(annotation) is Union: + annotation = typing.get_args(annotation)[0] + + # if dict, assume BaseModel but also check that field type is not dict + # example: FunctionCall.args + if isinstance(value, dict) and typing.get_origin(annotation) is not dict: + _remove_extra_fields(annotation, value) + elif isinstance(value, list): + for item in value: + # assume a list of dict is list of BaseModel + if isinstance(item, dict): + _remove_extra_fields(typing.get_args(annotation)[0], item) + + +class BaseModel(pydantic.BaseModel): + + model_config = pydantic.ConfigDict( + alias_generator=alias_generators.to_camel, + populate_by_name=True, + from_attributes=True, + protected_namespaces=(), + extra='forbid', + # This allows us to use arbitrary types in the model. E.g. PIL.Image. + arbitrary_types_allowed=True, + ser_json_bytes='base64', + val_json_bytes='base64', + ) + + @classmethod + def _from_response( + cls, response: dict[str, object], kwargs: dict[str, object] + ) -> 'BaseModel': + # To maintain forward compatibility, we need to remove extra fields from + # the response. + # We will provide another mechanism to allow users to access these fields. + _remove_extra_fields(cls, response) + validated_response = cls.model_validate(response) + return validated_response + + def to_json_dict(self) -> dict[str, object]: + return self.model_dump(exclude_none=True, mode='json') + + +class CaseInSensitiveEnum(str, enum.Enum): + """Case insensitive enum.""" + + @classmethod + def _missing_(cls, value): + try: + return cls[value.upper()] # Try to access directly with uppercase + except KeyError: + try: + return cls[value.lower()] # Try to access directly with lowercase + except KeyError as e: + raise ValueError(f"{value} is not a valid {cls.__name__}") from e + + +def timestamped_unique_name() -> str: + """Composes a timestamped unique name. + + Returns: + A string representing a unique name. + """ + timestamp = datetime.datetime.now().strftime('%Y%m%d%H%M%S') + unique_id = uuid.uuid4().hex[0:5] + return f'{timestamp}_{unique_id}' + + +def encode_unserializable_types(data: dict[str, object]) -> dict[str, object]: + """Converts unserializable types in dict to json.dumps() compatible types. + + This function is called in models.py after calling convert_to_dict(). The + convert_to_dict() can convert pydantic object to dict. However, the input to + convert_to_dict() is dict mixed of pydantic object and nested dict(the output + of converters). So they may be bytes in the dict and they are out of + `ser_json_bytes` control in model_dump(mode='json') called in + `convert_to_dict`, as well as datetime deserialization in Pydantic json mode. + + Returns: + A dictionary with json.dumps() incompatible type (e.g. bytes datetime) + to compatible type (e.g. base64 encoded string, isoformat date string). + """ + processed_data = {} + if not isinstance(data, dict): + return data + for key, value in data.items(): + if isinstance(value, bytes): + processed_data[key] = base64.urlsafe_b64encode(value).decode('ascii') + elif isinstance(value, datetime.datetime): + processed_data[key] = value.isoformat() + elif isinstance(value, dict): + processed_data[key] = encode_unserializable_types(value) + elif isinstance(value, list): + if all(isinstance(v, bytes) for v in value): + processed_data[key] = [ + base64.urlsafe_b64encode(v).decode('ascii') for v in value + ] + if all(isinstance(v, datetime.datetime) for v in value): + processed_data[key] = [v.isoformat() for v in value] + else: + processed_data[key] = [encode_unserializable_types(v) for v in value] + else: + processed_data[key] = value + return processed_data |