Skip to content

Commit a2b8e46

Browse files
committed
[Typo] Correct architecture selection for CUDA and CDNA
1 parent 6d26745 commit a2b8e46

File tree

5 files changed

+5
-5
lines changed

5 files changed

+5
-5
lines changed

benchmark/matmul/benchmark_matmul.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def get_configs(args, kwargs):
5353
from tilelang.carver.roller.rasterization import NoRasterization
5454
import torch
5555

56-
arch = CDNA("cuda") if torch.version.hip is None else CUDA("hip")
56+
arch = CUDA("cuda") if torch.version.hip is None else CDNA("hip")
5757
topk = 10
5858

5959
carve_template = MatmulTemplate(

benchmark/matmul/benchmark_matmul_intrinsic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def get_configs(args, kwargs):
187187
from tilelang.carver.roller.rasterization import NoRasterization
188188
import torch
189189

190-
arch = CDNA("cuda") if torch.version.hip is None else CUDA("hip")
190+
arch = CUDA("cuda") if torch.version.hip is None else CDNA("hip")
191191
topk = 10
192192

193193
carve_template = MatmulTemplate(

examples/analyze/example_conv_analyze.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def conv(
9696

9797
def main():
9898
my_func = kernel(N, C, H, W, F, K, S, D, P, 64, 128, 32, 3, 256)
99-
cuda_device = CDNA("cuda") if torch.version.hip is None else CUDA("hip")
99+
cuda_device = CUDA("cuda") if torch.version.hip is None else CDNA("hip")
100100
result = Analyzer.analysis(my_func, cuda_device)
101101
print(result)
102102
print(f"Analyzed FLOPs: {result.total_flops}")

examples/analyze/example_gemm_analyze.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def matmul(
4949
def main():
5050
my_func = kernel(128, 128, 32, 3, 128, True)
5151

52-
cuda_device = CDNA("cuda") if torch.version.hip is None else CUDA("hip")
52+
cuda_device = CUDA("cuda") if torch.version.hip is None else CDNA("hip")
5353
result = Analyzer.analysis(my_func, cuda_device)
5454

5555
print(f"Analyzed FLOPs: {result.total_flops}")

examples/gemm/example_gemm_autotune.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def ref_program(A, B):
1616

1717
def get_configs(M, N, K, with_roller=False, topk=20):
1818
if with_roller:
19-
arch = CDNA("cuda") if torch.version.hip is None else CUDA("hip")
19+
arch = CUDA("cuda") if torch.version.hip is None else CDNA("hip")
2020
carve_template = MatmulTemplate(
2121
M=M,
2222
N=N,

0 commit comments

Comments
 (0)