Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wrong output of fp32 Matmul #1190

Closed
sustcsonglin opened this issue Feb 13, 2023 · 16 comments
Closed

wrong output of fp32 Matmul #1190

sustcsonglin opened this issue Feb 13, 2023 · 16 comments

Comments

@sustcsonglin
Copy link

I compare the outputs of https://github.com/openai/triton/blob/main/python/triton/ops/matmul.py and torch.matmul

fp 16 works as it is, while for fp32 there is a large difference.

@Jokeren
Copy link
Contributor

Jokeren commented Feb 13, 2023

What inputs did you use?

@Jokeren
Copy link
Contributor

Jokeren commented Feb 13, 2023

And is it fp32 or tf32?

@sustcsonglin
Copy link
Author

sustcsonglin commented Feb 13, 2023

I use the default initialization of torch.rand, so I think it is fp32

@Jokeren
Copy link
Contributor

Jokeren commented Feb 13, 2023

Well, if you are using a GPU that supports tensor cores, the default should be TF32.

Let me be more specific on the "inputs". Can you please at least share the M, N, K, and dtype information? It would be even better if you can attach a simple .py file.

Example: https://github.com/openai/triton/blob/main/python/test/unit/operators/test_matmul.py

I do notice that we haven't tested fp32 with tensor core disabled, so maybe there is indeed something wrong

@Jokeren
Copy link
Contributor

Jokeren commented Feb 13, 2023

https://pytorch.org/docs/stable/notes/cuda.html

More reference:

# Do matmul at TF32 mode.
torch.backends.cuda.matmul.allow_tf32 = True
ab_tf32 = a @ b  # takes 0.016s on GA100
error = (ab_tf32 - ab_full).abs().max()  # 0.1747
relative_error = error / mean  # 0.0022

# Do matmul with TF32 disabled.
torch.backends.cuda.matmul.allow_tf32 = False
ab_fp32 = a @ b  # takes 0.11s on GA100
error = (ab_fp32 - ab_full).abs().max()  # 0.0031
relative_error = error / mean  # 0.000039

@sustcsonglin
Copy link
Author

sustcsonglin commented Feb 13, 2023

"_matmul.apply" is directly from https://github.com/openai/triton/blob/main/python/triton/ops/matmul.py

I have tried both torch.backends.cuda.matmul.allow_tf32 = False and True

    matmul = _matmul.apply
    a = torch.rand(1024, 1024).cuda().float()
    b = torch.rand(1024, 512).cuda().float()
    print(a.dtype, b.dtype)
    print( ((matmul(a, b)) - a@b).abs().max()) 

The outcome

torch.float32 torch.float32
tensor(0.1930, device='cuda:0')

@Jokeren
Copy link
Contributor

Jokeren commented Feb 13, 2023

I don't see obvious errors in this case. Note that the matrices you are testing are large ones, so the accumulated differences are expected to be somehow larger than you thought it would be.

BTW, fp32 should be more accurate, but we haven't optimized its performance yet.

import torch
import triton
import triton.language as tl

import numpy as np

a = torch.rand(1024, 1024).cuda().float()
b = torch.rand(1024, 512).cuda().float()

torch.backends.cuda.matmul.allow_tf32 = True
c_ref = a @ b
c = triton.ops.matmul(a, b)
torch.allclose(c, c_ref, rtol=0.01)

torch.backends.cuda.matmul.allow_tf32 = False
c_ref = a @ b
c = triton.ops.matmul(a, b)
torch.allclose(c, c_ref)

@Jokeren
Copy link
Contributor

Jokeren commented Feb 13, 2023

To verify, please print c and c_ref and take a look

@sustcsonglin
Copy link
Author

sustcsonglin commented Feb 13, 2023

I use NVIDIA A40

`
import torch
import triton
import triton.language as tl

import numpy as np

a = torch.rand(1024, 1024).cuda().float()
b = torch.rand(1024, 512).cuda().float()

torch.backends.cuda.matmul.allow_tf32 = True
c_ref = a @ b
c = triton.ops.matmul(a, b)
print(torch.allclose(c, c_ref, rtol=0.01))
print((c-c_ref).abs().max())

torch.backends.cuda.matmul.allow_tf32 = False
c_ref = a @ b
c = triton.ops.matmul(a, b)
print(torch.allclose(c, c_ref))
print( (c-c_ref).abs().max())
prinf(c)
prinf(c_ref)
`

Outcomes:

True tensor(0.1962, device='cuda:0') False tensor(0.1929, device='cuda:0') tensor([[243.1322, 246.3515, 240.2199, ..., 234.1416, 233.5045, 244.3806], [257.5991, 257.9790, 252.8911, ..., 244.5308, 249.1992, 254.9188], [260.1879, 263.5882, 259.1734, ..., 248.4590, 249.3200, 263.4294], ..., [248.6701, 257.1014, 255.7793, ..., 242.5764, 242.7238, 258.2006], [262.4697, 272.3884, 269.6973, ..., 256.3765, 252.2404, 261.1209], [259.7090, 269.3224, 263.1116, ..., 249.2758, 250.8823, 260.8963]], device='cuda:0') tensor([[243.2919, 246.5142, 240.3795, ..., 234.2974, 233.6607, 244.5448], [257.7720, 258.1540, 253.0603, ..., 244.6950, 249.3705, 255.0928], [260.3587, 263.7630, 259.3438, ..., 248.6195, 249.4847, 263.6058], ..., [248.8374, 257.2736, 255.9496, ..., 242.7423, 242.8906, 258.3785], [262.6431, 272.5712, 269.8774, ..., 256.5466, 252.4107, 261.2944], [259.8804, 269.4980, 263.2852, ..., 249.4434, 251.0513, 261.0689]], device='cuda:0')

@Jokeren
Copy link
Contributor

Jokeren commented Feb 13, 2023

So it passed torch.allclose(c, c_ref) right?

@Jokeren
Copy link
Contributor

Jokeren commented Feb 13, 2023

@sustcsonglin oh sorry, please remember to add manual_seed(123). I forgot to copy it here...

Just to ensure you get deterministic results. It is expected that results differ between each run.

@sustcsonglin
Copy link
Author

sustcsonglin commented Feb 13, 2023

`import torch
import triton
import triton.language as tl

import numpy as np
torch.manual_seed(123)
a = torch.rand(1024, 1024).cuda().float()
b = torch.rand(1024, 512).cuda().float()

torch.backends.cuda.matmul.allow_tf32 = True
c_ref = a @ b
c = triton.ops.matmul(a, b)
print(torch.allclose(c, c_ref, rtol=0.01))
print((c-c_ref).abs().max())
print(c)
print(c_ref)

torch.backends.cuda.matmul.allow_tf32 = False
c_ref = a @ b
c = triton.ops.matmul(a, b)
print(torch.allclose(c, c_ref))
print( (c-c_ref).abs().max())
print(c)
print(c_ref)

`

image

The first case passes the torch.allclose(c, c_ref) while the second fails

@Jokeren
Copy link
Contributor

Jokeren commented Feb 13, 2023

In that case, you can set torch.allclose(c, c_ref, rtol=0.01, atol=1e-3) and I would expect the second case to pass on your machine. It also varies among different GPUs...

@sustcsonglin
Copy link
Author

Thanks for your patience. Outputs are almost identical when using fp16, and I do not know the reasons for such differences when using fp32. Also i am not sure if rtol=0.01 is too big

@Jokeren
Copy link
Contributor

Jokeren commented Feb 13, 2023

We use different accumulation logics than cuBLAS.

It maybe worth investigating into the precisions, but fp32 is not our focus for now since it is way much slower than tf32.

@sustcsonglin
Copy link
Author

Thanks

ZzEeKkAa pushed a commit to ZzEeKkAa/triton that referenced this issue Aug 5, 2024
Conditionally use wall time if IPEX is not present. This allows us to
keep more accurate time for our current benchmarks, but also buys us
some time to figure out a better solution with the PyTorch team.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants