Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmark/matmul/benchmark_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def get_configs(args, kwargs):
from tilelang.carver.roller.rasterization import NoRasterization
import torch

arch = CDNA("cuda") if torch.version.hip is None else CUDA("hip")
arch = CUDA("cuda") if torch.version.hip is None else CDNA("hip")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This change corrects a critical bug in the architecture detection logic. The previous implementation had CUDA and CDNA architectures swapped, which would cause incorrect architecture-specific code paths to be taken. This fix ensures the correct platform is identified based on the presence of torch.version.hip.

topk = 10

carve_template = MatmulTemplate(
Expand Down
2 changes: 1 addition & 1 deletion benchmark/matmul/benchmark_matmul_intrinsic.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def get_configs(args, kwargs):
from tilelang.carver.roller.rasterization import NoRasterization
import torch

arch = CDNA("cuda") if torch.version.hip is None else CUDA("hip")
arch = CUDA("cuda") if torch.version.hip is None else CDNA("hip")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This change corrects a critical bug in the architecture detection logic. The previous implementation had CUDA and CDNA architectures swapped, which would cause incorrect architecture-specific code paths to be taken. This fix ensures the correct platform is identified based on the presence of torch.version.hip.

topk = 10

carve_template = MatmulTemplate(
Expand Down
2 changes: 1 addition & 1 deletion examples/analyze/example_conv_analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def conv(

def main():
my_func = kernel(N, C, H, W, F, K, S, D, P, 64, 128, 32, 3, 256)
cuda_device = CDNA("cuda") if torch.version.hip is None else CUDA("hip")
cuda_device = CUDA("cuda") if torch.version.hip is None else CDNA("hip")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This change corrects a critical bug in the architecture detection logic. The previous implementation had CUDA and CDNA architectures swapped, which would cause incorrect architecture-specific code paths to be taken. This fix ensures the correct platform is identified based on the presence of torch.version.hip.

result = Analyzer.analysis(my_func, cuda_device)
print(result)
print(f"Analyzed FLOPs: {result.total_flops}")
Expand Down
2 changes: 1 addition & 1 deletion examples/analyze/example_gemm_analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def matmul(
def main():
my_func = kernel(128, 128, 32, 3, 128, True)

cuda_device = CDNA("cuda") if torch.version.hip is None else CUDA("hip")
cuda_device = CUDA("cuda") if torch.version.hip is None else CDNA("hip")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This change corrects a critical bug in the architecture detection logic. The previous implementation had CUDA and CDNA architectures swapped, which would cause incorrect architecture-specific code paths to be taken. This fix ensures the correct platform is identified based on the presence of torch.version.hip.

result = Analyzer.analysis(my_func, cuda_device)

print(f"Analyzed FLOPs: {result.total_flops}")
Expand Down
2 changes: 1 addition & 1 deletion examples/gemm/example_gemm_autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def ref_program(A, B):

def get_configs(M, N, K, with_roller=False, topk=20):
if with_roller:
arch = CDNA("cuda") if torch.version.hip is None else CUDA("hip")
arch = CUDA("cuda") if torch.version.hip is None else CDNA("hip")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This change corrects a critical bug in the architecture detection logic. The previous implementation had CUDA and CDNA architectures swapped, which would cause incorrect architecture-specific code paths to be taken. This fix ensures the correct platform is identified based on the presence of torch.version.hip.

carve_template = MatmulTemplate(
M=M,
N=N,
Expand Down
5 changes: 2 additions & 3 deletions tilelang/carver/roller/policy/tensorcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,10 +281,9 @@ def _assign_block_size(self, node: PrimFuncNode, td: TileDict, block_size: int):

factors = factorize(np.prod(space) // warps)

def _score(node, thread): # small is better
def _score(node, warp_tile): # small is better
score = 0
block_tile = [int(np.ceil(tile[i] / thread[i])) for i in range(ndim)]
shape = node.propagate_inputs_on_reduction(block_tile)
shape = node.propagate_inputs_on_reduction(warp_tile)
Comment on lines +284 to +286
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This is an excellent correction to the scoring function's logic. The previous implementation was flawed, as it used the number of warps per dimension to estimate data traffic instead of the actual warp tile shape. By passing warp_tile directly to propagate_inputs_on_reduction, the score now correctly reflects the data movement cost for a given warp tile configuration, leading to better schedule selection. The parameter rename also improves clarity.

input_buffers = node.block_analyzer.get_input_buffers(node.reduction_block)
for i, _ in enumerate(input_buffers):
score += np.prod(shape[i]) / self.arch.bandwidth[1]
Expand Down
Loading