aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/google/genai/_extra_utils.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/google/genai/_extra_utils.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to '.venv/lib/python3.12/site-packages/google/genai/_extra_utils.py')
-rw-r--r--.venv/lib/python3.12/site-packages/google/genai/_extra_utils.py310
1 files changed, 310 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/google/genai/_extra_utils.py b/.venv/lib/python3.12/site-packages/google/genai/_extra_utils.py
new file mode 100644
index 00000000..db8d377b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/google/genai/_extra_utils.py
@@ -0,0 +1,310 @@
+# Copyright 2024 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Extra utils depending on types that are shared between sync and async modules.
+"""
+
+import inspect
+import logging
+from typing import Any, Callable, Dict, get_args, get_origin, Optional, types as typing_types, Union
+
+import pydantic
+
+from . import _common
+from . import errors
+from . import types
+
+
+_DEFAULT_MAX_REMOTE_CALLS_AFC = 10
+
+
+def format_destination(
+ src: str,
+ config: Optional[types.CreateBatchJobConfigOrDict] = None,
+) -> types.CreateBatchJobConfig:
+ """Formats the destination uri based on the source uri."""
+ config = (
+ types._CreateBatchJobParameters(config=config).config
+ or types.CreateBatchJobConfig()
+ )
+
+ unique_name = None
+ if not config.display_name:
+ unique_name = _common.timestamped_unique_name()
+ config.display_name = f'genai_batch_job_{unique_name}'
+
+ if not config.dest:
+ if src.startswith('gs://') and src.endswith('.jsonl'):
+ # If source uri is "gs://bucket/path/to/src.jsonl", then the destination
+ # uri prefix will be "gs://bucket/path/to/src/dest".
+ config.dest = f'{src[:-6]}/dest'
+ elif src.startswith('bq://'):
+ # If source uri is "bq://project.dataset.src", then the destination
+ # uri will be "bq://project.dataset.src_dest_TIMESTAMP_UUID".
+ unique_name = unique_name or _common.timestamped_unique_name()
+ config.dest = f'{src}_dest_{unique_name}'
+ else:
+ raise ValueError(f'Unsupported source: {src}')
+ return config
+
+
+def get_function_map(
+ config: Optional[types.GenerateContentConfigOrDict] = None,
+) -> dict[str, object]:
+ """Returns a function map from the config."""
+ config_model = (
+ types.GenerateContentConfig(**config)
+ if config and isinstance(config, dict)
+ else config
+ )
+ function_map = {}
+ if not config_model:
+ return function_map
+ if config_model.tools:
+ for tool in config_model.tools:
+ if callable(tool):
+ if inspect.iscoroutinefunction(tool):
+ raise errors.UnsupportedFunctionError(
+ f'Function {tool.__name__} is a coroutine function, which is not'
+ ' supported for automatic function calling. Please manually invoke'
+ f' {tool.__name__} to get the function response.'
+ )
+ function_map[tool.__name__] = tool
+ return function_map
+
+
+def convert_number_values_for_function_call_args(
+ args: Union[dict[str, object], list[object], object],
+) -> Union[dict[str, object], list[object], object]:
+ """Converts float values with no decimal to integers."""
+ if isinstance(args, float) and args.is_integer():
+ return int(args)
+ if isinstance(args, dict):
+ return {
+ key: convert_number_values_for_function_call_args(value)
+ for key, value in args.items()
+ }
+ if isinstance(args, list):
+ return [
+ convert_number_values_for_function_call_args(value) for value in args
+ ]
+ return args
+
+
+def _is_annotation_pydantic_model(annotation: Any) -> bool:
+ return inspect.isclass(annotation) and issubclass(
+ annotation, pydantic.BaseModel
+ )
+
+
+def convert_if_exist_pydantic_model(
+ value: Any, annotation: Any, param_name: str, func_name: str
+) -> Any:
+ if isinstance(value, dict) and _is_annotation_pydantic_model(annotation):
+ try:
+ return annotation(**value)
+ except pydantic.ValidationError as e:
+ raise errors.UnknownFunctionCallArgumentError(
+ f'Failed to parse parameter {param_name} for function'
+ f' {func_name} from function call part because function call argument'
+ f' value {value} is not compatible with parameter annotation'
+ f' {annotation}, due to error {e}'
+ )
+ if isinstance(value, list) and get_origin(annotation) == list:
+ item_type = get_args(annotation)[0]
+ return [
+ convert_if_exist_pydantic_model(item, item_type, param_name, func_name)
+ for item in value
+ ]
+ if isinstance(value, dict) and get_origin(annotation) == dict:
+ _, value_type = get_args(annotation)
+ return {
+ k: convert_if_exist_pydantic_model(v, value_type, param_name, func_name)
+ for k, v in value.items()
+ }
+ # example 1: typing.Union[int, float]
+ # example 2: int | float equivalent to typing.types.UnionType[int, float]
+ if get_origin(annotation) in (Union, typing_types.UnionType):
+ for arg in get_args(annotation):
+ if isinstance(value, arg) or (
+ isinstance(value, dict) and _is_annotation_pydantic_model(arg)
+ ):
+ try:
+ return convert_if_exist_pydantic_model(
+ value, arg, param_name, func_name
+ )
+ # do not raise here because there could be multiple pydantic model types
+ # in the union type.
+ except pydantic.ValidationError:
+ continue
+ # if none of the union type is matched, raise error
+ raise errors.UnknownFunctionCallArgumentError(
+ f'Failed to parse parameter {param_name} for function'
+ f' {func_name} from function call part because function call argument'
+ f' value {value} cannot be converted to parameter annotation'
+ f' {annotation}.'
+ )
+ # the only exception for value and annotation type to be different is int and
+ # float. see convert_number_values_for_function_call_args function for context
+ if isinstance(value, int) and annotation is float:
+ return value
+ if not isinstance(value, annotation):
+ raise errors.UnknownFunctionCallArgumentError(
+ f'Failed to parse parameter {param_name} for function {func_name} from'
+ f' function call part because function call argument value {value} is'
+ f' not compatible with parameter annotation {annotation}.'
+ )
+ return value
+
+
+def invoke_function_from_dict_args(
+ args: Dict[str, Any], function_to_invoke: Callable
+) -> Any:
+ signature = inspect.signature(function_to_invoke)
+ func_name = function_to_invoke.__name__
+ converted_args = {}
+ for param_name, param in signature.parameters.items():
+ if param_name in args:
+ converted_args[param_name] = convert_if_exist_pydantic_model(
+ args[param_name],
+ param.annotation,
+ param_name,
+ func_name,
+ )
+ try:
+ return function_to_invoke(**converted_args)
+ except Exception as e:
+ raise errors.FunctionInvocationError(
+ f'Failed to invoke function {func_name} with converted arguments'
+ f' {converted_args} from model returned function call argument'
+ f' {args} because of error {e}'
+ )
+
+
+def get_function_response_parts(
+ response: types.GenerateContentResponse,
+ function_map: dict[str, object],
+) -> list[types.Part]:
+ """Returns the function response parts from the response."""
+ func_response_parts = []
+ for part in response.candidates[0].content.parts:
+ if not part.function_call:
+ continue
+ func_name = part.function_call.name
+ func = function_map[func_name]
+ args = convert_number_values_for_function_call_args(part.function_call.args)
+ try:
+ response = {'result': invoke_function_from_dict_args(args, func)}
+ except Exception as e: # pylint: disable=broad-except
+ response = {'error': str(e)}
+ func_response = types.Part.from_function_response(func_name, response)
+
+ func_response_parts.append(func_response)
+ return func_response_parts
+
+
+def should_disable_afc(
+ config: Optional[types.GenerateContentConfigOrDict] = None,
+) -> bool:
+ """Returns whether automatic function calling is enabled."""
+ config_model = (
+ types.GenerateContentConfig(**config)
+ if config and isinstance(config, dict)
+ else config
+ )
+
+ # If max_remote_calls is less or equal to 0, warn and disable AFC.
+ if (
+ config_model
+ and config_model.automatic_function_calling
+ and config_model.automatic_function_calling.maximum_remote_calls
+ is not None
+ and int(config_model.automatic_function_calling.maximum_remote_calls)
+ <= 0
+ ):
+ logging.warning(
+ 'max_remote_calls in automatic_function_calling_config'
+ f' {config_model.automatic_function_calling.maximum_remote_calls} is'
+ ' less than or equal to 0. Disabling automatic function calling.'
+ ' Please set max_remote_calls to a positive integer.'
+ )
+ return True
+
+ # Default to enable AFC if not specified.
+ if (
+ not config_model
+ or not config_model.automatic_function_calling
+ or config_model.automatic_function_calling.disable is None
+ ):
+ return False
+
+ if (
+ config_model.automatic_function_calling.disable
+ and config_model.automatic_function_calling.maximum_remote_calls
+ is not None
+ and int(config_model.automatic_function_calling.maximum_remote_calls) > 0
+ ):
+ logging.warning(
+ '`automatic_function_calling.disable` is set to `True`. But'
+ ' `automatic_function_calling.maximum_remote_calls` is set to be a'
+ ' positive number'
+ f' {config_model.automatic_function_calling.maximum_remote_calls}.'
+ ' Disabling automatic function calling. If you want to enable'
+ ' automatic function calling, please set'
+ ' `automatic_function_calling.disable` to `False` or leave it unset,'
+ ' and set `automatic_function_calling.maximum_remote_calls` to a'
+ ' positive integer or leave'
+ ' `automatic_function_calling.maximum_remote_calls` unset.'
+ )
+
+ return config_model.automatic_function_calling.disable
+
+
+def get_max_remote_calls_afc(
+ config: Optional[types.GenerateContentConfigOrDict] = None,
+) -> int:
+ """Returns the remaining remote calls for automatic function calling."""
+ if should_disable_afc(config):
+ raise ValueError(
+ 'automatic function calling is not enabled, but SDK is trying to get'
+ ' max remote calls.'
+ )
+ config_model = (
+ types.GenerateContentConfig(**config)
+ if config and isinstance(config, dict)
+ else config
+ )
+ if (
+ not config_model
+ or not config_model.automatic_function_calling
+ or config_model.automatic_function_calling.maximum_remote_calls is None
+ ):
+ return _DEFAULT_MAX_REMOTE_CALLS_AFC
+ return int(config_model.automatic_function_calling.maximum_remote_calls)
+
+def should_append_afc_history(
+ config: Optional[types.GenerateContentConfigOrDict] = None,
+) -> bool:
+ config_model = (
+ types.GenerateContentConfig(**config)
+ if config and isinstance(config, dict)
+ else config
+ )
+ if (
+ not config_model
+ or not config_model.automatic_function_calling
+ ):
+ return True
+ return not config_model.automatic_function_calling.ignore_call_history