aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/job_limits.py
blob: 850e9b3d6d6adb19cb09f3b4df4816629686efc8 (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

# pylint: disable=unused-argument

from marshmallow import fields, post_load, validate

from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta


class CommandJobLimitsSchema(metaclass=PatchedSchemaMeta):
    timeout = fields.Int()

    @post_load
    def make(self, data, **kwargs):
        from azure.ai.ml.entities import CommandJobLimits

        return CommandJobLimits(**data)


class SweepJobLimitsSchema(metaclass=PatchedSchemaMeta):
    max_concurrent_trials = fields.Int(metadata={"description": "Sweep Job max concurrent trials."})
    max_total_trials = fields.Int(
        metadata={"description": "Sweep Job max total trials."},
        required=True,
    )
    timeout = fields.Int(
        metadata={"description": "The max run duration in Seconds, after which the job will be cancelled."}
    )
    trial_timeout = fields.Int(metadata={"description": "Sweep Job Trial timeout value."})

    @post_load
    def make(self, data, **kwargs):
        from azure.ai.ml.sweep import SweepJobLimits

        return SweepJobLimits(**data)


class DoWhileLimitsSchema(metaclass=PatchedSchemaMeta):
    max_iteration_count = fields.Int(
        metadata={"description": "The max iteration for do_while loop."},
        validate=validate.Range(min=1, max=1000),
        required=True,
    )