diff --git a/mindcv/optim/adamw.py b/mindcv/optim/adamw.py index 29be61ed..19d79c59 100644 --- a/mindcv/optim/adamw.py +++ b/mindcv/optim/adamw.py @@ -9,6 +9,11 @@ from mindspore.nn.optim import Optimizer from mindspore.nn.optim.optimizer import opt_init_args_register +try: + from mindspore import jit +except ImportError: + from mindspore import ms_function as jit + def _check_param_value(beta1, beta2, eps, prim_name): """Check the type of inputs.""" @@ -154,6 +159,7 @@ def __init__( self.reciprocal_scale = Tensor(1.0 / loss_scale, ms.float32) self.clip = clip + @jit def construct(self, gradients): lr = self.get_lr() gradients = scale_grad(gradients, self.reciprocal_scale) diff --git a/mindcv/optim/adan.py b/mindcv/optim/adan.py index 7494c67b..cd3fe1e8 100644 --- a/mindcv/optim/adan.py +++ b/mindcv/optim/adan.py @@ -5,6 +5,11 @@ from mindspore.common.tensor import Tensor from mindspore.nn.optim.optimizer import Optimizer, opt_init_args_register +try: + from mindspore import jit +except ImportError: + from mindspore import ms_function as jit + _adan_opt = ops.MultitypeFuncGraph("adan_opt") @@ -144,6 +149,7 @@ def __init__( self.weight_decay = Tensor(weight_decay, mstype.float32) + @jit def construct(self, gradients): params = self._parameters moment1 = self.moment1 diff --git a/mindcv/optim/lion.py b/mindcv/optim/lion.py index 3de410e7..ad079c5e 100644 --- a/mindcv/optim/lion.py +++ b/mindcv/optim/lion.py @@ -8,6 +8,11 @@ from mindspore.nn.optim import Optimizer from mindspore.nn.optim.optimizer import opt_init_args_register +try: + from mindspore import jit +except ImportError: + from mindspore import ms_function as jit + def _check_param_value(beta1, beta2, prim_name): """Check the type of inputs.""" @@ -142,6 +147,7 @@ def __init__( self.reciprocal_scale = Tensor(1.0 / loss_scale, ms.float32) self.clip = clip + @jit def construct(self, gradients): lr = self.get_lr() gradients = scale_grad(gradients, self.reciprocal_scale) diff --git a/mindcv/optim/nadam.py b/mindcv/optim/nadam.py index a257bf8c..e74c4b32 100644 --- a/mindcv/optim/nadam.py +++ b/mindcv/optim/nadam.py @@ -9,6 +9,11 @@ from mindspore.nn.optim import Optimizer from mindspore.nn.optim.optimizer import opt_init_args_register +try: + from mindspore import jit +except ImportError: + from mindspore import ms_function as jit + def _check_param_value(beta1, beta2, eps, prim_name): """Check the type of inputs.""" @@ -48,6 +53,7 @@ def __init__( self.mu_schedule = Parameter(initializer(1, [1], ms.float32), name="mu_schedule") self.beta2_power = Parameter(initializer(1, [1], ms.float32), name="beta2_power") + @jit def construct(self, gradients): lr = self.get_lr() params = self.parameters