about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/google/genai/_common.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/google/genai/_common.py')
-rw-r--r--.venv/lib/python3.12/site-packages/google/genai/_common.py272
1 files changed, 272 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/google/genai/_common.py b/.venv/lib/python3.12/site-packages/google/genai/_common.py
new file mode 100644
index 00000000..3211fd40
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/google/genai/_common.py
@@ -0,0 +1,272 @@
+# Copyright 2024 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Common utilities for the SDK."""
+
+import base64
+import datetime
+import enum
+import typing
+from typing import Union
+import uuid
+
+import pydantic
+from pydantic import alias_generators
+
+from . import _api_client
+
+
+def set_value_by_path(data, keys, value):
+  """Examples:
+
+  set_value_by_path({}, ['a', 'b'], v)
+    -> {'a': {'b': v}}
+  set_value_by_path({}, ['a', 'b[]', c], [v1, v2])
+    -> {'a': {'b': [{'c': v1}, {'c': v2}]}}
+  set_value_by_path({'a': {'b': [{'c': v1}, {'c': v2}]}}, ['a', 'b[]', 'd'], v3)
+    -> {'a': {'b': [{'c': v1, 'd': v3}, {'c': v2, 'd': v3}]}}
+  """
+  if value is None:
+    return
+  for i, key in enumerate(keys[:-1]):
+    if key.endswith('[]'):
+      key_name = key[:-2]
+      if key_name not in data:
+        if isinstance(value, list):
+          data[key_name] = [{} for _ in range(len(value))]
+        else:
+          raise ValueError(
+              f'value {value} must be a list given an array path {key}'
+          )
+      if isinstance(value, list):
+        for j, d in enumerate(data[key_name]):
+          set_value_by_path(d, keys[i + 1 :], value[j])
+      else:
+        for d in data[key_name]:
+          set_value_by_path(d, keys[i + 1 :], value)
+      return
+
+    data = data.setdefault(key, {})
+
+  existing_data = data.get(keys[-1])
+  # If there is an existing value, merge, not overwrite.
+  if existing_data is not None:
+    # Don't overwrite existing non-empty value with new empty value.
+    # This is triggered when handling tuning datasets.
+    if not value:
+      pass
+    # Don't fail when overwriting value with same value
+    elif value == existing_data:
+      pass
+    # Instead of overwriting dictionary with another dictionary, merge them.
+    # This is important for handling training and validation datasets in tuning.
+    elif isinstance(existing_data, dict) and isinstance(value, dict):
+      # Merging dictionaries. Consider deep merging in the future.
+      existing_data.update(value)
+    else:
+      raise ValueError(
+          f'Cannot set value for an existing key. Key: {keys[-1]};'
+          f' Existing value: {existing_data}; New value: {value}.'
+      )
+  else:
+    data[keys[-1]] = value
+
+
+def get_value_by_path(data: object, keys: list[str]):
+  """Examples:
+
+  get_value_by_path({'a': {'b': v}}, ['a', 'b'])
+    -> v
+  get_value_by_path({'a': {'b': [{'c': v1}, {'c': v2}]}}, ['a', 'b[]', 'c'])
+    -> [v1, v2]
+  """
+  if keys == ['_self']:
+    return data
+  for i, key in enumerate(keys):
+    if not data:
+      return None
+    if key.endswith('[]'):
+      key_name = key[:-2]
+      if key_name in data:
+        return [get_value_by_path(d, keys[i + 1 :]) for d in data[key_name]]
+      else:
+        return None
+    else:
+      if key in data:
+        data = data[key]
+      elif isinstance(data, BaseModel) and hasattr(data, key):
+        data = getattr(data, key)
+      else:
+        return None
+  return data
+
+
+class BaseModule:
+
+  def __init__(self, api_client_: _api_client.ApiClient):
+    self._api_client = api_client_
+
+
+def convert_to_dict(obj: dict[str, object]) -> dict[str, object]:
+  """Recursively converts a given object to a dictionary.
+
+  If the object is a Pydantic model, it uses the model's `model_dump()` method.
+
+  Args:
+    obj: The object to convert.
+
+  Returns:
+    A dictionary representation of the object.
+  """
+  if isinstance(obj, pydantic.BaseModel):
+    return obj.model_dump(exclude_none=True)
+  elif isinstance(obj, dict):
+    return {key: convert_to_dict(value) for key, value in obj.items()}
+  elif isinstance(obj, list):
+    return [convert_to_dict(item) for item in obj]
+  else:
+    return obj
+
+
+def _remove_extra_fields(
+    model: pydantic.BaseModel, response: dict[str, object]
+) -> None:
+  """Removes extra fields from the response that are not in the model.
+
+  Mutates the response in place.
+  """
+
+  key_values = list(response.items())
+
+  for key, value in key_values:
+    # Need to convert to snake case to match model fields names
+    # ex: UsageMetadata
+    alias_map = {
+        field_info.alias: key for key, field_info in model.model_fields.items()
+    }
+
+    if key not in model.model_fields and key not in alias_map:
+      response.pop(key)
+      continue
+
+    key = alias_map.get(key, key)
+
+    annotation = model.model_fields[key].annotation
+
+    # Get the BaseModel if Optional
+    if typing.get_origin(annotation) is Union:
+      annotation = typing.get_args(annotation)[0]
+
+    # if dict, assume BaseModel but also check that field type is not dict
+    # example: FunctionCall.args
+    if isinstance(value, dict) and typing.get_origin(annotation) is not dict:
+      _remove_extra_fields(annotation, value)
+    elif isinstance(value, list):
+      for item in value:
+        # assume a list of dict is list of BaseModel
+        if isinstance(item, dict):
+          _remove_extra_fields(typing.get_args(annotation)[0], item)
+
+
+class BaseModel(pydantic.BaseModel):
+
+  model_config = pydantic.ConfigDict(
+      alias_generator=alias_generators.to_camel,
+      populate_by_name=True,
+      from_attributes=True,
+      protected_namespaces=(),
+      extra='forbid',
+      # This allows us to use arbitrary types in the model. E.g. PIL.Image.
+      arbitrary_types_allowed=True,
+      ser_json_bytes='base64',
+      val_json_bytes='base64',
+  )
+
+  @classmethod
+  def _from_response(
+      cls, response: dict[str, object], kwargs: dict[str, object]
+  ) -> 'BaseModel':
+    # To maintain forward compatibility, we need to remove extra fields from
+    # the response.
+    # We will provide another mechanism to allow users to access these fields.
+    _remove_extra_fields(cls, response)
+    validated_response = cls.model_validate(response)
+    return validated_response
+
+  def to_json_dict(self) -> dict[str, object]:
+    return self.model_dump(exclude_none=True, mode='json')
+
+
+class CaseInSensitiveEnum(str, enum.Enum):
+  """Case insensitive enum."""
+
+  @classmethod
+  def _missing_(cls, value):
+    try:
+      return cls[value.upper()]  # Try to access directly with uppercase
+    except KeyError:
+      try:
+        return cls[value.lower()]  # Try to access directly with lowercase
+      except KeyError as e:
+        raise ValueError(f"{value} is not a valid {cls.__name__}") from e
+
+
+def timestamped_unique_name() -> str:
+  """Composes a timestamped unique name.
+
+  Returns:
+      A string representing a unique name.
+  """
+  timestamp = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
+  unique_id = uuid.uuid4().hex[0:5]
+  return f'{timestamp}_{unique_id}'
+
+
+def encode_unserializable_types(data: dict[str, object]) -> dict[str, object]:
+  """Converts unserializable types in dict to json.dumps() compatible types.
+
+  This function is called in models.py after calling convert_to_dict(). The
+  convert_to_dict() can convert pydantic object to dict. However, the input to
+  convert_to_dict() is dict mixed of pydantic object and nested dict(the output
+  of converters). So they may be bytes in the dict and they are out of
+  `ser_json_bytes` control in model_dump(mode='json') called in
+  `convert_to_dict`, as well as datetime deserialization in Pydantic json mode.
+
+  Returns:
+    A dictionary with json.dumps() incompatible type (e.g. bytes datetime)
+    to compatible type (e.g. base64 encoded string, isoformat date string).
+  """
+  processed_data = {}
+  if not isinstance(data, dict):
+    return data
+  for key, value in data.items():
+    if isinstance(value, bytes):
+      processed_data[key] = base64.urlsafe_b64encode(value).decode('ascii')
+    elif isinstance(value, datetime.datetime):
+      processed_data[key] = value.isoformat()
+    elif isinstance(value, dict):
+      processed_data[key] = encode_unserializable_types(value)
+    elif isinstance(value, list):
+      if all(isinstance(v, bytes) for v in value):
+        processed_data[key] = [
+            base64.urlsafe_b64encode(v).decode('ascii') for v in value
+        ]
+      if all(isinstance(v, datetime.datetime) for v in value):
+        processed_data[key] = [v.isoformat() for v in value]
+      else:
+        processed_data[key] = [encode_unserializable_types(v) for v in value]
+    else:
+      processed_data[key] = value
+  return processed_data