# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Common utilities for the SDK."""
import base64
import datetime
import enum
import typing
from typing import Union
import uuid
import pydantic
from pydantic import alias_generators
from . import _api_client
def set_value_by_path(data, keys, value):
"""Examples:
set_value_by_path({}, ['a', 'b'], v)
-> {'a': {'b': v}}
set_value_by_path({}, ['a', 'b[]', c], [v1, v2])
-> {'a': {'b': [{'c': v1}, {'c': v2}]}}
set_value_by_path({'a': {'b': [{'c': v1}, {'c': v2}]}}, ['a', 'b[]', 'd'], v3)
-> {'a': {'b': [{'c': v1, 'd': v3}, {'c': v2, 'd': v3}]}}
"""
if value is None:
return
for i, key in enumerate(keys[:-1]):
if key.endswith('[]'):
key_name = key[:-2]
if key_name not in data:
if isinstance(value, list):
data[key_name] = [{} for _ in range(len(value))]
else:
raise ValueError(
f'value {value} must be a list given an array path {key}'
)
if isinstance(value, list):
for j, d in enumerate(data[key_name]):
set_value_by_path(d, keys[i + 1 :], value[j])
else:
for d in data[key_name]:
set_value_by_path(d, keys[i + 1 :], value)
return
data = data.setdefault(key, {})
existing_data = data.get(keys[-1])
# If there is an existing value, merge, not overwrite.
if existing_data is not None:
# Don't overwrite existing non-empty value with new empty value.
# This is triggered when handling tuning datasets.
if not value:
pass
# Don't fail when overwriting value with same value
elif value == existing_data:
pass
# Instead of overwriting dictionary with another dictionary, merge them.
# This is important for handling training and validation datasets in tuning.
elif isinstance(existing_data, dict) and isinstance(value, dict):
# Merging dictionaries. Consider deep merging in the future.
existing_data.update(value)
else:
raise ValueError(
f'Cannot set value for an existing key. Key: {keys[-1]};'
f' Existing value: {existing_data}; New value: {value}.'
)
else:
data[keys[-1]] = value
def get_value_by_path(data: object, keys: list[str]):
"""Examples:
get_value_by_path({'a': {'b': v}}, ['a', 'b'])
-> v
get_value_by_path({'a': {'b': [{'c': v1}, {'c': v2}]}}, ['a', 'b[]', 'c'])
-> [v1, v2]
"""
if keys == ['_self']:
return data
for i, key in enumerate(keys):
if not data:
return None
if key.endswith('[]'):
key_name = key[:-2]
if key_name in data:
return [get_value_by_path(d, keys[i + 1 :]) for d in data[key_name]]
else:
return None
else:
if key in data:
data = data[key]
elif isinstance(data, BaseModel) and hasattr(data, key):
data = getattr(data, key)
else:
return None
return data
class BaseModule:
def __init__(self, api_client_: _api_client.ApiClient):
self._api_client = api_client_
def convert_to_dict(obj: dict[str, object]) -> dict[str, object]:
"""Recursively converts a given object to a dictionary.
If the object is a Pydantic model, it uses the model's `model_dump()` method.
Args:
obj: The object to convert.
Returns:
A dictionary representation of the object.
"""
if isinstance(obj, pydantic.BaseModel):
return obj.model_dump(exclude_none=True)
elif isinstance(obj, dict):
return {key: convert_to_dict(value) for key, value in obj.items()}
elif isinstance(obj, list):
return [convert_to_dict(item) for item in obj]
else:
return obj
def _remove_extra_fields(
model: pydantic.BaseModel, response: dict[str, object]
) -> None:
"""Removes extra fields from the response that are not in the model.
Mutates the response in place.
"""
key_values = list(response.items())
for key, value in key_values:
# Need to convert to snake case to match model fields names
# ex: UsageMetadata
alias_map = {
field_info.alias: key for key, field_info in model.model_fields.items()
}
if key not in model.model_fields and key not in alias_map:
response.pop(key)
continue
key = alias_map.get(key, key)
annotation = model.model_fields[key].annotation
# Get the BaseModel if Optional
if typing.get_origin(annotation) is Union:
annotation = typing.get_args(annotation)[0]
# if dict, assume BaseModel but also check that field type is not dict
# example: FunctionCall.args
if isinstance(value, dict) and typing.get_origin(annotation) is not dict:
_remove_extra_fields(annotation, value)
elif isinstance(value, list):
for item in value:
# assume a list of dict is list of BaseModel
if isinstance(item, dict):
_remove_extra_fields(typing.get_args(annotation)[0], item)
class BaseModel(pydantic.BaseModel):
model_config = pydantic.ConfigDict(
alias_generator=alias_generators.to_camel,
populate_by_name=True,
from_attributes=True,
protected_namespaces=(),
extra='forbid',
# This allows us to use arbitrary types in the model. E.g. PIL.Image.
arbitrary_types_allowed=True,
ser_json_bytes='base64',
val_json_bytes='base64',
)
@classmethod
def _from_response(
cls, response: dict[str, object], kwargs: dict[str, object]
) -> 'BaseModel':
# To maintain forward compatibility, we need to remove extra fields from
# the response.
# We will provide another mechanism to allow users to access these fields.
_remove_extra_fields(cls, response)
validated_response = cls.model_validate(response)
return validated_response
def to_json_dict(self) -> dict[str, object]:
return self.model_dump(exclude_none=True, mode='json')
class CaseInSensitiveEnum(str, enum.Enum):
"""Case insensitive enum."""
@classmethod
def _missing_(cls, value):
try:
return cls[value.upper()] # Try to access directly with uppercase
except KeyError:
try:
return cls[value.lower()] # Try to access directly with lowercase
except KeyError as e:
raise ValueError(f"{value} is not a valid {cls.__name__}") from e
def timestamped_unique_name() -> str:
"""Composes a timestamped unique name.
Returns:
A string representing a unique name.
"""
timestamp = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
unique_id = uuid.uuid4().hex[0:5]
return f'{timestamp}_{unique_id}'
def encode_unserializable_types(data: dict[str, object]) -> dict[str, object]:
"""Converts unserializable types in dict to json.dumps() compatible types.
This function is called in models.py after calling convert_to_dict(). The
convert_to_dict() can convert pydantic object to dict. However, the input to
convert_to_dict() is dict mixed of pydantic object and nested dict(the output
of converters). So they may be bytes in the dict and they are out of
`ser_json_bytes` control in model_dump(mode='json') called in
`convert_to_dict`, as well as datetime deserialization in Pydantic json mode.
Returns:
A dictionary with json.dumps() incompatible type (e.g. bytes datetime)
to compatible type (e.g. base64 encoded string, isoformat date string).
"""
processed_data = {}
if not isinstance(data, dict):
return data
for key, value in data.items():
if isinstance(value, bytes):
processed_data[key] = base64.urlsafe_b64encode(value).decode('ascii')
elif isinstance(value, datetime.datetime):
processed_data[key] = value.isoformat()
elif isinstance(value, dict):
processed_data[key] = encode_unserializable_types(value)
elif isinstance(value, list):
if all(isinstance(v, bytes) for v in value):
processed_data[key] = [
base64.urlsafe_b64encode(v).decode('ascii') for v in value
]
if all(isinstance(v, datetime.datetime) for v in value):
processed_data[key] = [v.isoformat() for v in value]
else:
processed_data[key] = [encode_unserializable_types(v) for v in value]
else:
processed_data[key] = value
return processed_data