Skip to content

Commit

Permalink
Fix tests (#1089)
Browse files Browse the repository at this point in the history
* Fix tests

* Fix fmt

* Fix CI
  • Loading branch information
nathanielsimard authored Dec 21, 2023
1 parent 1fd07fc commit d82e6b1
Show file tree
Hide file tree
Showing 34 changed files with 166 additions and 198 deletions.
7 changes: 2 additions & 5 deletions burn-autodiff/src/tests/log1p.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,8 @@ mod tests {

#[test]
fn should_diff_log1p() {
let data_1 = Data::<f32, 2>::from([[0.0, 1.0], [3.0, 4.0]]);
let data_2 = Data::<f32, 2>::from([[6.0, 7.0], [9.0, 10.0]]);

let tensor_1 = TestAutodiffTensor::from_data_devauto(data_1).require_grad();
let tensor_2 = TestAutodiffTensor::from_data_devauto(data_2).require_grad();
let tensor_1 = TestAutodiffTensor::from([[0.0, 1.0], [3.0, 4.0]]).require_grad();
let tensor_2 = TestAutodiffTensor::from([[6.0, 7.0], [9.0, 10.0]]).require_grad();

let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().log1p());
let tensor_4 = tensor_3.matmul(tensor_2.clone());
Expand Down
2 changes: 1 addition & 1 deletion burn-candle/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ impl From<candle_core::Device> for CandleDevice {
match device.location() {
DeviceLocation::Cpu => CandleDevice::Cpu,
DeviceLocation::Cuda { gpu_id } => CandleDevice::Cuda(gpu_id),
DeviceLocation::Metal => panic!("Metal unsupported"),
_ => panic!("Device unsupported: {device:?}"),
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions burn-candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ mod tests {
pub type TestTensor<const D: usize> = burn_tensor::Tensor<TestBackend, D>;
pub type ReferenceTensor<const D: usize> = burn_tensor::Tensor<ReferenceBackend, D>;
pub type TestTensorInt<const D: usize> = burn_tensor::Tensor<TestBackend, D, burn_tensor::Int>;
pub type TestTensorBool<const D: usize> =
burn_tensor::Tensor<TestBackend, D, burn_tensor::Bool>;

type TestAutodiffBackend = burn_autodiff::Autodiff<TestBackend>;
type TestAutodiffTensor<const D: usize> = burn_tensor::Tensor<TestAutodiffBackend, D>;
Expand Down
2 changes: 2 additions & 0 deletions burn-ndarray/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ mod tests {
type TestBackend = crate::NdArray<f32>;
type TestTensor<const D: usize> = burn_tensor::Tensor<TestBackend, D>;
type TestTensorInt<const D: usize> = burn_tensor::Tensor<TestBackend, D, burn_tensor::Int>;
type TestTensorBool<const D: usize> = burn_tensor::Tensor<TestBackend, D, burn_tensor::Bool>;

use alloc::format;
use alloc::vec;

Expand Down
1 change: 1 addition & 0 deletions burn-tch/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ mod tests {
type TestBackend = crate::LibTorch<f32>;
type TestTensor<const D: usize> = burn_tensor::Tensor<TestBackend, D>;
type TestTensorInt<const D: usize> = burn_tensor::Tensor<TestBackend, D, burn_tensor::Int>;
type TestTensorBool<const D: usize> = burn_tensor::Tensor<TestBackend, D, burn_tensor::Bool>;

burn_tensor::testgen_all!();
burn_autodiff::testgen_all!();
Expand Down
11 changes: 11 additions & 0 deletions burn-tensor/src/tensor/api/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,17 @@ where
pub(crate) primitive: K::Primitive<D>,
}

impl<B, const D: usize, K, T> From<T> for Tensor<B, D, K>
where
B: Backend,
K: BasicOps<B>,
T: Into<Data<K::Elem, D>>,
{
fn from(value: T) -> Self {
Tensor::from_data(value.into(), &Default::default())
}
}

impl<B, const D: usize, K> Tensor<B, D, K>
where
B: Backend,
Expand Down
8 changes: 2 additions & 6 deletions burn-tensor/src/tests/activation/gelu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,10 @@ mod tests {

#[test]
fn test_gelu() {
let data = Data::from([[
let tensor = TestTensor::from([[
0.5447, 0.9809, 0.4114, 0.1398, 0.8045, 0.4103, 0.2388, 0.5262, 0.6677, 0.6737,
]]);
let tensor = Tensor::<TestBackend, 2>::from_data_devauto(data)
.clone()
.clone();

let data_actual = activation::gelu(tensor).to_data();
let data_actual = activation::gelu(tensor).into_data();

let data_expected = Data::from([[
0.3851, 0.8207, 0.2714, 0.0777, 0.6351, 0.2704, 0.1419, 0.3687, 0.4993, 0.5051,
Expand Down
5 changes: 2 additions & 3 deletions burn-tensor/src/tests/activation/mish.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@ mod tests {

#[test]
fn test_mish() {
let data = Data::from([[-0.4240, -0.9574, -0.2215], [-0.5767, 0.7218, -0.1620]]);
let tensor = Tensor::<TestBackend, 2>::from_data(data);
let tensor = TestTensor::from([[-0.4240, -0.9574, -0.2215], [-0.5767, 0.7218, -0.1620]]);

let data_actual = activation::mish(tensor).to_data();
let data_actual = activation::mish(tensor).into_data();

let data_expected = Data::from([[-0.1971, -0.3006, -0.1172], [-0.2413, 0.5823, -0.0888]]);
data_actual.assert_approx_eq(&data_expected, 4);
Expand Down
5 changes: 2 additions & 3 deletions burn-tensor/src/tests/activation/quiet_softmax.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@ mod tests {

#[test]
fn test_quiet_softmax_d2() {
let data = Data::from([[1.0, 7.0], [13.0, -3.0]]);
let tensor = Tensor::<TestBackend, 2>::from_data(data);
let tensor = TestTensor::from([[1.0, 7.0], [13.0, -3.0]]);

let data_actual = activation::quiet_softmax(tensor, 1).to_data();
let data_actual = activation::quiet_softmax(tensor, 1).into_data();

let data_expected = Data::from([[2.47e-03, 9.975e-01], [1.0, 1.1254e-07]]);
data_actual.assert_approx_eq(&data_expected, 4);
Expand Down
5 changes: 2 additions & 3 deletions burn-tensor/src/tests/activation/relu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@ mod tests {

#[test]
fn test_relu_d2() {
let data = Data::from([[0.0, -1.0, 2.0], [3.0, -4.0, 5.0]]);
let tensor = Tensor::<TestBackend, 2>::from_data_devauto(data);
let tensor = TestTensor::from([[0.0, -1.0, 2.0], [3.0, -4.0, 5.0]]);

let data_actual = activation::relu(tensor).to_data();
let data_actual = activation::relu(tensor).into_data();

let data_expected = Data::from([[0.0, 0.0, 2.0], [3.0, 0.0, 5.0]]);
assert_eq!(data_expected, data_actual);
Expand Down
10 changes: 4 additions & 6 deletions burn-tensor/src/tests/activation/sigmoid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,19 @@ mod tests {

#[test]
fn test_sigmoid() {
let data = Data::from([[1.0, 7.0], [13.0, -3.0]]);
let tensor = Tensor::<TestBackend, 2>::from_data_devauto(data);
let tensor = TestTensor::from([[1.0, 7.0], [13.0, -3.0]]);

let data_actual = activation::sigmoid(tensor).to_data();
let data_actual = activation::sigmoid(tensor).into_data();

let data_expected = Data::from([[0.7311, 0.9991], [1.0, 0.0474]]);
data_actual.assert_approx_eq(&data_expected, 4);
}

#[test]
fn test_sigmoid_overflow() {
let data = Data::from([f32::MAX, f32::MIN]);
let tensor = Tensor::<TestBackend, 1>::from_data_devauto(data);
let tensor = TestTensor::from([f32::MAX, f32::MIN]);

let data_actual = activation::sigmoid(tensor).to_data();
let data_actual = activation::sigmoid(tensor).into_data();

let data_expected = Data::from([1.0, 0.0]);
data_actual.assert_approx_eq(&data_expected, 4);
Expand Down
5 changes: 2 additions & 3 deletions burn-tensor/src/tests/activation/silu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@ mod tests {

#[test]
fn test_silu() {
let data = Data::from([[1.0, 2.0], [3.0, 4.0]]);
let tensor = Tensor::<TestBackend, 2>::from_data_devauto(data);
let tensor = TestTensor::from([[1.0, 2.0], [3.0, 4.0]]);

let data_actual = activation::silu(tensor).to_data();
let data_actual = activation::silu(tensor).into_data();

let data_expected = Data::from([[0.7311, 1.7616], [2.8577, 3.9281]]);
data_actual.assert_approx_eq(&data_expected, 4);
Expand Down
5 changes: 2 additions & 3 deletions burn-tensor/src/tests/activation/softmax.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@ mod tests {

#[test]
fn test_softmax_d2() {
let data = Data::from([[1.0, 7.0], [13.0, -3.0]]);
let tensor = Tensor::<TestBackend, 2>::from_data_devauto(data);
let tensor = TestTensor::from([[1.0, 7.0], [13.0, -3.0]]);

let data_actual = activation::softmax(tensor, 1).to_data();
let data_actual = activation::softmax(tensor, 1).into_data();

let data_expected = Data::from([[2.47e-03, 9.975e-01], [1.0, 1.1254e-07]]);
data_actual.assert_approx_eq(&data_expected, 4);
Expand Down
8 changes: 5 additions & 3 deletions burn-tensor/src/tests/activation/softplus.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@ mod tests {

#[test]
fn test_softplus_d2() {
let data = Data::from([[-0.4240, -0.9574, -0.2215], [-0.5767, 0.7218, -0.1620]]);
let tensor = Tensor::<TestBackend, 2>::from_data(data);
let tensor = Tensor::<TestBackend, 2>::from([
[-0.4240, -0.9574, -0.2215],
[-0.5767, 0.7218, -0.1620],
]);

let data_actual_beta1 = activation::softplus(tensor.clone(), 1.0).to_data();
let data_actual_beta1 = activation::softplus(tensor.clone(), 1.0).into_data();
let data_expected_beta1 = Data::from([[0.5034, 0.3249, 0.5885], [0.4458, 1.1178, 0.6154]]);
data_actual_beta1.assert_approx_eq(&data_expected_beta1, 4);

Expand Down
5 changes: 2 additions & 3 deletions burn-tensor/src/tests/activation/tanh_activation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@ mod tests {

#[test]
fn test_tanh() {
let data = Data::from([[1., 2.], [3., 4.]]);
let tensor = Tensor::<TestBackend, 2>::from_data_devauto(data);
let tensor = TestTensor::from([[1., 2.], [3., 4.]]);

let data_actual = activation::tanh(tensor).to_data();
let data_actual = activation::tanh(tensor).into_data();

let data_expected = Data::from([[0.7616, 0.9640], [0.9951, 0.9993]]);
data_actual.assert_approx_eq(&data_expected, 4);
Expand Down
9 changes: 3 additions & 6 deletions burn-tensor/src/tests/module/adaptive_avgpool1d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ mod tests {
length_out: 4,
};

test.assert_output(TestTensor::from_floats_devauto([[
test.assert_output(TestTensor::from([[
[0.5, 2.5, 4.5, 6.5],
[8.5, 10.5, 12.5, 14.5],
]]));
Expand All @@ -28,10 +28,7 @@ mod tests {
length_out: 3,
};

test.assert_output(TestTensor::from_floats_devauto([[
[1.0, 3.0, 5.0],
[8.0, 10.0, 12.0],
]]));
test.assert_output(TestTensor::from([[[1.0, 3.0, 5.0], [8.0, 10.0, 12.0]]]));
}

#[test]
Expand All @@ -43,7 +40,7 @@ mod tests {
length_out: 8,
};

test.assert_output(TestTensor::from_floats_devauto([[
test.assert_output(TestTensor::from([[
[0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0],
[4.0, 4.0, 5.0, 5.0, 6.0, 6.0, 7.0, 7.0],
]]));
Expand Down
8 changes: 4 additions & 4 deletions burn-tensor/src/tests/module/adaptive_avgpool2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ mod tests {
width_out: 4,
};

test.assert_output(TestTensor::from_floats_devauto([[
test.assert_output(TestTensor::from([[
[
[3.5000, 4.5000, 6.5000, 7.5000],
[15.5000, 16.5000, 18.5000, 19.5000],
Expand All @@ -42,7 +42,7 @@ mod tests {
width_out: 2,
};

test.assert_output(TestTensor::from_floats_devauto([[
test.assert_output(TestTensor::from([[
[[5.0000, 8.0000], [15.5000, 18.5000], [26.0000, 29.0000]],
[[40.0000, 43.0000], [50.5000, 53.5000], [61.0000, 64.0000]],
]]));
Expand All @@ -59,7 +59,7 @@ mod tests {
width_out: 4,
};

test.assert_output(TestTensor::from_floats_devauto([[
test.assert_output(TestTensor::from([[
[
[0.0000, 0.5000, 1.5000, 2.0000],
[1.5000, 2.0000, 3.0000, 3.5000],
Expand Down Expand Up @@ -89,7 +89,7 @@ mod tests {
impl AdaptiveAvgPool2dTestCase {
fn assert_output(self, y: TestTensor<4>) {
let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]);
let x = TestTensor::from_data_devauto(
let x = TestTensor::from(
TestTensorInt::arange_devauto(0..shape_x.num_elements())
.reshape(shape_x)
.into_data()
Expand Down
8 changes: 4 additions & 4 deletions burn-tensor/src/tests/module/avgpool1d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ mod tests {
count_include_pad: true,
};

test.assert_output(TestTensor::from_floats_devauto([[[1., 2., 3., 4.]]]));
test.assert_output(TestTensor::from([[[1., 2., 3., 4.]]]));
}

#[test]
Expand All @@ -31,7 +31,7 @@ mod tests {
count_include_pad: true,
};

test.assert_output(TestTensor::from_floats_devauto([[
test.assert_output(TestTensor::from([[
[0.3333, 2.0000, 4.0000],
[4.3333, 8.0000, 10.0000],
]]));
Expand All @@ -49,7 +49,7 @@ mod tests {
count_include_pad: false,
};

test.assert_output(TestTensor::from_floats_devauto([[
test.assert_output(TestTensor::from([[
[0.5000, 2.0000, 4.0000],
[6.5000, 8.0000, 10.0000],
]]));
Expand All @@ -68,7 +68,7 @@ mod tests {
impl AvgPool1dTestCase {
fn assert_output(self, y: TestTensor<3>) {
let shape_x = Shape::new([self.batch_size, self.channels, self.length]);
let x = TestTensor::from_data_devauto(
let x = TestTensor::from(
TestTensorInt::arange_devauto(0..shape_x.num_elements())
.reshape(shape_x)
.into_data()
Expand Down
8 changes: 4 additions & 4 deletions burn-tensor/src/tests/module/avgpool2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ mod tests {
count_include_pad: true,
};

test.assert_output(TestTensor::from_floats_devauto([[[
test.assert_output(TestTensor::from([[[
[7., 8., 9., 10.],
[13., 14., 15., 16.],
[19., 20., 21., 22.],
Expand All @@ -44,7 +44,7 @@ mod tests {
count_include_pad: true,
};

test.assert_output(TestTensor::from_floats_devauto([[[
test.assert_output(TestTensor::from([[[
[1.1667, 3.0000, 4.3333, 2.5000],
[3.2500, 7.5000, 9.5000, 5.2500],
[6.2500, 13.5000, 15.5000, 8.2500],
Expand All @@ -68,7 +68,7 @@ mod tests {
count_include_pad: false,
};

test.assert_output(TestTensor::from_floats_devauto([[[
test.assert_output(TestTensor::from([[[
[3.5000, 4.5000, 6.5000, 7.5000],
[6.5000, 7.5000, 9.5000, 10.5000],
[12.5000, 13.5000, 15.5000, 16.5000],
Expand All @@ -93,7 +93,7 @@ mod tests {
impl AvgPool2dTestCase {
fn assert_output(self, y: TestTensor<4>) {
let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]);
let x = TestTensor::from_data_devauto(
let x = TestTensor::from(
TestTensorInt::arange_devauto(0..shape_x.num_elements())
.reshape(shape_x)
.into_data()
Expand Down
6 changes: 3 additions & 3 deletions burn-tensor/src/tests/module/conv1d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ mod tests {
length: 4,
};

test.assert_output(TestTensor::from_floats_devauto([
test.assert_output(TestTensor::from([
[[43., 67., 82., 49.], [104., 176., 227., 158.]],
[[139., 187., 202., 113.], [392., 584., 635., 414.]],
]));
Expand All @@ -39,7 +39,7 @@ mod tests {
length: 4,
};

test.assert_output(TestTensor::from_floats_devauto([
test.assert_output(TestTensor::from([
[[62., 38.], [159., 111.]],
[[158., 102.], [447., 367.]],
]));
Expand All @@ -59,7 +59,7 @@ mod tests {
length: 4,
};

test.assert_output(TestTensor::from_floats_devauto([
test.assert_output(TestTensor::from([
[[2., 5., 8., 3.], [42., 63., 75., 47.]],
[[26., 29., 32., 11.], [114., 159., 171., 103.]],
]));
Expand Down
Loading

0 comments on commit d82e6b1

Please sign in to comment.