-
Notifications
You must be signed in to change notification settings - Fork 234
【PaddlePaddle Hackathon 3 】为 PaddleScience 增加损失函数权重自适应功能 #142
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
126a274
add GradNorm without test
Asthestarsfalll f1b7006
fix code style
Asthestarsfalll 5ad36e0
add test of GradNorm
Asthestarsfalll 7c57b8d
fix import error
Asthestarsfalll f2afeda
fix name error
Asthestarsfalll d98f626
fix attribute error
Asthestarsfalll 1746801
keep grad for fcnet
Asthestarsfalll 4728902
decrease rtol
Asthestarsfalll 8bf7b1c
restore rtol
Asthestarsfalll 3ad7710
add code annotation and renormalize
Asthestarsfalll 1bb20b4
Merge branch 'PaddlePaddle:develop' into grad_norm
Asthestarsfalll File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,130 @@ | ||
| # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import numpy as np | ||
| import paddle | ||
| import paddle.nn | ||
| from paddle.nn.initializer import Assign | ||
| from .network_base import NetworkBase | ||
|
|
||
|
|
||
| class GradNorm(NetworkBase): | ||
| r""" | ||
| Gradient normalization for adaptive loss balancing. | ||
| Parameters: | ||
| net(NetworkBase): The network which must have "get_shared_layer" method. | ||
| n_loss(int): The number of loss, must be greater than 1. | ||
| alpha(float): The hyperparameter which controls learning rate, must be greater than 0. | ||
| weight_attr(list, tuple): The inital weights for "loss_weights". If not specified, "loss_weights" will be initialized with 1. | ||
| """ | ||
|
|
||
| def __init__(self, net, n_loss, alpha, weight_attr=None): | ||
| super().__init__() | ||
| if not isinstance(net, NetworkBase): | ||
| raise TypeError("'net' must be a NetworkBase subclass instance.") | ||
| if not hasattr(net, 'get_shared_layer'): | ||
| raise TypeError("'net' must have 'get_shared_layer' method.") | ||
| if n_loss <= 1: | ||
| raise ValueError( | ||
| "'n_loss' must be greater than 1, but got {}".format(n_loss)) | ||
| if alpha < 0: | ||
| raise ValueError("'alpha' is a positive number, but got {}".format( | ||
| alpha)) | ||
| if weight_attr is not None: | ||
| if len(weight_attr) != n_loss: | ||
| raise ValueError( | ||
| "weight_attr must have same length with loss weights.") | ||
|
|
||
| self.n_loss = n_loss | ||
| self.net = net | ||
| self.loss_weights = self.create_parameter( | ||
| shape=[n_loss], | ||
| attr=Assign(weight_attr if weight_attr else [1] * n_loss), | ||
| dtype=self._dtype, | ||
| is_bias=False) | ||
| self.set_grad() | ||
| self.alpha = float(alpha) | ||
| self.initial_losses = None | ||
|
|
||
| def nn_func(self, ins): | ||
| return self.net.nn_func(ins) | ||
|
|
||
| def __getattr__(self, __name): | ||
| try: | ||
| return super().__getattr__(__name) | ||
| except: | ||
| return getattr(self.net, __name) | ||
|
|
||
| def get_grad_norm_loss(self, losses): | ||
| if isinstance(losses, list): | ||
| losses = paddle.concat(losses) | ||
|
|
||
| if self.initial_losses is None: | ||
| self.initial_losses = losses.numpy() | ||
|
|
||
| W = self.net.get_shared_layer() | ||
|
|
||
| # set grad to zero | ||
| if self.loss_weights.grad is not None: | ||
| self.loss_weights.grad.set_value( | ||
| paddle.zeros_like(self.loss_weights)) | ||
|
|
||
| # calulate each loss's grad | ||
| norms = [] | ||
| for i in range(losses.shape[0]): | ||
| grad = paddle.autograd.grad(losses[i], W, retain_graph=True) | ||
| norms.append(paddle.norm(self.loss_weights[i] * grad[0], p=2)) | ||
| norms = paddle.concat(norms) | ||
|
|
||
| # calculate the inverse train rate | ||
| loss_ratio = losses.numpy() / self.initial_losses | ||
| inverse_train_rate = loss_ratio / np.mean(loss_ratio) | ||
|
|
||
| # calculate the mean value of grad | ||
| mean_norm = np.mean(norms.numpy()) | ||
|
|
||
| # convert it to constant, instead of having grads | ||
| constant_term = paddle.to_tensor( | ||
| mean_norm * np.power(inverse_train_rate, self.alpha), | ||
| dtype=self._dtype) | ||
| # calculate the grad norm loss | ||
| grad_norm_loss = paddle.norm(norms - constant_term, p=1) | ||
| # update the grad of loss weights | ||
| self.loss_weights.grad.set_value( | ||
| paddle.autograd.grad(grad_norm_loss, self.loss_weights)[0]) | ||
| # renormalize the loss weights each step when training | ||
| if self.training: | ||
| self.renormalize() | ||
| return grad_norm_loss | ||
|
|
||
| def renormalize(self): | ||
| normalize_coeff = self.n_loss / paddle.sum(self.loss_weights) | ||
| self.loss_weights = self.create_parameter( | ||
| shape=[self.n_loss], | ||
| attr=Assign(self.loss_weights * normalize_coeff), | ||
| dtype=self._dtype, | ||
| is_bias=False) | ||
| self.set_grad() | ||
|
|
||
| def reset_initial_losses(self): | ||
| self.initial_losses = None | ||
|
|
||
| def set_grad(self): | ||
| x = paddle.ones_like(self.loss_weights) | ||
| x *= self.loss_weights | ||
| x.backward() | ||
| self.loss_weights.grad.set_value(paddle.zeros_like(self.loss_weights)) | ||
|
|
||
| def get_weights(self): | ||
| return self.loss_weights.numpy() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,162 @@ | ||
| """ | ||
| # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| """ | ||
| from functools import partial | ||
|
|
||
| import numpy as np | ||
| import paddle | ||
| import paddlescience as psci | ||
| import pytest | ||
|
|
||
| from apibase import APIBase | ||
|
|
||
| GLOBAL_SEED = 22 | ||
| np.random.seed(GLOBAL_SEED) | ||
| paddle.seed(GLOBAL_SEED) | ||
| paddle.disable_static() | ||
|
|
||
| loss_func = [ | ||
| paddle.sum, paddle.mean, partial( | ||
| paddle.norm, p=2), partial( | ||
| paddle.norm, p=3) | ||
| ] | ||
|
|
||
|
|
||
| def randtool(dtype, low, high, shape, seed=None): | ||
| if seed is not None: | ||
| np.random.seed(seed) | ||
|
|
||
| if dtype == "int": | ||
| return np.random.randint(low, high, shape) | ||
|
|
||
| elif dtype == "float": | ||
| return low + (high - low) * np.random.random(shape) | ||
|
|
||
|
|
||
| def cal_gradnorm(ins, | ||
| num_ins, | ||
| num_outs, | ||
| num_layers, | ||
| hidden_size, | ||
| n_loss=3, | ||
| alpha=0.5, | ||
| activation='tanh', | ||
| weight_attr=None): | ||
|
|
||
| net = psci.network.FCNet( | ||
| num_ins=num_ins, | ||
| num_outs=num_outs, | ||
| num_layers=num_layers, | ||
| hidden_size=hidden_size, | ||
| activation=activation) | ||
|
|
||
| for i in range(num_layers): | ||
| net._weights[i] = paddle.ones_like(net._weights[i]) | ||
| net._weights[i].stop_gradient = False | ||
|
|
||
| grad_norm = psci.network.GradNorm( | ||
| net=net, n_loss=n_loss, alpha=alpha, weight_attr=weight_attr) | ||
| res = grad_norm.nn_func(ins) | ||
|
|
||
| losses = [] | ||
| for idx in range(n_loss): | ||
| losses.append(loss_func[idx](res)) | ||
| weighted_loss = grad_norm.loss_weights * paddle.concat(losses) | ||
| loss = paddle.sum(weighted_loss) | ||
| loss.backward(retain_graph=True) | ||
| grad_norm_loss = grad_norm.get_grad_norm_loss(losses) | ||
| return grad_norm_loss | ||
|
|
||
|
|
||
| class TestGradNorm(APIBase): | ||
| def hook(self): | ||
| """ | ||
| implement | ||
| """ | ||
| self.types = [np.float32] | ||
| # self.debug = True | ||
| # enable check grad | ||
| self.static = False | ||
| self.enable_backward = False | ||
| self.rtol = 1e-7 | ||
|
|
||
|
|
||
| obj = TestGradNorm(cal_gradnorm) | ||
|
|
||
|
|
||
| @pytest.mark.api_network_GradNorm | ||
| def test_GradNorm0(): | ||
| xy_data = np.array([[0.1, 0.5, 0.3, 0.4, 0.2]]) | ||
| u = np.array([1.138526], dtype=np.float32) | ||
| obj.run(res=u, | ||
| ins=xy_data, | ||
| num_ins=5, | ||
| num_outs=3, | ||
| num_layers=2, | ||
| hidden_size=1) | ||
|
|
||
|
|
||
| @pytest.mark.api_network_GradNorm | ||
| def test_GradNorm1(): | ||
| xy_data = randtool("float", 0, 10, (9, 2), GLOBAL_SEED) | ||
| u = np.array([20.636574]) | ||
| obj.run(res=u, | ||
| ins=xy_data, | ||
| num_ins=2, | ||
| num_outs=3, | ||
| num_layers=2, | ||
| hidden_size=1, | ||
| n_loss=4) | ||
|
|
||
|
|
||
| @pytest.mark.api_network_GradNorm | ||
| def test_GradNorm2(): | ||
| xy_data = randtool("float", 0, 1, (9, 3), GLOBAL_SEED) | ||
| u = np.array([7.633053]) | ||
| obj.run(res=u, | ||
| ins=xy_data, | ||
| num_ins=3, | ||
| num_outs=1, | ||
| num_layers=2, | ||
| hidden_size=1, | ||
| activation='sigmoid') | ||
|
|
||
|
|
||
| @pytest.mark.api_network_GradNorm | ||
| def test_GradNorm3(): | ||
| xy_data = randtool("float", 0, 1, (9, 4), GLOBAL_SEED) | ||
| u = np.array([41.803569]) | ||
| obj.run(res=u, | ||
| ins=xy_data, | ||
| num_ins=4, | ||
| num_outs=3, | ||
| num_layers=2, | ||
| hidden_size=10, | ||
| activation='sigmoid', | ||
| n_loss=2, | ||
| alpha=0.2) | ||
|
|
||
|
|
||
| @pytest.mark.api_network_GradNorm | ||
| def test_GradNorm4(): | ||
| xy_data = randtool("float", 0, 1, (9, 5), GLOBAL_SEED) | ||
| u = np.array([12.606881]) | ||
| obj.run(res=u, | ||
| ins=xy_data, | ||
| num_ins=5, | ||
| num_outs=1, | ||
| num_layers=3, | ||
| hidden_size=2, | ||
| weight_attr=[1.0, 2.0, 3.0]) | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这些数值是在相同的初始化方法、随机种子、输入的情况下,使用该仓库的逻辑通过Pytorch计算得来,代码如下: