diff --git a/crates/burn-ndarray/src/ops/base.rs b/crates/burn-ndarray/src/ops/base.rs index 7b94ea7999..ea9ae3a119 100644 --- a/crates/burn-ndarray/src/ops/base.rs +++ b/crates/burn-ndarray/src/ops/base.rs @@ -480,12 +480,13 @@ where rhs: NdArrayTensor, var_name: impl FnMut(&E, &OtherE) -> E, ) -> NdArrayTensor { - NdArrayTensor::new( - Zip::from(lhs.array.view()) - .and(rhs.array.view()) - .map_collect(var_name) - .into_shared(), - ) + let lhs = lhs + .array + .broadcast(rhs.array.dim()) + .unwrap_or(lhs.array.view()); + let rhs = rhs.array.broadcast(lhs.dim()).unwrap_or(rhs.array.view()); + + NdArrayTensor::new(Zip::from(lhs).and(rhs).map_collect(var_name).into_shared()) } pub(crate) fn elementwise_op_scalar( diff --git a/crates/burn-tensor/src/tests/ops/powf.rs b/crates/burn-tensor/src/tests/ops/powf.rs index d82d717748..4997b1de58 100644 --- a/crates/burn-tensor/src/tests/ops/powf.rs +++ b/crates/burn-tensor/src/tests/ops/powf.rs @@ -54,4 +54,21 @@ mod tests { output.into_data().assert_approx_eq(&expected, 3); } + + #[test] + fn should_support_powf_broadcasted() { + let device = Default::default(); + let tensor_1 = Tensor::::from_floats([2.0, 3.0, 4.0], &device); + let tensor_2 = Tensor::from_floats([1.0], &device); + + // Broadcast rhs + let output = tensor_1.clone().powf(tensor_2.clone()); + output.into_data().assert_approx_eq(&tensor_1.to_data(), 3); + + // Broadcast lhs + let output = tensor_2.powf(tensor_1); + output + .into_data() + .assert_approx_eq(&TensorData::from([1.0, 1.0, 1.0]), 3); + } }