Skip to content

Commit 6ca8c15

Browse files
committed
format code
1 parent 1e0e208 commit 6ca8c15

File tree

2 files changed

+3
-6
lines changed

2 files changed

+3
-6
lines changed

examples/distributed/example_overlapping_allgather.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,19 @@
1-
import argparse
21
import torch
32
import torch.distributed as dist
43
import pynvshmem
54
import tilelang
65
import tilelang.language as T
76
import os
8-
from tilelang.distributed.utils import init_distributed, dtype_map, perf_fn
9-
from tilelang.distributed.utils import init_dist
7+
from tilelang.distributed.utils import init_distributed
108
from tilelang.env import env
119
from packaging import version
1210
import importlib.metadata
1311

1412
cuda_python_version = importlib.metadata.version("cuda-python")
1513
if version.parse(cuda_python_version) >= version.parse("12.8.0"):
16-
from cuda.bindings import driver as cuda
1714
from cuda.bindings import runtime as cudart
1815
else:
19-
from cuda import cuda, cudart
16+
from cuda import cudart
2017
# NODES=2 NODE_RANK=0 ARNOLD_WORKER_0_HOST=ip0 bash tilelang/distributed/launch.sh ./examples/distributed/example_overlapping_allgather.py
2118
# NODES=2 NODE_RANK=1 ARNOLD_WORKER_0_HOST=ip0 bash tilelang/distributed/launch.sh ./examples/distributed/example_overlapping_allgather.py
2219

tilelang/distributed/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def init_distributed(return_tp_group=False, init_nvshmem=True, return_lc_group=F
8282
base = (RANK // local_world_size) * local_world_size
8383
LC_GROUP = torch.distributed.new_group(
8484
list(range(base, base + local_world_size)), backend="nccl")
85-
85+
8686
return WORLD_SIZE, RANK, LOCAL_RANK, TP_GROUP, LC_GROUP
8787
elif return_tp_group:
8888
return WORLD_SIZE, RANK, LOCAL_RANK, TP_GROUP

0 commit comments

Comments
 (0)