Skip to content

[Bug] triton error when run chunk gated deltanet op #196

Closed
@LouChao98

Description

@LouChao98

Checklist

  • I have checked FAQs and existing issues for similar problems
  • Please report this bug in English to ensure wider understanding and support

Describe the Bug

Env: triton==3.2.0, H800

When run into

def chunk_gated_delta_rule_fwd_kernel_h(

I get

python3: /project/lib/Tools/LinearLayout.cpp:562: mlir::triton::LinearLayout mlir::triton::LinearLayout::reshapeOuts(llvm::ArrayRef<std::pair<mlir::StringAttr, int> >) const: Assertion `getTotalOutDimSize() == std::accumulate( newOutDims.begin(), newOutDims.end(), 1, [&](int32_t acc, auto &outDim) { return acc * outDim.second; })' failed.

Using triton==3.1.0 does not encouter this error.

Steps to Reproduce the Bug

import torch
from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_fwd_h

k = torch.randn(1, 8192, 36, 256, device="cuda", dtype=torch.bfloat16)
w = torch.randn(1, 8192, 36, 256, device="cuda", dtype=torch.bfloat16)
u = torch.randn(1, 8192, 36, 256, device="cuda", dtype=torch.bfloat16)
g = torch.randn(1, 8192, 36, device="cuda", dtype=torch.float32)
head_first = False
chunk_size = 64

chunk_gated_delta_rule_fwd_h(k=k, w=w, u=u, g=g, head_first=head_first, chunk_size=chunk_size)

Expected Behavior

no error

Environment Information

  1. Torch: 2.3.1
  2. Triton: 3.2.0

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions