Skip to content

Commit

Permalink
perf: optimize performance for pynative (#730)
Browse files Browse the repository at this point in the history
add `jit` decorator to optimizer
  • Loading branch information
chujinjin101 authored Sep 6, 2023
1 parent 1b305f1 commit e7b26a4
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 0 deletions.
6 changes: 6 additions & 0 deletions mindcv/optim/adamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
from mindspore.nn.optim import Optimizer
from mindspore.nn.optim.optimizer import opt_init_args_register

try:
from mindspore import jit
except ImportError:
from mindspore import ms_function as jit


def _check_param_value(beta1, beta2, eps, prim_name):
"""Check the type of inputs."""
Expand Down Expand Up @@ -154,6 +159,7 @@ def __init__(
self.reciprocal_scale = Tensor(1.0 / loss_scale, ms.float32)
self.clip = clip

@jit
def construct(self, gradients):
lr = self.get_lr()
gradients = scale_grad(gradients, self.reciprocal_scale)
Expand Down
6 changes: 6 additions & 0 deletions mindcv/optim/adan.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@
from mindspore.common.tensor import Tensor
from mindspore.nn.optim.optimizer import Optimizer, opt_init_args_register

try:
from mindspore import jit
except ImportError:
from mindspore import ms_function as jit

_adan_opt = ops.MultitypeFuncGraph("adan_opt")


Expand Down Expand Up @@ -144,6 +149,7 @@ def __init__(

self.weight_decay = Tensor(weight_decay, mstype.float32)

@jit
def construct(self, gradients):
params = self._parameters
moment1 = self.moment1
Expand Down
6 changes: 6 additions & 0 deletions mindcv/optim/lion.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
from mindspore.nn.optim import Optimizer
from mindspore.nn.optim.optimizer import opt_init_args_register

try:
from mindspore import jit
except ImportError:
from mindspore import ms_function as jit


def _check_param_value(beta1, beta2, prim_name):
"""Check the type of inputs."""
Expand Down Expand Up @@ -142,6 +147,7 @@ def __init__(
self.reciprocal_scale = Tensor(1.0 / loss_scale, ms.float32)
self.clip = clip

@jit
def construct(self, gradients):
lr = self.get_lr()
gradients = scale_grad(gradients, self.reciprocal_scale)
Expand Down
6 changes: 6 additions & 0 deletions mindcv/optim/nadam.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
from mindspore.nn.optim import Optimizer
from mindspore.nn.optim.optimizer import opt_init_args_register

try:
from mindspore import jit
except ImportError:
from mindspore import ms_function as jit


def _check_param_value(beta1, beta2, eps, prim_name):
"""Check the type of inputs."""
Expand Down Expand Up @@ -48,6 +53,7 @@ def __init__(
self.mu_schedule = Parameter(initializer(1, [1], ms.float32), name="mu_schedule")
self.beta2_power = Parameter(initializer(1, [1], ms.float32), name="beta2_power")

@jit
def construct(self, gradients):
lr = self.get_lr()
params = self.parameters
Expand Down

0 comments on commit e7b26a4

Please sign in to comment.