Skip to content

[CUDA/CUTLASS] Improvements for varlen option derivation, input validation, can_implement errors, and workspace handling #161

@red1239109-cmd

Description

@red1239109-cmd

Summary

In the current FMHA / attention-related CUTLASS wrapper (and MLA phase1 path), the following areas appear to have potential risks in terms of correctness, robustness, and debuggability:

  1. Varlen option derivation is logically unsafe
  2. Input tensor validation is too weak
  3. op.can_implement() failure messages are not very informative
  4. Workspace handling is left as TODO and may cause runtime issues

(1) Varlen option derivation is logically unsafe (important)

Observed in:
get_options() where:

options.q = q.size(0) / options.b;
options.k = k.size(0) / options.b;

This logic only makes sense for dense batched (fixed sequence length) inputs.

In varlen mode, q.size(0) and k.size(0) often represent the total token counts (e.g. last element of prefix-sum).
Dividing by b produces an average-like value that may not match the actual maximum sequence length.

Why this is risky

Varlen kernels usually rely on:

  • max_seqlen_q
  • max_seqlen_kv

Using total_tokens / b may:

  • Underestimate the required shape
  • Break stride or shape assumptions
  • Lead to can_implement failures or undefined behavior

Example

b = 2
seqlen = [1, 4096]
total_tokens = 4097

options.q = total_tokens / b = 2048
actual max_seqlen = 4096

This creates an inconsistent problem shape.

Suggested fix

When kIsVarlen (or equivalent flag) is set:

  • Derive options.q and options.k from max_seqlen_q / max_seqlen_kv, or
  • Leave them as placeholders and finalize the shape only in initialize_varlen()

(2) Input tensor validation is too weak

Observed near kernel argument preparation / launch path.

Current checks only validate some stride conditions.
To avoid silent miscompilation or memory issues, stronger validation is recommended.

Suggested checks

  • All tensors are CUDA tensors
  • All tensors use expected dtype (bf16/fp16/fp8 as required)
  • All tensors are on the same device
  • Layout/stride matches what the kernel supports

Example:

TORCH_CHECK(q.is_cuda() && k.is_cuda() && v.is_cuda() && o.is_cuda(),
            "All tensors must be CUDA tensors");

TORCH_CHECK(q.scalar_type() == at::kBFloat16,
            "Unsupported dtype for q");

TORCH_CHECK(q.device() == k.device() &&
            k.device() == v.device() &&
            v.device() == o.device(),
            "All tensors must be on the same device");

TORCH_CHECK(q.is_contiguous() || is_supported_strided(q),
            "Unsupported layout/stride for q");

(3) Improve can_implement error handling

Observed in:

CUTLASS_CHECK(op.can_implement(arguments));

If this fails, the user often gets an unhelpful error without context.

Suggested improvement

Capture the status and provide a descriptive error:

auto status = op.can_implement(arguments);

TORCH_CHECK(status == cutlass::Status::kSuccess,
            "can_implement failed: ",
            cutlass_status_to_string(status),
            " | q shape=", ...,
            " | q stride=", ...,
            " | k shape=", ...,
            " | dtype=", ...);

This significantly improves debuggability.


(4) Workspace TODO may cause runtime or portability issues

Observed where workspace logic is currently commented out or marked as TODO.

Depending on the CUTLASS configuration or architecture, workspace may be required.
Ignoring it can cause:

  • Runtime launch failures
  • Architecture-dependent behavior
  • Hard-to-diagnose issues

Suggested fix

Query and validate workspace size:

size_t ws = Operation::get_workspace_size(arguments);

TORCH_CHECK(workspace_nbytes >= ws,
            "workspace too small: need ", ws, " bytes");

CUTLASS_CHECK(op.initialize(arguments, workspace_ptr));
CUTLASS_CHECK(op.run(stream));

Expected benefits

  • More stable varlen behavior
  • Early detection of invalid inputs
  • Much clearer debugging when can_implement fails
  • Better portability across architectures and configurations

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions