Skip to content

Commit 4cdf60c

Browse files
authored
feat: upgrade ms to 2.2 (#756)
1 parent 062fe1a commit 4cdf60c

File tree

5 files changed

+94
-6
lines changed

5 files changed

+94
-6
lines changed

mindcv/models/hrnet.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"""
55
from typing import Any, Dict, List, Optional, Tuple, Type, Union
66

7+
import mindspore as ms
78
import mindspore.nn as nn
89
import mindspore.ops as ops
910
from mindspore import Tensor
@@ -329,23 +330,26 @@ def construct(self, x: List[Tensor]) -> List[Tensor]:
329330
if self.num_branches == 1:
330331
return [self.branches[0](x[0])]
331332

333+
x2 = []
332334
for i in range(self.num_branches):
333-
x[i] = self.branches[i](x[i])
335+
x2.append(self.branches[i](x[i]))
334336

335337
x_fuse = []
336338

337339
for i in range(len(self.fuse_layers)):
338-
y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
340+
y = x2[0] if i == 0 else self.fuse_layers[i][0](x2[0])
339341
for j in range(1, self.num_branches):
340342
if i == j:
341-
y = y + x[j]
343+
y = y + x2[j]
342344
elif j > i:
343-
_, _, height, width = x[i].shape
344-
t = self.fuse_layers[i][j](x[j])
345+
_, _, height, width = x2[i].shape
346+
t = self.fuse_layers[i][j](x2[j])
347+
t = ops.cast(t, ms.float32)
345348
t = ops.ResizeNearestNeighbor((height, width))(t)
349+
t = ops.cast(t, ms.float16)
346350
y = y + t
347351
else:
348-
y = y + self.fuse_layers[i][j](x[j])
352+
y = y + self.fuse_layers[i][j](x2[j])
349353
x_fuse.append(self.relu(y))
350354

351355
if not self.multi_scale_output:

mindcv/optim/adamw.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,27 @@ def __init__(
159159
self.reciprocal_scale = Tensor(1.0 / loss_scale, ms.float32)
160160
self.clip = clip
161161

162+
def get_lr(self):
163+
"""
164+
The optimizer calls this interface to get the learning rate for the current step. User-defined optimizers based
165+
on :class:`mindspore.nn.Optimizer` can also call this interface before updating the parameters.
166+
167+
Returns:
168+
float, the learning rate of current step.
169+
"""
170+
lr = self.learning_rate
171+
if self.dynamic_lr:
172+
if self.is_group_lr:
173+
lr = ()
174+
for learning_rate in self.learning_rate:
175+
current_dynamic_lr = learning_rate(self.global_step).reshape(())
176+
lr += (current_dynamic_lr,)
177+
else:
178+
lr = self.learning_rate(self.global_step).reshape(())
179+
if self._is_dynamic_lr_or_weight_decay():
180+
self.assignadd(self.global_step, self.global_step_increase_tensor)
181+
return lr
182+
162183
@jit
163184
def construct(self, gradients):
164185
lr = self.get_lr()

mindcv/optim/adan.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,27 @@ def __init__(
149149

150150
self.weight_decay = Tensor(weight_decay, mstype.float32)
151151

152+
def get_lr(self):
153+
"""
154+
The optimizer calls this interface to get the learning rate for the current step. User-defined optimizers based
155+
on :class:`mindspore.nn.Optimizer` can also call this interface before updating the parameters.
156+
157+
Returns:
158+
float, the learning rate of current step.
159+
"""
160+
lr = self.learning_rate
161+
if self.dynamic_lr:
162+
if self.is_group_lr:
163+
lr = ()
164+
for learning_rate in self.learning_rate:
165+
current_dynamic_lr = learning_rate(self.global_step).reshape(())
166+
lr += (current_dynamic_lr,)
167+
else:
168+
lr = self.learning_rate(self.global_step).reshape(())
169+
if self._is_dynamic_lr_or_weight_decay():
170+
self.assignadd(self.global_step, self.global_step_increase_tensor)
171+
return lr
172+
152173
@jit
153174
def construct(self, gradients):
154175
params = self._parameters

mindcv/optim/lion.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,27 @@ def __init__(
147147
self.reciprocal_scale = Tensor(1.0 / loss_scale, ms.float32)
148148
self.clip = clip
149149

150+
def get_lr(self):
151+
"""
152+
The optimizer calls this interface to get the learning rate for the current step. User-defined optimizers based
153+
on :class:`mindspore.nn.Optimizer` can also call this interface before updating the parameters.
154+
155+
Returns:
156+
float, the learning rate of current step.
157+
"""
158+
lr = self.learning_rate
159+
if self.dynamic_lr:
160+
if self.is_group_lr:
161+
lr = ()
162+
for learning_rate in self.learning_rate:
163+
current_dynamic_lr = learning_rate(self.global_step).reshape(())
164+
lr += (current_dynamic_lr,)
165+
else:
166+
lr = self.learning_rate(self.global_step).reshape(())
167+
if self._is_dynamic_lr_or_weight_decay():
168+
self.assignadd(self.global_step, self.global_step_increase_tensor)
169+
return lr
170+
150171
@jit
151172
def construct(self, gradients):
152173
lr = self.get_lr()

mindcv/optim/nadam.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,27 @@ def __init__(
5353
self.mu_schedule = Parameter(initializer(1, [1], ms.float32), name="mu_schedule")
5454
self.beta2_power = Parameter(initializer(1, [1], ms.float32), name="beta2_power")
5555

56+
def get_lr(self):
57+
"""
58+
The optimizer calls this interface to get the learning rate for the current step. User-defined optimizers based
59+
on :class:`mindspore.nn.Optimizer` can also call this interface before updating the parameters.
60+
61+
Returns:
62+
float, the learning rate of current step.
63+
"""
64+
lr = self.learning_rate
65+
if self.dynamic_lr:
66+
if self.is_group_lr:
67+
lr = ()
68+
for learning_rate in self.learning_rate:
69+
current_dynamic_lr = learning_rate(self.global_step).reshape(())
70+
lr += (current_dynamic_lr,)
71+
else:
72+
lr = self.learning_rate(self.global_step).reshape(())
73+
if self._is_dynamic_lr_or_weight_decay():
74+
self.assignadd(self.global_step, self.global_step_increase_tensor)
75+
return lr
76+
5677
@jit
5778
def construct(self, gradients):
5879
lr = self.get_lr()

0 commit comments

Comments
 (0)