- 
                Notifications
    You must be signed in to change notification settings 
- Fork 2
          [Feature][Example] Support  return_peers in _allocate_tensor and add ag_gemm_ipc example
          #25
        
          New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
9de9e1b
              e9d360a
              fd39138
              a053313
              abfbcd4
              499a01b
              b555cff
              4904152
              f43e3cb
              fe734cc
              4b664ab
              8c9c0e8
              70e45ab
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|  | @@ -97,7 +97,6 @@ jobs: | |||||||||||||||||||||||||||||||||||||||||
| PIP_NO_BUILD_ISOLATION=1 pip install -r requirements-test.txt --no-user | ||||||||||||||||||||||||||||||||||||||||||
| # flash attention usually requires no isolation build | ||||||||||||||||||||||||||||||||||||||||||
| pip install flash_attn==2.5.8 --no-user --no-build-isolation | ||||||||||||||||||||||||||||||||||||||||||
| pip install . --no-user | ||||||||||||||||||||||||||||||||||||||||||
| touch "$MARKER" | ||||||||||||||||||||||||||||||||||||||||||
| fi | ||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||
|  | @@ -111,11 +110,49 @@ jobs: | |||||||||||||||||||||||||||||||||||||||||
| source "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate" | ||||||||||||||||||||||||||||||||||||||||||
| cd examples | ||||||||||||||||||||||||||||||||||||||||||
| unset PYTHONPATH | ||||||||||||||||||||||||||||||||||||||||||
| python -m pytest -n 4 **/test*.py -v -r fE | ||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||
| # find and run distributed tests with TILELANG_USE_DISTRIBUTED=1 | ||||||||||||||||||||||||||||||||||||||||||
| mapfile -t DIST_TESTS < <(find . -type f -path '*/distributed/*' -name 'test*.py' 2>/dev/null || true) | ||||||||||||||||||||||||||||||||||||||||||
| if [ "${#DIST_TESTS[@]}" -gt 0 ]; then | ||||||||||||||||||||||||||||||||||||||||||
| echo "Running distributed examples with TILELANG_USE_DISTRIBUTED=1:" | ||||||||||||||||||||||||||||||||||||||||||
| printf '%s\n' "${DIST_TESTS[@]}" | ||||||||||||||||||||||||||||||||||||||||||
| TILELANG_USE_DISTRIBUTED=1 python -m pytest -n 4 "${DIST_TESTS[@]}" -v -r fE | ||||||||||||||||||||||||||||||||||||||||||
| else | ||||||||||||||||||||||||||||||||||||||||||
| echo "No distributed examples found." | ||||||||||||||||||||||||||||||||||||||||||
| fi | ||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||
| # run remaining example tests (non-distributed) | ||||||||||||||||||||||||||||||||||||||||||
| mapfile -t OTHER_TESTS < <(find . -type f -name 'test*.py' ! -path '*/distributed/*' 2>/dev/null || true) | ||||||||||||||||||||||||||||||||||||||||||
| if [ "${#OTHER_TESTS[@]}" -gt 0 ]; then | ||||||||||||||||||||||||||||||||||||||||||
| echo "Running non-distributed examples:" | ||||||||||||||||||||||||||||||||||||||||||
| printf '%s\n' "${OTHER_TESTS[@]}" | ||||||||||||||||||||||||||||||||||||||||||
| python -m pytest -n 4 "${OTHER_TESTS[@]}" -v -r fE | ||||||||||||||||||||||||||||||||||||||||||
| else | ||||||||||||||||||||||||||||||||||||||||||
| echo "No non-distributed example tests found." | ||||||||||||||||||||||||||||||||||||||||||
| fi | ||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||
| - name: Run tests | ||||||||||||||||||||||||||||||||||||||||||
| run: | | ||||||||||||||||||||||||||||||||||||||||||
| source "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate" | ||||||||||||||||||||||||||||||||||||||||||
| cd testing/python | ||||||||||||||||||||||||||||||||||||||||||
| unset PYTHONPATH | ||||||||||||||||||||||||||||||||||||||||||
| python -m pytest -n 4 -v -r fE | ||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||
| # run distributed tests first with env var | ||||||||||||||||||||||||||||||||||||||||||
| mapfile -t DIST_TESTS < <(find . -type f -path '*/distributed/*' -name 'test*.py' 2>/dev/null || true) | ||||||||||||||||||||||||||||||||||||||||||
| if [ "${#DIST_TESTS[@]}" -gt 0 ]; then | ||||||||||||||||||||||||||||||||||||||||||
| echo "Running distributed tests with TILELANG_USE_DISTRIBUTED=1:" | ||||||||||||||||||||||||||||||||||||||||||
| printf '%s\n' "${DIST_TESTS[@]}" | ||||||||||||||||||||||||||||||||||||||||||
| TILELANG_USE_DISTRIBUTED=1 python -m pytest -n 4 "${DIST_TESTS[@]}" -v -r fE | ||||||||||||||||||||||||||||||||||||||||||
| else | ||||||||||||||||||||||||||||||||||||||||||
| echo "No distributed tests found under testing/python." | ||||||||||||||||||||||||||||||||||||||||||
| fi | ||||||||||||||||||||||||||||||||||||||||||
| 
      Comment on lines
    
      +140
     to 
      +148
    
   There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Run distributed tests (testing/python) serially and add strict shell flags Same flakiness risk here from -n 4 on GPU-distributed tests; also add set -euo pipefail for safety. Apply this diff: -        # run distributed tests first with env var
+        # run distributed tests first with env var
+        set -euo pipefail
         mapfile -t DIST_TESTS < <(find . -type f -path '*/distributed/*' -name 'test*.py' 2>/dev/null || true)
         if [ "${#DIST_TESTS[@]}" -gt 0 ]; then
           echo "Running distributed tests with TILELANG_USE_DISTRIBUTED=1:"
           printf '%s\n' "${DIST_TESTS[@]}"
-          TILELANG_USE_DISTRIBUTED=1 python -m pytest -n 4 "${DIST_TESTS[@]}" -v -r fE
+          # run distributed tests serially to avoid GPU/process-group contention
+          TILELANG_USE_DISTRIBUTED=1 python -m pytest -n 1 "${DIST_TESTS[@]}" -v -r fE
         else
           echo "No distributed tests found under testing/python."
         fi📝 Committable suggestion
 
        Suggested change
       
 🤖 Prompt for AI Agents | ||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||
| # run remaining tests | ||||||||||||||||||||||||||||||||||||||||||
| mapfile -t OTHER_TESTS < <(find . -type f -name 'test*.py' ! -path '*/distributed/*' 2>/dev/null || true) | ||||||||||||||||||||||||||||||||||||||||||
| if [ "${#OTHER_TESTS[@]}" -gt 0 ]; then | ||||||||||||||||||||||||||||||||||||||||||
| echo "Running non-distributed tests:" | ||||||||||||||||||||||||||||||||||||||||||
| printf '%s\n' "${OTHER_TESTS[@]}" | ||||||||||||||||||||||||||||||||||||||||||
| python -m pytest -n 4 "${OTHER_TESTS[@]}" -v -r fE | ||||||||||||||||||||||||||||||||||||||||||
| else | ||||||||||||||||||||||||||||||||||||||||||
| echo "No non-distributed tests found under testing/python." | ||||||||||||||||||||||||||||||||||||||||||
| fi | ||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,198 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import os | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import tilelang | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import tilelang.language as T | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import argparse | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import torch.distributed as dist | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import torch.multiprocessing | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from tilelang.distributed import init_dist | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from cuda import cudart | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from tilelang.distributed import set_signal, wait_eq | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tilelang.disable_cache() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| os.environ['NCCL_DEBUG'] = 'WARN' # silence NCCL log | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def gemm_kernel(M, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| N, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| K, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| num_rank, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| block_M, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| block_N, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| block_K, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| threads, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dtype="float16", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| accum_dtype="float"): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @T.prim_func | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def main( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| A: T.Tensor((M, K), dtype), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| B: T.Tensor((K, N // num_rank), dtype), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| C: T.Tensor((M, N // num_rank), dtype), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| A_shared = T.alloc_shared((block_M, block_K), dtype) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| B_shared = T.alloc_shared((block_K, block_N), dtype) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| C_local = T.alloc_fragment((block_M, block_N), accum_dtype) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| T.clear(C_local) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| T.copy(A[by * block_M, k * block_K], A_shared) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| T.copy(B[k * block_K, bx * block_N], B_shared) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| T.gemm(A_shared, B_shared, C_local) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| T.copy(C_local, C[by * block_M, bx * block_N]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| 
      Comment on lines
    
      +33
     to 
      +44
    
   There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix grid size: using N instead of N_per_rank causes out-of-bounds on B/C Kernel grid must reflect local N dimension per rank. Using global N overshoots. Apply this diff: -        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
+        with T.Kernel(T.ceildiv(N // num_rank, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):📝 Committable suggestion
 
        Suggested change
       
 🤖 Prompt for AI Agents | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return main | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def cp_engine_producer_all_gather_put(local_tensor, ag_buffer, signal_buffer, M_per_rank, N, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| signal_target, rank, local_world_size, world_size, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| intranode_ag_stream): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| local_rank = rank % local_world_size | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| 
      Comment on lines
    
      +48
     to 
      +51
    
   There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Parameter naming bug:  Misnaming invites misuse. Rename to  Apply this diff: -def cp_engine_producer_all_gather_put(local_tensor, ag_buffer, signal_buffer, M_per_rank, N,
+def cp_engine_producer_all_gather_put(local_tensor, ag_buffer, signal_buffer, M_per_rank, K,
                                       signal_target, rank, local_world_size, world_size,
                                       intranode_ag_stream):Follow-up diffs below update its uses. 📝 Committable suggestion
 
        Suggested change
       
 🤖 Prompt for AI Agents | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| n_nodes = world_size // local_world_size | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| node_rank = rank // local_world_size | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for i in range(1, local_world_size): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| segment = rank * M_per_rank * N | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| local_dst_rank = (local_rank + local_world_size - i) % local_world_size | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| src_ptr = ag_buffer[local_rank].data_ptr() + segment * local_tensor.element_size() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dst_ptr = ag_buffer[local_dst_rank].data_ptr() + segment * local_tensor.element_size() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Using copy engine to perform intranode transmission | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Sending rank-th local tensor to other ranks inside the node. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| (err,) = cudart.cudaMemcpyAsync( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dst_ptr, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| src_ptr, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| M_per_rank * N * local_tensor.element_size(), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| cudart.cudaMemcpyKind.cudaMemcpyDefault, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| intranode_ag_stream.cuda_stream, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Notify the peer that the transmission is done. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| 
      Comment on lines
    
      +55
     to 
      +69
    
   There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion | 🟠 Major Use K for segment sizing and add CUDA error checks Copy size/offset must be based on the width of A (K). Also check  Apply this diff: -    for i in range(1, local_world_size):
-        segment = rank * M_per_rank * N
+    for i in range(1, local_world_size):
+        segment = rank * M_per_rank * K
         local_dst_rank = (local_rank + local_world_size - i) % local_world_size
-        src_ptr = ag_buffer[local_rank].data_ptr() + segment * local_tensor.element_size()
-        dst_ptr = ag_buffer[local_dst_rank].data_ptr() + segment * local_tensor.element_size()
+        src_ptr = ag_buffer[local_rank].data_ptr() + segment * local_tensor.element_size()
+        dst_ptr = ag_buffer[local_dst_rank].data_ptr() + segment * local_tensor.element_size()
         # Using copy engine to perform intranode transmission
         # Sending rank-th local tensor to other ranks inside the node.
         (err,) = cudart.cudaMemcpyAsync(
             dst_ptr,
             src_ptr,
-            M_per_rank * N * local_tensor.element_size(),
+            M_per_rank * K * local_tensor.element_size(),
             cudart.cudaMemcpyKind.cudaMemcpyDefault,
             intranode_ag_stream.cuda_stream,
         )
+        CUDA_CHECK(err)
         # Notify the peer that the transmission is done.
         set_signal(signal_buffer[local_dst_rank][rank], signal_target, intranode_ag_stream)📝 Committable suggestion
 
        Suggested change
       
 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| set_signal(signal_buffer[local_dst_rank][rank], signal_target, intranode_ag_stream) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for i in range(1, n_nodes): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| recv_rank = local_rank + (node_rank + n_nodes - i) % n_nodes * local_world_size | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| recv_segment = recv_rank * M_per_rank * N | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Waiting for the internode data ready | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| wait_eq(signal_buffer[local_rank][recv_rank], signal_target, intranode_ag_stream) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| src_ptr = ag_buffer[local_rank].data_ptr() + recv_segment * local_tensor.element_size() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for j in range(1, local_world_size): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| local_dst_rank = (local_rank + local_world_size - j) % local_world_size | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dst_ptr = ag_buffer[local_dst_rank].data_ptr( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) + recv_segment * local_tensor.element_size() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Sending (local_rank + j*local_world_size) % world_size -th local tensor to other ranks inside the node. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| (err,) = cudart.cudaMemcpyAsync( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dst_ptr, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| src_ptr, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| M_per_rank * N * local_tensor.element_size(), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| cudart.cudaMemcpyKind.cudaMemcpyDefault, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| intranode_ag_stream.cuda_stream, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Notify the peer that the transmission is done. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| set_signal(signal_buffer[local_dst_rank][recv_rank], signal_target, intranode_ag_stream) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| 
      Comment on lines
    
      +72
     to 
      +92
    
   There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion | 🟠 Major Likewise for inter-node segment sizing and error checks Mirror the K-based sizing and add error handling. Apply this diff: -    for i in range(1, n_nodes):
-        recv_rank = local_rank + (node_rank + n_nodes - i) % n_nodes * local_world_size
-        recv_segment = recv_rank * M_per_rank * N
+    for i in range(1, n_nodes):
+        recv_rank = local_rank + (node_rank + n_nodes - i) % n_nodes * local_world_size
+        recv_segment = recv_rank * M_per_rank * K
         # Waiting for the internode data ready
         wait_eq(signal_buffer[local_rank][recv_rank], signal_target, intranode_ag_stream)
         src_ptr = ag_buffer[local_rank].data_ptr() + recv_segment * local_tensor.element_size()
         for j in range(1, local_world_size):
             local_dst_rank = (local_rank + local_world_size - j) % local_world_size
             dst_ptr = ag_buffer[local_dst_rank].data_ptr(
             ) + recv_segment * local_tensor.element_size()
             # Sending (local_rank + j*local_world_size) % world_size -th local tensor to other ranks inside the node.
             (err,) = cudart.cudaMemcpyAsync(
                 dst_ptr,
                 src_ptr,
-                M_per_rank * N * local_tensor.element_size(),
+                M_per_rank * K * local_tensor.element_size(),
                 cudart.cudaMemcpyKind.cudaMemcpyDefault,
                 intranode_ag_stream.cuda_stream,
             )
+            CUDA_CHECK(err)
             # Notify the peer that the transmission is done.
             set_signal(signal_buffer[local_dst_rank][recv_rank], signal_target, intranode_ag_stream)📝 Committable suggestion
 
        Suggested change
       
 🧰 Tools🪛 Ruff (0.13.1)83-83: Unpacked variable  Prefix it with an underscore or any other dummy variable pattern (RUF059) 🤖 Prompt for AI Agents | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def ag_gemm_op(A, B, C, ag_buffer, signal_buffer, M_per_rank, N, signal_target, rank, group, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| local_world_size, world_size, gemm_kernel, ag_stream): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dist.barrier(group) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # all_gather A to ag_buffer | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| with torch.cuda.stream(ag_stream): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| cp_engine_producer_all_gather_put(A, ag_buffer, signal_buffer, M_per_rank, N, signal_target, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Pass correct K dimension to cp_engine_producer_all_gather_put. The function call passes  Apply this diff: -        cp_engine_producer_all_gather_put(A, ag_buffer, signal_buffer, M_per_rank, N, signal_target,
+        cp_engine_producer_all_gather_put(A, ag_buffer, signal_buffer, M_per_rank, K, signal_target,
                                           rank, local_world_size, world_size, ag_stream)📝 Committable suggestion
 
        Suggested change
       
 🤖 Prompt for AI Agents | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| rank, local_world_size, world_size, ag_stream) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| current_stream = torch.cuda.current_stream() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| current_stream.wait_stream(ag_stream) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dist.barrier(group) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| torch.cuda.synchronize() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| torch.cuda.synchronize() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| torch.distributed.barrier(group) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| gemm_kernel(ag_buffer[rank], B, C) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| torch.cuda.synchronize() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| torch.distributed.barrier(group) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| 
      Comment on lines
    
      +110
     to 
      +114
    
   There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use local peer for kernel input; drop redundant sync Fix out-of-range access when rank >= local_world_size. Also remove duplicate synchronize. Apply this diff: -    torch.cuda.synchronize()
-    torch.distributed.barrier(group)
-    gemm_kernel(ag_buffer[rank], B, C)
+    gemm_kernel(ag_buffer[local_rank], B, C)
     torch.cuda.synchronize()
-    torch.distributed.barrier(group)
+    torch.distributed.barrier(group)📝 Committable suggestion
 
        Suggested change
       
 🤖 Prompt for AI Agents | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return C | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def torch_ag_gemm( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| pg: torch.distributed.ProcessGroup, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| local_input: torch.Tensor, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| local_weight: torch.Tensor, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ag_out: torch.Tensor, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| torch.distributed.all_gather_into_tensor(ag_out, local_input, pg) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ag_gemm_output = torch.matmul(ag_out, local_weight) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return ag_gemm_output | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dtype = torch.float16 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| M = args.M if args else 8192 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| N = args.N if args else 8192 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| K = args.K if args else 8192 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| M_per_rank = M // num_local_ranks | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| N_per_rank = N // num_local_ranks | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| BLOCK_M = 128 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| BLOCK_N = 128 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| BLOCK_K = 64 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| threads = 256 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| rank, num_ranks, group = init_dist(local_rank, num_local_ranks) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| allocator = tilelang.get_allocator( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| size=2**30, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| device="cuda", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| is_distributed=True, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| local_rank=local_rank, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| num_local_ranks=num_local_ranks, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| group=group) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| kernel = tilelang.compile(gemm_kernel(M, N, K, num_ranks, BLOCK_M, BLOCK_N, BLOCK_K, threads)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| 
      Comment on lines
    
      +130
     to 
      +151
    
   There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Compute per-rank sizes using the global world size and after init_dist Current M_per_rank/N_per_rank use num_local_ranks, which breaks multi-node (OOB segment indexing). Apply this diff: -    M = args.M if args else 8192
-    N = args.N if args else 8192
-    K = args.K if args else 8192
-    M_per_rank = M // num_local_ranks
-    N_per_rank = N // num_local_ranks
+    M = args.M if args else 8192
+    N = args.N if args else 8192
+    K = args.K if args else 8192
@@
-    rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
+    rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
+    M_per_rank = M // num_ranks
+    N_per_rank = N // num_ranks📝 Committable suggestion
 
        Suggested change
       
 🤖 Prompt for AI Agents | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| kernel.initialize(allocator=allocator) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if local_rank == 0: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| print(kernel.get_kernel_source()) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| A = tilelang.tensor((M_per_rank, K), dtype, allocator=allocator).normal_() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| B = tilelang.tensor((K, N_per_rank), dtype, allocator=allocator).normal_() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| C = tilelang.tensor((M, N_per_rank), dtype, allocator=allocator) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ag_buffer = tilelang.tensor((M, K), dtype, allocator=allocator, return_peers=True) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| signal_buffer = tilelang.tensor((num_local_ranks,), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| torch.int32, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| allocator=allocator, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return_peers=True) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| signal_buffer[rank].fill_(0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ag_buffer[rank][rank * M_per_rank:(rank + 1) * M_per_rank, :].copy_(A) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| 
      Comment on lines
    
      +156
     to 
      +166
    
   There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix peer buffer/signal shapes and local indexing 
 Apply this diff: -    A = tilelang.tensor((M_per_rank, K), dtype, allocator=allocator).normal_()
-    B = tilelang.tensor((K, N_per_rank), dtype, allocator=allocator).normal_()
-    C = tilelang.tensor((M, N_per_rank), dtype, allocator=allocator)
-    ag_buffer = tilelang.tensor((M, K), dtype, allocator=allocator, return_peers=True)
-    signal_buffer = tilelang.tensor((num_local_ranks,),
-                                    torch.int32,
-                                    allocator=allocator,
-                                    return_peers=True)
-    signal_buffer[rank].fill_(0)
-    ag_buffer[rank][rank * M_per_rank:(rank + 1) * M_per_rank, :].copy_(A)
+    A = tilelang.tensor((M_per_rank, K), dtype, allocator=allocator).normal_()
+    B = tilelang.tensor((K, N_per_rank), dtype, allocator=allocator).normal_()
+    C = tilelang.tensor((M, N_per_rank), dtype, allocator=allocator)
+    ag_buffer = tilelang.tensor((M, K), dtype, allocator=allocator, return_peers=True)
+    signal_buffer = tilelang.tensor((num_ranks,),
+                                    torch.int32,
+                                    allocator=allocator,
+                                    return_peers=True)
+    signal_buffer[local_rank].fill_(0)
+    ag_buffer[local_rank][rank * M_per_rank:(rank + 1) * M_per_rank, :].copy_(A)📝 Committable suggestion
 
        Suggested change
       
 🤖 Prompt for AI Agents | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dist.barrier(group) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ag_stream = torch.cuda.Stream() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| signal_target = 1 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tilelang_C = ag_gemm_op(A, B, C, ag_buffer, signal_buffer, M_per_rank, K, signal_target, rank, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| group, num_local_ranks, num_local_ranks, kernel, ag_stream) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| 
      Comment on lines
    
      +172
     to 
      +174
    
   There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Pass the correct world size (global) to  You currently pass  Apply this diff: -    tilelang_C = ag_gemm_op(A, B, C, ag_buffer, signal_buffer, M_per_rank, K, signal_target, rank,
-                            group, num_local_ranks, num_local_ranks, kernel, ag_stream)
+    tilelang_C = ag_gemm_op(A, B, C, ag_buffer, signal_buffer, M_per_rank, K, signal_target, rank,
+                            group, num_local_ranks, num_ranks, kernel, ag_stream)📝 Committable suggestion
 
        Suggested change
       
 🤖 Prompt for AI Agents | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| torch_ag_buffer = torch.empty([M, K], dtype=dtype, device="cuda") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| torch_C = torch_ag_gemm(group, A, B, torch_ag_buffer) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if torch.allclose(torch_C, tilelang_C, atol=1e-6, rtol=1e-6): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| print(f"rank {local_rank} check passed.✅") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| print(f"rank {local_rank} check failed.❌") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| print(f"torch_C: {torch_C}, tilelang_C: {tilelang_C}") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError("Test failed") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dist.destroy_process_group() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if __name__ == "__main__": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| parser = argparse.ArgumentParser() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| parser.add_argument( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| '--num-processes', type=int, default=2, help='Number of processes to spawn (default: 2)') | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| parser.add_argument('--M', type=int, default=8192, help='M dimension') | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| parser.add_argument('--N', type=int, default=8192, help='N dimension') | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| parser.add_argument('--K', type=int, default=8192, help='K dimension') | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| args = parser.parse_args() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| num_processes = args.num_processes | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| torch.multiprocessing.spawn(main, args=(num_processes, args), nprocs=num_processes) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Run distributed example tests serially and harden shell options
Parallelizing distributed GPU tests with pytest-xdist (-n 4) commonly oversubscribes GPUs/process groups and causes flaky hangs/OOMs. Also add strict shell flags to fail fast on CI.
Apply this diff:
📝 Committable suggestion
🤖 Prompt for AI Agents