aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/pydantic/_internal/_discriminated_union.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/pydantic/_internal/_discriminated_union.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to '.venv/lib/python3.12/site-packages/pydantic/_internal/_discriminated_union.py')
-rw-r--r--.venv/lib/python3.12/site-packages/pydantic/_internal/_discriminated_union.py503
1 files changed, 503 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/pydantic/_internal/_discriminated_union.py b/.venv/lib/python3.12/site-packages/pydantic/_internal/_discriminated_union.py
new file mode 100644
index 00000000..29a50a5a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/pydantic/_internal/_discriminated_union.py
@@ -0,0 +1,503 @@
+from __future__ import annotations as _annotations
+
+from typing import TYPE_CHECKING, Any, Hashable, Sequence
+
+from pydantic_core import CoreSchema, core_schema
+
+from ..errors import PydanticUserError
+from . import _core_utils
+from ._core_utils import (
+ CoreSchemaField,
+ collect_definitions,
+)
+
+if TYPE_CHECKING:
+ from ..types import Discriminator
+
+CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY = 'pydantic.internal.union_discriminator'
+
+
+class MissingDefinitionForUnionRef(Exception):
+ """Raised when applying a discriminated union discriminator to a schema
+ requires a definition that is not yet defined
+ """
+
+ def __init__(self, ref: str) -> None:
+ self.ref = ref
+ super().__init__(f'Missing definition for ref {self.ref!r}')
+
+
+def set_discriminator_in_metadata(schema: CoreSchema, discriminator: Any) -> None:
+ schema.setdefault('metadata', {})
+ metadata = schema.get('metadata')
+ assert metadata is not None
+ metadata[CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY] = discriminator
+
+
+def apply_discriminators(schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
+ # We recursively walk through the `schema` passed to `apply_discriminators`, applying discriminators
+ # where necessary at each level. During this recursion, we allow references to be resolved from the definitions
+ # that are originally present on the original, outermost `schema`. Before `apply_discriminators` is called,
+ # `simplify_schema_references` is called on the schema (in the `clean_schema` function),
+ # which often puts the definitions in the outermost schema.
+ global_definitions: dict[str, CoreSchema] = collect_definitions(schema)
+
+ def inner(s: core_schema.CoreSchema, recurse: _core_utils.Recurse) -> core_schema.CoreSchema:
+ nonlocal global_definitions
+
+ s = recurse(s, inner)
+ if s['type'] == 'tagged-union':
+ return s
+
+ metadata = s.get('metadata', {})
+ discriminator = metadata.pop(CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY, None)
+ if discriminator is not None:
+ s = apply_discriminator(s, discriminator, global_definitions)
+ return s
+
+ return _core_utils.walk_core_schema(schema, inner, copy=False)
+
+
+def apply_discriminator(
+ schema: core_schema.CoreSchema,
+ discriminator: str | Discriminator,
+ definitions: dict[str, core_schema.CoreSchema] | None = None,
+) -> core_schema.CoreSchema:
+ """Applies the discriminator and returns a new core schema.
+
+ Args:
+ schema: The input schema.
+ discriminator: The name of the field which will serve as the discriminator.
+ definitions: A mapping of schema ref to schema.
+
+ Returns:
+ The new core schema.
+
+ Raises:
+ TypeError:
+ - If `discriminator` is used with invalid union variant.
+ - If `discriminator` is used with `Union` type with one variant.
+ - If `discriminator` value mapped to multiple choices.
+ MissingDefinitionForUnionRef:
+ If the definition for ref is missing.
+ PydanticUserError:
+ - If a model in union doesn't have a discriminator field.
+ - If discriminator field has a non-string alias.
+ - If discriminator fields have different aliases.
+ - If discriminator field not of type `Literal`.
+ """
+ from ..types import Discriminator
+
+ if isinstance(discriminator, Discriminator):
+ if isinstance(discriminator.discriminator, str):
+ discriminator = discriminator.discriminator
+ else:
+ return discriminator._convert_schema(schema)
+
+ return _ApplyInferredDiscriminator(discriminator, definitions or {}).apply(schema)
+
+
+class _ApplyInferredDiscriminator:
+ """This class is used to convert an input schema containing a union schema into one where that union is
+ replaced with a tagged-union, with all the associated debugging and performance benefits.
+
+ This is done by:
+ * Validating that the input schema is compatible with the provided discriminator
+ * Introspecting the schema to determine which discriminator values should map to which union choices
+ * Handling various edge cases such as 'definitions', 'default', 'nullable' schemas, and more
+
+ I have chosen to implement the conversion algorithm in this class, rather than a function,
+ to make it easier to maintain state while recursively walking the provided CoreSchema.
+ """
+
+ def __init__(self, discriminator: str, definitions: dict[str, core_schema.CoreSchema]):
+ # `discriminator` should be the name of the field which will serve as the discriminator.
+ # It must be the python name of the field, and *not* the field's alias. Note that as of now,
+ # all members of a discriminated union _must_ use a field with the same name as the discriminator.
+ # This may change if/when we expose a way to manually specify the TaggedUnionSchema's choices.
+ self.discriminator = discriminator
+
+ # `definitions` should contain a mapping of schema ref to schema for all schemas which might
+ # be referenced by some choice
+ self.definitions = definitions
+
+ # `_discriminator_alias` will hold the value, if present, of the alias for the discriminator
+ #
+ # Note: following the v1 implementation, we currently disallow the use of different aliases
+ # for different choices. This is not a limitation of pydantic_core, but if we try to handle
+ # this, the inference logic gets complicated very quickly, and could result in confusing
+ # debugging challenges for users making subtle mistakes.
+ #
+ # Rather than trying to do the most powerful inference possible, I think we should eventually
+ # expose a way to more-manually control the way the TaggedUnionSchema is constructed through
+ # the use of a new type which would be placed as an Annotation on the Union type. This would
+ # provide the full flexibility/power of pydantic_core's TaggedUnionSchema where necessary for
+ # more complex cases, without over-complicating the inference logic for the common cases.
+ self._discriminator_alias: str | None = None
+
+ # `_should_be_nullable` indicates whether the converted union has `None` as an allowed value.
+ # If `None` is an acceptable value of the (possibly-wrapped) union, we ignore it while
+ # constructing the TaggedUnionSchema, but set the `_should_be_nullable` attribute to True.
+ # Once we have constructed the TaggedUnionSchema, if `_should_be_nullable` is True, we ensure
+ # that the final schema gets wrapped as a NullableSchema. This has the same semantics on the
+ # python side, but resolves the issue that `None` cannot correspond to any discriminator values.
+ self._should_be_nullable = False
+
+ # `_is_nullable` is used to track if the final produced schema will definitely be nullable;
+ # we set it to True if the input schema is wrapped in a nullable schema that we know will be preserved
+ # as an indication that, even if None is discovered as one of the union choices, we will not need to wrap
+ # the final value in another nullable schema.
+ #
+ # This is more complicated than just checking for the final outermost schema having type 'nullable' thanks
+ # to the possible presence of other wrapper schemas such as DefinitionsSchema, WithDefaultSchema, etc.
+ self._is_nullable = False
+
+ # `_choices_to_handle` serves as a stack of choices to add to the tagged union. Initially, choices
+ # from the union in the wrapped schema will be appended to this list, and the recursive choice-handling
+ # algorithm may add more choices to this stack as (nested) unions are encountered.
+ self._choices_to_handle: list[core_schema.CoreSchema] = []
+
+ # `_tagged_union_choices` is built during the call to `apply`, and will hold the choices to be included
+ # in the output TaggedUnionSchema that will replace the union from the input schema
+ self._tagged_union_choices: dict[Hashable, core_schema.CoreSchema] = {}
+
+ # `_used` is changed to True after applying the discriminator to prevent accidental reuse
+ self._used = False
+
+ def apply(self, schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
+ """Return a new CoreSchema based on `schema` that uses a tagged-union with the discriminator provided
+ to this class.
+
+ Args:
+ schema: The input schema.
+
+ Returns:
+ The new core schema.
+
+ Raises:
+ TypeError:
+ - If `discriminator` is used with invalid union variant.
+ - If `discriminator` is used with `Union` type with one variant.
+ - If `discriminator` value mapped to multiple choices.
+ ValueError:
+ If the definition for ref is missing.
+ PydanticUserError:
+ - If a model in union doesn't have a discriminator field.
+ - If discriminator field has a non-string alias.
+ - If discriminator fields have different aliases.
+ - If discriminator field not of type `Literal`.
+ """
+ assert not self._used
+ schema = self._apply_to_root(schema)
+ if self._should_be_nullable and not self._is_nullable:
+ schema = core_schema.nullable_schema(schema)
+ self._used = True
+ return schema
+
+ def _apply_to_root(self, schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
+ """This method handles the outer-most stage of recursion over the input schema:
+ unwrapping nullable or definitions schemas, and calling the `_handle_choice`
+ method iteratively on the choices extracted (recursively) from the possibly-wrapped union.
+ """
+ if schema['type'] == 'nullable':
+ self._is_nullable = True
+ wrapped = self._apply_to_root(schema['schema'])
+ nullable_wrapper = schema.copy()
+ nullable_wrapper['schema'] = wrapped
+ return nullable_wrapper
+
+ if schema['type'] == 'definitions':
+ wrapped = self._apply_to_root(schema['schema'])
+ definitions_wrapper = schema.copy()
+ definitions_wrapper['schema'] = wrapped
+ return definitions_wrapper
+
+ if schema['type'] != 'union':
+ # If the schema is not a union, it probably means it just had a single member and
+ # was flattened by pydantic_core.
+ # However, it still may make sense to apply the discriminator to this schema,
+ # as a way to get discriminated-union-style error messages, so we allow this here.
+ schema = core_schema.union_schema([schema])
+
+ # Reverse the choices list before extending the stack so that they get handled in the order they occur
+ choices_schemas = [v[0] if isinstance(v, tuple) else v for v in schema['choices'][::-1]]
+ self._choices_to_handle.extend(choices_schemas)
+ while self._choices_to_handle:
+ choice = self._choices_to_handle.pop()
+ self._handle_choice(choice)
+
+ if self._discriminator_alias is not None and self._discriminator_alias != self.discriminator:
+ # * We need to annotate `discriminator` as a union here to handle both branches of this conditional
+ # * We need to annotate `discriminator` as list[list[str | int]] and not list[list[str]] due to the
+ # invariance of list, and because list[list[str | int]] is the type of the discriminator argument
+ # to tagged_union_schema below
+ # * See the docstring of pydantic_core.core_schema.tagged_union_schema for more details about how to
+ # interpret the value of the discriminator argument to tagged_union_schema. (The list[list[str]] here
+ # is the appropriate way to provide a list of fallback attributes to check for a discriminator value.)
+ discriminator: str | list[list[str | int]] = [[self.discriminator], [self._discriminator_alias]]
+ else:
+ discriminator = self.discriminator
+ return core_schema.tagged_union_schema(
+ choices=self._tagged_union_choices,
+ discriminator=discriminator,
+ custom_error_type=schema.get('custom_error_type'),
+ custom_error_message=schema.get('custom_error_message'),
+ custom_error_context=schema.get('custom_error_context'),
+ strict=False,
+ from_attributes=True,
+ ref=schema.get('ref'),
+ metadata=schema.get('metadata'),
+ serialization=schema.get('serialization'),
+ )
+
+ def _handle_choice(self, choice: core_schema.CoreSchema) -> None:
+ """This method handles the "middle" stage of recursion over the input schema.
+ Specifically, it is responsible for handling each choice of the outermost union
+ (and any "coalesced" choices obtained from inner unions).
+
+ Here, "handling" entails:
+ * Coalescing nested unions and compatible tagged-unions
+ * Tracking the presence of 'none' and 'nullable' schemas occurring as choices
+ * Validating that each allowed discriminator value maps to a unique choice
+ * Updating the _tagged_union_choices mapping that will ultimately be used to build the TaggedUnionSchema.
+ """
+ if choice['type'] == 'definition-ref':
+ if choice['schema_ref'] not in self.definitions:
+ raise MissingDefinitionForUnionRef(choice['schema_ref'])
+
+ if choice['type'] == 'none':
+ self._should_be_nullable = True
+ elif choice['type'] == 'definitions':
+ self._handle_choice(choice['schema'])
+ elif choice['type'] == 'nullable':
+ self._should_be_nullable = True
+ self._handle_choice(choice['schema']) # unwrap the nullable schema
+ elif choice['type'] == 'union':
+ # Reverse the choices list before extending the stack so that they get handled in the order they occur
+ choices_schemas = [v[0] if isinstance(v, tuple) else v for v in choice['choices'][::-1]]
+ self._choices_to_handle.extend(choices_schemas)
+ elif choice['type'] not in {
+ 'model',
+ 'typed-dict',
+ 'tagged-union',
+ 'lax-or-strict',
+ 'dataclass',
+ 'dataclass-args',
+ 'definition-ref',
+ } and not _core_utils.is_function_with_inner_schema(choice):
+ # We should eventually handle 'definition-ref' as well
+ raise TypeError(
+ f'{choice["type"]!r} is not a valid discriminated union variant;'
+ ' should be a `BaseModel` or `dataclass`'
+ )
+ else:
+ if choice['type'] == 'tagged-union' and self._is_discriminator_shared(choice):
+ # In this case, this inner tagged-union is compatible with the outer tagged-union,
+ # and its choices can be coalesced into the outer TaggedUnionSchema.
+ subchoices = [x for x in choice['choices'].values() if not isinstance(x, (str, int))]
+ # Reverse the choices list before extending the stack so that they get handled in the order they occur
+ self._choices_to_handle.extend(subchoices[::-1])
+ return
+
+ inferred_discriminator_values = self._infer_discriminator_values_for_choice(choice, source_name=None)
+ self._set_unique_choice_for_values(choice, inferred_discriminator_values)
+
+ def _is_discriminator_shared(self, choice: core_schema.TaggedUnionSchema) -> bool:
+ """This method returns a boolean indicating whether the discriminator for the `choice`
+ is the same as that being used for the outermost tagged union. This is used to
+ determine whether this TaggedUnionSchema choice should be "coalesced" into the top level,
+ or whether it should be treated as a separate (nested) choice.
+ """
+ inner_discriminator = choice['discriminator']
+ return inner_discriminator == self.discriminator or (
+ isinstance(inner_discriminator, list)
+ and (self.discriminator in inner_discriminator or [self.discriminator] in inner_discriminator)
+ )
+
+ def _infer_discriminator_values_for_choice( # noqa C901
+ self, choice: core_schema.CoreSchema, source_name: str | None
+ ) -> list[str | int]:
+ """This function recurses over `choice`, extracting all discriminator values that should map to this choice.
+
+ `model_name` is accepted for the purpose of producing useful error messages.
+ """
+ if choice['type'] == 'definitions':
+ return self._infer_discriminator_values_for_choice(choice['schema'], source_name=source_name)
+ elif choice['type'] == 'function-plain':
+ raise TypeError(
+ f'{choice["type"]!r} is not a valid discriminated union variant;'
+ ' should be a `BaseModel` or `dataclass`'
+ )
+ elif _core_utils.is_function_with_inner_schema(choice):
+ return self._infer_discriminator_values_for_choice(choice['schema'], source_name=source_name)
+ elif choice['type'] == 'lax-or-strict':
+ return sorted(
+ set(
+ self._infer_discriminator_values_for_choice(choice['lax_schema'], source_name=None)
+ + self._infer_discriminator_values_for_choice(choice['strict_schema'], source_name=None)
+ )
+ )
+
+ elif choice['type'] == 'tagged-union':
+ values: list[str | int] = []
+ # Ignore str/int "choices" since these are just references to other choices
+ subchoices = [x for x in choice['choices'].values() if not isinstance(x, (str, int))]
+ for subchoice in subchoices:
+ subchoice_values = self._infer_discriminator_values_for_choice(subchoice, source_name=None)
+ values.extend(subchoice_values)
+ return values
+
+ elif choice['type'] == 'union':
+ values = []
+ for subchoice in choice['choices']:
+ subchoice_schema = subchoice[0] if isinstance(subchoice, tuple) else subchoice
+ subchoice_values = self._infer_discriminator_values_for_choice(subchoice_schema, source_name=None)
+ values.extend(subchoice_values)
+ return values
+
+ elif choice['type'] == 'nullable':
+ self._should_be_nullable = True
+ return self._infer_discriminator_values_for_choice(choice['schema'], source_name=None)
+
+ elif choice['type'] == 'model':
+ return self._infer_discriminator_values_for_choice(choice['schema'], source_name=choice['cls'].__name__)
+
+ elif choice['type'] == 'dataclass':
+ return self._infer_discriminator_values_for_choice(choice['schema'], source_name=choice['cls'].__name__)
+
+ elif choice['type'] == 'model-fields':
+ return self._infer_discriminator_values_for_model_choice(choice, source_name=source_name)
+
+ elif choice['type'] == 'dataclass-args':
+ return self._infer_discriminator_values_for_dataclass_choice(choice, source_name=source_name)
+
+ elif choice['type'] == 'typed-dict':
+ return self._infer_discriminator_values_for_typed_dict_choice(choice, source_name=source_name)
+
+ elif choice['type'] == 'definition-ref':
+ schema_ref = choice['schema_ref']
+ if schema_ref not in self.definitions:
+ raise MissingDefinitionForUnionRef(schema_ref)
+ return self._infer_discriminator_values_for_choice(self.definitions[schema_ref], source_name=source_name)
+ else:
+ raise TypeError(
+ f'{choice["type"]!r} is not a valid discriminated union variant;'
+ ' should be a `BaseModel` or `dataclass`'
+ )
+
+ def _infer_discriminator_values_for_typed_dict_choice(
+ self, choice: core_schema.TypedDictSchema, source_name: str | None = None
+ ) -> list[str | int]:
+ """This method just extracts the _infer_discriminator_values_for_choice logic specific to TypedDictSchema
+ for the sake of readability.
+ """
+ source = 'TypedDict' if source_name is None else f'TypedDict {source_name!r}'
+ field = choice['fields'].get(self.discriminator)
+ if field is None:
+ raise PydanticUserError(
+ f'{source} needs a discriminator field for key {self.discriminator!r}', code='discriminator-no-field'
+ )
+ return self._infer_discriminator_values_for_field(field, source)
+
+ def _infer_discriminator_values_for_model_choice(
+ self, choice: core_schema.ModelFieldsSchema, source_name: str | None = None
+ ) -> list[str | int]:
+ source = 'ModelFields' if source_name is None else f'Model {source_name!r}'
+ field = choice['fields'].get(self.discriminator)
+ if field is None:
+ raise PydanticUserError(
+ f'{source} needs a discriminator field for key {self.discriminator!r}', code='discriminator-no-field'
+ )
+ return self._infer_discriminator_values_for_field(field, source)
+
+ def _infer_discriminator_values_for_dataclass_choice(
+ self, choice: core_schema.DataclassArgsSchema, source_name: str | None = None
+ ) -> list[str | int]:
+ source = 'DataclassArgs' if source_name is None else f'Dataclass {source_name!r}'
+ for field in choice['fields']:
+ if field['name'] == self.discriminator:
+ break
+ else:
+ raise PydanticUserError(
+ f'{source} needs a discriminator field for key {self.discriminator!r}', code='discriminator-no-field'
+ )
+ return self._infer_discriminator_values_for_field(field, source)
+
+ def _infer_discriminator_values_for_field(self, field: CoreSchemaField, source: str) -> list[str | int]:
+ if field['type'] == 'computed-field':
+ # This should never occur as a discriminator, as it is only relevant to serialization
+ return []
+ alias = field.get('validation_alias', self.discriminator)
+ if not isinstance(alias, str):
+ raise PydanticUserError(
+ f'Alias {alias!r} is not supported in a discriminated union', code='discriminator-alias-type'
+ )
+ if self._discriminator_alias is None:
+ self._discriminator_alias = alias
+ elif self._discriminator_alias != alias:
+ raise PydanticUserError(
+ f'Aliases for discriminator {self.discriminator!r} must be the same '
+ f'(got {alias}, {self._discriminator_alias})',
+ code='discriminator-alias',
+ )
+ return self._infer_discriminator_values_for_inner_schema(field['schema'], source)
+
+ def _infer_discriminator_values_for_inner_schema(
+ self, schema: core_schema.CoreSchema, source: str
+ ) -> list[str | int]:
+ """When inferring discriminator values for a field, we typically extract the expected values from a literal
+ schema. This function does that, but also handles nested unions and defaults.
+ """
+ if schema['type'] == 'literal':
+ return schema['expected']
+
+ elif schema['type'] == 'union':
+ # Generally when multiple values are allowed they should be placed in a single `Literal`, but
+ # we add this case to handle the situation where a field is annotated as a `Union` of `Literal`s.
+ # For example, this lets us handle `Union[Literal['key'], Union[Literal['Key'], Literal['KEY']]]`
+ values: list[Any] = []
+ for choice in schema['choices']:
+ choice_schema = choice[0] if isinstance(choice, tuple) else choice
+ choice_values = self._infer_discriminator_values_for_inner_schema(choice_schema, source)
+ values.extend(choice_values)
+ return values
+
+ elif schema['type'] == 'default':
+ # This will happen if the field has a default value; we ignore it while extracting the discriminator values
+ return self._infer_discriminator_values_for_inner_schema(schema['schema'], source)
+
+ elif schema['type'] == 'function-after':
+ # After validators don't affect the discriminator values
+ return self._infer_discriminator_values_for_inner_schema(schema['schema'], source)
+
+ elif schema['type'] in {'function-before', 'function-wrap', 'function-plain'}:
+ validator_type = repr(schema['type'].split('-')[1])
+ raise PydanticUserError(
+ f'Cannot use a mode={validator_type} validator in the'
+ f' discriminator field {self.discriminator!r} of {source}',
+ code='discriminator-validator',
+ )
+
+ else:
+ raise PydanticUserError(
+ f'{source} needs field {self.discriminator!r} to be of type `Literal`',
+ code='discriminator-needs-literal',
+ )
+
+ def _set_unique_choice_for_values(self, choice: core_schema.CoreSchema, values: Sequence[str | int]) -> None:
+ """This method updates `self.tagged_union_choices` so that all provided (discriminator) `values` map to the
+ provided `choice`, validating that none of these values already map to another (different) choice.
+ """
+ for discriminator_value in values:
+ if discriminator_value in self._tagged_union_choices:
+ # It is okay if `value` is already in tagged_union_choices as long as it maps to the same value.
+ # Because tagged_union_choices may map values to other values, we need to walk the choices dict
+ # until we get to a "real" choice, and confirm that is equal to the one assigned.
+ existing_choice = self._tagged_union_choices[discriminator_value]
+ if existing_choice != choice:
+ raise TypeError(
+ f'Value {discriminator_value!r} for discriminator '
+ f'{self.discriminator!r} mapped to multiple choices'
+ )
+ else:
+ self._tagged_union_choices[discriminator_value] = choice