Skip to content

Commit 5ad36e0

Browse files
add test of GradNorm
1 parent f1b7006 commit 5ad36e0

File tree

2 files changed

+172
-3
lines changed

2 files changed

+172
-3
lines changed

paddlescience/network/grad_norm.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import paddle
1717
import paddle.nn
1818
from paddle.nn.initializer import Assign
19-
from .network_base import NetworkBase
19+
from network_base import NetworkBase
2020

2121

2222
class GradNorm(NetworkBase):
@@ -34,8 +34,7 @@ def __init__(self, net, n_loss, alpha, weight_attr=None):
3434
if not isinstance(net, NetworkBase):
3535
raise TypeError()("'net' must be a NetworkBase subclass instance.")
3636
if not hasattr(net, 'get_shared_layer'):
37-
raise TypeError(
38-
"'net' must be must have 'get_shared_layer' method.")
37+
raise TypeError("'net' must have 'get_shared_layer' method.")
3938
if n_loss <= 1:
4039
raise ValueError(
4140
"'n_loss' must be greater than 1, but got {}".format(n_loss))
@@ -47,6 +46,7 @@ def __init__(self, net, n_loss, alpha, weight_attr=None):
4746
raise ValueError(
4847
"weight_attr must have same length with loss weights.")
4948

49+
self.n_loss = n_loss
5050
self.net = net
5151
self.loss_weights = self.create_parameter(
5252
shape=[n_loss],
@@ -98,6 +98,15 @@ def get_grad_norm_loss(self, losses):
9898

9999
return grad_norm_loss
100100

101+
def renormalize(self):
102+
normalize_coeff = self.n_loss / paddle.sum(self.loss_weights)
103+
self.loss_weights = self.create_parameter(
104+
shape=[self.n_loss],
105+
attr=Assign(self.loss_weights * normalize_coeff),
106+
dtype=self._dtype,
107+
is_bias=False)
108+
self.set_grad()
109+
101110
def reset_initial_losses(self):
102111
self.initial_losses = None
103112

tests/test_api/test_GradNorm.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
"""
2+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""
16+
from functools import partial
17+
18+
import numpy as np
19+
import paddle
20+
import paddlescience as psci
21+
import pytest
22+
23+
from apibase import APIBase
24+
25+
GLOBAL_SEED = 22
26+
np.random.seed(GLOBAL_SEED)
27+
paddle.seed(GLOBAL_SEED)
28+
paddle.disable_static()
29+
30+
loss_func = [
31+
paddle.sum, paddle.mean, partial(
32+
paddle.norm, p=2), partial(
33+
paddle.norm, p=3)
34+
]
35+
36+
37+
def randtool(dtype, low, high, shape, seed=None):
38+
if seed is not None:
39+
np.random.seed(seed)
40+
41+
if dtype == "int":
42+
return np.random.randint(low, high, shape)
43+
44+
elif dtype == "float":
45+
return low + (high - low) * np.random.random(shape)
46+
47+
48+
def cal_gradnorm(ins,
49+
num_ins,
50+
num_outs,
51+
num_layers,
52+
hidden_size,
53+
n_loss=3,
54+
alpha=0.5,
55+
activation='tanh',
56+
weight_attr=None):
57+
58+
net = psci.network.FCNet(
59+
num_ins=num_ins,
60+
num_outs=num_outs,
61+
num_layers=num_layers,
62+
hidden_size=hidden_size,
63+
activation=activation)
64+
65+
for i in range(num_layers):
66+
net._weights[i] = paddle.ones_like(net._weights[i])
67+
68+
grad_norm = psci.network.GradNorm(
69+
net=net, n_loss=n_loss, alpha=alpha, weight_attr=weight_attr)
70+
res = grad_norm.nn_func(ins)
71+
72+
losses = []
73+
for idx in range(n_loss):
74+
losses.append(loss_func[idx](res))
75+
weighted_loss = grad_norm.loss_weights * paddle.concat(losses)
76+
loss = paddle.sum(weighted_loss)
77+
loss.backward(retain_graph=True)
78+
grad_norm_loss = grad_norm.get_grad_norm_loss(losses)
79+
return grad_norm_loss
80+
81+
82+
class TestGradNorm(APIBase):
83+
def hook(self):
84+
"""
85+
implement
86+
"""
87+
self.types = [np.float32]
88+
# self.debug = True
89+
# enable check grad
90+
self.static = False
91+
self.enable_backward = False
92+
93+
94+
obj = TestGradNorm(cal_gradnorm)
95+
96+
97+
@pytest.mark.api_network_GradNorm
98+
def test_GradNorm0():
99+
xy_data = np.array([[0.1, 0.5, 0.3, 0.4, 0.2]])
100+
u = np.array([1.138526], dtype=np.float32)
101+
obj.run(res=u,
102+
ins=xy_data,
103+
num_ins=5,
104+
num_outs=3,
105+
num_layers=2,
106+
hidden_size=1)
107+
108+
109+
@pytest.mark.api_network_GradNorm
110+
def test_GradNorm1():
111+
xy_data = randtool("float", 0, 10, (9, 2), GLOBAL_SEED)
112+
u = np.array([20.636574])
113+
obj.run(res=u,
114+
ins=xy_data,
115+
num_ins=2,
116+
num_outs=3,
117+
num_layers=2,
118+
hidden_size=1,
119+
n_loss=4)
120+
121+
122+
@pytest.mark.api_network_GradNorm
123+
def test_GradNorm2():
124+
xy_data = randtool("float", 0, 1, (9, 3), GLOBAL_SEED)
125+
u = np.array([7.633053])
126+
obj.run(res=u,
127+
ins=xy_data,
128+
num_ins=3,
129+
num_outs=1,
130+
num_layers=2,
131+
hidden_size=1,
132+
activation='sigmoid')
133+
134+
135+
@pytest.mark.api_network_GradNorm
136+
def test_GradNorm3():
137+
xy_data = randtool("float", 0, 1, (9, 4), GLOBAL_SEED)
138+
u = np.array([41.803569])
139+
obj.run(res=u,
140+
ins=xy_data,
141+
num_ins=4,
142+
num_outs=3,
143+
num_layers=2,
144+
hidden_size=10,
145+
activation='sigmoid',
146+
n_loss=2,
147+
alpha=0.2)
148+
149+
150+
@pytest.mark.api_network_GradNorm
151+
def test_GradNorm4():
152+
xy_data = randtool("float", 0, 1, (9, 5), GLOBAL_SEED)
153+
u = np.array([12.606881])
154+
obj.run(res=u,
155+
ins=xy_data,
156+
num_ins=5,
157+
num_outs=1,
158+
num_layers=3,
159+
hidden_size=2,
160+
weight_attr=[1.0, 2.0, 3.0])

0 commit comments

Comments
 (0)