Skip to content

Commit 1aa2bde

Browse files
authored
[bug fix] fix spectral_norm bug (#35005)
1 parent 096b0f2 commit 1aa2bde

File tree

2 files changed

+10
-0
lines changed

2 files changed

+10
-0
lines changed

python/paddle/fluid/dygraph/nn.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3062,6 +3062,12 @@ def __init__(self,
30623062
self._dtype = dtype
30633063

30643064
self._weight_shape = list(weight_shape)
3065+
assert np.prod(self._weight_shape) > 0,\
3066+
"Any dimension of `weight_shape` cannot be equal to 0."
3067+
assert dim < len(self._weight_shape), \
3068+
("The input `dim` should be less than the "
3069+
"length of `weight_shape`, but received dim="
3070+
"{}".format(dim))
30653071
h = self._weight_shape[self._dim]
30663072
w = np.prod(self._weight_shape) // h
30673073

python/paddle/fluid/layers/nn.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3720,6 +3720,10 @@ def spectral_norm(weight, dim=0, power_iters=1, eps=1e-12, name=None):
37203720
# create intput and parameters
37213721
inputs = {'Weight': weight}
37223722
input_shape = weight.shape
3723+
assert weight.numel() > 0, "Any dimension of input cannot be equal to 0."
3724+
assert dim < len(input_shape), ("The input `dim` should be less than the "
3725+
"rank of `weight`, but received dim="
3726+
"{}".format(dim))
37233727
h = input_shape[dim]
37243728
w = np.prod(input_shape) // h
37253729

0 commit comments

Comments
 (0)