# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Common utilities for the SDK.""" import base64 import datetime import enum import typing from typing import Union import uuid import pydantic from pydantic import alias_generators from . import _api_client def set_value_by_path(data, keys, value): """Examples: set_value_by_path({}, ['a', 'b'], v) -> {'a': {'b': v}} set_value_by_path({}, ['a', 'b[]', c], [v1, v2]) -> {'a': {'b': [{'c': v1}, {'c': v2}]}} set_value_by_path({'a': {'b': [{'c': v1}, {'c': v2}]}}, ['a', 'b[]', 'd'], v3) -> {'a': {'b': [{'c': v1, 'd': v3}, {'c': v2, 'd': v3}]}} """ if value is None: return for i, key in enumerate(keys[:-1]): if key.endswith('[]'): key_name = key[:-2] if key_name not in data: if isinstance(value, list): data[key_name] = [{} for _ in range(len(value))] else: raise ValueError( f'value {value} must be a list given an array path {key}' ) if isinstance(value, list): for j, d in enumerate(data[key_name]): set_value_by_path(d, keys[i + 1 :], value[j]) else: for d in data[key_name]: set_value_by_path(d, keys[i + 1 :], value) return data = data.setdefault(key, {}) existing_data = data.get(keys[-1]) # If there is an existing value, merge, not overwrite. if existing_data is not None: # Don't overwrite existing non-empty value with new empty value. # This is triggered when handling tuning datasets. if not value: pass # Don't fail when overwriting value with same value elif value == existing_data: pass # Instead of overwriting dictionary with another dictionary, merge them. # This is important for handling training and validation datasets in tuning. elif isinstance(existing_data, dict) and isinstance(value, dict): # Merging dictionaries. Consider deep merging in the future. existing_data.update(value) else: raise ValueError( f'Cannot set value for an existing key. Key: {keys[-1]};' f' Existing value: {existing_data}; New value: {value}.' ) else: data[keys[-1]] = value def get_value_by_path(data: object, keys: list[str]): """Examples: get_value_by_path({'a': {'b': v}}, ['a', 'b']) -> v get_value_by_path({'a': {'b': [{'c': v1}, {'c': v2}]}}, ['a', 'b[]', 'c']) -> [v1, v2] """ if keys == ['_self']: return data for i, key in enumerate(keys): if not data: return None if key.endswith('[]'): key_name = key[:-2] if key_name in data: return [get_value_by_path(d, keys[i + 1 :]) for d in data[key_name]] else: return None else: if key in data: data = data[key] elif isinstance(data, BaseModel) and hasattr(data, key): data = getattr(data, key) else: return None return data class BaseModule: def __init__(self, api_client_: _api_client.ApiClient): self._api_client = api_client_ def convert_to_dict(obj: dict[str, object]) -> dict[str, object]: """Recursively converts a given object to a dictionary. If the object is a Pydantic model, it uses the model's `model_dump()` method. Args: obj: The object to convert. Returns: A dictionary representation of the object. """ if isinstance(obj, pydantic.BaseModel): return obj.model_dump(exclude_none=True) elif isinstance(obj, dict): return {key: convert_to_dict(value) for key, value in obj.items()} elif isinstance(obj, list): return [convert_to_dict(item) for item in obj] else: return obj def _remove_extra_fields( model: pydantic.BaseModel, response: dict[str, object] ) -> None: """Removes extra fields from the response that are not in the model. Mutates the response in place. """ key_values = list(response.items()) for key, value in key_values: # Need to convert to snake case to match model fields names # ex: UsageMetadata alias_map = { field_info.alias: key for key, field_info in model.model_fields.items() } if key not in model.model_fields and key not in alias_map: response.pop(key) continue key = alias_map.get(key, key) annotation = model.model_fields[key].annotation # Get the BaseModel if Optional if typing.get_origin(annotation) is Union: annotation = typing.get_args(annotation)[0] # if dict, assume BaseModel but also check that field type is not dict # example: FunctionCall.args if isinstance(value, dict) and typing.get_origin(annotation) is not dict: _remove_extra_fields(annotation, value) elif isinstance(value, list): for item in value: # assume a list of dict is list of BaseModel if isinstance(item, dict): _remove_extra_fields(typing.get_args(annotation)[0], item) class BaseModel(pydantic.BaseModel): model_config = pydantic.ConfigDict( alias_generator=alias_generators.to_camel, populate_by_name=True, from_attributes=True, protected_namespaces=(), extra='forbid', # This allows us to use arbitrary types in the model. E.g. PIL.Image. arbitrary_types_allowed=True, ser_json_bytes='base64', val_json_bytes='base64', ) @classmethod def _from_response( cls, response: dict[str, object], kwargs: dict[str, object] ) -> 'BaseModel': # To maintain forward compatibility, we need to remove extra fields from # the response. # We will provide another mechanism to allow users to access these fields. _remove_extra_fields(cls, response) validated_response = cls.model_validate(response) return validated_response def to_json_dict(self) -> dict[str, object]: return self.model_dump(exclude_none=True, mode='json') class CaseInSensitiveEnum(str, enum.Enum): """Case insensitive enum.""" @classmethod def _missing_(cls, value): try: return cls[value.upper()] # Try to access directly with uppercase except KeyError: try: return cls[value.lower()] # Try to access directly with lowercase except KeyError as e: raise ValueError(f"{value} is not a valid {cls.__name__}") from e def timestamped_unique_name() -> str: """Composes a timestamped unique name. Returns: A string representing a unique name. """ timestamp = datetime.datetime.now().strftime('%Y%m%d%H%M%S') unique_id = uuid.uuid4().hex[0:5] return f'{timestamp}_{unique_id}' def encode_unserializable_types(data: dict[str, object]) -> dict[str, object]: """Converts unserializable types in dict to json.dumps() compatible types. This function is called in models.py after calling convert_to_dict(). The convert_to_dict() can convert pydantic object to dict. However, the input to convert_to_dict() is dict mixed of pydantic object and nested dict(the output of converters). So they may be bytes in the dict and they are out of `ser_json_bytes` control in model_dump(mode='json') called in `convert_to_dict`, as well as datetime deserialization in Pydantic json mode. Returns: A dictionary with json.dumps() incompatible type (e.g. bytes datetime) to compatible type (e.g. base64 encoded string, isoformat date string). """ processed_data = {} if not isinstance(data, dict): return data for key, value in data.items(): if isinstance(value, bytes): processed_data[key] = base64.urlsafe_b64encode(value).decode('ascii') elif isinstance(value, datetime.datetime): processed_data[key] = value.isoformat() elif isinstance(value, dict): processed_data[key] = encode_unserializable_types(value) elif isinstance(value, list): if all(isinstance(v, bytes) for v in value): processed_data[key] = [ base64.urlsafe_b64encode(v).decode('ascii') for v in value ] if all(isinstance(v, datetime.datetime) for v in value): processed_data[key] = [v.isoformat() for v in value] else: processed_data[key] = [encode_unserializable_types(v) for v in value] else: processed_data[key] = value return processed_data