Skip to content

add ST and CI workflow #44

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Mar 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions mindocr/utils/train_step_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.ops import operations as P
import mindspore.context as context


_ema_op = C.MultitypeFuncGraph("grad_ema_op")
_grad_scale = C.MultitypeFuncGraph("grad_scale")
Expand Down Expand Up @@ -62,6 +64,9 @@ def __init__(
self.weights_all = ms.ParameterTuple(list(network.get_parameters()))
self.ema_weight = self.weights_all.clone("ema", init="same")

self.is_cpu_device = context.get_context("device_target") == 'CPU' # to support CPU run
print('\n====-> device: ', context.get_context("device_target") )

def ema_update(self):
"""Update EMA parameters."""
self.updates += 1
Expand All @@ -75,8 +80,9 @@ def construct(self, *inputs):
weights = self.weights
loss = self.network(*inputs)
scaling_sens = self.scale_sense

status, scaling_sens = self.start_overflow_check(loss, scaling_sens)

if not self.is_cpu_device:
status, scaling_sens = self.start_overflow_check(loss, scaling_sens)

scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss))
grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled)
Expand All @@ -85,8 +91,13 @@ def construct(self, *inputs):
grads = self.grad_reducer(grads)

# get the overflow buffer
cond = self.get_overflow_status(status, grads)
overflow = self.process_loss_scale(cond)
if not self.is_cpu_device:
cond = self.get_overflow_status(status, grads)
overflow = self.process_loss_scale(cond)
else:
overflow = False
cond = False

if self.drop_overflow_update:
# if there is no overflow, do optimize
if not overflow :
Expand Down
4 changes: 4 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@ lmdb
mindcv
pyclipper
shapely
tqdm
addict
scikit-learn
matplotlib
rapidfuzz==2.13.7
numpy==1.21.6
opencv-python-headless==3.4.18.65
Expand Down
2 changes: 1 addition & 1 deletion tests/st/rec_crnn_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ train:
batch_size: *batch_size
drop_remainder: True
max_rowsize: 16
num_workers: 10
num_workers: 2

eval:
ckpt_load_path: './tmp_rec/best.ckpt'
Expand Down
4 changes: 2 additions & 2 deletions tests/st/test_train_eval_dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@ def test_train_eval(task):
image_dir = f'{data_dir}/{split}/dogs'
new_label_path = f'data/Canidae/{split}/{task}_gt.txt'
img_paths = glob.glob(os.path.join(image_dir, '*.JPEG'))
print(len(img_paths))
#print(len(img_paths))
with open(new_label_path, 'w') as f_w:
with open(label_path, 'r') as f_r:
i = 0
for line in f_r:
_, label = line.strip().split('\t')
print(i)
#print(i)
img_name = os.path.basename(img_paths[i])
new_img_label = img_name + '\t' + label
f_w.write(new_img_label + '\n')
Expand Down