aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/google/genai/_common.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/google/genai/_common.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to '.venv/lib/python3.12/site-packages/google/genai/_common.py')
-rw-r--r--.venv/lib/python3.12/site-packages/google/genai/_common.py272
1 files changed, 272 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/google/genai/_common.py b/.venv/lib/python3.12/site-packages/google/genai/_common.py
new file mode 100644
index 00000000..3211fd40
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/google/genai/_common.py
@@ -0,0 +1,272 @@
+# Copyright 2024 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Common utilities for the SDK."""
+
+import base64
+import datetime
+import enum
+import typing
+from typing import Union
+import uuid
+
+import pydantic
+from pydantic import alias_generators
+
+from . import _api_client
+
+
+def set_value_by_path(data, keys, value):
+ """Examples:
+
+ set_value_by_path({}, ['a', 'b'], v)
+ -> {'a': {'b': v}}
+ set_value_by_path({}, ['a', 'b[]', c], [v1, v2])
+ -> {'a': {'b': [{'c': v1}, {'c': v2}]}}
+ set_value_by_path({'a': {'b': [{'c': v1}, {'c': v2}]}}, ['a', 'b[]', 'd'], v3)
+ -> {'a': {'b': [{'c': v1, 'd': v3}, {'c': v2, 'd': v3}]}}
+ """
+ if value is None:
+ return
+ for i, key in enumerate(keys[:-1]):
+ if key.endswith('[]'):
+ key_name = key[:-2]
+ if key_name not in data:
+ if isinstance(value, list):
+ data[key_name] = [{} for _ in range(len(value))]
+ else:
+ raise ValueError(
+ f'value {value} must be a list given an array path {key}'
+ )
+ if isinstance(value, list):
+ for j, d in enumerate(data[key_name]):
+ set_value_by_path(d, keys[i + 1 :], value[j])
+ else:
+ for d in data[key_name]:
+ set_value_by_path(d, keys[i + 1 :], value)
+ return
+
+ data = data.setdefault(key, {})
+
+ existing_data = data.get(keys[-1])
+ # If there is an existing value, merge, not overwrite.
+ if existing_data is not None:
+ # Don't overwrite existing non-empty value with new empty value.
+ # This is triggered when handling tuning datasets.
+ if not value:
+ pass
+ # Don't fail when overwriting value with same value
+ elif value == existing_data:
+ pass
+ # Instead of overwriting dictionary with another dictionary, merge them.
+ # This is important for handling training and validation datasets in tuning.
+ elif isinstance(existing_data, dict) and isinstance(value, dict):
+ # Merging dictionaries. Consider deep merging in the future.
+ existing_data.update(value)
+ else:
+ raise ValueError(
+ f'Cannot set value for an existing key. Key: {keys[-1]};'
+ f' Existing value: {existing_data}; New value: {value}.'
+ )
+ else:
+ data[keys[-1]] = value
+
+
+def get_value_by_path(data: object, keys: list[str]):
+ """Examples:
+
+ get_value_by_path({'a': {'b': v}}, ['a', 'b'])
+ -> v
+ get_value_by_path({'a': {'b': [{'c': v1}, {'c': v2}]}}, ['a', 'b[]', 'c'])
+ -> [v1, v2]
+ """
+ if keys == ['_self']:
+ return data
+ for i, key in enumerate(keys):
+ if not data:
+ return None
+ if key.endswith('[]'):
+ key_name = key[:-2]
+ if key_name in data:
+ return [get_value_by_path(d, keys[i + 1 :]) for d in data[key_name]]
+ else:
+ return None
+ else:
+ if key in data:
+ data = data[key]
+ elif isinstance(data, BaseModel) and hasattr(data, key):
+ data = getattr(data, key)
+ else:
+ return None
+ return data
+
+
+class BaseModule:
+
+ def __init__(self, api_client_: _api_client.ApiClient):
+ self._api_client = api_client_
+
+
+def convert_to_dict(obj: dict[str, object]) -> dict[str, object]:
+ """Recursively converts a given object to a dictionary.
+
+ If the object is a Pydantic model, it uses the model's `model_dump()` method.
+
+ Args:
+ obj: The object to convert.
+
+ Returns:
+ A dictionary representation of the object.
+ """
+ if isinstance(obj, pydantic.BaseModel):
+ return obj.model_dump(exclude_none=True)
+ elif isinstance(obj, dict):
+ return {key: convert_to_dict(value) for key, value in obj.items()}
+ elif isinstance(obj, list):
+ return [convert_to_dict(item) for item in obj]
+ else:
+ return obj
+
+
+def _remove_extra_fields(
+ model: pydantic.BaseModel, response: dict[str, object]
+) -> None:
+ """Removes extra fields from the response that are not in the model.
+
+ Mutates the response in place.
+ """
+
+ key_values = list(response.items())
+
+ for key, value in key_values:
+ # Need to convert to snake case to match model fields names
+ # ex: UsageMetadata
+ alias_map = {
+ field_info.alias: key for key, field_info in model.model_fields.items()
+ }
+
+ if key not in model.model_fields and key not in alias_map:
+ response.pop(key)
+ continue
+
+ key = alias_map.get(key, key)
+
+ annotation = model.model_fields[key].annotation
+
+ # Get the BaseModel if Optional
+ if typing.get_origin(annotation) is Union:
+ annotation = typing.get_args(annotation)[0]
+
+ # if dict, assume BaseModel but also check that field type is not dict
+ # example: FunctionCall.args
+ if isinstance(value, dict) and typing.get_origin(annotation) is not dict:
+ _remove_extra_fields(annotation, value)
+ elif isinstance(value, list):
+ for item in value:
+ # assume a list of dict is list of BaseModel
+ if isinstance(item, dict):
+ _remove_extra_fields(typing.get_args(annotation)[0], item)
+
+
+class BaseModel(pydantic.BaseModel):
+
+ model_config = pydantic.ConfigDict(
+ alias_generator=alias_generators.to_camel,
+ populate_by_name=True,
+ from_attributes=True,
+ protected_namespaces=(),
+ extra='forbid',
+ # This allows us to use arbitrary types in the model. E.g. PIL.Image.
+ arbitrary_types_allowed=True,
+ ser_json_bytes='base64',
+ val_json_bytes='base64',
+ )
+
+ @classmethod
+ def _from_response(
+ cls, response: dict[str, object], kwargs: dict[str, object]
+ ) -> 'BaseModel':
+ # To maintain forward compatibility, we need to remove extra fields from
+ # the response.
+ # We will provide another mechanism to allow users to access these fields.
+ _remove_extra_fields(cls, response)
+ validated_response = cls.model_validate(response)
+ return validated_response
+
+ def to_json_dict(self) -> dict[str, object]:
+ return self.model_dump(exclude_none=True, mode='json')
+
+
+class CaseInSensitiveEnum(str, enum.Enum):
+ """Case insensitive enum."""
+
+ @classmethod
+ def _missing_(cls, value):
+ try:
+ return cls[value.upper()] # Try to access directly with uppercase
+ except KeyError:
+ try:
+ return cls[value.lower()] # Try to access directly with lowercase
+ except KeyError as e:
+ raise ValueError(f"{value} is not a valid {cls.__name__}") from e
+
+
+def timestamped_unique_name() -> str:
+ """Composes a timestamped unique name.
+
+ Returns:
+ A string representing a unique name.
+ """
+ timestamp = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
+ unique_id = uuid.uuid4().hex[0:5]
+ return f'{timestamp}_{unique_id}'
+
+
+def encode_unserializable_types(data: dict[str, object]) -> dict[str, object]:
+ """Converts unserializable types in dict to json.dumps() compatible types.
+
+ This function is called in models.py after calling convert_to_dict(). The
+ convert_to_dict() can convert pydantic object to dict. However, the input to
+ convert_to_dict() is dict mixed of pydantic object and nested dict(the output
+ of converters). So they may be bytes in the dict and they are out of
+ `ser_json_bytes` control in model_dump(mode='json') called in
+ `convert_to_dict`, as well as datetime deserialization in Pydantic json mode.
+
+ Returns:
+ A dictionary with json.dumps() incompatible type (e.g. bytes datetime)
+ to compatible type (e.g. base64 encoded string, isoformat date string).
+ """
+ processed_data = {}
+ if not isinstance(data, dict):
+ return data
+ for key, value in data.items():
+ if isinstance(value, bytes):
+ processed_data[key] = base64.urlsafe_b64encode(value).decode('ascii')
+ elif isinstance(value, datetime.datetime):
+ processed_data[key] = value.isoformat()
+ elif isinstance(value, dict):
+ processed_data[key] = encode_unserializable_types(value)
+ elif isinstance(value, list):
+ if all(isinstance(v, bytes) for v in value):
+ processed_data[key] = [
+ base64.urlsafe_b64encode(v).decode('ascii') for v in value
+ ]
+ if all(isinstance(v, datetime.datetime) for v in value):
+ processed_data[key] = [v.isoformat() for v in value]
+ else:
+ processed_data[key] = [encode_unserializable_types(v) for v in value]
+ else:
+ processed_data[key] = value
+ return processed_data