Fix PReLU Broadcasting Bug for Multiple Parameters #565
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
#################Summary#################
Fixed a bug in the PReLU function in jittor/nn.py where broadcasting the weight parameter caused errors when num_parameters was greater than 1. The previous implementation did not correctly broadcast the weights to match the input dimensions, leading to runtime errors.
#################Changes Made#################
Modified the execute method in PReLU class to correctly broadcast weight parameter for cases where num_parameters is greater than 1.
#################Original Code:#################
def init(self, num_parameters=1, init_=0.25):
self.num_parameters = num_parameters
self.weight = init.constant((num_parameters,), "float32", init_)
def execute(self, x):
if self.num_parameters != 1:
assert self.num_parameters == x.size(1), f"num_parameters does not match input channels in PReLU"
return jt.maximum(0, x) + self.weight.broadcast(x, [0,2,3]) * jt.minimum(0, x)
else:
return jt.maximum(0, x) + self.weight * jt.minimum(0, x)
############Updated Code:##############
def init(self, num_parameters=1, init_=0.25):
self.num_parameters = num_parameters
self.weight = init.constant((num_parameters,), "float32", init_)
def execute(self, x):
if self.num_parameters != 1:
assert self.num_parameters == x.shape[1], f"num_parameters does not match input channels in PReLU"
weight_broadcasted = self.weight.broadcast([x.shape[0], self.num_parameters, *([1] * (len(x.shape) - 2))])
return jt.maximum(0, x) + weight_broadcasted * jt.minimum(0, x)
else:
return jt.maximum(0, x) + self.weight * jt.minimum(0, x)
#################Testing#################
Tested the updated PReLU function with various configurations to ensure proper functionality:
import jittor as jt
from jittor import nn
Create input data with the specified shape
def create_input_data(shape):
num_elements = 1
for dim in shape:
num_elements *= dim
return jt.array(list(range(-num_elements // 2, num_elements // 2)), dtype=jt.float32).reshape(shape)
Test the PReLU activation function
def test_prelu(num_parameters, input_shape):
prelu_layer = nn.PReLU(num_parameters=num_parameters)
input_data = create_input_data(input_shape)
print(f"Testing PReLU with num_parameters={num_parameters} and input_shape={input_shape}")
print(f"Input Data:\n{input_data.numpy()}")
output_data = prelu_layer(input_data)
print(f"Output Data (PReLU):\n{output_data.numpy()}\n")
if name == "main":
test_configs = [
(1, (5,)), # Single parameter
(5, (5, 5)), # Five parameters matching the number of channels
(3, (3, 3)), # Three parameters matching the number of channels
]
for num_parameters, input_shape in test_configs:
test_prelu(num_parameters, input_shape)
#################Test Results:#################
Testing PReLU with num_parameters=1 and input_shape=(5,) Input Data:
[-3. -2. -1. 0. 1.]
Output Data (PReLU):
[-0.75 -0.5 -0.25 0. 1. ]
Testing PReLU with num_parameters=5 and input_shape=(5, 5) Input Data:
[[-13. -12. -11. -10. -9.]
[ -8. -7. -6. -5. -4.]
[ -3. -2. -1. 0. 1.]
[ 2. 3. 4. 5. 6.]
[ 7. 8. 9. 10. 11.]]
Output Data (PReLU):
[[-3.25 -3. -2.75 -2.5 -2.25]
[-2. -1.75 -1.5 -1.25 -1. ]
[-0.75 -0.5 -0.25 0. 1. ]
[ 2. 3. 4. 5. 6. ]
[ 7. 8. 9. 10. 11. ]]
Testing PReLU with num_parameters=3 and input_shape=(3, 3) Input Data:
[[-5. -4. -3.]
[-2. -1. 0.]
[ 1. 2. 3.]]
Output Data (PReLU):
[[-1.25 -1. -0.75]
[-0.5 -0.25 0. ]
[ 1. 2. 3. ]]
##################################
This fix ensures that the PReLU activation function can handle multiple parameters correctly by properly broadcasting the weight parameter to match the input tensor dimensions.