-
Notifications
You must be signed in to change notification settings - Fork 60
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP: Implement Pad
mode=reflect
#181
base: master
Are you sure you want to change the base?
Conversation
Hey, I went ahead and extended your branch with an implementation of Could you add some new test cases for that and try it? |
Thank you very much! I would very much appreciate an explanation. I can write up test cases when I find some time. |
Okay, I'll try to explain the most important parts, let me know if you need some more details. The most essential thing you need to understand about WGSL is the concept of invocations and workgroups.
Every invocation of the shader gets a The calculation for this is done here in the wonnx/wonnx/templates/matrix/pad.wgsl Lines 17 to 25 in fbb7ab1
This can be a bit challenging to read because wonnx doesn't just use plain wgsl shaders, but adds a layer of tera templates on top of it.It's often helpful to take a look at the raw wgsl source after evaluating the template, so let's do that for the section above (for a network I'm currently playing around with; the exact numbers do not matter here):
rest = gidx;
let d_0 = rest / 16384u;
rest = gidx % 16384u;
let d_1 = rest / 64u;
rest = gidx % 64u;
let d_2 = rest / 8u;
rest = gidx % 8u;
let d_3 = rest; As you can see, this is pretty much just a simple reverse calculation of the indices in a multi-dimensional array from the index in an equivalent single-dimension array. Next, we look at these indices and check if we are in a a part of the array that needs padding (with a constant value) or if we are in a place that just needs copying of existing data. wonnx/wonnx/templates/matrix/pad.wgsl Lines 27 to 38 in fbb7ab1
The code distinguishes two cases:
To do this, the code goes over each dimension and checks if we are within the original values or the padding values for that axis. After this, we just write the output value of our invocation: wonnx/wonnx/templates/matrix/pad.wgsl Lines 40 to 53 in fbb7ab1
If we are in a padding area, we write the constant value, otherwise we copy from the input. Now let's go back to the previous section, but this time looking at my modified code: https://github.com/mayjs/wonnx/blob/866a5c4e3f314a0db140fd209698b435713f3e51/wonnx/templates/matrix/pad.wgsl#L27-L49 Maybe this is easier to read if you look at the post-template code (With some additional comments for var id_0 = 0u;
if (d_0 < 0u) {
// We are in the lower coordinate area of padding in dimension 0, calculate the reflected coordinate
id_0 = (0u - d_0) % 1u;
} else if (d_0 > 1u) {
// We are in the upper coordinate area of padding in dimension 0, calculate the reflected coordinate
id_0 = 2u * 1u - d_0;
} else
// We are in the non-padding area in dimension 0, so we just copy from the input
id_0 = d_0 - 0u;
}
var id_1 = 0u;
if (d_1 < 0u) {
// We are in the lower coordinate area of padding in dimension 1, calculate the reflected coordinate
id_1 = (0u - d_1) % 256u;
} else if (d_1 > 256u) {
// We are in the upper coordinate area of padding in dimension 1, calculate the reflected coordinate
id_1 = 2u * 256u - d_1;
} else {
// We are in the non-padding area in dimension 1, so we just copy from the input
id_1 = d_1 - 0u;
}
var id_2 = 0u;
if (d_2 < 1u) {
// We are in the lower coordinate area of padding in dimension 2, calculate the reflected coordinate
id_2 = (1u - d_2) % 6u;
} else if (d_2 > 6u) {
// We are in the upper coordinate area of padding in dimension 2, calculate the reflected coordinate
id_2 = 2u * 6u - d_2;
} else {
// We are in the non-padding area in dimension 2, so we just copy from the input
id_2 = d_2 - 1u;
}
var id_3 = 0u;
if (d_3 < 1u) {
// We are in the lower coordinate area of padding in dimension 3, calculate the reflected coordinate
id_3 = (1u - d_3) % 6u;
} else if (d_3 > 6u) {
// We are in the upper coordinate area of padding in dimension 3, calculate the reflected coordinate
id_3 = 2u * 6u - d_3;
} else {
// We are in the non-padding area in dimension 3, so we just copy from the input
id_3 = d_3 - 1u;
} Notice that some of these conditions will never be true if a dimension is not padded at all. By looking at the expanded code, you can also tell that the logic for the upper areas will handle an edge case incorrectly. I hope this is helpful to you. Feel free to ask more questions if you need more help :) |
I very much appreciate the thorough writeup. Hopefully I can utilize it to implement more shaders for wonnx. |
Pulled in your changes! |
I added this test: #[test]
fn test_pad_reflect_complex() {
let mut input_data = HashMap::new();
#[rustfmt::skip]
let data = [
1.0, 1.2, 1.3,
2.3, 3.4, 4.5,
4.5, 5.7, 6.8,
].to_vec();
input_data.insert("X".to_string(), data.as_slice().into());
let model = model(graph(
vec![tensor("X", &[3, 2])],
vec![tensor("Y", &[3, 4])],
vec![],
vec![initializer_int64("pads", vec![2, 2, 0, 0], vec![4])],
vec![node(vec!["X", "pads"], vec!["Y"], "Pad", "Pad", vec![
attribute("mode", "reflect"),
])],
));
let session =
pollster::block_on(wonnx::Session::from_model(model)).expect("session did not create");
let result = pollster::block_on(session.run(&input_data)).unwrap();
#[rustfmt::skip]
let test_y = vec![
1.2, 1.3, 1.0, 1.2, 1.3, 1.2, 1.0,
3.4, 4.5, 2.3, 3.4, 4.5, 3.4, 2.3,
5.7, 6.8, 4.5, 5.7, 6.8, 5.7, 4.5,
];
let actual: &[_] = (&result["Y"]).try_into().unwrap();
// No arithmetic is done, so `assert_eq!` can be used.
assert_eq!(actual, &test_y);
} However this fails with:
I'm pretty sure the expected pad output I wrote is right, but feel free to point out any mistake. |
I think you just forgot to adjust the "X" input size for your new input tensor dimensions. It should probably be Regarding your output values: I think your pad values should be Are you sure that the output values at the start of each dimension are correct? In would have expected the values to be the other way around, e.g. Not sure if your padding value is large enough to trigger the bug I expect, but it's a good test case either way :) |
Ah my bad, thanks for pointing out my error. |
I finally got around to fixing it, lol. @mayjs you implied that there might be another bug? |
Hey, sorry for the late answer, I'm glad that you fixed it :) I think that part referred to configurations where the padding size is larger than the input size. In those cases a wrap-around is required (at least I think it should wrap), but the logic in that shader would just produce a subtraction underflow. |
WIP, following the general outline from here to implement reflect mode.
I have just about zero experience with compute shaders or WGPU, so some help in actually implementing the reflect mode in the shader would be much appreciated.