Summary
In the current FMHA / attention-related CUTLASS wrapper (and MLA phase1 path), the following areas appear to have potential risks in terms of correctness, robustness, and debuggability:
- Varlen option derivation is logically unsafe
- Input tensor validation is too weak
op.can_implement() failure messages are not very informative
- Workspace handling is left as TODO and may cause runtime issues
(1) Varlen option derivation is logically unsafe (important)
Observed in:
get_options() where:
options.q = q.size(0) / options.b;
options.k = k.size(0) / options.b;
This logic only makes sense for dense batched (fixed sequence length) inputs.
In varlen mode, q.size(0) and k.size(0) often represent the total token counts (e.g. last element of prefix-sum).
Dividing by b produces an average-like value that may not match the actual maximum sequence length.
Why this is risky
Varlen kernels usually rely on:
max_seqlen_q
max_seqlen_kv
Using total_tokens / b may:
- Underestimate the required shape
- Break stride or shape assumptions
- Lead to
can_implement failures or undefined behavior
Example
b = 2
seqlen = [1, 4096]
total_tokens = 4097
options.q = total_tokens / b = 2048
actual max_seqlen = 4096
This creates an inconsistent problem shape.
Suggested fix
When kIsVarlen (or equivalent flag) is set:
- Derive
options.q and options.k from max_seqlen_q / max_seqlen_kv, or
- Leave them as placeholders and finalize the shape only in
initialize_varlen()
(2) Input tensor validation is too weak
Observed near kernel argument preparation / launch path.
Current checks only validate some stride conditions.
To avoid silent miscompilation or memory issues, stronger validation is recommended.
Suggested checks
- All tensors are CUDA tensors
- All tensors use expected dtype (bf16/fp16/fp8 as required)
- All tensors are on the same device
- Layout/stride matches what the kernel supports
Example:
TORCH_CHECK(q.is_cuda() && k.is_cuda() && v.is_cuda() && o.is_cuda(),
"All tensors must be CUDA tensors");
TORCH_CHECK(q.scalar_type() == at::kBFloat16,
"Unsupported dtype for q");
TORCH_CHECK(q.device() == k.device() &&
k.device() == v.device() &&
v.device() == o.device(),
"All tensors must be on the same device");
TORCH_CHECK(q.is_contiguous() || is_supported_strided(q),
"Unsupported layout/stride for q");
(3) Improve can_implement error handling
Observed in:
CUTLASS_CHECK(op.can_implement(arguments));
If this fails, the user often gets an unhelpful error without context.
Suggested improvement
Capture the status and provide a descriptive error:
auto status = op.can_implement(arguments);
TORCH_CHECK(status == cutlass::Status::kSuccess,
"can_implement failed: ",
cutlass_status_to_string(status),
" | q shape=", ...,
" | q stride=", ...,
" | k shape=", ...,
" | dtype=", ...);
This significantly improves debuggability.
(4) Workspace TODO may cause runtime or portability issues
Observed where workspace logic is currently commented out or marked as TODO.
Depending on the CUTLASS configuration or architecture, workspace may be required.
Ignoring it can cause:
- Runtime launch failures
- Architecture-dependent behavior
- Hard-to-diagnose issues
Suggested fix
Query and validate workspace size:
size_t ws = Operation::get_workspace_size(arguments);
TORCH_CHECK(workspace_nbytes >= ws,
"workspace too small: need ", ws, " bytes");
CUTLASS_CHECK(op.initialize(arguments, workspace_ptr));
CUTLASS_CHECK(op.run(stream));
Expected benefits
- More stable varlen behavior
- Early detection of invalid inputs
- Much clearer debugging when
can_implement fails
- Better portability across architectures and configurations
Summary
In the current FMHA / attention-related CUTLASS wrapper (and MLA phase1 path), the following areas appear to have potential risks in terms of correctness, robustness, and debuggability:
op.can_implement()failure messages are not very informative(1) Varlen option derivation is logically unsafe (important)
Observed in:
get_options()where:This logic only makes sense for dense batched (fixed sequence length) inputs.
In varlen mode,
q.size(0)andk.size(0)often represent the total token counts (e.g. last element of prefix-sum).Dividing by
bproduces an average-like value that may not match the actual maximum sequence length.Why this is risky
Varlen kernels usually rely on:
max_seqlen_qmax_seqlen_kvUsing
total_tokens / bmay:can_implementfailures or undefined behaviorExample
This creates an inconsistent problem shape.
Suggested fix
When
kIsVarlen(or equivalent flag) is set:options.qandoptions.kfrommax_seqlen_q/max_seqlen_kv, orinitialize_varlen()(2) Input tensor validation is too weak
Observed near kernel argument preparation / launch path.
Current checks only validate some stride conditions.
To avoid silent miscompilation or memory issues, stronger validation is recommended.
Suggested checks
Example:
(3) Improve
can_implementerror handlingObserved in:
CUTLASS_CHECK(op.can_implement(arguments));If this fails, the user often gets an unhelpful error without context.
Suggested improvement
Capture the status and provide a descriptive error:
This significantly improves debuggability.
(4) Workspace TODO may cause runtime or portability issues
Observed where workspace logic is currently commented out or marked as TODO.
Depending on the CUTLASS configuration or architecture, workspace may be required.
Ignoring it can cause:
Suggested fix
Query and validate workspace size:
Expected benefits
can_implementfails