Skip to content

Fail to compile T.copy for fp8 on sm89 #981

@bingps

Description

@bingps

When running python examples/deepseek_v32/fp8_lighting_indexer.py on sm89, it reports error like

/tmp/tmpts7207id.cu(44): error: identifier "make_fp8_e4_16_t" is undefined
        condval = make_fp8_e4_16_t(fp8_e4_t(0x0p+0f ), fp8_e4_t(0x0p+0f ), fp8_e4_t(0x0p+0f ), fp8_e4_t(0x0p+0f ), fp8_e4_t(0x0p+0f ), fp8_e4_t(0x0p+0f ), fp8_e4_t(0x0p+0f ), fp8_e4_t(0x0p+0f ), fp8_e4_t(0x0p+0f ), fp8_e4_t(0x0p+0f ), fp8_e4_t(0x0p+0f ), fp8_e4_t(0x0p+0f ), fp8_e4_t(0x0p+0f ), fp8_e4_t(0x0p+0f ), fp8_e4_t(0x0p+0f ), fp8_e4_t(0x0p+0f ));

this is because the T.copy(IndexQ[seq_len_i * heads, 0], index_q_shared) will be lowed to CUDA like

  #pragma unroll
  for (int i = 0; i < 2; ++i) {
    fp8_e4_16_t condval;
    if ((((((((int)blockIdx.x) * 128) + (i * 64)) + (((int)threadIdx.x) >> 2)) < (seq_len * 32)) && ((((((int)blockIdx.x) * 128) + (i * 64)) + (((int)threadIdx.x) >> 2)) < (seq_len * 32)))) {
      condval = *(fp8_e4_16_t*)(IndexQ + (((((int64_t)((int)blockIdx.x)) * (int64_t)8192) + (((int64_t)i) * (int64_t)4096)) + (((int64_t)((int)threadIdx.x)) * (int64_t)16)));
    } else {
      condval = make_fp8_e4_16_t(fp8_e4_t(0x0p+0f/*0.000000e+00*/), fp8_e4_t(0x0p+0f/*0.000000e+00*/), fp8_e4_t(0x0p+0f/*0.000000e+00*/), fp8_e4_t(0x0p+0f/*0.000000e+00*/), fp8_e4_t(0x0p+0f/*0.000000e+00*/), fp8_e4_t(0x0p+0f/*0.000000e+00*/), fp8_e4_t(0x0p+0f/*0.000000e+00*/), fp8_e4_t(0x0p+0f/*0.000000e+00*/), fp8_e4_t(0x0p+0f/*0.000000e+00*/), fp8_e4_t(0x0p+0f/*0.000000e+00*/), fp8_e4_t(0x0p+0f/*0.000000e+00*/), fp8_e4_t(0x0p+0f/*0.000000e+00*/), fp8_e4_t(0x0p+0f/*0.000000e+00*/), fp8_e4_t(0x0p+0f/*0.000000e+00*/), fp8_e4_t(0x0p+0f/*0.000000e+00*/), fp8_e4_t(0x0p+0f/*0.000000e+00*/));
    }
    *(fp8_e4_16_t*)(((fp8_e4_t*)buf_dyn_shmem) + (((((i * 4096) + ((((int)threadIdx.x) >> 2) * 64)) + (((((((int)threadIdx.x) & 31) >> 4) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 15) >> 3) + (((int)threadIdx.x) & 1)) & 1) * 16)) + 49152)) = condval;
  }

and make_fp8_e4_16_t is not defined.

I made a temp fix in bingps@2115e1f by adding definition at src/tl_templates/cuda/cuda_fp8.h, and it passes the tests of fp8_lighting_indexer.py.

Are such manual fixes required? Or there are some config that auto-generates these make-functions disabled on sm89 🤔

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions