Skip to content

Commit

Permalink
Add ConvTranspose1d ONNX op (#2349)
Browse files Browse the repository at this point in the history
* add python, onnx model, burn/node for conv_transpose1d

* modify supported-onnx-ops.md about convtranspose1d

* add conv_transpose1d function in op_configuration and to_burn

* add test and modify dim_inference

* apply cargo fmt to formatting

* fix reviewer point-outs by adding symetry checks for padding in 2d,3d

* fix pads initilization and check ways
  • Loading branch information
tiruka authored Oct 21, 2024
1 parent 4d31d19 commit b7887b0
Show file tree
Hide file tree
Showing 11 changed files with 425 additions and 15 deletions.
2 changes: 1 addition & 1 deletion crates/burn-import/SUPPORTED-ONNX-OPS.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ represent the corresponding Burn Op.
| [Conv2d][34] |||
| [Conv3d][34] |||
| [ConvInteger][37] |||
| [ConvTranspose1d][38] | ||
| [ConvTranspose1d][38] | ||
| [ConvTranspose2d][38] |||
| [ConvTranspose3d][38] |||
| [Cos][39] |||
Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ fn main() {
.input("tests/conv1d/conv1d.onnx")
.input("tests/conv2d/conv2d.onnx")
.input("tests/conv3d/conv3d.onnx")
.input("tests/conv_transpose1d/conv_transpose1d.onnx")
.input("tests/conv_transpose2d/conv_transpose2d.onnx")
.input("tests/conv_transpose3d/conv_transpose3d.onnx")
.input("tests/cos/cos.onnx")
Expand Down
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#!/usr/bin/env python3

# used to generate model: conv_transpose1d.onnx

import torch
import torch.nn as nn


class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.transposed_conv = nn.ConvTranspose1d(
in_channels=4,
out_channels=6,
kernel_size=3,
stride=2,
padding=1,
dilation=2,
output_padding=1,
groups=2
)

def forward(self, x):
return self.transposed_conv(x)


def main():

# Set seed for reproducibility
torch.manual_seed(42)

torch.set_printoptions(precision=8)

# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")

file_name = "conv_transpose1d.onnx"
test_input = torch.ones(2, 4, 10, device=device)
torch.onnx.export(model, test_input, file_name,
verbose=False, opset_version=16)

print("Finished exporting model to {}".format(file_name))

# Output some test data for use in the test
print("Test input data shape of ones: {}".format(test_input.shape))
output = model.forward(test_input)
print("Test output data shape: {}".format(output.shape))

sum = output.sum().item()

print("Test output sum: {}".format(sum))


if __name__ == "__main__":
main()
23 changes: 23 additions & 0 deletions crates/burn-import/onnx-tests/tests/test_onnx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ include_models!(
conv1d,
conv2d,
conv3d,
conv_transpose1d,
conv_transpose2d,
conv_transpose3d,
cos,
Expand Down Expand Up @@ -1508,6 +1509,28 @@ mod tests {
output.to_data().assert_approx_eq(&expected, 4);
}

#[test]
fn conv_transpose1d() {
// Initialize the model with weights (loaded from the exported file)
let model: conv_transpose1d::Model<Backend> = conv_transpose1d::Model::default();

// Run the model with ones as input for easier testing
let input = Tensor::<Backend, 3>::ones([2, 4, 10], &Default::default());

let output = model.forward(input);

let expected_shape = Shape::from([2, 6, 22]);
assert_eq!(output.shape(), expected_shape);

// We are using the sum of the output tensor to test the correctness of the conv_transpose1d node
// because the output tensor is too large to compare with the expected tensor.
let output_sum = output.sum().into_scalar();

let expected_sum = 33.810329; // example result running the corresponding PyTorch model (conv_transpose1d.py)

assert!(expected_sum.approx_eq(output_sum, (1.0e-4, 2)));
}

#[test]
fn conv_transpose2d() {
// Initialize the model with weights (loaded from the exported file)
Expand Down
10 changes: 7 additions & 3 deletions crates/burn-import/src/burn/node/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@ use super::{
argmax::ArgMaxNode, avg_pool1d::AvgPool1dNode, avg_pool2d::AvgPool2dNode,
batch_norm::BatchNormNode, binary::BinaryNode, clip::ClipNode, concat::ConcatNode,
constant::ConstantNode, constant_of_shape::ConstantOfShapeNode, conv1d::Conv1dNode,
conv2d::Conv2dNode, conv3d::Conv3dNode, conv_transpose_2d::ConvTranspose2dNode,
conv_transpose_3d::ConvTranspose3dNode, dropout::DropoutNode, expand::ExpandNode,
gather::GatherNode, gather_elements::GatherElementsNode, global_avg_pool::GlobalAvgPoolNode,
conv2d::Conv2dNode, conv3d::Conv3dNode, conv_transpose_1d::ConvTranspose1dNode,
conv_transpose_2d::ConvTranspose2dNode, conv_transpose_3d::ConvTranspose3dNode,
dropout::DropoutNode, expand::ExpandNode, gather::GatherNode,
gather_elements::GatherElementsNode, global_avg_pool::GlobalAvgPoolNode,
layer_norm::LayerNormNode, linear::LinearNode, mask_where::WhereNode, matmul::MatmulNode,
max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, mean::MeanNode, pad::PadNode,
prelu::PReluNode, random_normal::RandomNormalNode, random_uniform::RandomUniformNode,
Expand Down Expand Up @@ -93,6 +94,7 @@ pub enum Node<PS: PrecisionSettings> {
Conv1d(Conv1dNode),
Conv2d(Conv2dNode),
Conv3d(Conv3dNode),
ConvTranspose1d(ConvTranspose1dNode),
ConvTranspose2d(ConvTranspose2dNode),
ConvTranspose3d(ConvTranspose3dNode),
PRelu(PReluNode),
Expand Down Expand Up @@ -142,6 +144,7 @@ macro_rules! match_all {
Node::Conv1d(node) => $func(node),
Node::Conv2d(node) => $func(node),
Node::Conv3d(node) => $func(node),
Node::ConvTranspose1d(node) => $func(node),
Node::ConvTranspose2d(node) => $func(node),
Node::ConvTranspose3d(node) => $func(node),
Node::PRelu(node) => $func(node),
Expand Down Expand Up @@ -199,6 +202,7 @@ impl<PS: PrecisionSettings> Node<PS> {
Node::Conv1d(_) => "conv1d",
Node::Conv2d(_) => "conv2d",
Node::Conv3d(_) => "conv3d",
Node::ConvTranspose1d(_) => "conv_transpose1d",
Node::ConvTranspose2d(_) => "conv_transpose2d",
Node::ConvTranspose3d(_) => "conv_transpose3d",
Node::PRelu(_) => "prelu",
Expand Down
201 changes: 201 additions & 0 deletions crates/burn-import/src/burn/node/conv_transpose_1d.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
use super::{Node, NodeCodegen, SerializationBackend};
use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type};
use burn::{
module::{ConstantRecord, Param, ParamId},
nn::conv::{ConvTranspose1dConfig, ConvTranspose1dRecord},
record::{PrecisionSettings, Record},
tensor::{Tensor, TensorData},
};
use proc_macro2::TokenStream;
use quote::quote;
use serde::Serialize;

#[derive(Debug, Clone)]
pub struct ConvTranspose1dNode {
pub field: OtherType,
pub input: TensorType,
pub output: TensorType,
pub data_weights: TensorData,
pub data_bias: Option<TensorData>,
pub config: ConvTranspose1dConfig,
}

impl ConvTranspose1dNode {
pub fn new<S: AsRef<str>>(
name: S,
input: TensorType,
output: TensorType,
data_weights: TensorData,
data_bias: Option<TensorData>,
config: ConvTranspose1dConfig,
) -> Self {
Self {
field: OtherType::new(
name,
quote! {
ConvTranspose1d<B>
},
),
input,
output,
data_weights,
data_bias,
config,
}
}
}

impl<PS: PrecisionSettings> NodeCodegen<PS> for ConvTranspose1dNode {
fn input_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.input.clone())]
}
fn output_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.output.clone())]
}
fn field_type(&self) -> Option<Type> {
Some(Type::Other(self.field.clone()))
}

fn field_init(&self) -> Option<TokenStream> {
let name = &self.field.name;
let channels = self.config.channels.to_tokens();
let kernel_size = self.config.kernel_size.to_tokens();
let stride = self.config.stride.to_tokens();
let dilation = self.config.dilation.to_tokens();
let groups = self.config.groups.to_tokens();
let padding = self.config.padding.to_tokens();
let padding_out = self.config.padding_out.to_tokens();
let bias = self.config.bias;

let tokens = quote! {
let #name = ConvTranspose1dConfig::new(#channels, #kernel_size)
.with_stride(#stride)
.with_padding(#padding)
.with_padding_out(#padding_out)
.with_dilation(#dilation)
.with_groups(#groups)
.with_bias(#bias)
.init(device);
};

Some(tokens)
}

fn field_serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
let device = Default::default();
let record = ConvTranspose1dRecord::<SerializationBackend> {
weight: Param::initialized(
ParamId::new(),
Tensor::from_data(
self.data_weights.clone().convert::<PS::FloatElem>(),
&device,
),
),
bias: self.data_bias.as_ref().map(|bias| {
Param::initialized(
ParamId::new(),
Tensor::from_data(bias.clone().convert::<PS::FloatElem>(), &device),
)
}),
stride: ConstantRecord::new(),
kernel_size: ConstantRecord::new(),
dilation: ConstantRecord::new(),
groups: ConstantRecord::new(),
padding: ConstantRecord::new(),
padding_out: ConstantRecord::new(),
channels: [ConstantRecord::new(); 2],
};

let item = Record::into_item::<PS>(record);
item.serialize(serializer)
}

fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
let input = scope.tensor_use_owned(&self.input, node_position);
let output = &self.output.name;
let field = &self.field.name;

quote! {
let #output = self.#field.forward(#input);
}
}
fn register_imports(&self, imports: &mut BurnImports) {
imports.register("burn::nn::conv::ConvTranspose1d");
imports.register("burn::nn::conv::ConvTranspose1dConfig");
}

fn into_node(self) -> Node<PS> {
Node::ConvTranspose1d(self)
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::burn::{
graph::BurnGraph,
node::{conv_transpose_1d::ConvTranspose1dNode, test::assert_tokens},
TensorType,
};
use burn::{nn::conv::ConvTranspose1dConfig, record::FullPrecisionSettings};

#[test]
fn test_codegen() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();

graph.register(ConvTranspose1dNode::new(
"conv_transpose_1d",
TensorType::new_float("input", 3),
TensorType::new_float("output", 3),
TensorData::from([2f32]),
None,
ConvTranspose1dConfig::new([3, 3], 3).with_padding(0),
));

graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]);

let expected = quote! {
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};
use burn::nn::conv::ConvTranspose1d;
use burn::nn::conv::ConvTranspose1dConfig;

#[derive(Module, Debug)]
pub struct Model <B: Backend> {
conv_transpose_1d: ConvTranspose1d<B>,
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}

impl<B: Backend> Model <B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
let conv_transpose_1d = ConvTranspose1dConfig::new([3, 3], 3)
.with_stride(1)
.with_padding(0)
.with_padding_out(0)
.with_dilation(1)
.with_groups(1)
.with_bias(true)
.init(device);

Self {
conv_transpose_1d,
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
let output = self.conv_transpose_1d.forward(input);

output
}
}
};

assert_tokens(graph.codegen(), expected);
}
}
1 change: 1 addition & 0 deletions crates/burn-import/src/burn/node/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pub(crate) mod constant_of_shape;
pub(crate) mod conv1d;
pub(crate) mod conv2d;
pub(crate) mod conv3d;
pub(crate) mod conv_transpose_1d;
pub(crate) mod conv_transpose_2d;
pub(crate) mod conv_transpose_3d;
pub(crate) mod dropout;
Expand Down
Loading

0 comments on commit b7887b0

Please sign in to comment.