Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tl.broadcast bug? #2157

Open
debadeepta opened this issue Aug 23, 2023 · 5 comments
Open

tl.broadcast bug? #2157

debadeepta opened this issue Aug 23, 2023 · 5 comments

Comments

@debadeepta
Copy link

GPU: A6000
Version: 2.1.0 (nightly build)

TLDR:

Broadcasting a vector to the shape of a 2D matrix and then multiplying them produces wrong results.

Details:

In the minimal reproducible kernel below I store the output of tl.broadcast_to to a tensor T and on inspection of v and T the output is clearly wrong.

import torch
import triton
import triton.language as tl

@triton.jit
def vector_matrix_rowmul(X, V, Y, T, N: tl.constexpr, d: tl.constexpr):

    pid = tl.program_id(0)

    # load X
    X_block_ptr = tl.make_block_ptr(base=X, shape=(N, d), strides=(d, 1), 
                                    offsets=(0, 0), block_shape=(N, d),
                                    order=(1, 0))
    
    x = tl.load(X_block_ptr)

    # load V
    V_block_ptr = tl.make_block_ptr(base=V, shape=(N), strides=(1), offsets=(0), 
                                    block_shape=(N,), order=(0))
    
    v = tl.load(V_block_ptr)

    z = tl.view(v, (N, 1))
    t = tl.broadcast_to(z, (N, d))

    # store t to T to check broadcast's output 
    T_block_ptr = tl.make_block_ptr(base=T, shape=(N, d), strides=(d, 1),
                                    offsets=(0, 0), block_shape=(N, d),
                                    order=(1, 0))
    tl.store(T_block_ptr, t)

    y = t * x

    # store result
    Y_block_ptr = tl.make_block_ptr(base=Y, shape=(N, d), strides=(d, 1),
                                    offsets=(0, 0), block_shape=(N, d),
                                    order=(1, 0))
    tl.store(Y_block_ptr, y)




if __name__ == '__main__':


    torch.manual_seed(42)
    N = 32
    d = 16
    Q = torch.randn(N, d).contiguous().cuda()
    v = torch.randn(N).contiguous().cuda()

    # PyTorch
    O_pyt = torch.reshape(v, (N, 1)) * Q

    # Triton
    O_triton = torch.zeros_like(Q)
    T = torch.zeros_like(Q)
    grid = (1,)
    vector_matrix_rowmul[grid](Q, v, O_triton, T, N, d)

    isclose = torch.allclose(O_pyt, O_triton, atol=1e-3)
    print(f"is close: {isclose}")

    print('done')
@debadeepta
Copy link
Author

I am also confused by why implicit and explicit broadcast are broken but reshaping using mytensor[:, None] works correctly. See the code snippet below.

v = tl.load(V_block_ptr)
z = tl.view(v, (N, 1))
t = tl.broadcast_to(z, (N, d))

# correct
y = v[:, None] * x

# incorrect
# y = t * x

# incorrect
# y = z * x

@peterbell10
Copy link
Contributor

From tl.view's docstring (emphasis mine):

Returns a tensor with the same elements as input but a different shape.
The order of the elements may not be preserved.

To preserve ordering you should use indexing with None as your example shows, or equivalently use tl.expand_dim.

IMO tl.view should probably be renamed since its behavior is so different from torch.view.

@debadeepta
Copy link
Author

Didn't pay attention to that. Just assumed it has same semantics as torch view.

Yes indeed the name should be changed.

@debadeepta
Copy link
Author

The tl.broadcast_to is still broken I think:

For some vector v of shape [N, 1]:
a35b0484-78f3-421e-a734-0b845c0f8f00

I get this matrix as the output after tl.broadcast_to(v, (N, d)

de00f4cc-a608-430a-9988-4d4b32194ed8

Each row should be the same value repeated, right?

@xijiu9
Copy link

xijiu9 commented Dec 2, 2023

image
Maybe this problem is fixed? I build from source to install triton.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants