Closed
Description
Checklist
- I have checked FAQs and existing issues for similar problems
- Please report this bug in English to ensure wider understanding and support
Describe the Bug
Env: triton==3.2.0, H800
When run into
I get
python3: /project/lib/Tools/LinearLayout.cpp:562: mlir::triton::LinearLayout mlir::triton::LinearLayout::reshapeOuts(llvm::ArrayRef<std::pair<mlir::StringAttr, int> >) const: Assertion `getTotalOutDimSize() == std::accumulate( newOutDims.begin(), newOutDims.end(), 1, [&](int32_t acc, auto &outDim) { return acc * outDim.second; })' failed.
Using triton==3.1.0 does not encouter this error.
Steps to Reproduce the Bug
import torch
from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_fwd_h
k = torch.randn(1, 8192, 36, 256, device="cuda", dtype=torch.bfloat16)
w = torch.randn(1, 8192, 36, 256, device="cuda", dtype=torch.bfloat16)
u = torch.randn(1, 8192, 36, 256, device="cuda", dtype=torch.bfloat16)
g = torch.randn(1, 8192, 36, device="cuda", dtype=torch.float32)
head_first = False
chunk_size = 64
chunk_gated_delta_rule_fwd_h(k=k, w=w, u=u, g=g, head_first=head_first, chunk_size=chunk_size)
Expected Behavior
no error
Environment Information
- Torch: 2.3.1
- Triton: 3.2.0