-
Notifications
You must be signed in to change notification settings - Fork 2
[Benchmark] Add benchmark for IPC communication and support for fence operators #22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
- Introduced a new benchmark script `benchmark_ipc_p2p.py` to evaluate the performance of IPC communication against NVSHMEM primitives. - Add docstring for `T.putmem_block`
- Introduced two new benchmark scripts: `benchmark_nvshmem_p2p.py` for measuring NVSHMEM-based communication bandwidth and `benchmark_unrolledcp_p2p.py` for evaluating an unrolled-copy IPC method. - Added a README.md file to document the benchmarks and their usage. - Updated installation documentation to clarify MPI requirements for NVSHMEM support.
WalkthroughAdds two IPC P2P benchmark scripts and README, introduces three TileLang fence intrinsics (fence_cta/fence_gpu/fence_sys) with Python wrappers, maps them in CUDA codegen including NVSHMEM block put/get support, and includes small docs changes and an NVSHMEM docstring. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant L as Launcher
participant R0 as Rank 0
participant R1 as Rank 1
participant D as torch.distributed
participant K as TileLang / NVSHMEM kernels
L->>R0: spawn(rank=0)
L->>R1: spawn(rank=1)
R0->>D: init_process_group(world_size=2)
R1->>D: init_process_group(world_size=2)
Note over R0,R1: allocate symmetric src/dst buffers
par Push phase
R0->>K: launch push kernel (peer=1)
R1->>K: launch push kernel (peer=0)
K-->>K: nvshmemx_putmem_block / T.push_warp
K-->>K: fence_sys()
end
R0->>D: barrier()
R1->>D: barrier()
par Pull phase
R0->>K: launch pull kernel (peer=1)
R1->>K: launch pull kernel (peer=0)
K-->>K: nvshmemx_getmem_block / T.pull_warp
K-->>K: fence_sys()
end
R0->>D: barrier()
R1->>D: barrier()
R0-->>L: print bandwidth table
R0->>D: destroy_process_group()
R1->>D: destroy_process_group()
Estimated code review effort🎯 4 (Complex) | ⏱️ ~55 minutes Possibly related PRs
Suggested reviewers
Poem
✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR/Issue comments)Type Other keywords and placeholders
CodeRabbit Configuration File (
|
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/target/codegen_cuda.cc (1)
183-186: Fix malformed NVSHMEM includes (extra '>'): this will not compile.There’s a trailing '>' in both include lines.
Apply:
- if (use_nvshmem_) { - decl_stream << "#include <nvshmem.h>>\n"; - decl_stream << "#include <nvshmemx.h>>\n"; - } + if (use_nvshmem_) { + decl_stream << "#include <nvshmem.h>\n"; + decl_stream << "#include <nvshmemx.h>\n"; + }
🧹 Nitpick comments (16)
docs/get_started/Installation.md (1)
120-120: Clarify MPI requirement for NVSHMEM build (docs/get_started/Installation.md:120): Replace thepip install mpich # building NVSHMEM needs MPIline with a note that this only installs the Python package and does not provide the MPI compiler wrapper. Recommend installing a full MPI implementation (MPICH or OpenMPI) via OS package manager (e.g.sudo apt-get install mpich), Homebrew (brew install mpich), or conda-forge (conda install -c conda-forge mpich), and verifympiccis on your PATH (which mpicc).tilelang/language/distributed/multi_device/nvshmem.py (1)
98-104: Docstring fixes: src symmetry and byte count.
srcneed not be a symmetric address (local pointer is fine for PUT).nelemsis actually bytes; the name/readme mismatch can confuse users.Suggest clarifying wording to avoid misuse.
def putmem_block(*args): - """Put data from local memory to remote memory at block granularity. - Args: - dest: Symmetric address of the destination data object. - src: Symmetric address of the object containing the data to be copied. - nelems: Number of elements to be transferred (in bytes). - pe: The PE ID of the destination PE. - """ + """Put data from local memory to remote memory at block granularity. + Args: + dest: Symmetric address on the destination PE to write to. + src: Local address of the source data (need not be symmetric). + nelems: Number of bytes to transfer. + pe: Destination PE ID. + """benchmark/distributed/ipc_impls/README.md (2)
4-4: Minor grammar tweak.“avoid NVLink bandwidth as the bottleneck” → “avoid NVLink bandwidth being the bottleneck.”
-We launch only one block on each rank to avoid NVLink bandwidth as the bottleneck. +We launch only one block on each rank to avoid NVLink bandwidth being the bottleneck.
17-34: Fix markdownlint MD058: add blank lines around the table.Insert a blank line before and after the table so linters pass.
-## Results on Hopper connected by NVLink +## Results on Hopper connected by NVLink + | Size (Bytes) | NVSHMEM Push BW (GB/s) | NVSHMEM Pull BW (GB/s) | TileScale Push BW (GB/s) | TileScale Pull BW (GB/s) | ... -| 4,194,304 | 10.6560 | 2.2474 | 11.9145 | 2.2845 | +| 4,194,304 | 10.6560 | 2.2474 | 11.9145 | 2.2845 | + -> **Note:** All data presented above are unidirectional bandwidth. +> **Note:** All data presented above are unidirectional bandwidth.benchmark/distributed/ipc_impls/benchmark_nvshmem_p2p.py (6)
1-5: Typos and incorrect usage path.
- “NVHSMEM” → “NVSHMEM”.
- Update the usage path to the actual location under
ipc_impls/.-# This benchmark aims to measure the bandwidth of NVHSMEM-based communication. +# This benchmark aims to measure the bandwidth of NVSHMEM-based communication. ... -# Usage: GPUS=2 bash tilelang/distributed/launch.sh benchmark/distributed/benchmark_nvshmem_p2p.py +# Usage: GPUS=2 bash tilelang/distributed/launch.sh benchmark/distributed/ipc_impls/benchmark_nvshmem_p2p.py
26-33: Ensure remote operation completion for accurate timing.
putcan be asynchronous;fence_sys()orders but doesn’t guarantee completion. AddT.quiet()after the NVSHMEM op to time completed transfers. Keepinggetconsistent also helps.T.putmem_block( T.address_of(dst), T.address_of(src), size * 4, T.get_pe() ^ 1, ) + T.quiet() T.fence_sys()T.getmem_block( T.address_of(dst), T.address_of(src), size * 4, T.get_pe() ^ 1, ) + T.quiet() T.fence_sys()Also applies to: 45-52
56-60: Prefer explicit exceptions over assert for user input.Asserts can be stripped with
-O. Raise aValueErrorfor robustness.- assert num_ranks == 2, "this benchmark only supports 2 ranks" - assert args.threads % 32 == 0, "threads must be divisible by 32" + if num_ranks != 2: + raise ValueError("this benchmark only supports 2 ranks") + if args.threads % 32 != 0: + raise ValueError("threads must be divisible by 32")
56-58: Silence linter for unusedrank.Rename to
_rankto address ARG001 without changing call sites.-def benchmark_nvshmem_bw(rank: int, num_ranks: int, group: dist.ProcessGroup, size: int, +def benchmark_nvshmem_bw(_rank: int, num_ranks: int, group: dist.ProcessGroup, size: int, args: argparse.Namespace):
93-97: Help text/default mismatch.
--repeatdefault is 50, but help says 10. Align the help.- parser.add_argument( - "--repeat", type=int, default=50, help="number of repeat iterations (default: 10)") + parser.add_argument( + "--repeat", type=int, default=50, help="number of repeat iterations (default: 50)")
68-71: Minor: avoid redundant synchronizations.
perf_fnalready synchronizes the current stream and waits on events; the explicittorch.cuda.synchronize()is usually unnecessary.- torch.cuda.synchronize() + # torch.cuda.synchronize() is unnecessary; perf_fn synchronizes. ... - torch.cuda.synchronize() + # torch.cuda.synchronize() is unnecessary; perf_fn synchronizes.Also applies to: 81-85
benchmark/distributed/ipc_impls/benchmark_unrolledcp_p2p.py (3)
62-66: Replaceassertwith explicit argument validation and mark unused arg.
assertcan be stripped with-O; use exceptions.rankis unused; prefix to silence linters.-def benchmark_ipc_bw(rank: int, num_ranks: int, group: dist.ProcessGroup, size: int, +def benchmark_ipc_bw(_rank: int, num_ranks: int, group: dist.ProcessGroup, size: int, args: argparse.Namespace, allocator): - assert num_ranks == 2, "this benchmark only supports 2 ranks" - assert args.threads % 32 == 0, "threads must be divisible by 32" + if num_ranks != 2: + raise ValueError("This benchmark only supports 2 ranks") + if args.threads % 32 != 0: + raise ValueError("threads must be divisible by 32")
27-33: Addressing mode: symmetric vs UVA — verify allocator guarantees symmetric addresses.You pass
dst_pe/src_pe = rank^1, which selects the symmetric-address path. Ensure the allocator provides symmetric allocations across PEs; otherwise use the UVA path (dst_pe/src_pe=None) with IPC-mapped peer pointers.I can help add a switch to choose symmetric vs UVA at runtime.
Also applies to: 51-57
10-12: Nit: set NCCL env before any CUDA/torch init and consider INFO for troubleshooting.Minor: move the env assignment to the very top; consider
INFOwhen diagnosing issues.src/op/sync.h (1)
72-94: Docstrings call these “synchronize all threads”; they are memory fences (ordering/visibility), not thread barriers.To avoid misuse, reword as “memory fence” and drop “synchronize all threads”.
- * \brief Synchronize all threads at the GPU level (visible to all blocks on the - * current device) + * \brief Create a memory fence at GPU scope (makes prior global/surface writes + * visible to all blocks on the current device) ... - * \brief Synchronize all threads at the system level (visible in a node) + * \brief Create a memory fence at system scope (makes prior writes visible across the node)tilelang/language/builtin.py (1)
419-431: Docstrings: clarify these are memory fences, not thread synchronizations.Align wording with fence semantics; “synchronize all threads” suggests a barrier.
-def fence_cta(): - """Create a memory fence at the block level (visible to all threads in the current block).""" +def fence_cta(): + """Memory fence at block scope (makes prior writes visible within the CTA).""" ... -def fence_gpu(): - """Synchronize all threads at the GPU level (visible to all blocks on the current device).""" +def fence_gpu(): + """Memory fence at device scope (makes prior writes visible to all blocks on the GPU).""" ... -def fence_sys(): - """Synchronize all threads at the system level (visible in a node).""" +def fence_sys(): + """Memory fence at system scope (makes prior writes visible across the node)."""src/target/codegen_cuda.cc (1)
1651-1658: New fence intrinsics mapping: OK.Emits tl::memory_fence_{cta,gpu,sys} and enables distributed headers; no NVSHMEM dependency. LGTM.
If feasible, avoid toggling use_distributed_ solely for fences to prevent emitting the 8KB meta_data constant. A separate flag (e.g., use_sync_intrinsics_) could include sync.h without meta_data.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (9)
benchmark/distributed/ipc_impls/README.md(1 hunks)benchmark/distributed/ipc_impls/benchmark_nvshmem_p2p.py(1 hunks)benchmark/distributed/ipc_impls/benchmark_unrolledcp_p2p.py(1 hunks)docs/get_started/Installation.md(1 hunks)src/op/sync.cc(1 hunks)src/op/sync.h(1 hunks)src/target/codegen_cuda.cc(3 hunks)tilelang/language/builtin.py(1 hunks)tilelang/language/distributed/multi_device/nvshmem.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (6)
tilelang/language/builtin.py (1)
tilelang/language/tir/op.py (1)
call_intrin(119-144)
src/op/sync.h (1)
tilelang/language/builtin.py (3)
fence_cta(419-421)fence_gpu(424-426)fence_sys(429-431)
benchmark/distributed/ipc_impls/benchmark_nvshmem_p2p.py (5)
tilelang/distributed/utils.py (1)
perf_fn(217-238)tilelang/language/distributed/multi_device/nvshmem.py (3)
putmem_block(97-105)get_pe(6-8)getmem_block(77-78)tilelang/language/tir/op.py (1)
address_of(463-479)tilelang/language/builtin.py (1)
fence_sys(429-431)tilelang/jit/__init__.py (1)
compile(32-81)
src/op/sync.cc (1)
tilelang/language/builtin.py (3)
fence_cta(419-421)fence_gpu(424-426)fence_sys(429-431)
benchmark/distributed/ipc_impls/benchmark_unrolledcp_p2p.py (9)
tilelang/distributed/utils.py (2)
init_dist(34-56)perf_fn(217-238)tilelang/env.py (1)
disable_cache(247-248)tilelang/language/distributed/common.py (3)
get_rank(8-11)push_warp(20-42)pull_warp(45-67)tilelang/language/tir/op.py (1)
address_of(463-479)tilelang/language/builtin.py (1)
fence_sys(429-431)tilelang/jit/__init__.py (1)
compile(32-81)tilelang/jit/kernel.py (1)
initialize(400-409)tilelang/utils/tensor.py (1)
tensor(45-57)tilelang/utils/allocator.py (1)
get_allocator(226-238)
src/target/codegen_cuda.cc (1)
tilelang/language/builtin.py (3)
fence_cta(419-421)fence_gpu(424-426)fence_sys(429-431)
🪛 markdownlint-cli2 (0.17.2)
benchmark/distributed/ipc_impls/README.md
18-18: Tables should be surrounded by blank lines
(MD058, blanks-around-tables)
🪛 Ruff (0.12.2)
benchmark/distributed/ipc_impls/benchmark_nvshmem_p2p.py
56-56: Unused function argument: rank
(ARG001)
58-58: Use of assert detected
(S101)
59-59: Use of assert detected
(S101)
benchmark/distributed/ipc_impls/benchmark_unrolledcp_p2p.py
62-62: Unused function argument: rank
(ARG001)
64-64: Use of assert detected
(S101)
65-65: Use of assert detected
(S101)
🔇 Additional comments (5)
benchmark/distributed/ipc_impls/benchmark_nvshmem_p2p.py (1)
26-31: Verify address_of usage with Tensors.
address_oftypically expects a BufferLoad (e.g.,dst[0]), not the buffer object directly. If your overload doesn’t handle raw tensors, this will fail at compile-time. Recommend usingT.address_of(dst[0])andT.address_of(src[0]), or confirm current usage compiles.- T.putmem_block( - T.address_of(dst), - T.address_of(src), + T.putmem_block( + T.address_of(dst[0]), + T.address_of(src[0]), size * 4, T.get_pe() ^ 1, )benchmark/distributed/ipc_impls/benchmark_unrolledcp_p2p.py (1)
78-79: Confirmsizeunit expected by tl.push_warp/pull_warp and bandwidth calc.If the intrinsics expect bytes (not elements), the reported GB/s will be off by 4×. If they expect elements, current math is correct. Please confirm and adjust either the kernel
sizeargument or the GB/s formula accordingly.Also applies to: 92-93
src/target/codegen_cuda.cc (2)
1553-1565: Confirm TL::PutmemBlock args map to nvshmemx_putmem_block(dest, src, nbytes, pe).
The CUDA codegen emitsop->args[0..3]in this order; ensure the TIR builder invokesPutmemBlockwith (dest, src, size, pe) to avoid silent data corruption.
1589-1600: Mapping of tl::GetmemBlock to nvshmemx_getmem_block is correct. The NVSHMEM intrinsic signature (__device__ void nvshmemx_getmem_block(TYPE *dest, const TYPE *source, size_t nelems, int pe)) matchesop->args[0..3]as (local_dst, remote_src, nbytes, pe). (docs.nvidia.com)src/op/sync.cc (1)
101-109: tl.fence_{cta,gpu,sys}: C++ declarations and Python bindings confirmed. Declarations in src/op/sync.h and Python wrappers in tilelang/language/builtin.py are present and imported in tilelang/language/init.py.
| warp_idx = T.get_thread_binding(0) // 32 | ||
| warp_copy_size = T.ceildiv(size, threads // 32) | ||
| warp_start = warp_copy_size * warp_idx | ||
| T.push_warp( | ||
| src=T.address_of(src[warp_start]), | ||
| dst=T.address_of(dst[warp_start]), | ||
| size=warp_copy_size, | ||
| dst_pe=rank[0] ^ 1, | ||
| unroll_factor=unroll_factor) | ||
| T.fence_sys() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guard the tail to avoid OOB reads/writes for non-multiple sizes.
warp_copy_size = ceildiv(size, num_warps) can overshoot on the last warp; push_warp will copy warp_copy_size starting at warp_start, potentially crossing size. Clamp the copy size and skip inactive warps.
Apply in both push/pull kernels:
- T.push_warp(
- src=T.address_of(src[warp_start]),
- dst=T.address_of(dst[warp_start]),
- size=warp_copy_size,
- dst_pe=rank[0] ^ 1,
- unroll_factor=unroll_factor)
+ if warp_start < size:
+ copy_sz = T.min(warp_copy_size, size - warp_start)
+ T.push_warp(
+ src=T.address_of(src[warp_start]),
+ dst=T.address_of(dst[warp_start]),
+ size=copy_sz,
+ dst_pe=rank[0] ^ 1,
+ unroll_factor=unroll_factor)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| warp_idx = T.get_thread_binding(0) // 32 | |
| warp_copy_size = T.ceildiv(size, threads // 32) | |
| warp_start = warp_copy_size * warp_idx | |
| T.push_warp( | |
| src=T.address_of(src[warp_start]), | |
| dst=T.address_of(dst[warp_start]), | |
| size=warp_copy_size, | |
| dst_pe=rank[0] ^ 1, | |
| unroll_factor=unroll_factor) | |
| T.fence_sys() | |
| warp_idx = T.get_thread_binding(0) // 32 | |
| warp_copy_size = T.ceildiv(size, threads // 32) | |
| warp_start = warp_copy_size * warp_idx | |
| if warp_start < size: | |
| copy_sz = T.min(warp_copy_size, size - warp_start) | |
| T.push_warp( | |
| src=T.address_of(src[warp_start]), | |
| dst=T.address_of(dst[warp_start]), | |
| size=copy_sz, | |
| dst_pe=rank[0] ^ 1, | |
| unroll_factor=unroll_factor) | |
| T.fence_sys() |
🤖 Prompt for AI Agents
In benchmark/distributed/ipc_impls/benchmark_unrolledcp_p2p.py around lines 24
to 33, the calculation warp_copy_size = ceildiv(size, threads // 32) can cause
the last warp to read/write past the buffer; clamp each warp's effective size to
max(0, min(warp_copy_size, size - warp_start)) and skip the warp when that
clamped size is <= 0 so no OOB operations occur; apply the same guard (compute
warp_start, compute remaining = size - warp_start, if remaining <= 0
return/skip, else use copy_size = min(warp_copy_size, remaining)) in both the
push and pull kernels.
| warp_copy_size = T.ceildiv(size, threads // 32) | ||
| warp_start = warp_copy_size * warp_idx | ||
| T.pull_warp( | ||
| src=T.address_of(src[warp_start]), | ||
| dst=T.address_of(dst[warp_start]), | ||
| size=warp_copy_size, | ||
| src_pe=rank[0] ^ 1, | ||
| unroll_factor=unroll_factor) | ||
| T.fence_sys() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same tail-safety needed for pull path.
Mirror the clamp/guard from push:
- T.pull_warp(
- src=T.address_of(src[warp_start]),
- dst=T.address_of(dst[warp_start]),
- size=warp_copy_size,
- src_pe=rank[0] ^ 1,
- unroll_factor=unroll_factor)
+ if warp_start < size:
+ copy_sz = T.min(warp_copy_size, size - warp_start)
+ T.pull_warp(
+ src=T.address_of(src[warp_start]),
+ dst=T.address_of(dst[warp_start]),
+ size=copy_sz,
+ src_pe=rank[0] ^ 1,
+ unroll_factor=unroll_factor)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| warp_copy_size = T.ceildiv(size, threads // 32) | |
| warp_start = warp_copy_size * warp_idx | |
| T.pull_warp( | |
| src=T.address_of(src[warp_start]), | |
| dst=T.address_of(dst[warp_start]), | |
| size=warp_copy_size, | |
| src_pe=rank[0] ^ 1, | |
| unroll_factor=unroll_factor) | |
| T.fence_sys() | |
| warp_copy_size = T.ceildiv(size, threads // 32) | |
| warp_start = warp_copy_size * warp_idx | |
| if warp_start < size: | |
| copy_sz = T.min(warp_copy_size, size - warp_start) | |
| T.pull_warp( | |
| src=T.address_of(src[warp_start]), | |
| dst=T.address_of(dst[warp_start]), | |
| size=copy_sz, | |
| src_pe=rank[0] ^ 1, | |
| unroll_factor=unroll_factor) | |
| T.fence_sys() |
🤖 Prompt for AI Agents
In benchmark/distributed/ipc_impls/benchmark_unrolledcp_p2p.py around lines 49
to 57, the pull path lacks the tail-safety clamp used in the push path; mirror
that logic by computing a clamped copy size per warp (e.g., min(warp_copy_size,
max(0, size - warp_start))) and skip or set size to 0 for warps beyond the tail
so you don't read past the buffer; ensure the address passed to T.pull_warp uses
the same guarded warp_start and that we only call T.pull_warp when the effective
copy size > 0 (leave T.fence_sys() after the pull).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
♻️ Duplicate comments (3)
benchmark/distributed/ipc_impls/benchmark_unrolledcp_p2p.py (3)
21-34: Guard tail to avoid OOB in push when size isn’t a multiple of warps.Current
warp_copy_size = ceildiv(size, threads//32)can overrun on the last warp.with T.Kernel(1, threads=threads): rank = T.alloc_local([1], "uint64") rank[0] = T.get_rank() warp_idx = T.get_thread_binding(0) // 32 warp_copy_size = T.ceildiv(size, threads // 32) warp_start = warp_copy_size * warp_idx - T.push_warp( - src=T.address_of(src[warp_start]), - dst=T.address_of(dst[warp_start]), - size=warp_copy_size, - dst_pe=rank[0] ^ 1, - unroll_factor=unroll_factor) + if warp_start < size: + copy_sz = T.min(warp_copy_size, size - warp_start) + T.push_warp( + src=T.address_of(src[warp_start]), + dst=T.address_of(dst[warp_start]), + size=copy_sz, + dst_pe=rank[0] ^ 1, + unroll_factor=unroll_factor) T.fence_sys()
102-108: Pass local rank topology to allocator (not global rank/world size).Using global values can misconfigure symmetric mappings on multi-node setups.
allocator = tilelang.get_allocator( size=2**30, device="cuda", is_distributed=True, - local_rank=rank, - num_local_ranks=num_ranks, + local_rank=local_rank, + num_local_ranks=num_local_ranks, group=group)
45-58: Mirror tail guard for pull path.Without this, last warp may read past
src.with T.Kernel(1, threads=threads): rank = T.alloc_local([1], "uint64") rank[0] = T.get_rank() warp_idx = T.get_thread_binding(0) // 32 warp_copy_size = T.ceildiv(size, threads // 32) warp_start = warp_copy_size * warp_idx - T.pull_warp( - src=T.address_of(src[warp_start]), - dst=T.address_of(dst[warp_start]), - size=warp_copy_size, - src_pe=rank[0] ^ 1, - unroll_factor=unroll_factor) + if warp_start < size: + copy_sz = T.min(warp_copy_size, size - warp_start) + T.pull_warp( + src=T.address_of(src[warp_start]), + dst=T.address_of(dst[warp_start]), + size=copy_sz, + src_pe=rank[0] ^ 1, + unroll_factor=unroll_factor) T.fence_sys()
🧹 Nitpick comments (7)
benchmark/distributed/ipc_impls/benchmark_nvshmem_p2p.py (4)
1-5: Fix typos and wrong usage path in header comments.
- “NVHSMEM” → “NVSHMEM”.
- Update the Usage path to include ipc_impls.
-# This benchmark aims to measure the bandwidth of NVHSMEM-based communication. +# This benchmark aims to measure the bandwidth of NVSHMEM-based communication. @@ -# Usage: GPUS=2 bash tilelang/distributed/launch.sh benchmark/distributed/benchmark_nvshmem_p2p.py +# Usage: GPUS=2 bash tilelang/distributed/launch.sh benchmark/distributed/ipc_impls/benchmark_nvshmem_p2p.py
6-16: Set NCCL_DEBUG before importing torch to ensure it takes effect.Move the env var assignment above torch/torch.distributed imports.
import os +os.environ.setdefault('NCCL_DEBUG', 'WARN') import tilelang import tilelang.language as T import argparse import torch import torch.distributed as dist from tilelang.distributed.utils import init_distributed, perf_fn import pynvshmem -os.environ['NCCL_DEBUG'] = 'WARN' -
56-60: Prefer explicit argument validation over assert for runtime checks.Asserts can be stripped with -O and won’t validate in production runs.
- assert num_ranks == 2, "this benchmark only supports 2 ranks" - assert args.threads % 32 == 0, "threads must be divisible by 32" + if num_ranks != 2: + raise ValueError("this benchmark only supports 2 ranks") + if args.threads % 32 != 0: + raise ValueError("threads must be divisible by 32")
56-56: Unused parameter ‘rank’.Either remove from signature/callsite or use it for logging/validation to appease linters.
benchmark/distributed/ipc_impls/benchmark_unrolledcp_p2p.py (3)
10-12: Set NCCL_DEBUG before torch import to be effective.Move env var assignment above torch/torch.distributed imports.
-tilelang.disable_cache() -os.environ['NCCL_DEBUG'] = 'WARN' +tilelang.disable_cache() +# Ensure NCCL debug level is set before importing torch.distributed +os.environ.setdefault('NCCL_DEBUG', 'WARN')
62-66: Replace asserts with explicit validation for CLI/runtime robustness.- assert num_ranks == 2, "this benchmark only supports 2 ranks" - assert args.threads % 32 == 0, "threads must be divisible by 32" + if num_ranks != 2: + raise ValueError("this benchmark only supports 2 ranks") + if args.threads % 32 != 0: + raise ValueError("threads must be divisible by 32")
62-62: Unused parameter ‘rank’.Not used in the function body; consider removing it and updating the call, or log with it.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (2)
benchmark/distributed/ipc_impls/benchmark_nvshmem_p2p.py(1 hunks)benchmark/distributed/ipc_impls/benchmark_unrolledcp_p2p.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
benchmark/distributed/ipc_impls/benchmark_unrolledcp_p2p.py (9)
src/tl_templates/cuda/reduce.h (1)
T(75-147)tilelang/distributed/utils.py (2)
init_dist(34-56)perf_fn(217-238)tilelang/env.py (1)
disable_cache(247-248)tilelang/language/distributed/common.py (3)
get_rank(8-11)push_warp(20-42)pull_warp(45-67)tilelang/language/tir/op.py (1)
address_of(463-479)tilelang/language/builtin.py (1)
fence_sys(429-431)tilelang/jit/kernel.py (1)
initialize(400-409)tilelang/utils/tensor.py (1)
tensor(45-57)tilelang/utils/allocator.py (1)
get_allocator(226-238)
benchmark/distributed/ipc_impls/benchmark_nvshmem_p2p.py (5)
tilelang/distributed/utils.py (1)
perf_fn(217-238)tilelang/language/distributed/multi_device/nvshmem.py (3)
putmem_block(97-105)get_pe(6-8)getmem_block(77-78)tilelang/language/tir/op.py (1)
address_of(463-479)tilelang/language/builtin.py (1)
fence_sys(429-431)tilelang/jit/__init__.py (1)
compile(32-81)
🪛 Ruff (0.12.2)
benchmark/distributed/ipc_impls/benchmark_unrolledcp_p2p.py
62-62: Unused function argument: rank
(ARG001)
64-64: Use of assert detected
(S101)
65-65: Use of assert detected
(S101)
benchmark/distributed/ipc_impls/benchmark_nvshmem_p2p.py
56-56: Unused function argument: rank
(ARG001)
58-58: Use of assert detected
(S101)
59-59: Use of assert detected
(S101)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: bot-task
| T.putmem_block( | ||
| T.address_of(dst), | ||
| T.address_of(src), | ||
| size * 4, | ||
| T.get_pe() ^ 1, | ||
| ) | ||
| T.fence_sys() | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
address_of requires a BufferLoad; use dst[0]/src[0] to pass base pointers.
Passing the buffer var itself is invalid; this can miscompile or crash.
with T.Kernel(1, threads=threads):
T.putmem_block(
- T.address_of(dst),
- T.address_of(src),
+ T.address_of(dst[0]),
+ T.address_of(src[0]),
size * 4,
T.get_pe() ^ 1,
)
T.fence_sys()📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| T.putmem_block( | |
| T.address_of(dst), | |
| T.address_of(src), | |
| size * 4, | |
| T.get_pe() ^ 1, | |
| ) | |
| T.fence_sys() | |
| with T.Kernel(1, threads=threads): | |
| T.putmem_block( | |
| T.address_of(dst[0]), | |
| T.address_of(src[0]), | |
| size * 4, | |
| T.get_pe() ^ 1, | |
| ) | |
| T.fence_sys() |
🤖 Prompt for AI Agents
In benchmark/distributed/ipc_impls/benchmark_nvshmem_p2p.py around lines 26-33,
the call site passes the buffer variable to T.address_of which requires a
BufferLoad and will miscompile or crash; replace T.address_of(dst) and
T.address_of(src) with the base pointer expressions dst[0] and src[0]
respectively so the underlying BufferLoad is emitted and the correct base
addresses are passed to T.putmem_block (leave size, peer, and T.fence_sys()
unchanged).
| T.getmem_block( | ||
| T.address_of(dst), | ||
| T.address_of(src), | ||
| size * 4, | ||
| T.get_pe() ^ 1, | ||
| ) | ||
| T.fence_sys() | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same fix for pull path: pass base element address, not buffer var.
with T.Kernel(1, threads=threads):
T.getmem_block(
- T.address_of(dst),
- T.address_of(src),
+ T.address_of(dst[0]),
+ T.address_of(src[0]),
size * 4,
T.get_pe() ^ 1,
)
T.fence_sys()📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| T.getmem_block( | |
| T.address_of(dst), | |
| T.address_of(src), | |
| size * 4, | |
| T.get_pe() ^ 1, | |
| ) | |
| T.fence_sys() | |
| with T.Kernel(1, threads=threads): | |
| - T.getmem_block( | |
| - T.address_of(dst), | |
| T.getmem_block( | |
| T.address_of(dst[0]), | |
| T.address_of(src[0]), | |
| size * 4, | |
| T.get_pe() ^ 1, | |
| ) | |
| T.fence_sys() |
🤖 Prompt for AI Agents
In benchmark/distributed/ipc_impls/benchmark_nvshmem_p2p.py around lines 45 to
52, the pull path currently passes the buffer variable to T.getmem_block via
T.address_of(src); change it to pass the base element address (e.g.,
T.address_of(src[0]) or the buffer's element base) instead of the buffer
variable so the remote memory copy uses the element address rather than the
Python buffer object.
This pull request introduces a new benchmark script to measure the bandwidth of the IPC (Inter-Process Communication) implementation and compare its efficiency with NVSHMEM primitives. Additionally, it improves code documentation for a distributed memory operation. The main focus is on providing a robust benchmarking tool for evaluating IPC communication performance in distributed GPU environments.
New benchmarking tool:
benchmark/distributed/benchmark_ipc_p2p.pythat benchmarks IPC push and pull bandwidth between two ranks, using custom kernels built withtilelangand comparing against NVSHMEM. The script includes kernel definitions, setup for distributed execution, and performance measurement logic.Documentation improvements:
putmem_blockfunction intilelang/language/distributed/multi_device/nvshmem.pyto clarify its purpose and usage, detailing the arguments and behavior for putting data from local to remote memory at block granularity.- Introduced a new benchmark scriptbenchmark_ipc_p2p.pyto evaluate the performance of IPC communication against NVSHMEM primitives.T.putmem_blockSummary by CodeRabbit
New Features
Documentation