aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/google/genai/_automatic_function_calling_util.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/google/genai/_automatic_function_calling_util.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to '.venv/lib/python3.12/site-packages/google/genai/_automatic_function_calling_util.py')
-rw-r--r--.venv/lib/python3.12/site-packages/google/genai/_automatic_function_calling_util.py294
1 files changed, 294 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/google/genai/_automatic_function_calling_util.py b/.venv/lib/python3.12/site-packages/google/genai/_automatic_function_calling_util.py
new file mode 100644
index 00000000..12d1df7c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/google/genai/_automatic_function_calling_util.py
@@ -0,0 +1,294 @@
+# Copyright 2024 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import inspect
+import types as typing_types
+from typing import Any, Callable, Literal, Union, _GenericAlias, get_args, get_origin
+import pydantic
+from . import types
+
+_py_builtin_type_to_schema_type = {
+ str: 'STRING',
+ int: 'INTEGER',
+ float: 'NUMBER',
+ bool: 'BOOLEAN',
+ list: 'ARRAY',
+ dict: 'OBJECT',
+}
+
+
+def _is_builtin_primitive_or_compound(
+ annotation: inspect.Parameter.annotation,
+) -> bool:
+ return annotation in _py_builtin_type_to_schema_type.keys()
+
+
+def _raise_for_any_of_if_mldev(schema: types.Schema):
+ if schema.any_of:
+ raise ValueError(
+ 'AnyOf is not supported in function declaration schema for Google AI.'
+ )
+
+
+def _raise_for_default_if_mldev(schema: types.Schema):
+ if schema.default is not None:
+ raise ValueError(
+ 'Default value is not supported in function declaration schema for'
+ ' Google AI.'
+ )
+
+
+def _raise_for_nullable_if_mldev(schema: types.Schema):
+ if schema.nullable:
+ raise ValueError(
+ 'Nullable is not supported in function declaration schema for'
+ ' Google AI.'
+ )
+
+
+def _raise_if_schema_unsupported(variant: str, schema: types.Schema):
+ if not variant == 'VERTEX_AI':
+ _raise_for_any_of_if_mldev(schema)
+ _raise_for_default_if_mldev(schema)
+ _raise_for_nullable_if_mldev(schema)
+
+
+def _is_default_value_compatible(
+ default_value: Any, annotation: inspect.Parameter.annotation
+) -> bool:
+ # None type is expected to be handled external to this function
+ if _is_builtin_primitive_or_compound(annotation):
+ return isinstance(default_value, annotation)
+
+ if (
+ isinstance(annotation, _GenericAlias)
+ or isinstance(annotation, typing_types.GenericAlias)
+ or isinstance(annotation, typing_types.UnionType)
+ ):
+ origin = get_origin(annotation)
+ if origin in (Union, typing_types.UnionType):
+ return any(
+ _is_default_value_compatible(default_value, arg)
+ for arg in get_args(annotation)
+ )
+
+ if origin is dict:
+ return isinstance(default_value, dict)
+
+ if origin is list:
+ if not isinstance(default_value, list):
+ return False
+ # most tricky case, element in list is union type
+ # need to apply any logic within all
+ # see test case test_generic_alias_complex_array_with_default_value
+ # a: typing.List[int | str | float | bool]
+ # default_value: [1, 'a', 1.1, True]
+ return all(
+ any(
+ _is_default_value_compatible(item, arg)
+ for arg in get_args(annotation)
+ )
+ for item in default_value
+ )
+
+ if origin is Literal:
+ return default_value in get_args(annotation)
+
+ # return False for any other unrecognized annotation
+ # let caller handle the raise
+ return False
+
+
+def _parse_schema_from_parameter(
+ variant: str, param: inspect.Parameter, func_name: str
+) -> types.Schema:
+ """parse schema from parameter.
+
+ from the simplest case to the most complex case.
+ """
+ schema = types.Schema()
+ default_value_error_msg = (
+ f'Default value {param.default} of parameter {param} of function'
+ f' {func_name} is not compatible with the parameter annotation'
+ f' {param.annotation}.'
+ )
+ if _is_builtin_primitive_or_compound(param.annotation):
+ if param.default is not inspect.Parameter.empty:
+ if not _is_default_value_compatible(param.default, param.annotation):
+ raise ValueError(default_value_error_msg)
+ schema.default = param.default
+ schema.type = _py_builtin_type_to_schema_type[param.annotation]
+ _raise_if_schema_unsupported(variant, schema)
+ return schema
+ if (
+ isinstance(param.annotation, typing_types.UnionType)
+ # only parse simple UnionType, example int | str | float | bool
+ # complex types.UnionType will be invoked in raise branch
+ and all(
+ (_is_builtin_primitive_or_compound(arg) or arg is type(None))
+ for arg in get_args(param.annotation)
+ )
+ ):
+ schema.type = 'OBJECT'
+ schema.any_of = []
+ unique_types = set()
+ for arg in get_args(param.annotation):
+ if arg.__name__ == 'NoneType': # Optional type
+ schema.nullable = True
+ continue
+ schema_in_any_of = _parse_schema_from_parameter(
+ variant,
+ inspect.Parameter(
+ 'item', inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=arg
+ ),
+ func_name,
+ )
+ if (
+ schema_in_any_of.model_dump_json(exclude_none=True)
+ not in unique_types
+ ):
+ schema.any_of.append(schema_in_any_of)
+ unique_types.add(schema_in_any_of.model_dump_json(exclude_none=True))
+ if len(schema.any_of) == 1: # param: list | None -> Array
+ schema.type = schema.any_of[0].type
+ schema.any_of = None
+ if (
+ param.default is not inspect.Parameter.empty
+ and param.default is not None
+ ):
+ if not _is_default_value_compatible(param.default, param.annotation):
+ raise ValueError(default_value_error_msg)
+ schema.default = param.default
+ _raise_if_schema_unsupported(variant, schema)
+ return schema
+ if isinstance(param.annotation, _GenericAlias) or isinstance(
+ param.annotation, typing_types.GenericAlias
+ ):
+ origin = get_origin(param.annotation)
+ args = get_args(param.annotation)
+ if origin is dict:
+ schema.type = 'OBJECT'
+ if param.default is not inspect.Parameter.empty:
+ if not _is_default_value_compatible(param.default, param.annotation):
+ raise ValueError(default_value_error_msg)
+ schema.default = param.default
+ _raise_if_schema_unsupported(variant, schema)
+ return schema
+ if origin is Literal:
+ if not all(isinstance(arg, str) for arg in args):
+ raise ValueError(
+ f'Literal type {param.annotation} must be a list of strings.'
+ )
+ schema.type = 'STRING'
+ schema.enum = list(args)
+ if param.default is not inspect.Parameter.empty:
+ if not _is_default_value_compatible(param.default, param.annotation):
+ raise ValueError(default_value_error_msg)
+ schema.default = param.default
+ _raise_if_schema_unsupported(variant, schema)
+ return schema
+ if origin is list:
+ schema.type = 'ARRAY'
+ schema.items = _parse_schema_from_parameter(
+ variant,
+ inspect.Parameter(
+ 'item',
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
+ annotation=args[0],
+ ),
+ func_name,
+ )
+ if param.default is not inspect.Parameter.empty:
+ if not _is_default_value_compatible(param.default, param.annotation):
+ raise ValueError(default_value_error_msg)
+ schema.default = param.default
+ _raise_if_schema_unsupported(variant, schema)
+ return schema
+ if origin is Union:
+ schema.any_of = []
+ schema.type = 'OBJECT'
+ unique_types = set()
+ for arg in args:
+ if arg.__name__ == 'NoneType': # Optional type
+ schema.nullable = True
+ continue
+ schema_in_any_of = _parse_schema_from_parameter(
+ variant,
+ inspect.Parameter(
+ 'item',
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
+ annotation=arg,
+ ),
+ func_name,
+ )
+ if (
+ schema_in_any_of.model_dump_json(exclude_none=True)
+ not in unique_types
+ ):
+ schema.any_of.append(schema_in_any_of)
+ unique_types.add(schema_in_any_of.model_dump_json(exclude_none=True))
+ if len(schema.any_of) == 1: # param: Union[List, None] -> Array
+ schema.type = schema.any_of[0].type
+ schema.any_of = None
+ if (
+ param.default is not None
+ and param.default is not inspect.Parameter.empty
+ ):
+ if not _is_default_value_compatible(param.default, param.annotation):
+ raise ValueError(default_value_error_msg)
+ schema.default = param.default
+ _raise_if_schema_unsupported(variant, schema)
+ return schema
+ # all other generic alias will be invoked in raise branch
+ if (
+ inspect.isclass(param.annotation)
+ # for user defined class, we only support pydantic model
+ and issubclass(param.annotation, pydantic.BaseModel)
+ ):
+ if (
+ param.default is not inspect.Parameter.empty
+ and param.default is not None
+ ):
+ schema.default = param.default
+ schema.type = 'OBJECT'
+ schema.properties = {}
+ for field_name, field_info in param.annotation.model_fields.items():
+ schema.properties[field_name] = _parse_schema_from_parameter(
+ variant,
+ inspect.Parameter(
+ field_name,
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
+ annotation=field_info.annotation,
+ ),
+ func_name,
+ )
+ _raise_if_schema_unsupported(variant, schema)
+ return schema
+ raise ValueError(
+ f'Failed to parse the parameter {param} of function {func_name} for'
+ ' automatic function calling.Automatic function calling works best with'
+ ' simpler function signature schema,consider manually parse your'
+ f' function declaration for function {func_name}.'
+ )
+
+
+def _get_required_fields(schema: types.Schema) -> list[str]:
+ if not schema.properties:
+ return
+ return [
+ field_name
+ for field_name, field_schema in schema.properties.items()
+ if not field_schema.nullable and field_schema.default is None
+ ]