aboutsummaryrefslogtreecommitdiff
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""Common utilities for the SDK."""

import base64
import datetime
import enum
import typing
from typing import Union
import uuid

import pydantic
from pydantic import alias_generators

from . import _api_client


def set_value_by_path(data, keys, value):
  """Examples:

  set_value_by_path({}, ['a', 'b'], v)
    -> {'a': {'b': v}}
  set_value_by_path({}, ['a', 'b[]', c], [v1, v2])
    -> {'a': {'b': [{'c': v1}, {'c': v2}]}}
  set_value_by_path({'a': {'b': [{'c': v1}, {'c': v2}]}}, ['a', 'b[]', 'd'], v3)
    -> {'a': {'b': [{'c': v1, 'd': v3}, {'c': v2, 'd': v3}]}}
  """
  if value is None:
    return
  for i, key in enumerate(keys[:-1]):
    if key.endswith('[]'):
      key_name = key[:-2]
      if key_name not in data:
        if isinstance(value, list):
          data[key_name] = [{} for _ in range(len(value))]
        else:
          raise ValueError(
              f'value {value} must be a list given an array path {key}'
          )
      if isinstance(value, list):
        for j, d in enumerate(data[key_name]):
          set_value_by_path(d, keys[i + 1 :], value[j])
      else:
        for d in data[key_name]:
          set_value_by_path(d, keys[i + 1 :], value)
      return

    data = data.setdefault(key, {})

  existing_data = data.get(keys[-1])
  # If there is an existing value, merge, not overwrite.
  if existing_data is not None:
    # Don't overwrite existing non-empty value with new empty value.
    # This is triggered when handling tuning datasets.
    if not value:
      pass
    # Don't fail when overwriting value with same value
    elif value == existing_data:
      pass
    # Instead of overwriting dictionary with another dictionary, merge them.
    # This is important for handling training and validation datasets in tuning.
    elif isinstance(existing_data, dict) and isinstance(value, dict):
      # Merging dictionaries. Consider deep merging in the future.
      existing_data.update(value)
    else:
      raise ValueError(
          f'Cannot set value for an existing key. Key: {keys[-1]};'
          f' Existing value: {existing_data}; New value: {value}.'
      )
  else:
    data[keys[-1]] = value


def get_value_by_path(data: object, keys: list[str]):
  """Examples:

  get_value_by_path({'a': {'b': v}}, ['a', 'b'])
    -> v
  get_value_by_path({'a': {'b': [{'c': v1}, {'c': v2}]}}, ['a', 'b[]', 'c'])
    -> [v1, v2]
  """
  if keys == ['_self']:
    return data
  for i, key in enumerate(keys):
    if not data:
      return None
    if key.endswith('[]'):
      key_name = key[:-2]
      if key_name in data:
        return [get_value_by_path(d, keys[i + 1 :]) for d in data[key_name]]
      else:
        return None
    else:
      if key in data:
        data = data[key]
      elif isinstance(data, BaseModel) and hasattr(data, key):
        data = getattr(data, key)
      else:
        return None
  return data


class BaseModule:

  def __init__(self, api_client_: _api_client.ApiClient):
    self._api_client = api_client_


def convert_to_dict(obj: dict[str, object]) -> dict[str, object]:
  """Recursively converts a given object to a dictionary.

  If the object is a Pydantic model, it uses the model's `model_dump()` method.

  Args:
    obj: The object to convert.

  Returns:
    A dictionary representation of the object.
  """
  if isinstance(obj, pydantic.BaseModel):
    return obj.model_dump(exclude_none=True)
  elif isinstance(obj, dict):
    return {key: convert_to_dict(value) for key, value in obj.items()}
  elif isinstance(obj, list):
    return [convert_to_dict(item) for item in obj]
  else:
    return obj


def _remove_extra_fields(
    model: pydantic.BaseModel, response: dict[str, object]
) -> None:
  """Removes extra fields from the response that are not in the model.

  Mutates the response in place.
  """

  key_values = list(response.items())

  for key, value in key_values:
    # Need to convert to snake case to match model fields names
    # ex: UsageMetadata
    alias_map = {
        field_info.alias: key for key, field_info in model.model_fields.items()
    }

    if key not in model.model_fields and key not in alias_map:
      response.pop(key)
      continue

    key = alias_map.get(key, key)

    annotation = model.model_fields[key].annotation

    # Get the BaseModel if Optional
    if typing.get_origin(annotation) is Union:
      annotation = typing.get_args(annotation)[0]

    # if dict, assume BaseModel but also check that field type is not dict
    # example: FunctionCall.args
    if isinstance(value, dict) and typing.get_origin(annotation) is not dict:
      _remove_extra_fields(annotation, value)
    elif isinstance(value, list):
      for item in value:
        # assume a list of dict is list of BaseModel
        if isinstance(item, dict):
          _remove_extra_fields(typing.get_args(annotation)[0], item)


class BaseModel(pydantic.BaseModel):

  model_config = pydantic.ConfigDict(
      alias_generator=alias_generators.to_camel,
      populate_by_name=True,
      from_attributes=True,
      protected_namespaces=(),
      extra='forbid',
      # This allows us to use arbitrary types in the model. E.g. PIL.Image.
      arbitrary_types_allowed=True,
      ser_json_bytes='base64',
      val_json_bytes='base64',
  )

  @classmethod
  def _from_response(
      cls, response: dict[str, object], kwargs: dict[str, object]
  ) -> 'BaseModel':
    # To maintain forward compatibility, we need to remove extra fields from
    # the response.
    # We will provide another mechanism to allow users to access these fields.
    _remove_extra_fields(cls, response)
    validated_response = cls.model_validate(response)
    return validated_response

  def to_json_dict(self) -> dict[str, object]:
    return self.model_dump(exclude_none=True, mode='json')


class CaseInSensitiveEnum(str, enum.Enum):
  """Case insensitive enum."""

  @classmethod
  def _missing_(cls, value):
    try:
      return cls[value.upper()]  # Try to access directly with uppercase
    except KeyError:
      try:
        return cls[value.lower()]  # Try to access directly with lowercase
      except KeyError as e:
        raise ValueError(f"{value} is not a valid {cls.__name__}") from e


def timestamped_unique_name() -> str:
  """Composes a timestamped unique name.

  Returns:
      A string representing a unique name.
  """
  timestamp = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
  unique_id = uuid.uuid4().hex[0:5]
  return f'{timestamp}_{unique_id}'


def encode_unserializable_types(data: dict[str, object]) -> dict[str, object]:
  """Converts unserializable types in dict to json.dumps() compatible types.

  This function is called in models.py after calling convert_to_dict(). The
  convert_to_dict() can convert pydantic object to dict. However, the input to
  convert_to_dict() is dict mixed of pydantic object and nested dict(the output
  of converters). So they may be bytes in the dict and they are out of
  `ser_json_bytes` control in model_dump(mode='json') called in
  `convert_to_dict`, as well as datetime deserialization in Pydantic json mode.

  Returns:
    A dictionary with json.dumps() incompatible type (e.g. bytes datetime)
    to compatible type (e.g. base64 encoded string, isoformat date string).
  """
  processed_data = {}
  if not isinstance(data, dict):
    return data
  for key, value in data.items():
    if isinstance(value, bytes):
      processed_data[key] = base64.urlsafe_b64encode(value).decode('ascii')
    elif isinstance(value, datetime.datetime):
      processed_data[key] = value.isoformat()
    elif isinstance(value, dict):
      processed_data[key] = encode_unserializable_types(value)
    elif isinstance(value, list):
      if all(isinstance(v, bytes) for v in value):
        processed_data[key] = [
            base64.urlsafe_b64encode(v).decode('ascii') for v in value
        ]
      if all(isinstance(v, datetime.datetime) for v in value):
        processed_data[key] = [v.isoformat() for v in value]
      else:
        processed_data[key] = [encode_unserializable_types(v) for v in value]
    else:
      processed_data[key] = value
  return processed_data