Skip to content

V0.4#515

Open
fwilliams wants to merge 2 commits intomainfrom
v0.4
Open

V0.4#515
fwilliams wants to merge 2 commits intomainfrom
v0.4

Conversation

@fwilliams
Copy link
Collaborator

No description provided.

@fwilliams fwilliams requested a review from a team as a code owner March 6, 2026 18:43
blackencino and others added 2 commits March 6, 2026 14:49
# PredGatherIGemm: Alternative Sparse Convolution Backend

## Summary

This PR adds a new sparse convolution backend -- **PredGatherIGemm** --
that uses
CUTLASS/CuTe implicit-GEMM (IGEMM) with predicated `cp.async` gather
loads on
SM80+ (Ampere and later) GPUs. It processes one output NanoVDB leaf node
per CTA,
using TF32 tensor-core arithmetic for the computation.

The backend is integrated into the `ConvolutionPlan` framework as a
selectable
backend (`expert_config={"backend": "pred_gather_igemm"}`), with the
existing
GatherScatterDefault backend remaining the default.

## Constraints

The PredGatherIGemm backend is intentionally limited in scope compared
to the
default GatherScatterDefault backend:

- **CUDA only**, requires SM80+ (Ampere or later)
- **Float32 only** (internally promoted to TF32)
- **Forward pass only** -- no transpose, no analytical backward
(backward falls
  back to GatherScatterDefault when used via autograd)
- **Uniform kernel sizes** only: 3, 5, or 7 (x=y=z)
- **Uniform strides** only: 1 or 2 (x=y=z)
- **Channel counts** must be multiples of 32
- **Batch size 1** only

Kernel size and stride are dispatched at compile time using the
project's
`dispatch` framework, giving 6 total template instantiations.

## Performance Characteristics

Benchmarked on SM120 with Cin=64, Cout=128, kernel 3x3x3, stride 1:

| Scenario | PredGatherIGemm | GS + topology | GS (topology cached) |
|---|---|---|---|
| 1M dense (75% leaf occ) | **5.2 ms** | 45.8 ms | 31.2 ms |
| 2M dense (75% leaf occ) | **10.2 ms** | 89.1 ms | 64.9 ms |
| 4M sparse (25% leaf occ) | 32.0 ms | 21.1 ms | **14.8 ms** |
| 8M sparse (10% leaf occ) | 43.6 ms | 8.1 ms | **5.0 ms** |

The IGEMM backend is significantly faster for **dense or near-dense**
grids
(high leaf-node occupancy), where its one-leaf-per-CTA approach keeps
the GPU
fully occupied. At low occupancy the per-CTA work becomes sparse and the
GatherScatterDefault backend -- which operates on compacted index pairs
-- wins
decisively.

## Files Changed

### New files

- `src/fvdb/detail/ops/convolution/PredGatherIGemm.h` -- public header
- `src/fvdb/detail/ops/convolution/PredGatherIGemm.cu` -- CUTLASS IGEMM
kernel,
  CuTe layouts, dispatch table, and entry point
- `src/tests/PredGatherIGemmTest.cu` -- C++ gtests: correctness
validation
against GatherScatterDefault across all 6 kernel/stride combinations,
plus
  speed comparison benchmarks
- `tests/unit/test_conv_pred_gather_igemm.py` -- Python tests:
forward-pass
  validation against dense PyTorch conv3d ground truth and cross-backend
  comparison with GatherScatterDefault

### Modified files

- `src/fvdb/GridBatch.h` / `src/fvdb/GridBatch.cpp` -- added static
  `predGatherIGemmConv` method
- `src/python/Bindings.cpp` -- pybind11 binding for
`pred_gather_igemm_conv`
- `fvdb/_fvdb_cpp.pyi` -- type stub for the new binding
- `fvdb/convolution_plan.py` -- `_PredGatherIGemmBackend`, autograd
wrapper
(`_PredGatherIGemmConvFn`), backend selection logic in `_build_backend`
- `src/CMakeLists.txt` / `src/tests/CMakeLists.txt` -- added new source
and test
  files to the build

## Test Plan

- `ninja PredGatherIGemmTest && ./src/tests/PredGatherIGemmTest` --
  runs the C++ gtest suite (correctness + benchmarks)
- `python -m pytest tests/unit/test_conv_pred_gather_igemm.py -v` --
runs the Python test suite (forward-only, TF32-tolerant comparisons
against
  dense ground truth and GatherScatterDefault)

---------

Signed-off-by: Christopher Horvath <chorvath@nvidia.com>
Signed-off-by: Francis Williams <fwilliams@users.noreply.github.com>
fwilliams added a commit that referenced this pull request Mar 8, 2026
## Summary

- Add `@probabilistic_test` decorator to `test_jsum_list_of_lists`,
matching the pattern already used by `test_jsum`
- The bfloat16 numerical precision causes non-deterministic `allclose`
failures when comparing `jsum` against `scatter_sum`
- The decorator runs 20 iterations and passes if >= 80% succeed, which
is the same strategy used for the flat-list variant

This was observed as a transient CI failure on the V0.4 release branch
(PR #515).

## Test plan

- [ ] CI passes (codestyle + unit tests)
- [ ] Verify `test_jsum_list_of_lists_1_cuda` (bfloat16) no longer
flakes


Made with [Cursor](https://cursor.com)

Signed-off-by: Francis Williams <francis@fwilliams.info>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants