Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Implement Pad mode=reflect #181

Open
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

redthing1
Copy link

WIP, following the general outline from here to implement reflect mode.

I have just about zero experience with compute shaders or WGPU, so some help in actually implementing the reflect mode in the shader would be much appreciated.

@mayjs
Copy link
Contributor

mayjs commented Aug 14, 2023

Hey, I went ahead and extended your branch with an implementation of mode=reflect in the shader :)
You can take a look at it here: 866a5c4
I'm not sure if that works correctly on the high end of each axis - we may need a modulo operation there as well / add some checks...

Could you add some new test cases for that and try it?
Let me know if you need some help understanding how the shader works - I'm still a WGSL beginner myself, but I believe I could write up an explanation for you.

@redthing1
Copy link
Author

Thank you very much! I would very much appreciate an explanation. I can write up test cases when I find some time.

@mayjs
Copy link
Contributor

mayjs commented Aug 15, 2023

Okay, I'll try to explain the most important parts, let me know if you need some more details.

The most essential thing you need to understand about WGSL is the concept of invocations and workgroups.
I found this explanation very helpful to understand it, since it contains some good visualizations.
Here's the gist of it:

  • An invocation is a single execution of the main function in a shader
  • A workgroup is a "cube" of invocations, i.e. a set of invocations where each invocation is identified by a (x, y, z) coordinate tuple
  • Invocations in the same workgroup are executed concurrently
  • Workgroup sizes are limited, so you have to create multiple instances in the x, y and z direction.

Every invocation of the shader gets a global_invocation_id, which is a global identifier for this invocation (i.e. a 3-tuple of IDs that is unique for all workgroup instances).
In wonnx, you usually have one invocation for each output value, so we need to use this ID to determine the matrix indices of the output value the invocation is responsible for.
Since the Pad operator only uses the X direction for workgroup invocations (this is the case for most operations right now), we can base the entire calculation on the x coordinate of the global_invocation_id.
This is usually stored in a variable used gidx.

The calculation for this is done here in the Pad implementation:

rest = gidx;
{%- for chunks in o_chunks[0] -%}
{% if loop.last %}
let d_{{ loop.index0 }} = rest;
{% else %}
let d_{{ loop.index0 }} = rest / {{ chunks }}u;
rest = gidx % {{ chunks }}u;
{% endif %}
{%- endfor -%}

This can be a bit challenging to read because wonnx doesn't just use plain wgsl shaders, but adds a layer of tera templates on top of it.
It's often helpful to take a look at the raw wgsl source after evaluating the template, so let's do that for the section above (for a network I'm currently playing around with; the exact numbers do not matter here):

rest = gidx;
let d_0 = rest / 16384u; 
rest = gidx % 16384u; 

let d_1 = rest / 64u; 
rest = gidx % 64u; 

let d_2 = rest / 8u; 
rest = gidx % 8u; 

let d_3 = rest; 

As you can see, this is pretty much just a simple reverse calculation of the indices in a multi-dimensional array from the index in an equivalent single-dimension array.
The variables d_0, d_1, d_2 and d_3 will contain the indices for each dimension.
We need these values to check what we need to do for our output position.
The actual input and output values can only be accessed as a single-dimension array in the shader, so we'll need to recalculate a single coordinate later on using the i_chunks template variable.

Next, we look at these indices and check if we are in a a part of the array that needs padding (with a constant value) or if we are in a place that just needs copying of existing data.
Let's start by looking at the original code (without my addition):

var pad = false;
{% for pad in pad_info %}
let id_{{ loop.index0 }} = d_{{ loop.index0 }}
- {{ pad.copy_start }}u;
if (d_{{ loop.index0 }} < {{ pad.copy_start }}u) {
pad = true;
}
if (d_{{ loop.index0 }} > {{ pad.end_pad_start }}u) {
pad = true;
}
{% endfor %}

The code distinguishes two cases:

  1. Values in non-padding areas where we just need to copy the corresponding input value
  2. Values that need to be padded, i.e. filled with a constant value

To do this, the code goes over each dimension and checks if we are within the original values or the padding values for that axis.
If we are in a padding area in any axis, we need to pad with a constant value, so we'll set pad to true.
In addition to checking the padding condition, the code also calculates id_X indices for each dimension.
These indices are the indices we'd need to copy from for each dimension.
They are unused if we need to pad (since we just write a constant value), but will be used to read from the input in case we are not in the padding area.

After this, we just write the output value of our invocation:

if (pad) {
output_0.data[gidx] = {{ scalar_type }}({{ constant_value }});
} else {
let index =
{%- for chunk in i_chunks | first -%}
{%- if not loop.first %}
+
{%- endif -%}
id_{{ loop.index0 }} * {{ chunk }}u
{%- endfor -%}
;
output_0.data[gidx] = input_0.data[index];
}

If we are in a padding area, we write the constant value, otherwise we copy from the input.

Now let's go back to the previous section, but this time looking at my modified code: https://github.com/mayjs/wonnx/blob/866a5c4e3f314a0db140fd209698b435713f3e51/wonnx/templates/matrix/pad.wgsl#L27-L49
I just modified the code to never set pad = true in reflect mode.
Instead we calculate the reflected coordinates for each dimension where the index is in the padding area and then use the already existing copy-logic later.

Maybe this is easier to read if you look at the post-template code (With some additional comments for mode=reflect):

var id_0 = 0u;
if (d_0 < 0u) {
    // We are in the lower coordinate area of padding in dimension 0, calculate the reflected coordinate
    id_0 = (0u - d_0) % 1u;
} else if (d_0 > 1u) {
    // We are in the upper coordinate area of padding in dimension 0, calculate the reflected coordinate
    id_0 = 2u * 1u - d_0;
} else 
    // We are in the non-padding area in dimension 0, so we just copy from the input
    id_0 = d_0 - 0u;
}
    		
var id_1 = 0u;
if (d_1 < 0u) {
    // We are in the lower coordinate area of padding in dimension 1, calculate the reflected coordinate
    id_1 = (0u - d_1) % 256u;
} else if (d_1 > 256u) { 
    // We are in the upper coordinate area of padding in dimension 1, calculate the reflected coordinate
    id_1 = 2u * 256u - d_1;
} else {
    // We are in the non-padding area in dimension 1, so we just copy from the input
    id_1 = d_1 - 0u;
}
    		
var id_2 = 0u;
if (d_2 < 1u) {
    // We are in the lower coordinate area of padding in dimension 2, calculate the reflected coordinate
    id_2 = (1u - d_2) % 6u;
} else if (d_2 > 6u) {
    // We are in the upper coordinate area of padding in dimension 2, calculate the reflected coordinate
    id_2 = 2u * 6u - d_2;
} else {
    // We are in the non-padding area in dimension 2, so we just copy from the input
    id_2 = d_2 - 1u;
}
    		
var id_3 = 0u;
if (d_3 < 1u) {
    // We are in the lower coordinate area of padding in dimension 3, calculate the reflected coordinate
    id_3 = (1u - d_3) % 6u;
} else if (d_3 > 6u) {
    // We are in the upper coordinate area of padding in dimension 3, calculate the reflected coordinate
    id_3 = 2u * 6u - d_3;
} else {
    // We are in the non-padding area in dimension 3, so we just copy from the input
    id_3 = d_3 - 1u;
}

Notice that some of these conditions will never be true if a dimension is not padded at all.
In this example we are only padding in the H and W dimensions (i.e. dimension index 2 and 3), so the logic for id_0 and id_1 will always end up in the else case.

By looking at the expanded code, you can also tell that the logic for the upper areas will handle an edge case incorrectly.
Consider the line id_3 = 2u * 6u - d_3;
If our highest coordinate in the padding area exceeds the value 12, the subtraction will underflow, thus resulting in an invalid index.
You should be able to produce that issue with a suitable test case and it must be fixed by adding a check for the underflow condition.

I hope this is helpful to you. Feel free to ask more questions if you need more help :)

@redthing1
Copy link
Author

I very much appreciate the thorough writeup. Hopefully I can utilize it to implement more shaders for wonnx.

@redthing1
Copy link
Author

Pulled in your changes!

@redthing1
Copy link
Author

I added this test:

#[test]
fn test_pad_reflect_complex() {
    let mut input_data = HashMap::new();
    #[rustfmt::skip]
    let data = [
        1.0, 1.2, 1.3,
        2.3, 3.4, 4.5,
        4.5, 5.7, 6.8,
    ].to_vec();
    input_data.insert("X".to_string(), data.as_slice().into());

    let model = model(graph(
        vec![tensor("X", &[3, 2])],
        vec![tensor("Y", &[3, 4])],
        vec![],
        vec![initializer_int64("pads", vec![2, 2, 0, 0], vec![4])],
        vec![node(vec!["X", "pads"], vec!["Y"], "Pad", "Pad", vec![
            attribute("mode", "reflect"),
        ])],
    ));

    let session =
        pollster::block_on(wonnx::Session::from_model(model)).expect("session did not create");
    let result = pollster::block_on(session.run(&input_data)).unwrap();

    #[rustfmt::skip]
    let test_y = vec![
        1.2, 1.3, 1.0, 1.2, 1.3, 1.2, 1.0,
        3.4, 4.5, 2.3, 3.4, 4.5, 3.4, 2.3,
        5.7, 6.8, 4.5, 5.7, 6.8, 5.7, 4.5,
    ];
    let actual: &[_] = (&result["Y"]).try_into().unwrap();
    // No arithmetic is done, so `assert_eq!` can be used.
    assert_eq!(actual, &test_y);
}

However this fails with:


failures:

---- test_pad_reflect_complex stdout ----
[2023-08-22T03:51:41Z ERROR wgpu::backend::direct] Handling wgpu errors as fatal by default
thread 'test_pad_reflect_complex' panicked at 'wgpu error: Validation Error

Caused by:
    In Queue::write_buffer
    Copy of 0..36 would end up overrunning the bounds of the Destination buffer of size 32

', /Users/user/.cargo/registry/src/index.crates.io-6f17d22bba15001f/wgpu-0.16.0/src/backend/direct.rs:3019:5
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace


failures:
    test_pad_reflect_complex

test result: FAILED. 21 passed; 1 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.10s

I'm pretty sure the expected pad output I wrote is right, but feel free to point out any mistake.
Also, not sure why it crashes.

@mayjs
Copy link
Contributor

mayjs commented Aug 22, 2023

I think you just forgot to adjust the "X" input size for your new input tensor dimensions. It should probably be &[3, 3] right?

Regarding your output values: I think your pad values should be [0, 2, 0, 2] to pad the x dimension at the beginning and end (which I assume you wanted to do by looking at your output).

Are you sure that the output values at the start of each dimension are correct? In would have expected the values to be the other way around, e.g. 1.3, 1.2, 1.0, 1.2, 1.3, 1.2, 1.0. But you can easily validate those values by either using Python and numpy or Rust and ndarray and ndarray-ndimage.

Not sure if your padding value is large enough to trigger the bug I expect, but it's a good test case either way :)

@redthing1
Copy link
Author

Ah my bad, thanks for pointing out my error.

@redthing1
Copy link
Author

I finally got around to fixing it, lol. @mayjs you implied that there might be another bug?

@redthing1 redthing1 marked this pull request as ready for review November 5, 2023 05:08
@mayjs
Copy link
Contributor

mayjs commented Dec 2, 2023

I finally got around to fixing it, lol. @mayjs you implied that there might be another bug?

Hey, sorry for the late answer, I'm glad that you fixed it :)
Looking at my initial explanation, I found the part about the issue with id_3 = 2u * 6u - d_3;.

I think that part referred to configurations where the padding size is larger than the input size. In those cases a wrap-around is required (at least I think it should wrap), but the logic in that shader would just produce a subtraction underflow.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants