Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Operator] Add weight_norm op [MooreThreads] #177

Merged
merged 1 commit into from
Sep 29, 2024

Conversation

TZWX-0
Copy link
Contributor

@TZWX-0 TZWX-0 commented Aug 23, 2024

No description provided.

@TZWX-0 TZWX-0 force-pushed the add_weight_norm branch 2 times, most recently from 54570d7 to e7fa442 Compare August 23, 2024 07:27
@TZWX-0
Copy link
Contributor Author

TZWX-0 commented Aug 23, 2024

perf of some cases on NV A100, tensor dtype is float16:
image

src/flag_gems/__init__.py Outdated Show resolved Hide resolved
Copy link
Contributor

@tongxin tongxin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please simplify the implementation only considering dim is 0 or v.dim() - 1.

@tongxin
Copy link
Contributor

tongxin commented Aug 28, 2024

It seems that there's an issue with the weight_norm in PyTorch. For the same input, the output results for float16 and float32 are completely different, likely due to a broadcasting problem. Based on the formula, the correct results should match those produced by the half (float16) implementation. Moreover, for float32, the broadcasting logic behaves strangely, with different dimensions leading to inconsistent broadcasting behavior.

For instance:

When v shape is (2, 1, 1) and g shape is (2, 2, 2), With dim = 0, the output shape is (2, 1, 1). With dim = 1, the output shape is (2, 2, 2).

import torch

v = torch.ones([2, 2], dtype=torch.float16).to("cuda")
g = torch.tensor([1, 2, 3, 4], dtype=torch.float16).to("cuda").reshape(2, 2)
golden_output = torch._weight_norm(v, g, dim = 0)
print(golden_output)


v = torch.ones([2, 2], dtype=torch.float32).to("cuda")
g = torch.tensor([1, 2, 3, 4], dtype=torch.float32).to("cuda").reshape(2, 2)
golden_output = torch._weight_norm(v, g, dim = 0)
print(golden_output)

output is

tensor([[0.7070, 1.4141],
        [2.1211, 2.8281]], device='cuda:0', dtype=torch.float16)
tensor([[0.7071, 0.7071],
        [1.4142, 1.4142]], device='cuda:0')

g is supposed to be a scalar factor for dim dimensions. For instance if dim is 0, g.shape should be something like [1, N].

@TZWX-0 TZWX-0 force-pushed the add_weight_norm branch 2 times, most recently from d5c2125 to 8ff8cad Compare September 2, 2024 06:59
v_value = tl.load(v + row_offset * N + col_offset, mask=mask)
v_block += v_value * v_value

normalized = tl.sqrt(tl.sum(v_block, axis=1) + eps)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should be reducing on the first dimension, ie., axis=0

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

v_block is stored in row-major order, so I perform the sum along the rows regardless of whether the reduction dimension is the first or last (xy index will be permuted for last). The test encountered an error because REDUCTION_SHAPES = (200, 40999, 3) and dim = 1 is not supported for weight normalization; this issue has now been resolved.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reducing on dim 1 is only correct provided the inputs are transposed up front. It looks like that's not the case in WeightNorm.forward. Can we further verify that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If reduction occurs in the error dimension, the result will definitely be different compared to the golden reference, but currently, they are consistent. The transpose occurs within the kernel, where threads load the number in the row direction from global, but store it in the column direction of v_block.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

            M = v.shape[0]
            N = math.prod(v.shape[1:])
            grid = lambda META: (triton.cdiv(M, META["BLOCK_ROW_SIZE"]),)

Above is the blocking scheme in the code, where M is the reduction dim size. It's clear the reduction axis is split. I don't see how transpose could be done in the kernel...

Copy link
Contributor Author

@TZWX-0 TZWX-0 Sep 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the kernel

// for reduce dim is first
tx = tl.arange(0, BLOCK_COL_SIZE)[None, :]
v_block = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)

// for reduce dim is last
ty = tl.arange(0, BLOCK_ROW_SIZE)[None, :]
v_block = tl.zeros([BLOCK_COL_SIZE, BLOCK_ROW_SIZE], dtype=tl.float32)

how about you verify this with a simple instance, for example reduce shape = (2, 2). if reduce dim is wrong in the kernel, the result will not consistent with golden

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad... I took for granted that the input dim is the dimension to be contracted off..

@TZWX-0 TZWX-0 force-pushed the add_weight_norm branch 2 times, most recently from 45ab7ec to 3bea848 Compare September 26, 2024 01:46
tests/test_norm_ops.py Outdated Show resolved Hide resolved
tests/test_norm_ops.py Outdated Show resolved Hide resolved
Copy link
Contributor

@tongxin tongxin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG

@zhzhcookie zhzhcookie merged commit fba6cdb into FlagOpen:master Sep 29, 2024
3 of 4 checks passed
DuanYaQi pushed a commit that referenced this pull request Oct 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants