about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/model.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/model.py')
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/model.py65
1 files changed, 65 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/model.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/model.py
new file mode 100644
index 00000000..60c17f63
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/model.py
@@ -0,0 +1,65 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+import logging
+
+from marshmallow import fields, post_load, pre_dump
+
+from azure.ai.ml._schema.core.fields import ExperimentalField, NestedField
+from azure.ai.ml._schema.core.intellectual_property import IntellectualPropertySchema
+from azure.ai.ml._schema.core.schema import PathAwareSchema
+from azure.ai.ml._schema.job import CreationContextSchema
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, AssetTypes, AzureMLResourceType
+
+from ..core.fields import ArmVersionedStr, StringTransformedEnum, VersionField
+
+module_logger = logging.getLogger(__name__)
+
+
+class ModelSchema(PathAwareSchema):
+    name = fields.Str(required=True)
+    id = ArmVersionedStr(azureml_type=AzureMLResourceType.MODEL, dump_only=True)
+    type = StringTransformedEnum(
+        allowed_values=[
+            AssetTypes.CUSTOM_MODEL,
+            AssetTypes.MLFLOW_MODEL,
+            AssetTypes.TRITON_MODEL,
+        ],
+        metadata={"description": "The storage format for this entity. Used for NCD."},
+    )
+    path = fields.Str()
+    version = VersionField()
+    description = fields.Str()
+    properties = fields.Dict()
+    tags = fields.Dict()
+    stage = fields.Str()
+    utc_time_created = fields.DateTime(format="iso", dump_only=True)
+    flavors = fields.Dict()
+    creation_context = NestedField(CreationContextSchema, dump_only=True)
+    job_name = fields.Str(dump_only=True)
+    latest_version = fields.Str(dump_only=True)
+    datastore = fields.Str(metadata={"description": "Name of the datastore to upload to."}, required=False)
+    intellectual_property = ExperimentalField(NestedField(IntellectualPropertySchema, required=False), dump_only=True)
+    system_metadata = fields.Dict()
+
+    @pre_dump
+    def validate(self, data, **kwargs):
+        if data._intellectual_property:  # pylint: disable=protected-access
+            ipp_field = data._intellectual_property  # pylint: disable=protected-access
+            if ipp_field:
+                setattr(data, "intellectual_property", ipp_field)
+        return data
+
+    @post_load
+    def make(self, data, **kwargs):
+        from azure.ai.ml.entities._assets import Model
+
+        return Model(base_path=self.context[BASE_PATH_CONTEXT_KEY], **data)
+
+
+class AnonymousModelSchema(ModelSchema):
+    name = fields.Str()
+    version = VersionField()