Skip to content
1 change: 1 addition & 0 deletions paddlescience/network/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@
# limitations under the License.

from .network_fc import FCNet
from .grad_norm import GradNorm
import paddle.nn.initializer as initializer
130 changes: 130 additions & 0 deletions paddlescience/network/grad_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import paddle
import paddle.nn
from paddle.nn.initializer import Assign
from .network_base import NetworkBase


class GradNorm(NetworkBase):
r"""
Gradient normalization for adaptive loss balancing.
Parameters:
net(NetworkBase): The network which must have "get_shared_layer" method.
n_loss(int): The number of loss, must be greater than 1.
alpha(float): The hyperparameter which controls learning rate, must be greater than 0.
weight_attr(list, tuple): The inital weights for "loss_weights". If not specified, "loss_weights" will be initialized with 1.
"""

def __init__(self, net, n_loss, alpha, weight_attr=None):
super().__init__()
if not isinstance(net, NetworkBase):
raise TypeError("'net' must be a NetworkBase subclass instance.")
if not hasattr(net, 'get_shared_layer'):
raise TypeError("'net' must have 'get_shared_layer' method.")
if n_loss <= 1:
raise ValueError(
"'n_loss' must be greater than 1, but got {}".format(n_loss))
if alpha < 0:
raise ValueError("'alpha' is a positive number, but got {}".format(
alpha))
if weight_attr is not None:
if len(weight_attr) != n_loss:
raise ValueError(
"weight_attr must have same length with loss weights.")

self.n_loss = n_loss
self.net = net
self.loss_weights = self.create_parameter(
shape=[n_loss],
attr=Assign(weight_attr if weight_attr else [1] * n_loss),
dtype=self._dtype,
is_bias=False)
self.set_grad()
self.alpha = float(alpha)
self.initial_losses = None

def nn_func(self, ins):
return self.net.nn_func(ins)

def __getattr__(self, __name):
try:
return super().__getattr__(__name)
except:
return getattr(self.net, __name)

def get_grad_norm_loss(self, losses):
if isinstance(losses, list):
losses = paddle.concat(losses)

if self.initial_losses is None:
self.initial_losses = losses.numpy()

W = self.net.get_shared_layer()

# set grad to zero
if self.loss_weights.grad is not None:
self.loss_weights.grad.set_value(
paddle.zeros_like(self.loss_weights))

# calulate each loss's grad
norms = []
for i in range(losses.shape[0]):
grad = paddle.autograd.grad(losses[i], W, retain_graph=True)
norms.append(paddle.norm(self.loss_weights[i] * grad[0], p=2))
norms = paddle.concat(norms)

# calculate the inverse train rate
loss_ratio = losses.numpy() / self.initial_losses
inverse_train_rate = loss_ratio / np.mean(loss_ratio)

# calculate the mean value of grad
mean_norm = np.mean(norms.numpy())

# convert it to constant, instead of having grads
constant_term = paddle.to_tensor(
mean_norm * np.power(inverse_train_rate, self.alpha),
dtype=self._dtype)
# calculate the grad norm loss
grad_norm_loss = paddle.norm(norms - constant_term, p=1)
# update the grad of loss weights
self.loss_weights.grad.set_value(
paddle.autograd.grad(grad_norm_loss, self.loss_weights)[0])
# renormalize the loss weights each step when training
if self.training:
self.renormalize()
return grad_norm_loss

def renormalize(self):
normalize_coeff = self.n_loss / paddle.sum(self.loss_weights)
self.loss_weights = self.create_parameter(
shape=[self.n_loss],
attr=Assign(self.loss_weights * normalize_coeff),
dtype=self._dtype,
is_bias=False)
self.set_grad()

def reset_initial_losses(self):
self.initial_losses = None

def set_grad(self):
x = paddle.ones_like(self.loss_weights)
x *= self.loss_weights
x.backward()
self.loss_weights.grad.set_value(paddle.zeros_like(self.loss_weights))

def get_weights(self):
return self.loss_weights.numpy()
3 changes: 3 additions & 0 deletions paddlescience/network/network_fc.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,3 +249,6 @@ def reconstruct(self, param_data):
self._biases.append(new_param)
else:
self._weights.append(new_param)

def get_shared_layer(self):
return self._weights[-1]
162 changes: 162 additions & 0 deletions tests/test_api/test_GradNorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
"""
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from functools import partial

import numpy as np
import paddle
import paddlescience as psci
import pytest

from apibase import APIBase

GLOBAL_SEED = 22
np.random.seed(GLOBAL_SEED)
paddle.seed(GLOBAL_SEED)
paddle.disable_static()

loss_func = [
paddle.sum, paddle.mean, partial(
paddle.norm, p=2), partial(
paddle.norm, p=3)
]


def randtool(dtype, low, high, shape, seed=None):
if seed is not None:
np.random.seed(seed)

if dtype == "int":
return np.random.randint(low, high, shape)

elif dtype == "float":
return low + (high - low) * np.random.random(shape)


def cal_gradnorm(ins,
num_ins,
num_outs,
num_layers,
hidden_size,
n_loss=3,
alpha=0.5,
activation='tanh',
weight_attr=None):

net = psci.network.FCNet(
num_ins=num_ins,
num_outs=num_outs,
num_layers=num_layers,
hidden_size=hidden_size,
activation=activation)

for i in range(num_layers):
net._weights[i] = paddle.ones_like(net._weights[i])
net._weights[i].stop_gradient = False

grad_norm = psci.network.GradNorm(
net=net, n_loss=n_loss, alpha=alpha, weight_attr=weight_attr)
res = grad_norm.nn_func(ins)

losses = []
for idx in range(n_loss):
losses.append(loss_func[idx](res))
weighted_loss = grad_norm.loss_weights * paddle.concat(losses)
loss = paddle.sum(weighted_loss)
loss.backward(retain_graph=True)
grad_norm_loss = grad_norm.get_grad_norm_loss(losses)
return grad_norm_loss


class TestGradNorm(APIBase):
def hook(self):
"""
implement
"""
self.types = [np.float32]
# self.debug = True
# enable check grad
self.static = False
self.enable_backward = False
self.rtol = 1e-7


obj = TestGradNorm(cal_gradnorm)


@pytest.mark.api_network_GradNorm
def test_GradNorm0():
xy_data = np.array([[0.1, 0.5, 0.3, 0.4, 0.2]])
u = np.array([1.138526], dtype=np.float32)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些数值是在相同的初始化方法、随机种子、输入的情况下,使用该仓库的逻辑通过Pytorch计算得来,代码如下:

import torch
import torch.nn as nn
from functools import partial
import numpy as np
from torch.nn.init import constant_

class FCNet(nn.Module):
    def __init__(self,
                 num_ins,
                 num_outs,
                 num_layers,
                 hidden_size,
                 activation='tanh',
                 n_loss=1):
        super(FCNet, self).__init__()

        self.num_ins = num_ins
        self.num_outs = num_outs
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.weights = nn.Parameter(torch.ones(n_loss).float())
        # self.weights = nn.Parameter(torch.tensor([1.0, 2.0, 3.0]).float())
        if activation == 'sigmoid':
            self.activation = torch.sigmoid
        elif activation == 'tanh':
            self.activation = torch.tanh
        else:
            assert 0, "Unsupported activation type."
        w = []
        self.num_layers = num_layers
        for i in range(num_layers):
            if i == 0:
                lsize = num_ins
                rsize = hidden_size
            elif i == (num_layers - 1):
                lsize = hidden_size
                rsize = num_outs
            else:
                lsize = hidden_size
                rsize = hidden_size
            w.append(nn.Linear(lsize, rsize, bias=False))
        self.fc = nn.ModuleList(w)
        self._init_weights()

    def _init_weights(self):
        for i in self.fc:
            if isinstance(i, nn.Linear):
                constant_(i.weight, 1)

    def forward(self, inp):
        u = inp
        for i in range(self.num_layers - 1):
            u = self.fc[i](u)
            u = self.activation(u)
        return self.fc[-1](u)

loss_func = [torch.sum, torch.mean, partial(torch.norm, p=2), partial(torch.norm, p=3)]

def cal_gradnorm(ins,
                num_ins,
                num_outs,
                num_layers,
                hidden_size,
                n_loss,
                alpha,
                activation='tanh',
                weight_attr=None):
    net = FCNet(
        num_ins=num_ins,
        num_outs=num_outs,
        num_layers=num_layers,
        hidden_size=hidden_size,
        activation=activation,
        n_loss=n_loss)

    res = net(ins)
    print(res)
    losses = []
    for idx in range(n_loss):
        losses.append(loss_func[idx](res))
    losses = torch.stack(losses)
    weighted_loss = losses * net.weights
    loss = torch.sum(weighted_loss)
    loss.backward(retain_graph=True)
    initial_task_loss = losses.detach().numpy()
    net.weights.grad.data = net.weights.grad.data * 0.0
    W = net.fc[-1]
    norms = []
    for i in range(n_loss):
        # get the gradient of this task loss with respect to the shared parameters
        gygw = torch.autograd.grad(losses[i], W.parameters(), retain_graph=True)
        # compute the norm
        norms.append(torch.norm(torch.mul(net.weights[i], gygw[0])))
    norms = torch.stack(norms)
    print("norms: ", norms)

    if torch.cuda.is_available():
        loss_ratio = losses.data.cpu().numpy() / initial_task_loss
    else:
        loss_ratio = losses.data.numpy() / initial_task_loss

    inverse_train_rate = loss_ratio / np.mean(loss_ratio)
    print("inverse_train_rate: ", inverse_train_rate)

    if torch.cuda.is_available():
        mean_norm = np.mean(norms.data.cpu().numpy())
    else:
        mean_norm = np.mean(norms.data.numpy())
    
    constant_term = torch.tensor(mean_norm * (inverse_train_rate ** alpha), requires_grad=False)

    print("constant_term: ", constant_term)

    if torch.cuda.is_available():
        constant_term = constant_term.cuda()
    grad_norm_loss = torch.sum(torch.abs(norms - constant_term))
    net.weights.grad = torch.autograd.grad(grad_norm_loss, net.weights)[0]
    print(net.weights.grad)
    return grad_norm_loss


def randtool(dtype, low, high, shape):
    """
    np random tools
    """
    if dtype == "int":
        return np.random.randint(low, high, shape)

    elif dtype == "float":
        return low + (high - low) * np.random.random(shape)


if __name__ == '__main__':
    np.random.seed(22)
    xy_data = randtool("float", 0, 10, (9, 2))
    print(xy_data)
    # xy_data = torch.tensor(np.array([[0.1, 0.5, 0.2, 0.4]]), dtype=torch.float32)
    # xy_data = torch.tensor(np.array([[0.1, 0.5, 0.3, 0.4, 0.2]]), dtype=torch.float32)
    # res = cal_gradnorm(xy_data, 4, 3, 5, 20, activation='sigmoid', n_loss=3, alpha=0.5)
    res = cal_gradnorm(torch.tensor(xy_data, dtype=torch.float32), 2, 3, 2, 1, activation='tanh', n_loss=4, alpha=0.5)
    print(res.item())
    

obj.run(res=u,
ins=xy_data,
num_ins=5,
num_outs=3,
num_layers=2,
hidden_size=1)


@pytest.mark.api_network_GradNorm
def test_GradNorm1():
xy_data = randtool("float", 0, 10, (9, 2), GLOBAL_SEED)
u = np.array([20.636574])
obj.run(res=u,
ins=xy_data,
num_ins=2,
num_outs=3,
num_layers=2,
hidden_size=1,
n_loss=4)


@pytest.mark.api_network_GradNorm
def test_GradNorm2():
xy_data = randtool("float", 0, 1, (9, 3), GLOBAL_SEED)
u = np.array([7.633053])
obj.run(res=u,
ins=xy_data,
num_ins=3,
num_outs=1,
num_layers=2,
hidden_size=1,
activation='sigmoid')


@pytest.mark.api_network_GradNorm
def test_GradNorm3():
xy_data = randtool("float", 0, 1, (9, 4), GLOBAL_SEED)
u = np.array([41.803569])
obj.run(res=u,
ins=xy_data,
num_ins=4,
num_outs=3,
num_layers=2,
hidden_size=10,
activation='sigmoid',
n_loss=2,
alpha=0.2)


@pytest.mark.api_network_GradNorm
def test_GradNorm4():
xy_data = randtool("float", 0, 1, (9, 5), GLOBAL_SEED)
u = np.array([12.606881])
obj.run(res=u,
ins=xy_data,
num_ins=5,
num_outs=1,
num_layers=3,
hidden_size=2,
weight_attr=[1.0, 2.0, 3.0])