- 
                Notifications
    You must be signed in to change notification settings 
- Fork 2
[Feat]support sp all2all without transpose example #16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughAdds a new sequence-parallel all-to-all NVSHMEM example with a PyTorch golden reference, introduces a distributed initializer utility, fixes environment flag handling and NVSHMEM path exports, adjusts a Cython env import, and updates a compile call to pass specific TileLang pass configurations. Changes
 Sequence Diagram(s)sequenceDiagram
  autonumber
  participant User
  participant Script as example_pre_attn_all2all.py
  participant Dist as torch.distributed
  participant NV as NVSHMEM
  participant TL as TileLang Kernel
  participant TorchRef as PyTorch Ref
  User->>Script: Run with args
  Script->>Dist: init_process_group (NCCL)
  Script->>NV: optional init
  Script->>Script: compile TileLang kernel (pass_configs)
  Script->>TorchRef: all_to_all reference
  Script->>TL: launch NVSHMEM kernel
  TL-->>Script: output tensor
  TorchRef-->>Script: reference tensor
  Script->>Script: compare & report
  Script->>Dist: destroy_process_group
sequenceDiagram
  autonumber
  participant App
  participant Utils as init_distributed()
  participant Env as OS Env
  participant Dist as torch.distributed
  participant CUDA as torch.cuda
  participant NV as NVSHMEM
  App->>Utils: call(return_tp_group, init_nvshmem)
  Utils->>Env: read WORLD_SIZE, RANK, LOCAL_RANK
  Utils->>Dist: init_process_group(backend=NCCL, timeout=1800s)
  Utils->>CUDA: set_device(LOCAL_RANK) & sync
  Utils->>Dist: new_group(world)
  alt init_nvshmem
    Utils->>NV: init()
  end
  Utils-->>App: WORLD_SIZE, RANK, LOCAL_RANK, (TP_GROUP)
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Poem
 Tip 🔌 Remote MCP (Model Context Protocol) integration is now available!Pro plan users can now connect to remote MCP servers from the Integrations page. Connect with popular remote MCPs such as Notion and Linear to add more context to your reviews and chats. ✨ Finishing Touches
 🧪 Generate unit tests
 Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit: 
 SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR/Issue comments)Type  Other keywords and placeholders
 CodeRabbit Configuration File ( | 
| 👋 Hi! Thank you for contributing to the TileLang project. Please remember to run  We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds support for sequence parallel all-to-all operations without transpose, introducing a new distributed communication pattern for attention mechanisms. The changes include environment variable fixes, utility function enhancements, and a comprehensive example implementation.
- Fixes environment variable access patterns by using the .get()method consistently
- Adds a new init_distributedutility function for simplified distributed setup
- Introduces a complete sequence parallel all-to-all example with PyTorch golden reference verification
Reviewed Changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description | 
|---|---|
| tilelang/jit/adapter/cython/cython_wrapper.pyx | Updates import path for environment module | 
| tilelang/env.py | Fixes environment variable access to use .get()method consistently | 
| tilelang/distributed/utils.py | Adds new init_distributedfunction for distributed initialization | 
| examples/distributed/ipc/example_pull_warp.py | Minor formatting adjustment to tensor parameter | 
| examples/distributed/example_pre_attn_all2all.py | New comprehensive example implementing sequence parallel all-to-all | 
| examples/distributed/example_allgather_gemm.py | Adds compilation configuration options | 
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 8
🧹 Nitpick comments (10)
examples/distributed/example_allgather_gemm.py (1)
63-69: Pass configs are valid; consider making them togglable.These flags look correct for NVSHMEM paths. For easier bisecting/perf tuning across GPUs, consider threading them through CLI/env (e.g., TILELANG_DISABLE_TMA, TILELANG_DISABLE_WARP_SPEC) instead of hard-coding.
tilelang/distributed/utils.py (2)
48-72: Prefer consistent namespace and optional barrier before NVSHMEM init.Use either dist.* or torch.distributed.* consistently within this file. Also consider a barrier before NVSHMEM init for stricter ordering in mixed launches.
Proposed tweaks:
- torch.cuda.synchronize() - if init_nvshmem: + torch.cuda.synchronize() + # Optional: ensure all ranks are ready before NVSHMEM init + # torch.distributed.barrier() + if init_nvshmem: import pynvshmem pynvshmem.init_nvshmem_by_uniqueid(TP_GROUP)
48-72: Avoid duplication with profiler.init_distributed().This replicates logic in tilelang/profiler/init.py:init_distributed. Centralize to a single helper to prevent drift (timeouts, group semantics).
I can refactor profiler to call this util (or vice versa) and add light unit coverage for both paths.
tilelang/env.py (4)
213-225: Boolean detection fixed; confirm acceptability of static evaluation.Using .get().lower() fixes the prior bug, but it freezes USE_DISTRIBUTED at import time. If you expect toggling at runtime (tests), prefer a property or method returning a live view.
Possible approach (outside this hunk):
- Keep _USE_DISTRIBUTED = EnvVar("TILELANG_USE_DISTRIBUTED", "0")
- Add property USE_DISTRIBUTED -> self._USE_DISTRIBUTED.get().lower() in ("1","true","on")
- Compute NVSHMEM_* lazily or in init.
Would you like me to draft that change?
220-221: Use os.path.join for NVSHMEM paths (cross-platform, avoids double slashes).Apply:
- NVSHMEM_INCLUDE_DIR: str = NVSHMEM_SRC + "/build/src/include" - NVSHMEM_LIB_PATH: str = NVSHMEM_SRC + "/build/src/lib" + NVSHMEM_INCLUDE_DIR: str = os.path.join(NVSHMEM_SRC, "build", "src", "include") + NVSHMEM_LIB_PATH: str = os.path.join(NVSHMEM_SRC, "build", "src", "lib")
213-221: Avoid repeated EnvVar construction for NVSHMEM_SRC.Minor readability: fetch once, reuse.
Apply:
- if EnvVar("NVSHMEM_SRC", None).get() is not None: - NVSHMEM_SRC = EnvVar("NVSHMEM_SRC", None).get() + _nvshmem_src = EnvVar("NVSHMEM_SRC", None).get() + if _nvshmem_src is not None: + NVSHMEM_SRC = _nvshmem_src
213-225: Refactor remainingenvimports
There are still legacyfrom tilelang import envstatements across multiple modules (e.g.tilelang/utils/sparse.py,tilelang/profiler/__init__.py,tilelang/jit/adapter/libgen.py,tilelang/jit/adapter/wrapper.py,tilelang/jit/kernel.py,tilelang/engine/lower.py,tilelang/cache/__init__.py,tilelang/autotuner/tuner.py,tilelang/cache/kernel_cache.py). Replace each with direct imports of only the variables you actually use, and verify no code callsenv.USE_DISTRIBUTED.get()or.lower().tilelang/jit/adapter/cython/cython_wrapper.pyx (1)
10-11: Guard against late toggling of USE_DISTRIBUTED (pynvshmem import).If USE_DISTRIBUTED were ever toggled at runtime post-import, pynvshmem may be undefined when forward() runs. If you plan to allow toggling, import lazily inside forward() before first use.
Do you intend USE_DISTRIBUTED to be runtime-static? If not, I can patch forward() to import pynvshmem on demand.
examples/distributed/example_pre_attn_all2all.py (2)
49-58: Use direct list comprehension for output list creationThe loop variable
pe_idxis not used within the loop body. Consider using a more Pythonic list comprehension.-output_list = [] -for pe_idx in range(world_size): - # Receive [BATCH_SIZE, HEADS_PER_PE, SEQ_PER_PE, HEAD_DIM] from each PE - recv_data = torch.empty( - batch_size, - heads_per_pe, - seq_per_pe, - head_dim, - dtype=data_src.dtype, - device=data_src.device) - output_list.append(recv_data) +# Prepare output list for all_to_all +# Receive [BATCH_SIZE, HEADS_PER_PE, SEQ_PER_PE, HEAD_DIM] from each PE +output_list = [ + torch.empty( + batch_size, + heads_per_pe, + seq_per_pe, + head_dim, + dtype=data_src.dtype, + device=data_src.device) + for _ in range(world_size) +]
133-133: Document magic number for signal operationThe value
10appears to beNVSHMEM_SIGNAL_ADDbut it's better to use a named constant or add a comment.+# NVSHMEM_SIGNAL_ADD = 10 T.signal_op( T.address_of(signal[mype[0]]), 1, 10, # NVSHMEM_SIGNAL_ADD target_pe)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📥 Commits
Reviewing files that changed from the base of the PR and between 10f4fdb and 2b5f147e203c90a2e2471f2d4405cef4941abf21.
📒 Files selected for processing (6)
- examples/distributed/example_allgather_gemm.py(1 hunks)
- examples/distributed/example_pre_attn_all2all.py(1 hunks)
- examples/distributed/ipc/example_pull_warp.py(1 hunks)
- tilelang/distributed/utils.py(2 hunks)
- tilelang/env.py(1 hunks)
- tilelang/jit/adapter/cython/cython_wrapper.pyx(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
examples/distributed/example_allgather_gemm.py (1)
tilelang/jit/__init__.py (1)
compile(32-81)
examples/distributed/example_pre_attn_all2all.py (7)
tilelang/utils/tensor.py (1)
TensorSupplyType(11-18)tilelang/distributed/utils.py (1)
perf_fn(206-227)tilelang/env.py (1)
disable_cache(246-247)tilelang/language/allocate.py (1)
alloc_local(38-49)tilelang/language/distributed/multi_device/nvshmem.py (5)
get_pe(6-8)
putmem_nbi_block(101-109)
fence(61-63)
signal_op(164-172)
signal_wait_until(175-183)tilelang/language/tir/op.py (1)
address_of(463-479)tilelang/jit/__init__.py (1)
compile(32-81)
tilelang/distributed/utils.py (3)
tilelang/profiler/__init__.py (1)
init_distributed(73-92)tilelang/env.py (1)
get(136-139)tilelang/distributed/pynvshmem/python/pynvshmem/__init__.py (1)
init_nvshmem_by_uniqueid(40-53)
🪛 Ruff (0.12.2)
examples/distributed/example_pre_attn_all2all.py
4-4: os imported but unused
Remove unused import: os
(F401)
7-7: tilelang.profiler.TensorSupplyType imported but unused
Remove unused import: tilelang.profiler.TensorSupplyType
(F401)
8-8: tilelang.distributed.utils.perf_fn imported but unused
Remove unused import: tilelang.distributed.utils.perf_fn
(F401)
29-29: Local variable rank is assigned to but never used
Remove assignment to unused variable rank
(F841)
49-49: Loop control variable pe_idx not used within loop body
(B007)
266-266: Local variable verification_passed is assigned to but never used
Remove assignment to unused variable verification_passed
(F841)
🔇 Additional comments (3)
examples/distributed/ipc/example_pull_warp.py (1)
18-21: Formatting-only change LGTM.Signature readability is improved; no semantic impact.
tilelang/jit/adapter/cython/cython_wrapper.pyx (1)
10-11: Import change aligns with new env semantics.This makes env.USE_DISTRIBUTED resolve to the Environment instance attribute. Looks good.
examples/distributed/example_pre_attn_all2all.py (1)
14-76: Excellent golden reference implementation!The PyTorch distributed all-to-all reference implementation is well-structured with clear documentation explaining the input/output tensor shapes and transformation steps. This provides a solid baseline for verification.
| Thanks Jiaqing! The main brach has been updated and please solve conflicts and merge. | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
♻️ Duplicate comments (7)
examples/distributed/example_pre_attn_all2all.py (7)
4-8: Remove unused imports to satisfy Ruff and keep the example minimal
os,TensorSupplyType, andperf_fnare unused.Apply:
-import os import tilelang import tilelang.language as T -from tilelang.profiler import TensorSupplyType -from tilelang.distributed.utils import init_distributed, dtype_map, perf_fn +from tilelang.distributed.utils import init_distributed, dtype_map import argparse
28-30: Remove unusedrankvariableAssigned and never used.
world_size = dist.get_world_size(group) -rank = dist.get_rank(group)
214-216: Gate kernel source printing on--print_sourceas intendedCurrently always prints on rank 0.
- if RANK == 0: + if RANK == 0 and args.print_source: print("\nTileLang Kernel Source:") print(kernel.get_kernel_source())
250-254: Synchronize NVSHMEM ranks after kernel or document why notGiven the non-blocking put and inter-PE signaling, a barrier is often used post-kernel for clarity and safety in examples.
Option A (simple):
- #pynvshmem.nvshmem_barrier_all() + pynvshmem.nvshmem_barrier_all()Option B (configurable):
- Add
parser.add_argument("--barrier", action="store_true")and guard the call withif args.barrier: ....
167-178: Tighten CLI: fix typo and consider using warmup/repeat for benchmarking
- Help text nit: missing space after comma.
--warmupand--repeatare parsed but unused; either wire them up or drop them.Apply the typo fix:
- parser.add_argument( - "--num_heads", type=int, default=16, help="Number of attention heads,combine QKV") + parser.add_argument( + "--num_heads", type=int, default=16, help="Number of attention heads, combine QKV")Optional: integrate simple timing using
torch.cuda.Event()aroundtilelang_all_to_all()or re-introduceperf_fnand benchmark just the kernel execution while allocating NVSHMEM tensors once outside the timed region.
116-124: Compute transfer size from dtype (not hard-coded 2 bytes)Hard-coding float16 breaks for other dtypes and mismatches
--dtype.Apply:
- transfer_size = SEQ_PER_PE * HEAD_DIM * 2 # float16 = 2 bytes + transfer_size = SEQ_PER_PE * HEAD_DIM * DTYPE_BYTESAnd define
DTYPE_BYTESonce before@T.prim_func:# Outside prim_func, near sequence_parallel_all_to_all signature DTYPE_BYTES_MAP = { "float16": 2, "bfloat16": 2, "float32": 4, "float64": 8, "int8": 1, "uint8": 1, "int32": 4, "uint32": 4, "int64": 8, "uint64": 8, } DTYPE_BYTES = DTYPE_BYTES_MAP.get(dtype) assert DTYPE_BYTES is not None, f"Unsupported dtype: {dtype}"
266-270: Fail fast on verification errors and aggregate across PEsCurrently the result is ignored, and CI won’t catch mismatches.
- verification_passed = verify_results(tilelang_output, torch_output, RANK) + verification_passed = verify_results(tilelang_output, torch_output, RANK) + verification_results = [None] * WORLD_SIZE + dist.all_gather_object(verification_results, verification_passed, group=TP_GROUP) + if RANK == 0 and not all(verification_results): + print("❌ Verification failed on one or more PEs!") + dist.destroy_process_group() + raise SystemExit(1)
🧹 Nitpick comments (2)
examples/distributed/example_pre_attn_all2all.py (2)
49-59: Avoid unused loop index warningUse
_since the index isn’t used in the loop body.- for pe_idx in range(world_size): + for _ in range(world_size):
236-256: Benchmark correctly: allocate NVSHMEM tensors once, time only the kernelCurrent structure would time allocation if you hook up benchmarking later. Allocate once, then pass closures to a perf helper.
Sketch:
# allocate once outside data_src = pynvshmem.nvshmem_create_tensor([...], dtype_torch) data_dst = pynvshmem.nvshmem_create_tensor([...], dtype_torch) signal = pynvshmem.nvshmem_create_tensor([PE_num], torch.uint64) def run_kernel(): data_src.copy_(input_data); data_dst.zero_(); signal.zero_() kernel(data_src, data_dst, signal) pynvshmem.nvshmem_barrier_all() return data_dst # warmup/repeat timing if desired # _, ms = perf_fn(run_kernel, args.warmup, args.repeat) # if RANK == 0: print(f"TileLang A2A avg: {ms:.3f} ms")
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📥 Commits
Reviewing files that changed from the base of the PR and between 2b5f147e203c90a2e2471f2d4405cef4941abf21 and a7005f4.
📒 Files selected for processing (5)
- examples/distributed/example_allgather_gemm.py(1 hunks)
- examples/distributed/example_pre_attn_all2all.py(1 hunks)
- tilelang/distributed/utils.py(2 hunks)
- tilelang/env.py(1 hunks)
- tilelang/jit/adapter/cython/cython_wrapper.pyx(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (4)
- tilelang/distributed/utils.py
- examples/distributed/example_allgather_gemm.py
- tilelang/jit/adapter/cython/cython_wrapper.pyx
- tilelang/env.py
🧰 Additional context used
🧬 Code graph analysis (1)
examples/distributed/example_pre_attn_all2all.py (6)
tilelang/distributed/utils.py (1)
perf_fn(206-227)tilelang/env.py (1)
disable_cache(246-247)tilelang/language/allocate.py (1)
alloc_local(38-49)tilelang/language/distributed/multi_device/nvshmem.py (5)
get_pe(6-8)
putmem_nbi_block(101-109)
fence(61-63)
signal_op(164-172)
signal_wait_until(175-183)tilelang/language/tir/op.py (1)
address_of(463-479)tilelang/jit/__init__.py (1)
compile(32-81)
🪛 Ruff (0.12.2)
examples/distributed/example_pre_attn_all2all.py
4-4: os imported but unused
Remove unused import: os
(F401)
7-7: tilelang.profiler.TensorSupplyType imported but unused
Remove unused import: tilelang.profiler.TensorSupplyType
(F401)
8-8: tilelang.distributed.utils.perf_fn imported but unused
Remove unused import: tilelang.distributed.utils.perf_fn
(F401)
29-29: Local variable rank is assigned to but never used
Remove assignment to unused variable rank
(F841)
49-49: Loop control variable pe_idx not used within loop body
(B007)
266-266: Local variable verification_passed is assigned to but never used
Remove assignment to unused variable verification_passed
(F841)
🔇 Additional comments (1)
examples/distributed/example_pre_attn_all2all.py (1)
11-11: Confirm the correct cache-disabling API
tilelang.disable_cache()may be a mis-call; in many places the API isfrom tilelang.env import env; env.disable_cache().If needed, switch to:
-tilelang.disable_cache() +from tilelang.env import env +env.disable_cache()
| if tx == 0: | ||
| T.signal_op( | ||
| T.address_of(signal[mype[0]]), | ||
| 1, | ||
| 10, # NVSHMEM_SIGNAL_ADD | ||
| target_pe) | ||
| for k in T.serial(PE_num): | ||
| T.signal_wait_until(T.address_of(signal[k]), 0, NUM_BLOCKS_X) | ||
|  | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Replace magic numbers in NVSHMEM signaling with named constants
10 and 0 are unclear; use named constants for maintainability.
Apply:
-                T.signal_op(
+                T.signal_op(
                     T.address_of(signal[mype[0]]),
                     1,
-                    10,  # NVSHMEM_SIGNAL_ADD
+                    NVSHMEM_SIGNAL_ADD,
                     target_pe)
                 for k in T.serial(PE_num):
-                    T.signal_wait_until(T.address_of(signal[k]), 0, NUM_BLOCKS_X)
+                    T.signal_wait_until(T.address_of(signal[k]), NVSHMEM_CMP_EQ, NUM_BLOCKS_X)Add once near the top of the file:
NVSHMEM_SIGNAL_ADD = 10
NVSHMEM_CMP_EQ = 0🤖 Prompt for AI Agents
In examples/distributed/example_pre_attn_all2all.py around lines 129 to 137, the
NVSHMEM signal calls use magic numbers (10 and 0); define named constants
NVSHMEM_SIGNAL_ADD = 10 and NVSHMEM_CMP_EQ = 0 near the top of the file and
replace the literal 10 with NVSHMEM_SIGNAL_ADD and the literal 0 used in
signal_wait_until with NVSHMEM_CMP_EQ so the intent is clear and maintainable.
| # Initialize distributed environment | ||
| WORLD_SIZE, RANK, LOCAL_RANK, TP_GROUP = init_distributed(return_tp_group=True) | ||
| PE_num = WORLD_SIZE | ||
|  | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Ensure tensors go to the correct GPU in multi-process setups
Set CUDA device using LOCAL_RANK after initialization to avoid all ranks using GPU 0.
 WORLD_SIZE, RANK, LOCAL_RANK, TP_GROUP = init_distributed(return_tp_group=True)
 PE_num = WORLD_SIZE
+torch.cuda.set_device(LOCAL_RANK)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # Initialize distributed environment | |
| WORLD_SIZE, RANK, LOCAL_RANK, TP_GROUP = init_distributed(return_tp_group=True) | |
| PE_num = WORLD_SIZE | |
| # Initialize distributed environment | |
| WORLD_SIZE, RANK, LOCAL_RANK, TP_GROUP = init_distributed(return_tp_group=True) | |
| PE_num = WORLD_SIZE | |
| torch.cuda.set_device(LOCAL_RANK) | 
🤖 Prompt for AI Agents
In examples/distributed/example_pre_attn_all2all.py around lines 184 to 187,
after calling init_distributed(return_tp_group=True) you must set the CUDA
device to LOCAL_RANK so each process uses its assigned GPU; call
torch.cuda.set_device(LOCAL_RANK) (or equivalent) immediately after
initialization and ensure subsequent tensor/model creation or .to(device) uses
that device to avoid all ranks defaulting to GPU 0.
| input_data = torch.rand([args.batch_size, args.num_heads, SEQ_PER_PE, args.head_dim], | ||
| dtype=dtype_torch, | ||
| device='cuda') | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Create tensors on the rank-local device
Be explicit to avoid device mismatch when the default device isn’t set globally.
-    input_data = torch.rand([args.batch_size, args.num_heads, SEQ_PER_PE, args.head_dim],
-                            dtype=dtype_torch,
-                            device='cuda')
+    input_data = torch.rand(
+        [args.batch_size, args.num_heads, SEQ_PER_PE, args.head_dim],
+        dtype=dtype_torch,
+        device=f'cuda:{LOCAL_RANK}',
+    )📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| input_data = torch.rand([args.batch_size, args.num_heads, SEQ_PER_PE, args.head_dim], | |
| dtype=dtype_torch, | |
| device='cuda') | |
| input_data = torch.rand( | |
| [args.batch_size, args.num_heads, SEQ_PER_PE, args.head_dim], | |
| dtype=dtype_torch, | |
| device=f'cuda:{LOCAL_RANK}', | |
| ) | 
🤖 Prompt for AI Agents
In examples/distributed/example_pre_attn_all2all.py around lines 222 to 224, the
tensor is created with device='cuda' which can cause device-mismatch across
ranks; instead determine the rank-local CUDA device (e.g., device =
torch.device(f"cuda:{args.local_rank}") or device =
torch.device(f"cuda:{torch.cuda.current_device()}") if args.local_rank isn’t
available) and pass that device to torch.rand so the tensor is allocated on the
correct per-rank GPU.
Summary by CodeRabbit
New Features
Bug Fixes
Chores