Skip to content

Conversation

@JiaqingFu
Copy link

@JiaqingFu JiaqingFu commented Aug 29, 2025

Summary by CodeRabbit

  • New Features

    • Added a sequence-parallel all-to-all example with a PyTorch reference and NVSHMEM-accelerated path.
    • Introduced a utility to initialize distributed execution with optional NVSHMEM setup.
  • Bug Fixes

    • Fixed distributed mode detection from environment variables.
    • Corrected NVSHMEM source path retrieval and exposed include/library paths when distributed mode is enabled.
  • Chores

    • Updated an example to use safer compiler pass settings.
    • Adjusted internal bindings to align with the updated environment configuration.

Copilot AI review requested due to automatic review settings August 29, 2025 07:11
@coderabbitai
Copy link

coderabbitai bot commented Aug 29, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Adds a new sequence-parallel all-to-all NVSHMEM example with a PyTorch golden reference, introduces a distributed initializer utility, fixes environment flag handling and NVSHMEM path exports, adjusts a Cython env import, and updates a compile call to pass specific TileLang pass configurations.

Changes

Cohort / File(s) Summary
Distributed examples
examples/distributed/example_pre_attn_all2all.py, examples/distributed/example_allgather_gemm.py
New NVSHMEM sequence-parallel all-to-all example with PyTorch reference, CLI, verification, and kernel compile. Existing allgather GEMM example now passes pass_configs (tl.disable_tma_lower=True, tl.disable_warp_specialized=True) to tilelang.compile.
Distributed utilities
tilelang/distributed/utils.py
Added init_distributed(return_tp_group=False, init_nvshmem=True) to initialize NCCL process group, set CUDA device, optionally create TP group, and optionally init NVSHMEM.
Environment configuration
tilelang/env.py
Corrected USE_DISTRIBUTED evaluation, fixed NVSHMEM_SRC retrieval, and added NVSHMEM_INCLUDE_DIR and NVSHMEM_LIB_PATH when distributed is enabled; set to None otherwise.
Cython adapter binding
tilelang/jit/adapter/cython/cython_wrapper.pyx
Changed env import to from tilelang.env import env, affecting how env.USE_DISTRIBUTED is resolved.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant User
  participant Script as example_pre_attn_all2all.py
  participant Dist as torch.distributed
  participant NV as NVSHMEM
  participant TL as TileLang Kernel
  participant TorchRef as PyTorch Ref

  User->>Script: Run with args
  Script->>Dist: init_process_group (NCCL)
  Script->>NV: optional init
  Script->>Script: compile TileLang kernel (pass_configs)
  Script->>TorchRef: all_to_all reference
  Script->>TL: launch NVSHMEM kernel
  TL-->>Script: output tensor
  TorchRef-->>Script: reference tensor
  Script->>Script: compare & report
  Script->>Dist: destroy_process_group
Loading
sequenceDiagram
  autonumber
  participant App
  participant Utils as init_distributed()
  participant Env as OS Env
  participant Dist as torch.distributed
  participant CUDA as torch.cuda
  participant NV as NVSHMEM

  App->>Utils: call(return_tp_group, init_nvshmem)
  Utils->>Env: read WORLD_SIZE, RANK, LOCAL_RANK
  Utils->>Dist: init_process_group(backend=NCCL, timeout=1800s)
  Utils->>CUDA: set_device(LOCAL_RANK) & sync
  Utils->>Dist: new_group(world)
  alt init_nvshmem
    Utils->>NV: init()
  end
  Utils-->>App: WORLD_SIZE, RANK, LOCAL_RANK, (TP_GROUP)
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Poem

I thump my paws, packets fly,
Heads and sequences leap sky-high—
All-to-all, we weave and share,
NVSHMEM sparks the data air.
Env flags trimmed, the paths align,
Kernels hum in tidy time—
Carrots queued, results divine. 🥕✨

Tip

🔌 Remote MCP (Model Context Protocol) integration is now available!

Pro plan users can now connect to remote MCP servers from the Integrations page. Connect with popular remote MCPs such as Notion and Linear to add more context to your reviews and chats.

✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbit in a new review comment at the desired location with your query.
  • PR comments: Tag @coderabbit in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbit gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbit read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

CodeRabbit Commands (Invoked using PR/Issue comments)

Type @coderabbit help to get the list of available commands.

Other keywords and placeholders

  • Add @coderabbit ignore or @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbit summary or @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbit or @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Status, Documentation and Community

  • Visit our Status Page to check the current availability of CodeRabbit.
  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run bash format.sh in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work!

🚀

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds support for sequence parallel all-to-all operations without transpose, introducing a new distributed communication pattern for attention mechanisms. The changes include environment variable fixes, utility function enhancements, and a comprehensive example implementation.

  • Fixes environment variable access patterns by using the .get() method consistently
  • Adds a new init_distributed utility function for simplified distributed setup
  • Introduces a complete sequence parallel all-to-all example with PyTorch golden reference verification

Reviewed Changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
tilelang/jit/adapter/cython/cython_wrapper.pyx Updates import path for environment module
tilelang/env.py Fixes environment variable access to use .get() method consistently
tilelang/distributed/utils.py Adds new init_distributed function for distributed initialization
examples/distributed/ipc/example_pull_warp.py Minor formatting adjustment to tensor parameter
examples/distributed/example_pre_attn_all2all.py New comprehensive example implementing sequence parallel all-to-all
examples/distributed/example_allgather_gemm.py Adds compilation configuration options

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 8

🧹 Nitpick comments (10)
examples/distributed/example_allgather_gemm.py (1)

63-69: Pass configs are valid; consider making them togglable.

These flags look correct for NVSHMEM paths. For easier bisecting/perf tuning across GPUs, consider threading them through CLI/env (e.g., TILELANG_DISABLE_TMA, TILELANG_DISABLE_WARP_SPEC) instead of hard-coding.

tilelang/distributed/utils.py (2)

48-72: Prefer consistent namespace and optional barrier before NVSHMEM init.

Use either dist.* or torch.distributed.* consistently within this file. Also consider a barrier before NVSHMEM init for stricter ordering in mixed launches.

Proposed tweaks:

-    torch.cuda.synchronize()
-    if init_nvshmem:
+    torch.cuda.synchronize()
+    # Optional: ensure all ranks are ready before NVSHMEM init
+    # torch.distributed.barrier()
+    if init_nvshmem:
         import pynvshmem
         pynvshmem.init_nvshmem_by_uniqueid(TP_GROUP)

48-72: Avoid duplication with profiler.init_distributed().

This replicates logic in tilelang/profiler/init.py:init_distributed. Centralize to a single helper to prevent drift (timeouts, group semantics).

I can refactor profiler to call this util (or vice versa) and add light unit coverage for both paths.

tilelang/env.py (4)

213-225: Boolean detection fixed; confirm acceptability of static evaluation.

Using .get().lower() fixes the prior bug, but it freezes USE_DISTRIBUTED at import time. If you expect toggling at runtime (tests), prefer a property or method returning a live view.

Possible approach (outside this hunk):

  • Keep _USE_DISTRIBUTED = EnvVar("TILELANG_USE_DISTRIBUTED", "0")
  • Add property USE_DISTRIBUTED -> self._USE_DISTRIBUTED.get().lower() in ("1","true","on")
  • Compute NVSHMEM_* lazily or in init.
    Would you like me to draft that change?

220-221: Use os.path.join for NVSHMEM paths (cross-platform, avoids double slashes).

Apply:

-        NVSHMEM_INCLUDE_DIR: str = NVSHMEM_SRC + "/build/src/include"
-        NVSHMEM_LIB_PATH: str = NVSHMEM_SRC + "/build/src/lib"
+        NVSHMEM_INCLUDE_DIR: str = os.path.join(NVSHMEM_SRC, "build", "src", "include")
+        NVSHMEM_LIB_PATH: str = os.path.join(NVSHMEM_SRC, "build", "src", "lib")

213-221: Avoid repeated EnvVar construction for NVSHMEM_SRC.

Minor readability: fetch once, reuse.

Apply:

-        if EnvVar("NVSHMEM_SRC", None).get() is not None:
-            NVSHMEM_SRC = EnvVar("NVSHMEM_SRC", None).get()
+        _nvshmem_src = EnvVar("NVSHMEM_SRC", None).get()
+        if _nvshmem_src is not None:
+            NVSHMEM_SRC = _nvshmem_src

213-225: Refactor remaining env imports
There are still legacy from tilelang import env statements across multiple modules (e.g. tilelang/utils/sparse.py, tilelang/profiler/__init__.py, tilelang/jit/adapter/libgen.py, tilelang/jit/adapter/wrapper.py, tilelang/jit/kernel.py, tilelang/engine/lower.py, tilelang/cache/__init__.py, tilelang/autotuner/tuner.py, tilelang/cache/kernel_cache.py). Replace each with direct imports of only the variables you actually use, and verify no code calls env.USE_DISTRIBUTED.get() or .lower().

tilelang/jit/adapter/cython/cython_wrapper.pyx (1)

10-11: Guard against late toggling of USE_DISTRIBUTED (pynvshmem import).

If USE_DISTRIBUTED were ever toggled at runtime post-import, pynvshmem may be undefined when forward() runs. If you plan to allow toggling, import lazily inside forward() before first use.

Do you intend USE_DISTRIBUTED to be runtime-static? If not, I can patch forward() to import pynvshmem on demand.

examples/distributed/example_pre_attn_all2all.py (2)

49-58: Use direct list comprehension for output list creation

The loop variable pe_idx is not used within the loop body. Consider using a more Pythonic list comprehension.

-output_list = []
-for pe_idx in range(world_size):
-    # Receive [BATCH_SIZE, HEADS_PER_PE, SEQ_PER_PE, HEAD_DIM] from each PE
-    recv_data = torch.empty(
-        batch_size,
-        heads_per_pe,
-        seq_per_pe,
-        head_dim,
-        dtype=data_src.dtype,
-        device=data_src.device)
-    output_list.append(recv_data)
+# Prepare output list for all_to_all
+# Receive [BATCH_SIZE, HEADS_PER_PE, SEQ_PER_PE, HEAD_DIM] from each PE
+output_list = [
+    torch.empty(
+        batch_size,
+        heads_per_pe,
+        seq_per_pe,
+        head_dim,
+        dtype=data_src.dtype,
+        device=data_src.device)
+    for _ in range(world_size)
+]

133-133: Document magic number for signal operation

The value 10 appears to be NVSHMEM_SIGNAL_ADD but it's better to use a named constant or add a comment.

+# NVSHMEM_SIGNAL_ADD = 10
 T.signal_op(
     T.address_of(signal[mype[0]]),
     1,
     10,  # NVSHMEM_SIGNAL_ADD
     target_pe)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 10f4fdb and 2b5f147e203c90a2e2471f2d4405cef4941abf21.

📒 Files selected for processing (6)
  • examples/distributed/example_allgather_gemm.py (1 hunks)
  • examples/distributed/example_pre_attn_all2all.py (1 hunks)
  • examples/distributed/ipc/example_pull_warp.py (1 hunks)
  • tilelang/distributed/utils.py (2 hunks)
  • tilelang/env.py (1 hunks)
  • tilelang/jit/adapter/cython/cython_wrapper.pyx (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
examples/distributed/example_allgather_gemm.py (1)
tilelang/jit/__init__.py (1)
  • compile (32-81)
examples/distributed/example_pre_attn_all2all.py (7)
tilelang/utils/tensor.py (1)
  • TensorSupplyType (11-18)
tilelang/distributed/utils.py (1)
  • perf_fn (206-227)
tilelang/env.py (1)
  • disable_cache (246-247)
tilelang/language/allocate.py (1)
  • alloc_local (38-49)
tilelang/language/distributed/multi_device/nvshmem.py (5)
  • get_pe (6-8)
  • putmem_nbi_block (101-109)
  • fence (61-63)
  • signal_op (164-172)
  • signal_wait_until (175-183)
tilelang/language/tir/op.py (1)
  • address_of (463-479)
tilelang/jit/__init__.py (1)
  • compile (32-81)
tilelang/distributed/utils.py (3)
tilelang/profiler/__init__.py (1)
  • init_distributed (73-92)
tilelang/env.py (1)
  • get (136-139)
tilelang/distributed/pynvshmem/python/pynvshmem/__init__.py (1)
  • init_nvshmem_by_uniqueid (40-53)
🪛 Ruff (0.12.2)
examples/distributed/example_pre_attn_all2all.py

4-4: os imported but unused

Remove unused import: os

(F401)


7-7: tilelang.profiler.TensorSupplyType imported but unused

Remove unused import: tilelang.profiler.TensorSupplyType

(F401)


8-8: tilelang.distributed.utils.perf_fn imported but unused

Remove unused import: tilelang.distributed.utils.perf_fn

(F401)


29-29: Local variable rank is assigned to but never used

Remove assignment to unused variable rank

(F841)


49-49: Loop control variable pe_idx not used within loop body

(B007)


266-266: Local variable verification_passed is assigned to but never used

Remove assignment to unused variable verification_passed

(F841)

🔇 Additional comments (3)
examples/distributed/ipc/example_pull_warp.py (1)

18-21: Formatting-only change LGTM.

Signature readability is improved; no semantic impact.

tilelang/jit/adapter/cython/cython_wrapper.pyx (1)

10-11: Import change aligns with new env semantics.

This makes env.USE_DISTRIBUTED resolve to the Environment instance attribute. Looks good.

examples/distributed/example_pre_attn_all2all.py (1)

14-76: Excellent golden reference implementation!

The PyTorch distributed all-to-all reference implementation is well-structured with clear documentation explaining the input/output tensor shapes and transformation steps. This provides a solid baseline for verification.

@chengyupku
Copy link

Thanks Jiaqing! The main brach has been updated and please solve conflicts and merge.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (7)
examples/distributed/example_pre_attn_all2all.py (7)

4-8: Remove unused imports to satisfy Ruff and keep the example minimal

os, TensorSupplyType, and perf_fn are unused.

Apply:

-import os
 import tilelang
 import tilelang.language as T
-from tilelang.profiler import TensorSupplyType
-from tilelang.distributed.utils import init_distributed, dtype_map, perf_fn
+from tilelang.distributed.utils import init_distributed, dtype_map
 import argparse

28-30: Remove unused rank variable

Assigned and never used.

 world_size = dist.get_world_size(group)
-rank = dist.get_rank(group)

214-216: Gate kernel source printing on --print_source as intended

Currently always prints on rank 0.

-    if RANK == 0:
+    if RANK == 0 and args.print_source:
         print("\nTileLang Kernel Source:")
         print(kernel.get_kernel_source())

250-254: Synchronize NVSHMEM ranks after kernel or document why not

Given the non-blocking put and inter-PE signaling, a barrier is often used post-kernel for clarity and safety in examples.

Option A (simple):

-        #pynvshmem.nvshmem_barrier_all()
+        pynvshmem.nvshmem_barrier_all()

Option B (configurable):

  • Add parser.add_argument("--barrier", action="store_true") and guard the call with if args.barrier: ....

167-178: Tighten CLI: fix typo and consider using warmup/repeat for benchmarking

  • Help text nit: missing space after comma.
  • --warmup and --repeat are parsed but unused; either wire them up or drop them.

Apply the typo fix:

-    parser.add_argument(
-        "--num_heads", type=int, default=16, help="Number of attention heads,combine QKV")
+    parser.add_argument(
+        "--num_heads", type=int, default=16, help="Number of attention heads, combine QKV")

Optional: integrate simple timing using torch.cuda.Event() around tilelang_all_to_all() or re-introduce perf_fn and benchmark just the kernel execution while allocating NVSHMEM tensors once outside the timed region.


116-124: Compute transfer size from dtype (not hard-coded 2 bytes)

Hard-coding float16 breaks for other dtypes and mismatches --dtype.

Apply:

-                transfer_size = SEQ_PER_PE * HEAD_DIM * 2  # float16 = 2 bytes
+                transfer_size = SEQ_PER_PE * HEAD_DIM * DTYPE_BYTES

And define DTYPE_BYTES once before @T.prim_func:

# Outside prim_func, near sequence_parallel_all_to_all signature
DTYPE_BYTES_MAP = {
    "float16": 2, "bfloat16": 2, "float32": 4, "float64": 8,
    "int8": 1, "uint8": 1, "int32": 4, "uint32": 4, "int64": 8, "uint64": 8,
}
DTYPE_BYTES = DTYPE_BYTES_MAP.get(dtype)
assert DTYPE_BYTES is not None, f"Unsupported dtype: {dtype}"

266-270: Fail fast on verification errors and aggregate across PEs

Currently the result is ignored, and CI won’t catch mismatches.

-    verification_passed = verify_results(tilelang_output, torch_output, RANK)
+    verification_passed = verify_results(tilelang_output, torch_output, RANK)
+    verification_results = [None] * WORLD_SIZE
+    dist.all_gather_object(verification_results, verification_passed, group=TP_GROUP)
+    if RANK == 0 and not all(verification_results):
+        print("❌ Verification failed on one or more PEs!")
+        dist.destroy_process_group()
+        raise SystemExit(1)
🧹 Nitpick comments (2)
examples/distributed/example_pre_attn_all2all.py (2)

49-59: Avoid unused loop index warning

Use _ since the index isn’t used in the loop body.

- for pe_idx in range(world_size):
+ for _ in range(world_size):

236-256: Benchmark correctly: allocate NVSHMEM tensors once, time only the kernel

Current structure would time allocation if you hook up benchmarking later. Allocate once, then pass closures to a perf helper.

Sketch:

# allocate once outside
data_src = pynvshmem.nvshmem_create_tensor([...], dtype_torch)
data_dst = pynvshmem.nvshmem_create_tensor([...], dtype_torch)
signal   = pynvshmem.nvshmem_create_tensor([PE_num], torch.uint64)

def run_kernel():
    data_src.copy_(input_data); data_dst.zero_(); signal.zero_()
    kernel(data_src, data_dst, signal)
    pynvshmem.nvshmem_barrier_all()
    return data_dst

# warmup/repeat timing if desired
# _, ms = perf_fn(run_kernel, args.warmup, args.repeat)
# if RANK == 0: print(f"TileLang A2A avg: {ms:.3f} ms")
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 2b5f147e203c90a2e2471f2d4405cef4941abf21 and a7005f4.

📒 Files selected for processing (5)
  • examples/distributed/example_allgather_gemm.py (1 hunks)
  • examples/distributed/example_pre_attn_all2all.py (1 hunks)
  • tilelang/distributed/utils.py (2 hunks)
  • tilelang/env.py (1 hunks)
  • tilelang/jit/adapter/cython/cython_wrapper.pyx (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (4)
  • tilelang/distributed/utils.py
  • examples/distributed/example_allgather_gemm.py
  • tilelang/jit/adapter/cython/cython_wrapper.pyx
  • tilelang/env.py
🧰 Additional context used
🧬 Code graph analysis (1)
examples/distributed/example_pre_attn_all2all.py (6)
tilelang/distributed/utils.py (1)
  • perf_fn (206-227)
tilelang/env.py (1)
  • disable_cache (246-247)
tilelang/language/allocate.py (1)
  • alloc_local (38-49)
tilelang/language/distributed/multi_device/nvshmem.py (5)
  • get_pe (6-8)
  • putmem_nbi_block (101-109)
  • fence (61-63)
  • signal_op (164-172)
  • signal_wait_until (175-183)
tilelang/language/tir/op.py (1)
  • address_of (463-479)
tilelang/jit/__init__.py (1)
  • compile (32-81)
🪛 Ruff (0.12.2)
examples/distributed/example_pre_attn_all2all.py

4-4: os imported but unused

Remove unused import: os

(F401)


7-7: tilelang.profiler.TensorSupplyType imported but unused

Remove unused import: tilelang.profiler.TensorSupplyType

(F401)


8-8: tilelang.distributed.utils.perf_fn imported but unused

Remove unused import: tilelang.distributed.utils.perf_fn

(F401)


29-29: Local variable rank is assigned to but never used

Remove assignment to unused variable rank

(F841)


49-49: Loop control variable pe_idx not used within loop body

(B007)


266-266: Local variable verification_passed is assigned to but never used

Remove assignment to unused variable verification_passed

(F841)

🔇 Additional comments (1)
examples/distributed/example_pre_attn_all2all.py (1)

11-11: Confirm the correct cache-disabling API

tilelang.disable_cache() may be a mis-call; in many places the API is from tilelang.env import env; env.disable_cache().

If needed, switch to:

-tilelang.disable_cache()
+from tilelang.env import env
+env.disable_cache()

Comment on lines +129 to +137
if tx == 0:
T.signal_op(
T.address_of(signal[mype[0]]),
1,
10, # NVSHMEM_SIGNAL_ADD
target_pe)
for k in T.serial(PE_num):
T.signal_wait_until(T.address_of(signal[k]), 0, NUM_BLOCKS_X)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Replace magic numbers in NVSHMEM signaling with named constants

10 and 0 are unclear; use named constants for maintainability.

Apply:

-                T.signal_op(
+                T.signal_op(
                     T.address_of(signal[mype[0]]),
                     1,
-                    10,  # NVSHMEM_SIGNAL_ADD
+                    NVSHMEM_SIGNAL_ADD,
                     target_pe)
                 for k in T.serial(PE_num):
-                    T.signal_wait_until(T.address_of(signal[k]), 0, NUM_BLOCKS_X)
+                    T.signal_wait_until(T.address_of(signal[k]), NVSHMEM_CMP_EQ, NUM_BLOCKS_X)

Add once near the top of the file:

NVSHMEM_SIGNAL_ADD = 10
NVSHMEM_CMP_EQ = 0
🤖 Prompt for AI Agents
In examples/distributed/example_pre_attn_all2all.py around lines 129 to 137, the
NVSHMEM signal calls use magic numbers (10 and 0); define named constants
NVSHMEM_SIGNAL_ADD = 10 and NVSHMEM_CMP_EQ = 0 near the top of the file and
replace the literal 10 with NVSHMEM_SIGNAL_ADD and the literal 0 used in
signal_wait_until with NVSHMEM_CMP_EQ so the intent is clear and maintainable.

Comment on lines +184 to +187
# Initialize distributed environment
WORLD_SIZE, RANK, LOCAL_RANK, TP_GROUP = init_distributed(return_tp_group=True)
PE_num = WORLD_SIZE

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Ensure tensors go to the correct GPU in multi-process setups

Set CUDA device using LOCAL_RANK after initialization to avoid all ranks using GPU 0.

 WORLD_SIZE, RANK, LOCAL_RANK, TP_GROUP = init_distributed(return_tp_group=True)
 PE_num = WORLD_SIZE
+torch.cuda.set_device(LOCAL_RANK)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Initialize distributed environment
WORLD_SIZE, RANK, LOCAL_RANK, TP_GROUP = init_distributed(return_tp_group=True)
PE_num = WORLD_SIZE
# Initialize distributed environment
WORLD_SIZE, RANK, LOCAL_RANK, TP_GROUP = init_distributed(return_tp_group=True)
PE_num = WORLD_SIZE
torch.cuda.set_device(LOCAL_RANK)
🤖 Prompt for AI Agents
In examples/distributed/example_pre_attn_all2all.py around lines 184 to 187,
after calling init_distributed(return_tp_group=True) you must set the CUDA
device to LOCAL_RANK so each process uses its assigned GPU; call
torch.cuda.set_device(LOCAL_RANK) (or equivalent) immediately after
initialization and ensure subsequent tensor/model creation or .to(device) uses
that device to avoid all ranks defaulting to GPU 0.

Comment on lines +222 to +224
input_data = torch.rand([args.batch_size, args.num_heads, SEQ_PER_PE, args.head_dim],
dtype=dtype_torch,
device='cuda')
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Create tensors on the rank-local device

Be explicit to avoid device mismatch when the default device isn’t set globally.

-    input_data = torch.rand([args.batch_size, args.num_heads, SEQ_PER_PE, args.head_dim],
-                            dtype=dtype_torch,
-                            device='cuda')
+    input_data = torch.rand(
+        [args.batch_size, args.num_heads, SEQ_PER_PE, args.head_dim],
+        dtype=dtype_torch,
+        device=f'cuda:{LOCAL_RANK}',
+    )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
input_data = torch.rand([args.batch_size, args.num_heads, SEQ_PER_PE, args.head_dim],
dtype=dtype_torch,
device='cuda')
input_data = torch.rand(
[args.batch_size, args.num_heads, SEQ_PER_PE, args.head_dim],
dtype=dtype_torch,
device=f'cuda:{LOCAL_RANK}',
)
🤖 Prompt for AI Agents
In examples/distributed/example_pre_attn_all2all.py around lines 222 to 224, the
tensor is created with device='cuda' which can cause device-mismatch across
ranks; instead determine the rank-local CUDA device (e.g., device =
torch.device(f"cuda:{args.local_rank}") or device =
torch.device(f"cuda:{torch.cuda.current_device()}") if args.local_rank isn’t
available) and pass that device to torch.rand so the tensor is allocated on the
correct per-rank GPU.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants