Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 67 additions & 64 deletions python/paddle/fluid/tests/unittests/test_sparse_norm_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,87 +18,90 @@
import paddle
from paddle.incubate.sparse import nn
import paddle.fluid as fluid
from paddle.fluid.framework import _test_eager_guard
import copy


class TestSparseBatchNorm(unittest.TestCase):

def test(self):
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True})
with _test_eager_guard():
paddle.seed(0)
channels = 4
shape = [2, 3, 6, 6, channels]
#there is no zero in dense_x
dense_x = paddle.randn(shape)
dense_x.stop_gradient = False

batch_norm = paddle.nn.BatchNorm3D(channels, data_format="NDHWC")
dense_y = batch_norm(dense_x)
dense_y.backward(dense_y)

sparse_dim = 4
dense_x2 = copy.deepcopy(dense_x)
dense_x2.stop_gradient = False
sparse_x = dense_x2.to_sparse_coo(sparse_dim)
sparse_batch_norm = paddle.incubate.sparse.nn.BatchNorm(channels)
# set same params
sparse_batch_norm._mean.set_value(batch_norm._mean)
sparse_batch_norm._variance.set_value(batch_norm._variance)
sparse_batch_norm.weight.set_value(batch_norm.weight)

sparse_y = sparse_batch_norm(sparse_x)
# compare the result with dense batch_norm
assert np.allclose(dense_y.flatten().numpy(),
sparse_y.values().flatten().numpy(),
atol=1e-5,
rtol=1e-5)

# test backward
sparse_y.backward(sparse_y)
assert np.allclose(dense_x.grad.flatten().numpy(),
sparse_x.grad.values().flatten().numpy(),
atol=1e-5,
rtol=1e-5)
paddle.seed(0)
channels = 4
shape = [2, 3, 6, 6, channels]
#there is no zero in dense_x
dense_x = paddle.randn(shape)
dense_x.stop_gradient = False

batch_norm = paddle.nn.BatchNorm3D(channels, data_format="NDHWC")
dense_y = batch_norm(dense_x)
dense_y.backward(dense_y)

sparse_dim = 4
dense_x2 = copy.deepcopy(dense_x)
dense_x2.stop_gradient = False
sparse_x = dense_x2.to_sparse_coo(sparse_dim)
sparse_batch_norm = paddle.incubate.sparse.nn.BatchNorm(channels)
# set same params
sparse_batch_norm._mean.set_value(batch_norm._mean)
sparse_batch_norm._variance.set_value(batch_norm._variance)
sparse_batch_norm.weight.set_value(batch_norm.weight)

sparse_y = sparse_batch_norm(sparse_x)
# compare the result with dense batch_norm
assert np.allclose(dense_y.flatten().numpy(),
sparse_y.values().flatten().numpy(),
atol=1e-5,
rtol=1e-5)

# test backward
sparse_y.backward(sparse_y)
assert np.allclose(dense_x.grad.flatten().numpy(),
sparse_x.grad.values().flatten().numpy(),
atol=1e-5,
rtol=1e-5)
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": False})

def test_error_layout(self):
with _test_eager_guard():
with self.assertRaises(ValueError):
shape = [2, 3, 6, 6, 3]
x = paddle.randn(shape)
sparse_x = x.to_sparse_coo(4)
sparse_batch_norm = paddle.incubate.sparse.nn.BatchNorm(
3, data_format='NCDHW')
sparse_batch_norm(sparse_x)
with self.assertRaises(ValueError):
shape = [2, 3, 6, 6, 3]
x = paddle.randn(shape)
sparse_x = x.to_sparse_coo(4)
sparse_batch_norm = paddle.incubate.sparse.nn.BatchNorm(
3, data_format='NCDHW')
sparse_batch_norm(sparse_x)

def test2(self):
with _test_eager_guard():
paddle.seed(123)
channels = 3
x_data = paddle.randn((1, 6, 6, 6, channels)).astype('float32')
dense_x = paddle.to_tensor(x_data)
sparse_x = dense_x.to_sparse_coo(4)
batch_norm = paddle.incubate.sparse.nn.BatchNorm(channels)
batch_norm_out = batch_norm(sparse_x)
print(batch_norm_out.shape)
# [1, 6, 6, 6, 3]
paddle.seed(123)
channels = 3
x_data = paddle.randn((1, 6, 6, 6, channels)).astype('float32')
dense_x = paddle.to_tensor(x_data)
sparse_x = dense_x.to_sparse_coo(4)
batch_norm = paddle.incubate.sparse.nn.BatchNorm(channels)
batch_norm_out = batch_norm(sparse_x)
dense_bn = paddle.nn.BatchNorm1D(channels)
dense_x = dense_x.reshape((-1, dense_x.shape[-1]))
dense_out = dense_bn(dense_x)
assert np.allclose(dense_out.numpy(), batch_norm_out.values().numpy())
# [1, 6, 6, 6, 3]


class TestSyncBatchNorm(unittest.TestCase):

def test_sync_batch_norm(self):
with _test_eager_guard():
x = np.array([[[[0.3, 0.4], [0.3, 0.07]],
[[0.83, 0.37], [0.18, 0.93]]]]).astype('float32')
x = paddle.to_tensor(x)
x = x.to_sparse_coo(len(x.shape) - 1)

if paddle.is_compiled_with_cuda():
sync_batch_norm = nn.SyncBatchNorm(2)
hidden1 = sync_batch_norm(x)
print(hidden1)
x = np.array([[[[0.3, 0.4], [0.3, 0.07]],
[[0.83, 0.37], [0.18, 0.93]]]]).astype('float32')
x = paddle.to_tensor(x)
sparse_x = x.to_sparse_coo(len(x.shape) - 1)

if paddle.is_compiled_with_cuda():
sparse_sync_bn = nn.SyncBatchNorm(2)
sparse_hidden = sparse_sync_bn(sparse_x)

dense_sync_bn = paddle.nn.SyncBatchNorm(2)
x = x.reshape((-1, x.shape[-1]))
dense_hidden = dense_sync_bn(x)
assert np.allclose(sparse_hidden.values().numpy(),
dense_hidden.numpy())

def test_convert(self):
base_model = paddle.nn.Sequential(nn.Conv3D(3, 5, 3), nn.BatchNorm(5),
Expand Down
6 changes: 4 additions & 2 deletions python/paddle/incubate/sparse/nn/layer/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ class SyncBatchNorm(paddle.nn.SyncBatchNorm):

Shapes:
input: Tensor that the dimension from 2 to 5.

output: Tensor with the same shape as input.

Examples:
Expand Down Expand Up @@ -278,7 +279,7 @@ def forward(self, x):

@classmethod
def convert_sync_batchnorm(cls, layer):
"""
r"""
Helper function to convert :class: `paddle.incubate.sparse.nn.BatchNorm` layers in the model to :class: `paddle.incubate.sparse.nn.SyncBatchNorm` layers.

Parameters:
Expand All @@ -290,13 +291,14 @@ def convert_sync_batchnorm(cls, layer):
Examples:

.. code-block:: python

import paddle
import paddle.incubate.sparse.nn as nn

model = paddle.nn.Sequential(nn.Conv3D(3, 5, 3), nn.BatchNorm(5))
sync_model = nn.SyncBatchNorm.convert_sync_batchnorm(model)

"""

layer_output = layer
if isinstance(layer, _BatchNormBase):
if layer._weight_attr != None and not isinstance(
Expand Down